Skip to content

Commit

Permalink
Add Bf16 Support (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
Achazwl authored Aug 29, 2023
1 parent df43d6d commit 38461bc
Show file tree
Hide file tree
Showing 22 changed files with 613 additions and 367 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ log
.vscode

!bmtrain/dist
tests/test_log.txt
tests/test_log.txt
tests/*.opt
74 changes: 26 additions & 48 deletions bmtrain/loss/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@
from .. import C
import torch
CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
def has_inf_nan(g_fp16: torch.Tensor, out: torch.Tensor) -> None:
assert g_fp16.dtype == torch.float16, "g_fp16 must be a half tensor"
def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None:
assert out.dtype == torch.uint8, "out must be a uint8 tensor"
assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda"
assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda"
assert CHECK_INPUT(out), "out must be contiguous and on cuda"
mid = torch.zeros(1024, device=out.device, dtype=out.dtype)
stream = torch.cuda.current_stream().cuda_stream
C.has_nan_inf_launcher(g_fp16.numel(), g_fp16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)


if g_half.dtype == torch.float16:
C.has_nan_inf_fp16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)
elif g_half.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.has_nan_inf_bf16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)
else:
raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}")

def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor,
softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None:
CHECK_INPUT(input)
CHECK_INPUT(target)
CHECK_INPUT(softmax)
CHECK_INPUT(output)
assert input.dtype == torch.float16, "input must be a half tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert softmax.dtype == torch.float16, "softmax must be a half tensor"
assert output.dtype == torch.float32, "output must be a float tensor"
assert input.numel() == softmax.numel(), "input and softmax must have the same number of elements"
assert target.numel() == output.numel(), "target and output must have the same number of elements"
Expand All @@ -30,43 +32,14 @@ def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Ten
softmax_ptr = softmax.data_ptr()
output_ptr = output.data_ptr()
cuda_stream = torch.cuda.current_stream().cuda_stream
C.cross_entropy_forward_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream)

def cross_entropy_backward(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor,
softmax: torch.Tensor, grad_input: torch.Tensor, ignore_index: int) -> None:
CHECK_INPUT(grad_output)
CHECK_INPUT(target)
CHECK_INPUT(softmax)
CHECK_INPUT(grad_input)
assert grad_output.dtype == torch.float32, "grad_output must be a float tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert softmax.dtype == torch.float16, "softmax must be a half tensor"
assert grad_input.dtype == torch.float16, "grad_input must be a half tensor"
assert grad_input.numel() == softmax.numel(), "grad_input and softmax must have the same number of elements"
assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements"
grad_output_ptr = grad_output.data_ptr()
target_ptr = target.data_ptr()
softmax_ptr = softmax.data_ptr()
grad_input_ptr = grad_input.data_ptr()
cuda_stream = torch.cuda.current_stream().cuda_stream
C.cross_entropy_backward_launcher(m, n, grad_output_ptr, target_ptr, softmax_ptr, grad_input_ptr, ignore_index, cuda_stream)

def cross_entropy_forward_inplace(m: int, n: int, x: torch.Tensor, target: torch.Tensor,
output: torch.Tensor, ignore_index: int) -> None:
CHECK_INPUT(x)
CHECK_INPUT(target)
CHECK_INPUT(output)
assert x.dtype == torch.float16, "x must be a half tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert output.dtype == torch.float32, "output must be a float tensor"
assert target.numel() == output.numel(), "target and output must have the same number of elements"
cuda_stream = torch.cuda.current_stream().cuda_stream
x_ptr = x.data_ptr()
output_ptr = output.data_ptr()
target_ptr = target.data_ptr()
output_ptr = output.data_ptr()

C.cross_entropy_forward_inplace_launcher(m, n, x_ptr, target_ptr, output_ptr, ignore_index, cuda_stream)
if input.dtype == torch.float16:
C.cross_entropy_forward_fp16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream)
elif input.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.cross_entropy_forward_bf16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream)
else:
raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}")

def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor,
x: torch.Tensor, ignore_index: int) -> None:
Expand All @@ -75,12 +48,17 @@ def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, ta
CHECK_INPUT(x)
assert grad_output.dtype == torch.float32, "grad_output must be a float tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert x.dtype == torch.float16, "x must be a half tensor"
assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements"
cuda_stream = torch.cuda.current_stream().cuda_stream
grad_output_ptr = grad_output.data_ptr()
target_ptr = target.data_ptr()
x_ptr = x.data_ptr()

C.cross_entropy_backward_inplace_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream)

if x.dtype == torch.float16:
C.cross_entropy_backward_inplace_fp16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream)
elif x.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.cross_entropy_backward_inplace_bf16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream)
else:
raise ValueError(f"cross_entropy_backward not supported for dtype {input.dtype}")
30 changes: 0 additions & 30 deletions bmtrain/loss/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,6 @@ def backward(ctx, grad_output : torch.Tensor):
)
return (softmax, None, None)

class OpFusedCrossEntropyInplace(torch.autograd.Function):
"""
CrossEntropy dim = 1
"""
@staticmethod
def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int):
assert x.ndim == 2
out = torch.empty(x.size(0), device=x.device, dtype=torch.float)
F.cross_entropy_forward_inplace(
x.size(0), x.size(1),
x, target,
out,
ignore_index,
) # x is inplace modify to softmax result
ctx.ignore_index = ignore_index
ctx.save_for_backward(x, target)
return out # float tensor

@staticmethod
def backward(ctx, grad_output : torch.Tensor):
grad_output = grad_output.contiguous()
softmax, target = ctx.saved_tensors
F.cross_entropy_backward_inplace(
softmax.size(0), softmax.size(1),
grad_output, target,
softmax,
ctx.ignore_index,
) # softmax is inplace modify to grad_input
return (softmax, None, None)

class FusedCrossEntropy(torch.nn.Module):
r"""This criterion computes the cross entropy loss between input and target.
Expand Down
81 changes: 62 additions & 19 deletions bmtrain/optim/_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from .. import C
import torch
CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
Expand All @@ -11,8 +10,8 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T
assert m_fp32.is_contiguous(), "m_fp32 must be contiguous"
assert v_fp32.is_contiguous(), "v_fp32 must be contiguous"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor"
assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor"
assert param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16, "param_fp16 must be float16/bfloat16 tensor"
assert g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16, "g_fp16 must be float16/bfloat16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor"
Expand All @@ -26,22 +25,28 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T
assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
C.adam_cpu_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
g_fp16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
)
if g_fp16.dtype == torch.float16:
launcher = C.adam_cpu_fp16_launcher
elif g_fp16.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
launcher = C.adam_cpu_bf16_launcher
launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
g_fp16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
)

def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor,
def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
Expand All @@ -61,7 +66,7 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
stream = torch.cuda.current_stream().cuda_stream
C.adam_launcher(
C.adam_fp16_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
Expand All @@ -76,3 +81,41 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso
bias_correction2,
stream
)

def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor"
assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements"
assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_fp16 must have the same number of elements"
assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_m_fp32 must have the same number of elements"
assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
stream = torch.cuda.current_stream().cuda_stream
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.adam_bf16_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_bf16.data_ptr(),
g_bf16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
stream
)
63 changes: 35 additions & 28 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def _on_justify_scale(self, old_scale, new_scale):
if p in self.state:
state = self.state[p]
if len(state) > 0:
state['exp_avg'] *= delta
state['exp_avg_sq'] *= delta
if p.dtype == torch.float16:
state['exp_avg'] *= delta
state['exp_avg_sq'] *= delta

@torch.no_grad()
def step(self, closure=None, scale=1):
Expand All @@ -63,45 +64,32 @@ def step(self, closure=None, scale=1):
if p.grad is not None and p.requires_grad:
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('Adam only supports fp32 or fp16 gradients')
if p.dtype not in [torch.float32, torch.half, torch.bfloat16]:
raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients')

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(p.size(), dtype=p.dtype, device=p.device) # on device
if p.dtype == torch.float16:
state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float16, device=p.device) # on device
else:
state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device

if p.dtype == torch.half:
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device)# on device
if p.dtype != torch.float32:
state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device
state['_param_fp32'].copy_(p)

# update the steps for each param group update
state['step'] += 1

if ('maximize' in group) and (group['maximize'] is True):
grad = -p.grad
else:
grad = p.grad

if p.dtype == torch.half:
F.adam(
state["_param_fp32"], # fp32
p, # fp16
grad, # fp16
state['exp_avg'], # fp16: m
state["exp_avg_sq"], # fp32: v
group['betas'][0], group['betas'][1],
group['eps'],
0.0 if state["step"] <= self._hold_steps else group['lr'],
scale,
group['weight_decay'],
state['step']
)
else:
if p.dtype == torch.float32:
other_kwargs = {}
if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters:
other_kwargs['maximize'] = False
Expand All @@ -116,11 +104,30 @@ def step(self, closure=None, scale=1):
amsgrad=False,
beta1=group['betas'][0],
beta2=group['betas'][1],
lr=0.0 if state["step"] <= self._hold_steps else group['lr'],
lr=0.0 if state["step"] < self._hold_steps else group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
**other_kwargs
)
state['step'] += 1
else:
f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16
state['step'] += 1
f(
state["_param_fp32"], # fp32
p, # fp16
grad, # fp16
state['exp_avg'], # fp16: m
state["exp_avg_sq"], # fp32: v
group['betas'][0], group['betas'][1],
group['eps'],
0.0 if state["step"] < self._hold_steps else group['lr'],
scale,
group['weight_decay'],
state['step']
)



return loss

Expand Down Expand Up @@ -159,11 +166,11 @@ def load_state_dict(self, state_dict: dict) -> None:
if k in id_map:
param = id_map[k]

if param.dtype == torch.half and "_param_fp32" not in v:
if param.dtype != torch.float32 and "_param_fp32" not in v:
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device)
v["_param_fp32"].copy_(param)

for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
for name, dtype in [("exp_avg", torch.float16 if param.dtype == torch.float16 else torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
if name in v:
v[name] = v[name].to(param.device).to(dtype)

Expand Down
Loading

0 comments on commit 38461bc

Please sign in to comment.