Skip to content

Commit

Permalink
Removing lambda due to pickling errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 26, 2024
1 parent 61dd7bb commit 84e9da6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 6 additions & 0 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import base64
import torch # Make sure to import torch as it's used in the DataCollator


def filter_by_max_seq_len(example, max_seq_len=4500):
sizes = example["input_ids"].shape
return sizes[-1] <= max_seq_len


def prepare_data_for_qwen2_training(example, processor):
# Prepare messages
messages = [
Expand Down
4 changes: 2 additions & 2 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@


from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training, prepare_data_for_qwen2_training
from pdelfin.train.dataprep import filter_by_max_seq_len, prepare_data_for_qwen2_training


class CheckpointUploadCallback(TrainerCallback):
Expand Down Expand Up @@ -143,7 +143,7 @@ def run_train(config: TrainConfig):
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]

train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[0] < 4500)
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor)).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor))

print(train_ds)
Expand Down

0 comments on commit 84e9da6

Please sign in to comment.