Skip to content

Commit

Permalink
map and filter on iterable dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 26, 2024
1 parent f14e910 commit 05fdb81
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pdelfin/train/config/qwen2vl-7b-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ hparams:
pad_multiple_of: 16
log_every_steps: 50
eval_every_steps: 500
optim: adamw_bnb_8bit
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
Expand Down
7 changes: 6 additions & 1 deletion pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)

formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))

# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)

print(formatted_dataset)
print("---------------")

Expand Down

0 comments on commit 05fdb81

Please sign in to comment.