Skip to content

Commit

Permalink
Support Deepseek MoE Model (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Jul 21, 2024
1 parent 5a4ef2b commit eedc12e
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 23 deletions.
57 changes: 38 additions & 19 deletions python/sglang/srt/managers/controller/cuda_graph_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Run the model with cuda graph."""

import bisect
from contextlib import contextmanager

import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
Expand All @@ -15,9 +16,10 @@
InputMetadata,
init_flashinfer_args,
)
from sglang.srt.utils import monkey_patch_vllm_all_gather


def _to_torch(model: torch.nn.Module, reverse=False):
def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
Expand All @@ -28,13 +30,26 @@ def _to_torch(model: torch.nn.Module, reverse=False):
_to_torch(sub, reverse)


def get_forward(model: torch.nn.Module, use_torch: bool):
if use_torch:
_to_torch(model, reverse=False)
return torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
else:
_to_torch(model, reverse=True)
return model.forward
@contextmanager
def patch_model(
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
):
backup_ca_comm = None

try:
if use_compile:
_to_torch(model)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
tp_group.ca_comm = None
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
else:
yield model.forward
finally:
if use_compile:
_to_torch(model, reverse=True)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm


class CudaGraphRunner:
Expand Down Expand Up @@ -86,17 +101,21 @@ def capture(self, batch_size_list):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in batch_size_list:
forward = get_forward(self.model_runner.model, bs in self.compile_bs)
(
graph,
input_buffers,
output_buffers,
flashinfer_handler,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
self.model_runner.tp_group,
) as forward:
(
graph,
input_buffers,
output_buffers,
flashinfer_handler,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler

def capture_one_batch_size(self, bs, forward):
graph = torch.cuda.CUDAGraph()
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry

from sglang.global_config import global_config
Expand Down Expand Up @@ -241,7 +240,9 @@ def init_cuda_graphs(self):
self.cuda_graph_runner = None
return

logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
logger.info(
f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
self.cuda_graph_runner = CudaGraphRunner(
self,
Expand All @@ -252,7 +253,7 @@ def init_cuda_graphs(self):
self.cuda_graph_runner.capture(batch_size_list)
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed {e}. Possible solutions:\n"
f"Capture cuda graph failed: {e}. Possible solutions:\n"
f"1. disable cuda graph by --disable-cuda-graph\n"
f"2. set --mem-fraction-static to a smaller value\n"
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
Expand Down
Loading

0 comments on commit eedc12e

Please sign in to comment.