From 3ac5840c03c829f8a77f740a3ce1887df472d1fa Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 4 Feb 2023 14:52:04 -0800 Subject: [PATCH 01/63] Added fp4 quant/dequant and dequant optimizations. --- bitsandbytes/cextension.py | 2 +- bitsandbytes/cuda_setup/main.py | 4 + bitsandbytes/functional.py | 123 +++++++++++++- csrc/kernels.cu | 288 ++++++++++++++++++++++++-------- csrc/kernels.cuh | 4 +- csrc/ops.cu | 54 +++--- csrc/ops.cuh | 4 +- csrc/pythonInterface.c | 23 ++- tests/test_functional.py | 85 +++++++++- 9 files changed, 471 insertions(+), 116 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7a62c1e12..e2ca978eb 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -9,7 +9,7 @@ setup = CUDASetup.get_instance() -if setup.initialized != True: +if not setup.initialized: setup.run_cuda_setup() if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': setup.print_log_stack() diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index cd9573fe2..6bebd9318 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -35,6 +35,9 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def generate_instructions(self): + if getattr(self, 'error', False): return + print(self.error) + self.error = True if self.cuda is None: self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') @@ -84,6 +87,7 @@ def initialize(self): self.has_printed = False self.lib = None self.initialized = False + self.error = False def run_cuda_setup(self): self.initialized = True diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a7c4f20..da9e7430b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -168,7 +168,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: - bias = 2**(exponent_bits-1)-1 + bias = 2**(exponent_bits-1)+1 + print(bias) for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -176,10 +177,12 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value*2**-(bias-1) + value = value*2**-(bias) else: # normals - value = value*2**-(evalue-bias-2) + print(value, 1) + value = value*2**-(evalue-bias-1) + print(value, 2) values.append(value) if signed: values.append(-value) @@ -193,7 +196,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values.append(0) values.sort() code = torch.Tensor(values) - code /= code.max() + #code /= code.max() return code @@ -587,7 +590,7 @@ def dequantize_blockwise( code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - is_on_gpu([A, out]) + is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: @@ -602,6 +605,116 @@ def dequantize_blockwise( return out +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64) -> Tensor: + """ + Quantize tensor A in blocks of FP4 values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype): + The quantization state to undo the quantization. + """ + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device) + + state = (absmax, input_shape, A.dtype) + + if out is None: + out = torch.zeros(((n+1)//2,), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + return out, state + + +def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) + Tuple of absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + + if quant_state is None: + assert absmax is not None and out is not None + shape = out.shape + dtype = out.dtype + else: + absmax, shape, dtype = quant_state + + + if out is None: + out = torch.empty(shape, dtype=dtype, device=A.device) + + n = out.numel() + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + return out + + + + def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: if "dynamic" not in name2qmap: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44f5..a1eec6880 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -43,6 +43,79 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 3.5f) + { + if( x > 7.0f) + { + if( x > 10.0f) + return 0b0011+sign; + else + return 0b0010+sign; + } + else + { + if(x > 5.0f) + return 0b101+sign; + else + return 0b100+sign; + } + } + else + { + if(x > 1.03125f) + { + if(x > 2.5f) + return 0b0111+sign; + else + return 0b0110+sign; + } + else + { + if(x > 0.03125f) + return 0b0001+sign; + else + return 0b0000+sign; + } + } +} + template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { @@ -427,7 +500,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } } -template +template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { @@ -437,13 +510,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH]; + unsigned char qvals[FP4 ? NUM_PER_TH/2 : NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; @@ -454,8 +527,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; - for(int i = threadIdx.x; i < 256; i+=blockDim.x) - smem_code[i] = code[i]; + if(!FP4) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -495,61 +569,138 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + if(FP4) { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); - else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + unsigned char packed_fp4 = 0; + packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4; + packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f); + qvals[j] = packed_fp4; + } + } + else + { + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } } __syncthreads(); - StoreChar(storec).Store(&(out[i]), qvals, valid_items); + StoreChar(storec).Store(&(out[FP4 ? i/2 : i]), qvals, FP4 ? (valid_items+1)/2 : valid_items); } } -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { - const int n_full = gridDim.x * BLOCK_SIZE; - int valid_items = 0; - const int base_idx = (blockIdx.x * BLOCK_SIZE); + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); - T vals[NUM_PER_TH]; + T vals[NUM_PER_TH*(FP4 ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - //__shared__ float smem_code[256]; - //float local_code[16]; - - //if(threadIdx.x < 256) - //smem_code[threadIdx.x] = code[threadIdx.x]; - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; - local_abs_max = absmax[i/BLOCK_SIZE]; + if(FP4) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + + if(FP4) + { + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); + vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); + } + } + else + { + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + } __syncthreads(); - StoreT(storet).Store(&(out[i]), vals, valid_items); + StoreT(storet).Store(&(out[FP4 ? i*2 : i]), vals, valid_items_store); } } +//template +//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store) +//{ +// +// const int n_load = n_store/2; +// const int base_idx = (blockIdx.x * TILE_SIZE); +// +// T vals[NUM_PER_TH*2]; +// unsigned char qvals[NUM_PER_TH]; +// +// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE; +// int idx = base_idx + (threadIdx.x*NUM_PER_TH); +// +// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]); +// +// if(valid_items == TILE_SIZE) +// { +// // we do 64 byte loads so we can 128 byte stores +// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; +// } +// else +// { +// #pragma unroll +// for(int j = 0; j < NUM_PER_TH; j++) +// if(idx+j < n_load) +// qvals[j] = A[idx+j]; +// else +// qvals[j] = 0; +// } +// +// +// #pragma unroll NUM_PER_TH +// for(int j = 0; j < NUM_PER_TH; j++) +// { +// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f); +// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f); +// } +// +// +// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; +// reinterpret_cast(A)[idx/16] = reinterpret_cast(local_valC)[j/num_items]; +// +// +//} + __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { @@ -2523,7 +2674,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; if(idx >= colsB){ break; } - //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx); if((idx+num_items < colsB)) { if(BITS == 8) @@ -2543,8 +2693,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o #pragma unroll num_items for(int k = 0; k < num_items; k++) { - //if((float)local_valsB[k] != 0.0) - // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB); if(BITS == 8 && dequant_stats != NULL) // we do texture cache reads (__ldg) on dequant_stats which should be super fast { @@ -2789,38 +2937,42 @@ MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); - -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); - +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); + +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d90ea138e..23aad6c84 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -14,8 +14,8 @@ template__global__ void kEstimateQuantiles(T *__restrict__ const A, __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e107f..483d915f5 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; @@ -58,42 +58,34 @@ template void quantizeBlockwise(float * code, T *A, assert(blocksize == 4096); if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - if(blocksize == 4096) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 2048) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 1024) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 512) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 256) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 128) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 64) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); + int tile_size = FP4 ? 1024 : 512; + + if(FP4) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + else + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -688,12 +680,16 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 31d4dd87c..b3e242419 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -128,8 +128,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290f7..6a4bb0d96 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -75,13 +75,17 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } - -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -148,6 +152,11 @@ extern "C" void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_g##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ diff --git a/tests/test_functional.py b/tests/test_functional.py index 69c200a7e..efdda5455 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -152,7 +152,7 @@ def test_dynamic_quantization(): def test_dynamic_blockwise_quantization(): #print('') - for blocksize in [4096, 2048, 1024, 512]: + for blocksize in [4096, 2048, 1024, 512, 256, 128, 64]: diffs = [] reldiffs = [] for i in range(100): @@ -2189,7 +2189,88 @@ def test_bench_dequantization(): torch.cuda.synchronize() t0 = time.time() for i in range(100): - F.dequantize_blockwise(qa, SA, blocksize=2048) + #F.dequantize_blockwise(qa, SA, blocksize=2048) + qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() #print((time.time()-t0)/1e6) + + +def test_fp4_quant(): + vals = list(product([0, 1], repeat=4)) + + code = {} + for bits in vals: + result = 0 + bias = 3 + sign, e1, e2, p1 = bits + idx = sign*8 + e1*4 + e2*2 + p1*1 + sign = -1.0 if sign else 1.0 + exp = e1*2 + e2*1 + if exp == 0: + # sub-normal + if p1 == 0: result = 0 + else: result = sign*0.0625 + else: + # normal + exp = 2**(-exp + bias + 1) + frac = 1.5 if p1 else 1.0 + result = sign*exp*frac + code[idx] = result + + A1 = torch.randn(1024, 1024, device='cuda').half() + qa, SA = F.quantize_fp4(A1, blocksize=64) + A2 = F.dequantize_fp4(qa, SA) + #qa, SA = F.quantize_fp4(A1, blocksize=128) + #A2 = F.dequantize_fp4(qa, SA, blocksize=128) + + #A1 = A1.flatten().sort()[0] + #A2 = A2.flatten().sort()[0] + + #print(A1) + #print(A2) + + err = (A1 - A2).abs().float() + relerr = (err/A1.abs().float()).mean() + err = err.mean() + + print(err, relerr) + + + + + #assert err.item() < 0.1 + #assert relerr.item() < 0.28 + + +def test_bench_fp4_dequant(): + blocksize = 256 + a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + qa, SA = F.quantize_fp4(a, blocksize=blocksize) + + input_size = a.numel()/2 + output_size = a.numel()*2 + num_bytes = input_size+output_size + GB = num_bytes/1e9 + max_theoretical_s = GB/768 + print(max_theoretical_s*1e6) + b = torch.randn(128, 1024*12, device='cuda').half() + + iters = 5 + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + F.dequantize_fp4(qa, SA, blocksize=blocksize) + #b.copy_(a) + torch.cuda.synchronize() + print((time.time()-t0)/iters*1e6) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + torch.matmul(b, a.t()) + torch.cuda.synchronize() + print((time.time()-t0)/iters*1e6) + + + From 160a83580d3e159d00fa3004c8b98a64d08fb732 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 4 Feb 2023 21:11:21 -0800 Subject: [PATCH 02/63] Forward matmul_fp4 tests pass. --- bitsandbytes/__init__.py | 1 + bitsandbytes/autograd/_functions.py | 67 +++++++++++++++- bitsandbytes/functional.py | 15 ++-- bitsandbytes/nn/modules.py | 62 +++++++++++++++ tests/test_autograd.py | 115 ++++++++++++++++++++++++++++ tests/test_functional.py | 17 +--- 6 files changed, 254 insertions(+), 23 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 041df4bac..c83b7ff40 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -10,6 +10,7 @@ matmul, matmul_cublas, mm_cublas, + matmul_fp4 ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 376fb8a29..a098d4b07 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,7 +2,7 @@ import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional +from typing import Tuple, Optional, List import torch @@ -474,6 +474,67 @@ def backward(ctx, grad_output): return grad_A, grad_B, None, grad_bias, None +class MatMulFP4(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=None): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + B_shape = state[1] + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + + # 1. Dequantize + # 2. Matmul + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) + + # 3. Save state + ctx.state = state + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = A + else: + ctx.tensors = [None, None] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + A = ctx.tensors + state = ctx.state + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + + if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A)) + + return grad_A, grad_B, None, grad_bias, None + + def matmul( A: tensor, B: tensor, @@ -486,3 +547,7 @@ def matmul( if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, bias, state) + + +def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None): + return MatMulFP4.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index da9e7430b..92ac67063 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -626,7 +626,7 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize ------- torch.Tensor: The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype): + tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ if A.device.type != 'cuda': @@ -640,10 +640,10 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - state = (absmax, input_shape, A.dtype) + state = (absmax, input_shape, A.dtype, blocksize) if out is None: - out = torch.zeros(((n+1)//2,), dtype=torch.uint8, device=A.device) + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -692,7 +692,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: shape = out.shape dtype = out.dtype else: - absmax, shape, dtype = quant_state + absmax, shape, dtype, blocksize = quant_state if out is None: @@ -700,6 +700,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: n = out.numel() + device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: @@ -710,9 +711,9 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - return out - - + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 45df35e89..6dfb06cc5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -133,6 +133,67 @@ def forward(self, input: Tensor) -> Tensor: return emb +class FP4Params(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, quant_state=None): + cls.quant_state = None + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) + w_fp4, quant_state = bnb.functional.quantize_fp4(w) + self.data = w_fp4 + self.quant_state = quant_state + + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) + else: + new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state) + + return new_param + + +class LinearFP4(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.state = bnb.MatmulLtState() + self.weight = FP4Params(self.weight.data, requires_grad=False) + + def init_8bit_state(self): + pass + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, 'state', None) is None: + print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state) + + return out + class Int8Params(torch.nn.Parameter): def __new__( @@ -208,6 +269,7 @@ def to(self, *args, **kwargs): return new_param + class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None): diff --git a/tests/test_autograd.py b/tests/test_autograd.py index c67126d6b..ba75d76a7 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -429,3 +429,118 @@ def test_matmullt( if req_grad[2]: torch.testing.assert_allclose(gradBias1, gradBias2) + + +n = 1 +k = 3 +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() + +dim2.append(0) + +funcs = [(torch.matmul, bnb.matmul_fp4)] +str_funcs = ["matmul"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +has_fp16_weights = [True, False] +has_bias = [True, False] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names) +def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + B2 = B.clone() + + B2, quant_state = bnb.functional.quantize_fp4(B) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2) + + if has_bias: + out_torch += bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.11 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None + + if req_grad[0]: + torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3 + ) + + if req_grad[2]: + torch.testing.assert_allclose(gradBias1, gradBias2) diff --git a/tests/test_functional.py b/tests/test_functional.py index efdda5455..e6b7b811b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2221,26 +2221,13 @@ def test_fp4_quant(): A1 = torch.randn(1024, 1024, device='cuda').half() qa, SA = F.quantize_fp4(A1, blocksize=64) A2 = F.dequantize_fp4(qa, SA) - #qa, SA = F.quantize_fp4(A1, blocksize=128) - #A2 = F.dequantize_fp4(qa, SA, blocksize=128) - - #A1 = A1.flatten().sort()[0] - #A2 = A2.flatten().sort()[0] - - #print(A1) - #print(A2) err = (A1 - A2).abs().float() relerr = (err/A1.abs().float()).mean() err = err.mean() - print(err, relerr) - - - - - #assert err.item() < 0.1 - #assert relerr.item() < 0.28 + assert err.item() < 0.1 + assert relerr.item() < 0.28 def test_bench_fp4_dequant(): From 13c0a4dc5d4be33bf0461d8bcc24e982b17dcb11 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 4 Feb 2023 21:35:43 -0800 Subject: [PATCH 03/63] Backward matmul_fp4 passes. --- bitsandbytes/autograd/_functions.py | 15 ++++++++------- tests/test_autograd.py | 16 ---------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a098d4b07..29c0b9308 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -503,11 +503,9 @@ def forward(ctx, A, B, out=None, bias=None, state=None): ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = A + ctx.tensors = (A, B) else: - ctx.tensors = [None, None] - ctx.tensor_states = (None, None) - ctx.save_for_backward(None, None) + ctx.tensors = (None, None) return output @@ -517,10 +515,12 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - A = ctx.tensors + req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + A, B = ctx.tensors state = ctx.state + grad_A, grad_B, grad_bias = None, None, None + if req_gradBias: # compute grad_bias first before changing grad_output dtype grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) @@ -529,7 +529,8 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + # not supported by PyTorch. TODO: create work-around + #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A)) return grad_A, grad_B, None, grad_bias, None diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ba75d76a7..ccbcc871f 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -480,7 +480,6 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2 = B.clone() B2, quant_state = bnb.functional.quantize_fp4(B) @@ -526,21 +525,6 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if req_grad[0]: torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[1]: - n = gradB1.numel() - if dim2 > 0: - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 - else: - assert torch.abs(gradB1).sum() == 0.0 - assert torch.abs(gradB2).sum() == 0.0 - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - - assert (idx == 0).sum().item() <= n * 0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3 - ) if req_grad[2]: torch.testing.assert_allclose(gradBias1, gradBias2) From cfe4705e321d884bae48ce785f29d4a0aff5518b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 4 Feb 2023 22:00:04 -0800 Subject: [PATCH 04/63] Added matmul_fp4 to the benchmark. --- bitsandbytes/autograd/_functions.py | 5 +- bitsandbytes/functional.py | 5 +- tests/test_autograd.py | 6 +- tests/test_functional.py | 86 +++++++++++++++++------------ 4 files changed, 57 insertions(+), 45 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 29c0b9308..01d1eb2dd 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -495,7 +495,7 @@ def forward(ctx, A, B, out=None, bias=None, state=None): # 1. Dequantize - # 2. Matmul + # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) # 3. Save state @@ -550,5 +550,6 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None): +def matmul_fp4(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): + assert quant_state is not None return MatMulFP4.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 92ac67063..b38ba1db1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: bias = 2**(exponent_bits-1)+1 - print(bias) for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value = value*2**-(bias) else: # normals - print(value, 1) value = value*2**-(evalue-bias-1) - print(value, 2) values.append(value) if signed: values.append(-value) @@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values.append(0) values.sort() code = torch.Tensor(values) - #code /= code.max() + code /= code.max() return code diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ccbcc871f..a8b920761 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) if has_bias: out_torch += bias @@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: - assert err < 0.11 + assert err < 0.115 if any(req_grad): out_bnb.data.copy_(out_torch) diff --git a/tests/test_functional.py b/tests/test_functional.py index e6b7b811b..49022dc00 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1788,18 +1788,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): seqdim = 1 values = [] values.append((batch_size, seqdim, 768, 4 * 768)) -# values.append((batch_size, seqdim, 1024, 4*1024)) -# values.append((batch_size, seqdim, 1536, 4*1536)) -# values.append((batch_size, seqdim, 2048, 4*2048)) -# values.append((batch_size, seqdim, 2560, 4*2560)) -# values.append((batch_size, seqdim, 4096, 4*4096)) -# values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - - +names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): iters = 128 @@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden): B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") torch.nn.init.xavier_uniform_(B) + B_fp4, state = F.quantize_fp4(B) + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit.eval() outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = ( - bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - ) + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()) linearMixedBit.eval() + linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + # warmup for i in range(iters): torch.matmul(A, B.t()) @@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( - f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_fp4(A, B_fp4, quant_state=state) + torch.cuda.synchronize() + print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() @@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden): Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) torch.cuda.synchronize() - #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") CxB, SB = F.nvidia_transform(CB, to_order=formatB) @@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden): Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) out = Cout * statsB * statsA * (1.0 / (127 * 127)) torch.cuda.synchronize() - #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linear8bit(A) torch.cuda.synchronize() @@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linear8bit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linearMixedBit(A) torch.cuda.synchronize() @@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linearMixedBit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + linear8bit_train(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit_train(A) + torch.cuda.synchronize() + print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + linear8bit_train_thresh(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit_train(A) + torch.cuda.synchronize() + print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") def test_zeropoint(): def quant_zp(x): @@ -2050,7 +2066,6 @@ def test_fp8_quant(): p_bits = 7-e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() - print(e_bits, p_bits) abserr = [] relerr = [] for i in range(100): @@ -2189,7 +2204,6 @@ def test_bench_dequantization(): torch.cuda.synchronize() t0 = time.time() for i in range(100): - #F.dequantize_blockwise(qa, SA, blocksize=2048) qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() #print((time.time()-t0)/1e6) @@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant(): num_bytes = input_size+output_size GB = num_bytes/1e9 max_theoretical_s = GB/768 - print(max_theoretical_s*1e6) + #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() iters = 5 @@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant(): F.dequantize_fp4(qa, SA, blocksize=blocksize) #b.copy_(a) torch.cuda.synchronize() - print((time.time()-t0)/iters*1e6) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - torch.matmul(b, a.t()) - torch.cuda.synchronize() - print((time.time()-t0)/iters*1e6) + #print((time.time()-t0)/iters*1e6) + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # torch.matmul(b, a.t()) + #torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) From c361f84239d52844ddae724e40c2c9a5d49284d5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 5 Feb 2023 06:16:56 -0800 Subject: [PATCH 05/63] Fixed matmul_fp4 transpose. --- bitsandbytes/autograd/_functions.py | 4 ++-- tests/test_autograd.py | 4 ++-- tests/test_functional.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 01d1eb2dd..6db90f5e1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -496,7 +496,7 @@ def forward(ctx, A, B, out=None, bias=None, state=None): # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) # 3. Save state ctx.state = state @@ -531,7 +531,7 @@ def backward(ctx, grad_output): # not supported by PyTorch. TODO: create work-around #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A)) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t()) return grad_A, grad_B, None, grad_bias, None diff --git a/tests/test_autograd.py b/tests/test_autograd.py index a8b920761..436c6b126 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2, quant_state, bias=bias2) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) if has_bias: out_torch += bias diff --git a/tests/test_functional.py b/tests/test_functional.py index 49022dc00..23b75587a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul_fp4(A, B_fp4, quant_state=state) + bnb.matmul_fp4(A, B_fp4.t(), quant_state=state) torch.cuda.synchronize() print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) From c0c352b3791a5aab14263108595479b9db58fa1f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 5 Feb 2023 06:29:52 -0800 Subject: [PATCH 06/63] Added bias test for LinearFP4 and basic test. --- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 6 +++--- tests/test_modules.py | 43 +++++++++++-------------------------- 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index edc595aea..79fb51e17 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding +from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6dfb06cc5..4c719c676 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -188,9 +188,9 @@ def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, 'state', None) is None: - print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') - out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state) + if getattr(self.weight, 'quant_state', None) is None: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state) return out diff --git a/tests/test_modules.py b/tests/test_modules.py index d78f0c9f9..ba67bfc95 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -330,12 +330,8 @@ def test_linear8bitlt_inference(threshold): def test_linear8bitlt_accumulated_gradient(): - l1 = torch.nn.Sequential( - *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] - ) - l2 = torch.nn.Sequential( - *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] - ) + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) @@ -376,21 +372,10 @@ def test_linear8bitlt_accumulated_gradient(): torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) -threshold = [0.0, 2.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - - -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( - bnb.nn.Linear8bitLt( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .cuda() - .half() - ) + l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -446,13 +431,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( - MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .half() - .to("cuda") - ) + mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -504,10 +483,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert (idx == 0).sum().item() <= b1.numel() * 0.005 -def test_linear8bitlt_fp32_bias(): +@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias.dtype == torch.float32 for i in range(100): @@ -517,11 +497,12 @@ def test_linear8bitlt_fp32_bias(): assert l1.bias.dtype == torch.float16 # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64, bias=False).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias is None for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) assert l1.bias is None + From 7f0773aede92a8be5bf0645185de4f5707b3a2a8 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 5 Feb 2023 06:49:54 -0800 Subject: [PATCH 07/63] Added backprop test for Linear8bitLt and LinearFP4. --- tests/test_modules.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index ba67bfc95..41cc050f2 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -375,7 +375,7 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) + l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -506,3 +506,41 @@ def test_linear_kbit_fp32_bias(module): o1 = l1(b1) assert l1.bias is None +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_kbit_backprop(module): + b = 17 + dim1 = 37 + dim2 = 83 + + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) + ref[1].weight.requires_grad = False + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) + kbit[0].weight.detach().copy_(ref[0].weight) + kbit[1].weight.detach().copy_(ref[1].weight) + kbit[0].bias.detach().copy_(ref[0].bias) + kbit[1].bias.detach().copy_(ref[1].bias) + ref = ref.half().cuda() + kbit = kbit.half().cuda() + + for i in range(100): + batch = torch.randn(b, dim1).half().cuda() + out1 = ref(batch) + out2 = kbit(batch) + out1.mean().backward() + out2.mean().backward() + + grad1 = ref[0].weight.grad + grad2 = kbit[0].weight.grad + bgrad1 = ref[0].bias.grad + bgrad2 = kbit[0].bias.grad + + torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) + torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) + ref.zero_grad() + kbit.zero_grad() + + assert kbit[0].weight.grad.sum().item() == 0 + assert kbit[0].bias.grad.sum().item() == 0 + + From c93a90d07595c143e87831228815d88a1e6d32e7 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 14 Feb 2023 13:31:39 -0800 Subject: [PATCH 08/63] Fixed FP4 import and data type conversion in backward. --- bitsandbytes/autograd/_functions.py | 6 +----- bitsandbytes/nn/__init__.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6db90f5e1..ffe19c5ca 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -525,13 +525,9 @@ def backward(ctx, grad_output): # compute grad_bias first before changing grad_output dtype grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) - # Cast grad_output to fp16 - if len(grad_output.shape) == 3: - grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # not supported by PyTorch. TODO: create work-around #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t()) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 79fb51e17..954a67f79 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4 +from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4, FP4Params From 9851a10b46d54bf1b2ae9b37d59f55f3d6580625 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 24 Feb 2023 10:17:57 -0800 Subject: [PATCH 09/63] Added cast to fp4 layer for speed. --- bitsandbytes/autograd/_functions.py | 7 ++++--- bitsandbytes/nn/modules.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index ffe19c5ca..8070ff8a2 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -404,10 +404,10 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = (CAt, subA) + ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None] + ctx.tensors = [None, None, A] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) @@ -420,7 +420,7 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - CAt, subA = ctx.tensors + CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state @@ -436,6 +436,7 @@ def backward(ctx, grad_output): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: + #grad_B = torch.matmul(grad_output.t(), A) CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 4c719c676..ad3f4f71c 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -190,7 +190,11 @@ def forward(self, x: torch.Tensor): if getattr(self.weight, 'quant_state', None) is None: print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') - out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state) + + inp_dtype = x.dtype + x = x.to(torch.float16) + out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state) + out = out.to(inp_dtype) return out From 6c31a5fe991169d1caad2426b1cee479af6afd13 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Mon, 27 Feb 2023 14:23:21 -0800 Subject: [PATCH 10/63] t5 model fix --- bitsandbytes/nn/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ad3f4f71c..5d6d19cae 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -190,10 +190,10 @@ def forward(self, x: torch.Tensor): if getattr(self.weight, 'quant_state', None) is None: print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') - inp_dtype = x.dtype x = x.to(torch.float16) - out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state) + bias = None if self.bias is None else self.bias.half() + out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) out = out.to(inp_dtype) return out From 69810521d37ed419452aac573f1c3b283290668c Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 27 Mar 2023 09:12:57 -0700 Subject: [PATCH 11/63] Some small changes. --- bitsandbytes/nn/modules.py | 8 +- bitsandbytes/utils.py | 40 +++++++++ csrc/kernels.cu | 2 + csrc/ops.cu | 2 + tests/test_functional.py | 170 ++++++++++++++++++------------------- 5 files changed, 135 insertions(+), 87 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5d6d19cae..a550ec1f4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -173,10 +173,11 @@ def to(self, *args, **kwargs): class LinearFP4(nn.Linear): - def __init__(self, input_features, output_features, bias=True): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None): super().__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() self.weight = FP4Params(self.weight.data, requires_grad=False) + self.compute_dtype = compute_dtype def init_8bit_state(self): pass @@ -191,9 +192,12 @@ def forward(self, x: torch.Tensor): if getattr(self.weight, 'quant_state', None) is None: print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') inp_dtype = x.dtype - x = x.to(torch.float16) + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + bias = None if self.bias is None else self.bias.half() out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + out = out.to(inp_dtype) return out diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 1cd90e377..d6cc9660b 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -21,3 +21,43 @@ def execute_and_return_decoded_std_streams(command_string): std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err + + + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a1eec6880..a2691be7f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2968,6 +2968,8 @@ template __global__ void kQuantizeBlockwise(float * code, ha template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +//template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +//template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 483d915f5..07ef85074 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -71,6 +71,8 @@ template void quantizeBlockwise(float * co kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 32) + //kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); diff --git a/tests/test_functional.py b/tests/test_functional.py index 23b75587a..54cecca51 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1784,17 +1784,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 +batch_size = 4 +seqdim = 256 values = [] values.append((batch_size, seqdim, 768, 4 * 768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) +values.append((batch_size, seqdim, 1024, 4*1024)) +values.append((batch_size, seqdim, 1536, 4*1536)) +values.append((batch_size, seqdim, 2048, 4*2048)) +values.append((batch_size, seqdim, 2560, 4*2560)) +values.append((batch_size, seqdim, 4096, 4*4096)) +values.append((batch_size, seqdim, 5140, 4*5140)) +values.append((batch_size, seqdim, 12288, 4*12288)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): @@ -1839,90 +1839,90 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul(A, B) - torch.cuda.synchronize() - print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul(A, B, threshold=6.0) - torch.cuda.synchronize() - print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B, threshold=6.0) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - C32A, SA = F.transform(CA, "col32") - CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - CxB, SB = F.transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - torch.cuda.synchronize() - print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + #C32A, SA = F.transform(CA, "col32") + #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + #CxB, SB = F.transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + #torch.cuda.synchronize() + #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - BA, statsB = F.vectorwise_quant(B, dim=1) - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1) - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - torch.cuda.synchronize() - print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #BA, statsB = F.vectorwise_quant(B, dim=1) + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1) + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + #torch.cuda.synchronize() + #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - out = Cout * statsB * statsA * (1.0 / (127 * 127)) - torch.cuda.synchronize() - print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # out = Cout * statsB * statsA * (1.0 / (127 * 127)) + #torch.cuda.synchronize() + #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linearMixedBit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linearMixedBit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linearMixedBit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linearMixedBit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit_train(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit_train(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit_train(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit_train_thresh(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit_train(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit_train_thresh(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") def test_zeropoint(): def quant_zp(x): From 8645d1f71cc78155887bc3ba082b1a610a05e31f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 29 Mar 2023 18:41:37 -0700 Subject: [PATCH 12/63] Added normal quant. --- bitsandbytes/functional.py | 76 +++++++++++++++++++++++++++++++++++--- csrc/kernels.cu | 4 +- csrc/ops.cu | 4 +- tests/test_functional.py | 10 ++--- 4 files changed, 80 insertions(+), 14 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b38ba1db1..969250a50 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -9,6 +9,8 @@ import torch import itertools import math +import scipy.stats +import numpy as np from functools import reduce # Required in Python 3 from typing import Tuple @@ -152,6 +154,70 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) +def custom_map(seed=0, scale=0.01): + v = [12, 10, 8, 6, 3, 2, 1] + # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45 + # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48 + + # 13B 100 steps: + # - 4-bit evo: 86.02 + # - 4-bit norm: 78.73 + # - 4-bit FP4: + # - 16-bit: + + # interval search on normal distribution + #v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5 + #v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99 + #v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97 + #v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81 + #v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68 + #v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03 + #v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01 + #v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47 + #v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85 + #v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287 + ##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 + #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 + #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 + + # 7B evo start + #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 + #v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951] + #v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564] + + # 13B evo start + #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042] + #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283] + v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] + + # mean evo 7B + 13B + #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] + + # theoretically optiomal (0.93333) + # v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 + + + + if seed > 0: + v = np.array(v) + np.random.seed(seed) + v += np.random.randn(7)*scale + print(v.tolist()) + #v[0] += (np.random.randn(1)*0.001)[0] + #v[-1] += (np.random.randn(1)*0.001)[0] + #print(v[0], v[-1]) + v = v.tolist() + values = v + [0]*(256-14) + \ + v[::-1] + + values = torch.Tensor(values) + values[0:7] *= -1 + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits @@ -168,7 +234,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: - bias = 2**(exponent_bits-1)+1 + bias = 2**(exponent_bits-1)-1 for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -176,10 +242,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value*2**-(bias-1) else: # normals - value = value*2**-(evalue-bias-1) + value = value*2**-(evalue-bias-2) values.append(value) if signed: values.append(-value) @@ -502,7 +568,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) @@ -585,7 +651,7 @@ def dequantize_blockwise( if A.device.type != 'cpu': device = pre_call(A.device) code = code.to(A.device) - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a2691be7f..8f331616c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2953,6 +2953,8 @@ template __global__ void kQuantizeBlockwise(float * code, ha template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); @@ -2968,8 +2970,6 @@ template __global__ void kQuantizeBlockwise(float * code, ha template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -//template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -//template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 07ef85074..8044c6671 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -71,8 +71,8 @@ template void quantizeBlockwise(float * co kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - //else if(blocksize == 32) - //kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 32 and FP4 == 0) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); diff --git a/tests/test_functional.py b/tests/test_functional.py index 54cecca51..cd4728e1d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -152,7 +152,7 @@ def test_dynamic_quantization(): def test_dynamic_blockwise_quantization(): #print('') - for blocksize in [4096, 2048, 1024, 512, 256, 128, 64]: + for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]: diffs = [] reldiffs = [] for i in range(100): @@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization(): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.011 assert relerr < 0.018 - #print('randn', blocksize, sum(diffs)/len(diffs)) - #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) + print('randn', blocksize, sum(diffs)/len(diffs)) + print('randn', blocksize, sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization(): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 assert relerr < 0.015 - #print('rand', blocksize, sum(diffs)/len(diffs)) - #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) + print('rand', blocksize, sum(diffs)/len(diffs)) + print('rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): From c4cfe4fbdd70088c2ff0db1ae81bfe01c35fd2ae Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 1 Apr 2023 10:33:03 -0700 Subject: [PATCH 13/63] Added bf16 Adam. --- Makefile | 7 ++-- bitsandbytes/functional.py | 68 +++++++++++++++++--------------------- csrc/kernels.cu | 2 ++ csrc/ops.cu | 2 ++ csrc/pythonInterface.c | 43 ++++++++++++------------ tests/test_optim.py | 43 ++++++++++-------------- 6 files changed, 78 insertions(+), 87 deletions(-) diff --git a/Makefile b/Makefile index 7bee7ef36..e1141607f 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ CUDA_VERSION:= endif + NVCC := $(CUDA_HOME)/bin/nvcc ########################################### @@ -59,9 +60,9 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 -all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o +all: $(BUILD_DIR) env + $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 969250a50..8bfd66860 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -73,6 +73,7 @@ def prod(iterable): str2optimizer8bit_blockwise["adam"] = ( lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16, + lib.cadam_8bit_blockwise_bf16, ) str2optimizer8bit_blockwise["momentum"] = ( lib.cmomentum_8bit_blockwise_fp32, @@ -1125,51 +1126,42 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: + optim_func = None if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): + optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + prev_device = pre_call(g.device) + optimizer_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) def percentile_clipping( grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8f331616c..e7e57d7a9 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2988,6 +2988,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise( \ diff --git a/csrc/ops.cu b/csrc/ops.cu index 8044c6671..a5a23b5ec 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -741,3 +741,5 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 6a4bb0d96..a485a098c 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -57,19 +57,20 @@ MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ +void fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ { optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ -MAKE_BLOCKWISE8(adam, ADAM, half, 16) -MAKE_BLOCKWISE8(adam, ADAM, float, 32) -MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) -MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -194,20 +195,20 @@ extern "C" MAKE_CFUNC8(rmsprop, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ - void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ + void c##fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ - - MAKE_CBLOCKWISE8(adam, ADAM, half, 16) - MAKE_CBLOCKWISE8(adam, ADAM, float, 32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) - + { fname##_8bit_blockwise_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) + MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dada2..92e3ed260 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -26,6 +26,8 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) +str2bf16support = {} +str2bf16support['adam8bit_blockwise'] = True str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) @@ -238,7 +240,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] +gtype = [torch.float32, torch.float16, torch.bfloat16] optimizer_names = [ "adam8bit", "momentum8bit", @@ -256,6 +258,7 @@ def test_global_config(dim1, dim2, gtype): @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name not in str2bf16support: return if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -269,7 +272,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 - + elif gtype == torch.bfloat16: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-4, 1e-2 else: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 @@ -314,8 +319,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): err = torch.abs(p1 - p2) relerr = err / torch.abs(p1) - assert err.mean() < 0.0001 - assert relerr.mean() < 0.001 + if g.dtype == torch.bfloat16: + assert err.mean() < 0.00015 + assert relerr.mean() < 0.0015 + else: + assert err.mean() < 0.0001 + assert relerr.mean() < 0.001 errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) @@ -335,12 +344,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose( - raws1cpy, bnb_optimizer.state[p2][name2] - ) - torch.testing.assert_allclose( - qmap1, bnb_optimizer.state[p2][qmap] - ) + torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -357,28 +362,16 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) torch.testing.assert_allclose(s1cpy, s1) - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], - s1, - atol=atol, - rtol=rtol, - ) - == 0 - ) + num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose( - p1, p2.float(), atol=patol, rtol=prtol - ) + torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) torch.testing.assert_allclose(p1.to(gtype), p2) - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) # print(sum(errors)/len(errors)) From 51a21df7288a7e2f78c10778493f9ba554694e81 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 1 Apr 2023 16:10:18 -0700 Subject: [PATCH 14/63] Added 8-bit compression to quantization statistics. --- bitsandbytes/functional.py | 38 +++++++++++++++++++--------- bitsandbytes/nn/modules.py | 10 +++++--- tests/test_autograd.py | 13 +++++----- tests/test_functional.py | 52 +++++++++++++++++++++++++++++++++++--- tests/test_modules.py | 2 +- 5 files changed, 88 insertions(+), 27 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8bfd66860..8234c46bb 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -155,7 +155,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) -def custom_map(seed=0, scale=0.01): +def create_custom_map(seed=0, scale=0.01): v = [12, 10, 8, 6, 3, 2, 1] # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45 # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48 @@ -191,13 +191,13 @@ def custom_map(seed=0, scale=0.01): # 13B evo start #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042] #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283] - v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] + #v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] # mean evo 7B + 13B #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] # theoretically optiomal (0.93333) - # v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 + v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 @@ -599,7 +599,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - return out, (absmax, code) + state = (absmax, code, blocksize) + + return out, state def dequantize_blockwise( @@ -644,9 +646,9 @@ def dequantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.float32) if quant_state is None: - quant_state = (absmax, code) + quant_state = (absmax, code, blocksize) else: - absmax, code = quant_state + absmax, code, blocksize = quant_state if A.device.type != 'cpu': @@ -669,7 +671,7 @@ def dequantize_blockwise( return out -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64) -> Tensor: +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False) -> Tensor: """ Quantize tensor A in blocks of FP4 values. @@ -704,12 +706,11 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - state = (absmax, input_shape, A.dtype, blocksize) if out is None: out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) @@ -722,6 +723,17 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) + if compress_statistics: + offset = absmax.mean() + absmax -= offset + #code = create_custom_map().to(absmax.device) + #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = (qabsmax, input_shape, A.dtype, blocksize, (offset, state2)) + else: + state = (absmax, input_shape, A.dtype, blocksize, None) + return out, state @@ -756,8 +768,12 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: shape = out.shape dtype = out.dtype else: - absmax, shape, dtype, blocksize = quant_state + absmax, shape, dtype, blocksize, compressed_stats = quant_state + if compressed_stats is not None: + offset, state2 = compressed_stats + absmax = dequantize_blockwise(absmax, state2) + absmax += offset if out is None: out = torch.empty(shape, dtype=dtype, device=A.device) @@ -1986,8 +2002,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB = ct.c_int32(B.shape[1]) cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - # print(cooA.rowidx[:64]) - # print(cooA.colidx[:64].sort()[0]) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index a550ec1f4..45eef4256 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -134,15 +134,17 @@ def forward(self, input: Tensor) -> Tensor: return emb class FP4Params(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, quant_state=None): + def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True): cls.quant_state = None + cls.blocksize = blocksize + cls.compress_statistics = compress_statistics if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) def cuda(self, device): w = self.data.contiguous().half().cuda(device) - w_fp4, quant_state = bnb.functional.quantize_fp4(w) + w_fp4, quant_state = bnb.functional.quantize_fp4(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics) self.data = w_fp4 self.quant_state = quant_state @@ -173,10 +175,10 @@ def to(self, *args, **kwargs): class LinearFP4(nn.Linear): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): super().__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() - self.weight = FP4Params(self.weight.data, requires_grad=False) + self.weight = FP4Params(self.weight.data, requires_grad=False, compress_statistics=compress_statistics) self.compute_dtype = compute_dtype def init_8bit_state(self): diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 436c6b126..4356c1d2a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -454,14 +454,15 @@ def test_matmullt( transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] dtype = [torch.float16, torch.float32] +compress_statistics = [False, True] has_fp16_weights = [True, False] has_bias = [True, False] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics".format(*vals) for vals in str_values] @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names) -def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias): +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics", values, ids=names) +def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: @@ -481,7 +482,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_fp4(B) + B2, quant_state = bnb.functional.quantize_fp4(B, compress_statistics=compress_statistics) if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) diff --git a/tests/test_functional.py b/tests/test_functional.py index cd4728e1d..a97470153 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization(): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.011 assert relerr < 0.018 - print('randn', blocksize, sum(diffs)/len(diffs)) - print('randn', blocksize, sum(reldiffs)/len(reldiffs)) + #print('randn', blocksize, sum(diffs)/len(diffs)) + #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization(): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 assert relerr < 0.015 - print('rand', blocksize, sum(diffs)/len(diffs)) - print('rand', blocksize, sum(reldiffs)/len(reldiffs)) + #print('rand', blocksize, sum(diffs)/len(diffs)) + #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): @@ -1806,6 +1806,7 @@ def test_bench_matmul(batch, seq, model, hidden): torch.nn.init.xavier_uniform_(B) B_fp4, state = F.quantize_fp4(B) + B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit.eval() @@ -1839,6 +1840,13 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_fp4(A, B_fp4.t(), quant_state=state_c) + torch.cuda.synchronize() + print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() #t0 = time.time() #for i in range(iters): @@ -2244,6 +2252,42 @@ def test_fp4_quant(): assert relerr.item() < 0.28 +def test_fp4_compressed_stats(): + for blocksize in [128, 64]: + errs1 = [] + errs2 = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device='cuda').half() + q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize) + q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True) + A2 = F.dequantize_fp4(q2, SA2) + A3 = F.dequantize_fp4(q3, SA3) + + + err = (A1 - A2).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs1.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + err = (A1 - A3).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs2.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + #print(sum(errs1)/len(errs1), blocksize) + #print(sum(errs2)/len(errs2), blocksize) + + + + def test_bench_fp4_dequant(): blocksize = 256 a = torch.rand(1024*12*4, 1024*12, device='cuda').half() diff --git a/tests/test_modules.py b/tests/test_modules.py index 41cc050f2..d0f5ca2c6 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -507,7 +507,7 @@ def test_linear_kbit_fp32_bias(module): assert l1.bias is None @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4, lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)], ids=['Int8Lt', 'FP4', 'FP4+C']) def test_kbit_backprop(module): b = 17 dim1 = 37 From 2dd5d69056e3b94f0462dd9ce6aaff7a89294d23 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 2 Apr 2023 12:42:01 -0700 Subject: [PATCH 15/63] Generalized FP4 data type. --- csrc/kernels.cu | 144 ++++++++++++++++++++++----------------- tests/test_functional.py | 10 +-- 2 files changed, 88 insertions(+), 66 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e7e57d7a9..2e61297dd 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax) } } +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + __device__ unsigned char dQuantizeFP4(float x) { // FP4 with bias of 3 @@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x) // 0b010 = 8 // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ unsigned char dQuantizeNormal(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + int sign = x < 0 ? 0b1000 : 0b0000; x = fabsf(x); if(x > 3.5f) - { if( x > 7.0f) - { if( x > 10.0f) return 0b0011+sign; else return 0b0010+sign; - } else - { if(x > 5.0f) return 0b101+sign; else return 0b100+sign; - } - } else - { if(x > 1.03125f) - { if(x > 2.5f) return 0b0111+sign; else return 0b0110+sign; - } else - { if(x > 0.03125f) return 0b0001+sign; else return 0b0000+sign; - } - } } template @@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float for(int j = 0; j < NUM_PER_TH/2; j++) { unsigned char packed_fp4 = 0; - packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4; - packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f); + packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_fp4; } } @@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { - vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); - vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); + //vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); + //vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); } } else @@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs } } -//template -//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store) -//{ -// -// const int n_load = n_store/2; -// const int base_idx = (blockIdx.x * TILE_SIZE); -// -// T vals[NUM_PER_TH*2]; -// unsigned char qvals[NUM_PER_TH]; -// -// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE; -// int idx = base_idx + (threadIdx.x*NUM_PER_TH); -// -// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]); -// -// if(valid_items == TILE_SIZE) -// { -// // we do 64 byte loads so we can 128 byte stores -// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; -// } -// else -// { -// #pragma unroll -// for(int j = 0; j < NUM_PER_TH; j++) -// if(idx+j < n_load) -// qvals[j] = A[idx+j]; -// else -// qvals[j] = 0; -// } -// -// -// #pragma unroll NUM_PER_TH -// for(int j = 0; j < NUM_PER_TH; j++) -// { -// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f); -// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f); -// } -// -// -// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; -// reinterpret_cast(A)[idx/16] = reinterpret_cast(local_valC)[j/num_items]; -// -// -//} - - __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { const unsigned int numThreads = blockDim.x * gridDim.x; diff --git a/tests/test_functional.py b/tests/test_functional.py index a97470153..12411e31f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2246,8 +2246,10 @@ def test_fp4_quant(): err = (A1 - A2).abs().float() relerr = (err/A1.abs().float()).mean() + idx = err > 1.0 err = err.mean() + assert err.item() < 0.1 assert relerr.item() < 0.28 @@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats(): for blocksize in [128, 64]: errs1 = [] errs2 = [] - for i in range(10): + for i in range(10000): A1 = torch.randn(1024, 1024, device='cuda').half() q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize) q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True) @@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats(): relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() - errs1.append(err.item()) + errs1.append(relerr.item()) assert err.item() < 0.11 assert relerr.item() < 0.28 @@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats(): relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() - errs2.append(err.item()) + errs2.append(relerr.item()) assert err.item() < 0.11 assert relerr.item() < 0.28 @@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant(): #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() - iters = 5 + iters = 500 torch.cuda.synchronize() t0 = time.time() for i in range(iters): From 0d332a641ff6b28e71b2a9ab5e641f8cf4a2ec99 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 2 Apr 2023 14:09:08 -0700 Subject: [PATCH 16/63] Added normal with extra value. --- bitsandbytes/functional.py | 28 +++++++++++++++++++++++----- tests/test_functional.py | 3 --- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8234c46bb..161f58f2f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -9,7 +9,7 @@ import torch import itertools import math -import scipy.stats +from scipy.stats import norm import numpy as np from functools import reduce # Required in Python 3 @@ -181,7 +181,7 @@ def create_custom_map(seed=0, scale=0.01): #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 - #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 + v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 # 7B evo start #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 @@ -197,9 +197,7 @@ def create_custom_map(seed=0, scale=0.01): #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] # theoretically optiomal (0.93333) - v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 - - + #v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 if seed > 0: v = np.array(v) @@ -220,6 +218,26 @@ def create_custom_map(seed=0, scale=0.01): assert values.numel() == 256 return values +def create_normal_map(offset=0.966666, use_extra_value=True): + + if use_extra_value: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() + v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + else: + v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() + v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values + def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits diff --git a/tests/test_functional.py b/tests/test_functional.py index 12411e31f..47a30a607 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2318,6 +2318,3 @@ def test_bench_fp4_dequant(): # torch.matmul(b, a.t()) #torch.cuda.synchronize() #print((time.time()-t0)/iters*1e6) - - - From 4ad999d1440e896abec3f3c7029f292ce46cc820 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 2 Apr 2023 14:42:45 -0700 Subject: [PATCH 17/63] Added quantization tree generation. --- bitsandbytes/functional.py | 2 +- tests/test_functional.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 161f58f2f..5198526cf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -218,7 +218,7 @@ def create_custom_map(seed=0, scale=0.01): assert values.numel() == 256 return values -def create_normal_map(offset=0.966666, use_extra_value=True): +def create_normal_map(offset=0.9677083, use_extra_value=True): if use_extra_value: # one more positive value, this is an asymmetric type diff --git a/tests/test_functional.py b/tests/test_functional.py index 47a30a607..074135e18 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2318,3 +2318,19 @@ def test_bench_fp4_dequant(): # torch.matmul(b, a.t()) #torch.cuda.synchronize() #print((time.time()-t0)/iters*1e6) + + + +def test_normal_map_tree(): + code = F.create_normal_map() + values =code[:8].tolist() + code[-8:].tolist() + num_pivots = 1 + while num_pivots <16: + idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) + print(idx) + num_pivots *= 2 + pivots = [] + for i in idx: + pivots.append((values[i-1]+values[i])/2) + print(pivots) + From 64cc05920d0e506e41e814b9ef6053923d967a95 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 2 Apr 2023 16:10:35 -0700 Subject: [PATCH 18/63] First draft of NF4. --- bitsandbytes/functional.py | 44 +++++- csrc/kernels.cu | 275 +++++++++++++++++++++++++------------ csrc/kernels.cuh | 4 +- csrc/ops.cu | 50 +++---- csrc/ops.cuh | 11 +- csrc/pythonInterface.c | 30 ++-- tests/test_functional.py | 23 ++-- 7 files changed, 292 insertions(+), 145 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 5198526cf..83c26054e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -688,8 +688,13 @@ def dequantize_blockwise( return out +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'fp4') -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False) -> Tensor: +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'nf4') + +def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: """ Quantize tensor A in blocks of FP4 values. @@ -705,6 +710,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize The output tensor (8-bit). blocksize : int The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} Returns ------- @@ -715,6 +722,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize """ if A.device.type != 'cuda': raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') n = A.numel() input_shape = A.shape @@ -734,9 +743,15 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -754,8 +769,13 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize return out, state +def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'fp4') + +def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: +def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -771,6 +791,10 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: The absmax values. out : torch.Tensor Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} Returns @@ -780,6 +804,8 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: """ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') if quant_state is None: assert absmax is not None and out is not None @@ -802,9 +828,15 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2e61297dd..0ed413f69 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -140,44 +140,111 @@ __device__ unsigned char dQuantizeFP4(float x) return 0b0000+sign; } +__device__ float dDequantizeNF4(unsigned char val, float absmax) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f*absmax; + else + return 0.7229568362236023f*absmax; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f*absmax; + else + return 0.44070982933044434f*absmax; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f*absmax; + else + return 0.24611230194568634f*absmax; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f*absmax; + else + return 0.07958029955625534f*absmax; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f*absmax; + else + return -0.09105003625154495f*absmax; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f*absmax; + else + return -0.28444138169288635f*absmax; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f*absmax; + else + return -0.5250730514526367f*absmax; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f*absmax; + else + return -1.0f*absmax; + +} + __device__ unsigned char dQuantizeNormal(float x) { - // FP4 with bias of 3 - // first bit is a sign - // subnormals - // 0b000 = 0 - // 0b001 = 0.0625 - // 0b110 = 2 - // 0b111 = 3 - // 0b100 = 4 - // 0b101 = 6 - // 0b010 = 8 - // 0b011 = 12 - int sign = x < 0 ? 0b1000 : 0b0000; - x = fabsf(x); - if(x > 3.5f) - if( x > 7.0f) - if( x > 10.0f) - return 0b0011+sign; + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; else - return 0b0010+sign; + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; else - if(x > 5.0f) - return 0b101+sign; + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; else - return 0b100+sign; + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1100; else - if(x > 1.03125f) - if(x > 2.5f) - return 0b0111+sign; + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; else - return 0b0110+sign; + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; else - if(x > 0.03125f) - return 0b0001+sign; + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; else - return 0b0000+sign; + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; } template @@ -564,7 +631,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } } -template +template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { @@ -574,13 +641,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; - unsigned char qvals[FP4 ? NUM_PER_TH/2 : NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; + typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; @@ -591,7 +658,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; - if(!FP4) + if(DATA_TYPE == General8bit) for(int i = threadIdx.x; i < 256; i+=blockDim.x) smem_code[i] = code[i]; @@ -633,31 +700,41 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - if(FP4) + unsigned char packed_4bit = 0; + switch(DATA_TYPE) { - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) - { - unsigned char packed_fp4 = 0; - packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; - packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); - qvals[j] = packed_fp4; - } - } - else - { - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); - else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); - } + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNormal(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNormal(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; } __syncthreads(); - StoreChar(storec).Store(&(out[FP4 ? i/2 : i]), qvals, FP4 ? (valid_items+1)/2 : valid_items); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } } @@ -2957,44 +3034,60 @@ MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); - -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); - -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 23aad6c84..ed549cb3f 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -14,8 +14,8 @@ template__global__ void kEstimateQuantiles(T *__restrict__ const A, __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, diff --git a/csrc/ops.cu b/csrc/ops.cu index a5a23b5ec..de14039da 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; @@ -60,34 +60,32 @@ template void quantizeBlockwise(float * co if(blocksize == 4096) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 32 and FP4 == 0) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - int tile_size = FP4 ? 1024 : 512; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; - if(FP4) - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + if(DATA_TYPE > 0) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); else - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -682,16 +680,20 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b3e242419..f73d4e0c4 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -81,6 +81,13 @@ typedef enum Transform_t COL_AMPERE = 4, } Transform_t; +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + class Context { public: @@ -128,8 +135,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index a485a098c..d16917891 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -76,17 +76,21 @@ MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } -void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } - -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -157,6 +161,10 @@ extern "C" void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_g##gbits(gtype *g, gtype *p, \ diff --git a/tests/test_functional.py b/tests/test_functional.py index 074135e18..98edb7c4f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2254,16 +2254,18 @@ def test_fp4_quant(): assert relerr.item() < 0.28 -def test_fp4_compressed_stats(): +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_4bit_compressed_stats(quant_type): for blocksize in [128, 64]: errs1 = [] errs2 = [] - for i in range(10000): + for i in range(10): A1 = torch.randn(1024, 1024, device='cuda').half() - q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize) - q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True) - A2 = F.dequantize_fp4(q2, SA2) - A3 = F.dequantize_fp4(q3, SA3) + q2, SA2 = F.quantize_4bit_packed(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3= F.quantize_4bit_packed(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit_packed(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit_packed(q3, SA3, quant_type=quant_type) err = (A1 - A2).abs().float() @@ -2290,10 +2292,12 @@ def test_fp4_compressed_stats(): -def test_bench_fp4_dequant(): +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_bench_fp4_dequant(quant_type): blocksize = 256 a = torch.rand(1024*12*4, 1024*12, device='cuda').half() - qa, SA = F.quantize_fp4(a, blocksize=blocksize) + qa, SA = F.quantize_4bit_packed(a, blocksize=blocksize, quant_type=quant_type) input_size = a.numel()/2 output_size = a.numel()*2 @@ -2307,7 +2311,7 @@ def test_bench_fp4_dequant(): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - F.dequantize_fp4(qa, SA, blocksize=blocksize) + F.dequantize_4bit_packed(qa, SA, blocksize=blocksize, quant_type=quant_type) #b.copy_(a) torch.cuda.synchronize() #print((time.time()-t0)/iters*1e6) @@ -2325,6 +2329,7 @@ def test_normal_map_tree(): code = F.create_normal_map() values =code[:8].tolist() + code[-8:].tolist() num_pivots = 1 + print(values) while num_pivots <16: idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) print(idx) From 4ea489d3bfc119ab4ceb50f999ce611690dc21e2 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 3 Apr 2023 11:00:12 -0700 Subject: [PATCH 19/63] Refactor FP4 into 4Bit and integrate NF4 data type. --- bitsandbytes/__init__.py | 2 +- bitsandbytes/autograd/_functions.py | 6 +- bitsandbytes/functional.py | 21 +++---- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 26 ++++++--- csrc/kernels.cu | 87 ++++++++++++++++------------- tests/test_autograd.py | 15 ++--- tests/test_functional.py | 42 ++++++++------ tests/test_modules.py | 34 ++++++++++- 9 files changed, 145 insertions(+), 90 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index c83b7ff40..fd83532cd 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -10,7 +10,7 @@ matmul, matmul_cublas, mm_cublas, - matmul_fp4 + matmul_4bit ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 8070ff8a2..a9c3a53de 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -475,7 +475,7 @@ def backward(ctx, grad_output): return grad_A, grad_B, None, grad_bias, None -class MatMulFP4(torch.autograd.Function): +class MatMul4Bit(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @@ -547,6 +547,6 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_fp4(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): +def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None - return MatMulFP4.apply(A, B, out, bias, quant_state) + return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 83c26054e..20841ebbc 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -689,14 +689,14 @@ def dequantize_blockwise( return out def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'fp4') + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'nf4') + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') -def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: """ - Quantize tensor A in blocks of FP4 values. + Quantize tensor A in blocks of 4-bit values. Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. @@ -763,19 +763,19 @@ def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, b #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = (qabsmax, input_shape, A.dtype, blocksize, (offset, state2)) + state = (qabsmax, input_shape, A.dtype, blocksize, (offset, state2), quant_type) else: - state = (absmax, input_shape, A.dtype, blocksize, None) + state = (absmax, input_shape, A.dtype, blocksize, None, quant_type) return out, state def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'fp4') + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'nf4') + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: +def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -812,7 +812,8 @@ def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, shape = out.shape dtype = out.dtype else: - absmax, shape, dtype, blocksize, compressed_stats = quant_state + absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state + if compressed_stats is not None: offset, state2 = compressed_stats diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 954a67f79..439f75077 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4, FP4Params +from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 45eef4256..86ea342ec 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -133,18 +133,19 @@ def forward(self, input: Tensor) -> Tensor: return emb -class FP4Params(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True): +class Params4bit(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): cls.quant_state = None cls.blocksize = blocksize cls.compress_statistics = compress_statistics + cls.quant_type = quant_type if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) def cuda(self, device): w = self.data.contiguous().half().cuda(device) - w_fp4, quant_state = bnb.functional.quantize_fp4(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics) + w_fp4, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_fp4 self.quant_state = quant_state @@ -168,17 +169,16 @@ def to(self, *args, **kwargs): if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): return self.cuda(device) else: - new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=self.quant_state) return new_param - -class LinearFP4(nn.Linear): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): +class Linear4bit(nn.Linear): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'): super().__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() - self.weight = FP4Params(self.weight.data, requires_grad=False, compress_statistics=compress_statistics) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.compute_dtype = compute_dtype def init_8bit_state(self): @@ -198,12 +198,20 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.half() - out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) out = out.to(inp_dtype) return out +class LinearFP4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4') + +class LinearNF4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4') + class Int8Params(torch.nn.Parameter): def __new__( diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0ed413f69..86a93ae24 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -194,7 +194,7 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax) } -__device__ unsigned char dQuantizeNormal(float x) +__device__ unsigned char dQuantizeNF4(float x) { // the values for this tree was generated by test_normal_map_tree @@ -221,7 +221,7 @@ __device__ unsigned char dQuantizeNormal(float x) if(x > 0.1202552504837513f) // 100 return 0b1001; else - return 0b1100; + return 0b1000; else if(x > -0.33967943489551544f) // 0 if(x > -0.13791173323988914f) // 01 @@ -726,8 +726,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { - packed_4bit |= dQuantizeNormal(((float)vals[2*j])*local_abs_max) << 4; - packed_4bit |= dQuantizeNormal(((float)vals[2*j+1])*local_abs_max); + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_4bit; } break; @@ -738,7 +738,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } } -template +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { @@ -747,55 +747,62 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs int valid_items_store = 0; const int base_idx = (blockIdx.x * TILE_SIZE); - T vals[NUM_PER_TH*(FP4 ? 2 : 1)]; + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { - if(FP4) - { - valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; - valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; - } - else - { - valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; - valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; - } - local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); - __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - if(FP4) - { - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - //vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); - //vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); - vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); - } - } - else - { - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; - } + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F, local_abs_max); + } + break; + } - __syncthreads(); - StoreT(storet).Store(&(out[FP4 ? i*2 : i]), vals, valid_items_store); + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); } } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 4356c1d2a..db333753f 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -440,7 +440,7 @@ def test_matmullt( dim2.append(0) -funcs = [(torch.matmul, bnb.matmul_fp4)] +funcs = [(torch.matmul, bnb.matmul_4bit)] str_funcs = ["matmul"] req_grad = list(product([True, False], repeat=3)) req_grad_str = [] @@ -457,12 +457,13 @@ def test_matmullt( compress_statistics = [False, True] has_fp16_weights = [True, False] has_bias = [True, False] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics".format(*vals) for vals in str_values] +quant_type = ['fp4', 'nf4'] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics", values, ids=names) -def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics): +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) +def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: @@ -482,7 +483,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_fp4(B, compress_statistics=compress_statistics) + B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) diff --git a/tests/test_functional.py b/tests/test_functional.py index 98edb7c4f..1f19d43b7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1784,8 +1784,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 4 -seqdim = 256 +batch_size = 2 +seqdim = 2048 values = [] values.append((batch_size, seqdim, 768, 4 * 768)) values.append((batch_size, seqdim, 1024, 4*1024)) @@ -1798,7 +1798,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 128 + iters = 32 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -1808,6 +1808,8 @@ def test_bench_matmul(batch, seq, model, hidden): B_fp4, state = F.quantize_fp4(B) B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) + B_nf4, state_nf4= F.quantize_nf4(B) + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit.eval() @@ -1836,17 +1838,24 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul_fp4(A, B_fp4.t(), quant_state=state) + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) torch.cuda.synchronize() print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul_fp4(A, B_fp4.t(), quant_state=state_c) + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) torch.cuda.synchronize() print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) + torch.cuda.synchronize() + print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() #t0 = time.time() #for i in range(iters): @@ -2262,17 +2271,18 @@ def test_4bit_compressed_stats(quant_type): errs2 = [] for i in range(10): A1 = torch.randn(1024, 1024, device='cuda').half() - q2, SA2 = F.quantize_4bit_packed(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3= F.quantize_4bit_packed(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) - A2 = F.dequantize_4bit_packed(q2, SA2, quant_type=quant_type) - A3 = F.dequantize_4bit_packed(q3, SA3, quant_type=quant_type) + q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) err = (A1 - A2).abs().float() relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() - errs1.append(relerr.item()) + errs1.append(err.item()) + assert err.item() < 0.11 assert relerr.item() < 0.28 @@ -2281,23 +2291,23 @@ def test_4bit_compressed_stats(quant_type): relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() - errs2.append(relerr.item()) + errs2.append(err.item()) assert err.item() < 0.11 assert relerr.item() < 0.28 - #print(sum(errs1)/len(errs1), blocksize) - #print(sum(errs2)/len(errs2), blocksize) + #print(sum(errs1)/len(errs1), blocksize, quant_type) + #print(sum(errs2)/len(errs2), blocksize, quant_type) @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) -def test_bench_fp4_dequant(quant_type): +def test_bench_4bit_dequant(quant_type): blocksize = 256 a = torch.rand(1024*12*4, 1024*12, device='cuda').half() - qa, SA = F.quantize_4bit_packed(a, blocksize=blocksize, quant_type=quant_type) + qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) input_size = a.numel()/2 output_size = a.numel()*2 @@ -2311,7 +2321,7 @@ def test_bench_fp4_dequant(quant_type): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - F.dequantize_4bit_packed(qa, SA, blocksize=blocksize, quant_type=quant_type) + F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) #b.copy_(a) torch.cuda.synchronize() #print((time.time()-t0)/iters*1e6) diff --git a/tests/test_modules.py b/tests/test_modules.py index d0f5ca2c6..94cf36b5a 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -506,8 +506,16 @@ def test_linear_kbit_fp32_bias(module): o1 = l1(b1) assert l1.bias is None +modules = [] +modules.append(bnb.nn.Linear8bitLt) +modules.append(bnb.nn.Linear4bit) +modules.append(bnb.nn.LinearFP4) +modules.append(bnb.nn.LinearNF4) +modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) +modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) +names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C'] @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4, lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)], ids=['Int8Lt', 'FP4', 'FP4+C']) +@pytest.mark.parametrize("module", modules, ids=names) def test_kbit_backprop(module): b = 17 dim1 = 37 @@ -515,6 +523,8 @@ def test_kbit_backprop(module): ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) ref[1].weight.requires_grad = False + torch.nn.init.kaiming_normal_(ref[0].weight) + torch.nn.init.kaiming_normal_(ref[1].weight) kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) kbit[0].weight.detach().copy_(ref[0].weight) kbit[1].weight.detach().copy_(ref[1].weight) @@ -523,6 +533,10 @@ def test_kbit_backprop(module): ref = ref.half().cuda() kbit = kbit.half().cuda() + errs1 = [] + errs2 = [] + relerrs1 = [] + relerrs2 = [] for i in range(100): batch = torch.randn(b, dim1).half().cuda() out1 = ref(batch) @@ -535,12 +549,26 @@ def test_kbit_backprop(module): bgrad1 = ref[0].bias.grad bgrad2 = kbit[0].bias.grad - torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) - torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) + err1 = (out1-out2).abs().float() + err2 = (grad1-grad2).abs().float() + relerr1 = (err1/(out1.abs().float()+1e-9)) + relerr2 = (err2/(grad1.abs().float()+1e-9)) + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + + + #torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) + #torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) ref.zero_grad() kbit.zero_grad() assert kbit[0].weight.grad.sum().item() == 0 assert kbit[0].bias.grad.sum().item() == 0 + print('out', sum(errs1)/len(errs1)) + print('grad', sum(errs2)/len(errs2)) + print('rel out', sum(relerrs1)/len(relerrs1)) + print('rel grad', sum(relerrs2)/len(relerrs2)) From 1ccb7bdec6c9afe8eccf23bea0619ef7d962f279 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 3 Apr 2023 18:47:00 -0700 Subject: [PATCH 20/63] Fixed ParamsIn4 init; fixed PyTorch 2.0 test failure. --- bitsandbytes/nn/modules.py | 18 +++++++----------- tests/test_functional.py | 4 ++-- tests/test_modules.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 86ea342ec..30f92ce8d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -136,12 +136,14 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): cls.quant_state = None - cls.blocksize = blocksize - cls.compress_statistics = compress_statistics - cls.quant_type = quant_type if data is None: data = torch.empty(0) - return torch.Tensor._make_subclass(cls, data, requires_grad) + + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.blocksize = blocksize + self.compress_statistics = compress_statistics + self.quant_type = quant_type + return self def cuda(self, device): w = self.data.contiguous().half().cuda(device) @@ -177,16 +179,10 @@ def to(self, *args, **kwargs): class Linear4bit(nn.Linear): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'): super().__init__(input_features, output_features, bias) - self.state = bnb.MatmulLtState() self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.compute_dtype = compute_dtype - def init_8bit_state(self): - pass - def forward(self, x: torch.Tensor): - self.state.is_training = self.training - # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) @@ -197,7 +193,7 @@ def forward(self, x: torch.Tensor): if self.compute_dtype is not None: x = x.to(self.compute_dtype) - bias = None if self.bias is None else self.bias.half() + bias = None if self.bias is None else self.bias.half(self.compute_dtype) out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) out = out.to(inp_dtype) diff --git a/tests/test_functional.py b/tests/test_functional.py index 1f19d43b7..61ea712e2 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1798,7 +1798,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 32 + iters = 1 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -2317,7 +2317,7 @@ def test_bench_4bit_dequant(quant_type): #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() - iters = 500 + iters = 5 torch.cuda.synchronize() t0 = time.time() for i in range(iters): diff --git a/tests/test_modules.py b/tests/test_modules.py index 94cf36b5a..89c319c35 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -558,14 +558,17 @@ def test_kbit_backprop(module): relerrs1.append(relerr1.mean().item()) relerrs2.append(relerr2.mean().item()) - - #torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) - #torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) + if isinstance(module, bnb.nn.Linear8bitLt): + torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) + torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) + else: + torch.testing.assert_allclose(grad1, grad2, atol=0.015, rtol=0.05) + torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.02, rtol=0.05) ref.zero_grad() kbit.zero_grad() - assert kbit[0].weight.grad.sum().item() == 0 - assert kbit[0].bias.grad.sum().item() == 0 + assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 + assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 print('out', sum(errs1)/len(errs1)) print('grad', sum(errs2)/len(errs2)) print('rel out', sum(relerrs1)/len(relerrs1)) From e9fa03b7176d51fa23d23616b16ef389db18ab02 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 7 Apr 2023 09:59:21 -0700 Subject: [PATCH 21/63] Some fixed for loading PEFT modules with Params4bit. --- bitsandbytes/functional.py | 10 +++++--- bitsandbytes/nn/modules.py | 52 +++++++++++++++++++++++++++++++++++--- csrc/kernels.cu | 32 +++++++++++++++-------- tests/test_optim.py | 4 +-- 4 files changed, 78 insertions(+), 20 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 20841ebbc..b16860609 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -362,9 +362,13 @@ def get_special_format_str(): def is_on_gpu(tensors): on_gpu = True + gpu_ids = set() for t in tensors: if t is None: continue # NULL pointers are fine on_gpu &= t.device.type == 'cuda' + gpu_ids.add(t.device.index) + if len(gpu_ids) > 1: + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}') return on_gpu def get_ptr(A: Tensor) -> ct.c_void_p: @@ -617,7 +621,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - state = (absmax, code, blocksize) + state = [absmax, code, blocksize] return out, state @@ -763,9 +767,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = (qabsmax, input_shape, A.dtype, blocksize, (offset, state2), quant_type) + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] else: - state = (absmax, input_shape, A.dtype, blocksize, None, quant_type) + state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] return out, state diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 30f92ce8d..de9e4ac81 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -135,7 +135,6 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): - cls.quant_state = None if data is None: data = torch.empty(0) @@ -143,12 +142,14 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, self.blocksize = blocksize self.compress_statistics = compress_statistics self.quant_type = quant_type + self.quant_state = quant_state + self.data = data return self def cuda(self, device): w = self.data.contiguous().half().cuda(device) - w_fp4, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) - self.data = w_fp4 + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit self.quant_state = quant_state return self @@ -171,8 +172,19 @@ def to(self, *args, **kwargs): if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): return self.cuda(device) else: + s = self.quant_state + if s is not None: + # make sure the quantization state is on the right device + s[0] = s[0].to(device) + if self.compress_statistics: + # TODO: refactor this. This is a nightmare + s[-2][0] = s[-2][0].to(device) # offset + s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics + s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, quant_state=self.quant_state) + requires_grad=self.requires_grad, quant_state=self.quant_state, + blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type) return new_param @@ -200,6 +212,38 @@ def forward(self, x: torch.Tensor): return out + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + + # we only need to save extra state if .cuda was called + # then we have the (1) quantization weight and the (2) quantization config + + #quant_state = getattr(self.weight, 'quant_state', None) + #if quant_state is not None: + # # 2. quantization state + # destination[prefix + 'quant_state'] = quant_state + + #destination[prefix + 'weight'] = self.weight.detach() + + + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + #for key in unexpected_keys: + # input_name = key[len(prefix):] + # if input_name == "quant_state": + # if getattr(self.weight, 'quant_state', None) is None: + # # buffers not yet initialized, can't call them directly without + # raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is " + # "not supported. Please call module.cuda() before module.load_state_dict()") + + # input_param = state_dict[key] + # self.weight.quant_state = input_param + # assert isinstance(self.weight, Param4bit) + # unexpected_keys.remove(key) + class LinearFP4(Linear4bit): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 86a93ae24..c35acc8cf 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1681,6 +1681,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; @@ -1742,16 +1743,24 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - g_val = float(g_vals[j]); - g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); @@ -1782,22 +1791,23 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char } __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), g_vals, valid_items, (T)0.0f); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { - g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); if(weight_decay > 0.0f) - g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); - StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH diff --git a/tests/test_optim.py b/tests/test_optim.py index 92e3ed260..83390a475 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -282,7 +282,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): errors = [] relerrors = [] - for i in range(50): + for i in range(100): g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -314,7 +314,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) == 0 ) - assert num_not_close.sum().item() < 20 + #assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) From b8ea2b416d25130ed32a3cf436b8a9f8fd1d412f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 12 Apr 2023 12:28:35 -0700 Subject: [PATCH 22/63] Fixed bias conversion in Linear4bit --- bitsandbytes/nn/modules.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index de9e4ac81..ab16e01e0 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -205,45 +205,13 @@ def forward(self, x: torch.Tensor): if self.compute_dtype is not None: x = x.to(self.compute_dtype) - bias = None if self.bias is None else self.bias.half(self.compute_dtype) + bias = None if self.bias is None else self.bias.to(self.compute_dtype) out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) out = out.to(inp_dtype) return out - def _save_to_state_dict(self, destination, prefix, keep_vars): - super()._save_to_state_dict(destination, prefix, keep_vars) - - # we only need to save extra state if .cuda was called - # then we have the (1) quantization weight and the (2) quantization config - - #quant_state = getattr(self.weight, 'quant_state', None) - #if quant_state is not None: - # # 2. quantization state - # destination[prefix + 'quant_state'] = quant_state - - #destination[prefix + 'weight'] = self.weight.detach() - - - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) - #for key in unexpected_keys: - # input_name = key[len(prefix):] - # if input_name == "quant_state": - # if getattr(self.weight, 'quant_state', None) is None: - # # buffers not yet initialized, can't call them directly without - # raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is " - # "not supported. Please call module.cuda() before module.load_state_dict()") - - # input_param = state_dict[key] - # self.weight.quant_state = input_param - # assert isinstance(self.weight, Param4bit) - # unexpected_keys.remove(key) - class LinearFP4(Linear4bit): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4') From 7dc198feb7b68f08790823a06d42c7500ff446fa Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 17 Apr 2023 18:01:49 -0700 Subject: [PATCH 23/63] Added 32-bit optimizer for bfloat16 gradients. --- bitsandbytes/cextension.py | 2 +- bitsandbytes/functional.py | 89 +++++++++++++++----------------------- bitsandbytes/nn/modules.py | 7 +++ csrc/kernels.cu | 7 ++- csrc/ops.cu | 1 + csrc/pythonInterface.c | 10 +++-- tests/test_optim.py | 35 +++++---------- 7 files changed, 65 insertions(+), 86 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index e2ca978eb..8adca9312 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -23,7 +23,7 @@ CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: https://github.com/TimDettmers/bitsandbytes/issues''') - lib.cadam32bit_g32 + lib.cadam_8bit_blockwise_fp32 lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p COMPILED_WITH_CUDA = True diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b16860609..ff0eb7ec2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -28,7 +28,7 @@ def prod(iterable): if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16) str2optimizer32bit["momentum"] = ( lib.cmomentum32bit_g32, lib.cmomentum32bit_g16, @@ -41,11 +41,6 @@ def prod(iterable): lib.cadagrad32bit_g32, lib.cadagrad32bit_g16, ) - str2optimizer32bit["lars"] = ( - lib.cmomentum32bit_g32, - lib.cmomentum32bit_g16, - ) - str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} str2optimizer8bit["adam"] = ( @@ -998,53 +993,37 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - if optimizer_name not in str2optimizer32bit: - raise NotImplementedError( - f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' - ) - if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + + is_on_gpu([g, p, state1, state2, unorm_vec]) + prev_device = pre_call(g.device) + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel())) + post_call(prev_device) def optimizer_update_8bit( @@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise( optim_func = None if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0] + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1] + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and len(str2optimizer8bit_blockwise[optimizer_name])==3): - optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2] + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" @@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise( is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) prev_device = pre_call(g.device) - optimizer_func( + optim_func( get_ptr(p), get_ptr(g), get_ptr(state1), diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ab16e01e0..24f50707b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -178,6 +178,13 @@ def to(self, *args, **kwargs): s[0] = s[0].to(device) if self.compress_statistics: # TODO: refactor this. This is a nightmare + # for 4-bit: + # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + #s[-2][0] = s[-2][0].to(device) # offset + #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax + + # for 8-bit s[-2][0] = s[-2][0].to(device) # offset s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook diff --git a/csrc/kernels.cu b/csrc/kernels.cu index c35acc8cf..2d940be1b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ diff --git a/csrc/ops.cu b/csrc/ops.cu index de14039da..76777ae6c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -703,6 +703,7 @@ template void optimizer32bit(gtype* g, gtype* p, \ MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d16917891..0e9106c40 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) -MAKE_FUNC32(adam, ADAM, float, 32) -MAKE_FUNC32(adam, ADAM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) @@ -173,8 +174,9 @@ extern "C" const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - MAKE_CFUNC32(adam, float, 32) - MAKE_CFUNC32(adam, half, 16) + MAKE_CFUNC32(adam, float, fp32) + MAKE_CFUNC32(adam, half, fp16) + MAKE_CFUNC32(adam, __nv_bfloat16, bf16) MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) diff --git a/tests/test_optim.py b/tests/test_optim.py index 83390a475..a13b33207 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -44,10 +44,6 @@ def rm_path(path): lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9), -) str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), @@ -64,10 +60,6 @@ def rm_path(path): lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars8bit"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9), -) str2optimizers["adam8bit_blockwise"] = ( torch.optim.Adam, @@ -85,7 +77,6 @@ def rm_path(path): str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["momentum"] = [("momentum_buffer", "state1")] -str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["adam8bit"] = [ @@ -106,7 +97,6 @@ def rm_path(path): str2statenames["momentum8bit_blockwise"] = [ ("momentum_buffer", "state1", "qmap1", "absmax1") ] -str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit_blockwise"] = [ ("square_avg", "state1", "qmap1", "absmax1") @@ -114,14 +104,10 @@ def rm_path(path): dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars"] +gtype = [torch.float32, torch.float16, torch.bfloat16] +optimizer_names = ["adam", "momentum", "rmsprop"] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] - - +names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: @@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 1e-6, 1e-5 + elif gtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-2 else: atol, rtol = 1e-4, 1e-3 @@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rtol=rtol, ) - if gtype == torch.float16: + if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit # but the paramters can diverge because they are 16-bit # the difference grow larger and larger with each update # --> copy the state to keep weights close - p1.data = p1.data.half().float() + p1.data = p1.data.to(p2.dtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.half(), p2) + torch.testing.assert_allclose(p1.to(p2.dtype), p2) if optim_name in ["lars", "lamb"]: assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 @@ -246,7 +234,6 @@ def test_global_config(dim1, dim2, gtype): "momentum8bit", "rmsprop8bit", "adam8bit_blockwise", - "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] @@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerr = err / torch.abs(p1) if g.dtype == torch.bfloat16: assert err.mean() < 0.00015 - assert relerr.mean() < 0.0015 + assert relerr.mean() < 0.0016 else: - assert err.mean() < 0.0001 - assert relerr.mean() < 0.001 + assert err.mean() < 0.00012 + assert relerr.mean() < 0.0012 errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) From 0f9d30207f7a86c6be17f8fd897f0716db32cdfd Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 19 Apr 2023 11:48:47 -0700 Subject: [PATCH 24/63] Added nested quantization for blockwise quantization. --- bitsandbytes/functional.py | 25 +++++++++---- tests/test_functional.py | 72 ++++++++++++++++++++------------------ 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ff0eb7ec2..eb4980021 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -541,7 +541,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n return out -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -586,7 +586,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) @@ -616,7 +616,15 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - state = [absmax, code, blocksize] + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + state = [qabsmax, code, blocksize, nested, offset, state2] + else: + state = [absmax, code, blocksize, nested, None, None] + + return out, state @@ -628,6 +636,7 @@ def dequantize_blockwise( code: Tensor = None, out: Tensor = None, blocksize: int = 4096, + nested=False ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -665,13 +674,15 @@ def dequantize_blockwise( if quant_state is None: quant_state = (absmax, code, blocksize) else: - absmax, code, blocksize = quant_state - + absmax, code, blocksize, nested, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset if A.device.type != 'cpu': device = pre_call(A.device) code = code.to(A.device) - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]: + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: @@ -736,7 +747,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if out is None: out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) diff --git a/tests/test_functional.py b/tests/test_functional.py index 61ea712e2..82f6a71df 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -150,42 +150,44 @@ def test_dynamic_quantization(): assert diff < 0.004 -def test_dynamic_blockwise_quantization(): + +@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) +@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) +def test_dynamic_blockwise_quantization(nested, blocksize): #print('') - for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]: - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.011 - assert relerr < 0.018 - #print('randn', blocksize, sum(diffs)/len(diffs)) - #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) - - diffs = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.0035 - assert relerr < 0.015 - #print('rand', blocksize, sum(diffs)/len(diffs)) - #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): From 6bfd7a405f7ccea4c40fb54c8fd0c179984ac506 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 25 Apr 2023 16:13:43 -0700 Subject: [PATCH 25/63] Initial template. --- Makefile | 13 ++++++++++++- csrc/kernels.cu | 25 +++++++++++++++++++++++++ csrc/kernels.cuh | 2 ++ csrc/ops.cu | 12 ++++++++++++ csrc/ops.cuh | 2 ++ 5 files changed, 53 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index e1141607f..a377f6548 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include +INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags @@ -61,7 +62,7 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 all: $(BUILD_DIR) env - $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) @@ -100,6 +101,11 @@ cuda11x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) +cuda11x_cutlass: $(BUILD_DIR) env cutlass + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++20 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + cuda12x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o @@ -121,6 +127,11 @@ env: @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" @echo "============================" +cutlass: + if [ ! -d "$(ROOT_DIR)/dependencies/cutlass" ]; then \ + git clone https://github.com/NVIDIA/cutlass.git $(ROOT_DIR)/dependencies/cutlass; \ + fi \ + $(BUILD_DIR): mkdir -p build mkdir -p dependencies diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2d940be1b..5d2a58ec5 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2919,10 +2919,35 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } + +template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +{ +// element-wise kernel +// 1. Load batch x k into registers +// 2. Load k x k into registers +// 3. dequantize and store in second pair of k x k +// 4. matmul +// 5. sum with cub +// 6. store outputs +// TC kernel +// use k warps per thread block +// 1. threadblock use read-only cache to read in register tile for A into shared memory +// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +// 3. each warp reads a segment of values 16x32 from B +// 4. do dequantization from register of B into second pair of registers +// 5. store (4) into fragment +// 6. matmul aggregate into fragment C +// 7. aggreecate files of C into shared memroy block C +// 8. sum (7) +// 9. write outputs to matmul output matrix +} + + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ed549cb3f..ecf3a0991 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,6 +9,8 @@ #ifndef kernels #define kernels +template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 76777ae6c..022f39785 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -90,6 +90,17 @@ template void dequantizeBlockwise(float *code, unsign CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +{ + int num_blocks = (colsB+32-1)/32; + kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *C, int lda, int ldb, int rowsA, int colsA, int colsB); + + template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -653,6 +664,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + //============================================================== // TEMPLATE DEFINITIONS //============================================================== diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f73d4e0c4..137320ba1 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -183,4 +183,6 @@ template void spmm_coo_very_sparse_naive(int *max_count, template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + #endif From 6e2544da251ccf281d5d88611d2cb5c13bcf42a6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 25 Apr 2023 16:15:44 -0700 Subject: [PATCH 26/63] Added cutlass example. --- csrc/kernels.cu | 134 ++++++++++++++++++++++++++++++++++++++++++++++++ csrc/ops.cu | 57 ++++++++++++++++++++ 2 files changed, 191 insertions(+) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5d2a58ec5..a108772db 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2942,6 +2942,140 @@ template __global // 9. write outputs to matmul output matrix } +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 +# include "cutlass/util/cublas_wrappers.hpp" +#endif +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(CThreadLayout{}))::value) +void +gemm_device(MShape M, NShape N, KShape K, + TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, + TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, + TC * C, CStride dC, CBlockLayout , CThreadLayout tC, + Alpha alpha, Beta beta) +{ + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); + CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); + + //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M + //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N + CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) + auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) + + // Represent the full tensors + auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) + auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) + auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) + + // Get the appropriate blocks for this thread block -- + // potential for thread block locality + auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + + auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // + // Partition the copying of A and B tiles across the threads + // + + // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB + // Default is a raked partition, but can be changed with Step parameter + + auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) + auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) + + auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) + auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) + + // + // Define C accumulators and A/B partitioning + // + + // TUTORIAL: Example of partitioning via projections of tC + + // Partition sA (M,K) by the rows of tC + auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) + // Partition sB (N,K) by the cols of tC + auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) + // Partition gC (M,N) by the tile of tC + auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) + + // Allocate the accumulators -- same size as the projected data + auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) + + // Clear the accumulators + clear(tCrC); + +#if 1 + + // TUTORIAL: Example of a very simple compute loop + // Data is read from global to shared memory via the tA|tB partitioning + // gemm(.) operates on the shared memory directly via the tC partitioning + + auto k_max = size<2>(tAgA); + + for (int k = 0; k < k_max; ++k) + { + // Copy gmem to smem + copy(tAgA(_,_,k), tAsA); + copy(tBgB(_,_,k), tBsB); + + // In case copy uses cp.async, make sure that the cp.async + // instructions are ordered with respect to other cp.async + // instructions (fence), then wait on all the outstanding copy + // operations (wait<0>()). __syncthreads() alone does not do + // this. + // + // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. + // This is equivalent to cp.async.commit_group followed by + // cp.async_wait_group 0. This should make the first + // cp_async_fence() (which also issues cp.async.commit_group) + // redundant. The tutorial works as-is, so we'll leave the + // redundant fence in for now and study its removal later. + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + // Compute gemm on smem + gemm(tCsA, tCsB, tCrC); + + __syncthreads(); + } + +#endif + + axpby(alpha, tCrC, beta, tCgC); +} + //============================================================== // TEMPLATE DEFINITIONS diff --git a/csrc/ops.cu b/csrc/ops.cu index 022f39785..1204cbda6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -665,6 +665,63 @@ template void extractOutliers(char * A, int *idx, char *out, int id } + +#include +#include + +#include + +template +void +gemm(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); + auto dB = make_stride(Int<1>{}, ldB); + auto dC = make_stride(Int<1>{}, ldC); + + // Define block sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + + // Define the block layouts (static) + auto sA = make_layout(make_shape(bM,bK)); + auto sB = make_layout(make_shape(bN,bK)); + auto sC = make_layout(make_shape(bM,bN)); + + // Define the thread layouts (static) + auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); + + dim3 dimBlock(size(tC)); + dim3 dimGrid(ceil_div(size(M), size(bM)), + ceil_div(size(N), size(bN))); + gemm_device + <<< dimGrid, dimBlock, 0, stream >>> + (M, N, K, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); +} + + //============================================================== // TEMPLATE DEFINITIONS //============================================================== From 84964db93789c66fbe8b2c150fb1f9f953781137 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 25 Apr 2023 17:15:51 -0700 Subject: [PATCH 27/63] CUTLASS compiles. --- Makefile | 7 ++++--- bitsandbytes/functional.py | 4 ++-- bitsandbytes/nn/modules.py | 1 + csrc/kernels.cu | 18 ++++++++++++------ csrc/ops.cu | 4 +--- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index a377f6548..7e8be4191 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) -GPP:= /usr/bin/g++ +#GPP:= /usr/bin/g++ +GPP:= /sw/gcc/11.2.0/bin/g++ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif @@ -25,7 +26,7 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include +INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include -I $(ROOT_DIR)/dependencies/cutlass/tools/util/include/ -I $(ROOT_DIR)/dependencies/cutlass/include/cute/util/ LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags @@ -104,7 +105,7 @@ cuda11x: $(BUILD_DIR) env cuda11x_cutlass: $(BUILD_DIR) env cutlass $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++20 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + $(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda12x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index eb4980021..80725b1e1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -176,7 +176,7 @@ def create_custom_map(seed=0, scale=0.01): #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 - v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 # 7B evo start #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 @@ -186,7 +186,7 @@ def create_custom_map(seed=0, scale=0.01): # 13B evo start #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042] #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283] - #v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] + v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] # mean evo 7B + 13B #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 24f50707b..287a46703 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -228,6 +228,7 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4') + class Int8Params(torch.nn.Parameter): def __new__( cls, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a108772db..ed7d6b2f3 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -12,6 +12,14 @@ #include #include #include +#include +#include + +#include +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/cublas_wrappers.hpp" +#include "cutlass/util/helper_cuda.hpp" #define HLF_MAX 65504 #define TH 1024 @@ -2709,7 +2717,7 @@ template @@ -2813,7 +2821,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o float valB = local_valsB[k]; float valA = local_valA[i]; if(valB != 0.0 && valA != 0.0) - local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA; + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; } else local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; @@ -2960,7 +2968,7 @@ void gemm_device(MShape M, NShape N, KShape K, TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, - TC * C, CStride dC, CBlockLayout , CThreadLayout tC, + TC * out, CStride dC, CBlockLayout , CThreadLayout tC, Alpha alpha, Beta beta) { using namespace cute; @@ -2991,7 +2999,7 @@ gemm_device(MShape M, NShape N, KShape K, // Represent the full tensors auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) - auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) + auto mC = make_tensor(make_gmem_ptr(out), make_shape(M,N), dC); // (M,N) // Get the appropriate blocks for this thread block -- // potential for thread block locality @@ -3034,7 +3042,6 @@ gemm_device(MShape M, NShape N, KShape K, // Clear the accumulators clear(tCrC); -#if 1 // TUTORIAL: Example of a very simple compute loop // Data is read from global to shared memory via the tA|tB partitioning @@ -3071,7 +3078,6 @@ gemm_device(MShape M, NShape N, KShape K, __syncthreads(); } -#endif axpby(alpha, tCrC, beta, tCgC); } diff --git a/csrc/ops.cu b/csrc/ops.cu index 1204cbda6..a3a7c29ca 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -666,11 +666,9 @@ template void extractOutliers(char * A, int *idx, char *out, int id -#include -#include - #include + template void From 0afc8e9e2f2a0a2ca707057fe6523bed98451bb6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 26 Apr 2023 17:12:34 -0700 Subject: [PATCH 28/63] Best attempt at cutlass3. --- Makefile | 8 +-- bitsandbytes/functional.py | 98 ++++++++++++++++++++++++++++ csrc/kernels.cu | 128 ++++++++++++++++++++++--------------- csrc/kernels.cuh | 22 ++++++- csrc/ops.cu | 77 +++++++++------------- csrc/ops.cuh | 12 ++++ csrc/pythonInterface.c | 18 ++++++ tests/test_functional.py | 21 ++++++ 8 files changed, 282 insertions(+), 102 deletions(-) diff --git a/Makefile b/Makefile index 7e8be4191..059545c55 100644 --- a/Makefile +++ b/Makefile @@ -55,8 +55,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 -CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 -CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 +#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 +#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 @@ -103,9 +103,9 @@ cuda11x: $(BUILD_DIR) env $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda11x_cutlass: $(BUILD_DIR) env cutlass - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math --expt-relaxed-constexpr -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + $(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(INCLUDE_cutlass) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda12x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 80725b1e1..7e4874a2c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1374,6 +1374,104 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout +def cutlass3_gemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, +): + sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.float32) + if out is None: + out = torch.zeros(size=sout, dtype=torch.float32, device=A.device) + + sA = A.shape + sB = B.shape + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + + # matrices in the input arguments for cuBLAS + # column major: A @ B = C: [m, k] @ [k, n] = [m, n] + # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] + # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] + if len(sB) == 2: + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True + if len(A.shape) == 2: + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True + else: + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True + + if len(sA) == 2: + n = sA[0] + ldb = A.stride()[1 if transposed_A else 0] + elif len(sA) == 3 and len(sB) == 2: + n = sA[0] * sA[1] + ldb = sA[2] + + m = sB[1] + k = sB[0] + lda = B.stride()[(1 if transposed_B else 0)] + ldc = sB[1] + elif len(sB) == 3: + # special case + assert len(sA) == 3 + if not (sA[0] == sB[0] and sA[1] == sB[1]): + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) + + transposed_A = True + transposed_B = False + + m = sB[2] + n = sA[2] + k = sB[0] * sB[1] + + lda = m + ldb = sA[2] + ldc = m + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + # B^T @ A^T = C^T + # [km, nk -> mn] + lda = ldb = ldc = 1 + #lda = 1 + print(m, n, k, lda, ldb, ldc) + is_on_gpu([B, A, out]) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + alpha = ct.c_float(1.0) + beta = ct.c_float(0.0) + lib.ccutlass_gemm(m, n, k, alpha, get_ptr(B), lda, get_ptr(A), ldb, beta, get_ptr(out), ldc) + + return out + + + def igemm( A: Tensor, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ed7d6b2f3..4c835732c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -19,7 +19,6 @@ #include "cutlass/util/print_error.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/cublas_wrappers.hpp" -#include "cutlass/util/helper_cuda.hpp" #define HLF_MAX 65504 #define TH 1024 @@ -2928,73 +2927,84 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } -template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) -{ -// element-wise kernel -// 1. Load batch x k into registers -// 2. Load k x k into registers -// 3. dequantize and store in second pair of k x k -// 4. matmul -// 5. sum with cub -// 6. store outputs -// TC kernel -// use k warps per thread block -// 1. threadblock use read-only cache to read in register tile for A into shared memory -// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -// 3. each warp reads a segment of values 16x32 from B -// 4. do dequantization from register of B into second pair of registers -// 5. store (4) into fragment -// 6. matmul aggregate into fragment C -// 7. aggreecate files of C into shared memroy block C -// 8. sum (7) -// 9. write outputs to matmul output matrix -} +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggreecate files of C into shared memroy block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} #include "cutlass/util/print_error.hpp" #include "cutlass/util/GPU_Clock.hpp" #if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 # include "cutlass/util/cublas_wrappers.hpp" #endif -#include "cutlass/util/helper_cuda.hpp" - -template -__global__ static -__launch_bounds__(decltype(size(CThreadLayout{}))::value) -void -gemm_device(MShape M, NShape N, KShape K, - TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, - TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, - TC * out, CStride dC, CBlockLayout , CThreadLayout tC, - Alpha alpha, Beta beta) +//#include "cutlass/util/helper_cuda.hpp" + +__global__ void gemm_device(int M, int N, int K, + float const* A, + float const* B, + float * out, int lda, int ldb, int ldc, + float alpha, float beta) { using namespace cute; using X = Underscore; // Preconditions - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); + //CUTE_STATIC_ASSERT(is_static::value); + //CUTE_STATIC_ASSERT(is_static::value); + //CUTE_STATIC_ASSERT(is_static::value); + + //CUTE_STATIC_ASSERT(is_static::value); + //CUTE_STATIC_ASSERT(is_static::value); + //CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); - CUTE_STATIC_ASSERT(is_static::value); + //CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); + //CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); - CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); - CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); + // Define block sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + + // Define the block layouts (static) + auto bA = make_layout(make_shape(bM,bK)); + auto bB = make_layout(make_shape(bN,bK)); + auto bC = make_layout(make_shape(bM,bN)); + + // Define the thread layouts (static) + auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N - CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K + //CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K // Shared memory buffers - __shared__ TA smemA[cosize_v]; - __shared__ TB smemB[cosize_v]; - auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) - auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) + __shared__ float smemA[128*8]; + __shared__ float smemB[128*8]; + auto sA = make_tensor(make_smem_ptr(smemA), bA); // (BLK_M,BLK_K) + auto sB = make_tensor(make_smem_ptr(smemB), bB); // (BLK_N,BLK_K) + + auto dA = make_stride(Int<1>{}, lda); + auto dB = make_stride(Int<1>{}, ldb); + auto dC = make_stride(Int<1>{}, ldc); // Represent the full tensors auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) @@ -3083,11 +3093,27 @@ gemm_device(MShape M, NShape N, KShape K, } + //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); +//template +//__global__ static +//__launch_bounds__(decltype(size(CThreadLayout{}))::value) +//void +//gemm_device(MShape M, NShape N, KShape K, +// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, +// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, +// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, +// half alpha, half beta); + + +//template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ecf3a0991..ba6de590a 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,7 +9,7 @@ #ifndef kernels #define kernels -template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); @@ -122,4 +122,24 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +//template +//__global__ static +//__launch_bounds__(decltype(size(CThreadLayout{}))::value) +//void +//gemm_device(MShape M, NShape N, KShape K, +// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, +// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, +// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, +// Alpha alpha, Beta beta); + +__global__ void gemm_device(int M, int N, int K, + float const* A, + float const* B, + float * out, int lda, int ldb, int ldc, + float alpha, float beta); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index a3a7c29ca..ca56faea7 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -91,14 +91,12 @@ template void dequantizeBlockwise(float *code, unsign } -void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) -{ - int num_blocks = (colsB+32-1)/32; - kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - -template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *C, int lda, int ldb, int rowsA, int colsA, int colsB); +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} template void optimizer32bit(T* g, T* p, @@ -666,60 +664,47 @@ template void extractOutliers(char * A, int *idx, char *out, int id + #include +#include "cutlass/util/helper_cuda.hpp" -template -void -gemm(int m, int n, int k, - Alpha alpha, - TA const* A, int ldA, - TB const* B, int ldB, - Beta beta, - TC * C, int ldC, - cudaStream_t stream = 0) +void gemm_host(int m, int n, int k, + float alpha, + float const* A, int lda, + float const* B, int ldb, + float beta, + float * C, int ldc) { + cute::device_init(0); using namespace cute; + + // Define shapes (dynamic) auto M = int(m); auto N = int(n); auto K = int(k); - // Define strides (mixed) - auto dA = make_stride(Int<1>{}, ldA); - auto dB = make_stride(Int<1>{}, ldB); - auto dC = make_stride(Int<1>{}, ldC); - - // Define block sizes (static) - auto bM = Int<128>{}; - auto bN = Int<128>{}; - auto bK = Int< 8>{}; - - // Define the block layouts (static) - auto sA = make_layout(make_shape(bM,bK)); - auto sB = make_layout(make_shape(bN,bK)); - auto sC = make_layout(make_shape(bM,bN)); - - // Define the thread layouts (static) - auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); - auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); - auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); - - dim3 dimBlock(size(tC)); - dim3 dimGrid(ceil_div(size(M), size(bM)), - ceil_div(size(N), size(bN))); + + printf("%i %i %i %i %i %i\n", m, n, k, lda, ldb, ldc); + + dim3 dimBlock(16, 16); + dim3 dimGrid((M+127)/128, (N+127)/128); +// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); +//- +//- dim3 dimBlock(size(tC)); +//- dim3 dimGrid(ceil_div(size(M), size(bM)), +//- ceil_div(size(N), size(bN))); gemm_device - <<< dimGrid, dimBlock, 0, stream >>> + <<< dimGrid, dimBlock, 0, 0 >>> (M, N, K, - A, dA, sA, tA, - B, dB, sB, tB, - C, dC, sC, tC, + A, + B, + C, lda, ldb, ldc, alpha, beta); } - //============================================================== // TEMPLATE DEFINITIONS //============================================================== diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 137320ba1..843a9bbee 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -20,6 +20,11 @@ #include #include +#include +#include + + + #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ if (_m_cudaStat != cudaSuccess) { \ @@ -185,4 +190,11 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); +void gemm_host(int m, int n, int k, + float alpha, + float const* A, int ldA, + float const* B, int ldB, + float beta, + float * C, int ldC); + #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 0e9106c40..c6de62d1d 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,6 +20,16 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } +void +cppgemm(int m, int n, int k, + float alpha, + float const* A, int ldA, + float const* B, int ldB, + float beta, + float * C, int ldC) +{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} + + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_g##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ @@ -306,6 +316,14 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + void ccutlass_gemm(int m, int n, int k, + float alpha, + float const* A, int ldA, + float const* B, int ldB, + float beta, + float * C, int ldC) + { cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 82f6a71df..128c80347 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2351,3 +2351,24 @@ def test_normal_map_tree(): pivots.append((values[i-1]+values[i])/2) print(pivots) + +def test_cutlass3_gemm(): + #A = torch.rand(2, 2).cuda() + #B = torch.rand(2, 2).cuda() + A = torch.arange(4).reshape(2, 2).float().cuda().contiguous() + B = torch.ones(2, 2).float().cuda() + + print('') + print(A) + print(B) + + C1 = torch.matmul(A, B) + print(C1) + C2 = F.cutlass3_gemm(A, B.t()) + print(C2) + C2 = F.cutlass3_gemm(A, B) + print(C2) + C2 = F.cutlass3_gemm(B.t(), A.t().contiguous()) + print(C2) + + From d1c4c2056893c35a7ca8e55a1b2beebeeeaee679 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 27 Apr 2023 15:11:26 -0700 Subject: [PATCH 29/63] Added non-cutlass template. --- Makefile | 14 +--- bitsandbytes/functional.py | 4 +- csrc/kernels.cu | 152 +++++-------------------------------- csrc/ops.cu | 28 ++----- tests/test_functional.py | 6 -- 5 files changed, 32 insertions(+), 172 deletions(-) diff --git a/Makefile b/Makefile index 059545c55..ea6ee87d5 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) -#GPP:= /usr/bin/g++ -GPP:= /sw/gcc/11.2.0/bin/g++ +GPP:= /usr/bin/g++ +#GPP:= /sw/gcc/11.2.0/bin/g++ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif @@ -26,7 +26,6 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include -I $(ROOT_DIR)/dependencies/cutlass/tools/util/include/ -I $(ROOT_DIR)/dependencies/cutlass/include/cute/util/ LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags @@ -63,8 +62,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 all: $(BUILD_DIR) env - $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env @@ -102,11 +101,6 @@ cuda11x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) -cuda11x_cutlass: $(BUILD_DIR) env cutlass - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math --expt-relaxed-constexpr -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(INCLUDE_cutlass) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) - cuda12x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7e4874a2c..54a08a15d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1456,7 +1456,7 @@ def cutlass3_gemm( # [km, nk -> mn] lda = ldb = ldc = 1 #lda = 1 - print(m, n, k, lda, ldb, ldc) + #print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1466,7 +1466,7 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) alpha = ct.c_float(1.0) beta = ct.c_float(0.0) - lib.ccutlass_gemm(m, n, k, alpha, get_ptr(B), lda, get_ptr(A), ldb, beta, get_ptr(out), ldc) + lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), ldb, get_ptr(B), lda, beta, get_ptr(out), ldc) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4c835732c..ed87c69c6 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -15,11 +15,6 @@ #include #include -#include -#include "cutlass/util/print_error.hpp" -#include "cutlass/util/GPU_Clock.hpp" -#include "cutlass/util/cublas_wrappers.hpp" - #define HLF_MAX 65504 #define TH 1024 #define NUM 4 @@ -2949,147 +2944,42 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} -#include "cutlass/util/print_error.hpp" -#include "cutlass/util/GPU_Clock.hpp" -#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 -# include "cutlass/util/cublas_wrappers.hpp" -#endif -//#include "cutlass/util/helper_cuda.hpp" - __global__ void gemm_device(int M, int N, int K, float const* A, float const* B, float * out, int lda, int ldb, int ldc, float alpha, float beta) { - using namespace cute; - using X = Underscore; - - // Preconditions - //CUTE_STATIC_ASSERT(is_static::value); - //CUTE_STATIC_ASSERT(is_static::value); - //CUTE_STATIC_ASSERT(is_static::value); - - //CUTE_STATIC_ASSERT(is_static::value); - //CUTE_STATIC_ASSERT(is_static::value); - //CUTE_STATIC_ASSERT(is_static::value); - - //CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); - //CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); - - // Define block sizes (static) - auto bM = Int<128>{}; - auto bN = Int<128>{}; - auto bK = Int< 8>{}; - - // Define the block layouts (static) - auto bA = make_layout(make_shape(bM,bK)); - auto bB = make_layout(make_shape(bN,bK)); - auto bC = make_layout(make_shape(bM,bN)); - - // Define the thread layouts (static) - auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); - auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); - auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); - - //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M - //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N - //CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K - - // Shared memory buffers - __shared__ float smemA[128*8]; - __shared__ float smemB[128*8]; - auto sA = make_tensor(make_smem_ptr(smemA), bA); // (BLK_M,BLK_K) - auto sB = make_tensor(make_smem_ptr(smemB), bB); // (BLK_N,BLK_K) - - auto dA = make_stride(Int<1>{}, lda); - auto dB = make_stride(Int<1>{}, ldb); - auto dC = make_stride(Int<1>{}, ldc); - - // Represent the full tensors - auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) - auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) - auto mC = make_tensor(make_gmem_ptr(out), make_shape(M,N), dC); // (M,N) - - // Get the appropriate blocks for this thread block -- - // potential for thread block locality - auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) - auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) - - auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) - - // - // Partition the copying of A and B tiles across the threads - // - - // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB - // Default is a raked partition, but can be changed with Step parameter - - auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) - auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) - - auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) - auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) - - // - // Define C accumulators and A/B partitioning - // - - // TUTORIAL: Example of partitioning via projections of tC +// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +// 1. Load dataB into register +// 2. Dequantize B +// 3. Fetch data from A and multiply - // Partition sA (M,K) by the rows of tC - auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) - // Partition sB (N,K) by the cols of tC - auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) - // Partition gC (M,N) by the tile of tC - auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) + typedef cub::BlockLoad LoadA; + __shared__ typename LoadA::TempStorage loada; + float dataA[1]; + int valid_items = 0; - // Allocate the accumulators -- same size as the projected data - auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) + __shared__ float[16*256] tileA; - // Clear the accumulators - clear(tCrC); + for(int idxA = 0; idxA < M*K; idxA+= 256) + { + valid_items = M*K - idxA > 256 ? 256 : M*K - idxA; + int baserow = 0; + for(int row = baserow; row < baserow+16 && row < M + ; row++) + { + LoadA(loada).Load(&(A[(row*lda) + i]), dataA, valid_items, 0.0f); + tileA[row*256 + threadIdx.x] = dataA[0]; + __syncthreads(); + } + baserow += 16; - // TUTORIAL: Example of a very simple compute loop - // Data is read from global to shared memory via the tA|tB partitioning - // gemm(.) operates on the shared memory directly via the tC partitioning - - auto k_max = size<2>(tAgA); - - for (int k = 0; k < k_max; ++k) - { - // Copy gmem to smem - copy(tAgA(_,_,k), tAsA); - copy(tBgB(_,_,k), tBsB); - - // In case copy uses cp.async, make sure that the cp.async - // instructions are ordered with respect to other cp.async - // instructions (fence), then wait on all the outstanding copy - // operations (wait<0>()). __syncthreads() alone does not do - // this. - // - // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. - // This is equivalent to cp.async.commit_group followed by - // cp.async_wait_group 0. This should make the first - // cp_async_fence() (which also issues cp.async.commit_group) - // redundant. The tutorial works as-is, so we'll leave the - // redundant fence in for now and study its removal later. - cp_async_fence(); - cp_async_wait<0>(); - - __syncthreads(); - // Compute gemm on smem - gemm(tCsA, tCsB, tCrC); + } - __syncthreads(); - } - axpby(alpha, tCrC, beta, tCgC); } diff --git a/csrc/ops.cu b/csrc/ops.cu index ca56faea7..8933927f6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -665,9 +665,6 @@ template void extractOutliers(char * A, int *idx, char *out, int id -#include -#include "cutlass/util/helper_cuda.hpp" - void gemm_host(int m, int n, int k, float alpha, @@ -676,29 +673,14 @@ void gemm_host(int m, int n, int k, float beta, float * C, int ldc) { - cute::device_init(0); - using namespace cute; - - - - // Define shapes (dynamic) - auto M = int(m); - auto N = int(n); - auto K = int(k); - - printf("%i %i %i %i %i %i\n", m, n, k, lda, ldb, ldc); + dim3 dimBlock(256); + int num_blocks = (n+31)/32; - dim3 dimBlock(16, 16); - dim3 dimGrid((M+127)/128, (N+127)/128); -// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); -//- -//- dim3 dimBlock(size(tC)); -//- dim3 dimGrid(ceil_div(size(M), size(bM)), -//- ceil_div(size(N), size(bN))); + cout << num_blocks << endl; gemm_device - <<< dimGrid, dimBlock, 0, 0 >>> - (M, N, K, + <<< num_blocks, dimBlock, 0, 0 >>> + (m, n, k, A, B, C, lda, ldb, ldc, diff --git a/tests/test_functional.py b/tests/test_functional.py index 128c80347..dd41972e9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2363,12 +2363,6 @@ def test_cutlass3_gemm(): print(B) C1 = torch.matmul(A, B) - print(C1) - C2 = F.cutlass3_gemm(A, B.t()) - print(C2) C2 = F.cutlass3_gemm(A, B) - print(C2) - C2 = F.cutlass3_gemm(B.t(), A.t().contiguous()) - print(C2) From 9cab14a3ff920a153fb450e299329a473f1416a4 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 27 Apr 2023 15:12:49 -0700 Subject: [PATCH 30/63] Adedd pipeline draft. --- bitsandbytes/functional.py | 5 ++++ csrc/kernels.cu | 49 ++++++++++++++++++++++++++++++++++++++ csrc/kernels.cuh | 2 ++ csrc/ops.cu | 11 +++++++++ csrc/ops.cuh | 2 ++ csrc/pythonInterface.c | 1 + tests/test_functional.py | 5 ++++ 7 files changed, 75 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 54a08a15d..bb3cde3dd 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2341,3 +2341,8 @@ def extract_outliers(A, SA, idx): post_call(prev_device) return out + +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ed87c69c6..775716f80 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -15,6 +15,9 @@ #include #include +#include +#include + #define HLF_MAX 65504 #define TH 1024 #define NUM 4 @@ -2983,6 +2986,51 @@ __global__ void gemm_device(int M, int N, int K, } +__device__ void compute(float* global_out, float const* shared_in) +{ + +} +template +__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) { + auto grid = cooperative_groups::this_grid(); + auto block = cooperative_groups::this_thread_block(); + assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size + + extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes + size_t shared_offset[stages_count]; + for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size(); + + __shared__ cuda::pipeline_shared_state< + cuda::thread_scope::thread_scope_block, + stages_count + > shared_state; + auto pipeline = cuda::make_pipeline(block, &shared_state); + + auto block_batch = [&](size_t batch) -> int { + return block.group_index().x * block.size() + grid.size() * batch; + }; + + // compute_batch: next batch to process + // fetch_batch: next batch to fetch from global memory + for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) { + // The outer loop iterates over the computation of the batches + for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) { + // This inner loop iterates over the memory transfers, making sure that the pipeline is always full + pipeline.producer_acquire(); + size_t shared_idx = fetch_batch % stages_count; + size_t batch_idx = fetch_batch; + size_t block_batch_idx = block_batch(batch_idx); + cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline); + pipeline.producer_commit(); + } + pipeline.consumer_wait(); + int shared_idx = compute_batch % stages_count; + int batch_idx = compute_batch; + compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]); + pipeline.consumer_release(); + } +} + //============================================================== // TEMPLATE DEFINITIONS @@ -3004,6 +3052,7 @@ __global__ void gemm_device(int M, int N, int K, //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); +template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ba6de590a..37e214a4a 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -135,6 +135,8 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // Alpha alpha, Beta beta); +template +__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); __global__ void gemm_device(int M, int N, int K, float const* A, diff --git a/csrc/ops.cu b/csrc/ops.cu index 8933927f6..ee585bb4c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -663,6 +663,17 @@ template void extractOutliers(char * A, int *idx, char *out, int id } +void pipeline_test(float *A, float *B, size_t n, size_t batch_size) +{ + + int threads = 256; + int num_blocks = (n+(256*batch_size)+1)/(batch_size*256); + + printf("%i %i\n", num_blocks, batch_size); + + with_staging_unified<2><<>>(A, B, n, batch_size); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 843a9bbee..83dd4e515 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -197,4 +197,6 @@ void gemm_host(int m, int n, int k, float beta, float * C, int ldC); + +void pipeline_test(float *A, float *B, size_t n, size_t batch_size); #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index c6de62d1d..170093f0f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -315,6 +315,7 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } void ccutlass_gemm(int m, int n, int k, float alpha, diff --git a/tests/test_functional.py b/tests/test_functional.py index dd41972e9..7dec375f6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2366,3 +2366,8 @@ def test_cutlass3_gemm(): C2 = F.cutlass3_gemm(A, B) +def test_pipeline_func(): + a = torch.rand(2, 4).cuda() + out = F.pipeline_test(a, 2) + print(a) + print(out) From c1bfb210c59dc56559b571a927714ca13cea80c5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 28 Apr 2023 17:19:02 -0700 Subject: [PATCH 31/63] First baseline kernel. --- bitsandbytes/functional.py | 8 +-- csrc/kernels.cu | 103 ++++++++++++++++++++++++++++++++----- csrc/kernels.cuh | 2 +- csrc/ops.cu | 13 +++-- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 4 +- tests/test_functional.py | 20 +++---- 7 files changed, 119 insertions(+), 33 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb3cde3dd..774e9547a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1429,7 +1429,7 @@ def cutlass3_gemm( m = sB[1] k = sB[0] - lda = B.stride()[(1 if transposed_B else 0)] + lda = B.stride()[0] ldc = sB[1] elif len(sB) == 3: # special case @@ -1446,7 +1446,7 @@ def cutlass3_gemm( n = sA[2] k = sB[0] * sB[1] - lda = m + lda = n ldb = sA[2] ldc = m @@ -1454,7 +1454,7 @@ def cutlass3_gemm( # B^T @ A^T = C^T # [km, nk -> mn] - lda = ldb = ldc = 1 + #lda = ldb = ldc = 1 #lda = 1 #print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) @@ -1466,7 +1466,7 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) alpha = ct.c_float(1.0) beta = ct.c_float(0.0) - lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), ldb, get_ptr(B), lda, beta, get_ptr(out), ldc) + lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 775716f80..91169dd72 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2947,9 +2947,11 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} + +#define ROWS 2 __global__ void gemm_device(int M, int N, int K, float const* A, - float const* B, + float* B, float * out, int lda, int ldb, int ldc, float alpha, float beta) { @@ -2958,29 +2960,106 @@ __global__ void gemm_device(int M, int N, int K, // 2. Dequantize B // 3. Fetch data from A and multiply - typedef cub::BlockLoad LoadA; - __shared__ typename LoadA::TempStorage loada; - float dataA[1]; - int valid_items = 0; + typedef cub::BlockLoad LoadA; + //__shared__ typename LoadA::TempStorage loada; + typedef cub::BlockLoad LoadB; + //__shared__ typename LoadB::TempStorage loadb; + typedef cub::BlockReduce BlockReduce; + // Allocate shared memory for BlockReduce + //__shared__ typename BlockReduce::TempStorage reduce; + + __shared__ union { + typename BlockReduce::TempStorage reduce; + typename LoadB::TempStorage loadb; + typename LoadA::TempStorage loada; + } temp_storage; + - __shared__ float[16*256] tileA; + float dataA[4]; + float local_B[4]; + float local_accC[ROWS]; + int valid_items = 0; + const int warp_id = threadIdx.x/32; + const int warp_lane = threadIdx.x % 32; + const int col_offset = blockIdx.x * 8; + + __shared__ float tileA[ROWS*1024]; + __shared__ float accumulatorC[ROWS*8]; + + //#pragma unroll 8 + //for(int i = 0; i < 8; i++) + // tileA[threadIdx.x + (i*256)] = 0.0f; + //__syncthreads(); + if(threadIdx.x < 64) + accumulatorC[threadIdx.x] = 0.0f; + __syncthreads(); - for(int idxA = 0; idxA < M*K; idxA+= 256) + for(int inner_idx = 0; inner_idx < K; inner_idx+= 1024) { - valid_items = M*K - idxA > 256 ? 256 : M*K - idxA; + valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx; int baserow = 0; - for(int row = baserow; row < baserow+16 && row < M + ; row++) + for(int row = baserow; row < (baserow+ROWS) && row < N; row++) { - LoadA(loada).Load(&(A[(row*lda) + i]), dataA, valid_items, 0.0f); - tileA[row*256 + threadIdx.x] = dataA[0]; + LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); + + #pragma unroll 4 + for(int k = 0; k < 4; k++) + tileA[row*1024 + threadIdx.x + (k*blockDim.x)] = dataA[k]; + __syncthreads(); } - baserow += 16; + baserow += ROWS; + + // load 16 columns from B at a time. B is transposed, so its like loading rows + // each warp loads one row + // each thread loads 128 byte + + // col: inner_idx + warp_lane + // row: ldb*(offset + warp_id) + for(int col = 0; col < 8 && (col_offset + col) < M; col++) + { + int colB = col_offset + col; + + for(int k = 0; k < ROWS; k++) + local_accC[k] = 0.0f; + int base_idxB = ldb*colB; + valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx; + LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); + __syncthreads(); + + for(int row = 0; row < ROWS && row < N; row++) + { + #pragma unroll 4 + for(int k = 0; k < 4; k++) + { + int idxA = row*1024 + threadIdx.x + (blockDim.x*k); + local_accC[row] += tileA[idxA]*local_B[k]; + } + local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); + if(threadIdx.x == 0) + atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); + } + } } + for(int row = 0; row < ROWS && row < N; row++) + { + int out_idx = ldc*row + col_offset; + + //if(threadIdx.x < 8) + // if(accumulatorC[row*8 + threadIdx.x] != 0.0) + // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); + + if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) + { + //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); + out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; + } + } + } diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 37e214a4a..55397e7dc 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -140,7 +140,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, __global__ void gemm_device(int M, int N, int K, float const* A, - float const* B, + float * B, float * out, int lda, int ldb, int ldc, float alpha, float beta); diff --git a/csrc/ops.cu b/csrc/ops.cu index ee585bb4c..dd8fadebe 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -669,8 +669,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) int threads = 256; int num_blocks = (n+(256*batch_size)+1)/(batch_size*256); - printf("%i %i\n", num_blocks, batch_size); - with_staging_unified<2><<>>(A, B, n, batch_size); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -680,15 +678,22 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) void gemm_host(int m, int n, int k, float alpha, float const* A, int lda, - float const* B, int ldb, + float * B, int ldb, float beta, float * C, int ldc) { dim3 dimBlock(256); - int num_blocks = (n+31)/32; + int num_blocks = (m+7)/8; cout << num_blocks << endl; + cout << lda << endl; + cout << ldb << endl; + cout << ldc << endl; + + cout << m << endl; + cout << n << endl; + cout << k << endl; gemm_device <<< num_blocks, dimBlock, 0, 0 >>> (m, n, k, diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 83dd4e515..2f71966c7 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -193,7 +193,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows void gemm_host(int m, int n, int k, float alpha, float const* A, int ldA, - float const* B, int ldB, + float * B, int ldB, float beta, float * C, int ldC); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 170093f0f..6ec550197 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -24,7 +24,7 @@ void cppgemm(int m, int n, int k, float alpha, float const* A, int ldA, - float const* B, int ldB, + float * B, int ldB, float beta, float * C, int ldC) { gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} @@ -320,7 +320,7 @@ extern "C" void ccutlass_gemm(int m, int n, int k, float alpha, float const* A, int ldA, - float const* B, int ldB, + float * B, int ldB, float beta, float * C, int ldC) { cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} diff --git a/tests/test_functional.py b/tests/test_functional.py index 7dec375f6..087bc849b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2353,17 +2353,19 @@ def test_normal_map_tree(): def test_cutlass3_gemm(): - #A = torch.rand(2, 2).cuda() - #B = torch.rand(2, 2).cuda() - A = torch.arange(4).reshape(2, 2).float().cuda().contiguous() - B = torch.ones(2, 2).float().cuda() + A = torch.rand(2, 4092).cuda() + B = torch.rand(4*4092, 4092).cuda() - print('') - print(A) - print(B) + #print('') + #print(A) + #print(B.t()) - C1 = torch.matmul(A, B) - C2 = F.cutlass3_gemm(A, B) + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + #print(C1) + #print(C2) + + torch.testing.assert_close(C1, C2) def test_pipeline_func(): From 3aef78342aec4fff1922c0c2cdd83bdda928b536 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 28 Apr 2023 17:34:08 -0700 Subject: [PATCH 32/63] Added template refactor. --- bitsandbytes/functional.py | 4 +--- csrc/kernels.cu | 23 ++++++++++------------- csrc/kernels.cuh | 6 +----- csrc/ops.cu | 11 +++-------- csrc/ops.cuh | 7 +------ csrc/pythonInterface.c | 19 ++++--------------- 6 files changed, 20 insertions(+), 50 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 774e9547a..da4e66c70 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1464,9 +1464,7 @@ def cutlass3_gemm( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - alpha = ct.c_float(1.0) - beta = ct.c_float(0.0) - lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc) + lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 91169dd72..45db4485f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2949,22 +2949,18 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * #define ROWS 2 -__global__ void gemm_device(int M, int N, int K, - float const* A, - float* B, - float * out, int lda, int ldb, int ldc, - float alpha, float beta) +template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) { // 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp // 1. Load dataB into register // 2. Dequantize B // 3. Fetch data from A and multiply - typedef cub::BlockLoad LoadA; + typedef cub::BlockLoad LoadA; //__shared__ typename LoadA::TempStorage loada; - typedef cub::BlockLoad LoadB; + typedef cub::BlockLoad LoadB; //__shared__ typename LoadB::TempStorage loadb; - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; // Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; @@ -2975,16 +2971,16 @@ __global__ void gemm_device(int M, int N, int K, } temp_storage; - float dataA[4]; - float local_B[4]; - float local_accC[ROWS]; + T dataA[4]; + T local_B[4]; + T local_accC[ROWS]; int valid_items = 0; const int warp_id = threadIdx.x/32; const int warp_lane = threadIdx.x % 32; const int col_offset = blockIdx.x * 8; - __shared__ float tileA[ROWS*1024]; - __shared__ float accumulatorC[ROWS*8]; + __shared__ T tileA[ROWS*1024]; + __shared__ T accumulatorC[ROWS*8]; //#pragma unroll 8 //for(int i = 0; i < 8; i++) @@ -3128,6 +3124,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); +template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 55397e7dc..900af908b 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,10 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -__global__ void gemm_device(int M, int N, int K, - float const* A, - float * B, - float * out, int lda, int ldb, int ldc, - float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index dd8fadebe..6aaa2414b 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -675,12 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -void gemm_host(int m, int n, int k, - float alpha, - float const* A, int lda, - float * B, int ldb, - float beta, - float * C, int ldc) +template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc) { dim3 dimBlock(256); @@ -699,14 +694,14 @@ void gemm_host(int m, int n, int k, (m, n, k, A, B, - C, lda, ldb, ldc, - alpha, beta); + out, lda, ldb, ldc); } //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void gemm_host(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 2f71966c7..b7ef9a3bd 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -190,12 +190,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -void gemm_host(int m, int n, int k, - float alpha, - float const* A, int ldA, - float * B, int ldB, - float beta, - float * C, int ldC); +template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 6ec550197..a7c4787b4 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,14 +20,8 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void -cppgemm(int m, int n, int k, - float alpha, - float const* A, int ldA, - float * B, int ldB, - float beta, - float * C, int ldC) -{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} +void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } #define MAKE_FUNC32(fname, oname, gtype, gbits) \ @@ -317,13 +311,8 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } - void ccutlass_gemm(int m, int n, int k, - float alpha, - float const* A, int ldA, - float * B, int ldB, - float beta, - float * C, int ldC) - { cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} + void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) + { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } From f6df4aef6a7b9c4636061c2701de0a9c3ab10098 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 28 Apr 2023 18:26:52 -0700 Subject: [PATCH 33/63] Added fp16 and thread/item template. --- bitsandbytes/functional.py | 11 ++++++++--- csrc/kernels.cu | 39 +++++++++++++++++++------------------- csrc/kernels.cuh | 2 +- csrc/ops.cu | 3 ++- csrc/pythonInterface.c | 5 +++++ tests/test_functional.py | 28 ++++++++++++++++----------- 6 files changed, 53 insertions(+), 35 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index da4e66c70..b5c622bbf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1381,9 +1381,9 @@ def cutlass3_gemm( transposed_A=False, transposed_B=False, ): - sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.float32) + sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if out is None: - out = torch.zeros(size=sout, dtype=torch.float32, device=A.device) + out = torch.zeros(size=sout, dtype=A.dtype, device=A.device) sA = A.shape sB = B.shape @@ -1464,7 +1464,12 @@ def cutlass3_gemm( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + if A.dtype == torch.float32: + lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + elif A.dtype == torch.float16: + lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + else: + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 45db4485f..67f9a3c67 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2949,18 +2949,18 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * #define ROWS 2 -template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) { // 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp // 1. Load dataB into register // 2. Dequantize B // 3. Fetch data from A and multiply - typedef cub::BlockLoad LoadA; + typedef cub::BlockLoad LoadA; //__shared__ typename LoadA::TempStorage loada; - typedef cub::BlockLoad LoadB; + typedef cub::BlockLoad LoadB; //__shared__ typename LoadB::TempStorage loadb; - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; // Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; @@ -2971,15 +2971,13 @@ template __global__ void gemm_device(int M, int N, int K, T const* } temp_storage; - T dataA[4]; - T local_B[4]; + T dataA[ITEMS]; + T local_B[ITEMS]; T local_accC[ROWS]; int valid_items = 0; - const int warp_id = threadIdx.x/32; - const int warp_lane = threadIdx.x % 32; const int col_offset = blockIdx.x * 8; - __shared__ T tileA[ROWS*1024]; + __shared__ T tileA[ROWS*THREADS*ITEMS]; __shared__ T accumulatorC[ROWS*8]; //#pragma unroll 8 @@ -2991,17 +2989,17 @@ template __global__ void gemm_device(int M, int N, int K, T const* __syncthreads(); - for(int inner_idx = 0; inner_idx < K; inner_idx+= 1024) + for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) { - valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx; + valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; int baserow = 0; for(int row = baserow; row < (baserow+ROWS) && row < N; row++) { LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); - #pragma unroll 4 - for(int k = 0; k < 4; k++) - tileA[row*1024 + threadIdx.x + (k*blockDim.x)] = dataA[k]; + #pragma unroll ITEMS + for(int k = 0; k < ITEMS; k++) + tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; __syncthreads(); } @@ -3021,16 +3019,16 @@ template __global__ void gemm_device(int M, int N, int K, T const* local_accC[k] = 0.0f; int base_idxB = ldb*colB; - valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx; + valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); __syncthreads(); for(int row = 0; row < ROWS && row < N; row++) { - #pragma unroll 4 - for(int k = 0; k < 4; k++) + #pragma unroll ITEMS + for(int k = 0; k < ITEMS; k++) { - int idxA = row*1024 + threadIdx.x + (blockDim.x*k); + int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); local_accC[row] += tileA[idxA]*local_B[k]; } @@ -3124,7 +3122,10 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); -template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 900af908b..9603e93e3 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,6 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 6aaa2414b..aa3dacfe9 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -689,7 +689,7 @@ template void gemm_host(int m, int n, int k, T const* A, T* B, T cout << m << endl; cout << n << endl; cout << k << endl; - gemm_device + gemm_device <<< num_blocks, dimBlock, 0, 0 >>> (m, n, k, A, @@ -702,6 +702,7 @@ template void gemm_host(int m, int n, int k, T const* A, T* B, T //============================================================== template void gemm_host(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, half const* A, half* B, half * out, int lda, int ldb, int ldc); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index a7c4787b4..3dd0b05ab 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -22,6 +22,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } +void gemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc) +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } #define MAKE_FUNC32(fname, oname, gtype, gbits) \ @@ -314,6 +316,9 @@ extern "C" void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + void cgemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc) + { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 087bc849b..156430689 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2352,20 +2352,26 @@ def test_normal_map_tree(): print(pivots) -def test_cutlass3_gemm(): - A = torch.rand(2, 4092).cuda() - B = torch.rand(4*4092, 4092).cuda() +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_cutlass3_gemm(dtype): + for i in range(2): + A = torch.rand(2, 4092, dtype=dtype, device='cuda') + B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(2, 4, dtype=dtype, device='cuda') + #B = torch.rand(4, 4, dtype=dtype, device='cuda') - #print('') - #print(A) - #print(B.t()) + #print('') + #print(A) + #print(B.t()) - C1 = torch.matmul(A, B.t()) - C2 = F.cutlass3_gemm(A, B.t()) - #print(C1) - #print(C2) - torch.testing.assert_close(C1, C2) + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + #print(C1) + #print(C2) + + #torch.testing.assert_close(C1, C2) def test_pipeline_func(): From f3e97ccbd2cdc1f40fe32e027fb3b5c22a92f09a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 28 Apr 2023 21:29:40 -0700 Subject: [PATCH 34/63] New implementation for batch size 1. --- csrc/kernels.cu | 265 ++++++++++++++++++++++++++------------- csrc/kernels.cuh | 2 +- csrc/ops.cu | 10 +- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 8 +- tests/test_functional.py | 12 +- 6 files changed, 196 insertions(+), 103 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 67f9a3c67..33102850f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2947,117 +2947,212 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} - #define ROWS 2 -template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { -// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp -// 1. Load dataB into register -// 2. Dequantize B -// 3. Fetch data from A and multiply - - typedef cub::BlockLoad LoadA; - //__shared__ typename LoadA::TempStorage loada; - typedef cub::BlockLoad LoadB; - //__shared__ typename LoadB::TempStorage loadb; - typedef cub::BlockReduce BlockReduce; - // Allocate shared memory for BlockReduce - //__shared__ typename BlockReduce::TempStorage reduce; - - __shared__ union { - typename BlockReduce::TempStorage reduce; - typename LoadB::TempStorage loadb; - typename LoadA::TempStorage loada; - } temp_storage; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage reduce; + int col_offset = blockIdx.x *8; - T dataA[ITEMS]; - T local_B[ITEMS]; - T local_accC[ROWS]; - int valid_items = 0; - const int col_offset = blockIdx.x * 8; + T local_A[8]; + T local_B[8]; + T local_C[8]; - __shared__ T tileA[ROWS*THREADS*ITEMS]; - __shared__ T accumulatorC[ROWS*8]; + __shared__ T smem_C[8]; - //#pragma unroll 8 - //for(int i = 0; i < 8; i++) - // tileA[threadIdx.x + (i*256)] = 0.0f; - //__syncthreads(); - if(threadIdx.x < 64) - accumulatorC[threadIdx.x] = 0.0f; + if(threadIdx.x < 8) + smem_C[threadIdx.x] = T(0); __syncthreads(); + #pragma unroll 8 + for(int k = 0; k < 8; k++) + local_C[k] = T(0); - for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) - { - valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; - int baserow = 0; - for(int row = baserow; row < (baserow+ROWS) && row < N; row++) - { - LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); - - #pragma unroll ITEMS - for(int k = 0; k < ITEMS; k++) - tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; - - __syncthreads(); - } - baserow += ROWS; - // load 16 columns from B at a time. B is transposed, so its like loading rows - // each warp loads one row - // each thread loads 128 byte + for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8) + { - // col: inner_idx + warp_lane - // row: ldb*(offset + warp_id) - for(int col = 0; col < 8 && (col_offset + col) < M; col++) + if(idx + 8 <= K) + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[idx/8]; + else { - int colB = col_offset + col; - - for(int k = 0; k < ROWS; k++) - local_accC[k] = 0.0f; + for(int k = 0; k < 8; k++) + { + if(idx + k < K) + local_A[k] = A[idx+k]; + else + local_A[k] = 0.0f; + } + } - int base_idxB = ldb*colB; - valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; - LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); - __syncthreads(); - for(int row = 0; row < ROWS && row < N; row++) + for(int col = 0; col < 8; col++) + { + int offset_B = (col_offset+col)*ldb; + if(idx + 8 <= K) + reinterpret_cast(local_B)[0] = reinterpret_cast(B)[(offset_B+idx)/8]; + else { - #pragma unroll ITEMS - for(int k = 0; k < ITEMS; k++) + for(int k = 0; k < 8; k++) { - int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); - local_accC[row] += tileA[idxA]*local_B[k]; + if(idx + k < K) + local_B[k] = B[(offset_B+idx)+k]; + else + local_B[k] = 0.0f; } + } - local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); - if(threadIdx.x == 0) - atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); + #pragma unroll 8 + for(int k = 0; k < 8; k++) + { + local_C[col] += local_A[k]*local_B[k]; + //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) + // printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]); } + } - } + } - for(int row = 0; row < ROWS && row < N; row++) + #pragma unroll 8 + for(int k = 0; k < 8; k++) { - int out_idx = ldc*row + col_offset; + local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); + __syncthreads(); + } - //if(threadIdx.x < 8) - // if(accumulatorC[row*8 + threadIdx.x] != 0.0) - // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); + if(threadIdx.x == 0) + #pragma unroll 8 + for(int k = 0; k < 8; k++) + smem_C[k] = local_C[k]; + else if(threadIdx.x >= 32) + // early return for unused warps + return; - if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) - { - //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); - out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; - } - } + __syncwarp(); + + + //for(int k = 0; k < 8; k++) + // if((float)local_C[k] != 0.0f) + // printf("%i %f\n", threadIdx.x, (float)local_C[k]); + + if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; } +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + __device__ void compute(float* global_out, float const* shared_in) { @@ -3122,10 +3217,8 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); -template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 9603e93e3..23ecf454e 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,6 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index aa3dacfe9..c0c26588e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -675,10 +675,10 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc) +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc) { - dim3 dimBlock(256); + dim3 dimBlock(128); int num_blocks = (m+7)/8; cout << num_blocks << endl; @@ -689,7 +689,7 @@ template void gemm_host(int m, int n, int k, T const* A, T* B, T cout << m << endl; cout << n << endl; cout << k << endl; - gemm_device + gemm_device <<< num_blocks, dimBlock, 0, 0 >>> (m, n, k, A, @@ -701,8 +701,8 @@ template void gemm_host(int m, int n, int k, T const* A, T* B, T // TEMPLATE DEFINITIONS //============================================================== -template void gemm_host(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_host(int m, int n, int k, half const* A, half* B, half * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b7ef9a3bd..882264095 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -190,7 +190,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 3dd0b05ab..f92b52f9b 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,9 +20,9 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) +void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } -void gemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc) +void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } @@ -313,10 +313,10 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } - void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) + void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } - void cgemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc) + void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } #endif diff --git a/tests/test_functional.py b/tests/test_functional.py index 156430689..f08c4a238 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2355,11 +2355,11 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for i in range(2): - A = torch.rand(2, 4092, dtype=dtype, device='cuda') - B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(2, 4, dtype=dtype, device='cuda') - #B = torch.rand(4, 4, dtype=dtype, device='cuda') + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + A = torch.rand(1, 4096, dtype=dtype, device='cuda') + B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') #print('') #print(A) @@ -2371,7 +2371,7 @@ def test_cutlass3_gemm(dtype): #print(C1) #print(C2) - #torch.testing.assert_close(C1, C2) + torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005) def test_pipeline_func(): From cad839941b2c0a013525be339f6e9c157caa925d Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 28 Apr 2023 22:10:42 -0700 Subject: [PATCH 35/63] Added bit template. --- csrc/kernels.cu | 77 +++++++++++++++++----------------------- csrc/kernels.cuh | 2 +- csrc/ops.cu | 16 ++++----- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 4 +-- tests/test_functional.py | 4 +-- 6 files changed, 45 insertions(+), 60 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 33102850f..a5697eeaa 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2947,16 +2947,31 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} -#define ROWS 2 -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = 0.0f; + } + } +} + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *8; - T local_A[8]; - T local_B[8]; + T local_A[128/BITS]; + T local_B[128/BITS]; T local_C[8]; __shared__ T smem_C[8]; @@ -2970,47 +2985,18 @@ template __global__ void gemm_device(int M, local_C[k] = T(0); - for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8) + for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS) { - - if(idx + 8 <= K) - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[idx/8]; - else - { - for(int k = 0; k < 8; k++) - { - if(idx + k < K) - local_A[k] = A[idx+k]; - else - local_A[k] = 0.0f; - } - } - + vector_load(local_A, A, idx, idx, K); for(int col = 0; col < 8; col++) { int offset_B = (col_offset+col)*ldb; - if(idx + 8 <= K) - reinterpret_cast(local_B)[0] = reinterpret_cast(B)[(offset_B+idx)/8]; - else - { - for(int k = 0; k < 8; k++) - { - if(idx + k < K) - local_B[k] = B[(offset_B+idx)+k]; - else - local_B[k] = 0.0f; - } - } + vector_load(local_B, B, offset_B+idx, idx, K); - #pragma unroll 8 - for(int k = 0; k < 8; k++) - { + #pragma unroll 128/BITS + for(int k = 0; k < 128/BITS; k++) local_C[col] += local_A[k]*local_B[k]; - //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) - // printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]); - } - } } @@ -3022,9 +3008,11 @@ template __global__ void gemm_device(int M, } if(threadIdx.x == 0) + { #pragma unroll 8 for(int k = 0; k < 8; k++) smem_C[k] = local_C[k]; + } else if(threadIdx.x >= 32) // early return for unused warps return; @@ -3032,15 +3020,8 @@ template __global__ void gemm_device(int M, __syncwarp(); - //for(int k = 0; k < 8; k++) - // if((float)local_C[k] != 0.0f) - // printf("%i %f\n", threadIdx.x, (float)local_C[k]); - if(threadIdx.x < 8 && col_offset + threadIdx.x < M) out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - - - } //#define ROWS 2 @@ -3217,7 +3198,13 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); + +// these are not used and make no sense, but the compiler needs them template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 23ecf454e..aab7b9554 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,6 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index c0c26588e..221969089 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -675,7 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc) +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { dim3 dimBlock(128); @@ -689,20 +689,18 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out cout << m << endl; cout << n << endl; cout << k << endl; - gemm_device - <<< num_blocks, dimBlock, 0, 0 >>> - (m, n, k, - A, - B, - out, lda, ldb, ldc); + if(bits == 32) + gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + else if(bits == 16) + gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 882264095..ffc9e874c 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -190,7 +190,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index f92b52f9b..1ece3e63f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -21,9 +21,9 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } #define MAKE_FUNC32(fname, oname, gtype, gbits) \ diff --git a/tests/test_functional.py b/tests/test_functional.py index f08c4a238..b256af9ad 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2352,8 +2352,8 @@ def test_normal_map_tree(): print(pivots) -#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') From 21723f796a3951e56b77460e7d572c76619b773f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 29 Apr 2023 21:52:47 -0700 Subject: [PATCH 36/63] 4-bit draft. --- bitsandbytes/functional.py | 22 +++- csrc/kernels.cu | 222 +++++++++++++++++++++++++++++++++---- csrc/kernels.cuh | 1 + csrc/ops.cu | 18 +++ csrc/ops.cuh | 1 + csrc/pythonInterface.c | 6 + tests/test_functional.py | 30 ++++- 7 files changed, 273 insertions(+), 27 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b5c622bbf..f725c1c63 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1380,10 +1380,15 @@ def cutlass3_gemm( out: Tensor = None, transposed_A=False, transposed_B=False, + state=None ): - sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + if state is None: + Bshape = B.shape + else: + Bshape = state[1] if out is None: - out = torch.zeros(size=sout, dtype=A.dtype, device=A.device) + out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device) sA = A.shape sB = B.shape @@ -1456,7 +1461,13 @@ def cutlass3_gemm( # [km, nk -> mn] #lda = ldb = ldc = 1 #lda = 1 - #print(m, n, k, lda, ldb, ldc) + if state is not None: + m = Bshape[0] + k = Bshape[1] + lda = Bshape[1] + ldc = Bshape[0] + ldb = (ldb+1)//2 + print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1464,7 +1475,10 @@ def cutlass3_gemm( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - if A.dtype == torch.float32: + + if B.dtype == torch.uint8: + lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) elif A.dtype == torch.float16: lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a5697eeaa..53a183d98 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -69,6 +69,27 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax) } } +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; @@ -145,7 +166,61 @@ __device__ unsigned char dQuantizeFP4(float x) return 0b0000+sign; } -__device__ float dDequantizeNF4(unsigned char val, float absmax) +__device__ half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py @@ -153,49 +228,49 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f*absmax; + return 1.0f; else - return 0.7229568362236023f*absmax; + return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f*absmax; + return 0.5626170039176941f; else - return 0.44070982933044434f*absmax; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f*absmax; + return 0.33791524171829224f; else - return 0.24611230194568634f*absmax; + return 0.24611230194568634f; else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f*absmax; + return 0.16093020141124725f; else - return 0.07958029955625534f*absmax; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f*absmax; + return 0.0f; else - return -0.09105003625154495f*absmax; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f*absmax; + return -0.18477343022823334f; else - return -0.28444138169288635f*absmax; + return -0.28444138169288635f; else if((val & 0b0010) == 2) //00 if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f*absmax; + return -0.39491748809814453f; else - return -0.5250730514526367f*absmax; + return -0.5250730514526367f; else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f*absmax; + return -0.6961928009986877f; else - return -1.0f*absmax; + return -1.0f; } @@ -800,8 +875,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { - vals[j*2] = dDequantizeNF4(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F, local_abs_max); + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; } break; } @@ -2947,7 +3022,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} -template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit) +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) { if(limit_base + ITEMS <= limit) reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; @@ -2958,7 +3033,7 @@ template __device__ inline void vector_l if(limit_base + k < limit) local[k] = buffer[idx+k]; else - local[k] = 0.0f; + local[k] = (T)zero_value; } } } @@ -3024,6 +3099,109 @@ template __global__ void gemm_device(int M, out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; } +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage reduce; + int col_offset = blockIdx.x *8; + + T local_A[32]; + unsigned char local_B_4bit[16]; + T local_B[32]; + T local_C[8]; + + __shared__ T smem_C[8]; + + if(threadIdx.x < 8) + smem_C[threadIdx.x] = T(0); + __syncthreads(); + + #pragma unroll 8 + for(int k = 0; k < 8; k++) + local_C[k] = T(0); + + + for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32) + { + + // we load only 8 values per iteration from A, so we + // need to do 4 loads for every single load from B + // for B, we have packed values, so the 16 8-bit values + // turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads + vector_load(local_A, A, idx, idx, K); + vector_load(&(local_A[8]), A, idx+8, idx+8, K); + vector_load(&(local_A[16]), A, idx+16, idx+16, K); + vector_load(&(local_A[24]), A, idx+24, idx+24, K); + + for(int col = 0; col < 8; col++) + { + if((col + col_offset) >= M){ break; } + + int offset_B = (col_offset+col)*ldb; + // 0111 -> 0.0f in NF4 + // since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111 + vector_load(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111); + + int absidx = (idx + offset_B)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + //for(int k = 0; k < 16; k++) + //printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F); + //printf("\n"); + + //vector_load(local_A, A, idx, idx, K); + + #pragma unroll 16 + for(int k = 0; k < 16; k++) + { + + //if(local_B_4bit[k ] != 0b01110111) + //printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax), + //local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax)); + //local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax; + //local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax; + local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax; + local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax; + //local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax; + //local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax; + } + + #pragma unroll 32 + //for(int k = 0; k < 8; k++) + for(int k = 0; k < 32; k++) + { + local_C[col] += local_A[k]*local_B[k]; + //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) + //if((float)local_B[k] != 0.0) + //printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]); + } + } + } + + #pragma unroll 8 + for(int k = 0; k < 8; k++) + { + local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); + __syncthreads(); + } + + if(threadIdx.x == 0) + { + #pragma unroll 8 + for(int k = 0; k < 8; k++) + smem_C[k] = local_C[k]; + } + else if(threadIdx.x >= 32) + // early return for unused warps + return; + + __syncwarp(); + + + if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; +} + //#define ROWS 2 //template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) //{ @@ -3207,6 +3385,8 @@ template __global__ void gemm_device(int M, int N, int K, half * template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index aab7b9554..4951031a2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -139,5 +139,6 @@ template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 221969089..07e710741 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -695,10 +695,28 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + dim3 dimBlock(128); + int num_blocks = (m+7)/8; + + cout << num_blocks << endl; + cout << lda << endl; + cout << ldb << endl; + cout << ldc << endl; + + cout << m << endl; + cout << n << endl; + cout << k << endl; + kgemm_4bit_inference<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index ffc9e874c..8919c6016 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -191,6 +191,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 1ece3e63f..bdf821c8a 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -25,6 +25,9 @@ void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, in void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } +void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_g##gbits(gtype *g, gtype *p, \ @@ -319,6 +322,9 @@ extern "C" void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } + void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index b256af9ad..f58cd43c2 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2352,8 +2352,8 @@ def test_normal_map_tree(): print(pivots) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') @@ -2373,6 +2373,32 @@ def test_cutlass3_gemm(dtype): torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005) +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_gemm_4bit(dtype): + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #torch.random.manual_seed(17) + A = torch.rand(1, 4096, dtype=dtype, device='cuda') + B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + + #print('') + #print(A) + #print(B) + + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) + + + C1 = torch.matmul(A, B.t()) + #C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + #print(C1) + #print(C2) + + #torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005) + def test_pipeline_func(): a = torch.rand(2, 4).cuda() From ad07d254fb5cefadf8dcb6020b24fb0baee4e936 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 17:43:02 -0700 Subject: [PATCH 37/63] Slow tensor core solution. --- csrc/kernels.cu | 177 +++++++++++++++++++++++++++++++-------- csrc/ops.cu | 17 ++-- csrc/pythonInterface.c | 8 +- tests/test_functional.py | 2 + 4 files changed, 158 insertions(+), 46 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 53a183d98..24b004b2f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,8 @@ #define NUM 4 #define NUM_BLOCK 4096 +using namespace nvcuda; + // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -3041,62 +3044,164 @@ template __device__ inline void vector_l template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage reduce; + typedef cub::WarpReduce WarpReduce; + // Allocate WarpReduce shared memory for one warp + //__shared__ typename WarpReduce::TempStorage temp_storage; + + //typedef cub::BlockReduce BlockReduce; + //// Allocate shared memory for BlockReduce + //__shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *8; + const int warp_id = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; - T local_A[128/BITS]; - T local_B[128/BITS]; + T local_A[64/BITS]; + T local_B[64/BITS]; T local_C[8]; - __shared__ T smem_C[8]; + __shared__ T smem_A[4*32*16]; + __shared__ T smem_B[4*16*8]; + __shared__ T smem_C[4*32*8]; - if(threadIdx.x < 8) - smem_C[threadIdx.x] = T(0); + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + + wmma::fill_fragment(c_frag, 0.0f); + + + for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x) + smem_A[i] = T(0); + + for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x) + smem_B[i] = T(0); + + for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x) + smem_C[i] = T(0); __syncthreads(); #pragma unroll 8 for(int k = 0; k < 8; k++) local_C[k] = T(0); - - for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS) + int block_idx = 0; + //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) + for(int base_idx = 0; base_idx < K; base_idx+=64) { - vector_load(local_A, A, idx, idx, K); - for(int col = 0; col < 8; col++) + int tidx = threadIdx.x*4; + + if(base_idx % (4*blockDim.x) == 0) { - int offset_B = (col_offset+col)*ldb; - vector_load(local_B, B, offset_B+idx, idx, K); + vector_load(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu + block_idx = 0; + } - #pragma unroll 128/BITS - for(int k = 0; k < 128/BITS; k++) - local_C[col] += local_A[k]*local_B[k]; + for(int k = 0; k < 4; k++) + { + if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16)) + smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu } - } + block_idx += 1; + + // 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time + // we need 8 columns, so 2 loads and smem stores + // we need a half-warp to load one column at a time + for(int j = 0; j < 2; j++) + { + int col = warp_id + (j*4); + int offset_B = (col_offset+col)*ldb; + vector_load(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu + + + //#pragma unroll 4 + //for(int k = 0; k < 4; k++) + // if((float)local_B[k] != 0.0) + // printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]); + + // load and store is different + // we wnat to load 64 consequitive values with one warp + // but we need to store those across 4 fragments since + // the max column width is 16. + + // each 16 values a new tile for each warp + //int tile_idx = warp_lane/16; + #pragma unroll 4 + for(int k = 0; k < 4; k++) + smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu + } + + - #pragma unroll 8 - for(int k = 0; k < 8; k++) - { - local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); __syncthreads(); - } - if(threadIdx.x == 0) - { - #pragma unroll 8 - for(int k = 0; k < 8; k++) - smem_C[k] = local_C[k]; + //if(threadIdx.x == 0) + // for(int w = 0; w < 4; w++) + // for(int trow = 0; trow < 32; trow++) + // for(int tcol = 0; tcol < 16; tcol++) + // if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0) + // printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); + + //if(threadIdx.x == 0) + // for(int w = 0; w < 4; w++) + // for(int trow = 0; trow < 16; trow++) + // for(int tcol = 0; tcol < 8; tcol++) + // if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0) + // printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); + + + //__syncthreads(); + + wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } - else if(threadIdx.x >= 32) - // early return for unused warps - return; - __syncwarp(); + // 129 mu + wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major); + __syncthreads(); + //if(threadIdx.x >= 16){ return; } + //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); + //if(threadIdx.x < 32) + if(warp_lane < 8 && warp_id > 0) + //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); + __syncthreads(); + + //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); + //if(threadIdx.x == 0) + // for(int row = 0; row < 32; row++) + // { + // printf("row %i ", row); + // for(int id = 0; id < 4; id++) + // { + // printf(" id %i: ", id); + // for(int k = 0; k < 8; k++) + // printf("%f ", (float)smem_C[k + (row*8) + (id*32*8)]); + // printf("\n"); + // } + // } + + //__syncthreads(); + + //if((float)local_C[0] !=0.0f) + // printf("%i %i %f\n", warp_lane, warp_id, (float)local_C[0]); + //local_C[0] = WarpReduce(temp_storage).Sum(local_C[0]); + + //__syncwarp(); + + ////for(int i = threadIdx.x; i < 32*8; i+=blockDim.x) + ////{ + // if((float)local_C[0] !=0.0f) + // printf("%i %f\n", 0, (float)local_C[0]); + //} + + //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; if(threadIdx.x < 8 && col_offset + threadIdx.x < M) - out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; + out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -3378,12 +3483,16 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // half alpha, half beta); // these are not used and make no sense, but the compiler needs them -template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them -template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index 07e710741..d83fc6e11 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -678,7 +678,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { - dim3 dimBlock(128); int num_blocks = (m+7)/8; cout << num_blocks << endl; @@ -689,16 +688,17 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out cout << m << endl; cout << n << endl; cout << k << endl; - if(bits == 32) - gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - else if(bits == 16) - gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { - dim3 dimBlock(128); int num_blocks = (m+7)/8; cout << num_blocks << endl; @@ -709,7 +709,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi cout << m << endl; cout << n << endl; cout << k << endl; - kgemm_4bit_inference<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } //============================================================== @@ -717,7 +718,7 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //============================================================== template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index bdf821c8a..26f16f218 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,8 +20,8 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } +//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } @@ -316,8 +316,8 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } - void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) - { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) + //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } diff --git a/tests/test_functional.py b/tests/test_functional.py index f58cd43c2..e2ecdcb55 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2358,6 +2358,8 @@ def test_cutlass3_gemm(dtype): for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') A = torch.rand(1, 4096, dtype=dtype, device='cuda') B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') From 604bb3fb573eee2437c2ed51efbd0e3c1382e060 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 18:06:01 -0700 Subject: [PATCH 38/63] Slow non-vector 530. --- csrc/kernels.cu | 102 +++++++++++++++--------------------------------- csrc/ops.cu | 4 +- 2 files changed, 33 insertions(+), 73 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 24b004b2f..5a6db7d48 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,6 +3041,7 @@ template __device__ inline void vector_l } } +#define WARPS 1 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3059,9 +3060,9 @@ template __global__ void gemm_device(int M, T local_B[64/BITS]; T local_C[8]; - __shared__ T smem_A[4*32*16]; - __shared__ T smem_B[4*16*8]; - __shared__ T smem_C[4*32*8]; + __shared__ T smem_A[WARPS*32*16]; + __shared__ T smem_B[WARPS*16*8]; + __shared__ T smem_C[WARPS*32*8]; wmma::fragment a_frag; wmma::fragment b_frag; @@ -3070,13 +3071,13 @@ template __global__ void gemm_device(int M, wmma::fill_fragment(c_frag, 0.0f); - for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x) + for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x) smem_A[i] = T(0); - for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x) + for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x) smem_B[i] = T(0); - for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x) + for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x) smem_C[i] = T(0); __syncthreads(); @@ -3084,91 +3085,48 @@ template __global__ void gemm_device(int M, for(int k = 0; k < 8; k++) local_C[k] = T(0); - int block_idx = 0; + //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) - for(int base_idx = 0; base_idx < K; base_idx+=64) + for(int base_idx = 0; base_idx < K; base_idx+=16) { + int idx = base_idx + threadIdx.x; - int tidx = threadIdx.x*4; - - if(base_idx % (4*blockDim.x) == 0) - { - vector_load(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu - block_idx = 0; - } - - for(int k = 0; k < 4; k++) - { - if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16)) - smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu - } - block_idx += 1; - - // 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time - // we need 8 columns, so 2 loads and smem stores - // we need a half-warp to load one column at a time - for(int j = 0; j < 2; j++) + if(threadIdx.x < 16) { - int col = warp_id + (j*4); - int offset_B = (col_offset+col)*ldb; - vector_load(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu - - - //#pragma unroll 4 - //for(int k = 0; k < 4; k++) - // if((float)local_B[k] != 0.0) - // printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]); + if(idx >= K) + { + smem_A[threadIdx.x] = 0.0f; + smem_B[threadIdx.x] = 0.0f; + } + else + { - // load and store is different - // we wnat to load 64 consequitive values with one warp - // but we need to store those across 4 fragments since - // the max column width is 16. + smem_A[threadIdx.x] = A[idx]; - // each 16 values a new tile for each warp - //int tile_idx = warp_lane/16; - #pragma unroll 4 - for(int k = 0; k < 4; k++) - smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu + for(int col = 0; col < 8; col++) + smem_B[threadIdx.x + (col*16)] = B[(col_offset+col)*ldb+idx]; + } } - - __syncthreads(); - //if(threadIdx.x == 0) - // for(int w = 0; w < 4; w++) - // for(int trow = 0; trow < 32; trow++) - // for(int tcol = 0; tcol < 16; tcol++) - // if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0) - // printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); - - //if(threadIdx.x == 0) - // for(int w = 0; w < 4; w++) - // for(int trow = 0; trow < 16; trow++) - // for(int tcol = 0; tcol < 8; tcol++) - // if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0) - // printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); - - - //__syncthreads(); - - wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu + wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu - wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); //if(threadIdx.x < 32) - if(warp_lane < 8 && warp_id > 0) - //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); - __syncthreads(); + //if(warp_lane < 8 && warp_id > 0) + // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + // atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); + //__syncthreads(); //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); //if(threadIdx.x == 0) @@ -3487,12 +3445,14 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index d83fc6e11..5c4f9c042 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -692,8 +692,8 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) - gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) From c35ed09b668db43da967ddeff88c13d92a5cb02a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 18:19:30 -0700 Subject: [PATCH 39/63] Double frag 440. --- csrc/kernels.cu | 27 ++++++++++++++++----------- tests/test_functional.py | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5a6db7d48..5d1982da5 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3053,19 +3053,24 @@ template __global__ void gemm_device(int M, //// Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *8; - const int warp_id = threadIdx.x / 32; - const int warp_lane = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; T local_A[64/BITS]; T local_B[64/BITS]; T local_C[8]; - __shared__ T smem_A[WARPS*32*16]; - __shared__ T smem_B[WARPS*16*8]; + const int a_tile_offset = 32*16; + const int b_tile_offset = 16*8; + + __shared__ T smem_A[WARPS*32*16*2]; + __shared__ T smem_B[WARPS*16*8*2]; __shared__ T smem_C[WARPS*32*8]; wmma::fragment a_frag; wmma::fragment b_frag; + wmma::fragment a2_frag; + wmma::fragment b2_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); @@ -3087,32 +3092,32 @@ template __global__ void gemm_device(int M, //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) - for(int base_idx = 0; base_idx < K; base_idx+=16) + for(int base_idx = 0; base_idx < K; base_idx+=32) { int idx = base_idx + threadIdx.x; - if(threadIdx.x < 16) - { if(idx >= K) { smem_A[threadIdx.x] = 0.0f; - smem_B[threadIdx.x] = 0.0f; + //smem_B[threadIdx.x] = 0.0f; } else { - smem_A[threadIdx.x] = A[idx]; + smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; for(int col = 0; col < 8; col++) - smem_B[threadIdx.x + (col*16)] = B[(col_offset+col)*ldb+idx]; + smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; } - } __syncthreads(); wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu + wmma::load_matrix_sync(a2_frag, &(smem_A[32*16]), 16); // 111 mu + wmma::load_matrix_sync(b2_frag, &(smem_B[16*8]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); } // 129 mu diff --git a/tests/test_functional.py b/tests/test_functional.py index e2ecdcb55..f31e9b4d7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2373,7 +2373,7 @@ def test_cutlass3_gemm(dtype): #print(C1) #print(C2) - torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005) + torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.05) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) From e01d4e033df8f94b28ae4e38608c621653673338 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 18:28:52 -0700 Subject: [PATCH 40/63] Fixed bank conflicts in non-vector load 422. --- csrc/kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5d1982da5..dffd40cef 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3060,11 +3060,11 @@ template __global__ void gemm_device(int M, T local_B[64/BITS]; T local_C[8]; - const int a_tile_offset = 32*16; - const int b_tile_offset = 16*8; + const int a_tile_offset = 32*16 + 16; + const int b_tile_offset = 16*8 + 16; - __shared__ T smem_A[WARPS*32*16*2]; - __shared__ T smem_B[WARPS*16*8*2]; + __shared__ T smem_A[WARPS*32*16*2 + (16*1)]; + __shared__ T smem_B[WARPS*16*8*2 + (16*1)]; __shared__ T smem_C[WARPS*32*8]; wmma::fragment a_frag; @@ -3114,8 +3114,8 @@ template __global__ void gemm_device(int M, wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu - wmma::load_matrix_sync(a2_frag, &(smem_A[32*16]), 16); // 111 mu - wmma::load_matrix_sync(b2_frag, &(smem_B[16*8]), 16); // 35 mu + wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); } From 30d03e0254f9868f29392f318787667d5bdff891 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 18:55:12 -0700 Subject: [PATCH 41/63] 64 threads, high smem, 434. --- csrc/kernels.cu | 48 ++++++++++++++++++++++++------------------------ csrc/ops.cu | 3 ++- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index dffd40cef..400211739 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 1 +#define WARPS 2 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3062,10 +3062,11 @@ template __global__ void gemm_device(int M, const int a_tile_offset = 32*16 + 16; const int b_tile_offset = 16*8 + 16; + const int c_tile_offset = 32*8 + 24; - __shared__ T smem_A[WARPS*32*16*2 + (16*1)]; - __shared__ T smem_B[WARPS*16*8*2 + (16*1)]; - __shared__ T smem_C[WARPS*32*8]; + __shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))]; + __shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))]; + __shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))]; wmma::fragment a_frag; wmma::fragment b_frag; @@ -3092,46 +3093,45 @@ template __global__ void gemm_device(int M, //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) - for(int base_idx = 0; base_idx < K; base_idx+=32) + for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) { int idx = base_idx + threadIdx.x; - if(idx >= K) - { - smem_A[threadIdx.x] = 0.0f; - //smem_B[threadIdx.x] = 0.0f; - } - else - { - - smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; + if(idx >= K) + { + smem_A[threadIdx.x] = 0.0f; + //smem_B[threadIdx.x] = 0.0f; + } + else + { + smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; - for(int col = 0; col < 8; col++) - smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; - } + for(int col = 0; col < 8; col++) + smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; + } __syncthreads(); wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu - wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu + wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); } // 129 mu - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[half_warp_id*c_tile_offset]), c_frag, 8, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); //if(threadIdx.x < 32) - //if(warp_lane < 8 && warp_id > 0) - // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - // atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); - //__syncthreads(); + if(half_warp_lane < 8 && half_warp_id > 0) + //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]); + __syncthreads(); //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); //if(threadIdx.x == 0) diff --git a/csrc/ops.cu b/csrc/ops.cu index 5c4f9c042..57d5cca0b 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -693,7 +693,8 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) From cabcd9b9d5c986b5c3c58318f9c1185ea8d8eff5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 19:12:42 -0700 Subject: [PATCH 42/63] Halved shared memory 466. --- csrc/kernels.cu | 70 ++++++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 400211739..301221c06 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3053,25 +3053,23 @@ template __global__ void gemm_device(int M, //// Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *8; + const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; - T local_A[64/BITS]; - T local_B[64/BITS]; - T local_C[8]; + T local_A[1]; + T local_B[8]; const int a_tile_offset = 32*16 + 16; const int b_tile_offset = 16*8 + 16; const int c_tile_offset = 32*8 + 24; - __shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))]; - __shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))]; + __shared__ T smem_A[WARPS*32*16 + (16*(WARPS-1))]; + __shared__ T smem_B[WARPS*16*8 + (16*(WARPS-1))]; __shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))]; wmma::fragment a_frag; wmma::fragment b_frag; - wmma::fragment a2_frag; - wmma::fragment b2_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); @@ -3087,9 +3085,9 @@ template __global__ void gemm_device(int M, smem_C[i] = T(0); __syncthreads(); - #pragma unroll 8 - for(int k = 0; k < 8; k++) - local_C[k] = T(0); + //#pragma unroll 8 + //for(int k = 0; k < 8; k++) + //local_C[k] = T(0); //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) @@ -3097,27 +3095,45 @@ template __global__ void gemm_device(int M, { int idx = base_idx + threadIdx.x; - if(idx >= K) + for(int k = 0; k < 2; k++) { - smem_A[threadIdx.x] = 0.0f; - //smem_B[threadIdx.x] = 0.0f; - } - else - { - smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; + if(k == 0) + { + if(idx < K) + { + local_A[0] = A[idx]; - for(int col = 0; col < 8; col++) - smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; - } + #pragma unroll 8 + for(int col = 0; col < 8; col++) + local_B[col] = B[(col_offset+col)*ldb+idx]; + } - __syncthreads(); + } + + if(idx >= K) + { + smem_A[threadIdx.x] = 0.0f; + //smem_B[threadIdx.x] = 0.0f; + } + else + { + if((k == 0 && half_warp_id % 2 == 0) || + (k == 1 && half_warp_id % 2 == 1)) + { + smem_A[half_warp_lane + (warp_id*a_tile_offset)] = local_A[0]; - wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu - wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); + #pragma unroll 8 + for(int col = 0; col < 8; col++) + smem_B[half_warp_lane + (warp_id*b_tile_offset) + (col*16)] = local_B[col]; + } + } + + __syncthreads(); + + wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } } // 129 mu From 7cc8ff4727e9e1094937b59aef96777c4818ae8a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 1 May 2023 08:21:12 -0700 Subject: [PATCH 43/63] Warp specalization 362. --- csrc/kernels.cu | 100 +++++++++++++++++++++------------------ csrc/ops.cu | 5 +- tests/test_functional.py | 6 +-- 3 files changed, 60 insertions(+), 51 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 301221c06..2c0737d0f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 2 +#define WARPS 4 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3056,17 +3056,18 @@ template __global__ void gemm_device(int M, const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; T local_A[1]; T local_B[8]; - const int a_tile_offset = 32*16 + 16; - const int b_tile_offset = 16*8 + 16; + const int a_tile_offset = (32*16 + 16); + const int b_tile_offset = (16*8 + 16); const int c_tile_offset = 32*8 + 24; - __shared__ T smem_A[WARPS*32*16 + (16*(WARPS-1))]; - __shared__ T smem_B[WARPS*16*8 + (16*(WARPS-1))]; - __shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))]; + __shared__ T smem_A[2*batch_size_warps*32*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*8 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[32*8]; wmma::fragment a_frag; wmma::fragment b_frag; @@ -3091,63 +3092,68 @@ template __global__ void gemm_device(int M, //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) - for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) + int ticktock = 0; + int idx = 0 + threadIdx.x; + // prefetch + if(idx < K && warp_id < (WARPS-1)) { - int idx = base_idx + threadIdx.x; + local_A[0] = A[idx]; - for(int k = 0; k < 2; k++) - { - if(k == 0) - { - if(idx < K) - { - local_A[0] = A[idx]; + #pragma unroll 8 + for(int col = 0; col < 8; col++) + local_B[col] = B[(col_offset+col)*ldb+idx]; - #pragma unroll 8 - for(int col = 0; col < 8; col++) - local_B[col] = B[(col_offset+col)*ldb+idx]; - } + smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0]; - } - - if(idx >= K) - { - smem_A[threadIdx.x] = 0.0f; - //smem_B[threadIdx.x] = 0.0f; - } - else - { - if((k == 0 && half_warp_id % 2 == 0) || - (k == 1 && half_warp_id % 2 == 1)) - { - smem_A[half_warp_lane + (warp_id*a_tile_offset)] = local_A[0]; + #pragma unroll 8 + for(int col = 0; col < 8; col++) + smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; + } + ticktock = ticktock == 0 ? 1 : 0; - #pragma unroll 8 - for(int col = 0; col < 8; col++) - smem_B[half_warp_lane + (warp_id*b_tile_offset) + (col*16)] = local_B[col]; - } - } + for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; - __syncthreads(); + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + local_A[0] = A[idx]; - wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + #pragma unroll 8 + for(int col = 0; col < 8; col++) + local_B[col] = B[(col_offset+col)*ldb+idx]; + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 8 + for(int col = 0; col < 8; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } } // 129 mu - wmma::store_matrix_sync(&(smem_C[half_warp_id*c_tile_offset]), c_frag, 8, wmma::mem_row_major); + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); //if(threadIdx.x < 32) - if(half_warp_lane < 8 && half_warp_id > 0) - //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]); - __syncthreads(); + //if(half_warp_lane < 8 && half_warp_id > 0) + // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + // atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]); + //__syncthreads(); //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); //if(threadIdx.x == 0) @@ -3463,6 +3469,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); @@ -3470,6 +3477,7 @@ template __global__ void gemm_device(int M, int N, int K, half * _ // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); diff --git a/csrc/ops.cu b/csrc/ops.cu index 57d5cca0b..c1c27b8b1 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -692,9 +692,10 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) - //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) diff --git a/tests/test_functional.py b/tests/test_functional.py index f31e9b4d7..5f90f693c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2370,10 +2370,10 @@ def test_cutlass3_gemm(dtype): C1 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, B.t()) - #print(C1) - #print(C2) + print(C1) + print(C2) - torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.05) + torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.06) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) From 3d4a2eadd3c1481447b8e885018ed24341ea91a5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 1 May 2023 16:23:45 -0700 Subject: [PATCH 44/63] 16x16 240. --- csrc/kernels.cu | 52 ++++++++++++++++++++++++------------------------- csrc/ops.cu | 2 +- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2c0737d0f..4e3a4a3d6 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3052,37 +3052,37 @@ template __global__ void gemm_device(int M, //typedef cub::BlockReduce BlockReduce; //// Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; - int col_offset = blockIdx.x *8; + int col_offset = blockIdx.x *16; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; T local_A[1]; - T local_B[8]; + T local_B[16]; - const int a_tile_offset = (32*16 + 16); - const int b_tile_offset = (16*8 + 16); - const int c_tile_offset = 32*8 + 24; + const int a_tile_offset = (16*16 + 16); + const int b_tile_offset = (16*16 + 16); + const int c_tile_offset = 16*16 + 24; - __shared__ T smem_A[2*batch_size_warps*32*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*8 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[32*8]; + __shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[16*16]; - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); - for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x) - smem_A[i] = T(0); + //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) + // smem_A[i] = T(0); - for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x) - smem_B[i] = T(0); + //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) + // smem_B[i] = T(0); - for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x) + for(int i = threadIdx.x; i < 16*16; i+=blockDim.x) smem_C[i] = T(0); __syncthreads(); @@ -3099,14 +3099,14 @@ template __global__ void gemm_device(int M, { local_A[0] = A[idx]; - #pragma unroll 8 - for(int col = 0; col < 8; col++) + #pragma unroll 16 + for(int col = 0; col < 16; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0]; - #pragma unroll 8 - for(int col = 0; col < 8; col++) + #pragma unroll 16 + for(int col = 0; col < 16; col++) smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; } ticktock = ticktock == 0 ? 1 : 0; @@ -3120,14 +3120,14 @@ template __global__ void gemm_device(int M, { local_A[0] = A[idx]; - #pragma unroll 8 - for(int col = 0; col < 8; col++) + #pragma unroll 16 + for(int col = 0; col < 16; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - #pragma unroll 8 - for(int col = 0; col < 8; col++) + #pragma unroll 16 + for(int col = 0; col < 16; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } ticktock = ticktock == 0 ? 1 : 0; @@ -3143,7 +3143,7 @@ template __global__ void gemm_device(int M, // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } @@ -3185,7 +3185,7 @@ template __global__ void gemm_device(int M, //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + if(threadIdx.x < 16 && col_offset + threadIdx.x < M) out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; } diff --git a/csrc/ops.cu b/csrc/ops.cu index c1c27b8b1..d0e903fe8 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { - int num_blocks = (m+7)/8; + int num_blocks = (m+15)/16; cout << num_blocks << endl; cout << lda << endl; From 7bfa09d0fcaa524863bcc8ea71436f99423bbd3f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 1 May 2023 16:38:09 -0700 Subject: [PATCH 45/63] 8x32 240 6 warps. --- csrc/kernels.cu | 50 ++++++++++++++++++++++++++----------------------- csrc/ops.cu | 6 ++++-- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4e3a4a3d6..b03c6ca6f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 4 +#define WARPS 6 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3052,26 +3052,26 @@ template __global__ void gemm_device(int M, //typedef cub::BlockReduce BlockReduce; //// Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; - int col_offset = blockIdx.x *16; + int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; T local_A[1]; - T local_B[16]; + T local_B[32]; - const int a_tile_offset = (16*16 + 16); - const int b_tile_offset = (16*16 + 16); - const int c_tile_offset = 16*16 + 24; + const int a_tile_offset = (8*16 + 16); + const int b_tile_offset = (16*32 + 16); + const int c_tile_offset = 8*32 + 24; - __shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[16*16]; + __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); @@ -3082,7 +3082,7 @@ template __global__ void gemm_device(int M, //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) // smem_B[i] = T(0); - for(int i = threadIdx.x; i < 16*16; i+=blockDim.x) + for(int i = threadIdx.x; i < 8*32; i+=blockDim.x) smem_C[i] = T(0); __syncthreads(); @@ -3099,14 +3099,14 @@ template __global__ void gemm_device(int M, { local_A[0] = A[idx]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; } ticktock = ticktock == 0 ? 1 : 0; @@ -3120,14 +3120,14 @@ template __global__ void gemm_device(int M, { local_A[0] = A[idx]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } ticktock = ticktock == 0 ? 1 : 0; @@ -3143,7 +3143,7 @@ template __global__ void gemm_device(int M, // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } @@ -3185,7 +3185,7 @@ template __global__ void gemm_device(int M, //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - if(threadIdx.x < 16 && col_offset + threadIdx.x < M) + if(threadIdx.x < 32 && col_offset + threadIdx.x < M) out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; } @@ -3470,18 +3470,22 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index d0e903fe8..2ccb4182b 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { - int num_blocks = (m+15)/16; + int num_blocks = (m+31)/32; cout << num_blocks << endl; cout << lda << endl; @@ -693,7 +693,9 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } From f9bfea8f2335a63fbb7b24175e1fa2951ee55bf1 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 07:24:12 -0700 Subject: [PATCH 46/63] Baseline for debugging. --- bitsandbytes/functional.py | 2 +- csrc/kernels.cu | 31 ++++++++++++++++++++++++++++--- csrc/ops.cu | 16 ++++++++-------- tests/test_functional.py | 36 +++++++++++++++++++++++++++++------- 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f725c1c63..b4cbd28d8 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1467,7 +1467,7 @@ def cutlass3_gemm( lda = Bshape[1] ldc = Bshape[0] ldb = (ldb+1)//2 - print(m, n, k, lda, ldb, ldc) + #print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) m = ct.c_int32(m) n = ct.c_int32(n) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index b03c6ca6f..477904cf2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3061,9 +3061,8 @@ template __global__ void gemm_device(int M, T local_A[1]; T local_B[32]; - const int a_tile_offset = (8*16 + 16); - const int b_tile_offset = (16*32 + 16); - const int c_tile_offset = 8*32 + 24; + const int a_tile_offset = (8*16); + const int b_tile_offset = (16*32); __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; @@ -3109,6 +3108,19 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = T(0.0); + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = T(0.0f); + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = T(0.0f); + } ticktock = ticktock == 0 ? 1 : 0; for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32) @@ -3130,6 +3142,19 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } ticktock = ticktock == 0 ? 1 : 0; if(warp_id == (WARPS-1)) diff --git a/csrc/ops.cu b/csrc/ops.cu index 2ccb4182b..6bf1e89c9 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -680,14 +680,14 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out int num_blocks = (m+31)/32; - cout << num_blocks << endl; - cout << lda << endl; - cout << ldb << endl; - cout << ldc << endl; - - cout << m << endl; - cout << n << endl; - cout << k << endl; + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; //if(bits == 32) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); diff --git a/tests/test_functional.py b/tests/test_functional.py index 5f90f693c..25fbb5ba5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2355,25 +2355,47 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for i in range(1): + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.rand(1, 4096, dtype=dtype, device='cuda') - B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, 128+32, dtype=dtype, device='cuda') + B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128) #print('') #print(A) #print(B.t()) + #A[:, :-3] = 0 + #B[:, :-3] = 0 C1 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, B.t()) - print(C1) - print(C2) - - torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.06) + err = C1-C2 + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(err.mean()).item() + mag = torch.abs(C1).mean() + relerr = err/mag + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, mag.item(), relerr.item()) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' + + c = int(C1.numel()*0.001) + assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) From 9192c9de648338dd9281368ed0bff20dc123490b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 07:50:32 -0700 Subject: [PATCH 47/63] Tighter and scaled error analysis. --- csrc/kernels.cu | 15 ++++++- tests/test_functional.py | 97 +++++++++++++++++++++++----------------- 2 files changed, 70 insertions(+), 42 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 477904cf2..2fa288f1d 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3123,6 +3123,7 @@ template __global__ void gemm_device(int M, } ticktock = ticktock == 0 ? 1 : 0; + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; @@ -3155,8 +3156,9 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } - ticktock = ticktock == 0 ? 1 : 0; + //ticktock = ticktock == 0 ? 1 : 0; + __syncthreads(); if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { @@ -3166,11 +3168,22 @@ template __global__ void gemm_device(int M, } } + //__syncthreads(); + //if(warp_id == (WARPS-1)) + // for(int k = 0; k < batch_size_warps; k++) + // { + // wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + // wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + // wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + // } + __syncthreads(); + // 129 mu if(warp_id == (WARPS-1)) wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); __syncthreads(); + //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); diff --git a/tests/test_functional.py b/tests/test_functional.py index 25fbb5ba5..050098479 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2355,47 +2355,62 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for i in range(100): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(1, 4096, dtype=dtype, device='cuda') - #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, 128+32, dtype=dtype, device='cuda') - B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128) - - #print('') - #print(A) - #print(B.t()) - #A[:, :-3] = 0 - #B[:, :-3] = 0 - - - C1 = torch.matmul(A, B.t()) - C2 = F.cutlass3_gemm(A, B.t()) - err = C1-C2 - - # tensor cores are non-deterministic - # so we need to analyze errors around the mean - # to test our implementation - err = torch.abs(err.mean()).item() - mag = torch.abs(C1).mean() - relerr = err/mag - - if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: - print('') - print(i, err, mag.item(), relerr.item()) - print(A.flatten()[-6:]) - print(B.flatten()[-6:]) - out = A.flatten()[-6:]*B.flatten()[-6:] - print(out) - print(out[:-1].sum()) - print('='*80) - print(C1.flatten()[-6:]) - print(C2.flatten()[-6:]) - #assert False, 'ERROR' - - c = int(C1.numel()*0.001) - assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) + for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(100): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*496, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-3] = 0 + #B[:, :-3] = 0 + + + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + # print('') + # print(i, err, mag.item(), relerr.item()) + # print(A.flatten()[-6:]) + # print(B.flatten()[-6:]) + # out = A.flatten()[-6:]*B.flatten()[-6:] + # print(out) + # print(out[:-1].sum()) + # print('='*80) + # print(C1.flatten()[-6:]) + # print(C2.flatten()[-6:]) + # #assert False, 'ERROR' + + c = int(C1.numel()*0.00125*(dim/256))+1 + assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) From 9aa232cc3918ef722791c2a6775aaa807ad72109 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 07:53:29 -0700 Subject: [PATCH 48/63] Initial. --- tests/test_functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index 050098479..808c1ce63 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2406,6 +2406,7 @@ def test_cutlass3_gemm(dtype): # #assert False, 'ERROR' c = int(C1.numel()*0.00125*(dim/256))+1 + assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) From 394749db718526aa7810333f0f90caa2b6af8554 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 08:58:59 -0700 Subject: [PATCH 49/63] Correct implementation 240. --- csrc/kernels.cu | 48 +++++++++++++++------------------------- tests/test_functional.py | 20 +++++++++++------ 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2fa288f1d..8ce881c32 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3061,8 +3061,8 @@ template __global__ void gemm_device(int M, T local_A[1]; T local_B[32]; - const int a_tile_offset = (8*16); - const int b_tile_offset = (16*32); + const int a_tile_offset = (8*16 + 16); + const int b_tile_offset = (16*32 + 16); __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; @@ -3074,23 +3074,10 @@ template __global__ void gemm_device(int M, wmma::fill_fragment(c_frag, 0.0f); - - //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) - // smem_A[i] = T(0); - - //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) - // smem_B[i] = T(0); - for(int i = threadIdx.x; i < 8*32; i+=blockDim.x) smem_C[i] = T(0); __syncthreads(); - //#pragma unroll 8 - //for(int k = 0; k < 8; k++) - //local_C[k] = T(0); - - //int block_idx = 0; - //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) int ticktock = 0; int idx = 0 + threadIdx.x; // prefetch @@ -3102,29 +3089,29 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; - smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0]; + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; #pragma unroll 32 for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } else if(warp_id < (WARPS-1)) { local_A[0] = T(0.0); - smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) - local_B[col] = T(0.0f); + local_B[col] = 0.0f; #pragma unroll 32 for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = T(0.0f); + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; @@ -3156,7 +3143,7 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } - //ticktock = ticktock == 0 ? 1 : 0; + ticktock = ticktock == 0 ? 1 : 0; __syncthreads(); if(warp_id == (WARPS-1)) @@ -3168,14 +3155,15 @@ template __global__ void gemm_device(int M, } } - //__syncthreads(); - //if(warp_id == (WARPS-1)) - // for(int k = 0; k < batch_size_warps; k++) - // { - // wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - // wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - // wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - // } + __syncthreads(); + ticktock = ticktock == 0 ? 1 : 0; + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } __syncthreads(); // 129 mu diff --git a/tests/test_functional.py b/tests/test_functional.py index 808c1ce63..4c86d83e1 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -18,12 +18,15 @@ k = 20 -def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0): +def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): idx = torch.isclose(a, b, rtol, atol) sumval = (idx == 0).sum().item() if sumval > count: - print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_allclose(a, b, rtol, atol) + if throw: + print(f"Too many values not close: assert {sumval} < {count}") + torch.testing.assert_allclose(a, b, rtol, atol) + + return sumval class FFN(torch.nn.Module): @@ -2355,7 +2358,9 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + for dim in [4096]: errs = [] relerrs = [] max_err = 0 @@ -2366,7 +2371,7 @@ def test_cutlass3_gemm(dtype): #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') A = torch.randn(1, dim+0, dtype=dtype, device='cuda') - B = torch.randn(4*496, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') #print(A) @@ -2405,9 +2410,10 @@ def test_cutlass3_gemm(dtype): # print(C2.flatten()[-6:]) # #assert False, 'ERROR' - c = int(C1.numel()*0.00125*(dim/256))+1 + c = int(C1.numel()*0.0014*(dim/256))+1 - assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + #print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) From 4decb3cc6878a7d51e92dd5f48ec0fb25ec8ba19 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 09:38:14 -0700 Subject: [PATCH 50/63] Removed uncessary sync. --- csrc/kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8ce881c32..d09f78a87 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3145,7 +3145,6 @@ template __global__ void gemm_device(int M, } ticktock = ticktock == 0 ? 1 : 0; - __syncthreads(); if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { From 89cccd8196b885de777cc6f627bd05c96c700300 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 09:40:31 -0700 Subject: [PATCH 51/63] A tile multi-tiling. --- csrc/kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d09f78a87..a528d16de 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3061,10 +3061,10 @@ template __global__ void gemm_device(int M, T local_A[1]; T local_B[32]; - const int a_tile_offset = (8*16 + 16); + const int a_tile_offset = (16 + 16); const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_A[8*16 + (4*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; __shared__ T smem_C[8*32]; From 77f15fdce9f11324f6616e4fccc03d16f61347e6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 11:38:11 -0700 Subject: [PATCH 52/63] Shared memory efficient 240. --- csrc/kernels.cu | 80 ++++++++++------------------------------ csrc/ops.cu | 2 +- tests/test_functional.py | 4 +- 3 files changed, 22 insertions(+), 64 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a528d16de..8b5544a1e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 6 +#define WARPS 5 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3061,23 +3061,18 @@ template __global__ void gemm_device(int M, T local_A[1]; T local_B[32]; - const int a_tile_offset = (16 + 16); + const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[8*16 + (4*16*(batch_size_warps-1))]; + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[8*32]; + //__shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - for(int i = threadIdx.x; i < 8*32; i+=blockDim.x) - smem_C[i] = T(0); - __syncthreads(); - int ticktock = 0; int idx = 0 + threadIdx.x; // prefetch @@ -3155,63 +3150,24 @@ template __global__ void gemm_device(int M, } __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + ticktock = ticktock == 0 ? 1 : 0; - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - __syncthreads(); + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - __syncthreads(); - + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - //if(threadIdx.x >= 16){ return; } - //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); - - //if(threadIdx.x < 32) - //if(half_warp_lane < 8 && half_warp_id > 0) - // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - // atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]); - //__syncthreads(); - - //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); - //if(threadIdx.x == 0) - // for(int row = 0; row < 32; row++) - // { - // printf("row %i ", row); - // for(int id = 0; id < 4; id++) - // { - // printf(" id %i: ", id); - // for(int k = 0; k < 8; k++) - // printf("%f ", (float)smem_C[k + (row*8) + (id*32*8)]); - // printf("\n"); - // } - // } - - //__syncthreads(); - - //if((float)local_C[0] !=0.0f) - // printf("%i %i %f\n", warp_lane, warp_id, (float)local_C[0]); - //local_C[0] = WarpReduce(temp_storage).Sum(local_C[0]); - - //__syncwarp(); - - ////for(int i = threadIdx.x; i < 32*8; i+=blockDim.x) - ////{ - // if((float)local_C[0] !=0.0f) - // printf("%i %f\n", 0, (float)local_C[0]); - //} - - //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) - //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - if(threadIdx.x < 32 && col_offset + threadIdx.x < M) - out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -3496,6 +3452,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); @@ -3506,6 +3463,7 @@ template __global__ void gemm_device(int M, int N, int K, half * _ //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); diff --git a/csrc/ops.cu b/csrc/ops.cu index 6bf1e89c9..16d82f953 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -693,7 +693,7 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); diff --git a/tests/test_functional.py b/tests/test_functional.py index 4c86d83e1..62dd1cb76 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2358,9 +2358,9 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: - for dim in [4096]: + #for dim in [4096]: errs = [] relerrs = [] max_err = 0 From 869b7e83b506cdb7e342e4939580104b486ed9ba Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 12:10:32 -0700 Subject: [PATCH 53/63] Warp multi-specialization 240. --- csrc/kernels.cu | 62 +++++++++++++++++++++++++++++++++------- tests/test_functional.py | 8 +++--- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8b5544a1e..65ed19ecd 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3058,8 +3058,8 @@ template __global__ void gemm_device(int M, const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; - T local_A[1]; - T local_B[32]; + T local_A[2]; + T local_B[64]; const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); @@ -3075,14 +3075,32 @@ template __global__ void gemm_device(int M, int ticktock = 0; int idx = 0 + threadIdx.x; + int loaded_values = 0; // prefetch if(idx < K && warp_id < (WARPS-1)) { - local_A[0] = A[idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = B[(col_offset+col)*ldb+idx]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + } + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+32]; + } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3113,11 +3131,35 @@ template __global__ void gemm_device(int M, __syncthreads(); if(idx < K && warp_id < (WARPS-1)) { - local_A[0] = A[idx]; + //local_A[0] = A[idx]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = B[(col_offset+col)*ldb+idx]; + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + } + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+32]; + + + } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; diff --git a/tests/test_functional.py b/tests/test_functional.py index 62dd1cb76..e9a67f5c9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype): #print('') #print(A) #print(B.t()) - #A[:, :-3] = 0 - #B[:, :-3] = 0 + #A[:, :-1] = 0 + #B[:, :-1] = 0 C1 = torch.matmul(A, B.t()) @@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype): #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: # print('') - # print(i, err, mag.item(), relerr.item()) + # print(i, err, relerr) # print(A.flatten()[-6:]) # print(B.flatten()[-6:]) # out = A.flatten()[-6:]*B.flatten()[-6:] @@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype): c = int(C1.numel()*0.0014*(dim/256))+1 - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True) #print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) From 264a948539d219e6b9a8fc8b9d92120d76b8878b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 16:15:38 -0700 Subject: [PATCH 54/63] 4-bit draft; 128 vector load 240. --- bitsandbytes/functional.py | 6 +- csrc/kernels.cu | 295 ++++++++++++++++++++++++------------- csrc/ops.cu | 18 +-- tests/test_functional.py | 95 ++++++++---- 4 files changed, 278 insertions(+), 136 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b4cbd28d8..e5b1bf7f1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1385,10 +1385,12 @@ def cutlass3_gemm( #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: Bshape = B.shape + bout = Bshape[1] else: Bshape = state[1] + bout = Bshape[0] if out is None: - out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device) + out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) sA = A.shape sB = B.shape @@ -1464,7 +1466,7 @@ def cutlass3_gemm( if state is not None: m = Bshape[0] k = Bshape[1] - lda = Bshape[1] + lda = Bshape[0] ldc = Bshape[0] ldb = (ldb+1)//2 #print(m, n, k, lda, ldb, ldc) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 65ed19ecd..2373b911a 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3044,22 +3044,15 @@ template __device__ inline void vector_l #define WARPS 5 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { - - typedef cub::WarpReduce WarpReduce; - // Allocate WarpReduce shared memory for one warp - //__shared__ typename WarpReduce::TempStorage temp_storage; - - //typedef cub::BlockReduce BlockReduce; - //// Allocate shared memory for BlockReduce - //__shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; - T local_A[2]; - T local_B[64]; + T local_A[4]; + T local_B[128]; const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); @@ -3082,24 +3075,45 @@ template __global__ void gemm_device(int M, if(loaded_values == 0) { local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } - loaded_values = 1; + loaded_values = 3; } else { - local_A[0] = local_A[1]; - loaded_values--; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+32]; + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3139,26 +3153,46 @@ template __global__ void gemm_device(int M, if(loaded_values == 0) { local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } - loaded_values = 1; + loaded_values = 3; + } else { - local_A[0] = local_A[1]; - loaded_values--; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+32]; - + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3215,104 +3249,166 @@ template __global__ void gemm_device(int M, template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage reduce; - int col_offset = blockIdx.x *8; - - T local_A[32]; - unsigned char local_B_4bit[16]; - T local_B[32]; - T local_C[8]; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; - __shared__ T smem_C[8]; + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; - if(threadIdx.x < 8) - smem_C[threadIdx.x] = T(0); - __syncthreads(); + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); - #pragma unroll 8 - for(int k = 0; k < 8; k++) - local_C[k] = T(0); + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); - for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32) + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; - // we load only 8 values per iteration from A, so we - // need to do 4 loads for every single load from B - // for B, we have packed values, so the 16 8-bit values - // turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads - vector_load(local_A, A, idx, idx, K); - vector_load(&(local_A[8]), A, idx+8, idx+8, K); - vector_load(&(local_A[16]), A, idx+16, idx+16, K); - vector_load(&(local_A[24]), A, idx+24, idx+24, K); + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; - for(int col = 0; col < 8; col++) + loaded_values = 1; + } + else { - if((col + col_offset) >= M){ break; } + local_A[0] = local_A[1]; + loaded_values--; - int offset_B = (col_offset+col)*ldb; - // 0111 -> 0.0f in NF4 - // since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111 - vector_load(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111); + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + } + } - int absidx = (idx + offset_B)/blocksize; - half local_absmax = __ldg(&(absmax[absidx])); - //for(int k = 0; k < 16; k++) - //printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F); - //printf("\n"); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - //vector_load(local_A, A, idx, idx, K); + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; - #pragma unroll 16 - for(int k = 0; k < 16; k++) + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; - //if(local_B_4bit[k ] != 0b01110111) - //printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax), - //local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax)); - //local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax; - //local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax; - local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax; - local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax; - //local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax; - //local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + } } + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + #pragma unroll 32 - //for(int k = 0; k < 8; k++) - for(int k = 0; k < 32; k++) + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) { - local_C[col] += local_A[k]*local_B[k]; - //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) - //if((float)local_B[k] != 0.0) - //printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]); + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } - } } - #pragma unroll 8 - for(int k = 0; k < 8; k++) - { - local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); - __syncthreads(); - } + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; - if(threadIdx.x == 0) + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) { - #pragma unroll 8 - for(int k = 0; k < 8; k++) - smem_C[k] = local_C[k]; + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } - else if(threadIdx.x >= 32) - // early return for unused warps - return; - - __syncwarp(); + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - if(threadIdx.x < 8 && col_offset + threadIdx.x < M) - out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; } //#define ROWS 2 @@ -3513,6 +3609,7 @@ template __global__ void gemm_device(int M, int N, int K, half * _ template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/ops.cu b/csrc/ops.cu index 16d82f953..4d68436c4 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -703,17 +703,17 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { - int num_blocks = (m+7)/8; + int num_blocks = (m+31)/32; - cout << num_blocks << endl; - cout << lda << endl; - cout << ldb << endl; - cout << ldc << endl; + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; - cout << m << endl; - cout << n << endl; - cout << k << endl; - kgemm_4bit_inference<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } diff --git a/tests/test_functional.py b/tests/test_functional.py index e9a67f5c9..dc4e40d94 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2358,20 +2358,19 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + debug = True + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: - #for dim in [4096]: + for dim in [4096]: + #for dim in [128+1]: errs = [] relerrs = [] max_err = 0 max_relerr = 0 for i in range(100): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(1, 4096, dtype=dtype, device='cuda') - #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + A = torch.randn(1, dim, dtype=dtype, device='cuda') B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') #print(A) @@ -2397,7 +2396,7 @@ def test_cutlass3_gemm(dtype): errs.append(err) relerrs.append(relerr) - #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: # print('') # print(i, err, relerr) # print(A.flatten()[-6:]) @@ -2412,7 +2411,7 @@ def test_cutlass3_gemm(dtype): c = int(C1.numel()*0.0014*(dim/256))+1 - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) #print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) @@ -2422,29 +2421,73 @@ def test_cutlass3_gemm(dtype): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_gemm_4bit(dtype): - for i in range(1): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #torch.random.manual_seed(17) - A = torch.rand(1, 4096, dtype=dtype, device='cuda') - B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + #for dim in [32]: + for dim in [4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 - #print('') - #print(A) - #print(B) + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) - qB, state = F.quantize_nf4(B) - F.dequantize_nf4(qB, state) + C3 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + print(C1.shape, C2.shape) - C1 = torch.matmul(A, B.t()) - #C1 = bnb.matmul_4bit(A, qB.t(), state) - C2 = F.cutlass3_gemm(A, qB.t(), state=state) - #print(C1) - #print(C2) + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, relerr) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' - #torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005) + c = int(C1.numel()*0.0014*(dim/256))+1 + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) def test_pipeline_func(): a = torch.rand(2, 4).cuda() From ec38ba95b0cd6bf3dadfccf366cd8917acf59c4b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 6 May 2023 11:14:06 -0700 Subject: [PATCH 55/63] Added paging. --- bitsandbytes/cextension.py | 2 + bitsandbytes/functional.py | 55 +++++++++++++++++++++++++++ csrc/kernels.cu | 76 ++++++++++---------------------------- csrc/kernels.cuh | 18 +-------- csrc/ops.cu | 25 ++++++++----- csrc/ops.cuh | 9 ++++- csrc/pythonInterface.c | 32 +++++++++++++++- tests/test_functional.py | 40 +++++++++++++++++--- 8 files changed, 167 insertions(+), 90 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 8adca9312..17c2a464e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -26,6 +26,8 @@ lib.cadam_8bit_blockwise_fp32 lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p + lib.cget_stream.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e5b1bf7f1..f54847545 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,6 +130,61 @@ def get_instance(cls): cls._instance.initialize() return cls._instance +dtype2bytes = {} +dtype2bytes[torch.float32] = 4 +dtype2bytes[torch.float16] = 2 +dtype2bytes[torch.bfloat16] = 2 +dtype2bytes[torch.uint8] = 1 +dtype2bytes[torch.int8] = 1 + +def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): + num_bytes = dtype2bytes[dtype]*prod(shape) + cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + new_array = np.ctypeslib.as_array(c_ptr, shape=shape) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)) + out.is_paged = True + out.page_deviceid = device.index + return out + +def prefetch_tensor(A, to_cpu=False): + assert A.is_paged, 'Only paged tensors can be prefetched!' + if to_cpu: + deviceid = -1 + else: + deviceid = A.page_deviceid + + num_bytes = dtype2bytes[A.dtype]*A.numel() + lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + +def elementwise_func(func_name, A, B, value, prefetch=True): + func = None + if A.dtype == torch.float32: + func = getattr(lib, f'c{func_name}_fp32', None) + cvalue = ct.c_float(value) + elif A.dtype == torch.uint8: + func = getattr(lib, f'c{func_name}_uint8', None) + cvalue = ct.c_uint8(value) + + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + + is_managed = getattr(A, 'is_managed', False) + if is_managed and prefetch: + prefetch_tensor(A) + if B is not None: prefetch_tensor(B) + + func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) + if A.is_paged or B.is_paged: + # paged function are fully asynchronous + # if we return from this function, we want to the tensor + # to be in the correct state, that is the final state after the + # operation occured. So we synchronize. + torch.cuda.synchronize() + +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + def create_linear_map(signed=True, total_bits=8, add_zero=True): sign = (-1.0 if signed else 0.0) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2373b911a..e1a315522 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3522,49 +3522,23 @@ template __global__ void kgemm_4bit_inference(int M, i //} -__device__ void compute(float* global_out, float const* shared_in) +template __global__ void kfunc(T *A, T *B, T value, long n) { - -} -template -__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) { - auto grid = cooperative_groups::this_grid(); - auto block = cooperative_groups::this_thread_block(); - assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size - - extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes - size_t shared_offset[stages_count]; - for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size(); - - __shared__ cuda::pipeline_shared_state< - cuda::thread_scope::thread_scope_block, - stages_count - > shared_state; - auto pipeline = cuda::make_pipeline(block, &shared_state); - - auto block_batch = [&](size_t batch) -> int { - return block.group_index().x * block.size() + grid.size() * batch; - }; - - // compute_batch: next batch to process - // fetch_batch: next batch to fetch from global memory - for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) { - // The outer loop iterates over the computation of the batches - for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) { - // This inner loop iterates over the memory transfers, making sure that the pipeline is always full - pipeline.producer_acquire(); - size_t shared_idx = fetch_batch % stages_count; - size_t batch_idx = fetch_batch; - size_t block_batch_idx = block_batch(batch_idx); - cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline); - pipeline.producer_commit(); - } - pipeline.consumer_wait(); - int shared_idx = compute_batch % stages_count; - int batch_idx = compute_batch; - compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]); - pipeline.consumer_release(); + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; } + } } @@ -3572,19 +3546,10 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TEMPLATE DEFINITIONS //============================================================== -//template -//__global__ static -//__launch_bounds__(decltype(size(CThreadLayout{}))::value) -//void -//gemm_device(MShape M, NShape N, KShape K, -// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, -// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, -// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, -// half alpha, half beta); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); @@ -3611,9 +3576,6 @@ template __global__ void gemm_device(int M, int N, int K, half * _ template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); - -//template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); -template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 4951031a2..29c6683d7 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -122,23 +122,9 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -//template -//__global__ static -//__launch_bounds__(decltype(size(CThreadLayout{}))::value) -//void -//gemm_device(MShape M, NShape N, KShape K, -// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, -// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, -// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, -// Alpha alpha, Beta beta); -template -__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); - template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kfunc(T *A, T *B, T value, long n); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 4d68436c4..7d13b7142 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -663,16 +663,6 @@ template void extractOutliers(char * A, int *idx, char *out, int id } -void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -{ - - int threads = 256; - int num_blocks = (n+(256*batch_size)+1)/(batch_size*256); - - with_staging_unified<2><<>>(A, B, n, batch_size); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) @@ -717,10 +707,25 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8919c6016..e9d2e229c 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -93,6 +93,13 @@ typedef enum DataType_t NF4 = 2, } DataType_t; +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + class Context { public: @@ -193,6 +200,6 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void func(T *A, T *B, T value, long n); -void pipeline_test(float *A, float *B, size_t n, size_t batch_size); #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 26f16f218..7271430a7 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,6 +28,14 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ +void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_g##gbits(gtype *g, gtype *p, \ @@ -314,7 +322,6 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } - void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } @@ -325,6 +332,29 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ + + CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) + CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) + CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) + CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index dc4e40d94..145c26786 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2489,8 +2489,38 @@ def test_gemm_4bit(dtype): print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) print(dim, (max_err.item(), max_relerr.item())) -def test_pipeline_func(): - a = torch.rand(2, 4).cuda() - out = F.pipeline_test(a, 2) - print(a) - print(out) +def test_managed(): + n = 32*10 + A = F.get_paged(n, n, dtype=torch.float32) + B = F.get_paged(n, n, dtype=torch.uint8) + B2 = F.get_paged(n, n, dtype=torch.float32) + assert A.is_paged + assert B.is_paged + assert A.page_deviceid==0 + assert B.page_deviceid==0 + F.fill(A, 17.0) + F.fill(B, 17) + F.fill(B2, 2) + assert (A==17).sum().item() == n*n + assert (B==17).sum().item() == n*n + C = A*B.float() + assert (C==289).sum().item() == n*n + F._mul(A, B2) + F._mul(A, B2) + F._mul(A, B2) + assert (A==17*(2**3)).sum().item() == n*n + # F.prefetch_tensor(A) + # F.prefetch_tensor(B) + + + # F.fill(B2, 17.0) + # F._mul(A, B2) + + # F.prefetch_tensor(A, to_cpu=True) + # F.prefetch_tensor(B, to_cpu=True) + # F.prefetch_tensor(B2, to_cpu=True) + # torch.cuda.synchronize() + + # assert (A==17).sum().item() == n*n + + # torch.testing.assert_allclose(A, torch.ones(A.shape)*289) From 44d68ff29cc19e54db13242e7f8cff3c7e4c5196 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 6 May 2023 14:59:29 -0700 Subject: [PATCH 56/63] Added paged optimizers. --- bitsandbytes/cextension.py | 1 - bitsandbytes/functional.py | 33 ++++++++-- bitsandbytes/optim/__init__.py | 4 +- bitsandbytes/optim/adam.py | 106 +++++++------------------------ bitsandbytes/optim/adamw.py | 108 ++++++++------------------------ bitsandbytes/optim/optimizer.py | 72 ++++++++++----------- tests/test_functional.py | 14 ++--- tests/test_optim.py | 87 +++++++++++-------------- 8 files changed, 158 insertions(+), 267 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 17c2a464e..29621c9b3 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -27,7 +27,6 @@ lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p - lib.cget_stream.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f54847545..a6ed67513 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -83,6 +83,27 @@ def prod(iterable): lib.cadagrad_8bit_blockwise_fp16, ) +class GlobalPageManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.paged_tensors = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def prefetch_all(self, to_cpu=False): + for t in self.paged_tensors: + prefetch_tensor(t, to_cpu) + + class CUBLAS_Context: _instance = None @@ -142,7 +163,7 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) - out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) out.is_paged = True out.page_deviceid = device.index return out @@ -415,10 +436,14 @@ def is_on_gpu(tensors): gpu_ids = set() for t in tensors: if t is None: continue # NULL pointers are fine - on_gpu &= t.device.type == 'cuda' - gpu_ids.add(t.device.index) + is_paged = getattr(t, 'is_paged', False) + on_gpu &= (t.device.type == 'cuda' or is_paged) + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}') + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') return on_gpu def get_ptr(A: Tensor) -> ct.c_void_p: diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 8c8a8f41e..994dae580 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -6,8 +6,8 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit -from .adam import Adam, Adam8bit, Adam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit +from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 396aeb85f..86981eb86 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,92 +14,34 @@ class Adam(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + +class PagedAdam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 022e64c46..21077f1a0 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,89 +5,35 @@ from bitsandbytes.optim.optimizer import Optimizer2State -class AdamW(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedAdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 867ad3dd8..4f8dcc7bc 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -92,10 +92,12 @@ def register_module_override(self, module, param_name, config): class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): + def __init__(self, params, defaults, optim_bits=32, is_paged=False): super().__init__(params, defaults) self.initialized = False self.name2qmap = {} + self.is_paged = is_paged + self.page_mng = F.GlobalPageManager.get_instance() self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { @@ -207,7 +209,9 @@ def to_gpu(self): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - self.state[p][k] = v.to(p.device) + is_paged = getattr(v, 'is_paged', False) + if not is_paged: + self.state[p][k] = v.to(p.device) def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: @@ -252,6 +256,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True + if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -261,6 +266,11 @@ def step(self, closure=None): self.init_state(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex) + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + return loss @@ -289,6 +299,16 @@ def update_step(self, group, p, gindex, pindex): "The update_step method needs to be overridden" ) + def get_state_buffer(self, p, dtype=torch.float32): + if not self.is_paged or p.numel() < 1e5: + return torch.zeros_like(p, dtype=dtype, device=p.device) + else: + # > 1 MB + buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) + F.fill(buff, 0) + self.page_mng.paged_tensors.append(buff) + return buff + class Optimizer2State(Optimizer8bit): def __init__( @@ -306,6 +326,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -325,7 +346,7 @@ def __init__( f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -365,18 +386,8 @@ def init_state(self, group, p, gindex, pindex): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) + state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -388,20 +399,10 @@ def init_state(self, group, p, gindex, pindex): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap2"] = self.name2qmap["udynamic"] if config["block_wise"]: @@ -538,6 +539,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -553,7 +555,7 @@ def __init__( f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -593,12 +595,7 @@ def init_state(self, group, p, gindex, pindex): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -607,12 +604,7 @@ def init_state(self, group, p, gindex, pindex): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] if config["block_wise"]: diff --git a/tests/test_functional.py b/tests/test_functional.py index 145c26786..6bda1a8a5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -172,8 +172,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.011 assert relerr < 0.018 - print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) - print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -189,8 +189,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 assert relerr < 0.015 - print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): @@ -320,7 +320,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - print("") + #print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -349,8 +349,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - print(mean(errors)) - print(mean(relerrors)) + #print(mean(errors)) + #print(mean(relerrors)) def test_stable_embedding(): diff --git a/tests/test_optim.py b/tests/test_optim.py index a13b33207..a5ecb6e1d 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -39,6 +39,8 @@ def rm_path(path): bnb.optim.Adam, ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) +str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), @@ -48,10 +50,7 @@ def rm_path(path): lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["adam8bit"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), -) +str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -61,10 +60,9 @@ def rm_path(path): lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["adam8bit_blockwise"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), -) +str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) +str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -76,36 +74,25 @@ def rm_path(path): str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] -str2statenames["adam8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["lamb8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["adam8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] -str2statenames["momentum8bit"] = [ - ("momentum_buffer", "state1", "qmap1", "max1") -] -str2statenames["momentum8bit_blockwise"] = [ - ("momentum_buffer", "state1", "qmap1", "absmax1") -] +str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] -str2statenames["rmsprop8bit_blockwise"] = [ - ("square_avg", "state1", "qmap1", "absmax1") -] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = ["adam", "momentum", "rmsprop"] +gtype = [torch.float32, torch.float16] +optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -135,14 +122,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): torch_optimizer.step() for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( + torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2], + bnb_optimizer.state[p2][name2].cuda(), atol=atol, rtol=rtol, ) - torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -152,9 +139,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol) for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( + torch.testing.assert_close( torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, @@ -168,7 +155,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # --> copy the state to keep weights close p1.data = p1.data.to(p2.dtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.to(p2.dtype), p2) + torch.testing.assert_close(p1.to(p2.dtype), p2) if optim_name in ["lars", "lamb"]: assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 @@ -277,7 +264,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol) dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: @@ -331,8 +318,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) - torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) + torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -347,17 +334,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - torch.testing.assert_allclose(s1cpy, s1) + torch.testing.assert_close(s1cpy, s1) num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.to(gtype), p2) + torch.testing.assert_close(p1.to(gtype), p2) for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) @@ -419,28 +406,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state if optim_bits == 32: - torch.testing.assert_allclose(p1, p2) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=5e-5, rtol=1e-4, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=5e-5, rtol=1e-4, ) elif optim_bits == 8: - torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, @@ -472,7 +459,7 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise"] +optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values From 41a9c708148c4a16675244de88352d0437e2d87a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 6 May 2023 18:59:59 -0700 Subject: [PATCH 57/63] Changed prefetching. --- bitsandbytes/functional.py | 5 ++++- bitsandbytes/optim/optimizer.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a6ed67513..2542e4bb5 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -100,7 +100,10 @@ def get_instance(cls): return cls._instance def prefetch_all(self, to_cpu=False): - for t in self.paged_tensors: + # assume the first added, will be hte + # ones that are used first, so swap them in last + # in the case they are evicted again + for t in self.paged_tensors[::-1]: prefetch_tensor(t, to_cpu) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 4f8dcc7bc..921ec0ac2 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -256,7 +256,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True - if self.is_paged: self.page_mng.prefetch_all() + #if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -265,7 +265,9 @@ def step(self, closure=None): if len(state) == 0: self.init_state(group, p, gindex, pindex) + self.prefetch_state(p) self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() if self.is_paged: # all paged operation are asynchronous, we need # to sync to make sure all tensors are in the right state @@ -309,6 +311,13 @@ def get_state_buffer(self, p, dtype=torch.float32): self.page_mng.paged_tensors.append(buff) return buff + def prefetch_state(self, p): + if self.is_paged: + state = self.state[p] + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) + class Optimizer2State(Optimizer8bit): def __init__( From f64cfe65aad56751cabf87c2a9a610e8c43bb981 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 6 May 2023 21:49:16 -0700 Subject: [PATCH 58/63] Fixed prefetch bug for non-paged tensors; added benchmark. --- bitsandbytes/optim/optimizer.py | 9 ++++--- tests/test_optim.py | 44 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 921ec0ac2..41c8d278b 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -314,9 +314,12 @@ def get_state_buffer(self, p, dtype=torch.float32): def prefetch_state(self, p): if self.is_paged: state = self.state[p] - F.prefetch_tensor(state['state1']) - if 'state2' in state: - F.prefetch_tensor(state['state2']) + s1 = state['state1'] + is_paged = getattr(s1, 'is_paged', False) + if is_paged: + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) class Optimizer2State(Optimizer8bit): diff --git a/tests/test_optim.py b/tests/test_optim.py index a5ecb6e1d..e35408e14 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -490,3 +490,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): params = (k - k // 5) * dim1 * dim2 print(optim_name, gtype, s / params) # assert s < 3.9 + +dim1 = [10*1024] +gtype = [torch.float16] +#mode = ['torch', 'bnb'] +mode = ['bnb'] +optimizer_names = ['paged_adamw'] +#optimizer_names = ['paged_adamw8bit_blockwise'] +values = list(product(dim1,gtype, optimizer_names, mode)) +names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) + layers1 = layers1.to(gtype) + layers1 = layers1.cuda() + + large_tensor = None + if mode == 'torch': + optim = str2optimizers[optim_name][0](layers1.parameters()) + else: + optim = str2optimizers[optim_name][1](layers1.parameters()) + # 12 GB + large_tensor = torch.empty((int(4.5e9),), device='cuda') + + torch.cuda.synchronize() + time.sleep(5) + + num_batches = 5 + batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) + lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + + for i in range(num_batches): + print(i) + b = batches[i] + if i ==2: + torch.cuda.synchronize() + t0 = time.time() + + out1 = layers1(b) + + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() + loss1.backward() + optim.step() + torch.cuda.synchronize() + print(mode, time.time() - t0) From 4bd11518293ea30c6792a5baf64f0715739a09ca Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 7 May 2023 15:06:17 -0700 Subject: [PATCH 59/63] Fixed gradient accumulation test. --- bitsandbytes/autograd/_functions.py | 1 - tests/test_modules.py | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index acd90f54b..63b7156b4 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -456,7 +456,6 @@ def backward(ctx, grad_output): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - #grad_B = torch.matmul(grad_output.t(), A) CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) diff --git a/tests/test_modules.py b/tests/test_modules.py index 1319cf7f9..d0a905197 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold): def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) - l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) - l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) - l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) - l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) - opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) - opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) + + opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001) acc_steps = 10 @@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient(): assert l1[0].state.CxB is not None assert l1[1].state.CxB is not None - print(i) if i > 0 and i % acc_steps == 0: opt1.step() opt1.zero_grad(True) @@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient(): # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad) - torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad) + torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize("threshold", [0.0, 2.0]) From 2bce175d156b5c5c1be925cb57fe33215675fafd Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 23 May 2023 18:42:19 -0700 Subject: [PATCH 60/63] Fixed Makefile. --- Makefile | 24 ++----------- bitsandbytes/functional.py | 69 -------------------------------------- tests/test_functional.py | 47 ++++++++++++++------------ 3 files changed, 27 insertions(+), 113 deletions(-) diff --git a/Makefile b/Makefile index ea6ee87d5..c113a3d5c 100644 --- a/Makefile +++ b/Makefile @@ -40,11 +40,6 @@ CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler # Later versions of CUDA support the new architectures -CC_CUDA10x += -gencode arch=compute_75,code=sm_75 - -CC_CUDA110 := -gencode arch=compute_75,code=sm_75 -CC_CUDA110 += -gencode arch=compute_80,code=sm_80 - CC_CUDA11x := -gencode arch=compute_75,code=sm_75 CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_86,code=sm_86 @@ -54,8 +49,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 -#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 -#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 +CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 +CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 @@ -66,16 +61,6 @@ all: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) -cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - cuda110_nomatmul: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o @@ -122,11 +107,6 @@ env: @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" @echo "============================" -cutlass: - if [ ! -d "$(ROOT_DIR)/dependencies/cutlass" ]; then \ - git clone https://github.com/NVIDIA/cutlass.git $(ROOT_DIR)/dependencies/cutlass; \ - fi \ - $(BUILD_DIR): mkdir -p build mkdir -p dependencies diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index cc82943b8..c0eb2dee3 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -128,11 +128,6 @@ def __init__(self): def initialize(self): self.context = {} - # prev_device = torch.cuda.current_device() - # for i in range(torch.cuda.device_count()): - # torch.cuda.set_device(torch.device('cuda', i)) - # self.context.append(ct.c_void_p(lib.get_context())) - # torch.cuda.set_device(prev_device) @classmethod def get_instance(cls): @@ -238,72 +233,8 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): return values else: l = values.numel()//2 - #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) -def create_custom_map(seed=0, scale=0.01): - v = [12, 10, 8, 6, 3, 2, 1] - # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45 - # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48 - - # 13B 100 steps: - # - 4-bit evo: 86.02 - # - 4-bit norm: 78.73 - # - 4-bit FP4: - # - 16-bit: - - # interval search on normal distribution - #v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5 - #v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99 - #v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97 - #v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81 - #v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68 - #v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03 - #v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01 - #v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47 - #v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85 - #v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287 - ##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293 - #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 - #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 - #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 - #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 - - # 7B evo start - #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 - #v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951] - #v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564] - - # 13B evo start - #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042] - #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283] - v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] - - # mean evo 7B + 13B - #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] - - # theoretically optiomal (0.93333) - #v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 - - if seed > 0: - v = np.array(v) - np.random.seed(seed) - v += np.random.randn(7)*scale - print(v.tolist()) - #v[0] += (np.random.randn(1)*0.001)[0] - #v[-1] += (np.random.randn(1)*0.001)[0] - #print(v[0], v[-1]) - v = v.tolist() - values = v + [0]*(256-14) + \ - v[::-1] - - values = torch.Tensor(values) - values[0:7] *= -1 - values = values.sort().values - values /= values.max() - assert values.numel() == 256 - return values - def create_normal_map(offset=0.9677083, use_extra_value=True): if use_extra_value: diff --git a/tests/test_functional.py b/tests/test_functional.py index c2d47961f..cc58324e4 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1773,21 +1773,24 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 2 -seqdim = 2048 +batch_size = 1 +seqdim = 1 values = [] -values.append((batch_size, seqdim, 768, 4 * 768)) +#values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 1024, 4*1024)) #values.append((batch_size, seqdim, 1536, 4*1536)) #values.append((batch_size, seqdim, 2048, 4*2048)) #values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) +values.append((batch_size, seqdim, 4096, 4*4096)) +values.append((batch_size, seqdim, 5120, 4*5120)) +values.append((batch_size, seqdim, 6656, 4*6656)) +values.append((batch_size, seqdim, 8192, 4*8192)) #values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 1 + iters = 80 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -1799,14 +1802,14 @@ def test_bench_matmul(batch, seq, model, hidden): B_nf4, state_nf4= F.quantize_nf4(B) - linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit.eval() outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()) - linearMixedBit.eval() + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) + #linearMixedBit.eval() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() @@ -1898,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden): #torch.cuda.synchronize() #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): - # linear8bit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + linear8bit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit(A) + torch.cuda.synchronize() + print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linearMixedBit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): - # linearMixedBit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + linearMixedBit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linearMixedBit(A) + torch.cuda.synchronize() + print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #linear8bit_train(A) #torch.cuda.synchronize() From 1b8772a8f33fdb47df0c849302cbb7e703571b8c Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 23 May 2023 19:37:38 -0700 Subject: [PATCH 61/63] Added PagedLion and bf16 Lion. --- bitsandbytes/functional.py | 6 +-- bitsandbytes/optim/__init__.py | 2 +- bitsandbytes/optim/lion.py | 95 +++++++--------------------------- csrc/kernels.cu | 3 ++ csrc/ops.cu | 2 + csrc/pythonInterface.c | 12 +++-- tests/test_optim.py | 23 ++++---- 7 files changed, 46 insertions(+), 97 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c0eb2dee3..afa346e6e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -37,10 +37,7 @@ def prod(iterable): lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_16, ) - str2optimizer32bit["lion"] = ( - lib.clion32bit_grad_32, - lib.clion32bit_grad_16, - ) + str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) str2optimizer32bit["adagrad"] = ( lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_16, @@ -89,6 +86,7 @@ def prod(iterable): str2optimizer8bit_blockwise["lion"] = ( lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["adagrad"] = ( lib.cadagrad_8bit_blockwise_grad_fp32, diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 1cfe2410e..83a57bd9f 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -12,5 +12,5 @@ from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2551b68e1..2bde1a447 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,84 +4,27 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State - class Lion(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion8bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion32bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedLion(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedLion8bit(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedLion32bit(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 30e5e2e8d..11ad63f23 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) @@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, __nv_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/ops.cu b/csrc/ops.cu index 7f3a83152..9c042fa66 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 776497b67..23a0364cc 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16) MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) -MAKE_FUNC32(lion, LION, float, 32) -MAKE_FUNC32(lion, LION, half, 16) +MAKE_FUNC32(lion, LION, float, fp32) +MAKE_FUNC32(lion, LION, half, fp16) +MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) @@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, float, fp32) +MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -201,8 +203,9 @@ extern "C" MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) - MAKE_CFUNC32(lion, float, 32) - MAKE_CFUNC32(lion, half, 16) + MAKE_CFUNC32(lion, float, fp32) + MAKE_CFUNC32(lion, half, fp16) + MAKE_CFUNC32(lion, __nv_bfloat16, bf16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) @@ -245,6 +248,7 @@ extern "C" MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, float, fp32) + MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } diff --git a/tests/test_optim.py b/tests/test_optim.py index 98e4289dd..9e90083a9 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -19,11 +19,11 @@ k = 20 def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() if error_count > max_error_count: print(f"Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) def get_temp_dir(): @@ -35,13 +35,8 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) -str2bf16support = {} -str2bf16support['adam8bit_blockwise'] = True - str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) -# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["momentum_pytorch"] = ( None, @@ -51,8 +46,8 @@ def rm_path(path): str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) -# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers["lion"] = (Lion, bnb.optim.Lion) +str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), @@ -76,6 +71,7 @@ def rm_path(path): str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) +str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -90,6 +86,7 @@ def rm_path(path): str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lion"] = [("exp_avg", "state1")] +str2statenames["paged_lion"] = [("exp_avg", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] @@ -104,15 +101,17 @@ def rm_path(path): str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] +str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion'] +gtype = [torch.float32, torch.float16, torch.bfloat16] +optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -254,7 +253,7 @@ def test_global_config(dim1, dim2, gtype): @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name not in str2bf16support: return + if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -485,7 +484,7 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise'] +optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values From 0f40fa3f0a198802056e29ba183eaabc6751d565 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 23 May 2023 19:55:52 -0700 Subject: [PATCH 62/63] Bumped version. --- CHANGELOG.md | 11 +++++++++++ Makefile | 3 +-- setup.py | 4 ++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2de70d371..eb7ac0dd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -228,3 +228,14 @@ Deprecated: Features: - Added Int8 SwitchBack layers - Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) + + +### 0.39.0 + + +Features: + - 4-bit matrix multiplication for Float4 and NormalFloat4 data types. + - Added 4-bit quantization routines + - Doubled quantization routines for 4-bit quantization + - Paged optimizers for Adam and Lion. + - bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states. diff --git a/Makefile b/Makefile index c113a3d5c..1f2b281af 100644 --- a/Makefile +++ b/Makefile @@ -25,8 +25,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include -INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib +LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell diff --git a/setup.py b/setup.py index 009fd3d94..b683bfcf1 100644 --- a/setup.py +++ b/setup.py @@ -18,10 +18,10 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.38.1", + version=f"0.39.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", - description="8-bit optimizers and matrix multiplication routines.", + description="k-bit optimizers and matrix multiplication routines.", license="MIT", keywords="gpu optimizers optimization 8-bit quantization compression", url="https://github.com/TimDettmers/bitsandbytes", From ac5550a0238286377ee3f58a85aeba1c40493e17 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 30 May 2023 19:06:59 -0700 Subject: [PATCH 63/63] Added changes for deployment. --- Makefile | 1 - csrc/kernels.cu | 10 +++++++--- deploy.sh | 11 ----------- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index 1f2b281af..5fa1f1736 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 11ad63f23..ab12c3702 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -16,15 +16,12 @@ #include #include -#include -#include #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 -using namespace nvcuda; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { @@ -3094,6 +3091,9 @@ template __device__ inline void vector_l #define WARPS 5 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; @@ -3294,11 +3294,14 @@ template __global__ void gemm_device(int M, if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; @@ -3459,6 +3462,7 @@ template __global__ void kgemm_4bit_inference(int M, i if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif } //#define ROWS 2 diff --git a/deploy.sh b/deploy.sh index 24d6cbf6b..a2257a2bb 100644 --- a/deploy.sh +++ b/deploy.sh @@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then fi -make clean -export CUDA_HOME=$BASE_PATH/cuda-10.2 -make cuda10x_nomatmul CUDA_VERSION=102 - -if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then - # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 - exit 64 -fi - - make clean export CUDA_HOME=$BASE_PATH/cuda-11.0 make cuda110_nomatmul CUDA_VERSION=110