Skip to content

Commit

Permalink
Getting ready to launch a new training run
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 2, 2024
1 parent 1686790 commit 0ddaf90
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 37 deletions.
27 changes: 13 additions & 14 deletions pdelfin/train/config/qwen2vl-7b-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@ generate:
train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
backend:
- openai
size: 100_000
- name: openai_batch_data_v5_1_eval # TODO This is just for testing the job, once ready change to a real train dataset
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl

valid_data:
sources:
- name: openai_batch_data_eval_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000
- name: openai_batch_data_v5_1_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.jsonl
- name: openai_batch_data_v5_1_iabooks_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.jsonl



# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
Expand All @@ -52,10 +51,10 @@ hparams:
gradient_checkpointing: false
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 5000
max_steps: 2000
pad_multiple_of: 16
log_every_steps: 50
eval_every_steps: 500
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
Expand Down
2 changes: 0 additions & 2 deletions pdelfin/train/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ class AwsConfig:
@dataclass
class SourceConfig:
name: str = field(help="The name of the source")
size: int = field(help="Limit size for the source")
query_glob_path: str = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
backend: List[str] = field(help="The data generation backend to use to train the model")


@dataclass
Expand Down
50 changes: 30 additions & 20 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import accelerate
import torch
import torch.distributed
from datasets import DatasetDict
from datasets import DatasetDict, concatenate_datasets
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
Expand Down Expand Up @@ -49,7 +49,7 @@
)


from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, filter_by_max_seq_len


Expand Down Expand Up @@ -113,13 +113,6 @@ def run_train(config: TrainConfig):

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)

dataset = make_dataset(
train_data_config=config.train_data,
valid_data_config=config.valid_data,
num_proc=config.num_proc,
logger=logger,
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
Expand All @@ -139,15 +132,33 @@ def run_train(config: TrainConfig):
log_trainable_parameters(model=model, logger=logger)

# Do final filtering, and prep for running model forward()
filtered_dataset = DatasetDict(**{split: dataset[split].filter(partial(filter_by_max_seq_len, processor=processor)) for split in dataset})
formatted_dataset = filtered_dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(formatted_dataset)
print("---------------")


# Training sets get all concatenated and shuffled
train_dataset = (
concatenate_datasets(
[
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
for source in config.train_data.sources
]
)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
)

# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
.filter(partial(filter_by_max_seq_len, processor=processor))
.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
for source in config.valid_data.sources
}
)

save_path = join_path("", config.save.path, run_name.run)

save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore

with TemporaryDirectory() as output_dir:

training_args = TrainingArguments(
Expand Down Expand Up @@ -192,8 +203,8 @@ def run_train(config: TrainConfig):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["validation"], # pyright: ignore
train_dataset=train_dataset,
eval_dataset=valid_dataset,
tokenizer=processor.tokenizer,
#Collator is not needed as we are doing batch size 1 for now...
#data_collator=collator,
Expand All @@ -215,9 +226,8 @@ def run_train(config: TrainConfig):
logger.info("LoRA adapters merged successfully.")

model.save_pretrained(best_dir)

logger.info("Saved best model to %s", best_dir)

logger.info("Saved best model to %s", best_dir)

# Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
Expand All @@ -232,4 +242,4 @@ def main():


if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion scripts/qwen2vl-7b-gantry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ run_name=$(basename "$0" .sh)
# --cluster 'ai2/allennlp-cirrascale' \
# --priority high \

CLUSTER='jupiter'
CLUSTER='pluto'

gantry run \
--description "${run_name}"\
Expand Down

0 comments on commit 0ddaf90

Please sign in to comment.