Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compile error of OptionalCUDAGuard and device_of #613

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <flashinfer/attention/decode_params.cuh>
Expand Down Expand Up @@ -42,7 +43,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. device_of is in a private header of torch.

  2. It seems that OptionalCUDAGuard is not necessary since the device is not optional.

  3. Adding include file for CUDAGuard

const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

Expand Down Expand Up @@ -113,7 +114,7 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun(
}
uint32_t head_dim = q.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
if (maybe_lse) {
Expand Down
3 changes: 2 additions & 1 deletion python/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <flashinfer/attention/mask.cuh>
Expand Down Expand Up @@ -50,7 +51,7 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
Expand Down
4 changes: 3 additions & 1 deletion python/csrc_aot/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>

#include <flashinfer/attention/decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <optional>
Expand Down Expand Up @@ -60,7 +62,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
kv_len = k.size(1);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);

Expand Down
3 changes: 2 additions & 1 deletion python/csrc_aot/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <flashinfer/attention/mask.cuh>
Expand Down Expand Up @@ -56,7 +57,7 @@ torch::Tensor single_prefill_with_kv_cache(
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down