Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add meta onDevice support for LLAMA2 #4147

Merged
merged 3 commits into from
Aug 25, 2023

Conversation

dc3671
Copy link
Contributor

@dc3671 dc3671 commented Aug 15, 2023

Problem

Currently LlamaRMSNorm inside LLAMA2 LlamaDecoderLayer won't be handled correctly in autoTP weight load when using deepspeed.onDevice(device="meta").

Solution

Add more special case in loading's if-clause, and abstract it as a method under Loading class along with Loading.load method. For it's used by two places: one for children inside autoTP policy module, and one for outside module.

Also, I aligned the way to add special case for this scenario, which is using string name rather than import the actual module class inside a try-catch-clause.

The method name can be changed if any other better name. @mrwyattii @jeffra Please help review this, thanks~

@awan-10
Copy link
Contributor

awan-10 commented Aug 21, 2023

@molly-smith - please review this PR when you get a chance.

@molly-smith molly-smith self-requested a review August 21, 2023 21:24
molly-smith
molly-smith previously approved these changes Aug 21, 2023
Copy link
Contributor

@molly-smith molly-smith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@molly-smith
Copy link
Contributor

@dc3671 Can you elaborate on the issue you were seeing and for what use case? Maybe share your reproducer script? Meta tensor is not supported with autotp for any model and not supported in the llama container.

@molly-smith molly-smith self-requested a review August 21, 2023 23:50
@molly-smith molly-smith dismissed their stale review August 21, 2023 23:51

Pending response from user

@dc3671
Copy link
Contributor Author

dc3671 commented Aug 22, 2023

@dc3671 Can you elaborate on the issue you were seeing and for what use case? Maybe share your reproducer script? Meta tensor is not supported with autotp for any model and not supported in the llama container.

@molly-smith I think container is only related to kernel injection? Because policy_to_ds_container is only used in replace_with_policy on this line: https://github.com/microsoft/DeepSpeed/blob/7f3e82fe0902ede54201d50e35d11ae4f3954791/deepspeed/module_inject/replace_module.py#L222

For autoTP, the only thing that matters is that deepspeed needs to make sure the replaced Linear or other modules can find and load correct checkpoint from state_dict, which is only related to these two positions:

So I just added LlamaRMSNorm to make sure this module's weight can be loaded correctly, which means not a meta tensor anymore.

I'm using this modified python script for launching: https://github.com/dc3671/intel-extension-for-transformers/blob/llm/examples/huggingface/pytorch/text-generation/inference/run_generation_with_deepspeed.py#L199

If I remove llama from is_meta_support list and run, I can get errors like: NotImplementedError: Cannot copy out of meta tensor; no data!, because LlamaRMSNorm's weight is not loaded correctly.

@dc3671
Copy link
Contributor Author

dc3671 commented Aug 23, 2023

@molly-smith any update?

@mrwyattii
Copy link
Contributor

@dc3671 I still see NotImplementedError: Cannot copy out of meta tensor; no data! when trying to load a llama2 model with meta tensor. I don't think we have support for meta tensor with AutoTP. Can you provide a minimal reproducer to demonstrate this working?

@dc3671
Copy link
Contributor Author

dc3671 commented Aug 24, 2023

@mrwyattii I updated some details of my script for llama2. It can be run with:

mpirun -np 4 python -u run_generation_with_deepspeed.py -m /localdisk/llama2 --benchmark

I use following code to load it first with meta tensor:

config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with deepspeed.OnDevice(dtype=load_dtype, device="meta"):
    model = AutoModelForCausalLM.from_config(config, torch_dtype=load_dtype, trust_remote_code=True)

And then I need to tell init_inference the checkpoint json file again to make sure it can really load the state_dict:

model = deepspeed.init_inference(
    model,
    mp_size=world_size,
    base_dir=repo_root,
    dtype=infer_dtype,
    checkpoint=checkpoints_json if is_meta_support else None,
    **kwargs,
)

The original way of getting checkpoints_json comes from https://github.com/huggingface/transformers-bloom-inference/blob/main/bloom-inference-scripts/bloom-ds-inference.py#L93

I guess maybe you didn't set checkpoint argument in init_inference so that it won't load checkpoint after autoTP.

@molly-smith
Copy link
Contributor

@dc3671 Thank you for your patience and detailed response. I was able to recreate your issue and successfully test your changes. Some of us were not aware that meta tensor support was added to AutoTP but I'm glad to see that it is in fact working. I will merge these changes soon. Thanks again.

