Skip to content

Commit

Permalink
Prepping data to be in a trainable format
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 20, 2024
1 parent dc86a99 commit fcb67eb
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 41 deletions.
82 changes: 82 additions & 0 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
from io import BytesIO
from PIL import Image
import base64


def prepare_data_for_qwen2_training(example, processor):
# Prepare messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))

# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)

# Get labels by tokenizing the output text
labels = processor(
text=[example["response"]],
padding=True,
return_tensors="np"
)

# Concatenate input_ids and labels
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
attention_mask = np.concatenate([inputs.attention_mask[0], labels.attention_mask[0]], axis=0)

# Create labels, masking the input portion with -100
labels_full = np.full_like(input_ids, fill_value=-100)
labels_full[len(inputs.input_ids[0]):] = labels.input_ids[0]

# Return as dict, including pixel_values
return {
"input_ids": input_ids.tolist(),
"attention_mask": attention_mask.tolist(),
"labels": labels_full.tolist(),
"pixel_values": inputs.pixel_values[0]
}


# Define a custom data collator
class DataCollatorForVisionLanguageModeling:
def __init__(self, processor):
self.processor = processor

def __call__(self, features):
input_ids = [f['input_ids'] for f in features]
attention_mask = [f['attention_mask'] for f in features]
labels = [f['labels'] for f in features]
pixel_values = [f['pixel_values'] for f in features]

# Pad input_ids, attention_mask, labels
batch = self.processor.pad(
{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels},
return_tensors="pt",
padding=True,
)

# Stack pixel_values
batch['pixel_values'] = torch.stack([torch.tensor(pv) for pv in pixel_values])

return batch
51 changes: 10 additions & 41 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Step 5. Move over from interactive session to gantry launch script

import os
import json
import base64
import logging
from io import BytesIO
Expand All @@ -19,6 +20,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from tqdm import tqdm

import accelerate
import torch
Expand Down Expand Up @@ -56,58 +58,25 @@


from pdelfin.train.dataloader import build_batch_query_response_vision_dataset
from pdelfin.train.dataprep import prepare_data_for_qwen2_training


def run_train(config: TrainConfig):
train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

for entry in train_ds:
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": entry["input_prompt_image_base64"]
},
{"type": "text", "text": entry["input_prompt_text"]},
],
}
]

# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"])))
train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor),
remove_columns=train_ds.column_names)

print(train_ds)

inputs = processor(
text=[text],
images=[main_image],
#videos=video_inputs,
padding=True,
return_tensors="pt",
)
#inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)



Expand Down

0 comments on commit fcb67eb

Please sign in to comment.