Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 26, 2024
1 parent d098a87 commit c00e40d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pdelfin/silver_data/buildsilver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _build_prompt(base_text: str) -> str:
return (
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it. "
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
f"Turn equations into a LaTeX representation. Remove the headers and footers, but keep references and footnotes.\n"
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n"
f"Read any natural handwriting.\n"
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n"
f"If there is no text at all that you think you should read, just output [NO TEXT].\n"
Expand Down
25 changes: 17 additions & 8 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def filter_by_max_seq_len(example, max_seq_len=4500):
return sizes[-1] <= max_seq_len


def prepare_data_for_qwen2_training(example, processor):
def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
# Prepare messages
messages = [
{
Expand Down Expand Up @@ -71,13 +71,22 @@ def prepare_data_for_qwen2_training(example, processor):
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]

# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}
if add_batch_dim:
return {
"input_ids": input_ids[np.newaxis, ...],
"attention_mask": attention_mask[np.newaxis, ...],
"labels": labels_full[np.newaxis, ...],
"pixel_values": inputs.pixel_values[np.newaxis, ...],
"image_grid_thw": inputs["image_grid_thw"]
}
else:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}


def batch_prepare_data_for_qwen2_training(batch, processor):
Expand Down
4 changes: 2 additions & 2 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def run_train(config: TrainConfig):
train_ds = dataset["train"].to_iterable_dataset(num_shards=64)
validation_ds = dataset["validation"]

train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=validation_ds.column_names)
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=train_ds.column_names).filter(filter_by_max_seq_len)
validation_ds = validation_ds.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True)), remove_columns=validation_ds.column_names)

print(train_ds)
print(validation_ds)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def testIterableDataset(self):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

formatted_dataset = dataset.to_iterable_dataset(num_shards=64)
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)
formatted_dataset = formatted_dataset.map(partial(prepare_data_for_qwen2_training, processor=processor, add_batch_dim=True), remove_columns=formatted_dataset.column_names).filter(lambda x: x["input_ids"].shape[0] < 4500)

for entry in formatted_dataset:
print(entry)
Expand Down

0 comments on commit c00e40d

Please sign in to comment.