Skip to content

Commit

Permalink
Starting on molmo changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 30, 2024
1 parent 232c445 commit bcb4794
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 5 deletions.
88 changes: 88 additions & 0 deletions pdelfin/train/config/molmo-o-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
model:
name_or_path: allenai/Molmo-7B-O-0924
arch: causal
use_flash_attn: true

wandb:
project: pdelfin
entity: ai2-llm

generate:
max_length: 8192

train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# - name: openai_batch_data_v5_1_train
# response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
# target_longest_image_dim: 1024
# target_anchor_text_len: 6000
# - name: openai_batch_data_v5_1_iabooks_train
# response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
# target_longest_image_dim: 1024
# target_anchor_text_len: 6000

valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
# - name: openai_batch_data_v5_1_eval
# response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
# target_longest_image_dim: 1024
# target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]



# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03

# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: causal_lm
target_modules:
- model.transformer.blocks.*.att_proj
- model.transformer.blocks.*.ff_proj
- model.transformer.blocks.*.attn_out
- model.transformer.blocks.*.ff_out
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wq
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wk
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wv
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wo
- model.vision_backbone.image_vit.transformer.resblocks.*.feed_forward.w1
- model.vision_backbone.image_vit.transformer.resblocks.*.feed_forward.w2
- model.vision_backbone.image_projector.w1
- model.vision_backbone.image_projector.w2
- model.vision_backbone.image_projector.w3

save:
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
save_every_steps: 1000

max_workers: 10
18 changes: 13 additions & 5 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,21 @@ def run_train(config: TrainConfig):

setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)

processor = AutoProcessor.from_pretrained(config.model.name_or_path)
processor = AutoProcessor.from_pretrained(config.model.name_or_path, trust_remote_code=True)
train_dataset, valid_dataset = make_dataset(config, processor)

model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
)
if "qwen" in config.model.name_or_path.lower():
model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
)
else:
model = AutoModelForCausalLM.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16,
trust_remote_code=True
)

print(model)

if config.lora is not None:
peft_config = LoraConfig(
Expand Down

0 comments on commit bcb4794

Please sign in to comment.