Skip to content

Commit

Permalink
Tries to run a forward pass but oOMS
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 20, 2024
1 parent 4eddb1b commit 55035b0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
24 changes: 23 additions & 1 deletion pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,29 @@ def prepare_data_for_qwen2_training(example, processor):
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values[0]
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}


def batch_prepare_data_for_qwen2_training(batch, processor):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["input_prompt_image_base64"])):
example = {
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
"input_prompt_text": batch["input_prompt_text"][i],
"response": batch["response"][i]
}
processed_example = prepare_data_for_qwen2_training(example, processor)
processed_examples.append(processed_example)

return {
"input_ids": [x["input_ids"] for x in processed_examples],
"attention_mask": [x["attention_mask"] for x in processed_examples],
"labels": [x["labels"] for x in processed_examples],
"pixel_values": [x["pixel_values"] for x in processed_examples],
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}


Expand Down
15 changes: 10 additions & 5 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from transformers.integrations import WandbCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import get_last_checkpoint
from torch.utils.data import DataLoader

import wandb

Expand All @@ -58,7 +59,7 @@


from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import prepare_data_for_qwen2_training
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training


def run_train(config: TrainConfig):
Expand All @@ -72,11 +73,15 @@ def run_train(config: TrainConfig):
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor),
remove_columns=train_ds.column_names)

train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(train_ds)


dataloader = DataLoader(train_ds, batch_size=1, shuffle=False)

for batch in dataloader:
print(batch)

result = model.forward(**batch)



Expand Down

0 comments on commit 55035b0

Please sign in to comment.