From 0ce390e03f32f3b00e5e760cc26e65bad18992d7 Mon Sep 17 00:00:00 2001 From: Ammar Ahmad Awan Date: Mon, 10 Apr 2023 13:41:21 -0700 Subject: [PATCH 1/3] create branch for public release (#489) * Merge chatgpt v2 to v3 - finalized (#484) * [squash] staging chatgpt v1 (#463) Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Co-authored-by: yaozhewei Co-authored-by: Tunji Ruwase * [partial] formatting fixes * quantizer fixes * fix for bert tests * formatting fixes * re-enable _param_slice_mappings in z2 * Enable the QKV requires_grad when in training mode (#466) Co-authored-by: Jeff Rasley * fixes for attention enable_training flag * commit to trigger CI * fix for distil-bert param * fixes for training context errors * remove reza's qkv-optimization (#469) Co-authored-by: Jeff Rasley * Chatgpt - Fuse lora params at HybridEngine (#472) Co-authored-by: Jeff Rasley * add option to enable non-pin mode (#473) * Chatgpt - fuse lora non pinned case (#474) * Fix fuse/unfuse lora for Z3 and non-pinned parameter * unfuse_lora_weight for non-pinned case * fix the multiple issue for lora parameters * formatting * fuse lora only when available --------- Co-authored-by: Jeff Rasley * Chatgpt/release inference cache (#475) * Fix fuse/unfuse lora for Z3 and non-pinned parameter * unfuse_lora_weight for non-pinned case * release/retake the inference cache after/before generate * remove duplicated _fuse_lora function * fix formatting * fix hybrid-engine config issue * update formatting * Chatgpt - fuse qkv v2 (#478) Co-authored-by: Jeff Rasley * ChatGPT: Refactor Hybrid Engine Config (#477) Co-authored-by: Lok Chand Koppaka * Inference Workspace Tweaks (#481) * Safety checks around inference workspace allocation, extra flushing * Formatting fixes * Merge fix * Chatgpt/inference tp (#480) * Update the merged-QKV weights only if there is difference with the model parameter * remove the hard-coded size * always reset qkv params to updated ones after running step * Add the infernce-tp group and tensor sharding to run inference in model-parallel mode * optimize the gather/mp-sharding part * Add hybrid_engine changes * fix config issue * Formatting fixes. Reset_qkv duplicate removal. * fix bloom container. * fix format. --------- Co-authored-by: Ammar Ahmad Awan Co-authored-by: Lok Chand Koppaka * fix formatting * more clean-up --------- Co-authored-by: Jeff Rasley Co-authored-by: yaozhewei Co-authored-by: Tunji Ruwase Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Michael Wyatt Co-authored-by: Lok Chand Koppaka Co-authored-by: Connor Holmes Co-authored-by: Ammar Ahmad Awan * fix a bug on lora-fusion (#487) * Cholmes/v3 workspace bugfixes (#488) * Miscellaneous workspace fixes, new config param * Fix typo --------- Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Co-authored-by: Jeff Rasley Co-authored-by: yaozhewei Co-authored-by: Tunji Ruwase Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Michael Wyatt Co-authored-by: Lok Chand Koppaka Co-authored-by: Connor Holmes --- accelerator/cuda_accelerator.py | 1 - csrc/includes/context.h | 10 +- csrc/includes/cpu_adagrad.h | 4 +- csrc/includes/cpu_adam.h | 4 +- csrc/quantization/fake_quantizer.cu | 4 +- csrc/transformer/dropout_kernels.cu | 6 +- csrc/transformer/ds_transformer_cuda.cpp | 47 +- .../transformer/inference/csrc/pt_binding.cpp | 346 ++++++++------- .../inference/includes/inference_context.h | 63 ++- deepspeed/__init__.py | 64 ++- deepspeed/inference/config.py | 15 + .../transformers/ds_transformer.py | 19 +- deepspeed/module_inject/containers/base.py | 302 +++++++++++-- deepspeed/module_inject/containers/bert.py | 19 +- deepspeed/module_inject/containers/bloom.py | 13 +- deepspeed/module_inject/containers/clip.py | 9 +- .../module_inject/containers/distil_bert.py | 15 +- .../containers/features/meta_tensor.py | 8 +- deepspeed/module_inject/containers/gpt2.py | 11 +- deepspeed/module_inject/containers/gptj.py | 13 +- deepspeed/module_inject/containers/gptneo.py | 13 +- deepspeed/module_inject/containers/gptneox.py | 11 +- .../module_inject/containers/megatron_gpt.py | 11 +- deepspeed/module_inject/containers/opt.py | 37 +- deepspeed/module_inject/layers.py | 32 +- deepspeed/module_inject/policy.py | 33 +- deepspeed/module_inject/replace_module.py | 18 +- deepspeed/ops/transformer/inference/config.py | 8 +- .../ops/transformer/inference/ds_attention.py | 82 +++- deepspeed/ops/transformer/inference/ds_mlp.py | 48 ++- .../inference/op_binding/gelu_gemm.py | 3 +- .../inference/op_binding/linear.py | 3 +- .../inference/op_binding/mlp_gemm.py | 10 +- .../inference/op_binding/qkv_gemm.py | 5 +- .../inference/op_binding/vector_matmul.py | 4 +- deepspeed/runtime/config.py | 18 + deepspeed/runtime/engine.py | 24 +- deepspeed/runtime/fp16/loss_scaler.py | 1 + deepspeed/runtime/hybrid_engine.py | 404 ++++++++++++++++++ deepspeed/runtime/utils.py | 29 ++ deepspeed/runtime/zero/config.py | 8 +- deepspeed/runtime/zero/linear.py | 1 - deepspeed/runtime/zero/parameter_offload.py | 6 +- .../runtime/zero/partition_parameters.py | 4 +- .../zero/partitioned_param_coordinator.py | 18 +- deepspeed/runtime/zero/stage3.py | 3 +- tests/hybrid_engine/hybrid_engine_config.json | 19 + tests/hybrid_engine/hybrid_engine_test.py | 30 ++ 48 files changed, 1448 insertions(+), 408 deletions(-) create mode 100644 deepspeed/runtime/hybrid_engine.py create mode 100644 tests/hybrid_engine/hybrid_engine_config.json create mode 100644 tests/hybrid_engine/hybrid_engine_test.py diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index a613e0d5bef5..36341a3c19b3 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -26,7 +26,6 @@ def __init__(self): # put all valid class name <--> class type mapping into class_dict op_builder_dir = self.op_builder_dir() op_builder_module = importlib.import_module(op_builder_dir) - for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]): # avoid self references if module_name != 'all_ops' and module_name != 'builder': diff --git a/csrc/includes/context.h b/csrc/includes/context.h index 99da84b94e8b..3a9067dc3b9f 100644 --- a/csrc/includes/context.h +++ b/csrc/includes/context.h @@ -44,9 +44,9 @@ inline int DS_GET_BLOCKS(const int N) 1); } -class Context { +class TrainingContext { public: - Context() : _workspace(nullptr), _seed(42), _curr_offset(0) + TrainingContext() : _workspace(nullptr), _seed(42), _curr_offset(0) { curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); curandSetPseudoRandomGeneratorSeed(_gen, 123); @@ -57,15 +57,15 @@ class Context { } } - virtual ~Context() + virtual ~TrainingContext() { cublasDestroy(_cublasHandle); cudaFree(_workspace); } - static Context& Instance() + static TrainingContext& Instance() { - static Context _ctx; + static TrainingContext _ctx; return _ctx; } diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index a968a2ba4f16..ba40fcf7b62a 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -39,8 +39,8 @@ class Adagrad_Optimizer { cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - _streams[0] = Context::Instance().GetCurrentStream(); - _streams[1] = Context::Instance().GetNewStream(); + _streams[0] = TrainingContext::Instance().GetCurrentStream(); + _streams[1] = TrainingContext::Instance().GetNewStream(); _buf_index = false; #endif } diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index f28303cbbf45..4648aede93ee 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -54,8 +54,8 @@ class Adam_Optimizer { cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - _streams[0] = Context::Instance().GetCurrentStream(); - _streams[1] = Context::Instance().GetNewStream(); + _streams[0] = TrainingContext::Instance().GetCurrentStream(); + _streams[1] = TrainingContext::Instance().GetNewStream(); _buf_index = false; #endif } diff --git a/csrc/quantization/fake_quantizer.cu b/csrc/quantization/fake_quantizer.cu index b7a413423a79..0f53e5235921 100644 --- a/csrc/quantization/fake_quantizer.cu +++ b/csrc/quantization/fake_quantizer.cu @@ -457,7 +457,7 @@ void launch_sr_fake_quantize_kernel(T* vals, dim3 grid_dim(group_num); uint64_t inc = total_count / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); sr_fake_quantize_kernel<<>>( vals, (total_count / group_num) / 4, group_num, num_bits, seed); @@ -1011,7 +1011,7 @@ void launch_sr_fake_quantize_kernel_asym(T* vals, dim3 grid_dim(group_num); uint64_t inc = total_count / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); sr_fake_quantize_kernel<<>>( vals, (total_count / group_num) / 4, group_num, num_bits, seed); diff --git a/csrc/transformer/dropout_kernels.cu b/csrc/transformer/dropout_kernels.cu index c57dc7875805..97cb3592dcdc 100644 --- a/csrc/transformer/dropout_kernels.cu +++ b/csrc/transformer/dropout_kernels.cu @@ -278,7 +278,7 @@ void launch_dropout(T* out, grid_dim.x <<= 1; } uint64_t inc = total_count / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); if (bwd) dropout_kernel_bwd<<>>( total_count, ratio, vals, out, mask, seed); @@ -625,7 +625,7 @@ void launch_dropout(T* out, dim3 block_dim = DS_CUDA_NUM_THREADS; uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); dropout_kernel<<>>( total_count, dim, ratio, bias, out, mask, seed); @@ -847,7 +847,7 @@ void launch_dropout(T* out, dim3 block_dim = DS_CUDA_NUM_THREADS; uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); + std::pair seed = TrainingContext::Instance().IncrementOffset(inc); dropout_kernel<<>>( total_count, dim, ratio, input, residual, bias, out, mask, seed); diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index c1ec7b58b7b8..80cbd72d09a4 100644 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -78,8 +78,8 @@ BertTransformerLayer::BertTransformerLayer(unsigned layer_id, _normalize_invertible(normalize_invertible), _gelu_checkpoint(gelu_checkpoint), _stochastic_mode(stochastic_mode), - _stream(Context::Instance().GetCurrentStream()), - _cublasHandle(Context::Instance().GetCublasHandle()), + _stream(TrainingContext::Instance().GetCurrentStream()), + _cublasHandle(TrainingContext::Instance().GetCublasHandle()), _qkv_linear(typename FeedForward::Config(batch_size * seq_length, 3 * hidden_size, hidden_size, @@ -183,7 +183,7 @@ void BertTransformerLayer::Forward(unsigned bsz, if (!_stochastic_mode) cudaStreamSynchronize(_stream); - T* workspace = static_cast(Context::Instance().GetWorkSpace()); + T* workspace = static_cast(TrainingContext::Instance().GetWorkSpace()); size_t small_buf_size = bsz * _seq_length * _hidden_size; T* buf_0 = workspace; T* buf_1 = buf_0 + small_buf_size; @@ -343,7 +343,7 @@ void BertTransformerLayer::Backward(unsigned bsz, if (!_stochastic_mode) cudaStreamSynchronize(_stream); - T* workspace = static_cast(Context::Instance().GetWorkSpace()); + T* workspace = static_cast(TrainingContext::Instance().GetWorkSpace()); size_t small_buf_size = bsz * _seq_length * _hidden_size; T* buf_0 = workspace; T* buf_1 = buf_0 + small_buf_size; @@ -609,25 +609,26 @@ int create_transformer_layer(unsigned layer_id, bool gelu_checkpoint, bool stochastic_mode) { - Context::Instance().SetSeed(seed); - Context::Instance().TestGemmFP16( + TrainingContext::Instance().SetSeed(seed); + TrainingContext::Instance().TestGemmFP16( test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); - auto layer = std::make_shared>(layer_id, - batch_size, - hidden_dim, - num_heads, - intermediate_size, - init_seq_length, - attn_dropout_ratio, - hidden_dropout_ratio, - layer_norm_eps, - pre_or_postLayerNorm, - Context::Instance().GetGemmAlgos(), - attn_dropout_checkpoint, - normalize_invertible, - gelu_checkpoint, - stochastic_mode); + auto layer = + std::make_shared>(layer_id, + batch_size, + hidden_dim, + num_heads, + intermediate_size, + init_seq_length, + attn_dropout_ratio, + hidden_dropout_ratio, + layer_norm_eps, + pre_or_postLayerNorm, + TrainingContext::Instance().GetGemmAlgos(), + attn_dropout_checkpoint, + normalize_invertible, + gelu_checkpoint, + stochastic_mode); s_transformer_layers[layer_id] = layer; @@ -725,7 +726,7 @@ std::vector ds_transformer_forward(unsigned layer_id, layer->IsTrainingMode(), layer->GeluCheckpoint())}, options); - Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); + TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr()); auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); @@ -909,7 +910,7 @@ std::vector ds_transformer_backward(unsigned layer_id, layer->IsTrainingMode(), layer->GeluCheckpoint())}, options); - Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); + TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr()); auto grad_input = torch::empty_like(input); auto grad_attn_qkvw = torch::empty_like(attn_qkvw); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index d91649133fc9..3de59e11377a 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -96,7 +96,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, head_offset, mask_stride, mp_size, - Context::Instance().GetCurrentStream(async_op)); + InferenceContext::Instance().GetCurrentStream(async_op)); return attn_scores_c; } @@ -110,18 +110,20 @@ void allocate_workspace(unsigned hidden_dim, unsigned mp_size = 1, bool external_cache = false, unsigned rank = 0, - unsigned max_out_tokens = 1024) + unsigned max_out_tokens = 1024, + unsigned min_out_tokens = 1) { - Context::Instance().GenWorkSpace(num_layers, - num_heads, - batch_size, - prompt_length, - hidden_dim, - mp_size, - external_cache, - sizeof(T), - rank, - max_out_tokens); + InferenceContext::Instance().GenWorkSpace(num_layers, + num_heads, + batch_size, + prompt_length, + hidden_dim, + mp_size, + external_cache, + sizeof(T), + rank, + max_out_tokens, + min_out_tokens); } template @@ -132,15 +134,15 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); float alpha = 1; float gemm_beta = 0.0; /* // Reallocate memory if we received a new prompt if (!workspace || input.size(1) != 1) { - allocate_workspace(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1, - head_size); workspace = (T*)Context::Instance().GetWorkSpace(); + allocate_workspace(W.size(1), InferenceContext::Instance().GetMaxTokenLenght(), + Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace(); } */ @@ -148,7 +150,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) unsigned m = W.size(1); unsigned n = Q.size(1) * Q.size(2); unsigned k = Q.size(0); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), CUBLAS_OP_N, CUBLAS_OP_T, m, @@ -195,8 +197,9 @@ void attention_unfused(at::Tensor& prev_key_cont, auto mask_stride = get_attn_mask_stride(attn_mask); - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), soft_len, seq_len, k, @@ -231,9 +234,9 @@ void attention_unfused(at::Tensor& prev_key_cont, 0, mask_stride, 1, - Context::Instance().GetCurrentStream(false)); + InferenceContext::Instance().GetCurrentStream(false)); alpha = 1.0; - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), k, seq_len, soft_len, @@ -364,10 +367,11 @@ void attention_unfused(T* prev_key_cont, float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float alpha = norm_factor * norm_factor / layer_scale; float gemm_beta = 0.0; - T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace(); + T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace(); - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), soft_len, seq_len, k, @@ -378,7 +382,7 @@ void attention_unfused(T* prev_key_cont, workspace, CUBLAS_OP_T, CUBLAS_OP_N, - Context::Instance().GetMaxTokenLenght() * k, + InferenceContext::Instance().GetMaxTokenLenght() * k, seq_len * k, seq_len * soft_len, bsz * heads, @@ -400,7 +404,7 @@ void attention_unfused(T* prev_key_cont, soft_len, heads); alpha = 1.0; - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(), k, seq_len, soft_len, @@ -411,7 +415,7 @@ void attention_unfused(T* prev_key_cont, (T*)output, CUBLAS_OP_N, CUBLAS_OP_N, - Context::Instance().GetMaxTokenLenght() * k, + InferenceContext::Instance().GetMaxTokenLenght() * k, seq_len * soft_len, seq_len * k, bsz * heads, @@ -422,7 +426,7 @@ void attention_unfused(T* prev_key_cont, #endif } -void reset_cache() { Context::Instance().reset_tokens(); } +void reset_cache() { InferenceContext::Instance().reset_tokens(); } template std::vector ds_softmax_context(at::Tensor& query_key_value, @@ -446,8 +450,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bool is_prompt = (seq_len > 1); - if (is_prompt) Context::Instance().reset_tokens(seq_len); - unsigned soft_len = Context::Instance().current_tokens(); + if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len); + unsigned soft_len = InferenceContext::Instance().current_tokens(); int k = hidden_dim / heads; auto options = at::TensorOptions() @@ -456,16 +460,17 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, .device(at::kCUDA) .requires_grad(false); - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); size_t buf_size = bsz * seq_len * hidden_dim; - auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); + auto output = torch::from_blob(workspace + 3 * buf_size, {bsz, seq_len, hidden_dim}, options); - auto query_cont = workspace + 8 * buf_size; - size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) + - layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; + auto query_cont = workspace + 4 * buf_size; + size_t offset = + 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) + + layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim; unsigned all_tokens = soft_len; auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); - size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; + size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim; T* temp_buf = (T*)output.data_ptr() + at::numel(output); launch_bias_add_transform_0213((T*)query_cont, @@ -482,9 +487,9 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, rotary_dim, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream(), + InferenceContext::Instance().GetCurrentStream(), 3, - Context::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetMaxTokenLenght()); if (rotary_dim > 0 && rotate_half) launch_apply_rotary_pos_emb(query_cont, kv_cache, @@ -496,8 +501,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream(), - Context::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLenght()); attention_unfused(workspace + offset, (T*)query_cont, @@ -522,25 +527,26 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, heads, seq_len, output.size(2), - Context::Instance().GetCurrentStream(false), + InferenceContext::Instance().GetCurrentStream(false), 1); - if (layer_id == num_layers - 1) Context::Instance().advance_tokens(); + if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, - {hidden_dim * Context::Instance().GetMaxTokenLenght(), - k * Context::Instance().GetMaxTokenLenght(), + {hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(), + k * InferenceContext::Instance().GetMaxTokenLenght(), k, 1}, options); - auto prev_value = torch::from_blob(workspace + offset + value_offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * Context::Instance().GetMaxTokenLenght(), - k * Context::Instance().GetMaxTokenLenght(), - k, - 1}, - options); + auto prev_value = + torch::from_blob(workspace + offset + value_offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(), + k * InferenceContext::Instance().GetMaxTokenLenght(), + k, + 1}, + options); return {output, prev_key, prev_value}; } @@ -557,7 +563,7 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) (T*)bias.data_ptr(), intermediate_size, bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return input_cont; } @@ -583,14 +589,14 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias) (const float*)bias.data_ptr(), rows, channels, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } else { launch_fused_bias_geglu((__half*)output.data_ptr(), (const __half*)activation.data_ptr(), (const __half*)bias.data_ptr(), rows, channels, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } return output; @@ -608,7 +614,7 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) (T*)bias.data_ptr(), intermediate_size, bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return input_cont; } @@ -624,7 +630,7 @@ at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias) (T*)bias.data_ptr(), hidden_size, bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return input_cont; } @@ -641,7 +647,7 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& // bsz, // input_cont.size(2), // (bias.size(0) > 1), - // Context::Instance().GetCurrentStream()); + // InferenceContext::Instance().GetCurrentStream()); return input_cont; } @@ -659,7 +665,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, epsilon, rows, elems_per_row, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } else { launch_fused_ln((float*)output.data_ptr(), (const float*)input.data_ptr(), @@ -668,7 +674,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, epsilon, rows, elems_per_row, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } return output; @@ -689,7 +695,7 @@ void ds_layer_norm_internal(T* workspace, epsilon, bsz, input.size(2), - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } /* Currently only used in unit testing */ @@ -714,7 +720,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input, epsilon, rows, elems_per_row, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } else { launch_fused_residual_ln((float*)output.data_ptr(), (const float*)input.data_ptr(), @@ -725,7 +731,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input, epsilon, rows, elems_per_row, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } return output; @@ -755,7 +761,7 @@ std::vector ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu epsilon, rows, elems_per_row, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } else { launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(), (float*)res_output.data_ptr(), @@ -767,7 +773,7 @@ std::vector ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu epsilon, rows, elems_per_row, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } return {norm_output, res_output}; @@ -782,7 +788,7 @@ void quantized_gemm(void* output, int bsz, int hidden_size) { - // T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz; + // T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz; auto options = at::TensorOptions() .dtype(at::kHalf) @@ -797,11 +803,11 @@ void quantized_gemm(void* output, weight.size(0), weight.size(1), groups, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublas_gemm_ex(Context::Instance().GetCublasHandle(), + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), CUBLAS_OP_T, CUBLAS_OP_N, weight.size(0), @@ -829,10 +835,11 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, at::Tensor& beta, const float epsilon, bool add_bias, - bool q_int8) + bool q_int8, + bool transposed_mode) { int bsz = input.size(0) * input.size(1); - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); workspace += (3 * bsz * input.size(2)); ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); @@ -843,12 +850,12 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, - weight.size(1), + weight.size(transposed_mode ? 0 : 1), bsz, input.size(2), &alpha, @@ -865,9 +872,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, if (add_bias) launch_bias_add((T*)output.data_ptr(), (T*)bias.data_ptr(), - q_int8 ? weight.size(0) : weight.size(1), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return torch::from_blob(workspace, input.sizes(), input.options()); } @@ -884,11 +891,12 @@ std::vector ds_qkv_gemm(at::Tensor& input, bool external_cache, unsigned mp_size, unsigned rank, - bool q_int8) + bool q_int8, + bool transposed_mode) { int bsz = input.size(0) * input.size(1); - T* workspace = (T*)Context::Instance().GetWorkSpace(); - int out_size = q_int8 ? weight.size(0) : weight.size(1); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); + int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); auto options = at::TensorOptions() .dtype(input.options().dtype()) @@ -897,8 +905,17 @@ std::vector ds_qkv_gemm(at::Tensor& input, .requires_grad(false); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); - auto inp_norm = qkv_unfused_cublas( - output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8); + auto inp_norm = qkv_unfused_cublas(output, + input, + weight, + q_scale, + bias, + gamma, + beta, + epsilon, + add_bias, + q_int8, + transposed_mode); return {output, inp_norm}; } @@ -926,11 +943,11 @@ void quantized_gemm(at::Tensor& output, weight.size(1), groups, merge_count, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublas_gemm_ex(Context::Instance().GetCublasHandle(), + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), CUBLAS_OP_T, CUBLAS_OP_N, weight.size(0), @@ -977,7 +994,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input, (T*)bias.data_ptr(), weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return output; } @@ -988,7 +1005,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& bias, bool add_bias, bool do_flash_attn, - int num_heads) + int num_heads, + bool transposed_mode) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -999,17 +1017,18 @@ at::Tensor ds_linear_layer(at::Tensor& input, int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), + weight.size(transposed_mode ? 0 : 1), bsz, input_cont.size(2), &alpha, @@ -1025,9 +1044,9 @@ at::Tensor ds_linear_layer(at::Tensor& input, if (add_bias) launch_bias_add((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + weight.size(transposed_mode ? 0 : 1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); if (do_flash_attn) { if (add_padding) { @@ -1040,7 +1059,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, 3 * bsz * num_heads, head_size, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); launch_bias_add_transform_0213( final_output, @@ -1057,7 +1076,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, -1, false, false, - Context::Instance().GetCurrentStream(), + InferenceContext::Instance().GetCurrentStream(), 3, input.size(1)); return at::from_blob(final_output, @@ -1082,7 +1101,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, -1, false, false, - Context::Instance().GetCurrentStream(), + InferenceContext::Instance().GetCurrentStream(), 3, input.size(1)); return at::from_blob( @@ -1100,7 +1119,7 @@ std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tens { int head_size = query.size(3); int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; pad_head_seq(workspace, @@ -1110,7 +1129,7 @@ std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tens query.size(2), head_size, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); pad_head_seq(key_pad_ptr, (T*)key.data_ptr(), query.size(0) * query.size(1), @@ -1118,7 +1137,7 @@ std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tens 128, head_size, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); pad_head_seq(value_pad_ptr, (T*)value.data_ptr(), query.size(0) * query.size(1), @@ -1126,7 +1145,7 @@ std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tens 128, head_size, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return { at::from_blob(workspace, {query.size(0), query.size(1), query.size(2), padded_head_size}, @@ -1148,7 +1167,7 @@ std::vector padd_add_transform(at::Tensor& query, int key_value_length = add_padding ? 128 : key.size(1); int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) : head_size; - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; launch_pad_add_transform_0213(workspace, @@ -1159,7 +1178,7 @@ std::vector padd_add_transform(at::Tensor& query, query.size(1), heads, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); launch_pad_add_transform_0213(key_pad_ptr, (T*)key.data_ptr(), key.size(0), @@ -1168,7 +1187,7 @@ std::vector padd_add_transform(at::Tensor& query, key_value_length, heads, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); launch_pad_add_transform_0213(value_pad_ptr, (T*)value.data_ptr(), value.size(0), @@ -1177,7 +1196,7 @@ std::vector padd_add_transform(at::Tensor& query, key_value_length, heads, padded_head_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return { at::from_blob( workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), @@ -1210,7 +1229,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input, (T*)bias.data_ptr(), weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return output; } @@ -1219,7 +1238,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op, at::Tensor& q_scale, - bool q_int8) + bool q_int8, + bool transposed_mode) { auto options = at::TensorOptions() .dtype(input.options().dtype()) @@ -1229,7 +1249,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input, int out_size = q_int8 ? weight.size(0) : weight.size(1); int bsz = input.size(0) * input.size(1); - T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* workspace = (T*)InferenceContext::Instance().GetWorkSpace(); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); if (q_int8) { quantized_gemm(output.data_ptr(), @@ -1242,12 +1262,12 @@ at::Tensor ds_vector_matmul(at::Tensor& input, } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream(async_op)); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream(async_op)); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, - weight.size(1), + weight.size(transposed_mode ? 0 : 1), bsz, input.size(2), &alpha, @@ -1300,11 +1320,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, at::Tensor& q_scale, at::Tensor& q_scale1, bool q_int8, - ActivationFuncType act_func_type) + ActivationFuncType act_func_type, + bool transposed_mode) { int bsz = input.size(0) * input.size(1); - T* inp_norm = - (T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output); + T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) + + torch::numel(output); T* intermediate = inp_norm + torch::numel(input); if (mlp_after_attn) { @@ -1317,7 +1338,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, epsilon, bsz, input.size(2), - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } else { ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); } @@ -1327,12 +1348,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), + weight.size(transposed_mode ? 0 : 1), bsz, input.size(2), &alpha, @@ -1349,15 +1370,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, if (act_func_type == ActivationFuncType::GELU) { launch_bias_gelu(intermediate, (T*)bias.data_ptr(), - q_int8 ? weight.size(0) : weight.size(1), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } else if (act_func_type == ActivationFuncType::ReLU) { launch_bias_relu(intermediate, (T*)bias.data_ptr(), - q_int8 ? weight.size(0) : weight.size(1), + (transposed_mode || q_int8) ? weight.size(0) : weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); } if (q_int8) { @@ -1371,14 +1392,14 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, - weight1.size(1), + weight1.size(transposed_mode ? 0 : 1), bsz, - weight1.size(0), + weight1.size(transposed_mode ? 1 : 0), &alpha, &gemm_beta, (T*)weight1.data_ptr(), @@ -1409,7 +1430,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& q_scale, at::Tensor& q_scale1, bool q_int8, - int activation_type) + int activation_type, + bool transposed_mode) { auto options = at::TensorOptions() .dtype(input.options().dtype()) @@ -1417,10 +1439,11 @@ std::vector ds_mlp_gemm(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); - auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), - {input.size(0), input.size(1), out_size}, - options); + int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1); + auto output = + at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input), + {input.size(0), input.size(1), out_size}, + options); int bsz = input.size(0) * input.size(1); auto act_func_type = static_cast(activation_type); @@ -1439,7 +1462,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, q_scale, q_scale1, q_int8, - act_func_type); + act_func_type, + transposed_mode); return {output, res_add}; } @@ -1475,7 +1499,7 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, (T*)bias.data_ptr(), weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return {output, residual_add}; } @@ -1490,7 +1514,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, const float epsilon, bool preLayerNorm, bool q_int8, - bool async_op) + bool async_op, + bool transposed_mode) { auto options = at::TensorOptions() .dtype(input.options().dtype()) @@ -1498,9 +1523,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - int intm_dim = q_int8 ? weight.size(0) : weight.size(1); + int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1); - // auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), + // auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + + // torch::numel(input), // {input.size(0), input.size(1), out_size}, // options); // T* intermediate = (T*)input.data_ptr() + torch::numel(input); @@ -1519,10 +1545,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, bsz, input.size(2)); } else { - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, + cublasSetStream(InferenceContext::Instance().GetCublasHandle(), + InferenceContext::Instance().GetCurrentStream()); + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, intm_dim, bsz, @@ -1542,9 +1568,9 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)bias.data_ptr(), intm_dim, bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); - int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); + int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1); auto output = at::empty({input.size(0), input.size(1), out_size}, options); if (q_int8) { quantized_gemm(output.data_ptr(), @@ -1555,8 +1581,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, bsz, input.size(2)); } else { - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, + cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(), + (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N), CUBLAS_OP_N, out_size, bsz, @@ -1572,8 +1598,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif } - // cudaEventRecord(Context::Instance().GetCompEvent(2), - // Context::Instance().GetCurrentStream(true)); + // cudaEventRecord(InferenceContext::Instance().GetCompEvent(2), + // InferenceContext::Instance().GetCurrentStream(true)); return output; } @@ -1600,7 +1626,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, hidden_size, mp_size, preln, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); else launch_gptj_residual_add( static_cast(residual.data_ptr()), @@ -1611,7 +1637,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, hidden_size, bsz, mp_size, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return residual; } @@ -1641,8 +1667,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream(), - Context::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLenght()); else launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), (__half*)key_cont.data_ptr(), @@ -1654,8 +1680,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream(), - Context::Instance().GetMaxTokenLenght()); + InferenceContext::Instance().GetCurrentStream(), + InferenceContext::Instance().GetMaxTokenLenght()); return {query_cont, key_cont}; } @@ -1684,7 +1710,7 @@ at::Tensor fused_gemm_gelu_int8(at::Tensor& input, (T*)bias.data_ptr(), weight.size(1), bsz, - Context::Instance().GetCurrentStream()); + InferenceContext::Instance().GetCurrentStream()); return output; } @@ -1693,7 +1719,7 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out { int M = moe_res.size(0) * moe_res.size(1); int N = moe_res.size(2); - Context::Instance().SynchComm(); + InferenceContext::Instance().SynchComm(); if (moe_res.scalar_type() == at::kFloat) { launch_moe_res_matmul((float*)moe_res.data_ptr(), (float*)coef.data_ptr(), @@ -1712,6 +1738,10 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out return output; } +void ds_release_workspace() { InferenceContext::Instance().release_workspace(); } + +bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); } + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); @@ -1791,4 +1821,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &allocate_workspace<__half>, "DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"); m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); + m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace"); + m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace"); } diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 99eb03392631..f7bbcad91e2a 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -46,17 +46,20 @@ inline int DS_GET_BLOCKS(const int N) 1); } -class Context { +class InferenceContext { public: - Context() + InferenceContext() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0), _free_memory_size(0), _num_tokens(1), - _attention_unfused_workspace_offset(0) + _attention_unfused_workspace_offset(0), + _workSpaceSize(0) { + _workSpaceSize = 0; + _workspace = 0; if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { auto message = std::string("Fail to create cublas handle."); std::cerr << message << std::endl; @@ -71,7 +74,7 @@ class Context { cudaEventCreate(&_comm_event); } - virtual ~Context() + virtual ~InferenceContext() { cublasDestroy(_cublasHandle); cudaFree(_workspace); @@ -81,9 +84,9 @@ class Context { cudaEventDestroy(_comm_event); } - static Context& Instance() + static InferenceContext& Instance() { - static Context _ctx; + static InferenceContext _ctx; return _ctx; } @@ -96,7 +99,8 @@ class Context { const bool& external_cache, const size_t& elem_size, const unsigned& rank, - unsigned max_out_tokens) + unsigned max_out_tokens, + unsigned min_out_tokens) { size_t total_size; if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } @@ -107,9 +111,9 @@ class Context { const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128); const int effective_head_size = (head_size > 128) ? head_size : padded_head_size; - size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size; + size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size; // Other sequence length dimension is added when the final workSpaceSize is calculated - size_t temp_size = batch_size * num_heads * max_out_tokens * 2; + size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens; size_t cache_size = num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2; size_t minimal_requirements = @@ -129,25 +133,37 @@ class Context { : (activation_size + temp_size + cache_size))) * _max_seq_len * elem_size; temp_size *= _max_seq_len * elem_size; - if (rank == 0 && !_workspace) + + if (_max_seq_len < min_out_tokens) { + printf( + "Allocatable workspace available (%d tokens) is less than minimum requested " + "workspace (%d tokens)\n", + _max_seq_len, + min_out_tokens); + throw std::runtime_error("Workspace can't be allocated, not enough memory"); + } + + if (!_workspace) { + assert(_workspace == nullptr); + cudaMalloc(&_workspace, workSpaceSize); + } else if (_workSpaceSize < workSpaceSize) { + cudaFree(_workspace); + cudaMalloc(&_workspace, workSpaceSize); + } + if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize)) printf( "------------------------------------------------------\n" "Free memory : %f (GigaBytes) \n" "Total memory: %f (GigaBytes) \n" "Requested memory: %f (GigaBytes) \n" "Setting maximum total tokens (input + output) to %lu \n" + "WorkSpace: %p \n" "------------------------------------------------------\n", (float)_free_memory_size / GIGABYTE, (float)total_size / GIGABYTE, (float)workSpaceSize / GIGABYTE, - _max_seq_len); - if (!_workspace) { - assert(_workspace == nullptr); - cudaMalloc(&_workspace, workSpaceSize); - } else if (_workSpaceSize < workSpaceSize) { - cudaFree(_workspace); - cudaMalloc(&_workspace, workSpaceSize); - } + _max_seq_len, + _workspace); if (!_workspace) { printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", @@ -203,6 +219,17 @@ class Context { return stream; } + void release_workspace() + { + cudaFree(_workspace); + _workspace = nullptr; + } + bool retake_workspace() + { + if (_workspace != nullptr || _workSpaceSize == 0) return true; + cudaMalloc(&_workspace, _workSpaceSize); + return _workspace != nullptr; + } cublasHandle_t GetCublasHandle() { return _cublasHandle; } std::pair IncrementOffset(uint64_t offset_inc) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 5329d83d01ae..12f26b1927af 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -17,6 +17,7 @@ from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER +from .runtime.hybrid_engine import DeepSpeedHybridEngine from .runtime.pipe.engine import PipelineEngine from .inference.engine import InferenceEngine from .inference.config import DeepSpeedInferenceConfig @@ -26,7 +27,7 @@ from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .module_inject import replace_transformer_layer, revert_transformer_layer -from .utils import log_dist, OnDevice +from .utils import log_dist, OnDevice, logger from .comm.comm import init_distributed from .runtime import zero @@ -118,31 +119,66 @@ def initialize(args=None, assert model is not None, "deepspeed.initialize requires a model" + # Set config using config_params for backwards compat + if config is None and config_params is not None: + config = config_params + + # Check for deepscale_config for backwards compat + if hasattr(args, "deepscale_config") and args.deepscale_config is not None: + logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") + if hasattr(args, "deepspeed_config"): + assert (args.deepspeed_config is + None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" + args.deepspeed_config = args.deepscale_config + args.deepscale_config = None + + # Check that we have only one config passed + if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None: + assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call" + config = args.deepspeed_config + assert config != None, "DeepSpeed requires --deepspeed_config to specify configuration file" + if not isinstance(model, PipelineModule): - engine = DeepSpeedEngine(args=args, - model=model, - optimizer=optimizer, - model_parameters=model_parameters, - training_data=training_data, - lr_scheduler=lr_scheduler, - mpu=mpu, - dist_init_required=dist_init_required, - collate_fn=collate_fn, - config=config, - config_params=config_params) + config_class = DeepSpeedConfig(config, mpu) + if config_class.hybrid_engine.enabled: + engine = DeepSpeedHybridEngine(args=args, + model=model, + optimizer=optimizer, + model_parameters=model_parameters, + training_data=training_data, + lr_scheduler=lr_scheduler, + mpu=mpu, + dist_init_required=dist_init_required, + collate_fn=collate_fn, + config=config, + config_class=config_class) + else: + engine = DeepSpeedEngine(args=args, + model=model, + optimizer=optimizer, + model_parameters=model_parameters, + training_data=training_data, + lr_scheduler=lr_scheduler, + mpu=mpu, + dist_init_required=dist_init_required, + collate_fn=collate_fn, + config=config, + config_class=config_class) else: assert mpu is None, "mpu must be None with pipeline parallelism" + mpu = model.mpu() + config_class = DeepSpeedConfig(config, mpu) engine = PipelineEngine(args=args, model=model, optimizer=optimizer, model_parameters=model_parameters, training_data=training_data, lr_scheduler=lr_scheduler, - mpu=model.mpu(), + mpu=mpu, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, - config_params=config_params) + config_class=config_class) return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler] return tuple(return_items) diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index cf112eb3c571..70a67c062ad2 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -197,6 +197,11 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): This can be passed through the json config too. """ + set_empty_params: bool = False + """ + specifying whether the inference-module is created with empty or real Tensor + """ + save_mp_checkpoint_path: str = None """ The path for which we want to save the loaded model with a checkpoint. This @@ -247,6 +252,16 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): to the required token-length required for your use-case. """ + min_out_tokens: int = Field(1, alias="min_tokens") + """ + This argument communicates to the runtime the minimum number of tokens you + expect you will need to generate. This will cause the runtime to error + if it unable to provide this and provide context on the memory pressure + rather than seg-faulting or providing corrupted output. + """ + + transposed_mode: bool = Field(False, alias="transposed_mode") + mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") """ Desired model parallel size, default is 1 meaning no model parallelism. diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 6c5dd3d45478..6ef838cea741 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -65,13 +65,18 @@ def __init__(self, mlp_extra_grouping) device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu' - self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), - requires_grad=False) - self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), - requires_grad=False) + if self.config.set_empty_params: + self.norm_w = None + self.norm_b = None + else: + self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), + requires_grad=False) + self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), + requires_grad=False) self.layer_past = None self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \ inference_cuda_module.allocate_workspace_fp16 + self._alloc_workspace = True @classmethod def reset_cache(cls): @@ -110,12 +115,14 @@ def forward( input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask # Allocate memory only on first layer forward - if self.config.layer_id == 0: + if self.config.layer_id == 0 and self._alloc_workspace: self.allocate_workspace(self.config.hidden_size, self.config.heads, input.size()[1], input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size, self.config.bigscience_bloom, - dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens) + dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens, + self.config.min_out_tokens) + self._alloc_workspace = False get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask diff --git a/deepspeed/module_inject/containers/base.py b/deepspeed/module_inject/containers/base.py index ddb1b7afef52..20a664668f87 100644 --- a/deepspeed/module_inject/containers/base.py +++ b/deepspeed/module_inject/containers/base.py @@ -34,14 +34,11 @@ def __init__(self, policy, config, model_config, layer_id, child): self.hidden_size = None self.num_attention_heads = None self.mp_size = self.config.tensor_parallel.tp_size - self.pre_layer_norm = self.policy.pre_attn_norm + self.pre_layer_norm = self.model_config.do_layer_norm_before if \ + hasattr(self.model_config, 'do_layer_norm_before') else self.policy.pre_attn_norm self.fp16 = False self.attn_linear_layer = self.policy.linear_layer self.mlp_linear_layer = self.policy.linear_layer - self.layer_norm_eps = self.model_config.layer_norm_eps if \ - hasattr(self.model_config, 'layer_norm_eps') else (self.model_config.layer_norm_epsilon if \ - hasattr(self.model_config, 'layer_norm_epsilon') else self.model_config.layernorm_epsilon if \ - hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12) self.return_tuple = self.config.return_tuple self.triangular_masking = True self.local_attention = ((self.model_config.attention_layers[self.layer_id] == "local") if hasattr( @@ -51,6 +48,7 @@ def __init__(self, policy, config, model_config, layer_id, child): self.training_mp_size = self.config.training_mp_size self.bigscience_bloom = False self.max_out_tokens = self.config.max_out_tokens + self.min_out_tokens = self.config.min_out_tokens self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False) self.use_mup = self.policy.use_mup self.return_single_tuple = False @@ -75,6 +73,8 @@ def __init__(self, policy, config, model_config, layer_id, child): self.input_nw = None self.input_nb = None + self.mp_group = None + def create_ds_model_config(self): self.set_hidden_heads(*self.policy.get_hidden_heads()) assert self.num_attention_heads % self.mp_size == 0,\ @@ -84,11 +84,11 @@ def create_ds_model_config(self): self.ds_model_config = DeepSpeedInferenceConfig( hidden_size=self.hidden_size, heads=self.num_attention_heads, - layer_norm_eps=self.layer_norm_eps, + layer_norm_eps=self.layernorm_epsilon, fp16=self.fp16, pre_layer_norm=self.pre_layer_norm, mp_size=self.mp_size, - q_int8=self.quantize, + q_int8=self.quantize if hasattr(self, 'quantize') else False, return_tuple=self.return_tuple, triangular_masking=self.triangular_masking, local_attention=self.local_attention, @@ -99,18 +99,24 @@ def create_ds_model_config(self): training_mp_size=self.training_mp_size, bigscience_bloom=self.bigscience_bloom, max_out_tokens=self.max_out_tokens, + min_out_tokens=self.min_out_tokens, scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx, use_mup=self.use_mup, return_single_tuple=self.return_single_tuple, - ) + set_empty_params=self.config.set_empty_params, + transposed_mode=self.config.transposed_mode) return self.ds_model_config - def initialize_tensors(self): + def initialize_tensors(self, enable_training=False): # Set the tensors from policy (user module) to container (DS module) - self.set_attention(*self.policy.attention()) + self.set_attention(*self.policy.attention(enable_training=enable_training)) self.set_mlp(*self.policy.mlp()) self.set_layernorm(*self.policy.layernorm()) + self.set_lora_params(self.policy.get_lora_params()) + self.q_k_v = self.policy.get_q_k_v() + if self.q_k_v is not None: + self.set_q_k_v(*self.q_k_v) def convert_to_required_dtype(self, dtype): # Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy @@ -138,9 +144,10 @@ def set_quantization_config(self, quantize, quantizer): self.quantize = quantize self.quantizer = quantizer - def set_hidden_heads(self, hidden_size, num_attention_heads): + def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon): self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads + self.layernorm_epsilon = epsilon def set_attention(self, qkvw, qkvb, dense_w, dense_b): self.qkvw = qkvw @@ -148,6 +155,17 @@ def set_attention(self, qkvw, qkvb, dense_w, dense_b): self.dense_w = dense_w self.dense_b = dense_b + def set_lora_params(self, lora_params): + self.lora_params = lora_params + + def set_q_k_v(self, qw, qb, kw, kb, vw, vb): + self.qw = qw + self.qb = qb + self.kw = kw + self.kb = kb + self.vw = vw + self.vb = vb + def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b): self._h4h_w = _h4h_w self._h4h_b = _h4h_b @@ -175,33 +193,148 @@ def mlp_quantization(self): self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w) self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w) - def apply_tensor_parallelism(self, mp_replace): + def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None): + reversed_dim = False + if mp_replace is None: + from deepspeed.module_inject import ReplaceWithTensorSlicing + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=tp_size, out_dim=0, in_dim=1) + reversed_dim = True # setup the new Attention module - self.attention_qkv_mp(mp_replace) - self.attention_o_mp(mp_replace) + if self.module.attention.attn_qkvw is None: + self.attention_q_k_v_mp(mp_replace, reversed_dim=reversed_dim) + else: + self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim) + self.attention_o_mp(mp_replace, reversed_dim=reversed_dim) # setup the new MLP module - self.mlp_inter_mp(mp_replace) - self.mlp_output_mp(mp_replace) + self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim) + self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim) # Apply weight quantization - self.apply_weight_quantization() - - def attention_qkv_mp(self, mp_replace): - self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw, self.qkvw) - self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb, self.qkvb) - - def attention_o_mp(self, mp_replace): - self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, self.dense_w) - self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, self.dense_b) - - def mlp_inter_mp(self, mp_replace): - self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w) - self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b) - - def mlp_output_mp(self, mp_replace): - self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w) - self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b) + #self.apply_weight_quantization() + + def attention_qkv_mp(self, mp_replace, reversed_dim=False): + self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw, + self.qkvw, + int8=reversed_dim) + self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb, + self.qkvb, + int8=reversed_dim) + + def attention_q_k_v_mp(self, mp_replace, reversed_dim=False): + self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] // + mp_replace.mp_size], + self.qw, + int8=reversed_dim, + allocat_tensor=reversed_dim) + self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] // + mp_replace.mp_size], + self.kw, + int8=reversed_dim, + allocat_tensor=reversed_dim) + self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] // + mp_replace.mp_size], + self.vw, + int8=reversed_dim, + allocat_tensor=reversed_dim) + self.module.attention.attn_qb = mp_replace.copy(self.module.attention.attn_qb[:self.qw.shape[0] // + mp_replace.mp_size], + self.qb, + int8=reversed_dim, + allocat_tensor=reversed_dim) + self.module.attention.attn_kb = mp_replace.copy(self.module.attention.attn_kb[:self.qw.shape[0] // + mp_replace.mp_size], + self.kb, + int8=reversed_dim, + allocat_tensor=reversed_dim) + self.module.attention.attn_vb = mp_replace.copy(self.module.attention.attn_vb[:self.qw.shape[0] // + mp_replace.mp_size], + self.vb, + int8=reversed_dim, + allocat_tensor=reversed_dim) + + def attention_o_mp(self, mp_replace, reversed_dim=False): + if reversed_dim: + self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow[:, :self.dense_w.shape[1] // + mp_replace.mp_size], + self.dense_w, + int8=reversed_dim, + allocat_tensor=reversed_dim) + else: + self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, + self.dense_w, + int8=reversed_dim) + self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, + self.dense_b, + int8=reversed_dim, + allocat_tensor=reversed_dim) + + def mlp_inter_mp(self, mp_replace, reversed_dim=False): + if reversed_dim: + self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w[:self._h4h_w.shape[0] // + mp_replace.mp_size], + self._h4h_w, + int8=reversed_dim, + allocat_tensor=reversed_dim) + self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b[:self._h4h_w.shape[0] // + mp_replace.mp_size], + self._h4h_b, + int8=reversed_dim, + allocat_tensor=reversed_dim) + else: + self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim) + self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim) + + def mlp_output_mp(self, mp_replace, reversed_dim=False): + if reversed_dim: + self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w[:, :self._4hh_w.shape[1] // + mp_replace.mp_size], + self._4hh_w, + int8=reversed_dim, + allocat_tensor=reversed_dim) + else: + self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim) + self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, + self._4hh_b, + int8=reversed_dim, + allocat_tensor=reversed_dim) + + def release_qkv(self): + del self.module.attention.attn_qkvw + del self.module.attention.attn_qkvb + self.module.attention.attn_qkvw = None + self.module.attention.attn_qkvb = None + + qkv_data = [self.module.attention.attn_qw.data, \ + self.module.attention.attn_qb.data, \ + self.module.attention.attn_kw.data, \ + self.module.attention.attn_kb.data, \ + self.module.attention.attn_vw.data, \ + self.module.attention.attn_vb.data] + for data in qkv_data: + del data + + self.module.attention.attn_qw = self.qw + self.module.attention.attn_qb = self.qb + self.module.attention.attn_kw = self.kw + self.module.attention.attn_kb = self.kb + self.module.attention.attn_vw = self.vw + self.module.attention.attn_vb = self.vb + + def release_memory(self): + self.release_qkv() + del self.module.attention.attn_ow + del self.module.attention.attn_ob + self.module.attention.attn_ow = self.dense_w + self.module.attention.attn_ob = self.dense_b + del self.module.mlp.inter_w + del self.module.mlp.inter_b + del self.module.mlp.output_w + del self.module.mlp.output_b + self.module.mlp.inter_w = self._h4h_w + self.module.mlp.inter_b = self._h4h_b + self.module.mlp.output_w = self._4hh_w + self.module.mlp.output_b = self._4hh_b def copy_data_to_new_module(self): if self.attn_nw is None: @@ -234,3 +367,106 @@ def transpose_impl(self, data): data = data.reshape(data.shape[-1], data.shape[-2]) data.to(get_accelerator().current_device_name()) return data + + def reset_qkv_experimental(self): + if self.module.attention.attn_qkvw is None: + self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3, + self.qw.shape[0], + dtype=self.qw.dtype, + device=self.qw.device) + self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3, + dtype=self.qw.dtype, + device=self.qw.device) + self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data + self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data + self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data + self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data + self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data + self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data + + qkv_data = [self.qw.data, \ + self.qb.data, \ + self.kw.data, \ + self.kb.data, \ + self.vw.data, \ + self.vb.data] + + self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]] + self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]] + self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] + self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] + + for data in qkv_data: + del data + + def reset_qkv(self): + self.qkvw.data[:self.qw.shape[0]] = self.qw.data + self.qkvb.data[:self.qw.shape[0]] = self.qb.data + self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data + self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data + self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data + self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data + + qkv_data = [self.qw.data, \ + self.qb.data, \ + self.kw.data, \ + self.kb.data, \ + self.vw.data, \ + self.vb.data] + + self.qw.data = self.qkvw.data[:self.qw.shape[0]] + self.qb.data = self.qkvb.data[:self.qw.shape[0]] + self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:] + self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:] + + for data in qkv_data: + del data + + def set_params_wo_copy(self, Z3_enabled=False): + self.module.mlp.attn_nw = self.attn_nw + self.module.mlp.attn_nb = self.attn_nb + self.module.norm_w = self.input_nw + self.module.norm_b = self.input_nb + self.module.mlp.inter_w = self._h4h_w + self.module.mlp.inter_b = self._h4h_b + self.module.mlp.output_w = self._4hh_w + self.module.mlp.output_b = self._4hh_b + self.module.attention.attn_ow = self.dense_w + self.module.attention.attn_ob = self.dense_b + if not Z3_enabled or self.q_k_v is None: + self.module.attention.attn_qkvw = self.qkvw + self.module.attention.attn_qkvb = self.qkvb + if self.q_k_v is not None: + if Z3_enabled: + self.module.attention.attn_qw = self.qw + self.module.attention.attn_qb = self.qb + self.module.attention.attn_kw = self.kw + self.module.attention.attn_kb = self.kb + self.module.attention.attn_vw = self.vw + self.module.attention.attn_vb = self.vb + else: + self.qw.data = self.qkvw[:self.qw.shape[0], :] + self.qb.data = self.qkvb[:self.qw.shape[0]] + self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :] + self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]] + self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :] + self.vb.data = self.qkvb[self.qw.shape[0] * 2:] + + def get_lora_params(self): + return self.lora_params + + def get_all_params(self): + if self.q_k_v is not None: + return [ + self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w, + self._4hh_b, self.qw, self.qb, self.kw, self.kb, self.vw, self.vb, self.dense_w, self.dense_b + ] + else: + return [ + self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w, + self._4hh_b, self.qkvw, self.qkvb, self.dense_w, self.dense_b + ] diff --git a/deepspeed/module_inject/containers/bert.py b/deepspeed/module_inject/containers/bert.py index 2f74db82d245..f8070655283e 100644 --- a/deepspeed/module_inject/containers/bert.py +++ b/deepspeed/module_inject/containers/bert.py @@ -44,10 +44,18 @@ def __init__(self, client_module, inference=False): HFBertLayerPolicy._orig_layer_class = None def get_hidden_heads(self): + if self.pre_attn_norm: + attention_layernorm = self.client_module.PostAttentionLayerNorm + else: + attention_layernorm = self.client_module.attention.output.LayerNorm return self.client_module.attention.self.query.weight.shape[1], \ - self.client_module.attention.self.num_attention_heads + self.client_module.attention.self.num_attention_heads, \ + attention_layernorm.eps + + def get_q_k_v(self): + return None - def attention(self): + def attention(self, enable_training=False): qw = self.client_module.attention.self.query.weight qb = self.client_module.attention.self.query.bias kw = self.client_module.attention.self.key.weight @@ -55,8 +63,8 @@ def attention(self): vw = self.client_module.attention.self.value.weight vb = self.client_module.attention.self.value.bias - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) - qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) return qkvw, \ qkvb, \ @@ -84,3 +92,6 @@ def layernorm(self): attention_layernorm.bias, \ transformer_layernorm.weight, \ transformer_layernorm.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index 26daba515c93..7bcf6943de60 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -28,7 +28,7 @@ def create_module(self, config=None): self.module.config.scale_attention = self.scale_attention return self.module - def attention_qkv_mp(self, mp_replace): + def attention_qkv_mp(self, mp_replace, reversed_dim=False): self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw) self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb) @@ -84,9 +84,13 @@ def __init__(self, client_module, inference=True, use_load_prefix=True, split_qk def get_hidden_heads(self): return self.client_module.self_attention.hidden_size, \ - self.client_module.self_attention.num_heads + self.client_module.self_attention.num_heads, \ + self.client_module.input_layernorm.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): return self.client_module.self_attention.query_key_value.weight, \ self.client_module.self_attention.query_key_value.bias, \ self.client_module.self_attention.dense.weight, \ @@ -103,3 +107,6 @@ def layernorm(self): self.client_module.post_attention_layernorm.bias, \ self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/clip.py b/deepspeed/module_inject/containers/clip.py index 698103cd69e5..144f1b823a1a 100644 --- a/deepspeed/module_inject/containers/clip.py +++ b/deepspeed/module_inject/containers/clip.py @@ -40,7 +40,11 @@ def __init__(self, client_module, inference=False): def get_hidden_heads(self): return self.client_module.self_attn.q_proj.weight.shape[1], \ - self.client_module.self_attn.num_heads + self.client_module.self_attn.num_heads, \ + self.client_module.layer_norm1.eps + + def get_q_k_v(self): + return None def attention(self): qw = self.client_module.self_attn.q_proj.weight @@ -69,3 +73,6 @@ def layernorm(self): self.client_module.layer_norm2.bias, \ self.client_module.layer_norm1.weight, \ self.client_module.layer_norm1.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/distil_bert.py b/deepspeed/module_inject/containers/distil_bert.py index ea1a7cdba115..792b965399e2 100644 --- a/deepspeed/module_inject/containers/distil_bert.py +++ b/deepspeed/module_inject/containers/distil_bert.py @@ -45,9 +45,13 @@ def __init__(self, client_module, inference=False, preln=False): def get_hidden_heads(self): return self.client_module.attention.q_lin.weight.shape[1], \ - self.client_module.attention.n_heads + self.client_module.attention.n_heads, \ + self.client_module.sa_layer_norm.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): qw = self.client_module.attention.q_lin.weight qb = self.client_module.attention.q_lin.bias kw = self.client_module.attention.k_lin.weight @@ -55,8 +59,8 @@ def attention(self): vw = self.client_module.attention.v_lin.weight vb = self.client_module.attention.v_lin.bias - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0)) - qkvb = Parameter(torch.cat((qb, kb, vb), dim=0)) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) return qkvw, \ qkvb, \ @@ -77,3 +81,6 @@ def layernorm(self): attention_layernorm.bias, \ transformer_layernorm.weight, \ transformer_layernorm.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/features/meta_tensor.py b/deepspeed/module_inject/containers/features/meta_tensor.py index d572573f31d2..7aa507ca2e44 100644 --- a/deepspeed/module_inject/containers/features/meta_tensor.py +++ b/deepspeed/module_inject/containers/features/meta_tensor.py @@ -13,18 +13,18 @@ def __init__(self, **kwargs): self.is_meta = False self.ckpt_load_enabled = True - def initialize_tensors(self): - super().initialize_tensors() + def initialize_tensors(self, enable_training=False): + super().initialize_tensors(enable_training=enable_training) self.is_meta = self.qkvw.is_meta - def apply_tensor_parallelism(self, mp_replace): + def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None): if self.is_meta: if self.qkvb is None: self.module.attention.attn_qkvb = None if self.dense_b is None: self.module.attention.attn_ob = None else: - super().apply_tensor_parallelism(mp_replace) + super().apply_tensor_parallelism(mp_replace, mp_group, tp_size) def copy_data_to_new_module(self): if self.is_meta: diff --git a/deepspeed/module_inject/containers/gpt2.py b/deepspeed/module_inject/containers/gpt2.py index 56b2ab3caf35..3f6373897c58 100644 --- a/deepspeed/module_inject/containers/gpt2.py +++ b/deepspeed/module_inject/containers/gpt2.py @@ -37,9 +37,13 @@ def __init__(self, client_module, inference=True): def get_hidden_heads(self): return self.client_module.attn.embed_dim, \ - self.client_module.attn.num_heads + self.client_module.attn.num_heads, \ + self.client_module.ln_1.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): return self.client_module.attn.c_attn.weight, \ self.client_module.attn.c_attn.bias, \ self.client_module.attn.c_proj.weight, \ @@ -56,3 +60,6 @@ def layernorm(self): self.client_module.ln_2.bias, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/gptj.py b/deepspeed/module_inject/containers/gptj.py index 816e38c8efc5..e7883105dde9 100644 --- a/deepspeed/module_inject/containers/gptj.py +++ b/deepspeed/module_inject/containers/gptj.py @@ -71,14 +71,18 @@ def __init__(self, client_module, inference=True): def get_hidden_heads(self): return self.client_module.attn.q_proj.weight.shape[1], \ - self.client_module.attn.num_attention_heads + self.client_module.attn.num_attention_heads, \ + self.client_module.ln_1.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): qw = self.client_module.attn.q_proj.weight kw = self.client_module.attn.k_proj.weight vw = self.client_module.attn.v_proj.weight - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) return qkvw, \ None, \ @@ -96,3 +100,6 @@ def layernorm(self): None, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/gptneo.py b/deepspeed/module_inject/containers/gptneo.py index 6b880160b3d8..b9261b8c0b3b 100644 --- a/deepspeed/module_inject/containers/gptneo.py +++ b/deepspeed/module_inject/containers/gptneo.py @@ -73,14 +73,18 @@ def __init__(self, client_module, inference=True): def get_hidden_heads(self): return self.client_module.attn.attention.q_proj.weight.shape[1], \ - self.client_module.attn.attention.num_heads + self.client_module.attn.attention.num_heads, \ + self.client_module.ln_1.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): qw = self.client_module.attn.attention.q_proj.weight kw = self.client_module.attn.attention.k_proj.weight vw = self.client_module.attn.attention.v_proj.weight - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) return qkvw, \ None, \ @@ -98,3 +102,6 @@ def layernorm(self): self.client_module.ln_2.bias, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/gptneox.py b/deepspeed/module_inject/containers/gptneox.py index fdcf4d9420fd..61fd2b14dab0 100644 --- a/deepspeed/module_inject/containers/gptneox.py +++ b/deepspeed/module_inject/containers/gptneox.py @@ -92,9 +92,13 @@ def get_hidden_heads(self): attention = self.client_module.self_attention return self.client_module.attention.query_key_value.weight.shape[1], \ - self.client_module.attention.num_attention_heads + self.client_module.attention.num_attention_heads, \ + self.client_module.input_layernorm.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): if GPTNEOXLayerPolicy.version == 0: attention = self.client_module.attention else: @@ -116,3 +120,6 @@ def layernorm(self): self.client_module.post_attention_layernorm.bias, \ self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/megatron_gpt.py b/deepspeed/module_inject/containers/megatron_gpt.py index 036bdea5acd8..28b7df5b48d5 100644 --- a/deepspeed/module_inject/containers/megatron_gpt.py +++ b/deepspeed/module_inject/containers/megatron_gpt.py @@ -56,9 +56,13 @@ def __init__(self, client_module, inference=True): def get_hidden_heads(self): return self.client_module.attention.query_key_value.weight.shape[1], \ - self.client_module.attention.num_attention_heads + self.client_module.attention.num_attention_heads, \ + self.client_module.input_layernorm.eps - def attention(self): + def get_q_k_v(self): + return None + + def attention(self, enable_training=False): if self.inference: if MegatronLayerPolicy.version == 0: attention = self.client_module.attention @@ -106,3 +110,6 @@ def layernorm(self): self.client_module.post_attention_layernorm.bias, \ self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + + def get_lora_params(self): + return [] diff --git a/deepspeed/module_inject/containers/opt.py b/deepspeed/module_inject/containers/opt.py index 57d9e3835148..142b92744103 100644 --- a/deepspeed/module_inject/containers/opt.py +++ b/deepspeed/module_inject/containers/opt.py @@ -12,6 +12,7 @@ from ..policy import transformer_param_names from ..policy import maybe_copy from ..policy import maybe_copy_qkv +from ..policy import maybe_get_lora from deepspeed.utils.types import ActivationFuncType @@ -77,20 +78,26 @@ def __init__(self, client_module, inference=True, use_load_prefix=True): pre_attn_norm=True, use_load_prefix=use_load_prefix) self.client_module = client_module - try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer - if isinstance(TransformerPolicy.hf_model_config, transformers.models.opt.configuration_opt.OPTConfig): - self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before except: HFOPTLayerPolicy._orig_layer_class = None def get_hidden_heads(self): return self.client_module.self_attn.embed_dim, \ - self.client_module.self_attn.num_heads - - def attention(self): + self.client_module.self_attn.num_heads, \ + self.client_module.self_attn_layer_norm.eps + + def get_q_k_v(self): + return self.client_module.self_attn.q_proj.weight, \ + self.client_module.self_attn.q_proj.bias, \ + self.client_module.self_attn.k_proj.weight, \ + self.client_module.self_attn.k_proj.bias, \ + self.client_module.self_attn.v_proj.weight, \ + self.client_module.self_attn.v_proj.bias + + def attention(self, enable_training=False): qw = self.client_module.self_attn.q_proj.weight qb = self.client_module.self_attn.q_proj.bias @@ -100,9 +107,8 @@ def attention(self): vw = self.client_module.self_attn.v_proj.weight vb = self.client_module.self_attn.v_proj.bias - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) - qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) - + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) return qkvw, \ qkvb, \ self.client_module.self_attn.out_proj.weight, \ @@ -119,3 +125,16 @@ def layernorm(self): self.client_module.final_layer_norm.bias, \ self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.bias + + def get_lora_params(self): + all_lora_params = [] + for p in [ + self.client_module.fc1, \ + self.client_module.fc2, \ + self.client_module.self_attn.q_proj, \ + self.client_module.self_attn.k_proj, \ + self.client_module.self_attn.v_proj, \ + self.client_module.self_attn.out_proj, \ + ]: + all_lora_params.append(maybe_get_lora(p)) + return all_lora_params diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 790ce99f05fb..70dd1a3af0e1 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -55,22 +55,34 @@ def forward(self, input): class Normalize(nn.Module): - def __init__(self, dim, dtype=torch.float, eps=1e-5): + def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None): super(Normalize, self).__init__() - self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name()) - self.weight = self.norm.weight - self.bias = self.norm.bias + if weight is not None: + self.weight = weight + self.bias = bias + else: + self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name()) + self.weight = self.norm.weight + self.bias = self.norm.bias + + self.eps = eps def forward(self, input): - return self.norm(input) + return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps) class EmbeddingLayer(nn.Module): - def __init__(self, weight_shape, dtype=torch.half): + def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): super(EmbeddingLayer, self).__init__() - self.weight = Parameter( - torch.empty(weight_shape[0], weight_shape[1], dtype=dtype, device=get_accelerator().current_device_name())) + if weight is None: + self.weight = Parameter( + torch.empty(weight_shape[0], + weight_shape[1], + dtype=dtype, + device=get_accelerator().current_device_name())) + else: + self.weight = weight def forward(self, input): return F.embedding(input, self.weight) @@ -81,11 +93,11 @@ class OPTEmbedding(EmbeddingLayer): This module learns positional embeddings up to a fixed maximum size. """ - def __init__(self, weight_shape): + def __init__(self, weight_shape=None, weight=None, bias=None): # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 - super().__init__(weight_shape) + super().__init__(weight_shape, weight=weight) def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" diff --git a/deepspeed/module_inject/policy.py b/deepspeed/module_inject/policy.py index 095dbb7af9c8..87b34e5aab5a 100644 --- a/deepspeed/module_inject/policy.py +++ b/deepspeed/module_inject/policy.py @@ -72,7 +72,7 @@ def __init__( self.split_qkv = split_qkv @abstractmethod - def attention(self): + def attention(self, enable_training=False): """ Returns attention qkv and dense parameters weight: (3*hidden, hidden) and (hidden, hidden) @@ -80,6 +80,13 @@ def attention(self): """ raise NotImplementedError + @abstractmethod + def get_q_k_v(self): + """ + return all q,k,v parameters without merging them together + """ + raise NotImplementedError + @abstractmethod def get_hidden_heads(self): """ @@ -105,6 +112,14 @@ def layernorm(self): """ raise NotImplementedError + @abstractmethod + def get_lora_params(self): + """ + Returns lora parameters used in transformer layer + + """ + raise NotImplementedError + # TODO (lekurile): This function exists in base container as well, consolidate as some point def transpose(data): @@ -189,3 +204,19 @@ def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ transpose(qkv_data)), int8=weight_quantizer.q_int8) setattr(module, dst_name, dst) + + +def pack_lora_weights(p): + return [ + p.lora_right_weight, \ + p.lora_left_weight, \ + p.lora_scaling + ] + + +def maybe_get_lora(p): + if hasattr(p, 'lora_right_weight'): + lora_param = pack_lora_weights(p) + else: + lora_param = [] + return lora_param diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index d8c1a501406c..b6f20845dda0 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -87,10 +87,12 @@ def qkv_copy(self, dst, src, int8=False): dst.scale = src.scale return dst - def copy(self, dst, src, int8=False): + def copy(self, dst, src, int8=False, allocat_tensor=False): if src is None: return src assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors + if allocat_tensor: + dst = torch.empty_like(dst) outer_dim = 0 if int8 else 1 inner_dim = 1 if int8 else 0 src_shape = src.shape @@ -102,21 +104,21 @@ def copy(self, dst, src, int8=False): else: if src_shape[inner_dim] != dst_shape[self.in_dim]: self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim]) - weight_split = torch.split(src, dst_shape[self.in_dim], dim=inner_dim)[self.gpu_index].contiguous() + dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \ + src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :]) else: self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) - weight_split = torch.split(src.data, dst_shape[self.out_dim], - dim=outer_dim)[self.gpu_index].contiguous() - dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(weight_split.shape) + dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \ + src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :]) else: if src_shape[0] == dst_shape[0]: - dst.data.copy_(src) + dst = src else: - bias_split = torch.split(src.data, dst_shape[-1])[self.gpu_index].contiguous() - dst.data.copy_(bias_split) + dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]]) dst = torch.nn.parameter.Parameter(dst, requires_grad=False) if hasattr(src, 'scale'): dst.scale = src.scale + return dst diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index 8015d5f0c814..549a03a70f19 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -69,10 +69,13 @@ def __init__(self, training_mp_size=1, bigscience_bloom=False, max_out_tokens=1024, + min_out_tokens=1, enable_qkv_quantization=False, use_mup=False, scale_attn_by_inverse_layer_idx=False, - return_single_tuple=False): + return_single_tuple=False, + set_empty_params=False, + transposed_mode=False): super(DeepSpeedInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers) @@ -97,10 +100,13 @@ def __init__(self, self.training_mp_size = training_mp_size self.bigscience_bloom = bigscience_bloom self.max_out_tokens = max_out_tokens + self.min_out_tokens = min_out_tokens self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx self.enable_qkv_quantization = enable_qkv_quantization self.use_mup = use_mup self.return_single_tuple = return_single_tuple + self.set_empty_params = set_empty_params + self.transposed_mode = transposed_mode @classmethod def from_dict(cls, json_object): diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 26f947275653..46c36d337428 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -15,6 +15,7 @@ class DeepSpeedSelfAttention(nn.Module): num_layers = 0 + _qkv_buffers = [] def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1): super(DeepSpeedSelfAttention, self).__init__() @@ -24,23 +25,35 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu' - qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 - self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, - qkv_size_per_partition, - dtype=data_type, - device=device), - requires_grad=False) - self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), - requires_grad=False) - out_size_per_partition = self.config.hidden_size // self.config.mp_size - self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, - self.config.hidden_size, - dtype=data_type, - device=device), - requires_grad=False) - - self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), - requires_grad=False) + if self.config.set_empty_params: + self.attn_qw = None + self.attn_qb = None + self.attn_kw = None + self.attn_kb = None + self.attn_vw = None + self.attn_vb = None + self.attn_qkvw = None + self.attn_qkvb = None + self.attn_ow = None + self.attn_ob = None + else: + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), + requires_grad=False) + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size @@ -65,6 +78,14 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.score_context_func = SoftmaxContextOp(config) self.linear_func = LinearOp(config) self.vector_matmul_func = VectorMatMulOp(config) + if len(DeepSpeedSelfAttention._qkv_buffers) == 0: + DeepSpeedSelfAttention._qkv_buffers = [ + torch.empty(self.hidden_size_per_partition * 3, + self.config.hidden_size, + dtype=data_type_fp, + device=device), + torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device) + ] def compute_attention(self, qkv_out, input_mask, layer_past, alibi): if isinstance(qkv_out, list): @@ -89,6 +110,18 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer + def _merge_qkv(self): + qvkw = DeepSpeedSelfAttention._qkv_buffers[0] + qvkw[:self.hidden_size_per_partition, :] = self.attn_qw + qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw + qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw + if self.attn_qb is not None: + qvkb = DeepSpeedSelfAttention._qkv_buffers[1] + qvkb[:self.hidden_size_per_partition] = self.attn_qb + qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb + qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb + return DeepSpeedSelfAttention._qkv_buffers + def forward(self, input, input_mask, @@ -101,30 +134,33 @@ def forward(self, norm_w=None, norm_b=None, alibi=None): + if self.attn_qkvw is None: + self._attn_qkvw, self._attn_qkvb = self._merge_qkv() + else: + self._attn_qkvw = self.attn_qkvw + self._attn_qkvb = self.attn_qkvb if not self.config.pre_layer_norm: qkv_out = self.linear_func(input=input, - weight=self.attn_qkvw, - bias=self.attn_qkvb, + weight=self._attn_qkvw, + bias=self._attn_qkvb, add_bias=self.attn_qkvb is not None, do_flash_attn=False, num_heads=self.num_attention_heads_per_partition, num_layers=DeepSpeedSelfAttention.num_layers) else: qkv_out = self.qkv_func(input=input, - weight=self.attn_qkvw, - bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b), + weight=self._attn_qkvw, + bias=(self._attn_qkvb if self._attn_qkvb is not None else norm_b), gamma=norm_w, beta=norm_b, add_bias=(self.attn_qkvb is not None), num_layers=DeepSpeedSelfAttention.num_layers, num_heads=self.num_attention_heads_per_partition) - context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out, input_mask=input_mask, layer_past=layer_past, alibi=alibi) - output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) inp_norm = qkv_out[-1] diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index 43e3449c5886..a4375178347a 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -20,25 +20,33 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float device = get_accelerator().current_device_name() - self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), - requires_grad=False) - self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), - requires_grad=False) - intm_size_per_partition = self.config.intermediate_size // self.config.mp_size - self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, - intm_size_per_partition, - dtype=data_type, - device=device), - requires_grad=False) - self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), - requires_grad=False) - self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, - self.config.hidden_size, - dtype=data_type, - device=device), - requires_grad=False) - self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), - requires_grad=False) + if self.config.set_empty_params: + self.attn_nw = None + self.attn_nb = None + self.inter_w = None + self.inter_b = None + self.output_w = None + self.output_b = None + else: + self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + intm_size_per_partition = self.config.intermediate_size // self.config.mp_size + self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, + intm_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), + requires_grad=False) + self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) # used for quantization self.q_scales = q_scales @@ -74,8 +82,6 @@ def forward(self, input, residual, residual_norm, bias): final_bias=self.output_b, add_bias=bias is not None, residual_add=residual_add) - if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: dist.all_reduce(residual, group=self.mp_group) - return residual diff --git a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py index 4da22c0aa610..89ef0b517c49 100644 --- a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py @@ -24,5 +24,6 @@ def forward(self, weight_out: torch.Tensor, async_op: bool = False): output = self.fused_gemm_gelu(input, weight, weight.scale, bias, weight_out, weight_out.scale, - self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op) + self.config.epsilon, self.config.pre_layer_norm, self.config.q_int8, async_op, + self.config.transposed_mode) return output diff --git a/deepspeed/ops/transformer/inference/op_binding/linear.py b/deepspeed/ops/transformer/inference/op_binding/linear.py index d5f095661519..9178c5f1fc5b 100644 --- a/deepspeed/ops/transformer/inference/op_binding/linear.py +++ b/deepspeed/ops/transformer/inference/op_binding/linear.py @@ -26,5 +26,6 @@ def forward(self, num_heads: int, external_cache: bool = None, num_layers: int = None): - qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads) + qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, + self.config.transposed_mode) return qkv_out diff --git a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py index 2e9c89daf03b..e7ca40219c34 100644 --- a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py @@ -20,8 +20,10 @@ def __init__(self, config: DeepSpeedInferenceConfig): def forward(self, input: torch.Tensor, residual: torch.Tensor, input_bias: torch.Tensor, weight_interm: torch.Tensor, weight_out: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor): - output, residual_add = self.mlp_gemm_func(input, residual, input_bias, weight_interm, weight_out, bias, gamma, - beta, self.config.epsilon, self.config.pre_layer_norm, - self.config.mlp_after_attn, weight_interm.scale, weight_out.scale, - self.config.q_int8, self.config.mlp_act_func_type) + output, residual_add = self.mlp_gemm_func( + input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, self.config.epsilon, + self.config.pre_layer_norm, self.config.mlp_after_attn, + weight_interm.scale if hasattr(weight_interm, 'scale') else torch.empty(1), + weight_out.scale if hasattr(weight_out, 'scale') else torch.empty(1), self.config.q_int8, + self.config.mlp_act_func_type, self.config.transposed_mode) return output, residual_add diff --git a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py index 7d6b646b0a8e..6b338b9041d9 100644 --- a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py @@ -28,10 +28,11 @@ def forward(self, num_layers: int, num_heads: int = None, max_out_tokens: int = None): - q_scale = weight.scale + q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1) external_cache = self.config.bigscience_bloom rank = dist.get_rank() if dist.is_initialized() else 0 q_int8 = self.config.q_int8 output = self.qkv_gemm_func(input, weight, q_scale, bias, gamma, beta, self.config.epsilon, add_bias, - num_layers, external_cache, self.config.mp_size, rank, q_int8) + num_layers, external_cache, self.config.mp_size, rank, q_int8, + self.config.transposed_mode) return output diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py index 631696c8ab57..f916020baa9e 100644 --- a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py +++ b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py @@ -18,7 +18,7 @@ def __init__(self, config: DeepSpeedInferenceConfig): self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp32 def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False): - q_scale = weight.scale + q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1) q_int8 = self.config.q_int8 - output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8) + output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode) return output diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 4d1d4883bd16..3c202a9acd07 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -31,6 +31,7 @@ from ..monitor.config import get_monitor_config from deepspeed import comm as dist +from deepspeed.runtime.config_utils import DeepSpeedConfigModel from ..git_version_info import version as __version__ from ..utils import logger @@ -514,6 +515,21 @@ def get_memory_breakdown(param_dict): return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT) +class HybridEngineConfig(DeepSpeedConfigModel): + enabled: bool = False + max_out_tokens: int = 512 + inference_tp_size: int = 1 + release_inference_cache: bool = False + pin_parameters: bool = True + tp_gather_partition_size: int = 8 + + +def get_hybrid_engine_config(param_dict): + hybrid_engine_config_dict = param_dict.get("hybrid_engine", {}) + hybrid_engine_config = HybridEngineConfig(**hybrid_engine_config_dict) + return hybrid_engine_config + + def get_eigenvalue_config(param_dict): if get_quantize_enabled(param_dict): param_dict = param_dict[QUANTIZE_TRAINING] @@ -816,6 +832,8 @@ def _initialize_params(self, param_dict): self.eigenvalue_layer_num, ) = get_eigenvalue_config(param_dict) + self.hybrid_engine = get_hybrid_engine_config(param_dict) + self.sparse_attention = get_sparse_attention(param_dict) self.pipeline = get_pipeline_config(param_dict) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index cc495636fe50..e953938c06a4 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -33,7 +33,7 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer -from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ +from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER @@ -195,7 +195,7 @@ def __init__( dist_init_required=None, collate_fn=None, config=None, - config_params=None, + config_class=None, dont_change_device=False, ): super(DeepSpeedEngine, self).__init__() @@ -213,6 +213,7 @@ def __init__( self.gradient_average = True self.warn_unscaled_loss = True self.config = config + self._config = config_class self.loaded_checkpoint_mp_world_size = None self.loaded_checkpoint_dp_world_size = None self.enable_backward_allreduce = True @@ -242,10 +243,6 @@ def __init__( # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict self.param_names = {param: name for name, param in model.named_parameters()} - # Set config using config_params for backwards compat - if self.config is None and config_params is not None: - self.config = config_params - from deepspeed.comm import supported_torch_version # This supported_torch_version check is for torch1.2 compatibility only if supported_torch_version: @@ -949,19 +946,8 @@ def _configure_with_arguments(self, args, mpu): if hasattr(args, 'local_rank'): args.local_rank = self.local_rank - if self.config is None: - self.config = (args.deepspeed_config if hasattr(args, "deepspeed_config") else None) - self._config = DeepSpeedConfig(self.config, mpu) - # Validate command line arguments def _do_args_sanity_check(self, args): - if hasattr(args, "deepscale_config") and args.deepscale_config is not None: - logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************") - if hasattr(args, "deepspeed_config"): - assert (args.deepspeed_config is - None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" - args.deepspeed_config = args.deepscale_config - assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \ "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." @@ -975,10 +961,6 @@ def _do_args_sanity_check(self, args): env_local_rank == args.local_rank ), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." - if self.config is None: - assert (hasattr(args, "deepspeed_config") and args.deepspeed_config - is not None), "DeepSpeed requires --deepspeed_config to specify configuration file" - def _is_supported_optimizer(self, optimizer_name): return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None) diff --git a/deepspeed/runtime/fp16/loss_scaler.py b/deepspeed/runtime/fp16/loss_scaler.py index 4d75af4b6314..e12ee92fdf98 100755 --- a/deepspeed/runtime/fp16/loss_scaler.py +++ b/deepspeed/runtime/fp16/loss_scaler.py @@ -60,6 +60,7 @@ def update_scale(self, overflow): def backward(self, loss, retain_graph=False): scaled_loss = loss * self.loss_scale scaled_loss.backward(retain_graph=retain_graph) + # print(f'LossScalerBackward: {scaled_loss=}') class LossScaler(LossScalerBase): diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py new file mode 100644 index 000000000000..c8e36b464fdc --- /dev/null +++ b/deepspeed/runtime/hybrid_engine.py @@ -0,0 +1,404 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.inference.config import DeepSpeedInferenceConfig +from deepspeed.module_inject.replace_policy import replace_policies +from deepspeed.module_inject.utils import policy_to_ds_container +from .engine import DeepSpeedEngine +from .utils import TLinear, get_inactive_params +from deepspeed.runtime.zero import GatheredParameters +import time +import gc + +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from torch import nn +from deepspeed.utils import logger + +from deepspeed.ops.op_builder import InferenceBuilder + +from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding +try: + import transformers + OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding +except: + OPTLearnedPositionalEmbedding = None +inference_cuda_module = None + + +class DeepSpeedHybridEngine(DeepSpeedEngine): + r"""DeepSpeed engine for training and inference.""" + inference_mp_group = None + + def __init__(self, args, model, **kwargs): + + super().__init__(args, model, **kwargs) + + # synch seed between all GPUs + _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) + dist.broadcast(_rng_state, 0) + get_accelerator().set_rng_state(_rng_state.cpu()) + + self.Z3_enabled = (self._config.zero_config.stage == 3) + self.gather_all_layers = self._config.hybrid_engine.pin_parameters + + # inference containers / fwds + self._inference_containers = [] + self._orig_modules = [] + self._orig_fwds = [] + self.create_inference_module() + + # Performance stats + self._t_start = None + self._total_latency = 0 + self._iters = 0 + self._training_start_time = None + self._generate_latency = 0 + self._training_latency = 0 + self._total_batch_size = None + self._gather_latency = 0 + + global inference_cuda_module + if inference_cuda_module is None: + builder = InferenceBuilder() + inference_cuda_module = builder.load() + + self.is_lora_fused = False + + def convert_to_linear_transposed(self, model): + + def _replace_linear_layer(r_module, parent_type=None, prev_type=None): + for name, child in r_module.named_children(): + if child.__class__ in [torch.nn.Linear] and \ + (parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList): + setattr(r_module, name, TLinear(child, name)) + else: + _replace_linear_layer(child, type(r_module), prev_type=parent_type) + return r_module + + _replace_linear_layer(model) + + def new_inference_container(self, orig_layer, policy_cls, layer_id): + policy = policy_cls(orig_layer, inference=True) + _container = policy_to_ds_container( + policy=policy, + config=DeepSpeedInferenceConfig(set_empty_params=True, + max_out_tokens=self._config.hybrid_engine.max_out_tokens, + min_out_tokens=self._config.hybrid_engine.max_out_tokens, + transposed_mode=True), + model_config=self.module.config if hasattr(self.module, 'config') else None, + layer_id=layer_id, + child=orig_layer) + _container.set_dtype(self._config.fp16_enabled) + + _container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group) + _container.initialize_tensors(enable_training=True) + _container.create_ds_model_config() + _container.create_module() + _container.set_params_wo_copy(Z3_enabled=self.Z3_enabled) + return _container + + def populate_all_inference_policies(self): + self.inference_policies = {} + for plcy in replace_policies: + _ = plcy(None) + if isinstance(plcy._orig_layer_class, list): + for orig_layer_class in plcy._orig_layer_class: + self.inference_policies.update({orig_layer_class: (self.new_inference_container, plcy)}) + elif plcy._orig_layer_class is not None: + self.inference_policies.update({plcy._orig_layer_class: (self.new_inference_container, plcy)}) + self.inference_policies.update({ + nn.Linear: (LinearLayer, ), + nn.Embedding: (EmbeddingLayer, ), + nn.LayerNorm: (Normalize, ), + OPTLearnedPositionalEmbedding: (OPTEmbedding, ) + }) + + def _fuse_lora(self, params, lora_params): + maybe_has_lora_params = [p for p in params if len(p.shape) > 1] + for lora_param, weight in zip(lora_params, maybe_has_lora_params): + if len(lora_params) > 0: + lora_right_weight, \ + lora_left_weight, \ + lora_scaling = lora_param + weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t()) + + def fuse_lora_weight(self): + for layer_id in range(len(self.layer_params)): + self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + + def _unfuse_lora(self, params, lora_params): + maybe_has_lora_params = [p for p in params if len(p.shape) > 1] + for lora_param, weight in zip(lora_params, maybe_has_lora_params): + if len(lora_params) > 0: + lora_right_weight, \ + lora_left_weight, \ + lora_scaling = lora_param + weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t()) + + def unfuse_lora_weight(self): + for layer_id in range(len(self.layer_params)): + self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + + def unfuse_lora_weight_non_pinned(self): + for layer_id in range(len(self.layer_params)): + non_active_params = get_inactive_params(self.layer_params[layer_id]) + non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id]) + non_active_params.extend(non_active_lora_params) + + with GatheredParameters(non_active_params): + self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + + def retake_inference_cache(self): + if self._config.hybrid_engine.release_inference_cache: + retake_success = inference_cuda_module.retake_workspace() + + if not retake_success: + logger.warning("Unable to acquire workspace on first attempt, emtpying cache and retrying.") + gc.collect() + get_accelerator().empty_cache() + retake_success = inference_cuda_module.retake_workspace() + + if not retake_success: + raise RuntimeError("Unable to retake inference workspace.") + + def generate(self, *inputs, **kwargs): + if self._total_batch_size is None: + bsz = inputs[0].shape[0] if len(inputs) > 0 else \ + kwargs['input_ids'].shape[0] + self._total_batch_size = bsz * dist.get_world_size() + + self._t0 = time.time() + + if self.Z3_enabled and self.gather_all_layers: + if self._config.hybrid_engine.inference_tp_size > 1: + non_tp_params = [] + for other_layer in self._other_layers: + non_tp_params.extend(list(other_layer.parameters())) + + partition_size = self._config.hybrid_engine.tp_gather_partition_size + + layer_groups = len(self.layer_params) // partition_size + for lg in range(layer_groups): + non_active_params = [] + non_active_lora_params = [] + for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size), + 1): + non_tp_params.extend(self.layer_params[layer_id][:4]) + non_active_params.extend(get_inactive_params(self.layer_params[layer_id])) + non_active_params.extend(get_inactive_params(self.layer_lora_params[layer_id])) + with GatheredParameters(non_active_params): + for layer_id in range(lg * partition_size, + min(len(self.layer_params), (lg + 1) * partition_size), 1): + if len(self.all_lora_params) > 0: + self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + self._inference_containers[layer_id].apply_tensor_parallelism( + mp_group=self.mp_group, tp_size=self._config.hybrid_engine.inference_tp_size) + + # TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache + # is enabled. + gc.collect() + get_accelerator().empty_cache() + + self._gather_latency = time.time() - self._t0 + + input_shape = inputs[0].shape if len(inputs) > 0 else \ + kwargs['input_ids'].shape + output = torch.zeros( + (input_shape[0] * self._config.hybrid_engine.inference_tp_size, ) + input_shape[1:], + dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype, + device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device) + input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous() + dist.all_gather_base(output, input_cont, group=self.mp_group) + + if len(inputs) > 0: + inputs = (output, ) + else: + kwargs['input_ids'] = output + + self.retake_inference_cache() + + non_active_params = get_inactive_params(non_tp_params) + with GatheredParameters(non_active_params): + generate_ret_vals = self._generate(*inputs, **kwargs) + + for layer_id in range(len(self.layer_params)): + self._inference_containers[layer_id].release_memory() + + rank = dist.get_rank(group=self.mp_group) + generate_ret_vals = generate_ret_vals[input_shape[0] * rank:input_shape[0] * (rank + 1)] + + else: + non_active_layers = get_inactive_params(self.all_layers_params) + non_active_lora_params = get_inactive_params(self.all_lora_params) + non_active_layers.extend(non_active_lora_params) + with GatheredParameters(non_active_layers): + self._gather_latency = time.time() - self._t0 + + if len(self.all_lora_params) > 0: + self.fuse_lora_weight() + + self.retake_inference_cache() + generate_ret_vals = self._generate(*inputs, **kwargs) + + if len(self.all_lora_params) > 0: + self.unfuse_lora_weight() + else: + if len(self.all_lora_params) > 0 and (not self.Z3_enabled): + self.fuse_lora_weight() + + self.retake_inference_cache() + generate_ret_vals = self._generate(*inputs, **kwargs) + + if len(self.all_lora_params) > 0: + if (not self.Z3_enabled): + self.unfuse_lora_weight() + else: + self.unfuse_lora_weight_non_pinned() + self.is_lora_fused = False + + if self._config.hybrid_engine.release_inference_cache: + inference_cuda_module.release_workspace() + gc.collect() + get_accelerator().empty_cache() + + self._generate_latency = time.time() - self._t0 - self._gather_latency + + return generate_ret_vals + + def create_inference_containers(self, module, layer_id=0): + for name, child in module.named_children(): + if child.__class__ in self.inference_policies: + if self.inference_policies[child.__class__][0] == self.new_inference_container: + self._inference_containers.append(self.inference_policies[child.__class__][0]( + child, self.inference_policies[child.__class__][-1], layer_id)) + self._orig_modules.append(child) + self._orig_fwds.append(child.forward) + + self.layer_params.append(self._inference_containers[layer_id].get_all_params()) + + self.lora_params.append(self._inference_containers[layer_id].get_lora_params()) + self.layer_lora_params.append([]) + for lora_param in self.lora_params[layer_id]: + self.layer_lora_params[layer_id].extend(lora_param[:-1]) + self.all_lora_params.extend(lora_param[:-1]) + + layer_id += 1 + else: + self._other_layers.append(self.inference_policies[child.__class__][0]( + weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None)) + self._orig_modules_others.append(child) + self._orig_fwds_others.append(child.forward) + else: + self.create_inference_containers(child, layer_id=layer_id) + + def create_inference_module(self): + self.layer_params = [] + self.layer_lora_params = [] + self.lora_params = [] + self.all_lora_params = [] + + self._other_layers = [] + self._orig_modules_others = [] + self._orig_fwds_others = [] + + if self._config.hybrid_engine.inference_tp_size > 1: + global_rank = dist.get_rank() + world_size = dist.get_world_size() + mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size + num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size + for mp_group_id in range(num_mp_groups): + ranks = list( + range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \ + (mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \ + 1) + ) + mp_group = dist.new_group(ranks) + if global_rank in ranks: + self.mp_group = mp_group + else: + self.mp_group = None + self.populate_all_inference_policies() + self.all_layers_params = list(self.module.parameters()) + self.create_inference_containers(self.module) + + self._generate = self.module.generate + self.module.generate = self.generate + + self._t0 = time.time() + + def _zero3_forward(self, layer_id): + + def run_forward(*inputs, **kwargs): + non_active_params = get_inactive_params(self.layer_params[layer_id]) + non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id]) + non_active_params.extend(non_active_lora_params) + + with GatheredParameters(non_active_params): + if len(self.all_lora_params) > 0: + # Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory + if not self.is_lora_fused: + self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id]) + # Set the is_lora_fused to true when reaching the last layer + if layer_id == len(self.layer_params) - 1: + self.is_lora_fused = True + return self._inference_containers[layer_id].module.forward(*inputs, **kwargs) + + return run_forward + + def eval(self): + if self._t_start is not None: + latency = time.time() - self._t_start + self._total_latency = self._total_latency + latency + self._iters = self._iters + 1 + if not dist.is_initialized() or dist.get_rank() == 0: + others = latency - (self._generate_latency + self._training_latency) + print(f'|E2E latency={(latency):.2f}s ' + \ + f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) ' + f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \ + f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \ + f'|Others={others:.2f} ({(others / latency * 100):.2f}%)' + f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \ + f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}') + self._t_start = time.time() + self._training_latency = 0 + super().eval() + if len(self._inference_containers) > 0: + for i, (orig_module, inference_container) in enumerate(zip(self._orig_modules, + self._inference_containers)): + if self.Z3_enabled and not self.gather_all_layers: + orig_module.forward = self._zero3_forward(i) + else: + orig_module.forward = inference_container.module.forward + + if not self.Z3_enabled or self.gather_all_layers: + for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers): + orig_module.forward = inference_layer.forward + if self.Z3_enabled: + gc.collect() + get_accelerator().empty_cache() + if self._t_start is None: + self._t_start = time.time() + + def train(self, mode=True): + if mode and len(self._orig_modules) > 0: + for orig_module, orig_fwd in zip(self._orig_modules, self._orig_fwds): + orig_module.forward = orig_fwd + for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others): + orig_module.forward = orig_fwd + super().train(mode) + if mode: + self._training_start_time = time.time() + + def step(self, lr_kwargs=None): + super().step(lr_kwargs=lr_kwargs) + if(self._inference_containers[0].module.attention.attn_qkvw is not None and \ + self._inference_containers[0].q_k_v is not None): + for inference_container in self._inference_containers: + inference_container.reset_qkv() + if self._training_start_time is not None: + self._training_latency += (time.time() - self._training_start_time) + self._training_start_time = time.time() diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index bbdbee192452..1cb6a7ceb76f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -30,6 +30,9 @@ from numpy import prod from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.policy import transpose +from torch.nn import functional as F + torch_memory_reserved = get_accelerator().memory_reserved torch_max_memory_reserved = get_accelerator().max_memory_reserved @@ -300,6 +303,7 @@ def get_global_norm(norm_list): total_norm = 0.0 for norm in norm_list: total_norm += norm**2.0 + # logger.info(f'norm_list = {norm_list} global = {sqrt(total_norm)}') return sqrt(total_norm) @@ -943,3 +947,28 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align shard_list.append(curr_shard) dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id]) + + +class TLinear(torch.nn.Linear): + + def __init__(self, orig_layer, name=""): + self.name = name + super().__init__(orig_layer.weight.shape[1], orig_layer.weight.shape[0], bias=(orig_layer.bias is not None)) + self.weight.data = transpose(orig_layer.weight.data) + self.bias = orig_layer.bias + self._fwd_func = self._fwd_bias_add if self.bias is not None else self._fwd + + def _fwd(self, input): + return F.linear(input, self.weight) + + def _fwd_bias_add(self, input): + return F.linear(input, self.weight, bias=self.bias) + + def forward(self, input): + return self._fwd_func(input) + + +def get_inactive_params(param_list): + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + return [param for param in param_list if (hasattr(param, 'ds_id') and \ + param.ds_status == ZeroParamStatus.NOT_AVAILABLE)] diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 4ad508513816..cc204afa7bf2 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -34,7 +34,8 @@ "offload_param": {...}, "offload_optimizer": {...}, "ignore_unused_parameters": [true|false], - "round_robin_gradients": [true|false] + "round_robin_gradients": [true|false], + "memory_efficient_linear": [true|false] } } """ @@ -248,6 +249,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): between optimizer steps) or GPU count (increased parallelism). """ + memory_efficient_linear: bool = True + """ + Use memory efficient linear implementation, for Stage 3. + """ + # Validators @validator("overlap_comm") def overlap_comm_valid(cls, field_value, values): diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index d5c7b8d3176c..b97a833beacb 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -49,7 +49,6 @@ class LinearFunctionForZeroStage3(torch.autograd.Function): @autocast_custom_fwd # bias is an optional argument def forward(ctx, input, weight, bias=None): - #print("In ZeRO Linear Function") weight_id = id(weight) bias_id = id(bias) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index eabff446bfa5..1bf4540d1439 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -461,7 +461,8 @@ def post_sub_module_forward_function(self, sub_module): @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - param_coordinator = self.get_param_coordinator(training=sub_module.training) + assert sub_module.training, "backward pass is invalid for module in evaluation mode" + param_coordinator = self.get_param_coordinator(training=True) param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) @@ -469,11 +470,12 @@ def pre_sub_module_backward_function(self, sub_module): @torch.no_grad() def post_sub_module_backward_function(self, sub_module): + assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - self.get_param_coordinator(training=sub_module.training).release_sub_module(sub_module) + self.get_param_coordinator(training=True).release_sub_module(sub_module) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 9b572a0227a2..84e628ef487c 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -11,7 +11,6 @@ import functools import itertools from typing import List - import torch from torch import Tensor from deepspeed import comm as dist @@ -688,9 +687,10 @@ def get_model(): config_dict_or_path = config logger.warning( f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.') - _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) if config_dict_or_path is not None else None + if _ds_config is not None: + mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) if not dist.is_initialized(): init_distributed() diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index fa40a3476d92..949c54f5e806 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -138,10 +138,20 @@ def _invalidate_trace(self) -> None: def trace_prologue(self, sub_module: Module) -> None: if self.is_complete_trace(): # sub_module must match expectation else invalidate trace cache + if len(self.__submodule_order) <= self.__step_id: + print_rank_0( + f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: " + f"cache has only {len(self.__submodule_order)} modules", + force=True) + self._invalidate_trace() + return + if sub_module != self.__submodule_order[self.__step_id]: expected_module_id = self.__submodule_order[self.__step_id].id - debug_rank0(f"Invalidate trace cache @ step {self.__step_id}: " - f"expected module {expected_module_id}, but got module {sub_module.id}") + print_rank_0( + f"Invalidate trace cache @ step {self.__step_id}: " + f"expected module {expected_module_id}, but got module {sub_module.id}", + force=True) self._invalidate_trace() def record_module(self, sub_module: Module) -> None: @@ -187,7 +197,9 @@ def reset_step(self) -> None: self.__submodule_order = tuple(self.__submodule_order) # freeze self.__param_order = tuple(self.__param_order) # freeze self.__trace_mode = ZeRoTraceMode.COMPLETE - print_rank_0(f"completed record trace: {[m.id for m in self.__submodule_order]}", force=False) + print_rank_0( + f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}", + force=False) else: # Enable trace recording for next forward/backward pass self.__trace_mode = ZeRoTraceMode.RECORD diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b3a66f25f1f2..e3b6be65ed2b 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -33,7 +33,7 @@ def print_rank_0(message, debug=False, force=False): rank = dist.get_rank() if rank == 0 and (debug or force): - print(message) + logger.info(message) # other variations # - print for all ranks w/o interleaving # printflock(f"[{rank}] {message}") @@ -1015,6 +1015,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.__reduce_and_partition_ipg_grads() param_id = self.get_param_id(param) + assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ Gradient computed twice for this partition. \ diff --git a/tests/hybrid_engine/hybrid_engine_config.json b/tests/hybrid_engine/hybrid_engine_config.json new file mode 100644 index 000000000000..1d418ae8e019 --- /dev/null +++ b/tests/hybrid_engine/hybrid_engine_config.json @@ -0,0 +1,19 @@ +{ + "train_batch_size" : 32, + "train_micro_batch_size_per_gpu": 2, + "steps_per_print": 10, + "zero_optimization": { + "stage": 0, + "offload_param": { + "device": "cpu" + }, + "stage3_param_persistence_threshold": 0 + }, + "fp16":{ + "enabled": true, + "loss_scale_window": 100 + }, + "gradient_clipping": 1.0, + "prescale_gradients": false, + "wall_clock_breakdown" : false +} diff --git a/tests/hybrid_engine/hybrid_engine_test.py b/tests/hybrid_engine/hybrid_engine_test.py new file mode 100644 index 000000000000..1b8958a387e7 --- /dev/null +++ b/tests/hybrid_engine/hybrid_engine_test.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from transformers import AutoModelForCausalLM +import deepspeed +import argparse +from deepspeed.accelerator import get_accelerator + +deepspeed.runtime.utils.see_memory_usage('pre test', force=True) + +model = AutoModelForCausalLM.from_pretrained('facebook/opt-350M').half().to(get_accelerator().device_name()) +parser = argparse.ArgumentParser() +parser = deepspeed.add_config_arguments(parser) +args = parser.parse_args() + +deepspeed.runtime.utils.see_memory_usage('post test', force=True) + +m, _, _, _ = deepspeed.initialize(model=model, args=args, enable_hybrid_engine=True) + +m.eval() +input = torch.ones(1, 16, device='cuda', dtype=torch.long) +out = m(input) + +m.train() +out = m(input) +print(out['logits'], out['logits'].norm()) From 65344e64c60034d9332ab2fc01ed26fdd32099e9 Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Mon, 10 Apr 2023 18:15:24 -0700 Subject: [PATCH 2/3] Use current AG method from comm (#492) --- deepspeed/runtime/hybrid_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py index c8e36b464fdc..3d7538ac4f3b 100644 --- a/deepspeed/runtime/hybrid_engine.py +++ b/deepspeed/runtime/hybrid_engine.py @@ -212,7 +212,7 @@ def generate(self, *inputs, **kwargs): dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype, device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device) input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous() - dist.all_gather_base(output, input_cont, group=self.mp_group) + dist.all_gather_into_tensor(output, input_cont, group=self.mp_group) if len(inputs) > 0: inputs = (output, ) From 36a3659c28a5b1500a2317c5c2ebdd6383671c77 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 11 Apr 2023 09:30:07 -0700 Subject: [PATCH 3/3] bump to 0.9.0 (#498) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index ee94dd834b53..ac39a106c485 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.8.3 +0.9.0