diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index b6d80639..b779893b 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -273,6 +273,7 @@ def storage_type_cuda(storage_type): torch.FloatStorage: torch.cuda.FloatStorage, torch.DoubleStorage: torch.cuda.DoubleStorage, torch.HalfStorage: torch.cuda.HalfStorage, + torch.BFloat16Storage: torch.cuda.BFloat16Storage, torch.CharStorage: torch.cuda.CharStorage, torch.ByteStorage: torch.cuda.ByteStorage, torch.ShortStorage: torch.cuda.ShortStorage, @@ -280,6 +281,7 @@ def storage_type_cuda(storage_type): torch.cuda.FloatStorage: torch.cuda.FloatStorage, torch.cuda.DoubleStorage: torch.cuda.DoubleStorage, torch.cuda.HalfStorage: torch.cuda.HalfStorage, + torch.cuda.BFloat16Storage: torch.cuda.BFloat16Storage, torch.cuda.CharStorage: torch.cuda.CharStorage, torch.cuda.ByteStorage: torch.cuda.ByteStorage, torch.cuda.ShortStorage: torch.cuda.ShortStorage, diff --git a/bmtrain/nccl/__init__.py b/bmtrain/nccl/__init__.py index bccae34b..b6b2de91 100644 --- a/bmtrain/nccl/__init__.py +++ b/bmtrain/nccl/__init__.py @@ -35,6 +35,7 @@ def dtype2nccl(dtype : torch.dtype) -> int: torch.int64 : ncclInt64, torch.float16 : ncclFloat16, torch.half : ncclHalf, + torch.bfloat16 : ncclBFloat16, torch.float32 : ncclFloat32, torch.float : ncclFloat, torch.float64 : ncclFloat64, diff --git a/bmtrain/nccl/enums.py b/bmtrain/nccl/enums.py index dc2adbd4..67411f0e 100644 --- a/bmtrain/nccl/enums.py +++ b/bmtrain/nccl/enums.py @@ -16,6 +16,7 @@ ncclFloat = 7 ncclFloat64 = 8 ncclDouble = 8 +ncclBFloat16 = 9 ### ncclRedOp_t