Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the policy to run llama model from the official repo #4313

Merged
merged 19 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bool rotate_half,
bool rotate_every_two,
int heads,
int num_kv,
float norm_factor,
bool triangular,
bool local_attention,
Expand All @@ -448,14 +449,14 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
unsigned hidden_dim = query_key_value.size(2) / 3;
int k = query_key_value.size(2) / (heads + 2 * (num_kv > 0 ? num_kv : heads));
unsigned hidden_dim = heads * k;

bool is_prompt = (seq_len > 1);

if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();

int k = hidden_dim / heads;
auto options = at::TensorOptions()
.dtype(query_key_value.options().dtype())
.layout(at::kStrided)
Expand Down Expand Up @@ -486,6 +487,7 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
soft_len,
hidden_dim,
heads,
(num_kv > 0 ? num_kv : heads),
rotary_dim,
rotate_half,
rotate_every_two,
Expand Down Expand Up @@ -1167,6 +1169,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
(num_heads * padded_head_size),
num_heads,
-1,
-1,
false,
false,
InferenceContext::Instance().GetCurrentStream(),
Expand All @@ -1192,6 +1195,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
input_cont.size(2),
num_heads,
-1,
-1,
false,
false,
InferenceContext::Instance().GetCurrentStream(),
Expand Down
27 changes: 19 additions & 8 deletions csrc/transformer/inference/csrc/transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ __global__ void bias_add_transform_0213(float* output,
int seq_length,
unsigned seq_offset,
int heads,
int head_stride,
int num_kv,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
Expand All @@ -49,10 +51,10 @@ __global__ void bias_add_transform_0213(float* output,
float4* output_vec =
reinterpret_cast<float4*>(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache));

vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length);
vals_vec += d1 * (d1_stride + num_kv * 2 * d2_stride);
vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride);
vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride);

output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_out_stride);
Expand Down Expand Up @@ -92,6 +94,8 @@ __global__ void bias_add_transform_0213(T* output, // q
unsigned seq_offset,
int all_tokens,
int heads,
int head_stride,
int num_kv,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
Expand Down Expand Up @@ -124,10 +128,10 @@ __global__ void bias_add_transform_0213(T* output, // q
float4* output_vec =
reinterpret_cast<float4*>(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache));

vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
vals_vec += (d0 * (d1_stride + num_kv * 2 * d2_stride) * seq_length);
vals_vec += (d1 * (d1_stride + num_kv * 2 * d2_stride));
vals_vec += (cnt == 0 ? 0 : d1_stride) + (cnt == 0 ? 0 : (cnt - 1) * num_kv * d2_stride);
vals_vec += ((cnt == 0 ? d2 : (d2 / head_stride)) * d2_stride);

output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_out_stride);
Expand Down Expand Up @@ -171,6 +175,7 @@ void launch_bias_add_transform_0213<float>(float* output,
int all_tokens,
int hidden_dim,
int heads,
int num_kv,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
Expand All @@ -193,6 +198,8 @@ void launch_bias_add_transform_0213<float>(float* output,
seq_length,
seq_offset,
heads,
num_kv > 0 ? (heads / num_kv) : 1,
num_kv > 0 ? num_kv : heads,
rotary_dim >> 2,
rotate_half,
rotate_every_two,
Expand All @@ -212,6 +219,7 @@ void launch_bias_add_transform_0213(T* output,
int all_tokens,
int hidden_dim,
int heads,
int num_kv,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
Expand All @@ -233,6 +241,8 @@ void launch_bias_add_transform_0213(T* output,
seq_offset,
all_tokens,
heads,
num_kv > 0 ? (heads / num_kv) : 1,
num_kv > 0 ? num_kv : heads,
rotary_dim >> 3,
rotate_half,
rotate_every_two,
Expand All @@ -253,6 +263,7 @@ void launch_bias_add_transform_0213(T* output,
int, \
int, \
int, \
int, \
bool, \
bool, \
cudaStream_t, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ void launch_bias_add_transform_0213(T* outputs,
int seq_length1,
int hidden_dim,
int heads,
int num_kv,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
Expand Down
69 changes: 69 additions & 0 deletions deepspeed/model_implementations/transformers/ds_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from deepspeed import comm as dist
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference

inference_module = None


class DeepSpeedLlama2Inference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer.
"""

def __init__(self,
config,
mp_group=None,
quantize_scales=None,
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)

def forward(self, *args, **kwargs):

input = args[0]
input_mask = None
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False

get_present = True

# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
self.layer_past = None
layer_past = self.layer_past

input_type = input.dtype

if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
and input.dtype == torch.float:
target_dtype = torch.half if self.dtype == torch.int8 else self.dtype
input = input.to(target_dtype)

with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
input_mask,
None,
layer_past,
get_present,
None, None, None,
self.norm_w,
self.norm_b,
None)
self.layer_past = (key, value)
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)

output = output.to(input_type)
return output
6 changes: 5 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def strided_copy(self,
src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
try:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
except:
print(dst.shape, src.shape)
exit()
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .gptneo import DS_GPTNEOContainer, HFGPTNEOLayerPolicy
from .gptneox import DS_GPTNEOXContainer, GPTNEOXLayerPolicy
from .llama import DS_LLAMAContainer, LLAMALayerPolicy
from .llama2 import LLAMA2LayerPolicy, DS_LLAMA2Container
from .internlm import DS_InternLMContainer, InternLMLayerPolicy
from .megatron_gpt import DS_MegatronGPTContainer, MegatronLayerPolicy
from .megatron_gpt_moe import DS_MegatronGPTMoEContainer, MegatronMoELayerPolicy
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/containers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def initialize_tensors(self, enable_training=False):
self.set_attention(*self.policy.attention(enable_training=enable_training))
self.set_mlp(*self.policy.mlp(enable_training=enable_training))
self.set_layernorm(*self.policy.layernorm())
self.check_meta_tensor_support()
#self.check_meta_tensor_support()

def convert_to_required_dtype(self):
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
Expand Down
Loading