Skip to content

Commit

Permalink
Basic LORA trainer, doesn't seem to make any speed difference
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 23, 2024
1 parent 3ed14a9 commit ab9458b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
12 changes: 6 additions & 6 deletions pdelfin/train/config/qwen2vl-2b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ model:
arch: causal

wandb:
project: refine
entity: pdf-qwen2vl
project: pdelfin
entity: ai2-llm

# TODO This is not used
format:
Expand Down Expand Up @@ -93,10 +93,10 @@ hparams:
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 10000
max_steps: 200
pad_multiple_of: 16
log_every_steps: 5
eval_every_steps: 250
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
Expand All @@ -118,7 +118,7 @@ lora:
- down_proj

save:
path: s3://ai2-tylerm-experimental/experiments/rephrase/v1/models/lucas
save_every_steps: 500
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 100

max_workers: 1
58 changes: 42 additions & 16 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def run_train(config: TrainConfig):

accelerator = accelerate.Accelerator()

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)

train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
Expand All @@ -133,10 +135,25 @@ def run_train(config: TrainConfig):
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

if config.lora is not None:
peft_config = LoraConfig(
r=config.lora.rank,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
bias=config.lora.bias, # pyright: ignore
task_type=config.lora.task_type,
target_modules=list(config.lora.target_modules),
)
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)

train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(train_ds)
print("---------------")

save_path = join_path("", config.save.path, run_name.run)

save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore

with TemporaryDirectory() as output_dir:

Expand Down Expand Up @@ -177,22 +194,31 @@ def run_train(config: TrainConfig):

# Set the collator
collator = partial(packing_collator, pad_multiple_of=config.hparams.pad_multiple_of, do_shrink=False)
#checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)

# # Initialize Trainer
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=train_ds,
# #eval_dataset=formatted_dataset["validation"], # pyright: ignore
# tokenizer=processor.tokenizer,
# #data_collator=collator,
# #callbacks=[checkpoint_callback],
# )


# # Train the model
# trainer.train() # pyright: ignore
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)

# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#data_collator=collator,
#callbacks=[checkpoint_callback],
)

# Could not get this to work
# if get_rank() == 0:
# # this is a hack to add script and peft config to wandb config
# update_wandb_config(config, trainer, model)

# Train the model
trainer.train() # pyright: ignore

with get_local_dir(join_path("", save_path, "best")) as best_dir:
model.save_pretrained(best_dir)
tokenizer.tokenizer.save_pretrained(best_dir)
logger.info("Saved best model to %s", best_dir)

# Uncomment to test speed of data loader
# train_dataloader = DataLoader(train_ds, batch_size=1, num_workers=2, shuffle=False)
Expand Down

0 comments on commit ab9458b

Please sign in to comment.