Skip to content

Commit

Permalink
Fixing the refiner input prompt to something simpler that doesn't dep…
Browse files Browse the repository at this point in the history
…end on the training data. Fixing beaker job workspace and bumping priority to high.
  • Loading branch information
jakep-allenai committed Sep 27, 2024
1 parent 22b765e commit decfd7f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
16 changes: 15 additions & 1 deletion pdelfin/train/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,24 @@ def extract_openai_batch_query(query: Dict[str, Any]) -> Dict[str, Any]:
except IndexError:
input_prompt_image_base64 = ""

# At this point, the input_prompt_text is the raw text that was passed to the OpenAI model
# to generate our silver data. But, we want to have a simplfied prompt for this here fine tune,
# so we're going to extract out just the raw extracted prompt text
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"

# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, input_prompt_text, re.DOTALL)

if match:
raw_page_text = match.group(1).strip()
else:
raw_page_text = ""

return {
"custom_id": custom_id,
"input_prompt_text": input_prompt_text,
"input_prompt_image_base64": input_prompt_image_base64,
"raw_page_text": raw_page_text,
}


Expand Down Expand Up @@ -223,7 +237,7 @@ def pick_image_sizes(x):
final_dataset = final_dataset.filter(pick_image_sizes)

# Limit the size of the input text not to explode the context size
final_dataset = final_dataset.filter(lambda x: len(x["input_prompt_text"]) < 4000)
final_dataset = final_dataset.filter(lambda x: len(x["raw_page_text"]) < 4000)

return final_dataset

Expand Down
12 changes: 11 additions & 1 deletion pdelfin/train/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ def filter_by_max_seq_len(example, processor, max_prompt_len: int=2000, max_resp
return True


# This is a base prompt that will be used for training and running the fine tuned model
# It's simplified from the prompt which was used to generate the silver data, and can change from dataset to dataset
def _build_finetuning_prompt(base_text: str) -> str:
return (
f"Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. "
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
f"Do not hallucinate.\n"
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)


def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
# Prepare messages
Expand All @@ -26,7 +36,7 @@ def prepare_data_for_qwen2_training(example, processor, add_batch_dim=False):
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
{"type": "text", "text": _build_finetuning_prompt(example["raw_page_text"])},
],
}
]
Expand Down
4 changes: 2 additions & 2 deletions scripts/qwen2vl-7b-gantry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ gantry run \
--task-name "${run_name}"\
--allow-dirty \
--host-networking \
--workspace ai2/oe-data-pdf \
--workspace ai2/oe-data-model-based-cleanup \
--beaker-image 'jakep/jakep-pdf-finetunev1.1' \
--venv 'base' \
--pip gantry-requirements.txt \
--priority normal \
--priority high \
--gpus 8 \
--preemptible \
--cluster "ai2/${CLUSTER}*" \
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def testPlotSequenceLengthHistogram(self):
fig.write_image("sequence_lengths_histogram.png")

def testExtractBatch(self):
query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
query_data = load_jsonl_from_s3("s3://ai2-oe-data/jakep/openai_batch_data_v2_mini/*.jsonl", first_n_files=3)
query_data = query_data["train"]
query_data = query_data.map(extract_openai_batch_query, remove_columns=query_data.column_names)

Expand Down
9 changes: 7 additions & 2 deletions tests/test_dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)

from pdelfin.train.dataprep import (
prepare_data_for_qwen2_training
prepare_data_for_qwen2_training, _build_finetuning_prompt
)


Expand All @@ -32,7 +32,7 @@ def testTokenizationMatches(self):
"type": "image",
"image": example["input_prompt_image_base64"] # Placeholder
},
{"type": "text", "text": example["input_prompt_text"]},
{"type": "text", "text": _build_finetuning_prompt(example["raw_page_text"])},
],
},

Expand All @@ -47,6 +47,11 @@ def testTokenizationMatches(self):
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(example["input_prompt_image_base64"])))

width, height = main_image.size
assert 1800 <= max(width, height) <= 2200, f"Image size {width}x{height} invalid"
main_image = main_image.resize((width // 2, height // 2), Image.LANCZOS)


# Process inputs using processor
inference_inputs = processor(
text=[text],
Expand Down

0 comments on commit decfd7f

Please sign in to comment.