Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Sep 3, 2024
1 parent be9f84e commit 05e67ab
Show file tree
Hide file tree
Showing 28 changed files with 237 additions and 374 deletions.
45 changes: 5 additions & 40 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# Overview of the optional performance features uinque to https://github.com/ROCm/vllm
## Multi-GPU torchrun
On ROCm the default multi GPU executor is `torchrun` as opposed to `ray` on NVIDIA
This can be overridden by the `--worker-use-ray` flag to vllm or its benchmarks
To utilize torchran parallelism, the run command should be modified from
`python <command>`
to
`torchrun --standalone --nnodes=1 --nproc-per-node=<world-size> <command>`

## Triton attention
The default attention function on ROCm is using triton attention kernel. To fallback to the https://github.com/ROCm/flash-attention implementation set up the following environment symbol:
`VLLM_USE_TRITON_FLASH_ATTN=0`

## Tunable ops
Pytorch tunable ops are supported.
Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, set `PYTORCH_TUNABLEOP_TUNING=0`
Expand All @@ -17,39 +12,9 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order

On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format.

By default, rocm-vllm accepts the quantized weights generated by Quark quantizer. To do this, install quark and run the command:

```
python3 quantize_quark.py --model_dir [llama2 checkpoint folder] \
--output_dir output_dir \
--quant_scheme w_fp8_a_fp8_o_fp8 \
--num_calib_data 128 \
--model_export vllm_adopted_safetensors \
--no_weight_matrix_merge
```
For more details, please refer to Quark's documentation.

To use ammo, please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer), and set `VLLM_FP8_USE_AMMO=1`.

Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantized_weights_path={relative path of the safetensors with your model path}`.

## Gemm Tuning for Fp8

To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model.

To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and then run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`.
The custom PagedAttention kernel is enabled for dtype: bf16, fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.

Next, run gradlib to obtain the best solutions of these shapes:
## NCCL Performance environment variable

```
python3 gradlib/gradlib/gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv --indtype fp8 --outdtype f16
```
where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer.
For MI300x, setting environment variable NCCL_MIN_NCHANNELS=112 is expected to improve performance.

Now, when running inference with fp8, we are using the tuned gemm for best performance.
31 changes: 15 additions & 16 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>

namespace py = pybind11;
#include "core/registration.h"

// declare templates for front (cpp) and back (cuda) sides of function:
// template <typename T>

void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block);
void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block) {
void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block) {
auto M = in_a.size(0);
auto K = in_a.size(1);
LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K,
Expand All @@ -21,10 +20,10 @@ void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block);

// template <typename T>
void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block) {
int M = in_a.size(0);
int K = in_a.size(1);
void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block) {
auto M = in_a.size(0);
auto K = in_a.size(1);
// if (N != in_b.numel())
// throw std::invalid_argument("Size mismatch A.numel(): " +
// std::to_string(in_a.numel())
Expand All @@ -41,10 +40,10 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K,
const int N, cudaStream_t stream, const int CuCount);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, int64_t N_in,
int64_t CuCount) {
int M = in_a.size(0);
int K = in_a.size(1);
void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t N_in, const int64_t CuCount) {
auto M = in_a.size(0);
auto K = in_a.size(1);
int N = N_in;
wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N,
at::cuda::getCurrentCUDAStream(), CuCount);
Expand All @@ -54,9 +53,9 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
const int solidx = 0) {
int M = in_a.size(0);
int K = in_a.size(1);
const int64_t solidx = 0) {
auto M = in_a.size(0);
auto K = in_a.size(1);

LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K,
at::cuda::getCurrentCUDAStream(), solidx);
Expand All @@ -69,7 +68,7 @@ void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows,
int numAColumns, int numBRows, int numBColumns, int numCRows,
int numCColumns, cudaStream_t stream);

void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) {
auto matA_sizes{in_a.sizes()};
auto matB_sizes{in_b.sizes()};
auto matO_sizes{out_c.sizes()};
Expand Down
3 changes: 1 addition & 2 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <cuda_fp16.h>
#include <stdexcept>
#include <algorithm>
#include "cuda_compat.h"

#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
Expand All @@ -17,8 +18,6 @@
#define UNREACHABLE_CODE assert(false);
#endif

constexpr int WARP_SIZE = 64;

template <typename T>
__device__ __forceinline__ T loadnt(T* addr) {
return __builtin_nontemporal_load(addr);
Expand Down
12 changes: 6 additions & 6 deletions csrc/custom/custom_ops.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#pragma once
#include <torch/all.h>

void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block);
void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block);

