Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for quantized gemm skinny sizes #411

Merged
merged 13 commits into from
Feb 19, 2025
18 changes: 18 additions & 0 deletions csrc/rocm/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
void* scale_b, const int M, const int K, const int Kp,
const int N, const int Otp_in, cudaStream_t stream,
const int CuCount);

void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in,
const int64_t Otp_in, const int64_t CuCount) {
auto M = in_a.size(0);
auto K = in_a.size(1);
auto Kp = in_a.stride(0);
int N = N_in;
int Otp = Otp_in;
wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(),
scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down
Loading