From e7318dbcb2a65377b8e62490d3a082a4ad33dce3 Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Sun, 18 Aug 2024 23:37:10 +0000 Subject: [PATCH 01/10] support update model weights --- .../srt/managers/detokenizer_manager.py | 5 + python/sglang/srt/managers/io_struct.py | 11 +++ .../sglang/srt/managers/tokenizer_manager.py | 39 +++++++- python/sglang/srt/managers/tp_worker.py | 11 +++ .../sglang/srt/model_executor/model_runner.py | 98 +++++++++++++++++-- python/sglang/srt/server.py | 19 ++++ 6 files changed, 173 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 12511ac44e5..9d68f6a53e0 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -28,6 +28,7 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + UpdateWeightReqOutput ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs @@ -84,6 +85,10 @@ async def handle_loop(self): ) continue + if isinstance(recv_obj, UpdateWeightReqOutput): + self.send_to_tokenizer.send_pyobj(recv_obj) + continue + assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 82f280b6062..9c3a9d09bf2 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -278,6 +278,17 @@ class FlushCacheReq: pass +@dataclass +class UpdateWeightReqInput: + model_path: str + load_format: str + + +@dataclass +class UpdateWeightReqOutput: + success: bool + message: str + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d5fbfe05d3b..1685f07e451 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,8 @@ GenerateReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateWeightReqInput, + UpdateWeightReqOutput ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams @@ -65,6 +67,10 @@ class ReqState: event: asyncio.Event +model_update_lock = asyncio.Lock() +model_update_result = None + + class TokenizerManager: def __init__( self, @@ -146,6 +152,11 @@ async def generate_request( if self.to_create_loop: self.create_handle_loop() + global model_update_lock + + while model_update_lock.locked(): + await asyncio.sleep(0.001) + obj.post_init() is_single = obj.is_single @@ -493,6 +504,26 @@ def flush_cache(self): req = FlushCacheReq() self.send_to_router.send_pyobj(req) + async def update_weights(self, obj: UpdateWeightReqInput, request): + global model_update_lock + global model_update_result + if self.to_create_loop: + self.create_handle_loop() + if not model_update_lock.locked(): + async with model_update_lock: + while len(self.rid_to_state) > 0: + await asyncio.sleep(1) + self.send_to_router.send_pyobj(obj) + model_update_result = asyncio.Future() + result = await model_update_result + if result.success: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + + return result.success, result.message + else: + return False, "Another update is in progress. Please try again later." + def abort_request(self, rid: str): if rid not in self.rid_to_state: return @@ -521,12 +552,18 @@ def create_handle_loop(self): async def handle_loop(self): while True: - recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = ( + recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput] = ( await self.recv_from_detokenizer.recv_pyobj() ) + + if isinstance(recv_obj, UpdateWeightReqOutput): + model_update_result.set_result(recv_obj) + continue + assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) ), f"Unexpected obj received: {type(recv_obj)}" + for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b6cfa68bd4a..c911f9f99c8 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -39,6 +39,8 @@ FlushCacheReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateWeightReqInput, + UpdateWeightReqOutput ) from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.schedule_batch import ( @@ -214,6 +216,9 @@ def exposed_step(self, recv_reqs: List): self.flush_cache() elif isinstance(recv_req, AbortReq): self.abort_request(recv_req) + elif isinstance(recv_req, UpdateWeightReqInput): + success, message = self.update_weights(recv_req) + self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) else: raise ValueError(f"Invalid request: {recv_req}") @@ -797,6 +802,12 @@ def abort_request(self, recv_req): if req.rid == recv_req.rid: req.finished_reason = FINISH_ABORT() break + + def update_weights(self, recv_req): + success, message = self.model_runner.update_weights(recv_req.model_path, recv_req.load_format) + if success: + self.flush_cache() + return success, message def run_tp_server( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b74a19e60df..014840baa90 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -22,6 +22,7 @@ import warnings from functools import lru_cache from typing import Optional, Type +import gc import torch import torch.nn as nn @@ -155,9 +156,9 @@ def load_model(self): self.server_args.dtype = "float16" monkey_patch_vllm_dummy_weight_loader() - device_config = DeviceConfig() - load_config = LoadConfig(load_format=self.server_args.load_format) - vllm_model_config = VllmModelConfig( + self.device_config = DeviceConfig() + self.load_config = LoadConfig(load_format=self.server_args.load_format) + self.vllm_model_config = VllmModelConfig( model=self.server_args.model_path, quantization=self.server_args.quantization, tokenizer=None, @@ -171,17 +172,17 @@ def load_model(self): if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints self.model_config.hf_config.num_key_value_heads = 8 - vllm_model_config.hf_config.num_key_value_heads = 8 + self.vllm_model_config.hf_config.num_key_value_heads = 8 monkey_patch_vllm_qvk_linear_loader() - self.dtype = vllm_model_config.dtype + self.dtype = self.vllm_model_config.dtype if self.model_config.model_overide_args is not None: - vllm_model_config.hf_config.update(self.model_config.model_overide_args) + self.vllm_model_config.hf_config.update(self.model_config.model_overide_args) self.model = get_model( - model_config=vllm_model_config, - device_config=device_config, - load_config=load_config, + model_config=self.vllm_model_config, + device_config=self.device_config, + load_config=self.load_config, lora_config=None, multimodal_config=None, parallel_config=None, @@ -204,6 +205,85 @@ def load_model(self): f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) + def update_weights(self, model_path, load_format): + from vllm.model_executor.model_loader.utils import set_default_torch_dtype + from vllm.model_executor.model_loader.loader import get_model_loader + from vllm.model_executor.model_loader.loader import device_loading_context + from vllm.model_executor.model_loader.loader import DefaultModelLoader + from vllm.model_executor.model_loader.utils import get_model_architecture + target_device = torch.device(self.device_config.device) + + try: + model_config = VllmModelConfig( + model=model_path, + quantization=self.server_args.quantization, + tokenizer=None, + tokenizer_mode=None, + trust_remote_code=self.server_args.trust_remote_code, + dtype=self.server_args.dtype, + seed=42, + skip_tokenizer_init=True, + ) + except Exception as e: + logger.error(f"Failed to load model config: {e}") + return False, "Failed to update model weights" + + logger.info("start updating weights") + + load_config = LoadConfig(load_format=self.server_args.load_format) + loader = get_model_loader(load_config) + if not isinstance(loader, DefaultModelLoader): + logger.error("Failed to get weights iterator: Unsupported loader") + return False, "Failed to update model weights" + + + def get_weight_iter(config): + iter = loader._get_weights_iterator(config.model, + config.revision, + fall_back_to_pt=getattr( + self.model, + "fall_back_to_pt_during_load", + True)) + return iter + + def model_load_weights(model, iter): + model.load_weights(iter) + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model + + with set_default_torch_dtype(model_config.dtype): + try: + iter = get_weight_iter(model_config) + except Exception as e: + logger.error(f"Failed to get weights iterator: {e}") + return False, "Failed to update model" + try: + model = model_load_weights(self.model, iter) + except Exception as e: + logger.error(f"Failed to update weights: {e}. \n Rolling back to original weights") + del iter + gc.collect() + iter = get_weight_iter(self.vllm_model_config) + self.model = model_load_weights(self.model, iter) + return False, "Failed to update model" + + self.server_args.model_path = model_path + self.server_args.load_format = load_format + self.vllm_model_config = model_config + self.load_config = load_config + + logger.info("finish updating weights") + return True, "Updating model weights succeeded" + def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 9028c12309b..bcca6d93223 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -52,6 +52,7 @@ ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput +from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, @@ -118,6 +119,24 @@ async def flush_cache(): status_code=200, ) +@app.post("/update_weights") +async def update_weights(obj: UpdateWeightReqInput, request: Request): + + success, message = await tokenizer_manager.update_weights(obj, request) + content = {"message": message, "success": str(success)} + print(content) + if success: + return JSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return JSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" From 8657c04407fd844af41c3d7acaee157bcd45fd14 Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Sun, 18 Aug 2024 23:37:41 +0000 Subject: [PATCH 02/10] fix format --- .../srt/managers/detokenizer_manager.py | 2 +- python/sglang/srt/managers/io_struct.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 8 ++-- python/sglang/srt/managers/tp_worker.py | 8 ++-- .../sglang/srt/model_executor/model_runner.py | 45 +++++++++++-------- python/sglang/srt/server.py | 15 ++++--- 6 files changed, 47 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 9d68f6a53e0..e1402795fbf 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -28,7 +28,7 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, - UpdateWeightReqOutput + UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9c3a9d09bf2..08f02afda5d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -289,6 +289,7 @@ class UpdateWeightReqOutput: success: bool message: str + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1685f07e451..efad297ab01 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -47,7 +47,7 @@ TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightReqInput, - UpdateWeightReqOutput + UpdateWeightReqOutput, ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams @@ -552,9 +552,9 @@ def create_handle_loop(self): async def handle_loop(self): while True: - recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput] = ( - await self.recv_from_detokenizer.recv_pyobj() - ) + recv_obj: Union[ + BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput + ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, UpdateWeightReqOutput): model_update_result.set_result(recv_obj) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index c911f9f99c8..a83ad2378fd 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -40,7 +40,7 @@ TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightReqInput, - UpdateWeightReqOutput + UpdateWeightReqOutput, ) from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.schedule_batch import ( @@ -802,9 +802,11 @@ def abort_request(self, recv_req): if req.rid == recv_req.rid: req.finished_reason = FINISH_ABORT() break - + def update_weights(self, recv_req): - success, message = self.model_runner.update_weights(recv_req.model_path, recv_req.load_format) + success, message = self.model_runner.update_weights( + recv_req.model_path, recv_req.load_format + ) if success: self.flush_cache() return success, message diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 014840baa90..8a130d56cd3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -15,6 +15,7 @@ """ModelRunner runs the forward passes of the models.""" +import gc import importlib import importlib.resources import logging @@ -22,7 +23,6 @@ import warnings from functools import lru_cache from typing import Optional, Type -import gc import torch import torch.nn as nn @@ -177,7 +177,9 @@ def load_model(self): self.dtype = self.vllm_model_config.dtype if self.model_config.model_overide_args is not None: - self.vllm_model_config.hf_config.update(self.model_config.model_overide_args) + self.vllm_model_config.hf_config.update( + self.model_config.model_overide_args + ) self.model = get_model( model_config=self.vllm_model_config, @@ -206,11 +208,16 @@ def load_model(self): ) def update_weights(self, model_path, load_format): - from vllm.model_executor.model_loader.utils import set_default_torch_dtype - from vllm.model_executor.model_loader.loader import get_model_loader - from vllm.model_executor.model_loader.loader import device_loading_context - from vllm.model_executor.model_loader.loader import DefaultModelLoader - from vllm.model_executor.model_loader.utils import get_model_architecture + from vllm.model_executor.model_loader.loader import ( + DefaultModelLoader, + device_loading_context, + get_model_loader, + ) + from vllm.model_executor.model_loader.utils import ( + get_model_architecture, + set_default_torch_dtype, + ) + target_device = torch.device(self.device_config.device) try: @@ -229,23 +236,23 @@ def update_weights(self, model_path, load_format): return False, "Failed to update model weights" logger.info("start updating weights") - + load_config = LoadConfig(load_format=self.server_args.load_format) loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): logger.error("Failed to get weights iterator: Unsupported loader") return False, "Failed to update model weights" - def get_weight_iter(config): - iter = loader._get_weights_iterator(config.model, - config.revision, - fall_back_to_pt=getattr( - self.model, - "fall_back_to_pt_during_load", - True)) + iter = loader._get_weights_iterator( + config.model, + config.revision, + fall_back_to_pt=getattr( + self.model, "fall_back_to_pt_during_load", True + ), + ) return iter - + def model_load_weights(model, iter): model.load_weights(iter) for _, module in self.model.named_modules(): @@ -269,7 +276,9 @@ def model_load_weights(model, iter): try: model = model_load_weights(self.model, iter) except Exception as e: - logger.error(f"Failed to update weights: {e}. \n Rolling back to original weights") + logger.error( + f"Failed to update weights: {e}. \n Rolling back to original weights" + ) del iter gc.collect() iter = get_weight_iter(self.vllm_model_config) @@ -280,7 +289,7 @@ def model_load_weights(model, iter): self.server_args.load_format = load_format self.vllm_model_config = model_config self.load_config = load_config - + logger.info("finish updating weights") return True, "Updating model weights succeeded" diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index bcca6d93223..621232a9418 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -51,8 +51,11 @@ start_controller_process as start_controller_process_single, ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput -from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + UpdateWeightReqInput, +) from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, @@ -119,6 +122,7 @@ async def flush_cache(): status_code=200, ) + @app.post("/update_weights") async def update_weights(obj: UpdateWeightReqInput, request: Request): @@ -127,9 +131,9 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): print(content) if success: return JSONResponse( - content, - status_code=HTTPStatus.OK, - ) + content, + status_code=HTTPStatus.OK, + ) else: return JSONResponse( content, @@ -137,7 +141,6 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ) - async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: From 3bc24b21d3a11de5955c20b1a5fed53b6e50023c Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 01:06:47 +0000 Subject: [PATCH 03/10] refactor and add logs --- .../sglang/srt/managers/tokenizer_manager.py | 30 +++++++++---------- .../sglang/srt/model_executor/model_runner.py | 27 +++++++++-------- python/sglang/srt/server.py | 1 - 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index efad297ab01..d3c78b8f62d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -67,10 +67,6 @@ class ReqState: event: asyncio.Event -model_update_lock = asyncio.Lock() -model_update_result = None - - class TokenizerManager: def __init__( self, @@ -127,6 +123,10 @@ def __init__( self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} + # for update model weights + self.model_update_lock = asyncio.Lock() + self.model_update_result = None + async def get_pixel_values(self, image_data): aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) grid_pinpoints = ( @@ -152,9 +152,7 @@ async def generate_request( if self.to_create_loop: self.create_handle_loop() - global model_update_lock - - while model_update_lock.locked(): + while self.model_update_lock.locked(): await asyncio.sleep(0.001) obj.post_init() @@ -505,21 +503,21 @@ def flush_cache(self): self.send_to_router.send_pyobj(req) async def update_weights(self, obj: UpdateWeightReqInput, request): - global model_update_lock - global model_update_result if self.to_create_loop: self.create_handle_loop() - if not model_update_lock.locked(): - async with model_update_lock: + + if not self.model_update_lock.locked(): + async with self.model_update_lock: + # wait for the previous generation requests to finish while len(self.rid_to_state) > 0: - await asyncio.sleep(1) + await asyncio.sleep(0.1) self.send_to_router.send_pyobj(obj) - model_update_result = asyncio.Future() - result = await model_update_result + self.model_update_result = asyncio.Future() + result = await self.model_update_result if result.success: self.server_args.model_path = obj.model_path self.server_args.load_format = obj.load_format - + self.model_path = obj.model_path return result.success, result.message else: return False, "Another update is in progress. Please try again later." @@ -557,7 +555,7 @@ async def handle_loop(self): ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, UpdateWeightReqOutput): - model_update_result.set_result(recv_obj) + self.model_update_result.set_result(recv_obj) continue assert isinstance( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8a130d56cd3..0db65069cab 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -214,14 +214,18 @@ def update_weights(self, model_path, load_format): get_model_loader, ) from vllm.model_executor.model_loader.utils import ( - get_model_architecture, set_default_torch_dtype, ) + logger.info( + f"[gpu={self.gpu_id}] Update weights begin. " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) + target_device = torch.device(self.device_config.device) try: - model_config = VllmModelConfig( + vllm_model_config = VllmModelConfig( model=model_path, quantization=self.server_args.quantization, tokenizer=None, @@ -235,9 +239,9 @@ def update_weights(self, model_path, load_format): logger.error(f"Failed to load model config: {e}") return False, "Failed to update model weights" - logger.info("start updating weights") + load_config = LoadConfig(load_format=load_format) - load_config = LoadConfig(load_format=self.server_args.load_format) + # Only support vllm DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): logger.error("Failed to get weights iterator: Unsupported loader") @@ -258,18 +262,13 @@ def model_load_weights(model, iter): for _, module in self.model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) return model - with set_default_torch_dtype(model_config.dtype): + with set_default_torch_dtype(vllm_model_config.dtype): try: - iter = get_weight_iter(model_config) + iter = get_weight_iter(vllm_model_config) except Exception as e: logger.error(f"Failed to get weights iterator: {e}") return False, "Failed to update model" @@ -285,12 +284,14 @@ def model_load_weights(model, iter): self.model = model_load_weights(self.model, iter) return False, "Failed to update model" + self.model = model self.server_args.model_path = model_path self.server_args.load_format = load_format - self.vllm_model_config = model_config + self.vllm_model_config = vllm_model_config self.load_config = load_config + self.model_config.path = model_path - logger.info("finish updating weights") + logger.info(f"[gpu={self.gpu_id}] Update weights end.") return True, "Updating model weights succeeded" def profile_max_num_token(self, total_gpu_memory): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 621232a9418..ed20d0d4610 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -128,7 +128,6 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): success, message = await tokenizer_manager.update_weights(obj, request) content = {"message": message, "success": str(success)} - print(content) if success: return JSONResponse( content, From c12d2992fb39caef3a4a96b2fd74c06eb848f18d Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 01:29:23 +0000 Subject: [PATCH 04/10] allow load_format to be omited --- python/sglang/srt/managers/io_struct.py | 4 +++- python/sglang/srt/managers/tokenizer_manager.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 08f02afda5d..630439e4a97 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -280,8 +280,10 @@ class FlushCacheReq: @dataclass class UpdateWeightReqInput: + # The model path with the new weights model_path: str - load_format: str + # The format to load the weights + load_format: Optional[str] = None @dataclass diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d3c78b8f62d..d1a6b1d4887 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -506,6 +506,10 @@ async def update_weights(self, obj: UpdateWeightReqInput, request): if self.to_create_loop: self.create_handle_loop() + # default the load format to the server_args + if obj.load_format is None: + obj.load_format = self.server_args.load_format + if not self.model_update_lock.locked(): async with self.model_update_lock: # wait for the previous generation requests to finish From 89200658e63eebcff4b539b87a390eb30bb16b3c Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 01:29:33 +0000 Subject: [PATCH 05/10] add tests --- test/srt/test_update_weights.py | 98 +++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 test/srt/test_update_weights.py diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py new file mode 100644 index 00000000000..cdcd36dadca --- /dev/null +++ b/test/srt/test_update_weights.py @@ -0,0 +1,98 @@ +import json +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_UNIT_TEST, + popen_launch_server, +) + + +class TestReplaceWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "n": 1, + }, + "stream": False, + "return_logprob": False, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + # return the "text" in response + text = response.json()["text"] + return text + + def get_model_info(self): + response = requests.get(self.base_url + "/get_model_info") + model_path = response.json()["model_path"] + print(json.dumps(response.json())) + return model_path + + def run_update_weights(self, model_path): + response = requests.post( + self.base_url + "/update_weights", + json={ + "model_path": model_path, + }, + ) + print(json.dumps(response.json())) + + def test_replace_weights(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = "meta-llama/Meta-Llama-3.1-8B" + self.run_update_weights(new_model_path) + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == new_model_path + assert updated_model_path != origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] != updated_response[:32] + + def test_replace_weights_unexist_model(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" + self.run_update_weights(new_model_path) + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + + +if __name__ == "__main__": + unittest.main() From 0a509ffafc7fabcf04c8bc1d3edede6c48bf53d4 Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 01:30:36 +0000 Subject: [PATCH 06/10] fix format --- python/sglang/srt/model_executor/model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0db65069cab..dcef58bbf21 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -213,9 +213,7 @@ def update_weights(self, model_path, load_format): device_loading_context, get_model_loader, ) - from vllm.model_executor.model_loader.utils import ( - set_default_torch_dtype, - ) + from vllm.model_executor.model_loader.utils import set_default_torch_dtype logger.info( f"[gpu={self.gpu_id}] Update weights begin. " From b224c08a3d15a5a81ee84158e6e8a1dd8cf24d84 Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 17:23:24 +0000 Subject: [PATCH 07/10] address comments --- python/sglang/srt/managers/tp_worker.py | 6 +++++- python/sglang/srt/model_executor/model_runner.py | 12 ++++++------ test/srt/test_update_weights.py | 8 ++++++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a83ad2378fd..9d6f4917a2e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -778,12 +778,15 @@ def flush_cache(self): self.token_to_kv_pool.clear() torch.cuda.empty_cache() logger.info("Cache flushed successfully!") + if_success = True else: warnings.warn( f"Cache not flushed because there are pending requests. " f"#queue-req: {len(self.waiting_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) + if_success = False + return if_success def abort_request(self, recv_req): # Delete requests in the waiting queue @@ -808,7 +811,8 @@ def update_weights(self, recv_req): recv_req.model_path, recv_req.load_format ) if success: - self.flush_cache() + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" return success, message diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index dcef58bbf21..1dbd4a7ba7a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -268,19 +268,19 @@ def model_load_weights(model, iter): try: iter = get_weight_iter(vllm_model_config) except Exception as e: - logger.error(f"Failed to get weights iterator: {e}") - return False, "Failed to update model" + message = f"Failed to get weights iterator: {e}" + logger.error(message) + return False, message try: model = model_load_weights(self.model, iter) except Exception as e: - logger.error( - f"Failed to update weights: {e}. \n Rolling back to original weights" - ) + message = f"Failed to update weights: {e}. \n Rolling back to original weights" + logger.error(message) del iter gc.collect() iter = get_weight_iter(self.vllm_model_config) self.model = model_load_weights(self.model, iter) - return False, "Failed to update model" + return False, message self.model = model self.server_args.model_path = model_path diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index cdcd36dadca..64f84263aa9 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -77,6 +77,14 @@ def test_replace_weights(self): updated_response = self.run_decode() assert origin_response[:32] != updated_response[:32] + # update weights back + self.run_update_weights(origin_model_path) + updated_model_path = self.get_model_info() + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + def test_replace_weights_unexist_model(self): origin_model_path = self.get_model_info() print(f"origin_model_path: {origin_model_path}") From 8d8b8e92e304468fce111ff97d24b7e508db2a14 Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 17:29:32 +0000 Subject: [PATCH 08/10] change the wait time of generation requests while there is an ongoing update_weights request --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8a6b9783f6c..ebb28b6350b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -153,7 +153,7 @@ async def generate_request( self.create_handle_loop() while self.model_update_lock.locked(): - await asyncio.sleep(0.001) + await asyncio.sleep(0.1) obj.post_init() is_single = obj.is_single From bb7275701d1ed5c563728a2fb8b4691f247110f4 Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 11:17:21 -0700 Subject: [PATCH 09/10] Apply suggestions from code review Co-authored-by: Yineng Zhang --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ee75e0e939e..15d167b6188 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -153,7 +153,7 @@ async def generate_request( self.create_handle_loop() while self.model_update_lock.locked(): - await asyncio.sleep(0.1) + await asyncio.sleep(0) obj.post_init() is_single = obj.is_single diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 22ecba729c4..4a3396cf2c3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -292,7 +292,7 @@ def model_load_weights(model, iter): self.model_config.path = model_path logger.info(f"[gpu={self.gpu_id}] Update weights end.") - return True, "Updating model weights succeeded" + return True, "Succeeded to update model weights" def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( From 488120e9ab7ad386beb92e67793437ded2f0af8d Mon Sep 17 00:00:00 2001 From: Shan Yu Date: Tue, 20 Aug 2024 18:19:16 +0000 Subject: [PATCH 10/10] Apply suggestions from code review --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 15d167b6188..ab375a39a95 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -534,7 +534,7 @@ async def update_weights(self, obj: UpdateWeightReqInput, request): async with self.model_update_lock: # wait for the previous generation requests to finish while len(self.rid_to_state) > 0: - await asyncio.sleep(0.1) + await asyncio.sleep(0) self.send_to_router.send_pyobj(obj) self.model_update_result = asyncio.Future() result = await self.model_update_result