Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the policy to run llama model from the official repo #4313

Merged
merged 19 commits into from
Sep 19, 2023

Conversation

RezaYazdaniAminabadi
Copy link
Contributor

@RezaYazdaniAminabadi RezaYazdaniAminabadi commented Sep 12, 2023

This PR adds the support for Llama2 using the official implementation of llama using llama repo

This is now working for all the llama variant except the ones that require kv-sharing.
Next will add the support for the KV-shared architecture.

@mpjlu
Copy link
Contributor

mpjlu commented Sep 13, 2023

This PR is for the official llama repo, but named llama2. User will be confusing for the file and class names.
Since the 7B and 13B model arch of llama and llama2 is the same, now DS can run llama 2 7B and 13B HF model, just need to support 70B llama 2 model.
It is better to enable 70B llama2 model on the current llama code.

@RezaYazdaniAminabadi
Copy link
Contributor Author

This PR is for the official llama repo, but named llama2. User will be confusing for the file and class names. Since the 7B and 13B model arch of llama and llama2 is the same, now DS can run llama 2 7B and 13B HF model, just need to support 70B llama 2 model. It is better to enable 70B llama2 model on the current llama code.

Hi @mpjlu,

Thanks for the comment.
Agreed on this, and this policy is just a secondary one for me to try with the official repo, as getting access to the llama models on HF takes a long processing time. I will name them better!
Also, I am working on adding the 70B support.
Best,
Reza

@RezaYazdaniAminabadi RezaYazdaniAminabadi changed the title Add the llama2 support from the official llama repo Add the policy to run llama model from the official repo Sep 13, 2023
@RezaYazdaniAminabadi
Copy link
Contributor Author

Btw, the models that HF and Llama repo use are a bit different! At least, I know that they use different rotary-embedding.

@RezaYazdaniAminabadi
Copy link
Contributor Author

RezaYazdaniAminabadi commented Sep 13, 2023

I added some test for checking the performance and accuracy of this PR using a fork of the llama code-base.
There are some script that you can use to run different model configuration, here.
I am seeing about 2.8x performance speedup (using 8 A100 GPUs) when using ds-inference for the Llama-70B model using same example used in the repo (will add more test to check the performance more extensively):

[2023-09-13 13:23:09,286] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 8192, 'intermediate_size': 28672, 'heads': 64, 'num_hidden_layers': -1, 'dtype': torch.float16, 'pre_layer_norm': True, 'norm_type': <NormType.RMSNorm: 3>, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 8, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 128, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': False, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GATED_SILU: 4>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'min_out_tokens': 1, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False, 'set_empty_params': False, 'transposed_mode': False, 'use_triton': False, 'triton_autotune': False, 'num_kv': 8}
Loading extension module transformer_inference...
------------------------------------------------------
Free memory : 35.154175 (GigaBytes)  
Total memory: 79.169678 (GigaBytes)  
Requested memory: 3.875000 (GigaBytes) 
Setting maximum total tokens (input + output) to 1024 
WorkSpace: 0x7f9798000000 
------------------------------------------------------
baseline: generation time 10.940 sec for generating x tokens.
ds-inference: generation time 3.938 sec for generating x tokens.
speeup: 2.778x

@mpjlu
Copy link
Contributor

mpjlu commented Sep 18, 2023

I added some test for checking the performance and accuracy of this PR using a fork of the llama code-base. There are some script that you can use to run different model configuration, here. I am seeing about 2.8x performance speedup (using 8 A100 GPUs) when using ds-inference for the Llama-70B model using same example used in the repo (will add more test to check the performance more extensively):

[2023-09-13 13:23:09,286] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 8192, 'intermediate_size': 28672, 'heads': 64, 'num_hidden_layers': -1, 'dtype': torch.float16, 'pre_layer_norm': True, 'norm_type': <NormType.RMSNorm: 3>, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 8, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 128, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': False, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GATED_SILU: 4>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'min_out_tokens': 1, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False, 'set_empty_params': False, 'transposed_mode': False, 'use_triton': False, 'triton_autotune': False, 'num_kv': 8}
Loading extension module transformer_inference...
------------------------------------------------------
Free memory : 35.154175 (GigaBytes)  
Total memory: 79.169678 (GigaBytes)  
Requested memory: 3.875000 (GigaBytes) 
Setting maximum total tokens (input + output) to 1024 
WorkSpace: 0x7f9798000000 
------------------------------------------------------
baseline: generation time 10.940 sec for generating x tokens.
ds-inference: generation time 3.938 sec for generating x tokens.
speeup: 2.778x

Does this PR support llama-2-70B model? "Llama-70B model" is llama 1 or llama 2?

@RezaYazdaniAminabadi
Copy link
Contributor Author

I added some test for checking the performance and accuracy of this PR using a fork of the llama code-base. There are some script that you can use to run different model configuration, here. I am seeing about 2.8x performance speedup (using 8 A100 GPUs) when using ds-inference for the Llama-70B model using same example used in the repo (will add more test to check the performance more extensively):

[2023-09-13 13:23:09,286] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 8192, 'intermediate_size': 28672, 'heads': 64, 'num_hidden_layers': -1, 'dtype': torch.float16, 'pre_layer_norm': True, 'norm_type': <NormType.RMSNorm: 3>, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 8, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 128, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': False, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GATED_SILU: 4>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'min_out_tokens': 1, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False, 'set_empty_params': False, 'transposed_mode': False, 'use_triton': False, 'triton_autotune': False, 'num_kv': 8}
Loading extension module transformer_inference...
------------------------------------------------------
Free memory : 35.154175 (GigaBytes)  
Total memory: 79.169678 (GigaBytes)  
Requested memory: 3.875000 (GigaBytes) 
Setting maximum total tokens (input + output) to 1024 
WorkSpace: 0x7f9798000000 
------------------------------------------------------
baseline: generation time 10.940 sec for generating x tokens.
ds-inference: generation time 3.938 sec for generating x tokens.
speeup: 2.778x

