Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed Chat #3186

Merged
merged 6 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self):
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)

for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder':
Expand Down
10 changes: 5 additions & 5 deletions csrc/includes/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ inline int DS_GET_BLOCKS(const int N)
1);
}

class Context {
class TrainingContext {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
TrainingContext() : _workspace(nullptr), _seed(42), _curr_offset(0)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
Expand All @@ -57,15 +57,15 @@ class Context {
}
}

virtual ~Context()
virtual ~TrainingContext()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
}

static Context& Instance()
static TrainingContext& Instance()
{
static Context _ctx;
static TrainingContext _ctx;
return _ctx;
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/includes/cpu_adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class Adagrad_Optimizer {
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#endif
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class Adam_Optimizer {
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));

_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#endif
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/fake_quantizer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ void launch_sr_fake_quantize_kernel(T* vals,
dim3 grid_dim(group_num);

uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);

sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
Expand Down Expand Up @@ -1011,7 +1011,7 @@ void launch_sr_fake_quantize_kernel_asym(T* vals,
dim3 grid_dim(group_num);

uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);

sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
Expand Down
6 changes: 3 additions & 3 deletions csrc/transformer/dropout_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ void launch_dropout(T* out,
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
if (bwd)
dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
total_count, ratio, vals, out, mask, seed);
Expand Down Expand Up @@ -625,7 +625,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS;

uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);

dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, bias, out, mask, seed);
Expand Down Expand Up @@ -847,7 +847,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS;

uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);

dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, input, residual, bias, out, mask, seed);
Expand Down
47 changes: 24 additions & 23 deletions csrc/transformer/ds_transformer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_stream(TrainingContext::Instance().GetCurrentStream()),
_cublasHandle(TrainingContext::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
Expand Down Expand Up @@ -183,7 +183,7 @@ void BertTransformerLayer<T>::Forward(unsigned bsz,

if (!_stochastic_mode) cudaStreamSynchronize(_stream);

T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
Expand Down Expand Up @@ -343,7 +343,7 @@ void BertTransformerLayer<T>::Backward(unsigned bsz,

if (!_stochastic_mode) cudaStreamSynchronize(_stream);

T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
Expand Down Expand Up @@ -609,25 +609,26 @@ int create_transformer_layer(unsigned layer_id,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
TrainingContext::Instance().SetSeed(seed);
TrainingContext::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);

auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
auto layer =
std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
TrainingContext::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);

s_transformer_layers[layer_id] = layer;

Expand Down Expand Up @@ -725,7 +726,7 @@ std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());

auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
Expand Down Expand Up @@ -909,7 +910,7 @@ std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());

auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
Expand Down
Loading