Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Improve the compile times of gptq_marlin.cu #7317

Closed
tlrmchlsmth opened this issue Aug 8, 2024 · 6 comments
Closed

[Feature]: Improve the compile times of gptq_marlin.cu #7317

tlrmchlsmth opened this issue Aug 8, 2024 · 6 comments

Comments

@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented Aug 8, 2024

🚀 The feature, motivation and pitch

The compile times for the GPTQ Marlin kernels are quite long, and have become painful for developers. gptq_marlin.cu is monolithic and heavily templated, and many cases of the marlin kernel are instantiated. Compile times got particularly bad in #6612, where the number of cases being compiled approximately doubled.

Namely, 320 kernels are defined in this block of code in gptq_marlin.cu:

    GPTQ_CALL_IF(4, 16, 4, 256)
    GPTQ_CALL_IF(4, 8, 8, 256)
    GPTQ_CALL_IF(4, 8, 4, 128)
    GPTQ_CALL_IF(4, 4, 8, 128)
    GPTQ_CALL_IF(8, 16, 4, 256)
    GPTQ_CALL_IF(8, 8, 8, 256)
    GPTQ_CALL_IF(8, 8, 4, 128)
    GPTQ_CALL_IF(8, 4, 8, 128)

    AWQ_CALL_IF(4, 16, 4, 256)
    AWQ_CALL_IF(4, 8, 8, 256)
    AWQ_CALL_IF(4, 8, 4, 128)
    AWQ_CALL_IF(4, 4, 8, 128)
    AWQ_CALL_IF(8, 16, 4, 256)
    AWQ_CALL_IF(8, 8, 8, 256)
    AWQ_CALL_IF(8, 8, 4, 128)
    AWQ_CALL_IF(8, 4, 8, 128)

I think the best option for improving this right now it to split up gptq_marlin.cu into multiple files so that compilation can be parallelized.

Details:

First, this function and its dependencies should be moved into a file called something like gptq_marlin_kernel.cuh

  template <typename scalar_t,          // compute dtype, half or nv_float16
            const int num_bits,         // number of bits used for weights
            const int threads,          // number of threads in a threadblock
            const int thread_m_blocks,  // number of 16x16 blocks in the m
                                        // dimension (batchsize) of the
                                        // threadblock
            const int thread_n_blocks,  // same for n dimension (output)
            const int thread_k_blocks,  // same for k dimension (reduction)
            const int stages,  // number of stages for the async global->shared
                               // fetch pipeline
            const bool has_act_order,    // whether act_order is enabled
            const bool has_zp,           // whether zero-points are enabled
            const int group_blocks = -1  // number of consecutive 16x16 blocks
                                         // with a separate quantization scale
            >
  __global__ void Marlin { ... }

Next, we need to spread the instantiations of the template function across a sensible number of .cu files. Too many will likely be counter-productive, so some experimentation will be needed. Each new .cu file should include gptq_marlin_kernel.cuh, but to create a firewall, I think it's best if gptq_marlin.cu does not include gptq_marlin_kernel.cuh. Instead we can add a new file called gptq_marlin.cuh that just contains the declarations of the template specializations that have been defined.

Summary of the proposed file structure

I think gptq_marlin.cu should be broken down into something like the following:

  • gptq_marlin_kernel.cuh -- The bulk of gptq_marlin.cu should go in here.
  • gptq_marlin_a.cu through gptq_marlin_d.cu - These should include gptq_marlin_kernel.cuh and each define some number of marlin configs. Name them better than I did here.
  • gptq_marlin.cu -- drastically smaller than it is currently. Should contain dispatch logic and include gpq_marlin.cuh
  • gptq_marlin.cuh -- Should only declare the Marlin configs that have been defined in the new files e.g. gptq_marlin_a.cu - gptq_marlin_d.cu

Alternatives

If there are any template parameters that can be made dynamic without losing any performance, that is an even better option.

Additional context

No response

@bnellnm
Copy link
Contributor

bnellnm commented Aug 8, 2024

It might be nice to do something like the following. Define all the combinations in the header file and use them in the dispatching section and manual instantiation files. I don't know how fine grained the instantiations need to be so maybe a range is overkill. For example, maybe processing one tuple (4, 16, 4, 256) of combinations is enough.

