diff --git a/CHANGELOG.md b/CHANGELOG.md index 2de70d371..eb7ac0dd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -228,3 +228,14 @@ Deprecated: Features: - Added Int8 SwitchBack layers - Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) + + +### 0.39.0 + + +Features: + - 4-bit matrix multiplication for Float4 and NormalFloat4 data types. + - Added 4-bit quantization routines + - Doubled quantization routines for 4-bit quantization + - Paged optimizers for Adam and Lion. + - bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states. diff --git a/Makefile.previous b/Makefile.previous index fd0e5057f..39730b12a 100644 --- a/Makefile.previous +++ b/Makefile.previous @@ -2,6 +2,7 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) GPP:= /usr/bin/g++ +#GPP:= /sw/gcc/11.2.0/bin/g++ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif @@ -39,6 +40,7 @@ ifneq ($(TARGET_ARCH),$(NATIVE_ARCH)) endif + NVCC := $(CUDA_HOME)/bin/nvcc ########################################### @@ -50,8 +52,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.cpp INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include -INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib +LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell @@ -59,17 +60,11 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler # Later versions of CUDA support the new architectures -CC_CUDA10x += -gencode arch=compute_75,code=sm_75 - -CC_CUDA110 := -gencode arch=compute_75,code=sm_75 -CC_CUDA110 += -gencode arch=compute_80,code=sm_80 - CC_CUDA11x := -gencode arch=compute_75,code=sm_75 CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_86,code=sm_86 @@ -86,23 +81,10 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 -all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - echo "Specify a target to make (arch=$(TARGET_ARCH), flags:$(EXTRA_FLAGS))" && false - -cuda: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) $(EXTRA_FLAGS) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)$(SHLIB_EXTENSION) $(LIB) - -cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) $(EXTRA_FLAGS) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt$(SHLIB_EXTENSION) $(LIB) - -cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) $(EXTRA_FLAGS) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt$(SHLIB_EXTENSION) $(LIB) +all: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda110_nomatmul: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index dcbc42319..f35a3b582 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -10,6 +10,7 @@ matmul, matmul_cublas, mm_cublas, + matmul_4bit ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d1e033101..baa633041 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,7 +2,7 @@ import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional +from typing import Tuple, Optional, List import torch @@ -427,10 +427,10 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = (CAt, subA) + ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None] + ctx.tensors = [None, None, A] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) @@ -443,7 +443,7 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - CAt, subA = ctx.tensors + CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state @@ -489,6 +489,65 @@ def backward(ctx, grad_output): return grad_A, grad_B, None, grad_bias, None + +class MatMul4Bit(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=None): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + B_shape = state[1] + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + + # 1. Dequantize + # 2. MatmulnN + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) + + # 3. Save state + ctx.state = state + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (A, B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + A, B = ctx.tensors + state = ctx.state + + grad_A, grad_B, grad_bias = None, None, None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # not supported by PyTorch. TODO: create work-around + #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t()) + + return grad_A, grad_B, None, grad_bias, None + + def matmul( A: tensor, B: tensor, @@ -501,3 +560,8 @@ def matmul( if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, bias, state) + + +def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): + assert quant_state is not None + return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index e3cb9a504..5d84c1555 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -30,15 +30,13 @@ Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - lib.cadam32bit_g32 + lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p COMPILED_WITH_CUDA = True - except AttributeError: + except AttributeError as ex: warn("The installed version of bitsandbytes was compiled without GPU support. " "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False - - # print the setup details after checking for errors so we do not print twice - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - setup.print_log_stack() + print(str(ex)) diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 4ed3c6d6a..d6ca212a2 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -44,6 +44,9 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def generate_instructions(self): + if getattr(self, 'error', False): return + print(self.error) + self.error = True if self.cuda is None: self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') @@ -93,6 +96,7 @@ def initialize(self): self.has_printed = False self.lib = None self.initialized = False + self.error = False self.cuda_setup_log = [] def run_cuda_setup(self): diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c58ddac41..2f4187d3e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -9,6 +9,8 @@ import torch import itertools import math +from scipy.stats import norm +import numpy as np from functools import reduce # Required in Python 3 from typing import Tuple @@ -26,77 +28,95 @@ def prod(iterable): if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) str2optimizer32bit["momentum"] = ( - lib.cmomentum32bit_g32, - lib.cmomentum32bit_g16, + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, ) str2optimizer32bit["rmsprop"] = ( - lib.crmsprop32bit_g32, - lib.crmsprop32bit_g16, - ) - str2optimizer32bit["lion"] = ( - lib.clion32bit_g32, - lib.clion32bit_g16, + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, ) + str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) str2optimizer32bit["adagrad"] = ( - lib.cadagrad32bit_g32, - lib.cadagrad32bit_g16, - ) - str2optimizer32bit["lars"] = ( - lib.cmomentum32bit_g32, - lib.cmomentum32bit_g16, + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, ) - str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} str2optimizer8bit["adam"] = ( - lib.cadam_static_8bit_g32, - lib.cadam_static_8bit_g16, + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, ) str2optimizer8bit["momentum"] = ( - lib.cmomentum_static_8bit_g32, - lib.cmomentum_static_8bit_g16, + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, ) str2optimizer8bit["rmsprop"] = ( - lib.crmsprop_static_8bit_g32, - lib.crmsprop_static_8bit_g16, + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, ) str2optimizer8bit["lion"] = ( - lib.clion_static_8bit_g32, - lib.clion_static_8bit_g16, + lib.clion_static_8bit_grad_32, + lib.clion_static_8bit_grad_16, ) str2optimizer8bit["lamb"] = ( - lib.cadam_static_8bit_g32, - lib.cadam_static_8bit_g16, + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, ) str2optimizer8bit["lars"] = ( - lib.cmomentum_static_8bit_g32, - lib.cmomentum_static_8bit_g16, + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, ) str2optimizer8bit_blockwise = {} str2optimizer8bit_blockwise["adam"] = ( - lib.cadam_8bit_blockwise_fp32, - lib.cadam_8bit_blockwise_fp16, + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["momentum"] = ( - lib.cmomentum_8bit_blockwise_fp32, - lib.cmomentum_8bit_blockwise_fp16, + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, ) str2optimizer8bit_blockwise["rmsprop"] = ( - lib.crmsprop_8bit_blockwise_fp32, - lib.crmsprop_8bit_blockwise_fp16, + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, ) str2optimizer8bit_blockwise["lion"] = ( - lib.clion_8bit_blockwise_fp32, - lib.clion_8bit_blockwise_fp16, + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["adagrad"] = ( - lib.cadagrad_8bit_blockwise_fp32, - lib.cadagrad_8bit_blockwise_fp16, + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, ) +class GlobalPageManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.paged_tensors = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def prefetch_all(self, to_cpu=False): + # assume the first added, will be hte + # ones that are used first, so swap them in last + # in the case they are evicted again + for t in self.paged_tensors[::-1]: + prefetch_tensor(t, to_cpu) + + class CUBLAS_Context: _instance = None @@ -106,11 +126,6 @@ def __init__(self): def initialize(self): self.context = {} - # prev_device = torch.cuda.current_device() - # for i in range(torch.cuda.device_count()): - # torch.cuda.set_device(torch.device('cuda', i)) - # self.context.append(ct.c_void_p(lib.get_context())) - # torch.cuda.set_device(prev_device) @classmethod def get_instance(cls): @@ -146,6 +161,61 @@ def get_instance(cls): cls._instance.initialize() return cls._instance +dtype2bytes = {} +dtype2bytes[torch.float32] = 4 +dtype2bytes[torch.float16] = 2 +dtype2bytes[torch.bfloat16] = 2 +dtype2bytes[torch.uint8] = 1 +dtype2bytes[torch.int8] = 1 + +def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): + num_bytes = dtype2bytes[dtype]*prod(shape) + cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + new_array = np.ctypeslib.as_array(c_ptr, shape=shape) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) + out.is_paged = True + out.page_deviceid = device.index + return out + +def prefetch_tensor(A, to_cpu=False): + assert A.is_paged, 'Only paged tensors can be prefetched!' + if to_cpu: + deviceid = -1 + else: + deviceid = A.page_deviceid + + num_bytes = dtype2bytes[A.dtype]*A.numel() + lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + +def elementwise_func(func_name, A, B, value, prefetch=True): + func = None + if A.dtype == torch.float32: + func = getattr(lib, f'c{func_name}_fp32', None) + cvalue = ct.c_float(value) + elif A.dtype == torch.uint8: + func = getattr(lib, f'c{func_name}_uint8', None) + cvalue = ct.c_uint8(value) + + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + + is_managed = getattr(A, 'is_managed', False) + if is_managed and prefetch: + prefetch_tensor(A) + if B is not None: prefetch_tensor(B) + + func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) + if A.is_paged or B.is_paged: + # paged function are fully asynchronous + # if we return from this function, we want to the tensor + # to be in the correct state, that is the final state after the + # operation occured. So we synchronize. + torch.cuda.synchronize() + +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + def create_linear_map(signed=True, total_bits=8, add_zero=True): sign = (-1.0 if signed else 0.0) @@ -163,9 +233,27 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): return values else: l = values.numel()//2 - #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) +def create_normal_map(offset=0.9677083, use_extra_value=True): + + if use_extra_value: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() + v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + else: + v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() + v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits @@ -291,9 +379,17 @@ def get_special_format_str(): def is_on_gpu(tensors): on_gpu = True + gpu_ids = set() for t in tensors: if t is None: continue # NULL pointers are fine - on_gpu &= t.device.type == 'cuda' + is_paged = getattr(t, 'is_paged', False) + on_gpu &= (t.device.type == 'cuda' or is_paged) + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + if len(gpu_ids) > 1: + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') return on_gpu def get_ptr(A: Tensor) -> ct.c_void_p: @@ -475,7 +571,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n return out -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -491,8 +587,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra The quantization map. absmax : torch.Tensor The absmax values. - rand : torch.Tensor - The tensor for stochastic rounding. out : torch.Tensor The output tensor (8-bit). @@ -524,32 +618,30 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) - if rand is not None: - is_on_gpu([code, A, out, absmax, rand]) - assert rand.numel() >= 1024 - rand_offset = random.randint(0, 1023) - if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), cblocksize, ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), cblocksize, ct.c_int(A.numel())) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + is_on_gpu([code, A, out, absmax]) + if A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + elif A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: - is_on_gpu([code, A, out, absmax]) - if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - return out, (absmax, code) + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + state = [qabsmax, code, blocksize, nested, offset, state2] + else: + state = [absmax, code, blocksize, nested, None, None] + + + + return out, state def dequantize_blockwise( @@ -559,6 +651,7 @@ def dequantize_blockwise( code: Tensor = None, out: Tensor = None, blocksize: int = 4096, + nested=False ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -593,10 +686,15 @@ def dequantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if quant_state is None: - quant_state = (absmax, code) + quant_state = (absmax, code, blocksize) + assert absmax is not None and out is not None else: - absmax, code = quant_state + absmax, code, blocksize, nested, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset if A.device.type == 'cuda': @@ -604,7 +702,7 @@ def dequantize_blockwise( code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - is_on_gpu([A, out]) + is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: @@ -620,6 +718,164 @@ def dequantize_blockwise( return out +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') + +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') + +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device) + + + if out is None: + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + #code = create_custom_map().to(absmax.device) + #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + else: + state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + + return out, state + +def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') + +def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') + +def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) + Tuple of absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + if quant_state is None: + assert absmax is not None and out is not None + shape = out.shape + dtype = out.dtype + else: + absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state + + + if compressed_stats is not None: + offset, state2 = compressed_stats + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + + if out is None: + out = torch.empty(shape, dtype=dtype, device=A.device) + + n = out.numel() + + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out + def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: @@ -772,55 +1028,36 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - if optimizer_name not in str2optimizer32bit: - raise NotImplementedError( - f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' - ) - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, unorm_vec]) - if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + + is_on_gpu([g, p, state1, state2, unorm_vec]) + prev_device = pre_call(g.device) + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel())) post_call(prev_device) @@ -977,54 +1214,45 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: + optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) post_call(prev_device) + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + prev_device = pre_call(g.device) + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) def percentile_clipping( grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 @@ -1178,6 +1406,123 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout +def cutlass3_gemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, + state=None +): + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + if state is None: + Bshape = B.shape + bout = Bshape[1] + else: + Bshape = state[1] + bout = Bshape[0] + if out is None: + out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + sA = A.shape + sB = B.shape + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + + # matrices in the input arguments for cuBLAS + # column major: A @ B = C: [m, k] @ [k, n] = [m, n] + # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] + # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] + if len(sB) == 2: + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True + if len(A.shape) == 2: + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True + else: + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True + + if len(sA) == 2: + n = sA[0] + ldb = A.stride()[1 if transposed_A else 0] + elif len(sA) == 3 and len(sB) == 2: + n = sA[0] * sA[1] + ldb = sA[2] + + m = sB[1] + k = sB[0] + lda = B.stride()[0] + ldc = sB[1] + elif len(sB) == 3: + # special case + assert len(sA) == 3 + if not (sA[0] == sB[0] and sA[1] == sB[1]): + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) + + transposed_A = True + transposed_B = False + + m = sB[2] + n = sA[2] + k = sB[0] * sB[1] + + lda = n + ldb = sA[2] + ldc = m + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + # B^T @ A^T = C^T + # [km, nk -> mn] + #lda = ldb = ldc = 1 + #lda = 1 + if state is not None: + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (ldb+1)//2 + #print(m, n, k, lda, ldb, ldc) + is_on_gpu([B, A, out]) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + + if B.dtype == torch.uint8: + lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: + lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + elif A.dtype == torch.float16: + lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + else: + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + + return out + + + def igemm( A: Tensor, @@ -1852,8 +2197,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB = ct.c_int32(B.shape[1]) cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - # print(cooA.rowidx[:64]) - # print(cooA.colidx[:64].sort()[0]) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: @@ -2051,3 +2394,8 @@ def extract_outliers(A, SA, idx): post_call(prev_device) return out + +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index f51f60078..49d7b5ced 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,5 +2,5 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb +from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1f730e66d..32849212d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -136,33 +136,100 @@ def forward(self, input: Tensor) -> Tensor: return emb -class OutlierAwareLinear(nn.Linear): - def __init__(self, input_features, output_features, bias=True): +class Params4bit(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): + if data is None: + data = torch.empty(0) + + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.blocksize = blocksize + self.compress_statistics = compress_statistics + self.quant_type = quant_type + self.quant_state = quant_state + self.data = data + return self + + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit + self.quant_state = quant_state + + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) + else: + s = self.quant_state + if s is not None: + # make sure the quantization state is on the right device + s[0] = s[0].to(device) + if self.compress_statistics: + # TODO: refactor this. This is a nightmare + # for 4-bit: + # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + #s[-2][0] = s[-2][0].to(device) # offset + #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax + + # for 8-bit + s[-2][0] = s[-2][0].to(device) # offset + s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics + s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook + new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state, + blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type) + + return new_param + +class Linear4bit(nn.Linear): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'): super().__init__(input_features, output_features, bias) - self.outlier_dim = None - self.is_quantized = False + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) + self.compute_dtype = compute_dtype - def forward_with_outliers(self, x, outlier_idx): - raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) - def quantize_weight(self, w, outlier_idx): - raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + if getattr(self.weight, 'quant_state', None) is None: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) - def forward(self, x): - if self.outlier_dim is None: - tracer = OutlierTracer.get_instance() - if not tracer.is_initialized(): - print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') - outlier_idx = tracer.get_outliers(self.weight) - #print(outlier_idx, tracer.get_hvalue(self.weight)) - self.outlier_dim = outlier_idx + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) - if not self.is_quantized: - w = self.quantize_weight(self.weight, self.outlier_dim) - self.weight.data.copy_(w) - self.is_quantized = True + out = out.to(inp_dtype) + + return out + +class LinearFP4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4') + +class LinearNF4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4') - return self.forward_with_outliers(x, self.outlier_dim) class Int8Params(torch.nn.Parameter): @@ -239,6 +306,7 @@ def to(self, *args, **kwargs): return new_param + class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None): @@ -328,6 +396,32 @@ def forward(self, x: torch.Tensor): return out +class OutlierAwareLinear(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.outlier_dim = None + self.is_quantized = False + + def forward_with_outliers(self, x, outlier_idx): + raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + + def quantize_weight(self, w, outlier_idx): + raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + + def forward(self, x): + if self.outlier_dim is None: + tracer = OutlierTracer.get_instance() + if not tracer.is_initialized(): + print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + outlier_idx = tracer.get_outliers(self.weight) + #print(outlier_idx, tracer.get_hvalue(self.weight)) + self.outlier_dim = outlier_idx + + if not self.is_quantized: + w = self.quantize_weight(self.weight, self.outlier_dim) + self.weight.data.copy_(w) + self.is_quantized = True + class SwitchBackLinearBnb(nn.Linear): def __init__( self, diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 53533ee90..83a57bd9f 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -6,11 +6,11 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit -from .adam import Adam, Adam8bit, Adam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit +from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 396aeb85f..86981eb86 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,92 +14,34 @@ class Adam(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + +class PagedAdam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 022e64c46..21077f1a0 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,89 +5,35 @@ from bitsandbytes.optim.optimizer import Optimizer2State -class AdamW(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedAdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2551b68e1..2bde1a447 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,84 +4,27 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State - class Lion(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion8bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion32bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedLion(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedLion8bit(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedLion32bit(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 1adf5d424..fb83eddf0 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -92,10 +92,12 @@ def register_module_override(self, module, param_name, config): class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): + def __init__(self, params, defaults, optim_bits=32, is_paged=False): super().__init__(params, defaults) self.initialized = False self.name2qmap = {} + self.is_paged = is_paged + self.page_mng = F.GlobalPageManager.get_instance() self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { @@ -207,7 +209,9 @@ def to_gpu(self): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - self.state[p][k] = v.to(p.device) + is_paged = getattr(v, 'is_paged', False) + if not is_paged: + self.state[p][k] = v.to(p.device) def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: @@ -252,6 +256,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True + #if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -260,7 +265,14 @@ def step(self, closure=None): if len(state) == 0: self.init_state(group, p, gindex, pindex) + self.prefetch_state(p) self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + return loss @@ -289,6 +301,26 @@ def update_step(self, group, p, gindex, pindex): "The update_step method needs to be overridden" ) + def get_state_buffer(self, p, dtype=torch.float32): + if not self.is_paged or p.numel() < 1e5: + return torch.zeros_like(p, dtype=dtype, device=p.device) + else: + # > 1 MB + buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) + F.fill(buff, 0) + self.page_mng.paged_tensors.append(buff) + return buff + + def prefetch_state(self, p): + if self.is_paged: + state = self.state[p] + s1 = state['state1'] + is_paged = getattr(s1, 'is_paged', False) + if is_paged: + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) + class Optimizer2State(Optimizer8bit): def __init__( @@ -306,6 +338,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -325,7 +358,7 @@ def __init__( f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -365,18 +398,8 @@ def init_state(self, group, p, gindex, pindex): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) + state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -388,20 +411,10 @@ def init_state(self, group, p, gindex, pindex): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap2"] = self.name2qmap["udynamic"] if config["block_wise"]: @@ -538,6 +551,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -553,7 +567,7 @@ def __init__( f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -593,12 +607,7 @@ def init_state(self, group, p, gindex, pindex): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -607,12 +616,7 @@ def init_state(self, group, p, gindex, pindex): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] if config["block_wise"]: diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 946c2ab2d..6b97b24fd 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -171,3 +171,42 @@ def get_cuda_devices(): def is_cuda_device(device): return device in get_cuda_devices() + + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 10c4eaaf5..33163fd68 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -12,12 +12,17 @@ #include #include #include +#include +#include +#include + #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 + // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -43,11 +48,289 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} // sign function for lion // taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA -template -__device__ int sgn(T val) { +template __device__ int sgn(T val) +{ return (T(0) < val) - (val < T(0)); } @@ -435,7 +718,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } } -template +template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { @@ -445,13 +728,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; + typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; @@ -462,8 +745,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; - for(int i = threadIdx.x; i < 256; i+=blockDim.x) - smem_code[i] = code[i]; + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -503,64 +787,111 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + unsigned char packed_4bit = 0; + switch(DATA_TYPE) { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); - else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; } __syncthreads(); - StoreChar(storec).Store(&(out[i]), qvals, valid_items); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } } -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { - const int n_full = gridDim.x * BLOCK_SIZE; - int valid_items = 0; - const int base_idx = (blockIdx.x * BLOCK_SIZE); + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); - T vals[NUM_PER_TH]; + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - //__shared__ float smem_code[256]; - //float local_code[16]; - //if(threadIdx.x < 256) - //smem_code[threadIdx.x] = code[threadIdx.x]; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; - local_abs_max = absmax[i/BLOCK_SIZE]; + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); - __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; - } + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } - __syncthreads(); - StoreT(storet).Store(&(out[i]), vals, valid_items); + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); } } - __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { const unsigned int numThreads = blockDim.x * gridDim.x; @@ -1462,6 +1793,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; @@ -1523,16 +1855,24 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - g_val = float(g_vals[j]); - g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); @@ -1563,22 +1903,23 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char } __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), g_vals, valid_items, (T)0.0f); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { - g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); if(weight_decay > 0.0f) - g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); - StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH @@ -2498,7 +2839,7 @@ template @@ -2577,7 +2918,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; if(idx >= colsB){ break; } - //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx); if((idx+num_items < colsB)) { if(BITS == 8) @@ -2597,15 +2937,13 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o #pragma unroll num_items for(int k = 0; k < num_items; k++) { - //if((float)local_valsB[k] != 0.0) - // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB); if(BITS == 8 && dequant_stats != NULL) // we do texture cache reads (__ldg) on dequant_stats which should be super fast { float valB = local_valsB[k]; float valA = local_valA[i]; if(valB != 0.0 && valA != 0.0) - local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA; + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; } else local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; @@ -2711,10 +3049,587 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggreecate files of C into shared memroy block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 5 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); @@ -2755,6 +3670,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) @@ -2768,6 +3684,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, __nv_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -2777,12 +3694,15 @@ template __global__ void kPreconditionOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ @@ -2853,51 +3773,60 @@ MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); - -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); - - +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ @@ -2910,6 +3839,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise( \ @@ -2927,5 +3858,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index a8aa3fca5..30faf4a80 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,13 +9,15 @@ #ifndef kernels #define kernels +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, @@ -120,4 +122,9 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 86af18133..9c042fa66 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,52 +50,53 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - if(blocksize == 4096) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 2048) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 1024) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 512) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 256) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 128) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 64) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + else + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} + + template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -681,10 +682,73 @@ template void extractOutliers(char * A, int *idx, char *out, int id CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); @@ -708,12 +772,20 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ @@ -723,12 +795,14 @@ template void optimizer32bit(gtype* g, gtype* p, \ MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -764,8 +838,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 7351e3261..ff342b5bc 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -22,6 +22,11 @@ #include #include +#include +#include + + + #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ if (_m_cudaStat != cudaSuccess) { \ @@ -84,6 +89,20 @@ typedef enum Transform_t COL_AMPERE = 4, } Transform_t; +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + class Context { public: @@ -131,8 +150,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, @@ -179,4 +198,11 @@ template void spmm_coo_very_sparse_naive(int *max_count, template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index bef02e82f..8b6b27d58 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -23,8 +23,25 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } +//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } +void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } + +void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ +void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ -void fname##32bit_g##gbits(gtype *g, gtype *p, \ +void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ @@ -32,17 +49,19 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) -MAKE_FUNC32(adam, ADAM, float, 32) -MAKE_FUNC32(adam, ADAM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) -MAKE_FUNC32(lion, LION, float, 32) -MAKE_FUNC32(lion, LION, half, 16) +MAKE_FUNC32(lion, LION, float, fp32) +MAKE_FUNC32(lion, LION, half, fp16) +MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ -void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ +void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -64,33 +83,42 @@ MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ +void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ { optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ -MAKE_BLOCKWISE8(adam, ADAM, half, 16) -MAKE_BLOCKWISE8(adam, ADAM, float, 32) -MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) -MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) -MAKE_BLOCKWISE8(lion, LION, half, 16) -MAKE_BLOCKWISE8(lion, LION, float, 32) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, half, fp16) +MAKE_BLOCKWISE8(lion, LION, float, fp32) +MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, blocksize, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, blocksize, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -151,32 +179,41 @@ extern "C" void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, blocksize, n); } - void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, blocksize, n); } void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + #define MAKE_CFUNC32(name, gtype, gbits) \ - void c##name##32bit_g##gbits(gtype *g, gtype *p, \ + void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - MAKE_CFUNC32(adam, float, 32) - MAKE_CFUNC32(adam, half, 16) + MAKE_CFUNC32(adam, float, fp32) + MAKE_CFUNC32(adam, half, fp16) + MAKE_CFUNC32(adam, __nv_bfloat16, bf16) MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) - MAKE_CFUNC32(lion, float, 32) - MAKE_CFUNC32(lion, half, 16) + MAKE_CFUNC32(lion, float, fp32) + MAKE_CFUNC32(lion, half, fp16) + MAKE_CFUNC32(lion, __nv_bfloat16, bf16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) #define MAKE_CFUNC8(name, gtype, gbits) \ - void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -184,7 +221,7 @@ extern "C" float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ @@ -198,22 +235,23 @@ extern "C" MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ - void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ + void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ - - MAKE_CBLOCKWISE8(adam, ADAM, half, 16) - MAKE_CBLOCKWISE8(adam, ADAM, float, 32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) - MAKE_CBLOCKWISE8(lion, LION, half, 16) - MAKE_CBLOCKWISE8(lion, LION, float, 32) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) - + { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) + MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, half, fp16) + MAKE_CBLOCKWISE8(lion, LION, float, fp32) + MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } @@ -301,6 +339,38 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) + //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) + { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ + + CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) + CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) + CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) + CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/deploy.sh b/deploy.sh index 24d6cbf6b..a2257a2bb 100644 --- a/deploy.sh +++ b/deploy.sh @@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then fi -make clean -export CUDA_HOME=$BASE_PATH/cuda-10.2 -make cuda10x_nomatmul CUDA_VERSION=102 - -if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then - # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 - exit 64 -fi - - make clean export CUDA_HOME=$BASE_PATH/cuda-11.0 make cuda110_nomatmul CUDA_VERSION=110 diff --git a/setup.py b/setup.py index b6e543c97..21b0f47cf 100644 --- a/setup.py +++ b/setup.py @@ -23,10 +23,10 @@ def has_ext_modules(foo): setup( name=f"bitsandbytes", - version=f"0.38.1", + version=f"0.39.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", - description="8-bit optimizers and matrix multiplication routines.", + description="k-bit optimizers and matrix multiplication routines.", license="MIT", keywords="gpu optimizers optimization 8-bit quantization compression", url="https://github.com/TimDettmers/bitsandbytes", diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 3e4b2b8e2..0da7aa5a8 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -99,7 +99,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -108,7 +108,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_allclose( + torch.testing.assert_close( gradB1, gradB2, atol=0.18, rtol=0.3 ) @@ -137,7 +137,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_allclose( + torch.testing.assert_close( out_bnb, out_torch, atol=0.027, rtol=0.2 ) @@ -161,7 +161,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -220,7 +220,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -409,7 +409,7 @@ def test_matmullt( bias.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -425,13 +425,12 @@ def test_matmullt( assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_allclose( + torch.testing.assert_close( gradB1, gradB2, atol=0.18, rtol=0.3 ) if req_grad[2]: - torch.testing.assert_allclose(gradBias1, gradBias2) - + torch.testing.assert_close(gradBias1, gradBias2) n = 1 @@ -443,6 +442,99 @@ def test_matmullt( dim2.append(0) +funcs = [(torch.matmul, bnb.matmul_4bit)] +str_funcs = ["matmul"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +compress_statistics = [False, True] +has_fp16_weights = [True, False] +has_bias = [True, False] +quant_type = ['fp4', 'nf4'] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) +def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + + B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) + + if has_bias: + out_torch += bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + + #assert err < 0.20 + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None + + if req_grad[0]: + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + + if req_grad[2]: + torch.testing.assert_close(gradBias1, gradBias2) + + funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] req_grad = list(product([True, False], repeat=3)) @@ -475,6 +567,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + torch.nn.init.xavier_uniform_(B) fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) @@ -492,7 +585,8 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: - assert err < 0.20 + assert err < 0.115 + #assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -511,7 +605,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() @@ -528,7 +622,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() <= n * 0.02 grad_err = (gradB1-gradB2).abs().mean() assert grad_err.item() < 0.003 - torch.testing.assert_allclose( + torch.testing.assert_close( gradB1, gradB2, atol=0.18, rtol=0.3 ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 402977a4b..820a78d24 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -20,12 +20,15 @@ k = 20 -def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0): +def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): idx = torch.isclose(a, b, rtol, atol) sumval = (idx == 0).sum().item() if sumval > count: - print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_allclose(a, b, rtol, atol) + if throw: + print(f"Too many values not close: assert {sumval} < {count}") + torch.testing.assert_close(a, b, rtol, atol) + + return sumval class FFN(torch.nn.Module): @@ -100,7 +103,7 @@ def test_estimate_quantiles(dtype): code = F.estimate_quantiles(A) percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) - torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) + torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) A = torch.randn(1024, 1024, device="cuda") A = A.to(dtype) @@ -126,7 +129,7 @@ def test_quantile_quantization(): C = F.quantize_no_absmax(A1, code) A2 = F.dequantize_no_absmax(C, code) diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) + torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) assert diff < 0.001 @@ -151,67 +154,49 @@ def test_dynamic_quantization(device): C, S = F.quantize(A1) A2 = F.dequantize(C, S) diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) assert diff < 0.004 -def test_dynamic_blockwise_quantization(device): - for blocksize in [4096, 2048, 1024, 512]: - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device=device) - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.011 - assert relerr < 0.018 - #print('randn', blocksize, sum(diffs)/len(diffs)) - #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) - - diffs = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device=device) - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.0035 - assert relerr < 0.015 - #print('rand', blocksize, sum(diffs)/len(diffs)) - #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) - - @skip_if_no_cuda() +@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) -@pytest.mark.skip("Stochastic has some bugs, but will be deprecated soon anyways.") -def test_dynamic_blockwise_stochastic_quantization(blocksize): +def test_dynamic_blockwise_quantization(nested, blocksize): + #print('') diffs = [] reldiffs = [] - rand = torch.rand(1024).cuda() - err = 0 for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") - C1, S1 = F.quantize_blockwise(A1, rand=rand, blocksize=blocksize) - C2, S2 = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C1, S1, blocksize=blocksize) - err += (A1-A2).abs().mean().item()/100 - # a maximunm distance of quantized values of 1 - torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) - fraction_smaller = (C1 < C2).float().sum() / C1.numel() - fraction_larger = (C1 > C2).float().sum() / C1.numel() - torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) - assert err < 0.019 + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + @skip_if_no_cuda() @@ -241,9 +226,9 @@ def test_percentile_clipping(gtype): vals, idx = torch.sort(gnorm_vec1) clip1 = vals[percentile] - torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2)) - torch.testing.assert_allclose(clip1, clip2) - torch.testing.assert_allclose(gnorm1, gnorm2) + torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_close(clip1, clip2) + torch.testing.assert_close(gnorm1, gnorm2) def quant(x): @@ -326,7 +311,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - print("") + #print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -338,7 +323,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_allclose( + torch.testing.assert_close( quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 ) if batched: @@ -355,8 +340,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - print(mean(errors)) - print(mean(relerrors)) + #print(mean(errors)) + #print(mean(relerrors)) def test_stable_embedding(): @@ -410,7 +395,7 @@ def test_igemm(device, hidden_dim, batch_dim, transpose, seq_dim): out2 = torch.matmul(A.t().float(), B.t().float()) out = F.igemm(A.t(), B.t()) - torch.testing.assert_allclose(out.float(), out2) + torch.testing.assert_close(out.float(), out2) for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) @@ -428,7 +413,7 @@ def test_igemm(device, hidden_dim, batch_dim, transpose, seq_dim): out2 = torch.matmul(A.float(), B.t().float()) out = F.igemm(A, B.t()) - torch.testing.assert_allclose(out.float(), out2) + torch.testing.assert_close(out.float(), out2) n = 3 @@ -460,7 +445,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): ) out = F.igemm(A, B, out=iout) - torch.testing.assert_allclose(out.float(), out2) + torch.testing.assert_close(out.float(), out2) n = 2 @@ -587,7 +572,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() ) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) - torch.testing.assert_allclose(out.float(), out2.float()) + torch.testing.assert_close(out.float(), out2.float()) n = 1 @@ -647,9 +632,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans out, S = F.nvidia_transform(A, to_order=orderOut) if orderOut == "row": - torch.testing.assert_allclose(A.flatten(), out.flatten()) + torch.testing.assert_close(A.flatten(), out.flatten()) elif orderOut == "col": - torch.testing.assert_allclose(A.t().flatten(), out.flatten()) + torch.testing.assert_close(A.t().flatten(), out.flatten()) elif orderOut == "col32": if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) @@ -682,14 +667,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans assert A.flatten()[i + j] == A[row, col] # assert A.flatten()[i+j] == out.flatten()[row2+col2] - # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) - # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) + # torch.testing.assert_close(A.flatten()[i+j], A[row, col]) + # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": out2, S = F.nvidia_transform( out, from_order=orderOut, to_order="row", state=S ) - torch.testing.assert_allclose(A, out2) + torch.testing.assert_close(A, out2) n = 1 @@ -734,7 +719,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B2, SB = F.transform(B, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_allclose(C1, C3.float()) + torch.testing.assert_close(C1, C3.float()) # transpose B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( @@ -745,7 +730,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B2t, SBt = F.transform(B, "col_turing", transpose=True) C2, SC = F.igemmlt(A2, B2t, SA, SBt) C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_allclose(C1, C3.float()) + torch.testing.assert_close(C1, C3.float()) dim1 = [32] @@ -792,7 +777,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # print(C1.flatten()[:10]) # print(C2.flatten()[:10]) - # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # transpose # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) @@ -801,7 +786,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # B2t, SBt = F.transform2(B, 'col_turing', transpose=True) # C2, SC = F.igemmlt(A2, B2t, SA, SBt) # C3, S = F.transform(C2, 'row', state=SC) - # torch.testing.assert_allclose(C1, C3.float()) + # torch.testing.assert_close(C1, C3.float()) batch_size = 2 @@ -1022,7 +1007,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - #torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1) + #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) @@ -1073,16 +1058,16 @@ def test_colrow_absmax(dim1, dim2, dims): ) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - torch.testing.assert_allclose(col_stats1_trunc, col_stats2) - torch.testing.assert_allclose(row_stats1_trunc, row_stats2) - torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) + torch.testing.assert_close(col_stats1_trunc, col_stats2) + torch.testing.assert_close(row_stats1_trunc, row_stats2) + torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( A, threshold=0.0 ) - torch.testing.assert_allclose(col_stats1, col_stats2) - torch.testing.assert_allclose(row_stats1, row_stats2) + torch.testing.assert_close(col_stats1, col_stats2) + torch.testing.assert_close(row_stats1, row_stats2) assert nnz_block_ptr2 is None @@ -1107,8 +1092,8 @@ def test_double_quant(dim1, dim2): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) # max difference is 1 due to rounding differences - torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) - torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) + torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) + torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() num_not_close_rows = ( @@ -1131,8 +1116,8 @@ def test_double_quant(dim1, dim2): ) assert False - torch.testing.assert_allclose(Srow.flatten(), statsA) - torch.testing.assert_allclose(Scol.flatten(), statsAt) + torch.testing.assert_close(Srow.flatten().float(), statsA) + torch.testing.assert_close(Scol.flatten().float(), statsAt) n = 4 @@ -1158,10 +1143,10 @@ def test_integrated_igemmlt(dim1, dim4, inner): A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - torch.testing.assert_allclose(maxA.flatten(), stats1a) - torch.testing.assert_allclose(maxB.flatten(), stats2a) - torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) - torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) + torch.testing.assert_close(maxA.flatten().float(), stats1a) + torch.testing.assert_close(maxB.flatten().float(), stats2a) + torch.testing.assert_close(C1a, A1, rtol=0, atol=1) + torch.testing.assert_close(C2a, B1, rtol=0, atol=1) A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(C2a, "col_turing") @@ -1366,7 +1351,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): # print(out1) # print(out2) - torch.testing.assert_allclose(out1, out2) + torch.testing.assert_close(out1, out2) n = 2 @@ -1430,11 +1415,11 @@ def test_coo_double_quant(dim1, dim2): A2[ coo_tensor.rowidx.long(), coo_tensor.colidx.long() ] = coo_tensor.values - torch.testing.assert_allclose(A1, A2) + torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_allclose( + torch.testing.assert_close( A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 ) @@ -1647,7 +1632,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): idx_col = torch.randint(0, A2.shape[-1], size=(15,)) - # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) + # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001) # Bt = torch.randn(dim2*4, dim2, device='cuda').half() # torch.cuda.synchronize() @@ -1679,9 +1664,9 @@ def test_coo2csr(): counts = csrA.rowptr[1:] - csrA.rowptr[:-1] assert counts.numel() == A.shape[0] - torch.testing.assert_allclose(counts, (A2 != 0).sum(1)) + torch.testing.assert_close(counts.long(), (A2 != 0).sum(1)) idx = A2 != 0 - torch.testing.assert_allclose(A2[idx], csrA.values) + torch.testing.assert_close(A2[idx], csrA.values) @skip_if_no_cuda() @@ -1700,10 +1685,10 @@ def test_coo2csc(): counts = cscA.colptr[1:] - cscA.colptr[:-1] assert counts.numel() == A.shape[1] - torch.testing.assert_allclose(counts, (A2 != 0).sum(0)) + torch.testing.assert_close(counts.long(), (A2 != 0).sum(0)) # torch uses row-major -> use transpose to transfer to col-major idx = A2.t() != 0 - torch.testing.assert_allclose(A2.t()[idx], cscA.values) + torch.testing.assert_close(A2.t()[idx], cscA.values) n = 2 @@ -1754,7 +1739,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): max_count, max_idx = torch.sort(counts, descending=True) print(torch.median(max_count.float())) - torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) + torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001) p = 200 / (2048 * 12288 * 4) n = out1.numel() @@ -1824,39 +1809,45 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): batch_size = 1 seqdim = 1 values = [] -values.append((batch_size, seqdim, 768, 4 * 768)) -# values.append((batch_size, seqdim, 1024, 4*1024)) -# values.append((batch_size, seqdim, 1536, 4*1536)) -# values.append((batch_size, seqdim, 2048, 4*2048)) -# values.append((batch_size, seqdim, 2560, 4*2560)) -# values.append((batch_size, seqdim, 4096, 4*4096)) -# values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 768, 4 * 768)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +values.append((batch_size, seqdim, 4096, 4*4096)) +values.append((batch_size, seqdim, 5120, 4*5120)) +values.append((batch_size, seqdim, 6656, 4*6656)) +values.append((batch_size, seqdim, 8192, 4*8192)) +#values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - +names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @skip_if_no_cuda() @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 128 + iters = 80 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") torch.nn.init.xavier_uniform_(B) - linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + B_fp4, state = F.quantize_fp4(B) + B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) + + B_nf4, state_nf4= F.quantize_nf4(B) + + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit.eval() outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = ( - bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - ) - linearMixedBit.eval() + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) + #linearMixedBit.eval() + + linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() # warmup for i in range(iters): @@ -1869,61 +1860,80 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( - f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul(A, B) + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) torch.cuda.synchronize() - print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul(A, B, threshold=6.0) + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) torch.cuda.synchronize() - print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - C32A, SA = F.transform(CA, "col32") - CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - CxB, SB = F.transform(CB, to_order=formatB) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - torch.cuda.synchronize() - print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - BA, statsB = F.vectorwise_quant(B, dim=1) - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1) - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() + print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B, threshold=6.0) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + #C32A, SA = F.transform(CA, "col32") + #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + #CxB, SB = F.transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + #torch.cuda.synchronize() + #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #BA, statsB = F.vectorwise_quant(B, dim=1) + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1) + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + #torch.cuda.synchronize() #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - out = Cout * statsB * statsA * (1.0 / (127 * 127)) - torch.cuda.synchronize() + #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # out = Cout * statsB * statsA * (1.0 / (127 * 127)) + #torch.cuda.synchronize() #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linear8bit(A) @@ -1932,9 +1942,7 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linear8bit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linearMixedBit(A) torch.cuda.synchronize() @@ -1942,9 +1950,23 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linearMixedBit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit_train(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit_train_thresh(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") @skip_if_no_cuda() def test_zeropoint(): @@ -2049,7 +2071,7 @@ def test_extract_outliers(): assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[1] == idx.numel() - torch.testing.assert_allclose(outliers1, outliers2) + torch.testing.assert_close(outliers1, outliers2) CA, SA = F.transform(A, "col_ampere") @@ -2058,7 +2080,7 @@ def test_extract_outliers(): assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[1] == idx.numel() - torch.testing.assert_allclose(outliers1, outliers2) + torch.testing.assert_close(outliers1, outliers2) @@ -2091,7 +2113,6 @@ def test_fp8_quant(): p_bits = 7-e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() - print(e_bits, p_bits) abserr = [] relerr = [] for i in range(100): @@ -2150,7 +2171,6 @@ def test_few_bit_quant(): ebits = math.ceil(bits/2) pbits = bits-ebits-1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - print(code) elif method == 'dynamic': code = F.create_dynamic_map(True, bits-0, bits).cuda() elif method == 'quantile': @@ -2191,7 +2211,7 @@ def test_few_bit_quant(): #assert err2.mean() <= err1 else: - torch.testing.assert_allclose(q1, q2) + torch.testing.assert_close(q1, q2) #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) #assert False @@ -2235,8 +2255,302 @@ def test_bench_dequantization(): torch.cuda.synchronize() t0 = time.time() for i in range(100): - F.dequantize_blockwise(qa, SA, blocksize=2048) + qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() #print((time.time()-t0)/1e6) + +def test_fp4_quant(): + vals = list(product([0, 1], repeat=4)) + + code = {} + for bits in vals: + result = 0 + bias = 3 + sign, e1, e2, p1 = bits + idx = sign*8 + e1*4 + e2*2 + p1*1 + sign = -1.0 if sign else 1.0 + exp = e1*2 + e2*1 + if exp == 0: + # sub-normal + if p1 == 0: result = 0 + else: result = sign*0.0625 + else: + # normal + exp = 2**(-exp + bias + 1) + frac = 1.5 if p1 else 1.0 + result = sign*exp*frac + code[idx] = result + + A1 = torch.randn(1024, 1024, device='cuda').half() + qa, SA = F.quantize_fp4(A1, blocksize=64) + A2 = F.dequantize_fp4(qa, SA) + + err = (A1 - A2).abs().float() + relerr = (err/A1.abs().float()).mean() + idx = err > 1.0 + err = err.mean() + + + assert err.item() < 0.1 + assert relerr.item() < 0.28 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_4bit_compressed_stats(quant_type): + for blocksize in [128, 64]: + errs1 = [] + errs2 = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device='cuda').half() + q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) + + + err = (A1 - A2).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs1.append(err.item()) + + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + err = (A1 - A3).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs2.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + #print(sum(errs1)/len(errs1), blocksize, quant_type) + #print(sum(errs2)/len(errs2), blocksize, quant_type) + + + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_bench_4bit_dequant(quant_type): + blocksize = 256 + a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) + + input_size = a.numel()/2 + output_size = a.numel()*2 + num_bytes = input_size+output_size + GB = num_bytes/1e9 + max_theoretical_s = GB/768 + #print(max_theoretical_s*1e6) + b = torch.randn(128, 1024*12, device='cuda').half() + + iters = 5 + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + #b.copy_(a) + torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # torch.matmul(b, a.t()) + #torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + + + +def test_normal_map_tree(): + code = F.create_normal_map() + values =code[:8].tolist() + code[-8:].tolist() + num_pivots = 1 + print(values) + while num_pivots <16: + idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) + print(idx) + num_pivots *= 2 + pivots = [] + for i in idx: + pivots.append((values[i-1]+values[i])/2) + print(pivots) + + +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_cutlass3_gemm(dtype): + debug = True + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + for dim in [4096]: + #for dim in [128+1]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(100): + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 + + + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + # print('') + # print(i, err, relerr) + # print(A.flatten()[-6:]) + # print(B.flatten()[-6:]) + # out = A.flatten()[-6:]*B.flatten()[-6:] + # print(out) + # print(out[:-1].sum()) + # print('='*80) + # print(C1.flatten()[-6:]) + # print(C2.flatten()[-6:]) + # #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) + +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_gemm_4bit(dtype): + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + #for dim in [32]: + for dim in [4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 + + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) + + C3 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + + print(C1.shape, C2.shape) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, relerr) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) + +@pytest.mark.skip("Row scale has some bugs for ampere") +def test_managed(): + n = 32*10 + A = F.get_paged(n, n, dtype=torch.float32) + B = F.get_paged(n, n, dtype=torch.uint8) + B2 = F.get_paged(n, n, dtype=torch.float32) + assert A.is_paged + assert B.is_paged + assert A.page_deviceid==0 + assert B.page_deviceid==0 + F.fill(A, 17.0) + F.fill(B, 17) + F.fill(B2, 2) + assert (A==17).sum().item() == n*n + assert (B==17).sum().item() == n*n + C = A*B.float() + assert (C==289).sum().item() == n*n + F._mul(A, B2) + F._mul(A, B2) + F._mul(A, B2) + assert (A==17*(2**3)).sum().item() == n*n + # F.prefetch_tensor(A) + # F.prefetch_tensor(B) + + + # F.fill(B2, 17.0) + # F._mul(A, B2) + + # F.prefetch_tensor(A, to_cpu=True) + # F.prefetch_tensor(B, to_cpu=True) + # F.prefetch_tensor(B2, to_cpu=True) + # torch.cuda.synchronize() + + # assert (A==17).sum().item() == n*n + + # torch.testing.assert_close(A, torch.ones(A.shape)*289) diff --git a/tests/test_modules.py b/tests/test_modules.py index 9ef109dcb..94c49e564 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -47,7 +47,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_allclose(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol, atol) class LinearFunction(torch.autograd.Function): @@ -334,19 +334,16 @@ def test_linear8bitlt_inference(threshold): @skip_if_no_cuda() -def test_linear8bitlt_accumulated_gradient(device): - l1 = torch.nn.Sequential( - *[bnb.nn.Linear8bitLt(32, 32).to(device).half() for i in range(2)] - ) - l2 = torch.nn.Sequential( - *[torch.nn.Linear(32, 32).to(device).half() for i in range(2)] - ) - l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) - l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) - l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) - l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) - opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) - opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) +def test_linear8bitlt_accumulated_gradient(): + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) + + opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001) acc_steps = 10 @@ -376,27 +373,17 @@ def test_linear8bitlt_accumulated_gradient(device): # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) - torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) - - -threshold = [0.0, 2.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - + torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3) @skip_if_no_cuda() -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( - bnb.nn.Linear8bitLt( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .cuda() - .half() - ) + l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -452,13 +439,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( - MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .half() - .to("cuda") - ) + mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -505,15 +486,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() scale = grad_ref.abs().mean() - torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) assert (idx == 0).sum().item() <= b1.numel() * 0.005 + @skip_if_no_cuda() -def test_linear8bitlt_fp32_bias(): +@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias.dtype == torch.float32 for i in range(100): @@ -523,8 +506,8 @@ def test_linear8bitlt_fp32_bias(): assert l1.bias.dtype == torch.float16 # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64, bias=False).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias is None for i in range(100): @@ -532,6 +515,74 @@ def test_linear8bitlt_fp32_bias(): o1 = l1(b1) assert l1.bias is None +modules = [] +modules.append(bnb.nn.Linear8bitLt) +modules.append(bnb.nn.Linear4bit) +modules.append(bnb.nn.LinearFP4) +modules.append(bnb.nn.LinearNF4) +modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) +modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) +names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C'] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("module", modules, ids=names) +def test_kbit_backprop(module): + b = 17 + dim1 = 37 + dim2 = 83 + + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) + ref[1].weight.requires_grad = False + torch.nn.init.kaiming_normal_(ref[0].weight) + torch.nn.init.kaiming_normal_(ref[1].weight) + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) + kbit[0].weight.detach().copy_(ref[0].weight) + kbit[1].weight.detach().copy_(ref[1].weight) + kbit[0].bias.detach().copy_(ref[0].bias) + kbit[1].bias.detach().copy_(ref[1].bias) + ref = ref.half().cuda() + kbit = kbit.half().cuda() + + errs1 = [] + errs2 = [] + relerrs1 = [] + relerrs2 = [] + for i in range(100): + batch = torch.randn(b, dim1).half().cuda() + out1 = ref(batch) + out2 = kbit(batch) + out1.mean().backward() + out2.mean().backward() + + grad1 = ref[0].weight.grad + grad2 = kbit[0].weight.grad + bgrad1 = ref[0].bias.grad + bgrad2 = kbit[0].bias.grad + + err1 = (out1-out2).abs().float() + err2 = (grad1-grad2).abs().float() + relerr1 = (err1/(out1.abs().float()+1e-9)) + relerr2 = (err2/(grad1.abs().float()+1e-9)) + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + + if isinstance(module, bnb.nn.Linear8bitLt): + torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) + else: + torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05) + ref.zero_grad() + kbit.zero_grad() + + assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 + assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 + print('out', sum(errs1)/len(errs1)) + print('grad', sum(errs2)/len(errs2)) + print('rel out', sum(relerrs1)/len(relerrs1)) + print('rel grad', sum(relerrs2)/len(relerrs2)) + @skip_if_no_cuda() def test_fp8linear(): diff --git a/tests/test_optim.py b/tests/test_optim.py index 33ed65a93..ed9754c6b 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -20,11 +20,11 @@ k = 20 def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() if error_count > max_error_count: print(f"Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_allclose(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) def get_temp_dir(): @@ -36,11 +36,8 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) - str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) -# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["momentum_pytorch"] = ( None, @@ -48,28 +45,20 @@ def rm_path(path): bnb.optim.Adam, ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) -# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) +str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["lion"] = (Lion, bnb.optim.Lion) +str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9), -) str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["adam8bit"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), -) -str2optimizers["lion8bit"] = ( - Lion, - lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False), -) +str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) +str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False)) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -78,19 +67,12 @@ def rm_path(path): lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars8bit"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9), -) -str2optimizers["adam8bit_blockwise"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), -) -str2optimizers["lion8bit_blockwise"] = ( - Lion, - lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True), -) +str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) +str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) +str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) +str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -102,54 +84,36 @@ def rm_path(path): str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lion"] = [("exp_avg", "state1")] +str2statenames["paged_lion"] = [("exp_avg", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")] -str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] -str2statenames["adam8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["lion8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1") -] -str2statenames["lamb8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["adam8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] -str2statenames["lion8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1") -] -str2statenames["momentum8bit"] = [ - ("momentum_buffer", "state1", "qmap1", "max1") -] -str2statenames["momentum8bit_blockwise"] = [ - ("momentum_buffer", "state1", "qmap1", "absmax1") -] -str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] -str2statenames["rmsprop8bit_blockwise"] = [ - ("square_avg", "state1", "qmap1", "absmax1") -] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] +str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] +str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"] +gtype = [torch.float32, torch.float16, torch.bfloat16] +optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] - - +names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @skip_if_no_cuda() @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -161,6 +125,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 1e-6, 1e-5 + elif gtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-2 else: atol, rtol = 1e-4, 1e-3 @@ -174,9 +140,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( + torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2], + bnb_optimizer.state[p2][name2].cuda(), atol=atol, rtol=rtol, ) @@ -203,14 +169,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): atol=atol, rtol=rtol, max_error_count=10) - if gtype == torch.float16: + if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit # but the paramters can diverge because they are 16-bit # the difference grow larger and larger with each update # --> copy the state to keep weights close - p1.data = p1.data.half().float() + p1.data = p1.data.to(p2.dtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.half(), p2) + torch.testing.assert_close(p1.to(p2.dtype), p2) if optim_name in ["lars", "lamb"]: assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 @@ -271,7 +237,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] +gtype = [torch.float32, torch.float16, torch.bfloat16] optimizer_names = [ "adam8bit", "lion8bit", @@ -279,7 +245,6 @@ def test_global_config(dim1, dim2, gtype): "rmsprop8bit", "adam8bit_blockwise", "lion8bit_blockwise", - "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] @@ -292,6 +257,7 @@ def test_global_config(dim1, dim2, gtype): @skip_if_no_cuda() @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -305,7 +271,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 - + elif gtype == torch.bfloat16: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-4, 1e-2 else: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 @@ -313,7 +281,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): errors = [] relerrors = [] - for i in range(50): + for i in range(100): g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -347,13 +315,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) == 0 ) - assert num_not_close.sum().item() < 20 + #assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) relerr = err / (torch.abs(p1)+1e-9) - assert err.mean() < 0.0001 - assert relerr.mean() < 0.001 + if g.dtype == torch.bfloat16: + assert err.mean() < 0.00015 + assert relerr.mean() < 0.0016 + else: + assert err.mean() < 0.00012 + assert relerr.mean() < 0.0012 errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) @@ -373,12 +345,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose( - raws1cpy, bnb_optimizer.state[p2][name2] - ) - torch.testing.assert_allclose( - qmap1, bnb_optimizer.state[p2][qmap] - ) + torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -393,17 +361,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - torch.testing.assert_allclose(s1cpy, s1) - - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], - s1, - atol=atol, - rtol=rtol, - ) - == 0 - ) + torch.testing.assert_close(s1cpy, s1) + + num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 5 errors for Lion @@ -413,10 +373,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): # together so we can test against the Adam error p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.to(gtype), p2) - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + torch.testing.assert_close(p1.to(gtype), p2) + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) # print(sum(errors)/len(errors)) @@ -478,28 +436,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state if optim_bits == 32: - torch.testing.assert_allclose(p1, p2) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=5e-5, rtol=1e-4, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=5e-5, rtol=1e-4, ) elif optim_bits == 8: - torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, @@ -531,7 +489,7 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise"] +optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values @@ -563,3 +521,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): params = (k - k // 5) * dim1 * dim2 print(optim_name, gtype, s / params) # assert s < 3.9 + +dim1 = [2*1024] +gtype = [torch.float16] +#mode = ['torch', 'bnb'] +mode = ['bnb'] +optimizer_names = ['paged_adamw'] +#optimizer_names = ['paged_adamw8bit_blockwise'] +values = list(product(dim1,gtype, optimizer_names, mode)) +names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) + layers1 = layers1.to(gtype) + layers1 = layers1.cuda() + + large_tensor = None + if mode == 'torch': + optim = str2optimizers[optim_name][0](layers1.parameters()) + else: + optim = str2optimizers[optim_name][1](layers1.parameters()) + # 12 GB + large_tensor = torch.empty((int(4.5e9),), device='cuda') + + torch.cuda.synchronize() + time.sleep(5) + + num_batches = 5 + batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) + lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + + for i in range(num_batches): + print(i) + b = batches[i] + if i ==2: + torch.cuda.synchronize() + t0 = time.time() + + out1 = layers1(b) + + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() + loss1.backward() + optim.step() + torch.cuda.synchronize() + print(mode, time.time() - t0)