Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Build] Guard against older CUDA versions when building CUTLASS 3.x kernels #5168

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000

#include <torch/extension.h>

#include <ATen/cuda/CUDAContext.h>
Expand All @@ -6,8 +12,6 @@
#include <sstream>
#include <vector>

// clang-format will break include orders
// clang-format off
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
Expand Down Expand Up @@ -241,3 +245,5 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
}
}
}

#endif
11 changes: 10 additions & 1 deletion csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cudaTypedefs.h>

#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
Expand All @@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif

void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
Expand Down Expand Up @@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,

if (version_num >= 90) {
// Hopper

// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
#else
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
Expand Down
Loading