Skip to content

Commit

Permalink
Basic forward generation pass with openai dataset and qwen2vl
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 19, 2024
1 parent 7d2c447 commit 84e68f3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pdelfin/train/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class GenerateConfig:
@dataclass
class WandbConfig:
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
project: str = field(help="The wandb project to use for logging", default="refine")
project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl")
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
watch: str = field(help="The wandb watch to use for logging", default="false")
Expand Down
48 changes: 44 additions & 4 deletions pdelfin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
# Step 5. Move over from interactive session to gantry launch script

import os
import base64
import logging
import os
from io import BytesIO
from PIL import Image
from functools import partial
from logging import Logger
from pathlib import Path
Expand Down Expand Up @@ -53,12 +55,10 @@
)


from qwen_vl_utils import process_vision_info

from pdelfin.train.dataloader import build_batch_query_response_vision_dataset


def run_train():
def run_train(config: TrainConfig):
train_ds = build_batch_query_response_vision_dataset(
query_glob_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl",
response_glob_path="s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json",
Expand All @@ -69,6 +69,46 @@ def run_train():
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

for entry in train_ds:
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": entry["input_prompt_image_base64"]
},
{"type": "text", "text": entry["input_prompt_text"]},
],
}
]

# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

main_image = Image.open(BytesIO(base64.b64decode(entry["input_prompt_image_base64"])))

inputs = processor(
text=[text],
images=[main_image],
#videos=video_inputs,
padding=True,
return_tensors="pt",
)
#inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)



def main():
Expand Down

0 comments on commit 84e68f3

Please sign in to comment.