diff --git a/README.md b/README.md index c79afae..5904156 100644 --- a/README.md +++ b/README.md @@ -18,3 +18,12 @@ You will probably need to install some fonts on your computer so that any pdfs y sudo apt-get install ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools ``` + + +### TODOs for future versions + - Equations could be specified to be in a more specific format (they are "LaTeX" now) + - Ask model to predict footnotes in a structured format separately + - Add training data for complex tables + - More training augmentations to improve performance + - Fix pages which are all-references sometimes rendering as empty-text + diff --git a/pdelfin/data/convertsilver_birr.py b/pdelfin/data/convertsilver_birr.py index ec9352b..4f7c29a 100644 --- a/pdelfin/data/convertsilver_birr.py +++ b/pdelfin/data/convertsilver_birr.py @@ -101,7 +101,7 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool): # Save the pdf to a temporary cache folder local_pdf_path = cached_path(s3_path, quiet=True) - raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") + raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=6000) transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text) diff --git a/pdelfin/train/config/qwen2vl-7b.yaml b/pdelfin/train/config/qwen2vl-7b.yaml index 3eb1e78..f4f896d 100644 --- a/pdelfin/train/config/qwen2vl-7b.yaml +++ b/pdelfin/train/config/qwen2vl-7b.yaml @@ -47,7 +47,7 @@ hparams: gradient_checkpointing: true clip_grad_norm: 1.0 learning_rate: 1e-6 - max_steps: 30000 + max_steps: 10000 pad_multiple_of: 16 log_every_steps: 10 eval_every_steps: 100 @@ -59,6 +59,6 @@ hparams: save: path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/ - save_every_steps: 29500 + save_every_steps: 9500 max_workers: 10 \ No newline at end of file diff --git a/pdelfin/train/fixqwen2vlcheckpoint.py b/pdelfin/train/fixqwen2vlcheckpoint.py index 5d37448..417d2c4 100644 --- a/pdelfin/train/fixqwen2vlcheckpoint.py +++ b/pdelfin/train/fixqwen2vlcheckpoint.py @@ -1,8 +1,89 @@ import argparse import os import json +import torch +import boto3 +import tempfile +import concurrent.futures + from smart_open import smart_open +from transformers import Qwen2VLForConditionalGeneration +from pdelfin.s3_utils import parse_s3_path + +s3_client = boto3.client('s3') + +def download_file_from_s3(bucket_name, key, local_file_path): + """Download a single file from S3.""" + s3_client.download_file(bucket_name, key, local_file_path) + print(f"Downloaded {key} to {local_file_path}") + +def download_model_from_s3(bucket_name, model_s3_key, local_model_dir): + if not os.path.exists(local_model_dir): + os.makedirs(local_model_dir) + + # List objects in the S3 model path + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_s3_key) + objects = response.get('Contents', []) + + # Prepare list of download tasks + download_tasks = [] + for obj in objects: + key = obj['Key'] + if key.endswith('/'): + continue # Skip directories + + local_file_path = os.path.join(local_model_dir, os.path.basename(key)) + download_tasks.append((bucket_name, key, local_file_path)) + + # Use a ThreadPoolExecutor to download files in parallel + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(download_file_from_s3, bucket_name, key, local_file_path) + for bucket_name, key, local_file_path in download_tasks + ] + + # Wait for all downloads to complete and handle any exceptions + for future in concurrent.futures.as_completed(futures): + try: + future.result() # This will raise any exceptions encountered during download + except Exception as e: + print(f"Error downloading file: {e}") + + +def upload_file_to_s3(local_file_path, bucket_name, s3_key): + """Upload a single file to S3.""" + try: + s3_client.upload_file(local_file_path, bucket_name, s3_key) + print(f"Uploaded {local_file_path} to s3://{bucket_name}/{s3_key}") + except Exception as e: + print(f"Error uploading {local_file_path} to s3://{bucket_name}/{s3_key}: {e}") + + +def save_model_to_s3(local_model_dir, bucket_name, s3_model_key): + """Upload the model directory to S3 in parallel.""" + # Collect all file paths to be uploaded + upload_tasks = [] + for root, dirs, files in os.walk(local_model_dir): + for file in files: + local_file_path = os.path.join(root, file) + s3_key = os.path.join(s3_model_key, file) + upload_tasks.append((local_file_path, bucket_name, s3_key)) + + # Use a ThreadPoolExecutor to upload files in parallel + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(upload_file_to_s3, local_file_path, bucket_name, s3_key) + for local_file_path, bucket_name, s3_key in upload_tasks + ] + + # Wait for all uploads to complete and handle any exceptions + for future in concurrent.futures.as_completed(futures): + try: + future.result() # This will raise any exceptions encountered during upload + except Exception as e: + print(f"Error during upload: {e}") + def main(): parser = argparse.ArgumentParser(description='Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr') parser.add_argument('s3_path', type=str, help='S3 path to the Hugging Face checkpoint.') @@ -30,6 +111,28 @@ def main(): assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"] + if config_data["torch_dtype"] == "float32": + print("Detected model is float32, this is probably an FSDP checkpoint") + print("Saving to _bf16 location with adjusted parameters") + + bucket, prefix = parse_s3_path(args.s3_path) + td = "/tmp/qwen2_checkpoint_saving" + download_model_from_s3(bucket, prefix, td) + + print("Downloaded entire model from s3, resaving as bfloat16") + model = Qwen2VLForConditionalGeneration.from_pretrained(td) + model = model.to(torch.bfloat16) + os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True) + + print("Saving...") + model.save_pretrained(os.path.join(td, "bf16_checkpoint")) + + print("Uploading") + save_model_to_s3(os.path.join(td, "bf16_checkpoint"), bucket, prefix.rstrip('/') + "/bf16") + + args.s3_path = args.s3_path.rstrip('/') + "/bf16" + + # Iterate over each file in the replacement list for replacement_file in qwen_replacement_files: filename = os.path.basename(replacement_file)