void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block);
void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, int64_t N_in,
int64_t CuCount);
void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t N_in, const int64_t CuCount);

void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
Expand Down
2 changes: 1 addition & 1 deletion csrc/custom/paged_attention/attention_ll4mi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"

#include <algorithm>

Expand All @@ -23,7 +24,6 @@
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define WARP_SIZE 64

#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

Expand Down
35 changes: 17 additions & 18 deletions csrc/custom/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,28 @@

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, custom_ops) {
custom_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> ()"
);
"LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> "
"()");
custom_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
custom_ops.def(
"LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> ()"
);
"LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) "
"-> ()");
custom_ops.impl("LLMM_Silu", torch::kCUDA, &LLMM_Silu);
custom_ops.def(
"paged_attention_custom(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()"
);
custom_ops.impl("paged_attention_custom", torch::kCUDA, &paged_attention_custom);
"paged_attention_custom(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()");
custom_ops.impl("paged_attention_custom", torch::kCUDA,
&paged_attention_custom);
custom_ops.def(
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
" int CuCount) -> ()"
);
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
" int CuCount) -> ()");
custom_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
21 changes: 8 additions & 13 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,19 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,

#ifdef USE_ROCM

void free_meta_buffer(void* buffer) { hipFree(buffer); }
void free_meta_buffer(void* buffer) { CUDACHECK(cudaFree(buffer)); }

std::vector<uint8_t> get_meta_buffer_ipc_handle(torch::Tensor inp) {
std::vector<uint8_t> data_handle(sizeof(cudaIpcMemHandle_t), 0);
CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data(),
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}

torch::Tensor allocate_meta_buffer(int size) {
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::cuda::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
Expand All @@ -181,12 +184,4 @@ torch::Tensor allocate_meta_buffer(int size) {
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}

std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(cudaDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}

#endif
5 changes: 2 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#ifdef USE_ROCM
torch::Tensor allocate_meta_buffer(int size);
std::vector<uint8_t> get_meta_buffer_ipc_handle(torch::Tensor inp);
std::vector<uint8_t> get_device_bdf(int dev);
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#endif
9 changes: 7 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
&get_max_shared_memory_per_block_device_attribute);
}

#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def("init_custom_ar", &init_custom_ar);
Expand Down Expand Up @@ -373,7 +372,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.impl("register_graph_buffers", torch::kCPU,
&register_graph_buffers);
}
#ifdef USE_ROCM
custom_ar.def("allocate_meta_buffer", &allocate_meta_buffer);
custom_ar.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
custom_ar.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
custom_ar.impl("get_meta_buffer_ipc_handle", torch::kCPU,
&get_meta_buffer_ipc_handle);
#endif
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
14 changes: 5 additions & 9 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify
$ DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below:

.. code-block:: console
$ DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
-f Dockerfile.rocm -t vllm-rocm .
To run the above docker image ``vllm-rocm``, use the below command:

.. code-block:: console
Expand Down Expand Up @@ -160,10 +153,13 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases.
.. tip::

- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of pytorch, ideally, should match the ROCm driver version.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.


.. tip::
- For MI300x (gfx942) users, to achieve optimal performance, please refer to `MI300x tuning guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html>`_ for performance optimization and tuning tips on system and workflow level.
For vLLM, please refer to `vLLM performance optimization <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization>`_.


14 changes: 7 additions & 7 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm import envs
import vllm.envs as envs
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.models.mixtral import MixtralMoE
Expand Down Expand Up @@ -97,14 +97,14 @@ def test_mixtral_moe(dtype: torch.dtype):

# pad the weight if using padding
if envs.VLLM_MOE_PADDING:
w13_weight = F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant",
0)
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()
w2_weight = F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w13_weight = Parameter(w13_weight,
requires_grad=False)
vllm_moe.experts.w2_weight = Parameter(w2_weight, requires_grad=False)

# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
Expand Down
Loading

0 comments on commit 05e67ab

Please sign in to comment.