Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang committed Apr 11, 2024
1 parent 5b4775a commit 455bbbd
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 207 deletions.
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from vllm.ops import ops
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random

NUM_BLOCKS = 1024
Expand Down
6 changes: 3 additions & 3 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from vllm.ops import cache_ops, ops
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
Expand Down Expand Up @@ -237,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache

value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache

ref_output = torch.empty_like(query)
Expand Down
25 changes: 12 additions & 13 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from vllm.ops import cache_ops
from vllm import _custom_ops as ops
from vllm.utils import is_hip

COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]

# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
ops.copy_blocks(key_caches, value_caches, block_mapping)

# Run the reference implementation.
for src, dsts in block_mapping.items():
Expand Down Expand Up @@ -145,9 +145,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, cloned_key_cache)
ops.convert_fp8(key_cache, cloned_key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, cloned_value_cache)
ops.convert_fp8(value_cache, cloned_value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
Expand All @@ -156,14 +156,14 @@ def test_reshape_and_cache(
kv_scale = 1.0

# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, kv_scale)

if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, result_key_cache)
ops.convert_fp8(key_cache, result_key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, result_value_cache)
ops.convert_fp8(value_cache, result_value_cache)

# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
Expand Down Expand Up @@ -251,9 +251,8 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()

# Call the swap_blocks kernel.
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)

for src, dst in block_mapping.items():
assert torch.allclose(src_key_caches_clone[src].cpu(),
Expand Down Expand Up @@ -291,9 +290,9 @@ def test_fp8_conversion(
cache.uniform_(low, high)

cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
cache_ops.convert_fp8(cache, cache_fp8)
ops.convert_fp8(cache, cache_fp8)

converted_cache = torch.empty_like(cache)
cache_ops.convert_fp8(cache_fp8, converted_cache)
ops.convert_fp8(cache_fp8, converted_cache)

assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
191 changes: 191 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from typing import Dict, Optional

import torch

try:
from vllm._C import cache_ops as vllm_cache_ops
from vllm._C import ops as vllm_ops
except ImportError:
pass


# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x)


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x)


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x)


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x)


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x)


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables,
context_lens, block_size, max_context_len,
alibi_slopes, kv_cache_dtype, kv_scale)


def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype,
kv_scale)


# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon)


def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> None:
vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> None:
vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> None:
vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
use_exllama, bit)


def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)


# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)


# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)


# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)


def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)


def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: Dict[int, int]) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping)


def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
vllm_cache_ops.convert_fp8(output, input)


#TODO: cuda_utils, custom_ar
10 changes: 5 additions & 5 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.ops import cache_ops, ops

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
Expand Down Expand Up @@ -75,7 +75,7 @@ def write_to_paged_cache(
kv_cache_dtype: str,
kv_scale: float,
) -> None:
cache_ops.reshape_and_cache(
ops.reshape_and_cache(
key,
value,
key_cache,
Expand Down Expand Up @@ -205,11 +205,11 @@ def swap_blocks(
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)

src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)

@staticmethod
def copy_blocks(
Expand All @@ -218,4 +218,4 @@ def copy_blocks(
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
ops.copy_blocks(key_caches, value_caches, src_to_dists)
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import torch.nn as nn
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.ops import ops


class SiluAndMul(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.ops import ops
from vllm.utils import is_hip

logger = init_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from vllm.ops import ops
from vllm import _custom_ops as ops


class RMSNorm(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.ops import ops


class AWQConfig(QuantizationConfig):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import torch
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.ops import ops


class GPTQConfig(QuantizationConfig):
Expand Down
Loading

0 comments on commit 455bbbd

Please sign in to comment.