diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 12511ac44e5..e1402795fbf 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -28,6 +28,7 @@ BatchEmbeddingOut, BatchStrOut, BatchTokenIDOut, + UpdateWeightReqOutput, ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs @@ -84,6 +85,10 @@ async def handle_loop(self): ) continue + if isinstance(recv_obj, UpdateWeightReqOutput): + self.send_to_tokenizer.send_pyobj(recv_obj) + continue + assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3a0ecd8f6c8..dc82245931d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -278,6 +278,20 @@ class FlushCacheReq: pass +@dataclass +class UpdateWeightReqInput: + # The model path with the new weights + model_path: str + # The format to load the weights + load_format: Optional[str] = None + + +@dataclass +class UpdateWeightReqOutput: + success: bool + message: str + + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e157217e348..ab375a39a95 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,8 @@ GenerateReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateWeightReqInput, + UpdateWeightReqOutput, ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams @@ -121,6 +123,10 @@ def __init__( self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} + # for update model weights + self.model_update_lock = asyncio.Lock() + self.model_update_result = None + async def get_pixel_values(self, image_data): aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) grid_pinpoints = ( @@ -146,6 +152,9 @@ async def generate_request( if self.to_create_loop: self.create_handle_loop() + while self.model_update_lock.locked(): + await asyncio.sleep(0) + obj.post_init() is_single = obj.is_single @@ -513,6 +522,30 @@ def flush_cache(self): req = FlushCacheReq() self.send_to_router.send_pyobj(req) + async def update_weights(self, obj: UpdateWeightReqInput, request): + if self.to_create_loop: + self.create_handle_loop() + + # default the load format to the server_args + if obj.load_format is None: + obj.load_format = self.server_args.load_format + + if not self.model_update_lock.locked(): + async with self.model_update_lock: + # wait for the previous generation requests to finish + while len(self.rid_to_state) > 0: + await asyncio.sleep(0) + self.send_to_router.send_pyobj(obj) + self.model_update_result = asyncio.Future() + result = await self.model_update_result + if result.success: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + return result.success, result.message + else: + return False, "Another update is in progress. Please try again later." + def abort_request(self, rid: str): if rid not in self.rid_to_state: return @@ -541,12 +574,18 @@ def create_handle_loop(self): async def handle_loop(self): while True: - recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = ( - await self.recv_from_detokenizer.recv_pyobj() - ) + recv_obj: Union[ + BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput + ] = await self.recv_from_detokenizer.recv_pyobj() + + if isinstance(recv_obj, UpdateWeightReqOutput): + self.model_update_result.set_result(recv_obj) + continue + assert isinstance( recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) ), f"Unexpected obj received: {type(recv_obj)}" + for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b8a4576f736..7bd2e381297 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -39,6 +39,8 @@ FlushCacheReq, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, + UpdateWeightReqInput, + UpdateWeightReqOutput, ) from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.schedule_batch import ( @@ -214,6 +216,9 @@ def exposed_step(self, recv_reqs: List): self.flush_cache() elif isinstance(recv_req, AbortReq): self.abort_request(recv_req) + elif isinstance(recv_req, UpdateWeightReqInput): + success, message = self.update_weights(recv_req) + self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) else: raise ValueError(f"Invalid request: {recv_req}") @@ -773,12 +778,15 @@ def flush_cache(self): self.token_to_kv_pool.clear() torch.cuda.empty_cache() logger.info("Cache flushed successfully!") + if_success = True else: logging.warning( f"Cache not flushed because there are pending requests. " f"#queue-req: {len(self.waiting_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) + if_success = False + return if_success def abort_request(self, recv_req): # Delete requests in the waiting queue @@ -798,6 +806,15 @@ def abort_request(self, recv_req): req.finished_reason = FINISH_ABORT() break + def update_weights(self, recv_req): + success, message = self.model_runner.update_weights( + recv_req.model_path, recv_req.load_format + ) + if success: + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" + return success, message + def run_tp_server( gpu_id: int, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bf89c637d9f..4a3396cf2c3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -15,6 +15,7 @@ """ModelRunner runs the forward passes of the models.""" +import gc import importlib import importlib.resources import logging @@ -157,9 +158,9 @@ def load_model(self): self.server_args.dtype = "float16" monkey_patch_vllm_dummy_weight_loader() - device_config = DeviceConfig() - load_config = LoadConfig(load_format=self.server_args.load_format) - vllm_model_config = VllmModelConfig( + self.device_config = DeviceConfig() + self.load_config = LoadConfig(load_format=self.server_args.load_format) + self.vllm_model_config = VllmModelConfig( model=self.server_args.model_path, quantization=self.server_args.quantization, tokenizer=None, @@ -173,17 +174,19 @@ def load_model(self): if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints self.model_config.hf_config.num_key_value_heads = 8 - vllm_model_config.hf_config.num_key_value_heads = 8 + self.vllm_model_config.hf_config.num_key_value_heads = 8 monkey_patch_vllm_qvk_linear_loader() - self.dtype = vllm_model_config.dtype + self.dtype = self.vllm_model_config.dtype if self.model_config.model_overide_args is not None: - vllm_model_config.hf_config.update(self.model_config.model_overide_args) + self.vllm_model_config.hf_config.update( + self.model_config.model_overide_args + ) self.model = get_model( - model_config=vllm_model_config, - device_config=device_config, - load_config=load_config, + model_config=self.vllm_model_config, + device_config=self.device_config, + load_config=self.load_config, lora_config=None, multimodal_config=None, parallel_config=None, @@ -206,6 +209,91 @@ def load_model(self): f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) + def update_weights(self, model_path, load_format): + from vllm.model_executor.model_loader.loader import ( + DefaultModelLoader, + device_loading_context, + get_model_loader, + ) + from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + logger.info( + f"[gpu={self.gpu_id}] Update weights begin. " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) + + target_device = torch.device(self.device_config.device) + + try: + vllm_model_config = VllmModelConfig( + model=model_path, + quantization=self.server_args.quantization, + tokenizer=None, + tokenizer_mode=None, + trust_remote_code=self.server_args.trust_remote_code, + dtype=self.server_args.dtype, + seed=42, + skip_tokenizer_init=True, + ) + except Exception as e: + logger.error(f"Failed to load model config: {e}") + return False, "Failed to update model weights" + + load_config = LoadConfig(load_format=load_format) + + # Only support vllm DefaultModelLoader for now + loader = get_model_loader(load_config) + if not isinstance(loader, DefaultModelLoader): + logger.error("Failed to get weights iterator: Unsupported loader") + return False, "Failed to update model weights" + + def get_weight_iter(config): + iter = loader._get_weights_iterator( + config.model, + config.revision, + fall_back_to_pt=getattr( + self.model, "fall_back_to_pt_during_load", True + ), + ) + return iter + + def model_load_weights(model, iter): + model.load_weights(iter) + for _, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model + + with set_default_torch_dtype(vllm_model_config.dtype): + try: + iter = get_weight_iter(vllm_model_config) + except Exception as e: + message = f"Failed to get weights iterator: {e}" + logger.error(message) + return False, message + try: + model = model_load_weights(self.model, iter) + except Exception as e: + message = f"Failed to update weights: {e}. \n Rolling back to original weights" + logger.error(message) + del iter + gc.collect() + iter = get_weight_iter(self.vllm_model_config) + self.model = model_load_weights(self.model, iter) + return False, message + + self.model = model + self.server_args.model_path = model_path + self.server_args.load_format = load_format + self.vllm_model_config = vllm_model_config + self.load_config = load_config + self.model_config.path = model_path + + logger.info(f"[gpu={self.gpu_id}] Update weights end.") + return True, "Succeeded to update model weights" + def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 55271c23526..0c5a3c706b0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -51,7 +51,11 @@ start_controller_process as start_controller_process_single, ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + UpdateWeightReqInput, +) from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, @@ -136,6 +140,23 @@ async def flush_cache(): ) +@app.post("/update_weights") +async def update_weights(obj: UpdateWeightReqInput, request: Request): + + success, message = await tokenizer_manager.update_weights(obj, request) + content = {"message": message, "success": str(success)} + if success: + return JSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return JSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" if obj.stream: diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py new file mode 100644 index 00000000000..64f84263aa9 --- /dev/null +++ b/test/srt/test_update_weights.py @@ -0,0 +1,106 @@ +import json +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_UNIT_TEST, + popen_launch_server, +) + + +class TestReplaceWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "n": 1, + }, + "stream": False, + "return_logprob": False, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + # return the "text" in response + text = response.json()["text"] + return text + + def get_model_info(self): + response = requests.get(self.base_url + "/get_model_info") + model_path = response.json()["model_path"] + print(json.dumps(response.json())) + return model_path + + def run_update_weights(self, model_path): + response = requests.post( + self.base_url + "/update_weights", + json={ + "model_path": model_path, + }, + ) + print(json.dumps(response.json())) + + def test_replace_weights(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = "meta-llama/Meta-Llama-3.1-8B" + self.run_update_weights(new_model_path) + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == new_model_path + assert updated_model_path != origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] != updated_response[:32] + + # update weights back + self.run_update_weights(origin_model_path) + updated_model_path = self.get_model_info() + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + + def test_replace_weights_unexist_model(self): + origin_model_path = self.get_model_info() + print(f"origin_model_path: {origin_model_path}") + origin_response = self.run_decode() + + # update weights + new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" + self.run_update_weights(new_model_path) + + updated_model_path = self.get_model_info() + print(f"updated_model_path: {updated_model_path}") + assert updated_model_path == origin_model_path + + updated_response = self.run_decode() + assert origin_response[:32] == updated_response[:32] + + +if __name__ == "__main__": + unittest.main()