Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Support update weights without restart server #1157

Merged
merged 12 commits into from
Aug 20, 2024
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
Expand Down Expand Up @@ -84,6 +85,10 @@ async def handle_loop(self):
)
continue

if isinstance(recv_obj, UpdateWeightReqOutput):
self.send_to_tokenizer.send_pyobj(recv_obj)
continue

assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)

Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,20 @@ class FlushCacheReq:
pass


@dataclass
class UpdateWeightReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
success: bool
message: str


@dataclass
class AbortReq:
# The request id
Expand Down
45 changes: 42 additions & 3 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
GenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
Expand Down Expand Up @@ -121,6 +123,10 @@ def __init__(
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}

# for update model weights
self.model_update_lock = asyncio.Lock()
self.model_update_result = None

async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
Expand All @@ -146,6 +152,9 @@ async def generate_request(
if self.to_create_loop:
self.create_handle_loop()

while self.model_update_lock.locked():
await asyncio.sleep(0.001)

obj.post_init()
is_single = obj.is_single

Expand Down Expand Up @@ -500,6 +509,30 @@ def flush_cache(self):
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)

async def update_weights(self, obj: UpdateWeightReqInput, request):
if self.to_create_loop:
self.create_handle_loop()

# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format

if not self.model_update_lock.locked():
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.1)
self.send_to_router.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
if result.success:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else:
return False, "Another update is in progress. Please try again later."

def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
Expand Down Expand Up @@ -528,12 +561,18 @@ def create_handle_loop(self):

async def handle_loop(self):
while True:
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
await self.recv_from_detokenizer.recv_pyobj()
)
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
] = await self.recv_from_detokenizer.recv_pyobj()

if isinstance(recv_obj, UpdateWeightReqOutput):
self.model_update_result.set_result(recv_obj)
continue

assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"

for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from sglang.srt.managers.schedule_batch import (
Expand Down Expand Up @@ -214,6 +216,9 @@ def exposed_step(self, recv_reqs: List):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
else:
raise ValueError(f"Invalid request: {recv_req}")

Expand Down Expand Up @@ -798,6 +803,14 @@ def abort_request(self, recv_req):
req.finished_reason = FINISH_ABORT()
break

def update_weights(self, recv_req):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
)
if success:
self.flush_cache()
return success, message


def run_tp_server(
gpu_id: int,
Expand Down
106 changes: 97 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""ModelRunner runs the forward passes of the models."""

import gc
import importlib
import importlib.resources
import logging
Expand Down Expand Up @@ -155,9 +156,9 @@ def load_model(self):
self.server_args.dtype = "float16"

monkey_patch_vllm_dummy_weight_loader()
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig(
self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
Expand All @@ -171,17 +172,19 @@ def load_model(self):
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()

self.dtype = vllm_model_config.dtype
self.dtype = self.vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
self.vllm_model_config.hf_config.update(
self.model_config.model_overide_args
)

self.model = get_model(
model_config=vllm_model_config,
device_config=device_config,
load_config=load_config,
model_config=self.vllm_model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=None,
multimodal_config=None,
parallel_config=None,
Expand All @@ -204,6 +207,91 @@ def load_model(self):
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)

def update_weights(self, model_path, load_format):
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype

logger.info(
f"[gpu={self.gpu_id}] Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)

target_device = torch.device(self.device_config.device)

try:
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=42,
skip_tokenizer_init=True,
)
except Exception as e:
logger.error(f"Failed to load model config: {e}")
return False, "Failed to update model weights"

load_config = LoadConfig(load_format=load_format)

# Only support vllm DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
logger.error("Failed to get weights iterator: Unsupported loader")
return False, "Failed to update model weights"

def get_weight_iter(config):
iter = loader._get_weights_iterator(
config.model,
config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
)
return iter

def model_load_weights(model, iter):
model.load_weights(iter)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model

with set_default_torch_dtype(vllm_model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
except Exception as e:
logger.error(f"Failed to get weights iterator: {e}")
return False, "Failed to update model"
try:
model = model_load_weights(self.model, iter)
except Exception as e:
logger.error(
f"Failed to update weights: {e}. \n Rolling back to original weights"
)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
self.model = model_load_weights(self.model, iter)
return False, "Failed to update model"

self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path

logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Updating model weights succeeded"

def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
Expand Down
23 changes: 22 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
Expand Down Expand Up @@ -119,6 +123,23 @@ async def flush_cache():
)


@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):

success, message = await tokenizer_manager.update_weights(obj, request)
content = {"message": message, "success": str(success)}
if success:
return JSONResponse(
content,
status_code=HTTPStatus.OK,
)
else:
return JSONResponse(
content,
status_code=HTTPStatus.BAD_REQUEST,
)


async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
Expand Down
Loading
Loading