Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bf16 Support #136

Merged
merged 31 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3210632
modified: bmtrain/optim/adam.py
Aug 2, 2023
4e4d32a
modified: add bf16 in adam.py
Aug 2, 2023
ef038ae
modified: bmtrain/optim/_function.py
Aug 2, 2023
a1e0e42
modified: bmtrain/optim/_function.py
JerryYin777 Aug 2, 2023
efe3cdc
modified: add bf16.h to csrc/cuda/cross_entropy.cu
JerryYin777 Aug 3, 2023
6125f39
modified: bmtrain/optim/_function.py
JerryYin777 Aug 3, 2023
e2fdcc5
modified: add adam_fp32_accum_bf16 function
JerryYin777 Aug 3, 2023
00eee39
modified: add adam_fp32_accum_bf16 function
JerryYin777 Aug 3, 2023
2eb6d72
modified: add adam_fp32_accum_bf16 function
JerryYin777 Aug 3, 2023
6dfd739
modified: bmtrain/loss/_function.py
JerryYin777 Aug 7, 2023
8a6f686
modified: add bf16 to is_nan_inf()
JerryYin777 Aug 7, 2023
0b22ed1
FIX: csrc/bind.cpp
JerryYin777 Aug 7, 2023
f8885cf
modified: tests/test_has_inf_nan.py
JerryYin777 Aug 7, 2023
9a9d526
modified: bmtrain/optim/_function.py
JerryYin777 Aug 7, 2023
77c3585
add pybind11 in Update other_requirements.txt
JerryYin777 Aug 7, 2023
5cc3611
Update adam_cuda.cu
JerryYin777 Aug 7, 2023
870c613
Merge branch 'OpenBMB:main' into main
JerryYin777 Aug 8, 2023
2b414b8
Update test_optim_bf16.py
JerryYin777 Aug 8, 2023
40441f8
Merge branch 'OpenBMB:main' into main
JerryYin777 Aug 8, 2023
c5f7e49
FIX
JerryYin777 Aug 9, 2023
148ed85
Merge branch 'main' of https://github.com/JerryYin777/BMTrain
JerryYin777 Aug 9, 2023
55839be
refactor has_inf_nan_bf16
Achazwl Aug 10, 2023
4d44dce
refactor has_inf_nan_bf16
Achazwl Aug 10, 2023
d008d7f
refactor adam_offload
Achazwl Aug 10, 2023
145d90f
refactor adam
Achazwl Aug 10, 2023
da14e7e
fix adam_cuda
Achazwl Aug 11, 2023
21d5218
test nccl
Achazwl Aug 11, 2023
d29b22a
fix optim state test
Achazwl Aug 11, 2023
4aeb638
fix cuda version if; refactor cross_entropy
Achazwl Aug 11, 2023
53d8f8b
fix bf16 not support info
Achazwl Aug 11, 2023
eef5542
Merge branch 'dev' into bf16
Achazwl Aug 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@ log
.vscode