Does this PR support llama-2-70B model? "Llama-70B model" is llama 1 or llama 2?

It supports Llama-2-70B. Of course, it is a llama-2 model

@mpjlu
Copy link
Contributor

mpjlu commented Sep 18, 2023

I added some test for checking the performance and accuracy of this PR using a fork of the llama code-base. There are some script that you can use to run different model configuration, here. I am seeing about 2.8x performance speedup (using 8 A100 GPUs) when using ds-inference for the Llama-70B model using same example used in the repo (will add more test to check the performance more extensively):

[2023-09-13 13:23:09,286] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 8192, 'intermediate_size': 28672, 'heads': 64, 'num_hidden_layers': -1, 'dtype': torch.float16, 'pre_layer_norm': True, 'norm_type': <NormType.RMSNorm: 3>, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 8, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 128, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': False, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GATED_SILU: 4>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'min_out_tokens': 1, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False, 'set_empty_params': False, 'transposed_mode': False, 'use_triton': False, 'triton_autotune': False, 'num_kv': 8}
Loading extension module transformer_inference...
------------------------------------------------------
Free memory : 35.154175 (GigaBytes)  
Total memory: 79.169678 (GigaBytes)  
Requested memory: 3.875000 (GigaBytes) 
Setting maximum total tokens (input + output) to 1024 
WorkSpace: 0x7f9798000000 
------------------------------------------------------
baseline: generation time 10.940 sec for generating x tokens.
ds-inference: generation time 3.938 sec for generating x tokens.
speeup: 2.778x

Does this PR support llama-2-70B model? "Llama-70B model" is llama 1 or llama 2?

It supports Llama-2-70B. Of course, it is a llama-2 model

the attention of llama-2-70b is GQA(KV-shared arch) , this PR support llama2 in a not KV-shared method, so the KV-cache memory is the same as MHA, right?

@RezaYazdaniAminabadi
Copy link
Contributor Author

I added some test for checking the performance and accuracy of this PR using a fork of the llama code-base. There are some script that you can use to run different model configuration, here. I am seeing about 2.8x performance speedup (using 8 A100 GPUs) when using ds-inference for the Llama-70B model using same example used in the repo (will add more test to check the performance more extensively):

[2023-09-13 13:23:09,286] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 8192, 'intermediate_size': 28672, 'heads': 64, 'num_hidden_layers': -1, 'dtype': torch.float16, 'pre_layer_norm': True, 'norm_type': <NormType.RMSNorm: 3>, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 8, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': 128, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': False, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GATED_SILU: 4>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'min_out_tokens': 1, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False, 'set_empty_params': False, 'transposed_mode': False, 'use_triton': False, 'triton_autotune': False, 'num_kv': 8}
Loading extension module transformer_inference...
------------------------------------------------------
Free memory : 35.154175 (GigaBytes)  
Total memory: 79.169678 (GigaBytes)  
Requested memory: 3.875000 (GigaBytes) 
Setting maximum total tokens (input + output) to 1024 
WorkSpace: 0x7f9798000000 
------------------------------------------------------
baseline: generation time 10.940 sec for generating x tokens.
ds-inference: generation time 3.938 sec for generating x tokens.
speeup: 2.778x

Does this PR support llama-2-70B model? "Llama-70B model" is llama 1 or llama 2?

It supports Llama-2-70B. Of course, it is a llama-2 model

the attention of llama-2-70b is GQA(KV-shared arch) , this PR support llama2 in a not KV-shared method, so the KV-cache memory is the same as MHA, right?

right, that will be added next

@RezaYazdaniAminabadi RezaYazdaniAminabadi added this pull request to the merge queue Sep 19, 2023
Merged via the queue into master with commit 468882f Sep 19, 2023
CurryRice233 pushed a commit to CurryRice233/DeepSpeed that referenced this pull request Sep 28, 2023
* origin/master:
  Allow multiple inference engines in single script (deepspeedai#4384)
  adds triton flash attention2 kernel (deepspeedai#4337)
  Fix llama meta tensor loading in AutoTP and kernel injected inference (deepspeedai#3608)
  Fix min torch version (deepspeedai#4375)
  Fix multinode runner to properly append to PDSH_SSH_ARGS_APPEND (deepspeedai#4373)
  add the missing method (deepspeedai#4363)
  Openfold fix (deepspeedai#4368)
  deepspeed4science japanese blog (deepspeedai#4369)
  deepspeed4science chinese blog (deepspeedai#4366)
  Enable workflow dispatch on Torch 1.10 CI tests (deepspeedai#4361)
  Update conda env to have max pydantic version (deepspeedai#4362)
  add deepspeed4science blog link (deepspeedai#4364)
  added check to avoid undefined behavior when the input_id length is greater than max_tokens (deepspeedai#4349)
  Add the policy to run llama model from the official repo (deepspeedai#4313)
  fix deepspeed4science links (deepspeedai#4358)
  DeepSpeed4Science (deepspeedai#4357)
  Support InternLM (deepspeedai#4137)
  Pass base_dir to model files can be loaded for auto-tp/meta-tensor. (deepspeedai#4348)
@ghost
Copy link

ghost commented Oct 27, 2023

@RezaYazdaniAminabadi How is the progress of adding GQA to support the LLaMA2-70B model? We would like to see if any help is needed to expedite it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants