Skip to content

Commit

Permalink
Fixing qwen checkpoint script
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Sep 30, 2024
1 parent 963e946 commit e179453
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions pdelfin/train/fixqwen2vlcheckpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import argparse
import os
import tempfile
import boto3
from tqdm import tqdm
from transformers import AutoModel, Qwen2VLForConditionalGeneration
from smart_open import smart_open


def main():
parser = argparse.ArgumentParser(description='Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr')
parser.add_argument('s3_path', type=str, help='S3 path to the Hugging Face checkpoint.')
args = parser.parse_args()

# Create a temporary directory to store the model files

# Rewrite the config.json from the official repo, this fixes a weird bug with the rope scaling configuration
with smart_open("https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/config.json", "r") as newf:
new_config = newf.read()

with smart_open(os.path.join(args.s3_path, "config.json"), "w") as oldf:
oldf.write(new_config)

print("Model updated successfully.")

if __name__ == '__main__':
main()

0 comments on commit e179453

Please sign in to comment.