From 6a6daddb4c8d5241df4fcf836af518df793e8fde Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Jan 2025 23:12:42 -0800 Subject: [PATCH] Split sgemm for lora_a and lora_b --- python/sglang/srt/lora/backend/__init__.py | 1 + .../sglang/srt/lora/backend/base_backend.py | 20 +- .../srt/lora/backend/flashinfer_backend.py | 23 +- .../sglang/srt/lora/backend/triton_backend.py | 351 +----------------- python/sglang/srt/lora/lora.py | 16 +- python/sglang/srt/lora/triton_ops/__init__.py | 5 + .../sglang/srt/lora/triton_ops/qkv_lora_b.py | 180 +++++++++ .../srt/lora/triton_ops/sgemm_lora_a.py | 144 +++++++ .../srt/lora/triton_ops/sgemm_lora_b.py | 142 +++++++ 9 files changed, 528 insertions(+), 354 deletions(-) create mode 100644 python/sglang/srt/lora/triton_ops/__init__.py create mode 100644 python/sglang/srt/lora/triton_ops/qkv_lora_b.py create mode 100644 python/sglang/srt/lora/triton_ops/sgemm_lora_a.py create mode 100644 python/sglang/srt/lora/triton_ops/sgemm_lora_b.py diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py index 9a3ea52dc157..ed377b4b4adb 100644 --- a/python/sglang/srt/lora/backend/__init__.py +++ b/python/sglang/srt/lora/backend/__init__.py @@ -1,3 +1,4 @@ +from .base_backend import BaseLoraBackend from .flashinfer_backend import FlashInferLoraBackend from .triton_backend import TritonLoraBackend diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 3f04aabb10d7..f1c41f1204b6 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -16,13 +16,27 @@ def __init__(self, name: str, batch_info: LoraBatchInfo = None): self.name = name self.batch_info = batch_info - def run_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - """Run segment Gemm with current backend. + def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """Run segment Gemm of lora a modules with current backend. The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. Args: x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths - weights: a set of lora weights with shape (num_lora, output_dim, input_dim) + weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank + Usually input_dim is much larger than r + Returns: + result with shape (s, r) + """ + pass + + def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """Run segment Gemm of lora b modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank + weights: a set of lora weights with shape (num_lora, output_dim, r) + Usually output_dim is much larger than r Returns: result with shape (s, output_dim) """ diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py index 5e4273603134..88136555ad42 100644 --- a/python/sglang/srt/lora/backend/flashinfer_backend.py +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -1,7 +1,7 @@ import torch from flashinfer import SegmentGEMMWrapper -from sglang.srt.lora.backend.base_backend import BaseLoraBackend +from sglang.srt.lora.backend import BaseLoraBackend from sglang.srt.lora.lora import LoraBatchInfo @@ -15,7 +15,18 @@ def __init__(self, name: str, batch_info: LoraBatchInfo = None): workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) - def run_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: return self.segment_gemm.run( x=x, @@ -35,7 +46,7 @@ def run_qkv_lora( ) -> torch.Tensor: # Shape of lora_a_output: (s, 3 * r) - lora_a_output = self.run_sgemm(x=x, weights=qkv_lora_a) + lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) lora_rank = kv_lora_b.shape[-1] output_dim_q = q_lora_b.shape[-2] @@ -48,19 +59,19 @@ def run_qkv_lora( # FIXME parallelize qkv # q - lora_output[:, :output_dim_q] = self.run_sgemm( + lora_output[:, :output_dim_q] = self.run_lora_b_sgemm( x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] ) # kv - lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_sgemm( + lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_lora_b_sgemm( x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), weights=kv_lora_b[0], ) lora_output[ :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv - ] = self.run_sgemm( + ] = self.run_lora_b_sgemm( x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(), weights=kv_lora_b[1], ) diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index d61cb296b93c..34db604e8add 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -2,188 +2,13 @@ import triton import triton.language as tl -from sglang.srt.lora.backend.base_backend import BaseLoraBackend +from sglang.srt.lora.backend import BaseLoraBackend from sglang.srt.lora.lora import LoraBatchInfo - - -@triton.jit -def _sgemm_kernel( - # Pointers to matrices - x, - weights, - output, - # Matrix dimensions - N, - K, - # Strides - x_stride_0, - x_stride_1, - w_stride_0, - w_stride_1, - w_stride_2, - output_stride_0, - output_stride_1, - # Information on sequence lengths and weight id - seg_lens, - seg_indptr, - weight_indices, - # Meta parameters - BLOCK_S: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - - # x: (s, K), s is the sum of sequence lengths - # weights: (num_lora, N, K) - # output: (s, N) - - # Current block computes sequence with batch_id, - # which starts from row seg_start of x with length seg_len - batch_id = tl.program_id(axis=1) - pid = tl.program_id(axis=0) - seg_len = tl.load(seg_lens + batch_id) - w_index = tl.load(weight_indices + batch_id) - seg_start = tl.load(seg_indptr + batch_id) - - # The tile in output matrix will have (pid_s, pid_n) as id - # FIXME: might be replaced by super-grouping - num_pid_n = tl.cdiv(N, BLOCK_N) - pid_s = pid // num_pid_n - pid_n = pid % num_pid_n - - # Create pointers for the first block of x and weights[batch_id] - # The pointers will be advanced as we move in the K direction - # and accumulate - s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S - n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - k_offset = tl.arange(0, BLOCK_K) - x_ptrs = (x + seg_start * x_stride_0) + ( - s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 - ) - w_ptrs = (weights + w_index * w_stride_0) + ( - k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 - ) - - # Iteate to compute the block in output matrix - partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_K)): - x_tile = tl.load( - x_ptrs, - mask=(s_offset[:, None] < seg_len) - and (k_offset[None, :] < K - k * BLOCK_K), - other=0.0, - ) - w_tile = tl.load( - w_ptrs, - mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N), - other=0.0, - ) - partial_sum += tl.dot(x_tile, w_tile) - - x_ptrs += BLOCK_K * x_stride_1 - w_ptrs += BLOCK_K * w_stride_2 - - # Store result to output matrix - partial_sum = partial_sum.to(x.dtype.element_ty) - output_ptr = (output + seg_start * output_stride_0) + ( - s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 - ) - output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N) - tl.store(output_ptr, partial_sum, mask=output_mask) - - -@triton.jit -def _qkv_lora_b_kernel( - # Pointers to matrices - x, - weights, - output, - # Parameters of size - K, # K = R - N_Q, - N_KV, - # Strides - x_stride_0, - x_stride_1, - w_stride_0, - w_stride_1, - w_stride_2, - output_stride_0, - output_stride_1, - # Information on sequence lengths and weight id - seg_lens, - seg_indptr, - weight_indices, - # Offsets of q/k/v slice on output dimension - n_indptr, - # Meta parameters - BLOCK_S: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - # This kernel packs 3 sgemms (q/k/v) into a single kernel. - - # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank - # weights: (num_lora, N_Q + 2 * N_KV, K) - # output: (s, N_Q + 2 * N_KV) - - # Current block computes sequence with batch_id, - # which starts from row seg_start of x with length seg_len. - # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) - batch_id = tl.program_id(axis=2) - qkv_id = tl.program_id(axis=1) - pid = tl.program_id(axis=0) - seg_len = tl.load(seg_lens + batch_id) - w_index = tl.load(weight_indices + batch_id) - seg_start = tl.load(seg_indptr + batch_id) - n_start = tl.load(n_indptr + qkv_id) - n_size = tl.load(n_indptr + qkv_id + 1) - n_start - - # The tile in output matrix will have (pid_s, pid_n) as id - # FIXME: might be replaced by super-grouping - num_pid_n = tl.cdiv(max(N_Q, N_KV), BLOCK_N) - pid_s = pid // num_pid_n - pid_n = pid % num_pid_n - - # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] - # The pointers will be advanced as we move in the K direction - # and accumulate - s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S - n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - k_offset = tl.arange(0, BLOCK_K) - x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( - s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 - ) - w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( - k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 - ) - - # Iteate to compute the block in output matrix - partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_K)): - x_tile = tl.load( - x_ptrs, - mask=(s_offset[:, None] < seg_len) - and (k_offset[None, :] < K - k * BLOCK_K), - other=0.0, - ) - w_tile = tl.load( - w_ptrs, - mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), - other=0.0, - ) - partial_sum += tl.dot(x_tile, w_tile) - - x_ptrs += BLOCK_K * x_stride_1 - w_ptrs += BLOCK_K * w_stride_2 - - # Store result to output matrix - partial_sum = partial_sum.to(x.dtype.element_ty) - output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( - s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 - ) - output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) - tl.store(output_ptr, partial_sum, mask=output_mask) +from sglang.srt.lora.triton_ops import ( + qkv_lora_b_fwd, + sgemm_lora_a_fwd, + sgemm_lora_b_fwd, +) class TritonLoraBackend(BaseLoraBackend): @@ -191,54 +16,11 @@ class TritonLoraBackend(BaseLoraBackend): def __init__(self, name: str, batch_info: LoraBatchInfo = None): super().__init__(name, batch_info) - def run_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - - assert x.is_contiguous() - assert weights.is_contiguous() - assert len(x.shape) == 2 - assert len(weights.shape) == 3 - - # x: (s, k) - # weights: (num_lora, n, k) - # output: (s, n) - S = x.shape[0] - N = weights.shape[-2] - K = weights.shape[-1] - assert x.shape[-1] == K - - # Block shapes - # Autotuning tried but not effective - BLOCK_S = 16 - BLOCK_K = 16 - BLOCK_N = 16 - - grid = ( - triton.cdiv(self.batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N), - self.batch_info.bs, - ) + def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + return sgemm_lora_a_fwd(x, weights, self.batch_info) - output = torch.empty((S, N), device=x.device, dtype=x.dtype) - _sgemm_kernel[grid]( - x, - weights, - output, - N, - K, - x.stride(0), - x.stride(1), - weights.stride(0), - weights.stride(1), - weights.stride(2), - output.stride(0), - output.stride(1), - self.batch_info.seg_lens, - self.batch_info.seg_indptr, - self.batch_info.weight_indices, - BLOCK_S, - BLOCK_N, - BLOCK_K, - ) - return output + def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + return sgemm_lora_b_fwd(x, weights, self.batch_info) def run_qkv_lora( self, @@ -253,115 +35,8 @@ def run_qkv_lora( # q_lora_b: (1, num_lora, output_dim_q, r) # kv_lora_b: (2, num_lora, output_dim_kv, r) - assert x.is_contiguous() - assert qkv_lora_a.is_contiguous() - assert q_lora_b.is_contiguous() - assert kv_lora_b.is_contiguous - assert len(x.shape) == 2 - assert len(qkv_lora_a.shape) == 3 - assert len(q_lora_b.shape) == 4 - assert len(kv_lora_b.shape) == 4 - - # Get dims - s = x.shape[0] - input_dim = x.shape[1] - output_dim_q = q_lora_b.shape[-2] - output_dim_kv = kv_lora_b.shape[-2] - r = q_lora_b.shape[-1] - middle_dim = 3 * r - output_dim = output_dim_q + 2 * output_dim_kv - assert qkv_lora_a.shape[-2] == middle_dim - assert qkv_lora_a.shape[-1] == input_dim - assert kv_lora_b.shape[-1] == r - - # Compute lora_a_output = sgemm(x, qkv_lora_a) - # shape of lora_a_output: (s, middle_dim) - BLOCK_S = 16 - BLOCK_IN = 64 - BLOCK_R = 16 - - grid_a = ( - triton.cdiv(self.batch_info.max_len, BLOCK_S) - * triton.cdiv(middle_dim, BLOCK_R), - self.batch_info.bs, + lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info) + lora_output = qkv_lora_b_fwd( + lora_a_output, q_lora_b, kv_lora_b, self.batch_info ) - - lora_a_output = torch.empty((s, middle_dim), device=x.device, dtype=x.dtype) - _sgemm_kernel[grid_a]( - x, - qkv_lora_a, - lora_a_output, - middle_dim, - input_dim, - x.stride(0), - x.stride(1), - qkv_lora_a.stride(0), - qkv_lora_a.stride(1), - qkv_lora_a.stride(2), - lora_a_output.stride(0), - lora_a_output.stride(1), - self.batch_info.seg_lens, - self.batch_info.seg_indptr, - self.batch_info.weight_indices, - BLOCK_S, - BLOCK_R, - BLOCK_IN, - ) - - # Compute lora_output with shape (s, output_dim) as follows: - # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], q_lora_b[0]) - # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] - # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) - # lora_output[:, output_dim_q + output_dim_kv: ] - # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) - BLOCK_S = 16 - BLOCK_R = 16 - BLOCK_OUT = 64 - - grid_b = ( - triton.cdiv(self.batch_info.max_len, BLOCK_S) - * triton.cdiv(max(output_dim_q, output_dim_kv), BLOCK_OUT), - 3, # this dimension decides current block computes on q, k or v - self.batch_info.bs, - ) - - # w_lora_b with shape (num_lora, output_dim_q + 2 * output_dim_kv, r) is passed to kernel - w_lora_b = torch.cat( - (q_lora_b[0], kv_lora_b[0], kv_lora_b[1]), dim=-2 - ).contiguous() - lora_output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) - n_indptr = torch.tensor( - [ - 0, - output_dim_q, - output_dim_q + output_dim_kv, - output_dim_q + 2 * output_dim_kv, - ], - dtype=torch.int32, - device=x.device, - ) - - _qkv_lora_b_kernel[grid_b]( - lora_a_output, - w_lora_b, - lora_output, - r, - output_dim_q, - output_dim_kv, - lora_a_output.stride(0), - lora_a_output.stride(1), - w_lora_b.stride(0), - w_lora_b.stride(1), - w_lora_b.stride(2), - lora_output.stride(0), - lora_output.stride(1), - self.batch_info.seg_lens, - self.batch_info.seg_indptr, - self.batch_info.weight_indices, - n_indptr, - BLOCK_S, - BLOCK_OUT, - BLOCK_R, - ) - return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 801ffd397a6c..b1db04e7f65e 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -120,18 +120,20 @@ def set_lora_info( self.B_buffer = B_buffer def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.lora_backend.run_sgemm(x=x, weights=self.A_buffer) + lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) output_dim = base_output.shape[-1] lora_output = torch.empty_like(base_output) - lora_output[:, :output_dim] = self.lora_backend.run_sgemm( + lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm( x=lora_a_output[:, 0 : self.lora_rank].contiguous(), weights=self.B_buffer[0], ) - lora_output[:, output_dim : 2 * output_dim] = self.lora_backend.run_sgemm( - x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(), - weights=self.B_buffer[1], + lora_output[:, output_dim : 2 * output_dim] = ( + self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(), + weights=self.B_buffer[1], + ) ) return base_output + lora_output * self.scaling @@ -176,8 +178,8 @@ def set_lora_info(self, A_buffer, B_buffer): self.B_buffer = B_buffer def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.lora_backend.run_sgemm(x=x, weights=self.A_buffer) - lora_output = self.lora_backend.run_sgemm( + lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) + lora_output = self.lora_backend.run_lora_b_sgemm( x=lora_a_output, weights=self.B_buffer[0] ) return base_output + lora_output * self.scaling diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py new file mode 100644 index 000000000000..efc76bb8b472 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -0,0 +1,5 @@ +from .qkv_lora_b import qkv_lora_b_fwd +from .sgemm_lora_a import sgemm_lora_a_fwd +from .sgemm_lora_b import sgemm_lora_b_fwd + +__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"] diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py new file mode 100644 index 000000000000..14bbcd10592b --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -0,0 +1,180 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _qkv_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + N_Q, + N_KV, + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Offsets of q/k/v slice on output dimension + n_indptr, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # This kernel packs 3 sgemms (q/k/v) into a single kernel. + + # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, N_Q + 2 * N_KV, K) + # output: (s, N_Q + 2 * N_KV) + # N_Q >> K, N_KV >> K + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + batch_id = tl.program_id(axis=2) + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_indptr + qkv_id) + n_size = tl.load(n_indptr + qkv_id + 1) - n_start + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(max(N_Q, N_KV), BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def qkv_lora_b_fwd( + x: torch.Tensor, + q_lora_b: torch.Tensor, + kv_lora_b: torch.Tensor, + batch_info: LoraBatchInfo, +) -> torch.Tensor: + + # x: (s, 3 * r) + # q_lora_b: (1, num_lora, output_dim_q, r) + # kv_lora_b: (2, num_lora, output_dim_kv, r) + # output: (s, output_dim_q + 2 * output_dim_kv) + + # Compute lora_output with shape (s, output_dim) as follows: + # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], q_lora_b[0]) + # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] + # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) + # lora_output[:, output_dim_q + output_dim_kv: ] + # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) + + # Get dims + s = x.shape[0] + input_dim = x.shape[1] + r = q_lora_b.shape[-1] + output_dim_q = q_lora_b.shape[-2] + output_dim_kv = kv_lora_b.shape[-2] + output_dim = output_dim_q + 2 * output_dim_kv + assert input_dim == 3 * r + + # FIXME: replace with autotune + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_OUT = 64 + + grid_b = ( + triton.cdiv(batch_info.max_len, BLOCK_S) + * triton.cdiv(max(output_dim_q, output_dim_kv), BLOCK_OUT), + 3, # this dimension decides current block computes on q, k or v + batch_info.bs, + ) + + # w_lora_b with shape (num_lora, output_dim_q + 2 * output_dim_kv, r) is passed to kernel + w_lora_b = torch.cat((q_lora_b[0], kv_lora_b[0], kv_lora_b[1]), dim=-2).contiguous() + lora_output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) + n_indptr = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=x.device, + ) + + _qkv_lora_b_kernel[grid_b]( + x, + w_lora_b, + lora_output, + r, + output_dim_q, + output_dim_kv, + x.stride(0), + x.stride(1), + w_lora_b.stride(0), + w_lora_b.stride(1), + w_lora_b.stride(2), + lora_output.stride(0), + lora_output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + n_indptr, + BLOCK_S, + BLOCK_OUT, + BLOCK_R, + ) + + return lora_output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py new file mode 100644 index 000000000000..fec73f7679e2 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -0,0 +1,144 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_a_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # r + K, # input_dim + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_a_fwd( + x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo +) -> torch.Tensor: + # x: (s, input_dim) + # weights: (num_lora, r, input_dim) + # output: (s, r) + # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r + # input_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + R = weights.shape[-2] + K = weights.shape[-1] + assert x.shape[-1] == K + + # Block shapes + # FIXME: Add autotune + BLOCK_S = 16 + BLOCK_K = 256 + BLOCK_R = 16 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R), + batch_info.bs, + ) + + output = torch.empty((S, R), device=x.device, dtype=x.dtype) + _sgemm_lora_a_kernel[grid]( + x, + weights, + output, + R, + K, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_R, + BLOCK_K, + ) + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py new file mode 100644 index 000000000000..d2961eb0532f --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -0,0 +1,142 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # output_dim + K, # r + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = s_offset[:, None] < seg_len + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_b_fwd( + x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo +) -> torch.Tensor: + # x: (s, r) + # weights: (num_lora, output_dim, r) + # output: (s, output_dim) + # output_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + N = weights.shape[-2] + R = weights.shape[-1] + assert x.shape[-1] == R + + # Block shapes + # FIXME: Add autotune + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_N = 256 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N), + batch_info.bs, + ) + + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + _sgemm_lora_b_kernel[grid]( + x, + weights, + output, + N, + R, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_N, + BLOCK_R, + ) + return output