Boost has a lot of nice preprocessor utilities to do things like this. Bringing in boost is a bit overkill just for this but it might be easy enough to extract out the bits that are needed.

gptq_marlin_kernel.cuh:
#define COMBOS ((4, 16, 4, 256), ... (8, 4, 8, 128))

gptq_marlin_kernel.cu:
   GPTQ_CALL_IF_COMBOS(COMBOS)
   AWQ_CALL_IF_COMBOS(COMBOS)

// where N is how many combinations to instantiate (evenly divisible by the total number)
gptq_marlin_1.cu:  
GPTQ_INSTANTIATE_COMBOS(COMBOS, 0, N)

gptq_marlin_2.cu:
GPTQ_INSTANTIATE_COMBOS(COMBOS, N, 2*N)
...
gptq_marlin_M.cu:
GPTQ_INSTANTIATE_COMBOS(COMBOS, (M-1)*N, M*N)

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Aug 9, 2024

I think it should be sufficient to just break down by a combination of weight-type and zero-point support. For example that would mean breaking down the current gptq_marlin.cu kernel into:

marlin_u4b8.cu    // weight-type: u4b8, zero-points: no   | i.e. GPTQ 4bit / Orig-Marlin
marlin_u8b128.cu  // weight-type: u8b128, zero-points: no | i.e. GPTQ 8bit
marlin_u4_zp.cu   // weight-type: u4, zero-points: yes    | i.e. AWQ 4bit 
marlin_u8_zp.cu   // weight-type: u8, zero-points: yes    | i.e. AWQ 8bit 

(Note the rest of this comment assumes #7323 is merged)

Just like @tlrmchlsmth mentioned I would think that gptq_marlin_kernel.cuh would contain the implementation/definition:

namespace marlin {

template <typename scalar_t,          // compute dtype, half or nv_float16
          int64_t w_type_id,         // weight type ScalarType id
          ....
          const bool has_zp,           // whether zero-points are enabled
          const int group_blocks = -1  // number of consecutive 16x16 blocks
                                        // with a separate quantization scale
          >
__global__ void Marlin { ... }

}  // namespace marlin

Then the cu files would instantiate the appropriate kernels, for example marlin_u4b8.cu would contain:

#include "gptq_marlin_kernel.cuh"
#include "scalar_type.hpp"

namespace marlin {

template<> __global__ void Marlin<vllm::kU4B8.id(), ..., false, -1>;
...
template<> __global__ void Marlin<vllm::kU4B8.id(), ..., false, 128>;
...

}  // namespace marlin

I suspect some macro magic like @bnellnm suggested would be appropriate here.

Then gptq_marlin.cu would just contain the current dispatching logic, i.e.:

// there should be no need to include `gptq_marlin_kernel.cuh` in this file

namespace marlin {

// Kernel Template Declaration
template <typename scalar_t,
          int64_t w_type_id,
          ....
          const bool has_zp,
          const int group_blocks = -1
          >
__global__ void Marlin;

template <typename scalar_t>
void marlin_mm(const void* A, ...) {
    ...

    GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
    GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
    ...

    AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
    AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
    ...

}

}  // namespace marlin

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, ....) {

  ...
  if (a.scalar_type() == at::ScalarType::Half) {
    marlin::marlin_mm<half>(a.data_ptr<at::Half>(), ...);
  } else if (a.scalar_type() == at::ScalarType::BFloat16) {
    marlin::marlin_mm<nv_bfloat16>(a.data_ptr<at::BFloat16>(), ...);
  } else {
    TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
  }

  return c;
}

Copy link

github-actions bot commented Nov 8, 2024

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Nov 8, 2024
Copy link

github-actions bot commented Dec 8, 2024

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Dec 8, 2024
@tlrmchlsmth
Copy link
Collaborator Author

@LucasWilkinson is there still stuff we can do here?

@LucasWilkinson
Copy link
Collaborator

I think this could still be an improvement (especially now with HQQ part of it), albeit lower priority due to the addition of python only development paths and better arch code handling over the last couple months. I think part of this could also be unifying QQQ Marlin and Fp8 Marlin into whatever the new structure ends up being.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants