Skip to content

Commit

Permalink
Hoping to get a quick batch inference pipeline rolling
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 24, 2024
1 parent 45f691c commit 28bcf72
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 29 deletions.
49 changes: 20 additions & 29 deletions pdelfin/train/batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,53 +48,44 @@
)


from pdelfin.train.dataloader import make_dataset
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_training
from pdelfin.train.dataloader import load_jsonl_from_s3, extract_openai_batch_query
from pdelfin.train.dataprep import batch_prepare_data_for_qwen2_inference


def run_train(model_name: str, dataset_path: str):
if get_rank() == 0:
logger_level = logging.INFO
else:
logger_level = logging.WARN
disable_progress_bars()

logger = get_logger(__name__, level=logger_level)
set_verbosity(logger_level)

dataset = make_dataset(
train_data_config=config.train_data,
valid_data_config=config.valid_data,
num_proc=config.num_proc,
logger=logger,
)
def run_inference(model_name: str, query_dataset_path: str):
logger = get_logger(__name__, level=logging.INFO)
set_verbosity(logging.INFO)


model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto",
_attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_name)

query_data = load_jsonl_from_s3(query_dataset_path)

# Map the datasets down to the core fields that we're going to need to make them easier to process
logger.info("Mapping query data")
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)

formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))

formatted_dataset = query_data.with_transform(partial(batch_prepare_data_for_qwen2_inference, processor=processor))
print(formatted_dataset)
print("---------------")


with TemporaryDirectory() as output_dir:



# Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
train_dataloader = DataLoader(formatted_dataset, batch_size=1, num_workers=4, shuffle=False)
for entry in tqdm(train_dataloader):
print("Step!")
model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})


def main():
run_inference(model_name="Qwen/Qwen2-VL-2B-Instruct",
dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl")
query_dataset_path="s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl")


if __name__ == "__main__":
Expand Down
69 changes: 69 additions & 0 deletions pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,75 @@ def batch_prepare_data_for_qwen2_training(batch, processor):
}


def prepare_data_for_qwen2_inference(example, processor):
# Prepare messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))

# Right now, we are going to downsample to 1024 on the longest dimension, because
# 2048 as we passed to OpenAI is too large for training
width, height = main_image.size
assert max(width, height) == 2048
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)


# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)

input_ids = inputs["input_ids"]

# All columns will participate in attention fully
attention_mask = np.ones_like(input_ids)

# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0]
}


def batch_prepare_data_for_qwen2_inference(batch, processor):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["input_prompt_image_base64"])):
example = {
"input_prompt_image_base64": batch["input_prompt_image_base64"][i],
"input_prompt_text": batch["input_prompt_text"][i],
}
processed_example = prepare_data_for_qwen2_inference(example, processor)
processed_examples.append(processed_example)

return {
"input_ids": [x["input_ids"] for x in processed_examples],
"attention_mask": [x["attention_mask"] for x in processed_examples],
"pixel_values": [x["pixel_values"] for x in processed_examples],
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}

# Define a custom data collator
class DataCollatorForVisionLanguageModeling:
def __init__(self, processor):
Expand Down

0 comments on commit 28bcf72

Please sign in to comment.