From 455bbbd27e553b01a103b71112aac7dc6a51a5c8 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 10 Apr 2024 22:14:45 +0800 Subject: [PATCH] address comments --- .../kernels/benchmark_paged_attention.py | 2 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 25 ++- vllm/_custom_ops.py | 191 ++++++++++++++++++ vllm/attention/ops/paged_attn.py | 10 +- vllm/model_executor/layers/activation.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- .../model_executor/layers/quantization/awq.py | 2 +- .../layers/quantization/gptq.py | 2 +- .../layers/quantization/marlin.py | 2 +- .../layers/quantization/squeezellm.py | 2 +- .../model_executor/layers/rotary_embedding.py | 2 +- vllm/ops.py | 175 ---------------- vllm/utils.py | 4 +- 15 files changed, 222 insertions(+), 207 deletions(-) create mode 100644 vllm/_custom_ops.py delete mode 100644 vllm/ops.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index e0d8ac9e9ac11..5c3650fa72d17 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,7 +5,7 @@ import torch -from vllm.ops import ops +from vllm import _custom_ops as ops from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 75dee6adaed4b..9b1f3e30b6dca 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm.ops import cache_ops, ops +from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -237,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 98cab3d7a4813..d1051fd7e2f4d 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -4,7 +4,7 @@ import pytest import torch -from vllm.ops import cache_ops +from vllm import _custom_ops as ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -80,7 +80,7 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + ops.copy_blocks(key_caches, value_caches, block_mapping) # Run the reference implementation. for src, dsts in block_mapping.items(): @@ -145,9 +145,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(key_cache, cloned_key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(value_cache, cloned_value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -156,14 +156,14 @@ def test_reshape_and_cache( kv_scale = 1.0 # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(key_cache, result_key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -251,9 +251,8 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) for src, dst in block_mapping.items(): assert torch.allclose(src_key_caches_clone[src].cpu(), @@ -291,9 +290,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - cache_ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache, cache_fp8) converted_cache = torch.empty_like(cache) - cache_ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(cache_fp8, converted_cache) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000000000..b26e2d4369aba --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,191 @@ +from typing import Dict, Optional + +import torch + +try: + from vllm._C import cache_ops as vllm_cache_ops + from vllm._C import ops as vllm_ops +except ImportError: + pass + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_tanh_and_mul(out, x) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_new(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, + context_lens, block_size, max_context_len, + alibi_slopes, kv_cache_dtype, kv_scale) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + vllm_ops.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> None: + vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> None: + vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> None: + vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + use_exllama, bit) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + + +# squeezellm +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) + + +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: Dict[int, int]) -> None: + vllm_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: + vllm_cache_ops.convert_fp8(output, input) + + +#TODO: cuda_utils, custom_ar diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 9c74e53853ac5..080cd69209206 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -3,8 +3,8 @@ import torch +from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.ops import cache_ops, ops # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 @@ -75,7 +75,7 @@ def write_to_paged_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - cache_ops.reshape_and_cache( + ops.reshape_and_cache( key, value, key_cache, @@ -205,11 +205,11 @@ def swap_blocks( ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -218,4 +218,4 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 6ac0ca787d3ff..baf1d4f266181 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,11 +6,11 @@ import torch.nn as nn import torch.nn.functional as F +from vllm import _custom_ops as ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from vllm.ops import ops class SiluAndMul(nn.Module): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fd47df7e9cb27..377b6588dbf47 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,8 +8,8 @@ import triton import triton.language as tl +from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.ops import ops from vllm.utils import is_hip logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index fb61f5de30c96..a6619714b8aab 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.ops import ops +from vllm import _custom_ops as ops class RMSNorm(nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index b974046ea7e06..daea5ac73e429 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,11 +3,11 @@ import torch from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.ops import ops class AWQConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index ed982589d1a2e..757ab1af8392e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,11 +6,11 @@ import torch from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.ops import ops class GPTQConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 6fa532c4f7c31..a6482c059cc41 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -3,11 +3,11 @@ import torch from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.ops import ops class MarlinConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 1fbf5f9cdb305..bb295df2acc3f 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,11 +3,11 @@ import torch from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.ops import ops from vllm.utils import is_hip diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 476d495746ccc..eb8d5f6dfb2a9 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from vllm.ops import ops +from vllm import _custom_ops as ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/ops.py b/vllm/ops.py deleted file mode 100644 index 93eb7911cb60d..0000000000000 --- a/vllm/ops.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import Dict, Optional - -import torch - -try: - from vllm._C import cache_ops as vllm_cache_ops - from vllm._C import ops as vllm_ops -except ImportError: - pass - - -class ops: - - # activation ops - def silu_and_mul(out: torch.Tensor, x: torch.Tensor): - vllm_ops.silu_and_mul(out, x) - - def gelu_and_mul(out: torch.Tensor, x: torch.Tensor): - vllm_ops.gelu_and_mul(out, x) - - def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor): - vllm_ops.gelu_tanh_and_mul(out, x) - - def gelu_fast(out: torch.Tensor, x: torch.Tensor): - vllm_ops.gelu_fast(out, x) - - def gelu_new(out: torch.Tensor, x: torch.Tensor): - vllm_ops.gelu_new(out, x) - - # page attention ops - def paged_attention_v1( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - block_size: int, - max_context_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, - kv_scale: float, - ): - vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, - context_lens, block_size, max_context_len, - alibi_slopes, kv_cache_dtype, kv_scale) - - def paged_attention_v2( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - block_size: int, - max_context_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, - kv_scale: float, - ): - vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, - scale, block_tables, context_lens, - block_size, max_context_len, alibi_slopes, - kv_cache_dtype, kv_scale) - - # pos encoding ops - def rotary_embedding( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - head_size: int, - cos_sin_cache: torch.Tensor, - is_neox: bool, - ): - vllm_ops.rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox) - - def batched_rotary_embedding(positions: torch.tensor, query: torch.tensor, - key: torch.tensor, head_size: int, - cos_sin_cache: torch.tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.tensor): - vllm_ops.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) - - # layer norm ops - def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float): - vllm_ops.rms_norm(out, input, weight, epsilon) - - def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float): - vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) - - # quantization ops - # awq - def awq_dequantize(qweight: torch.tensor, scales: torch.tensor, - zeros: torch.tensor, split_k_iters, thx, thy): - vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, - thy) - - def awq_gemm(input: torch.tensor, qweight: torch.tensor, - qzeros: torch.tensor, scales: torch.tensor, - split_k_iters: int): - vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) - - # gptq - def gptq_gemm(a: torch.tensor, b_q_weight: torch.tensor, - b_gptq_qzeros: torch.tensor, b_gptq_scales: torch.tensor, - b_g_idx: torch.tensor, use_exllama: bool, bit: int): - vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) - - def gptq_shuffle(q_weight: torch.tensor, q_perm: torch.Tensor, bit: int): - vllm_ops.gptq_shuffle(q_weight, q_perm, bit) - - # squeezellm - def squeezellm_gemm(vec: torch.tensor, mat: torch.tensor, - mul: torch.tensor, lookup_table: torch.tensor): - vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) - - # marlin - def marlin_gemm(a: torch.tensor, b_q_weight: torch.tensor, - b_scales: torch.tensor, workspace: torch.tensor, - size_m: int, size_n: int, size_k: int): - return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) - - # moe - def moe_align_block_size(topk_ids: torch.tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.tensor, - experts_ids: torch.tensor, - num_tokens_post_pad: torch.tensor): - vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) - - -class cache_ops: - - def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - kv_scale: float, - ) -> None: - vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, - kv_scale) - - def copy_blocks(key_caches, value_caches, block_mapping): - vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - - def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: Dict[int, int]): - vllm_cache_ops.swap_blocks(src, dst, block_mapping) - - def convert_fp8(output: torch.tensor, input: torch.tensor): - vllm_cache_ops.convert_fp8(output, input) - - -#TODO: cuda_utils, custom_ar diff --git a/vllm/utils.py b/vllm/utils.py index 1427995e367bb..8ab8927512cc9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -279,10 +279,10 @@ def _generate_random_fp8( #-----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm.ops import cache_ops + from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp