Skip to content

Commit

Permalink
[Neuron] add custom_ops for neuron backend
Browse files Browse the repository at this point in the history
Co-authored-by: George Novack <[email protected]>
Co-authored-by: Aoyu Zhang <[email protected]>
Signed-off-by: Liangfu Chen <[email protected]>
  • Loading branch information
3 people committed Feb 17, 2025
1 parent d84cef7 commit ef47583
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 10 deletions.
48 changes: 48 additions & 0 deletions tests/neuron/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.activation import FastGELU, SiluAndMul
from vllm.platforms import current_platform


@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"])
@pytest.mark.parametrize("num_tokens,d,dtype", [
(7, 512, torch.half),
(7, 512, torch.float),
(83, 512, torch.half),
])
@torch.inference_mode()
def test_act_and_mul(
activation: str,
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm

device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device)
if activation == "silu_and_mul":
layer = SiluAndMul()

def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
assert x.is_cpu, "reference input is expected be executed on cpu."
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

fn = _silu_and_mul
elif activation == "gelu_fast":
layer = FastGELU()
fn = F.gelu
else:
raise NotImplementedError(
f"activation {activation} is not implemented.")
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device).forward_neuron(x)
ref_out = fn(x.cpu())
torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0)
56 changes: 56 additions & 0 deletions tests/neuron/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform


@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [
(7, 8, False, torch.half),
(83, 768, False, torch.half),
(83, 768, True, torch.half),
(83, 768, True, torch.bfloat16),
(83, 768, True, torch.float32),
])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
) -> None:
import torch_xla.core.xla_model as xm

device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

residual_cpu = residual.cpu() if add_residual else None
ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu)
assert x.is_xla, "input tensor under testing is expected to be XLA tensor."
out = layer.to(device=device)(x, residual)

# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if add_residual:
assert out[0].is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out[0].cpu(),
ref_out[0],
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out[1].cpu(),
ref_out[1],
atol=1e-2,
rtol=1e-2)
else:
assert out.is_xla, "output tensor is expected to be XLA tensor"
torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2)
95 changes: 95 additions & 0 deletions tests/neuron/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import Tuple
from unittest.mock import patch

import pytest
import torch

from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available


class MockLogitsProcessor(LogitsProcessor):

def __init__(self, vocab_size: int, scale: float,
fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size, scale=scale)
self.fake_logits = fake_logits.clone()

def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
lambda x, y: x
), patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
return input_tensor, fake_logits, logits_processor


RANDOM_SEEDS = list(range(8))


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_logits_processors(seed: int):
import torch_xla.core.xla_model as xm

device = xm.xla_device()
set_random_seed(seed)
torch.set_default_device("cpu")
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

seq_group_metadata_list = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)

fake_logits *= logits_processor.scale
torch.testing.assert_close(logits_processor_output[:, 1],
fake_logits[:, 1],
rtol=1e-4,
atol=0.0)
17 changes: 7 additions & 10 deletions tests/neuron/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,13 @@ def test_contexted_kv_attention(

device = xm.xla_device()

compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
os.environ["NEURON_CC_FLAGS"] = (
" --model-type=transformer -O1 --retry_failed_compilation "
" --internal-hlo2tensorizer-options='--verify-hlo' ")

torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
torch.set_default_device("cpu")

min_ctx_len = 32
max_ctx_len = 1024
Expand Down Expand Up @@ -394,9 +391,9 @@ def pad_to_next_power_of_2(a):

# transform block table
active_block_table = get_active_block_tables(
block_table,
torch.tensor(query_lens),
torch.tensor(seq_lens),
block_table.cpu(),
torch.tensor(query_lens).cpu(),
torch.tensor(seq_lens).cpu(),
block_size,
num_active_blocks,
)
Expand Down
56 changes: 56 additions & 0 deletions tests/neuron/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""

import pytest
import torch

from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform


@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
(11, False, 32, 32, 1024),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len):
import torch_xla.core.xla_model as xm

device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")

batch_size = 1
base = 10000
num_heads = 7

rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)

positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cpu")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query)

# rotary_embedding_opcheck(rot, positions, query, key)
assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device))
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)
7 changes: 7 additions & 0 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def forward_hpu(self, *args, **kwargs):
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def forward_neuron(self, *args, **kwargs):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
Expand Down Expand Up @@ -88,6 +93,8 @@ def dispatch_forward(self):
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_neuron():
return self.forward_neuron
elif current_platform.is_out_of_tree():
return self.forward_oot
else:
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
self.op(out, x)
return out

def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
# TODO(gnovack): clean this up
d = x.shape[-1] // 2
if len(x.shape) == 3:
s = x[:, :, :d] * torch.nn.functional.sigmoid(x[:, :, :d])
return s * x[:, :, d:]
elif len(x.shape) == 2:
s = x[:, :d] * torch.nn.functional.sigmoid(x[:, :d])
return s * x[:, d:]
else:
raise NotImplementedError(
"Expected input to have either 3 or 2 dims")


@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self,
# Whether to use gather or all-gather to gather the logits.
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or current_platform.is_neuron() \
or envs.VLLM_USE_V1 \
or parallel_config.distributed_executor_backend == "external_launcher" # noqa

Expand Down
41 changes: 41 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def forward_native(
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
import pdb
pdb.set_trace()
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
Expand Down Expand Up @@ -253,6 +255,45 @@ def forward_hpu(
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def forward_neuron(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO(gnovack): handle edge cases
if offsets is not None:
positions = positions + offsets

self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)

positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)

if self.rotary_dim == self.head_size:
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
query = query.reshape(query_shape)

key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)

if self.rotary_dim == self.head_size:
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
key = key.reshape(key_shape)
else:
key_pass = key[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
Expand Down

0 comments on commit ef47583

Please sign in to comment.