Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 21, 2024
1 parent a482271 commit f44dbd1
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ You will probably need to install some fonts on your computer so that any pdfs y
sudo apt-get install ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
```


### TODOs for future versions
- Equations could be specified to be in a more specific format (they are "LaTeX" now)
- Ask model to predict footnotes in a structured format separately
- Add training data for complex tables
- More training augmentations to improve performance
- Fix pages which are all-references sometimes rendering as empty-text

2 changes: 1 addition & 1 deletion pdelfin/data/convertsilver_birr.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
# Save the pdf to a temporary cache folder
local_pdf_path = cached_path(s3_path, quiet=True)

raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=6000)
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)


Expand Down
4 changes: 2 additions & 2 deletions pdelfin/train/config/qwen2vl-7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ hparams:
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-6
max_steps: 30000
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
Expand All @@ -59,6 +59,6 @@ hparams:

save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 29500
save_every_steps: 9500

max_workers: 10
103 changes: 103 additions & 0 deletions pdelfin/train/fixqwen2vlcheckpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,89 @@
import argparse
import os
import json
import torch
import boto3
import tempfile
import concurrent.futures

from smart_open import smart_open

from transformers import Qwen2VLForConditionalGeneration
from pdelfin.s3_utils import parse_s3_path

s3_client = boto3.client('s3')

def download_file_from_s3(bucket_name, key, local_file_path):
"""Download a single file from S3."""
s3_client.download_file(bucket_name, key, local_file_path)
print(f"Downloaded {key} to {local_file_path}")

def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
if not os.path.exists(local_model_dir):
os.makedirs(local_model_dir)

# List objects in the S3 model path
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_s3_key)
objects = response.get('Contents', [])

# Prepare list of download tasks
download_tasks = []
for obj in objects:
key = obj['Key']
if key.endswith('/'):
continue # Skip directories

local_file_path = os.path.join(local_model_dir, os.path.basename(key))
download_tasks.append((bucket_name, key, local_file_path))

# Use a ThreadPoolExecutor to download files in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(download_file_from_s3, bucket_name, key, local_file_path)
for bucket_name, key, local_file_path in download_tasks
]

# Wait for all downloads to complete and handle any exceptions
for future in concurrent.futures.as_completed(futures):
try:
future.result() # This will raise any exceptions encountered during download
except Exception as e:
print(f"Error downloading file: {e}")


def upload_file_to_s3(local_file_path, bucket_name, s3_key):
"""Upload a single file to S3."""
try:
s3_client.upload_file(local_file_path, bucket_name, s3_key)
print(f"Uploaded {local_file_path} to s3://{bucket_name}/{s3_key}")
except Exception as e:
print(f"Error uploading {local_file_path} to s3://{bucket_name}/{s3_key}: {e}")


def save_model_to_s3(local_model_dir, bucket_name, s3_model_key):
"""Upload the model directory to S3 in parallel."""
# Collect all file paths to be uploaded
upload_tasks = []
for root, dirs, files in os.walk(local_model_dir):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = os.path.join(s3_model_key, file)
upload_tasks.append((local_file_path, bucket_name, s3_key))

# Use a ThreadPoolExecutor to upload files in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(upload_file_to_s3, local_file_path, bucket_name, s3_key)
for local_file_path, bucket_name, s3_key in upload_tasks
]

# Wait for all uploads to complete and handle any exceptions
for future in concurrent.futures.as_completed(futures):
try:
future.result() # This will raise any exceptions encountered during upload
except Exception as e:
print(f"Error during upload: {e}")

def main():
parser = argparse.ArgumentParser(description='Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr')
parser.add_argument('s3_path', type=str, help='S3 path to the Hugging Face checkpoint.')
Expand Down Expand Up @@ -30,6 +111,28 @@ def main():

assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]

if config_data["torch_dtype"] == "float32":
print("Detected model is float32, this is probably an FSDP checkpoint")
print("Saving to _bf16 location with adjusted parameters")

bucket, prefix = parse_s3_path(args.s3_path)
td = "/tmp/qwen2_checkpoint_saving"
download_model_from_s3(bucket, prefix, td)

print("Downloaded entire model from s3, resaving as bfloat16")
model = Qwen2VLForConditionalGeneration.from_pretrained(td)
model = model.to(torch.bfloat16)
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)

print("Saving...")
model.save_pretrained(os.path.join(td, "bf16_checkpoint"))

print("Uploading")
save_model_to_s3(os.path.join(td, "bf16_checkpoint"), bucket, prefix.rstrip('/') + "/bf16")

args.s3_path = args.s3_path.rstrip('/') + "/bf16"


# Iterate over each file in the replacement list
for replacement_file in qwen_replacement_files:
filename = os.path.basename(replacement_file)
Expand Down

0 comments on commit f44dbd1

Please sign in to comment.