Skip to content

Commit

Permalink
bugfix: Fix compile error of OptionalCUDAGuard and device_of (#613)
Browse files Browse the repository at this point in the history
There are some compile errors in the main branch, like

```
/app/python/csrc_aot/single_prefill.cu(59): error: namespace "at::cuda" has no member "OptionalCUDAGuard"
    const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
                    ^

/app/python/csrc_aot/single_prefill.cu(59): error: identifier "device_of" is undefined
    const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
```

cc @yzh119
  • Loading branch information
reyoung authored Nov 18, 2024
1 parent b53a46f commit dd3c836
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
5 changes: 3 additions & 2 deletions python/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <flashinfer/attention/decode_params.cuh>
Expand Down Expand Up @@ -42,7 +43,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");

Expand Down Expand Up @@ -113,7 +114,7 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun(
}
uint32_t head_dim = q.size(2);

const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
if (maybe_lse) {
Expand Down
3 changes: 2 additions & 1 deletion python/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <flashinfer/attention/mask.cuh>
Expand Down Expand Up @@ -50,7 +51,7 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();

auto device = float_workspace_buffer.device();
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
Expand Down
4 changes: 3 additions & 1 deletion python/csrc_aot/single_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>

#include <flashinfer/attention/decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <optional>
Expand Down Expand Up @@ -60,7 +62,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
kv_len = k.size(1);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q);

Expand Down
3 changes: 2 additions & 1 deletion python/csrc_aot/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <flashinfer/attention/mask.cuh>
Expand Down Expand Up @@ -56,7 +57,7 @@ torch::Tensor single_prefill_with_kv_cache(
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
const at::cuda::CUDAGuard device_guard(device);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
if (maybe_lse) {
Expand Down

0 comments on commit dd3c836

Please sign in to comment.