Skip to content

Commit

Permalink
Proper selection of LORA weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 30, 2024
1 parent bcb4794 commit 43aa4f2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
29 changes: 16 additions & 13 deletions pdelfin/train/config/molmo-o-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,22 @@ lora:
dropout: 0.05
task_type: causal_lm
target_modules:
- model.transformer.blocks.*.att_proj
- model.transformer.blocks.*.ff_proj
- model.transformer.blocks.*.attn_out
- model.transformer.blocks.*.ff_out
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wq
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wk
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wv
- model.vision_backbone.image_vit.transformer.resblocks.*.attention.wo
- model.vision_backbone.image_vit.transformer.resblocks.*.feed_forward.w1
- model.vision_backbone.image_vit.transformer.resblocks.*.feed_forward.w2
- model.vision_backbone.image_projector.w1
- model.vision_backbone.image_projector.w2
- model.vision_backbone.image_projector.w3
# attention layers in main transformer
- att_proj
- ff_proj
- attn_out
- ff_out
# vision transformer attention and FF
- attention.wq
- attention.wk
- attention.wv
- attention.wo
- feed_forward.w1
- feed_forward.w2
# vision image projector
- vision_backbone.image_projector.w1
- vision_backbone.image_projector.w2
- vision_backbone.image_projector.w3

save:
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
Expand Down
3 changes: 2 additions & 1 deletion pdelfin/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] =
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
for name, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
(logger or get_logger(__name__)).info(f"training with {name}")
trainable_params += param.numel()

(logger or get_logger(__name__)).info(
Expand Down

0 comments on commit 43aa4f2

Please sign in to comment.