Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FullModelHFCheckpointer saved checkpoint isn't compatible with Huggingface transformers model loading #2048

Closed
vancoykendall opened this issue Nov 21, 2024 · 11 comments
Assignees

Comments

@vancoykendall
Copy link
Contributor

Loading huggingface transformers models is done with the from_pretrained() method. For pytorch or safetensors checkpoints, this method expects a pytorch_model.bin or model.safetensors file for single file checkpoints. For sharded checkpoints, it expects either a pytorch_model.bin.index.json or model.safetensors.index.json file that maps each weight to the shard file (I think sharded checkpoint files can have arbitrary naming).

Currently, the FullModelHFCheckpointer doesn’t name single file checkpoitns pytorch_model.bin or model.safetensors and doesn't create an index.json file for sharded checkpoints. Thus, saved checkpoints can't be loaded with from_pretrained().

# write the partitioned state dicts to the right checkpoint file
for cpt_idx, model_state_dict in split_state_dicts.items():
if not self._safe_serialization:
output_path = Path.joinpath(
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
).with_suffix(".pt")
torch.save(model_state_dict, output_path)
else:
output_path = Path.joinpath(
self._output_dir,
f"model-0{cpt_idx}-of-0{list(split_state_dicts.keys())[-1]}_{epoch}",
).with_suffix(".safetensors")
save_file(model_state_dict, output_path, metadata={"format": "pt"})
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)

@vancoykendall
Copy link
Contributor Author

Here's some potentially useful links related to huggingface checkpoint creation
Huggingface documents index file creation here: https://huggingface.co/docs/transformers/v4.46.3/en/big_models#sharded-checkpoints

Code for sharding state dict here: https://github.com/huggingface/huggingface_hub/blob/v0.26.2/src/huggingface_hub/serialization/_torch.py

@felipemello1
Copy link
Contributor

felipemello1 commented Nov 22, 2024

hey @vancoykendall , thank you so much for this issue! I am one of the maintainers and am working on fixing this. @joecummings is also working on refactoring our checkpointing. This will become much easier in the following days.

Meanwhile, to unblock you, we give an example here: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#using-torchtune-checkpoints-with-other-libraries

You can manually make it .bin (i know, that's not fun) and select the files to keep in the folder. Then from_pretrained will work.

@vancoykendall
Copy link
Contributor Author

Nice thanks! I also just patched the save_checkpoint method locally so I don't have to convert them anymore

# split the state_dict into separate dicts, one for each output checkpoint file
            split_state_dicts: Dict[str, Dict[str, torch.Tensor]] = {}
            total_size = 0
            for key, weight in state_dict[training.MODEL_KEY].items():
                cpt_idx = self._weight_map[key]
                if cpt_idx not in split_state_dicts:
                    split_state_dicts[cpt_idx] = {}
                split_state_dicts[cpt_idx].update({key: weight})
                total_size += weight.numel() * weight.element_size()

            # write the partitioned state dicts to the right checkpoint file
            num_shards = len(split_state_dicts)
            for cpt_idx, model_state_dict in split_state_dicts.items():
                if not self._safe_serialization:
                    shard_name = f"pytorch_model-{int(cpt_idx):05d}-of-{int(num_shards):05d}"
                    output_path = Path.joinpath(
                        self._output_dir, f"{shard_name}_{epoch}"
                    ).with_suffix(".bin")
                    torch.save(model_state_dict, output_path)
                else:
                    shard_name = f"model-{int(cpt_idx):05d}-of-{int(num_shards):05d}_{epoch}"
                    output_path = Path.joinpath(self._output_dir, shard_name).with_suffix(".safetensors")
                    save_file(model_state_dict, output_path, metadata={"format": "pt"})
                logger.info(
                    "Model checkpoint of size "
                    f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
                    f"saved to {output_path}"
                )

            # Save the appropriate index file based on serialization format
            if self._safe_serialization:
                index_path = Path.joinpath(self._output_dir, "model.safetensors.index.json")
                weight_map = {
                    k: f"model-{int(v):05d}-of-{int(num_shards):05d}_{epoch}.safetensors" 
                    for k, v in self._weight_map.items()
                }
            else:
                index_path = Path.joinpath(self._output_dir, "pytorch_model.bin.index.json")
                weight_map = {
                    k: f"pytorch_model-{int(v):05d}-of-{int(num_shards):05d}_{epoch}.bin"
                    for k, v in self._weight_map.items()
                }

            index_data = {
                "metadata": {"total_size": total_size},
                "weight_map": weight_map
            }
            with open(index_path, "w") as f:
                json.dump(index_data, f, indent=2)

@joecummings
Copy link
Contributor

@vancoykendall This is awesome! Would you like to open a PR on our repo adding this patch? I think it would definitely benefit our entire community :)

@vancoykendall
Copy link
Contributor Author

Sure I'd be happy to. Although I've realized this current method would overwrite the index.json file each epoch since the index file name can't be modified. I could instead save each epoch checkpoint in a separate subfolder? Any thoughts? @joecummings

@felipemello1
Copy link
Contributor

hey @vancoykendall , i am working on a related issue to handle how we save/load files. Would it be fine if i drive the PR and have you as co-author? I am afraid that my changes may undo/conflict with yours. If so, please send me your email on discord, and i can add you as a co-author in the commit, to make sure you get credit for it

adding as co-author: https://docs.github.com/en/pull-requests/committing-changes-to-your-project/creating-and-editing-commits/creating-a-commit-with-multiple-authors
my discord: whynot9753

@vancoykendall
Copy link
Contributor Author

Cool, thanks. I just sent you my email in the discord.

@gordicaleksa
Copy link

Hey guys, I'm hitting this as well - any progress on this? :)

@felipemello1
Copy link
Contributor

felipemello1 commented Nov 26, 2024

@gordicaleksa, I have a draft, but it is not ready to be used: https://github.com/pytorch/torchtune/pull/2074/files . After merged, the outputdir should be more organized and ready to be used by vllm and HF.

TLDR:
Assuming you are doing full finetuning, you can follow the script shared above by @vancoykendall .

  1. create an empty folder

  2. save the checkpointing as .safetensors or bin. Example here also.

  3. create the "model.safetensors.index.json" file (or pytorch_model.bin.index.json if you saved .bin)

  4. add to this folder the tokenizer (not sure if this is necessary)

  5. now you can pass the folder to .from_pretrained or vllm

Sorry you hit this. We are ironing out our integration with HF/vllm.

@gordicaleksa
Copy link

I think i did manage to work around it

snippet from @vancoykendall helped speed it up

thanks guys!

i think tokenizer will also be necessary for vLLM, etc.

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 6, 2024

hey folks, PR is merged: #2074

Now it should be much easier to use vllm/huggingface. Instructions are in the pr description

We will update the docs soon. Let us know if you find any issues and thanks for your patience :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants