Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: fix the compilation issue of pip wheels #115

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/flashinfer/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
num_kv_heads, rope_scale, rope_theta, stream);
handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr,
/*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads,
rope_scale, rope_theta, stream);
})})})})})});
return cudaSuccess;
}
Expand Down
1 change: 1 addition & 0 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
&handler_, static_cast<c_type*>(q.data_ptr()),
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
/*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr, batch_size,
num_kv_heads, rope_scale, rope_theta,
Expand Down
18 changes: 9 additions & 9 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v_a.scalar_type(), c_type, [&] {
cudaError_t status =
MergeState(static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()),
static_cast<float*>(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
cudaError_t status = MergeState(
static_cast<c_type*>(v_a.data_ptr()), static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()), static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()), static_cast<float*>(s_merged.data_ptr()),
seq_len, num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeState kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down Expand Up @@ -80,10 +80,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] {
cudaError_t status =
MergeStateInPlace(static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()),
static_cast<float*>(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream);
cudaError_t status = MergeStateInPlace(
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()), static_cast<float*>(s_other.data_ptr()), seq_len,
num_heads, head_dim, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down
25 changes: 13 additions & 12 deletions python/csrc/flashinfer_decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \
PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \
CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \
int32_t* q_rope_position, \
paged_kv_t<PageStorage::kIndices, LAYOUT, T, int32_t> paged_kv, T* o, \
float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \
}

#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
LAYOUT, ROTARY_MODE) \
namespace flashinfer { \
template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \
BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \
T* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, float rope_scale, \
float rope_theta, cudaStream_t stream); \
#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
LAYOUT, ROTARY_MODE) \
namespace flashinfer { \
template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \
BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \
int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \
uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \
}

#define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \
Expand All @@ -56,15 +57,15 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
const uint32_t num_kv_heads, const float rope_scale, const float rope_theta,
cudaStream_t stream);
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale,
const float rope_theta, cudaStream_t stream);

template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
float rope_scale, float rope_theta, cudaStream_t stream);

Expand Down
7 changes: 4 additions & 3 deletions python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
num_heads, page_size, head_dim, batch_size, static_cast<c_type*>(kv_data.data_ptr()),
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
cudaError_t status =
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
return true;
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include "generated/dispatch.inc"

Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from .decode import (
single_decode_with_kv_cache,
batch_decode_with_padded_kv_cache,
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional
import torch
Expand Down
7 changes: 4 additions & 3 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional, Union
import torch
Expand Down Expand Up @@ -477,9 +478,9 @@ def begin_forward(
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
empty_data = torch.empty(
0,
dtype=getattr(torch, data_type)
if isinstance(data_type, str)
else data_type,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self._wrapper.begin_forward(
self._workspace_buffer,
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch

try:
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import Optional
import torch
Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch


Expand Down
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import pathlib
import os
import re
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch
Expand Down
1 change: 1 addition & 0 deletions python/tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy
import pytest
import torch
Expand Down