@molly-smith molly-smith enabled auto-merge August 24, 2023 22:12
@molly-smith molly-smith added this pull request to the merge queue Aug 24, 2023
Merged via the queue into deepspeedai:master with commit 0712e29 Aug 25, 2023
@lashoun
Copy link

lashoun commented Aug 30, 2023

Hi, does this merge mean that we can now successfully use with deepspeed.OnDevice(dtype=load_dtype, device="meta"): with llama models?
If yes, could you please add details on what the checkpoint file should look like? I tried

{
	"type": "LLAMA", 
	"checkpoints": [
		"pytorch_model-00001-of-00015.bin",
		...
		"pytorch_model-00015-of-00015.bin"
	]
}

but I got AssertionError: LLAMA checkpoint type is not supported.
And if I do not load any checkpoint, I get AssertionError: Meta tensors are not supported for this model currently.
@dc3671 @awan-10 @molly-smith

@dc3671
Copy link
Contributor Author

dc3671 commented Aug 30, 2023

@lashoun This error maybe is because you are running with kernel_injection=True. KI mode need to modify corresponding llama container and it's not contained in this PR.

@ZaVang
Copy link

ZaVang commented Sep 6, 2023

@dc3671 Hi, sorry to bother you, but I encountered the same issue as #3452 . When I load the llama2-13b-hf model normally and enable replace_with_kernel_inject, as follow:

tokenizer = LlamaTokenizer.from_pretrained(args.ckpt)
model = LlamaForCausalLM.from_pretrained(args.ckpt)
ds_engine = deepspeed.init_inference(model,
                                      tensor_parallel = TPConfig,
                                      dtype=torch.float16,
                                      replace_with_kernel_inject=True)

the model output seems fine:
WGKGUFhnPp

However, when I tried to load it with meta ondevice, I found that replace_with_kernel_inject does not currently support the llama2 model. Based on this pr , I modified the code and set replace_with_kernel_inject to False, like:

checkpoint_json = {
	"type": "ds_model", 
        "version": "2",
	"checkpoints": [
        f'/llama2/13b_chat_hf/pytorch_model-0000{i}-of-00003.bin'
        for i in range(1,4)
	]
}
config = LlamaConfig.from_pretrained(args.ckpt)
tokenizer = LlamaTokenizer.from_pretrained(args.ckpt)            
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
    model = LlamaForCausalLM._from_config(
        config, torch_dtype=torch.float16
    )

ds_engine = deepspeed.init_inference(model,
                                    tensor_parallel = TPConfig,
                                    dtype=torch.float16,
                                    checkpoint=checkpoint_json,
                                    replace_with_kernel_inject=False)

but the output became very anomalous.
1H59HK6TcA

I suspect there's an issue with how weights are loaded. Could the 'ds_model' type setting in checkpoint_json be causing incorrect loading? I find that it only support for 'megatronlm', 'ds_model', and 'bloom'. If my understanding is flawed, which type should I use? Any insights would be greatly appreciated.

By the way, I'm using multi-node to inference:

deepspeed --hostfile hostfile --no_local_rank inference.py 

@dc3671
Copy link
Contributor Author

dc3671 commented Sep 7, 2023

@ZaVang I'm not that familiar with this part. But I think "ds_model" is ok according to this function (I'm using "bloom"):

class SDLoaderFactory:

    @staticmethod
    def get_sd_loader_json(json_file, checkpoint_engine):
        if isinstance(json_file, str):
            with open(json_file) as f:
                data = json.load(f)
        else:
            assert isinstance(json_file, dict)
            data = json_file
        sd_type = data['type']
        ckpt_list = data['checkpoints']
        version = data['version']
        ckpt_type = data.get('parallelization', 'pp')
        mp_size = data.get('mp_size', 0)
        if sd_type.lower() in ['bloom', 'ds_model']:
            return data
        return SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine, sd_type, version)

https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/state_dict_factory.py#L36

@ZaVang
Copy link

ZaVang commented Sep 7, 2023

@ZaVang I'm not that familiar with this part. But I think "ds_model" is ok according to this function (I'm using "bloom"):

class SDLoaderFactory:

    @staticmethod
    def get_sd_loader_json(json_file, checkpoint_engine):
        if isinstance(json_file, str):
            with open(json_file) as f:
                data = json.load(f)
        else:
            assert isinstance(json_file, dict)
            data = json_file
        sd_type = data['type']
        ckpt_list = data['checkpoints']
        version = data['version']
        ckpt_type = data.get('parallelization', 'pp')
        mp_size = data.get('mp_size', 0)
        if sd_type.lower() in ['bloom', 'ds_model']:
            return data
        return SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine, sd_type, version)

https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/state_dict_factory.py#L36

fixed by #4259

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants