Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support fused silu mul #427

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions include/flashinfer/activation.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef FLASHINFER_ACTIVATION_CUH_
#define FLASHINFER_ACTIVATION_CUH_

#include "utils.cuh"
#include "vec_dtypes.cuh"

namespace flashinfer {

namespace activation {

// https://github.com/NVIDIA/FasterTransformer/blob/d21dc02bc5f70bc7dc0d18ba5801ae263565e68e/src/fastertransformer/kernels/activation_kernels.cu#L126-L133
__device__ __forceinline__ float silu_kernel(const float& val) {
// NOTE(Zihao): use __fdividef might be faster, at the cost of precision
return val / (1.0f + __expf(-val));
}

template <typename T, float (*Activation)(const float&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = 16 / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * 2 * d;

#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
vec_t<float, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
y_vec.cast_load(input + offset + d + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}

const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
float x = input[offset + remaining_offset + idx],
y = input[offset + remaining_offset + d + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
}
}

} // namespace activation
} // namespace flashinfer

#endif // FLASHINFER_ACTIVATION_CUH_
42 changes: 42 additions & 0 deletions python/csrc/activation.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <flashinfer/activation.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::silu_kernel>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

return true;
});
}
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
)
from .activation import silu_and_mul
from .group_gemm import SegmentGEMMWrapper
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
Expand Down
71 changes: 71 additions & 0 deletions python/flashinfer/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
from typing import Optional

# mypy: disable-error-code="attr-defined"
try:
from . import _kernels
except ImportError as e:
import logging
import os

if os.environ.get("BUILD_DOC", "0") == "1":
_kernels = None
logging.warning("Kernels are not loaded in documentation build mode.")
else:
raise e


def _check_shape(input: torch.Tensor, output: torch.Tensor):
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"


def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused SiLU and Mul operation.

Parameters
----------
input: torch.Tensor
Input tensor, shape (..., 2 * hidden_size).

out: Optional[torch.Tensor]
The the output tensor, if specified, the kernel will update this tensor inplace.

Returns
-------
output: torch.Tensor
Output tensor, shape (..., hidden_size).
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
_kernels.silu_and_mul(out, input)
return out
2 changes: 2 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(self, *args, **kwargs) -> None:
"1",
"-Xfatbin",
"-compress-all",
"-use_fast_math",
],
}
ext_modules = []
Expand All @@ -341,6 +342,7 @@ def __init__(self, *args, **kwargs) -> None:
"csrc/flashinfer_ops.cu",
"csrc/sampling.cu",
"csrc/norm.cu",
"csrc/activation.cu",
"csrc/rope.cu",
"csrc/group_gemm.cu",
"csrc/quantization.cu",
Expand Down
33 changes: 33 additions & 0 deletions python/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch

import flashinfer


@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_silu_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
y = flashinfer.activation.silu_and_mul(x)
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)