Skip to content

Commit

Permalink
Proper use of iterable_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 26, 2024
1 parent 05fdb81 commit cf1aa01
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,14 @@ def run_train(config: TrainConfig):
# formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))

# Convert to an iteratble dataset, so we can apply map and filter without doing a full calculation in advance
formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]

print(formatted_dataset)
train_ds = train_ds.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)
validation_ds = validation_ds.map(partial(batch_prepare_data_for_qwen2_training, processor=processor))

print(train_ds)
print(validation_ds)
print("---------------")

save_path = join_path("", config.save.path, run_name.run)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,13 @@ def testExtractResponse(self):

print(response_data)
print(response_data[0])

def testIterableDataset(self):
dataset = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(batch_prepare_data_for_qwen2_training, processor=processor)).filter(lambda x: x["input_ids"].shape[1] < 4500)

0 comments on commit cf1aa01

Please sign in to comment.