Skip to content

Commit

Permalink
Merge branch 'main' into fix_gpu_mem_clear_function_of_test_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ShangmingCai committed Sep 10, 2024
2 parents 2fabef0 + da1a844 commit cf2375b
Show file tree
Hide file tree
Showing 34 changed files with 1,104 additions and 295 deletions.
13 changes: 12 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,18 @@ steps:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt

- label: Weight Loading Multiple GPU Test - Large Models # optional
working_dir: "/vllm-workspace/tests"
num_gpus: 2
gpu: a100
optional: true
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt


##### multi gpus test #####
Expand Down
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ Easy, fast, and cheap LLM serving for everyone

---

**vLLM & NVIDIA Triton User Meetup (Monday, September 9, 5pm-9pm PT) at Fort Mason, San Francisco**
**vLLM, AMD, Anyscale Meet & Greet at [Ray Summit 2024](http://raysummit.anyscale.com) (Monday, Sept 30th, 5-7pm PT) at Marriott Marquis San Francisco**

We are excited to announce our sixth vLLM Meetup, in collaboration with NVIDIA Triton Team.
Join us to hear the vLLM's recent update about performance.
Register now [here](https://lu.ma/87q3nvnh) and be part of the event!
We are excited to announce our special vLLM event in collaboration with AMD and Anyscale.
Join us to learn more about recent advancements of vLLM on MI300X.
Register [here](https://lu.ma/db5ld9n5) and be a part of the event!

---

*Latest News* 🔥
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
Expand Down Expand Up @@ -130,3 +131,10 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
year={2023}
}
```

## Contact Us

* For technical questions and feature requests, please use Github issues or discussions.
* For discussing with fellow users, please use Discord.
* For security disclosures, please use Github's security advisory feature.
* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, max_par, replicate_input, apply_weights);
return c;
}
}
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
bool replicate_input, bool apply_weights);
1 change: 0 additions & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");

m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
1 change: 1 addition & 0 deletions docs/source/community/meetups.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ vLLM Meetups

We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:

- `The sixth vLLM meetup <https://lu.ma/87q3nvnh>`__, with NVIDIA, September 9th 2024. `[Slides] <https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing>`__
- `The fifth vLLM meetup <https://lu.ma/lp0gyjqr>`__, with AWS, July 24th 2024. `[Slides] <https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing>`__
- `The fourth vLLM meetup <https://lu.ma/agivllm>`__, with Cloudflare and BentoML, June 11th 2024. `[Slides] <https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing>`__
- `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__
Expand Down
221 changes: 217 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List

import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types


def torch_moe(a, w1, w2, score, topk):
Expand All @@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)


@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand All @@ -43,11 +65,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
torch.manual_seed(7)

if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []

for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)

w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)

w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []

for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)

w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)

triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []

for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)

w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
2 changes: 1 addition & 1 deletion tests/tool_use/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ServerConfig(TypedDict):
CONFIGS: Dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"NousResearch/Hermes-3-Llama-3.1-8B",
"arguments": [
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
Expand Down
3 changes: 3 additions & 0 deletions tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
3 changes: 1 addition & 2 deletions tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
Expand Down
7 changes: 0 additions & 7 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel):
function: Optional[DeltaFunctionCall] = None


# the initial delta that gets sent once a new tool call is started;
class InitialDeltaToolCall(DeltaToolCall):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int


class ExtractedToolCallInformation(BaseModel):
# indicate if tools were called
tools_called: bool
Expand Down
Loading

0 comments on commit cf2375b

Please sign in to comment.