!bmtrain/dist
tests/test_log.txt
tests/test_log.txt
tests/*.opt
74 changes: 26 additions & 48 deletions bmtrain/loss/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@
from .. import C
import torch
CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
def has_inf_nan(g_fp16: torch.Tensor, out: torch.Tensor) -> None:
assert g_fp16.dtype == torch.float16, "g_fp16 must be a half tensor"
def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None:
assert out.dtype == torch.uint8, "out must be a uint8 tensor"
assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda"
assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda"
assert CHECK_INPUT(out), "out must be contiguous and on cuda"
mid = torch.zeros(1024, device=out.device, dtype=out.dtype)
stream = torch.cuda.current_stream().cuda_stream
C.has_nan_inf_launcher(g_fp16.numel(), g_fp16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)


if g_half.dtype == torch.float16:
C.has_nan_inf_fp16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)
elif g_half.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.has_nan_inf_bf16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream)
else:
raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}")

def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor,
softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None:
CHECK_INPUT(input)
CHECK_INPUT(target)
CHECK_INPUT(softmax)
CHECK_INPUT(output)
assert input.dtype == torch.float16, "input must be a half tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert softmax.dtype == torch.float16, "softmax must be a half tensor"
assert output.dtype == torch.float32, "output must be a float tensor"
assert input.numel() == softmax.numel(), "input and softmax must have the same number of elements"
assert target.numel() == output.numel(), "target and output must have the same number of elements"
Expand All @@ -30,43 +32,14 @@ def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Ten
softmax_ptr = softmax.data_ptr()
output_ptr = output.data_ptr()
cuda_stream = torch.cuda.current_stream().cuda_stream
C.cross_entropy_forward_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream)

def cross_entropy_backward(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor,
softmax: torch.Tensor, grad_input: torch.Tensor, ignore_index: int) -> None:
CHECK_INPUT(grad_output)
CHECK_INPUT(target)
CHECK_INPUT(softmax)
CHECK_INPUT(grad_input)
assert grad_output.dtype == torch.float32, "grad_output must be a float tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert softmax.dtype == torch.float16, "softmax must be a half tensor"
assert grad_input.dtype == torch.float16, "grad_input must be a half tensor"
assert grad_input.numel() == softmax.numel(), "grad_input and softmax must have the same number of elements"
assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements"
grad_output_ptr = grad_output.data_ptr()
target_ptr = target.data_ptr()
softmax_ptr = softmax.data_ptr()
grad_input_ptr = grad_input.data_ptr()
cuda_stream = torch.cuda.current_stream().cuda_stream
C.cross_entropy_backward_launcher(m, n, grad_output_ptr, target_ptr, softmax_ptr, grad_input_ptr, ignore_index, cuda_stream)

def cross_entropy_forward_inplace(m: int, n: int, x: torch.Tensor, target: torch.Tensor,
output: torch.Tensor, ignore_index: int) -> None:
CHECK_INPUT(x)
CHECK_INPUT(target)
CHECK_INPUT(output)
assert x.dtype == torch.float16, "x must be a half tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert output.dtype == torch.float32, "output must be a float tensor"
assert target.numel() == output.numel(), "target and output must have the same number of elements"
cuda_stream = torch.cuda.current_stream().cuda_stream
x_ptr = x.data_ptr()
output_ptr = output.data_ptr()
target_ptr = target.data_ptr()
output_ptr = output.data_ptr()

C.cross_entropy_forward_inplace_launcher(m, n, x_ptr, target_ptr, output_ptr, ignore_index, cuda_stream)
if input.dtype == torch.float16:
C.cross_entropy_forward_fp16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream)
elif input.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.cross_entropy_forward_bf16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream)
else:
raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}")

def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor,
x: torch.Tensor, ignore_index: int) -> None:
Expand All @@ -75,12 +48,17 @@ def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, ta
CHECK_INPUT(x)
assert grad_output.dtype == torch.float32, "grad_output must be a float tensor"
assert target.dtype == torch.int32, "target must be an int tensor"
assert x.dtype == torch.float16, "x must be a half tensor"
assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements"
cuda_stream = torch.cuda.current_stream().cuda_stream
grad_output_ptr = grad_output.data_ptr()
target_ptr = target.data_ptr()
x_ptr = x.data_ptr()

C.cross_entropy_backward_inplace_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream)

if x.dtype == torch.float16:
C.cross_entropy_backward_inplace_fp16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream)
elif x.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.cross_entropy_backward_inplace_bf16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream)
else:
raise ValueError(f"cross_entropy_backward not supported for dtype {input.dtype}")
30 changes: 0 additions & 30 deletions bmtrain/loss/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,6 @@ def backward(ctx, grad_output : torch.Tensor):
)
return (softmax, None, None)

class OpFusedCrossEntropyInplace(torch.autograd.Function):
"""
CrossEntropy dim = 1
"""
@staticmethod
def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int):
assert x.ndim == 2
out = torch.empty(x.size(0), device=x.device, dtype=torch.float)
F.cross_entropy_forward_inplace(
x.size(0), x.size(1),
x, target,
out,
ignore_index,
) # x is inplace modify to softmax result
ctx.ignore_index = ignore_index
ctx.save_for_backward(x, target)
return out # float tensor

@staticmethod
def backward(ctx, grad_output : torch.Tensor):
grad_output = grad_output.contiguous()
softmax, target = ctx.saved_tensors
F.cross_entropy_backward_inplace(
softmax.size(0), softmax.size(1),
grad_output, target,
softmax,
ctx.ignore_index,
) # softmax is inplace modify to grad_input
return (softmax, None, None)

class FusedCrossEntropy(torch.nn.Module):
r"""This criterion computes the cross entropy loss between input and target.

Expand Down
81 changes: 62 additions & 19 deletions bmtrain/optim/_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from .. import C
import torch
CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
Expand All @@ -11,8 +10,8 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T
assert m_fp32.is_contiguous(), "m_fp32 must be contiguous"
assert v_fp32.is_contiguous(), "v_fp32 must be contiguous"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor"
assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor"
assert param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16, "param_fp16 must be float16/bfloat16 tensor"
assert g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16, "g_fp16 must be float16/bfloat16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor"
Expand All @@ -26,22 +25,28 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T
assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
C.adam_cpu_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
g_fp16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
)
if g_fp16.dtype == torch.float16:
launcher = C.adam_cpu_fp16_launcher
elif g_fp16.dtype == torch.bfloat16:
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
launcher = C.adam_cpu_bf16_launcher
launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
g_fp16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
)

def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor,
def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
Expand All @@ -61,7 +66,7 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
stream = torch.cuda.current_stream().cuda_stream
C.adam_launcher(
C.adam_fp16_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_fp16.data_ptr(),
Expand All @@ -76,3 +81,41 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso
bias_correction2,
stream
)

def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor,
v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float,
weight_decay: float, step: int) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda"
assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor"
assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements"
assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_fp16 must have the same number of elements"
assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_m_fp32 must have the same number of elements"
assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
stream = torch.cuda.current_stream().cuda_stream
if not C.is_bf16_supported():
raise NotImplementedError(f"bfloat16 is not supported on current GPU")
C.adam_bf16_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
param_bf16.data_ptr(),
g_bf16.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1, beta2,
eps, lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
stream
)
63 changes: 35 additions & 28 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def _on_justify_scale(self, old_scale, new_scale):
if p in self.state:
state = self.state[p]
if len(state) > 0:
state['exp_avg'] *= delta
state['exp_avg_sq'] *= delta
if p.dtype == torch.float16:
state['exp_avg'] *= delta
state['exp_avg_sq'] *= delta

@torch.no_grad()
def step(self, closure=None, scale=1):
Expand All @@ -63,45 +64,32 @@ def step(self, closure=None, scale=1):
if p.grad is not None and p.requires_grad:
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('Adam only supports fp32 or fp16 gradients')
if p.dtype not in [torch.float32, torch.half, torch.bfloat16]:
raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients')

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros(p.size(), dtype=p.dtype, device=p.device) # on device
if p.dtype == torch.float16:
state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float16, device=p.device) # on device
else:
state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device

if p.dtype == torch.half:
state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device)# on device
if p.dtype != torch.float32:
state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device
state['_param_fp32'].copy_(p)

# update the steps for each param group update
state['step'] += 1

if ('maximize' in group) and (group['maximize'] is True):
grad = -p.grad
else:
grad = p.grad

if p.dtype == torch.half:
F.adam(
state["_param_fp32"], # fp32
p, # fp16
grad, # fp16
state['exp_avg'], # fp16: m
state["exp_avg_sq"], # fp32: v
group['betas'][0], group['betas'][1],
group['eps'],
0.0 if state["step"] <= self._hold_steps else group['lr'],
scale,
group['weight_decay'],
state['step']
)
else:
if p.dtype == torch.float32:
other_kwargs = {}
if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters:
other_kwargs['maximize'] = False
Expand All @@ -116,11 +104,30 @@ def step(self, closure=None, scale=1):
amsgrad=False,
beta1=group['betas'][0],
beta2=group['betas'][1],
lr=0.0 if state["step"] <= self._hold_steps else group['lr'],
lr=0.0 if state["step"] < self._hold_steps else group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
**other_kwargs
)
state['step'] += 1
else:
f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16
state['step'] += 1
f(
state["_param_fp32"], # fp32
p, # fp16
grad, # fp16
state['exp_avg'], # fp16: m
state["exp_avg_sq"], # fp32: v
group['betas'][0], group['betas'][1],
group['eps'],
0.0 if state["step"] < self._hold_steps else group['lr'],
scale,
group['weight_decay'],
state['step']
)



return loss

Expand Down Expand Up @@ -159,11 +166,11 @@ def load_state_dict(self, state_dict: dict) -> None:
if k in id_map:
param = id_map[k]

if param.dtype == torch.half and "_param_fp32" not in v:
if param.dtype != torch.float32 and "_param_fp32" not in v:
v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device)
v["_param_fp32"].copy_(param)

for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
for name, dtype in [("exp_avg", torch.float16 if param.dtype == torch.float16 else torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]:
if name in v:
v[name] = v[name].to(param.device).to(dtype)

Expand Down
Loading