From 38461bc34893c293ae76b5398adfabbc2c29a5ee Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Tue, 29 Aug 2023 15:01:11 +0800 Subject: [PATCH] Add Bf16 Support (#136) --- .gitignore | 3 +- bmtrain/loss/_function.py | 74 ++++++--------- bmtrain/loss/cross_entropy.py | 30 ------ bmtrain/optim/_function.py | 81 ++++++++++++---- bmtrain/optim/adam.py | 63 +++++++------ bmtrain/optim/adam_offload.py | 83 +++++++++-------- bmtrain/optim/optim_manager.py | 6 +- csrc/bind.cpp | 18 ++-- csrc/cuda/adam_cuda.cu | 73 ++++++++++++++- csrc/cuda/bfloat16.cuh | 5 + csrc/cuda/cross_entropy.cu | 99 ++++++++++---------- csrc/cuda/has_inf_nan.cu | 84 ++++++++++++++--- csrc/cuda/reduce.cuh | 2 - csrc/include/adam_cpu.hpp | 163 ++++++++++++++++++++++----------- csrc/include/bind.hpp | 44 ++++++--- setup.py | 1 + tests/test_all.py | 2 + tests/test_has_inf_nan.py | 13 ++- tests/test_loss_func.py | 40 ++++---- tests/test_nccl_backward.py | 7 +- tests/test_optim.py | 87 ++++++++++++------ tests/test_optim_state.py | 2 +- 22 files changed, 613 insertions(+), 367 deletions(-) create mode 100644 csrc/cuda/bfloat16.cuh diff --git a/.gitignore b/.gitignore index 0222862f..2e8c0dcd 100644 --- a/.gitignore +++ b/.gitignore @@ -150,4 +150,5 @@ log .vscode !bmtrain/dist -tests/test_log.txt \ No newline at end of file +tests/test_log.txt +tests/*.opt \ No newline at end of file diff --git a/bmtrain/loss/_function.py b/bmtrain/loss/_function.py index 658ef242..e2b67bb8 100644 --- a/bmtrain/loss/_function.py +++ b/bmtrain/loss/_function.py @@ -2,16 +2,20 @@ from .. import C import torch CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda -def has_inf_nan(g_fp16: torch.Tensor, out: torch.Tensor) -> None: - assert g_fp16.dtype == torch.float16, "g_fp16 must be a half tensor" +def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None: assert out.dtype == torch.uint8, "out must be a uint8 tensor" - assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda" assert CHECK_INPUT(out), "out must be contiguous and on cuda" mid = torch.zeros(1024, device=out.device, dtype=out.dtype) stream = torch.cuda.current_stream().cuda_stream - C.has_nan_inf_launcher(g_fp16.numel(), g_fp16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) - - + if g_half.dtype == torch.float16: + C.has_nan_inf_fp16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + elif g_half.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.has_nan_inf_bf16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + else: + raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}") def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor, softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None: @@ -19,9 +23,7 @@ def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Ten CHECK_INPUT(target) CHECK_INPUT(softmax) CHECK_INPUT(output) - assert input.dtype == torch.float16, "input must be a half tensor" assert target.dtype == torch.int32, "target must be an int tensor" - assert softmax.dtype == torch.float16, "softmax must be a half tensor" assert output.dtype == torch.float32, "output must be a float tensor" assert input.numel() == softmax.numel(), "input and softmax must have the same number of elements" assert target.numel() == output.numel(), "target and output must have the same number of elements" @@ -30,43 +32,14 @@ def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Ten softmax_ptr = softmax.data_ptr() output_ptr = output.data_ptr() cuda_stream = torch.cuda.current_stream().cuda_stream - C.cross_entropy_forward_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) - -def cross_entropy_backward(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, - softmax: torch.Tensor, grad_input: torch.Tensor, ignore_index: int) -> None: - CHECK_INPUT(grad_output) - CHECK_INPUT(target) - CHECK_INPUT(softmax) - CHECK_INPUT(grad_input) - assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" - assert target.dtype == torch.int32, "target must be an int tensor" - assert softmax.dtype == torch.float16, "softmax must be a half tensor" - assert grad_input.dtype == torch.float16, "grad_input must be a half tensor" - assert grad_input.numel() == softmax.numel(), "grad_input and softmax must have the same number of elements" - assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" - grad_output_ptr = grad_output.data_ptr() - target_ptr = target.data_ptr() - softmax_ptr = softmax.data_ptr() - grad_input_ptr = grad_input.data_ptr() - cuda_stream = torch.cuda.current_stream().cuda_stream - C.cross_entropy_backward_launcher(m, n, grad_output_ptr, target_ptr, softmax_ptr, grad_input_ptr, ignore_index, cuda_stream) - -def cross_entropy_forward_inplace(m: int, n: int, x: torch.Tensor, target: torch.Tensor, - output: torch.Tensor, ignore_index: int) -> None: - CHECK_INPUT(x) - CHECK_INPUT(target) - CHECK_INPUT(output) - assert x.dtype == torch.float16, "x must be a half tensor" - assert target.dtype == torch.int32, "target must be an int tensor" - assert output.dtype == torch.float32, "output must be a float tensor" - assert target.numel() == output.numel(), "target and output must have the same number of elements" - cuda_stream = torch.cuda.current_stream().cuda_stream - x_ptr = x.data_ptr() - output_ptr = output.data_ptr() - target_ptr = target.data_ptr() - output_ptr = output.data_ptr() - - C.cross_entropy_forward_inplace_launcher(m, n, x_ptr, target_ptr, output_ptr, ignore_index, cuda_stream) + if input.dtype == torch.float16: + C.cross_entropy_forward_fp16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + elif input.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.cross_entropy_forward_bf16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + else: + raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}") def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, x: torch.Tensor, ignore_index: int) -> None: @@ -75,12 +48,17 @@ def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, ta CHECK_INPUT(x) assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" assert target.dtype == torch.int32, "target must be an int tensor" - assert x.dtype == torch.float16, "x must be a half tensor" assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" cuda_stream = torch.cuda.current_stream().cuda_stream grad_output_ptr = grad_output.data_ptr() target_ptr = target.data_ptr() x_ptr = x.data_ptr() - C.cross_entropy_backward_inplace_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) - + if x.dtype == torch.float16: + C.cross_entropy_backward_inplace_fp16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + elif x.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.cross_entropy_backward_inplace_bf16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + else: + raise ValueError(f"cross_entropy_backward not supported for dtype {input.dtype}") diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 982a6469..a2e123ad 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -36,36 +36,6 @@ def backward(ctx, grad_output : torch.Tensor): ) return (softmax, None, None) -class OpFusedCrossEntropyInplace(torch.autograd.Function): - """ - CrossEntropy dim = 1 - """ - @staticmethod - def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): - assert x.ndim == 2 - out = torch.empty(x.size(0), device=x.device, dtype=torch.float) - F.cross_entropy_forward_inplace( - x.size(0), x.size(1), - x, target, - out, - ignore_index, - ) # x is inplace modify to softmax result - ctx.ignore_index = ignore_index - ctx.save_for_backward(x, target) - return out # float tensor - - @staticmethod - def backward(ctx, grad_output : torch.Tensor): - grad_output = grad_output.contiguous() - softmax, target = ctx.saved_tensors - F.cross_entropy_backward_inplace( - softmax.size(0), softmax.size(1), - grad_output, target, - softmax, - ctx.ignore_index, - ) # softmax is inplace modify to grad_input - return (softmax, None, None) - class FusedCrossEntropy(torch.nn.Module): r"""This criterion computes the cross entropy loss between input and target. diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index ee4b04a7..f04f9ca0 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -1,4 +1,3 @@ - from .. import C import torch CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda @@ -11,8 +10,8 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T assert m_fp32.is_contiguous(), "m_fp32 must be contiguous" assert v_fp32.is_contiguous(), "v_fp32 must be contiguous" assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" - assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor" - assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor" + assert param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16, "param_fp16 must be float16/bfloat16 tensor" + assert g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16, "g_fp16 must be float16/bfloat16 tensor" assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor" assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor" @@ -26,22 +25,28 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step - C.adam_cpu_launcher( - param_fp32.numel(), - param_fp32.data_ptr(), - param_fp16.data_ptr(), - g_fp16.data_ptr(), - m_fp32.data_ptr(), - v_fp32.data_ptr(), - beta1, beta2, - eps, lr, - scale, - weight_decay, - bias_correction1, - bias_correction2, - ) + if g_fp16.dtype == torch.float16: + launcher = C.adam_cpu_fp16_launcher + elif g_fp16.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + launcher = C.adam_cpu_bf16_launcher + launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_fp16.data_ptr(), + g_fp16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, beta2, + eps, lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + ) -def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor, +def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor, v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, weight_decay: float, step: int) -> None: assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" @@ -61,7 +66,7 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step stream = torch.cuda.current_stream().cuda_stream - C.adam_launcher( + C.adam_fp16_launcher( param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr(), @@ -76,3 +81,41 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso bias_correction2, stream ) + +def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor, + v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, + weight_decay: float, step: int) -> None: + assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda" + assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda" + assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor" + assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor" + assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements" + assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_fp16 must have the same number of elements" + assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_m_fp32 must have the same number of elements" + assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + stream = torch.cuda.current_stream().cuda_stream + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.adam_bf16_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_bf16.data_ptr(), + g_bf16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, beta2, + eps, lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + stream + ) \ No newline at end of file diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index b63a4f51..a3138980 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -40,8 +40,9 @@ def _on_justify_scale(self, old_scale, new_scale): if p in self.state: state = self.state[p] if len(state) > 0: - state['exp_avg'] *= delta - state['exp_avg_sq'] *= delta + if p.dtype == torch.float16: + state['exp_avg'] *= delta + state['exp_avg_sq'] *= delta @torch.no_grad() def step(self, closure=None, scale=1): @@ -63,45 +64,32 @@ def step(self, closure=None, scale=1): if p.grad is not None and p.requires_grad: if p.grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('Adam only supports fp32 or fp16 gradients') + if p.dtype not in [torch.float32, torch.half, torch.bfloat16]: + raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients') state = self.state[p] # Lazy state initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros(p.size(), dtype=p.dtype, device=p.device) # on device + if p.dtype == torch.float16: + state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float16, device=p.device) # on device + else: + state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device - - if p.dtype == torch.half: + state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device)# on device + + if p.dtype != torch.float32: state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device state['_param_fp32'].copy_(p) # update the steps for each param group update - state['step'] += 1 - if ('maximize' in group) and (group['maximize'] is True): grad = -p.grad else: grad = p.grad - if p.dtype == torch.half: - F.adam( - state["_param_fp32"], # fp32 - p, # fp16 - grad, # fp16 - state['exp_avg'], # fp16: m - state["exp_avg_sq"], # fp32: v - group['betas'][0], group['betas'][1], - group['eps'], - 0.0 if state["step"] <= self._hold_steps else group['lr'], - scale, - group['weight_decay'], - state['step'] - ) - else: + if p.dtype == torch.float32: other_kwargs = {} if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters: other_kwargs['maximize'] = False @@ -116,11 +104,30 @@ def step(self, closure=None, scale=1): amsgrad=False, beta1=group['betas'][0], beta2=group['betas'][1], - lr=0.0 if state["step"] <= self._hold_steps else group['lr'], + lr=0.0 if state["step"] < self._hold_steps else group['lr'], weight_decay=group['weight_decay'], eps=group['eps'], **other_kwargs ) + state['step'] += 1 + else: + f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16 + state['step'] += 1 + f( + state["_param_fp32"], # fp32 + p, # fp16 + grad, # fp16 + state['exp_avg'], # fp16: m + state["exp_avg_sq"], # fp32: v + group['betas'][0], group['betas'][1], + group['eps'], + 0.0 if state["step"] < self._hold_steps else group['lr'], + scale, + group['weight_decay'], + state['step'] + ) + + return loss @@ -159,11 +166,11 @@ def load_state_dict(self, state_dict: dict) -> None: if k in id_map: param = id_map[k] - if param.dtype == torch.half and "_param_fp32" not in v: + if param.dtype != torch.float32 and "_param_fp32" not in v: v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device) v["_param_fp32"].copy_(param) - for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + for name, dtype in [("exp_avg", torch.float16 if param.dtype == torch.float16 else torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: if name in v: v[name] = v[name].to(param.device).to(dtype) diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index e33219bf..5b34a287 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -54,8 +54,8 @@ def step(self, closure=None, scale=1): if p.grad is not None and p.requires_grad: if p.grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('Adam only supports fp32 or fp16 gradients') + if p.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients') state = self.state[p] # Lazy state initialization @@ -66,19 +66,19 @@ def step(self, closure=None, scale=1): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device="cpu") # on host - if p.dtype == torch.half: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device="cpu") # on host + if p.dtype == torch.float32: + state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host state['_param_fp32'].copy_(p) # placeholder - state["_param_fp16"] = torch.empty(p.size(), dtype=torch.float16, pin_memory=True) # on host - state["_grad_fp16"] = torch.empty(p.size(), dtype=torch.float16, pin_memory=True) # on host + state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host else: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host + state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device="cpu") # on host state['_param_fp32'].copy_(p) # placeholder - state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host + state["_param_fp16"] = torch.empty(p.size(), dtype=p.dtype, pin_memory=True) # on host + state["_grad_fp16"] = torch.empty(p.size(), dtype=p.dtype, pin_memory=True) # on host if p not in self._events: self._events[p] = torch.cuda.Event() @@ -87,39 +87,18 @@ def step(self, closure=None, scale=1): # transfer parameters to host asynchronously for param, state, event, _, _, _, _, _ in update_params: - if param.dtype == torch.half: - state["_grad_fp16"].copy_(param.grad, non_blocking=True) - else: + if param.dtype == torch.float32: state["_grad_fp32"].copy_(param.grad, non_blocking=True) + else: + state["_grad_fp16"].copy_(param.grad, non_blocking=True) torch.cuda.current_stream().record_event(event) for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params: # wait for transfer to host event.synchronize() - state["step"] += 1 - # update parameters - if param.dtype == torch.half: - if ('maximize' in group) and (group['maximize'] is True): - grad = -state["_grad_fp16"] - else: - grad = state["_grad_fp16"] - F.adam_cpu( - state["_param_fp32"].view(-1), - state["_param_fp16"].view(-1), - grad.view(-1), - state["exp_avg"].view(-1), - state["exp_avg_sq"].view(-1), - beta1, beta2, - eps, 0.0 if state["step"] <= self._hold_steps else lr, - scale, - weight_decay, - state["step"] - ) - # transfer parameters back to device asynchronously - param.copy_(state["_param_fp16"], non_blocking=True) - else: + if param.dtype == torch.float32: state["_grad_fp32"].mul_(1.0 / scale) if ('maximize' in group) and (group['maximize'] is True): grad = -state["_grad_fp32"] @@ -139,13 +118,35 @@ def step(self, closure=None, scale=1): amsgrad=False, beta1=beta1, beta2=beta2, - lr=0.0 if state["step"] <= self._hold_steps else lr, + lr=0.0 if state["step"] < self._hold_steps else lr, weight_decay=weight_decay, eps=eps, **other_kwargs ) # transfer parameters back to device asynchronously param.copy_(state["_param_fp32"], non_blocking=True) + state["step"] += 1 + else: + state["step"] += 1 + if ('maximize' in group) and (group['maximize'] is True): + grad = -state["_grad_fp16"] + else: + grad = state["_grad_fp16"] + F.adam_cpu( + state["_param_fp32"].view(-1), + state["_param_fp16"].view(-1), + grad.view(-1), + state["exp_avg"].view(-1), + state["exp_avg_sq"].view(-1), + beta1, beta2, + eps, 0.0 if state["step"] < self._hold_steps else lr, + scale, + weight_decay, + state["step"] + ) + # transfer parameters back to device asynchronously + param.copy_(state["_param_fp16"], non_blocking=True) + return loss @@ -193,15 +194,14 @@ def load_state_dict(self, state_dict: dict) -> None: v[name] = v[name].to("cpu").to(dtype) state[param] = v - if param.dtype == torch.half: + if param.dtype == torch.float32: + state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() # on host # initialize placeholders - state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host - state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host else: - state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() - # initialize placeholders - state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host + state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host else: state[k] = v @@ -254,5 +254,4 @@ def cut_states(state): #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu def zero_grad(self, set_to_none: bool = False): - super().zero_grad(set_to_none=set_to_none) - + super().zero_grad(set_to_none=set_to_none) \ No newline at end of file diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 78ad15f8..9b7a3120 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -11,9 +11,9 @@ def check_overflow(param_groups): has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] for group in param_groups: for p in group['params']: - if p.grad is not None and p.dtype == torch.half: # TODO support other types - has_inf_nan(p.grad, has_inf_or_nan) - + if p.grad is not None: + if p.dtype != torch.float: + has_inf_nan(p.grad, has_inf_or_nan) if "comm" in config: nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"]) diff --git a/csrc/bind.cpp b/csrc/bind.cpp index 8324ba52..73f79a61 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -1,13 +1,17 @@ #include "include/bind.hpp" PYBIND11_MODULE(C, m) { - m.def("has_nan_inf_launcher",&has_nan_inf_launcher,"has nan inf"); - m.def("adam_launcher", &adam_launcher, "adam function cpu"); - m.def("adam_cpu_launcher", &adam_cpu_launcher, "adam function cpu"); - m.def("cross_entropy_forward_launcher", &cross_entropy_forward_launcher, "cross entropy forward"); - m.def("cross_entropy_backward_launcher", &cross_entropy_backward_launcher, "cross entropy backward"); - m.def("cross_entropy_forward_inplace_launcher", &cross_entropy_forward_inplace_launcher, "cross entropy forward inplace"); - m.def("cross_entropy_backward_inplace_launcher", &cross_entropy_backward_inplace_launcher, "cross entropy backward inplace"); + m.def("is_bf16_supported",&is_bf16_supported,"whether bf16 supported"); + m.def("has_nan_inf_fp16_launcher",&has_nan_inf_fp16_launcher,"has nan inf"); + m.def("has_nan_inf_bf16_launcher",&has_nan_inf_bf16_launcher,"has nan inf bf16"); + m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu"); + m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu"); + m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu"); + m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu"); + m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward"); + m.def("cross_entropy_forward_bf16_launcher", &cross_entropy_forward_bf16_launcher, "cross entropy forward"); + m.def("cross_entropy_backward_inplace_fp16_launcher", &cross_entropy_backward_inplace_fp16_launcher, "cross entropy backward inplace"); + m.def("cross_entropy_backward_inplace_bf16_launcher", &cross_entropy_backward_inplace_bf16_launcher, "cross entropy backward inplace"); m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID"); m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank"); m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank"); diff --git a/csrc/cuda/adam_cuda.cu b/csrc/cuda/adam_cuda.cu index 0ab55934..0510ac12 100644 --- a/csrc/cuda/adam_cuda.cu +++ b/csrc/cuda/adam_cuda.cu @@ -1,5 +1,7 @@ -#include #include +#include +#include +#include "bfloat16.cuh" namespace { // blocks , threads @@ -8,8 +10,8 @@ __global__ void adam_fp32_accum( const half *g, // (n) half *m, // (n) float *v, // (n) - float* param, // (n) - half* param_h, // (n) + float *param, // (n) + half *param_h, // (n) float beta1, float beta2, float eps, @@ -33,9 +35,45 @@ __global__ void adam_fp32_accum( m[col] = __float2half(local_m); } } + +__global__ void adam_fp32_accum_bf16( + int32_t n, + const std::uintptr_t g_ptr, // (n) + float *m, // (n) + float *v, // (n) + float *param, // (n) + std::uintptr_t param_h_ptr, // (n) + float beta1, + float beta2, + float eps, + float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* g = reinterpret_cast(g_ptr); + __nv_bfloat16* param_h = reinterpret_cast<__nv_bfloat16*>(param_h_ptr); + int32_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (col < n) { + float local_g = __bfloat162float(g[col]) / scale; // real_g + float local_m = beta1 * m[col] + (1 - beta1) * local_g; // real_m + float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v + float local_p = param[col]; + local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p; + + param_h[col] = __float2bfloat16(local_p); + param[col] = local_p; + v[col] = local_v; + m[col] = local_m; + } +#endif +} + } -void adam_launcher( +void adam_fp16_launcher( int n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, @@ -60,4 +98,29 @@ void adam_launcher( dim3 block_size = dim3(threads, 1, 1); dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); adam_fp32_accum<<(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); -} \ No newline at end of file +} + +void adam_bf16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +) { + if (n <= 0) return; + auto m_ptr = reinterpret_cast(m_fp32); + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + adam_fp32_accum_bf16<<(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} diff --git a/csrc/cuda/bfloat16.cuh b/csrc/cuda/bfloat16.cuh new file mode 100644 index 00000000..564d8bec --- /dev/null +++ b/csrc/cuda/bfloat16.cuh @@ -0,0 +1,5 @@ +#include +#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +#include +#define BF16_SUPPORT +#endif \ No newline at end of file diff --git a/csrc/cuda/cross_entropy.cu b/csrc/cuda/cross_entropy.cu index c0b742ac..bdd5a08e 100644 --- a/csrc/cuda/cross_entropy.cu +++ b/csrc/cuda/cross_entropy.cu @@ -1,11 +1,12 @@ -#include #include "reduce.cuh" #include -#include +#include +#include +#include "bfloat16.cuh" namespace { // blocks , threads<1024> -__global__ void cross_entropy_forward( +__global__ void cross_entropy_forward_fp16( int64_t n, const half *input, // (m, n) const int32_t *target, // (m) @@ -42,12 +43,11 @@ __global__ void cross_entropy_forward( } // blocks , threads<1024> -__global__ void cross_entropy_backward( +__global__ void cross_entropy_backward_inplace_fp16( int64_t n, const float *grad_output, // (m) const int32_t *target, // (m) - const half *softmax, // (m, n) - half *grad_input, // (m, n) + half *x, // (m, n) int32_t ignore_index ) { int64_t base_idx = blockIdx.x * n; @@ -56,83 +56,99 @@ __global__ void cross_entropy_backward( if (t == ignore_index) { half v = __float2half(0.); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - grad_input[base_idx + i] = v; + x[base_idx + i] = v; } } else { half v = __float2half(grad_output[blockIdx.x]); + __syncthreads(); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - grad_input[base_idx + i] = i==t ? __hsub(__hmul(softmax[base_idx + i], v), v) : __hmul(softmax[base_idx + i], v); + x[base_idx + i] = i==t ? __hsub(__hmul(x[base_idx + i], v), v) : __hmul(x[base_idx + i], v); } } } // blocks , threads<1024> -__global__ void cross_entropy_forward_inplace( +__global__ void cross_entropy_forward_bf16( int64_t n, - half *x, // (m, n) + std::uintptr_t input_ptr, // (m, n) const int32_t *target, // (m) + std::uintptr_t softmax_ptr, // (m, n) float *output, // (m) int32_t ignore_index ) { +#ifdef BF16_SUPPORT + __nv_bfloat16* input = reinterpret_cast<__nv_bfloat16*>(input_ptr); + __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr); int64_t base_idx = blockIdx.x * n; float local_max = -INFINITY; for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - local_max = fmaxf(__half2float(x[base_idx + i]), local_max); + local_max = fmaxf(__bfloat162float(input[base_idx + i]), local_max); } + local_max = fmaxf(block_allreduce_max(local_max), -1e6); float local_sum = 0; for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - local_sum += expf(__half2float(x[base_idx + i]) - local_max); + local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max); } local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(input[base_idx + i]) - local_max) / local_sum ); + } if (threadIdx.x == 0) { if (target[blockIdx.x] != ignore_index) { - output[blockIdx.x] = -__half2float(x[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum); + output[blockIdx.x] = -__bfloat162float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum); } else { output[blockIdx.x] = 0; } } - - __syncthreads(); - - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = __float2half( expf(__half2float(x[base_idx + i]) - local_max) / local_sum ); - } +#endif } // blocks , threads<1024> -__global__ void cross_entropy_backward_inplace( +__global__ void cross_entropy_backward_inplace_bf16( int64_t n, const float *grad_output, // (m) const int32_t *target, // (m) - half *x, // (m, n) + std::uintptr_t x_ptr, // (m, n) int32_t ignore_index ) { +#ifdef BF16_SUPPORT + __nv_bfloat16* x = reinterpret_cast<__nv_bfloat16*>(x_ptr); int64_t base_idx = blockIdx.x * n; int32_t t = target[blockIdx.x]; if (t == ignore_index) { - half v = __float2half(0.); + __nv_bfloat16 v = __float2bfloat16(0.); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { x[base_idx + i] = v; } } else { - half v = __float2half(grad_output[blockIdx.x]); + #if __CUDA_ARCH__ >= 800 + __nv_bfloat16 v = __float2bfloat16(grad_output[blockIdx.x]); __syncthreads(); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { x[base_idx + i] = i==t ? __hsub(__hmul(x[base_idx + i], v), v) : __hmul(x[base_idx + i], v); } + #else + float v = grad_output[blockIdx.x]; + __syncthreads(); + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])*v)-v : __bfloat162float(x[base_idx + i])*v); + } + #endif } +#endif } } -void cross_entropy_forward_launcher( +void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, std::uintptr_t input, std::uintptr_t target, @@ -146,48 +162,40 @@ void cross_entropy_forward_launcher( auto softmax_ptr = reinterpret_cast(softmax); auto output_ptr = reinterpret_cast(output); int32_t threads = 1024; - cross_entropy_forward<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); + cross_entropy_forward_fp16<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); } -void cross_entropy_backward_launcher( +void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, - std::uintptr_t softmax, - std::uintptr_t grad_input, + std::uintptr_t x, int32_t ignore_index, std::uintptr_t stream ) { - // auto output_ptr = grad_output.data_ptr(); auto output_ptr = reinterpret_cast(grad_output); - // auto target_ptr = target.data_ptr(); auto target_ptr = reinterpret_cast(target); - auto softmax_ptr = reinterpret_cast(softmax); - auto input_ptr = reinterpret_cast(grad_input); + auto x_ptr = reinterpret_cast(x); int32_t threads = 1024; - cross_entropy_backward<<(stream)>>>(n, output_ptr, target_ptr, softmax_ptr, input_ptr, ignore_index); + cross_entropy_backward_inplace_fp16<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index); } -void cross_entropy_forward_inplace_launcher( +void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, - std::uintptr_t x, + std::uintptr_t input, std::uintptr_t target, + std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, std::uintptr_t stream ) { - // auto x_ptr = reinterpret_cast(x.data_ptr()); - auto x_ptr = reinterpret_cast(x); - // auto target_ptr = target.data_ptr(); auto target_ptr = reinterpret_cast(target); - // auto output_ptr = output.data_ptr(); auto output_ptr = reinterpret_cast(output); int32_t threads = 1024; - // auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_forward_inplace<<(stream)>>>(n, x_ptr, target_ptr, output_ptr, ignore_index); + cross_entropy_forward_bf16<<(stream)>>>(n, input, target_ptr, softmax, output_ptr, ignore_index); } -void cross_entropy_backward_inplace_launcher( +void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, @@ -195,13 +203,8 @@ void cross_entropy_backward_inplace_launcher( int32_t ignore_index, std::uintptr_t stream ) { - // auto output_ptr = grad_output.data_ptr(); auto output_ptr = reinterpret_cast(grad_output); - // auto target_ptr = target.data_ptr(); auto target_ptr = reinterpret_cast(target); - // auto x_ptr = reinterpret_cast(x.data_ptr()); - auto x_ptr = reinterpret_cast(x); int32_t threads = 1024; - // auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_backward_inplace<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index); + cross_entropy_backward_inplace_bf16<<(stream)>>>(n, output_ptr, target_ptr, x, ignore_index); } \ No newline at end of file diff --git a/csrc/cuda/has_inf_nan.cu b/csrc/cuda/has_inf_nan.cu index b0e906ff..32bc5a5f 100644 --- a/csrc/cuda/has_inf_nan.cu +++ b/csrc/cuda/has_inf_nan.cu @@ -1,16 +1,18 @@ -#include #include +#include +#include +#include +#include "bfloat16.cuh" namespace{ __inline__ __device__ bool isnan_(half v) { -#if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600 - return __hisnan(v); -#else - - return !__heq(v, v); -#endif + #if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600 + return __hisnan(v); + #else + return !__heq(v, v); + #endif } - + __inline__ __device__ int8_t warpReduceAny(int8_t x) { for (int offset = warpSize/2; offset > 0; offset /= 2) x |= __shfl_down_sync(0xFFFFFFFF, x, offset); @@ -30,7 +32,7 @@ __inline__ __device__ float blockReduceAny(int8_t x) { } // grid , thread<1024> -__global__ void bmt_has_nan_inf_1( +__global__ void bmt_has_nan_inf_fp16( int32_t n, const half* inp, // (n,) uint8_t* mid // (1024,) @@ -53,7 +55,7 @@ __global__ void bmt_has_nan_inf_1( } // grid <1>, thread<1024> -__global__ void bmt_has_nan_inf_2( +__global__ void bmt_has_nan_inf_reduce( const uint8_t* mid, // (1024,) uint8_t* out ) { @@ -64,9 +66,39 @@ __global__ void bmt_has_nan_inf_2( } } +// grid , thread<1024> +__global__ void bmt_has_nan_inf_bf16( + int32_t n, + const uintptr_t inp, // (n,) + uint8_t* mid // (1024,) +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* bf_inp = reinterpret_cast(inp); + int32_t gid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t span = blockDim.x * gridDim.x; + + int8_t r = 0; + for (int i = gid; i < n; i += span) { + __nv_bfloat16 v = bf_inp[i]; + #if __CUDA_ARCH__ >= 800 + if (__hisinf(v) || __hisnan(v)) { + #else + if (isinf(__bfloat162float(v)) || isnan(__bfloat162float(v))) { + #endif + r = 1; + break; + } + } + r = blockReduceAny(r); + if (threadIdx.x == 0) { + mid[blockIdx.x] = r; + } +#endif +} + } -void has_nan_inf_launcher( +void has_nan_inf_fp16_launcher( int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, @@ -82,6 +114,32 @@ void has_nan_inf_launcher( dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); - bmt_has_nan_inf_1<<(stream)>>>(n, g_ptr, mid_ptr); - bmt_has_nan_inf_2<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); + bmt_has_nan_inf_fp16<<(stream)>>>(n, g_ptr, mid_ptr); + bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); +} + +void has_nan_inf_bf16_launcher( + int32_t n, + std::uintptr_t g_bf16, + std::uintptr_t mid, + std::uintptr_t out, + std::uintptr_t stream +) { + if (n <= 0) return; + auto mid_ptr = reinterpret_cast(mid); + auto out_ptr = reinterpret_cast(out); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); + + bmt_has_nan_inf_bf16<<(stream)>>>(n, g_bf16, mid_ptr); + bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); +} + +int is_bf16_supported() { +#ifdef BF16_SUPPORT + return 1; +#endif + return 0; } \ No newline at end of file diff --git a/csrc/cuda/reduce.cuh b/csrc/cuda/reduce.cuh index 095e8593..a9c4c15b 100644 --- a/csrc/cuda/reduce.cuh +++ b/csrc/cuda/reduce.cuh @@ -1,5 +1,3 @@ -#include - namespace { const int WARP_SZ = 32; diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index d95cf637..1e497bb3 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -2,10 +2,9 @@ #include #include #include -#include +#include #include -#include -#include +#include #include #include #include @@ -69,8 +68,7 @@ inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F } } - - +// fp32 -> fp16 inline uint16_t fp16_ieee_from_fp32_value(float f) { // const float scale_to_inf = 0x1.0p+112f; // const float scale_to_zero = 0x1.0p-110f; @@ -84,45 +82,55 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) { float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - const uint32_t w = (uint32_t)fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } + const uint32_t w = (uint32_t)fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = (uint32_t)fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return static_cast( - (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) - ); - } + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = (uint32_t)fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) + ); +} +// fp16 -> fp32 inline float fp16_ieee_to_fp32_value(uint16_t h) { - const uint32_t w = (uint32_t)h << 16; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t two_w = w + w; const uint32_t exp_offset = UINT32_C(0xE0) << 23; const float exp_scale = 0x1.0p-112f; - const float normalized_value = - fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; const uint32_t magic_mask = UINT32_C(126) << 23; const float magic_bias = 0.5f; - const float denormalized_value = - fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; const uint32_t denormalized_cutoff = UINT32_C(1) << 27; const uint32_t result = - sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) - : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +// fp32 -> bf16 +inline uint16_t bf16_from_fp32_value(float f){ + return *reinterpret_cast(&f) >> 16; +} + +// bf16 -> fp32 +inline float bf16_to_fp32_value(uint16_t h){ + uint32_t src = h; + src <<= 16; + return *reinterpret_cast(&src); } void adam_cpu_0( @@ -141,23 +149,58 @@ void adam_cpu_0( ){ int64_t span = 1; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { - for (int64_t j = start; j < end; j += span) { - for (int64_t i = j; i < end; i++) { - float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; - float m = m_fp32_ptr[i]; - float v = v_fp32_ptr[i]; - float p = param_fp32_ptr[i]; - m = beta1 * m + (1 - beta1) * g; - v = beta2 * v + (1 - beta2) * g * g; - p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; - param_fp32_ptr[i] = p; - param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); - m_fp32_ptr[i] = m; - v_fp32_ptr[i] = v; - } + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } break; // must break here } - }); + }); +} + +void adam_cpu_bf16_0( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_bf16_ptr, + uint16_t* g_bf16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + int64_t span = 1; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } + }); } static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( @@ -223,7 +266,8 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( _mm256_storeu_ps(&m_fp32_ptr[j], m); _mm256_storeu_ps(&v_fp32_ptr[j], v); } - }}); + } + }); } static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( @@ -293,13 +337,10 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( _mm512_storeu_ps(&v_fp32_ptr[j], v); } } - }); + }); } - - - -void adam_cpu_launcher( +void adam_cpu_fp16_launcher( int64_t n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, @@ -329,4 +370,24 @@ void adam_cpu_launcher( } } - +void adam_cpu_bf16_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto m_fp32_ptr = reinterpret_cast(m_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); + auto param_bf16_ptr = reinterpret_cast(param_bf16); + auto g_bf16_ptr = reinterpret_cast(g_bf16); + adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index 0929de91..94d6af95 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -1,19 +1,22 @@ -#include +#include #include "nccl.hpp" #include "adam_cpu.hpp" -void has_nan_inf_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); +int is_bf16_supported(); -void cross_entropy_backward_launcher( +void has_nan_inf_fp16_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); +void has_nan_inf_bf16_launcher(int32_t n,std::uintptr_t g_bf16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); + +void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, - std::uintptr_t grad_output, + std::uintptr_t input, std::uintptr_t target, std::uintptr_t softmax, - std::uintptr_t grad_input, + std::uintptr_t output, int32_t ignore_index, std::uintptr_t stream ); -void cross_entropy_backward_inplace_launcher( +void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, @@ -21,24 +24,24 @@ void cross_entropy_backward_inplace_launcher( int32_t ignore_index, std::uintptr_t stream ); - void cross_entropy_forward_inplace_launcher( +void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, - std::uintptr_t x, + std::uintptr_t input, std::uintptr_t target, + std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, std::uintptr_t stream ); -void cross_entropy_forward_launcher( +void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, - std::uintptr_t input, + std::uintptr_t grad_output, std::uintptr_t target, - std::uintptr_t softmax, - std::uintptr_t output, + std::uintptr_t x, int32_t ignore_index, std::uintptr_t stream ); -void adam_launcher( +void adam_fp16_launcher( int n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, @@ -52,4 +55,19 @@ void adam_launcher( float bias_correction1, float bias_correction2, uintptr_t stream +); +void adam_bf16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream ); \ No newline at end of file diff --git a/setup.py b/setup.py index 2bbb55d8..1bac037e 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,7 @@ def build_extension(self, ext): if os.path.exists(build_temp): shutil.rmtree(build_temp) os.makedirs(build_temp) + cmake_args += ["-DPython_ROOT_DIR=" + os.path.dirname(os.path.dirname(sys.executable))] subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) diff --git a/tests/test_all.py b/tests/test_all.py index fc9ab3e9..07be4077 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,6 +16,8 @@ ("dropout", 1), ("loss_func", 1), + ("optim", 1), + ("multi_return", 2), ("middle_hidden", 4), ("other_hidden", 4), diff --git a/tests/test_has_inf_nan.py b/tests/test_has_inf_nan.py index fda85515..93ac8118 100644 --- a/tests/test_has_inf_nan.py +++ b/tests/test_has_inf_nan.py @@ -1,5 +1,4 @@ from utils import * - import torch import bmtrain.loss._function as F import random @@ -9,9 +8,9 @@ def check(x, v): F.has_inf_nan(x, out) assert_eq(out.item(), v) -def test_main(): +def test_main(dtype): for i in list(range(1, 100)) + [1000]*10 + [10000]*10 + [100000]*10 + [1000000]*10: - x = torch.rand((i,)).half().cuda() + x = torch.rand((i,)).to(dtype).cuda() check(x, 0) p = random.randint(0, i-1) x[p] = x[p] / 0 @@ -27,6 +26,12 @@ def test_main(): p = random.randint(0, i-1) x[p] = x[p] / 0 check(x, 1) + print("That's right") if __name__ == "__main__": - test_main() + test_main(torch.float16) + print("==============================================================================") + try: + test_main(torch.bfloat16) + except NotImplementedError: + pass diff --git a/tests/test_loss_func.py b/tests/test_loss_func.py index a76be5f4..a448b6d1 100644 --- a/tests/test_loss_func.py +++ b/tests/test_loss_func.py @@ -27,59 +27,53 @@ def check(x, tgt, loss_func1, loss_func2, bigmodel=None): loss_2, grad_2 = run(x, tgt, loss_func2, bigmodel=bigmodel, use_float=True) assert_eq(grad_1.isnan().sum(), 0) assert_eq(grad_2.isnan().sum(), 0) + print(f"{(loss_1 - loss_2).abs().item():.6f} {(grad_1 - grad_2).abs().max().item():.6f}") assert_lt((loss_1 - loss_2).abs().item(), 1e-5) - assert_lt((grad_1 - grad_2).abs().max().item(), 1e-2) + assert_lt((grad_1 - grad_2).abs().max().item(), 1e-1) -def test_simple(): +def test_simple(dtype): loss_func1 = bmt.loss.FusedCrossEntropy() loss_func2 = torch.nn.CrossEntropyLoss() N = 32 * 512 for i in range(1, 10): C = i * 10 - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() check(x, tgt, loss_func1, loss_func2) for i in range(1, 10): C = i * 100 - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() check(x, tgt, loss_func1, loss_func2) for i in range(1, 31): C = i * 1000 - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() check(x, tgt, loss_func1, loss_func2) -def test_other(): +def test_other(dtype): N = 32 * 512 for i in range(1, 11): C = i * 10 weight = [i+1 for i in range(C)] random.shuffle(weight) weight = torch.tensor(weight, device="cuda") - loss_func1 = bmt.loss.FusedCrossEntropy(weight=weight.clone().half()) + loss_func1 = bmt.loss.FusedCrossEntropy(weight=weight.clone().to(dtype)) loss_func2 = torch.nn.CrossEntropyLoss(weight=weight.clone().float()) - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() mask = torch.randint(0, 2, (N,)).cuda().bool() tgt[mask] = -100 check(x, tgt, loss_func1, loss_func2) -def test_inplace(): - loss_func1 = bmt.loss.FusedCrossEntropy(inplace=True) - loss_func2 = torch.nn.CrossEntropyLoss() - N = 32 * 512 - - for i in range(1, 11): - C = i * 10 - bigmodel = torch.nn.Linear(5, C).cuda().half() - x = torch.randn(N, 5).cuda().half() - tgt = torch.randint(0, C, (N,)).cuda().long() - check(x, tgt, loss_func1, loss_func2, bigmodel=bigmodel) - if __name__ == "__main__": - test_other() - test_inplace() - test_simple() \ No newline at end of file + test_other(torch.float16) + test_simple(torch.float16) + print("==============================================================================") + try: + test_other(torch.bfloat16) + test_simple(torch.bfloat16) + except NotImplementedError: + pass \ No newline at end of file diff --git a/tests/test_nccl_backward.py b/tests/test_nccl_backward.py index 3dcd0560..5e7b22d8 100644 --- a/tests/test_nccl_backward.py +++ b/tests/test_nccl_backward.py @@ -3,8 +3,8 @@ import bmtrain as bmt import torch -def test_main(): - x = torch.full((1,), bmt.rank() + 1, dtype=torch.half, device="cuda").requires_grad_(True) +def test_main(dtype): + x = torch.full((1,), bmt.rank() + 1, dtype=dtype, device="cuda").requires_grad_(True) y = bmt.distributed.all_reduce(x, "prod").view(-1) loss = (y * y).sum() / 2 loss.backward() @@ -17,4 +17,5 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed() - test_main() \ No newline at end of file + test_main(torch.half) + test_main(torch.bfloat16) \ No newline at end of file diff --git a/tests/test_optim.py b/tests/test_optim.py index fdb64521..0aca8c31 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,3 +1,4 @@ +from utils import * import torch import bmtrain as bmt from bmtrain import optim @@ -5,55 +6,89 @@ class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() - self.fc1 = torch.nn.Linear(128, 128) + self.fc1 = torch.nn.Linear(128, 128, bias=False) self.fc2 = torch.nn.Linear(128, 128) self.fc3 = torch.nn.Linear(128, 128) self.fc4 = torch.nn.Linear(128, 128) self.fc5 = torch.nn.Linear(128, 128) self.param = torch.nn.Parameter(torch.empty(1237)) -def main(): - # FIXME: this test script is not working + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc4(x) + x = self.fc5(x) + return x + +def main(dtype): model1 = TestModule() model2 = TestModule() model3 = TestModule() + model4 = TestModule() + model5 = TestModule() state_dict = model1.state_dict() for kw in state_dict.keys(): state_dict[kw] = torch.randn_like(state_dict[kw]) - + model1.load_state_dict(state_dict) model2.load_state_dict(state_dict) model3.load_state_dict(state_dict) + model4.load_state_dict(state_dict) + model5.load_state_dict(state_dict) - model1 = model1.cuda() - model2 = model2.cuda() + model1 = model1.cuda().to(dtype) + model2 = model2.cuda().to(dtype) model3 = model3.cuda() - - opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) - opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) + model4 = model4.cuda() + model5 = model5.cuda() + + opt1 = bmt.optim.AdamOptimizer(model1.parameters(), lr=1) + opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), lr=1) + opt3 = torch.optim.Adam(model3.parameters(), lr=1) + opt4 = bmt.optim.AdamOptimizer(model4.parameters(), lr=1) + opt5 = bmt.optim.AdamOffloadOptimizer(model5.parameters(), lr=1) + + optim_manager = bmt.optim.OptimManager(loss_scale=4) + optim_manager.add_optimizer(opt1) + optim_manager.add_optimizer(opt2) + optim_manager.add_optimizer(opt3) + optim_manager.add_optimizer(opt4) + optim_manager.add_optimizer(opt5) for _ in range(100): - opt1.zero_grad() - opt2.zero_grad() - opt3.zero_grad() + optim_manager.zero_grad() - for p1, p2, p3 in zip(model1.parameters(), model2.parameters(), model3.parameters()): + for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): grad = torch.randn_like(p1) - p1.grad = grad - p2.grad = grad + p1.grad = grad.to(dtype) + p2.grad = grad.to(dtype) p3.grad = grad.float() - - opt1.step() - opt2.step() - opt3.step() + p4.grad = grad.float() + p5.grad = grad.float() + + optim_manager.step() + torch.cuda.synchronize() - for p1, p2, p3 in zip(model1.parameters(), model2.parameters(), model3.parameters()): - diff1 = torch.abs(p1 - p2).max() - diff2 = torch.abs(p1 - p3).max() - diff3 = torch.abs(p2 - p3).max() - print(diff1, diff2, diff3) + for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): + diff1 = torch.abs(p1 - p2).max().item() + diff2 = torch.abs(p1 - p3).max().item() + diff3 = torch.abs(p2 - p3).max().item() + diff4 = torch.abs(p3 - p4).max().item() + diff5 = torch.abs(p3 - p5).max().item() + print(f"{diff1:.6f}, {diff2:.6f}, {diff3:.6f}, {diff4:.6f}, {diff5:.6f}") + assert_lt(diff1, 1) + assert_lt(diff2, 1) + assert_lt(diff3, 1) + assert_eq(diff4, 0) + assert_lt(diff5, 0.00001) if __name__ == "__main__": - main() + bmt.init_distributed() + main(torch.float16) + print("==============================================================================") + try: + main(torch.bfloat16) + except NotImplementedError: + pass diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index cef06734..16833b42 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -10,7 +10,7 @@ def __init__(self): self.fc1 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 3072)) self.fc2 = bmt.BMTrainModelWrapper(torch.nn.Linear(3072, 1024)) self.fc3 = bmt.BMTrainModelWrapper(torch.nn.Linear(1024, 768)) - self.param = bmt.DistributedParameter(torch.empty(1237)) + self.param = bmt.DistributedParameter(torch.zeros(1237)) self.fc4 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 300)) self.fc5 = bmt.BMTrainModelWrapper(torch.nn.Linear(300, 768)) self.dropout = torch.nn.Dropout(0.0)