Skip to content

Commit

Permalink
Adding test to make sure the traning and inference time tokenization …
Browse files Browse the repository at this point in the history
…stays identical, currenlty failing
  • Loading branch information
jakep-allenai committed Sep 20, 2024
1 parent fcb67eb commit a47afe5
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def prepare_data_for_qwen2_training(example, processor):

# Return as dict, including pixel_values
return {
"input_ids": input_ids.tolist(),
"attention_mask": attention_mask.tolist(),
"labels": labels_full.tolist(),
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values[0]
}

Expand Down
4 changes: 2 additions & 2 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def run_train(config: TrainConfig):
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

train_ds = train_ds.map(partial(prepare_data_for_qwen2_training, processor=processor),
remove_columns=train_ds.column_names)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_dataprep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
import base64
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor

from pdelfin.train.dataloader import (
build_batch_query_response_vision_dataset,
)

from pdelfin.train.dataprep import (
prepare_data_for_qwen2_training
)


class TestDataprep(unittest.TestCase):
def testTokenizationMatches(self):
ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2_mini/*.json",
)

example = ds[0]

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

full_messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
],
},

{
"role": "assistant",
"content": example["response"]
}
]

text = processor.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=True)

# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))

# Process inputs using processor
inference_inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)

print(inference_inputs)
print(inference_inputs["input_ids"].shape)

training_inputs = prepare_data_for_qwen2_training(example, processor=processor)

print(training_inputs)
print(training_inputs["input_ids"].shape)

0 comments on commit a47afe5

Please sign in to comment.