Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load ckpt files generated bytorchtune.utils.FullModelHFCheckpointer into hf models #878

Closed
BMPixel opened this issue Apr 26, 2024 · 2 comments

Comments

@BMPixel
Copy link

BMPixel commented Apr 26, 2024

When using torchtune.utils.FullModelHFCheckpointer to load huggingface models, it reads *.safetensor files, while it instead outputs *.pt as ckpt files. The *.pt can not be load with from_pretrained function.

Is there a way to convert *.pt ckpt files into something like pytorch_model.bin or *.safetensors?

This issue is similar to #832 , which seem focus on converting meta ckpt files like consolidated.xx.pth. I am wondering will it be good to have a cli tool to convert ckpts between meta, pytorch and huggingface formats? That will be helpful.

@BMPixel
Copy link
Author

BMPixel commented Apr 26, 2024

I've figure it out. I post my understanding here if anyone has same question.

First of all, *bin and *.pt are the same format of files which can be processed with torch.load/save. And *.safetensor is just another format used by huggingface. They all contain state_dicts of model.

torchtune.utils.FullModelHFCheckpointer create pt files that are corresponding to the safetensor files, but in different format. However huggingface can read these file with no difference, so the simplest solution to from_pretrained your ckpts is to modify the model.safetensors.index.json like this:

{
  "metadata": {
    "total_size": 32121044992
  },
  "weight_map": {
    "lm_head.weight": "hf_model_0007_0.pt",
    "model.embed_tokens.weight": "hf_model_0001_0.pt",
    "model.layers.0.input_layernorm.weight": "hf_model_0001_0.pt"
... 

from_pretrained will look for the model.safetensors.index.json and read all ckpts. Furthermore, if you want to create a pytorch_model.bin out of *.pt files, you can just merge all dicts from *.pt files and save it as pytorch_model.bin

pt_to_merge = glob.glob("outputs/trial_one/hf_model_000*_1.pt")
state_dicts = [torch.load(p) for p in tqdm(pt_to_merge)]
merged_state_dicts = {k: v for d in state_dicts for k, v in d.items()}
torch.save(merged_state_dicts, "pytorch_model.bin"

I hope these help!

@kartikayk
Copy link
Contributor

@BMPixel absolutely spot on! Thanks so much for the detailed comment on this - all of this makes sense to me.

You can simply rename the .pt files to .bin and things should work. .pt is a more common pytorch extension and so that's what we output. Great point about modifying model.safetensors.index.json!

Did you need any other changes? Or is the json change worked?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants