From ee8eaa39fc69f0696429575eafeeea2e62c0a89c Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Wed, 4 Dec 2024 07:59:36 +0000 Subject: [PATCH 01/33] add custom params Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 11 ++++++++++- python/sglang/srt/sampling/sampling_batch_info.py | 6 +++++- python/sglang/srt/sampling/sampling_params.py | 4 +++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index d7db6036ca9..e3ee7ba5fc0 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import Union +from typing import Callable, Optional, Union import torch from torch import nn @@ -25,6 +25,11 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.custom_logits_processor: Optional[Callable] = None + + def register_logit_processor(self, processor: Callable): + """Register a custom logit processor function.""" + self.custom_logit_processor = processor def forward( self, @@ -36,6 +41,10 @@ def forward( logits = logits.contiguous() + # Apply custom logit processor if registered + if self.custom_logit_processor is not None: + logits = self.custom_logit_processor(logits, sampling_info) + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") logits = torch.where( diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 1624fd255f9..b24ab5e77a0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,7 +3,7 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch @@ -46,6 +46,9 @@ class SamplingBatchInfo: # Device device: str = "cuda" + # Custom Parameters + custom_params: Optional[List[Dict[str, Any]]] = None + @classmethod def from_schedule_batch( cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool @@ -79,6 +82,7 @@ def from_schedule_batch( is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), vocab_size=vocab_size, device=device, + custom_params=[r.sampling_params.custom_params for r in reqs], ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 64d5e0783ea..88f0dedfbae 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -13,7 +13,7 @@ # ============================================================================== """Sampling parameters for text generation.""" -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union _SAMPLING_EPS = 1e-6 @@ -39,6 +39,7 @@ def __init__( no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -61,6 +62,7 @@ def __init__( self.n = n self.json_schema = json_schema self.no_stop_trim = no_stop_trim + self.custom_params = custom_params # Process some special cases if self.temperature < _SAMPLING_EPS: From 19719f4f6b3a49266be59bbaa0a0c1d8927e9653 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Wed, 4 Dec 2024 08:01:56 +0000 Subject: [PATCH 02/33] fix typos Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e3ee7ba5fc0..42d5cbd7d18 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -29,7 +29,7 @@ def __init__(self): def register_logit_processor(self, processor: Callable): """Register a custom logit processor function.""" - self.custom_logit_processor = processor + self.custom_logits_processor = processor def forward( self, @@ -42,8 +42,8 @@ def forward( logits = logits.contiguous() # Apply custom logit processor if registered - if self.custom_logit_processor is not None: - logits = self.custom_logit_processor(logits, sampling_info) + if self.custom_logits_processor is not None: + logits = self.custom_logits_processor(logits, sampling_info) if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") From 5d77e98ee9da86aaafd7c561d525c4f3d2189488 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 02:00:02 +0000 Subject: [PATCH 03/33] add general custom logit processors Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 17 ++--- python/sglang/srt/managers/io_struct.py | 22 +++++++ python/sglang/srt/managers/schedule_batch.py | 2 + python/sglang/srt/managers/scheduler.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 1 + .../srt/sampling/custom_logit_processor.py | 32 ++++++++++ .../srt/sampling/sampling_batch_info.py | 64 ++++++++++++++++++- python/sglang/srt/server.py | 4 ++ 8 files changed, 132 insertions(+), 11 deletions(-) create mode 100644 python/sglang/srt/sampling/custom_logit_processor.py diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 995abeac275..807f0d0b421 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -6,6 +6,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import crash_on_warnings, is_flashinfer_available @@ -25,11 +26,6 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] - self.custom_logits_processor: Optional[Callable] = None - - def register_logit_processor(self, processor: Callable): - """Register a custom logit processor function.""" - self.custom_logits_processor = processor def forward( self, @@ -41,9 +37,14 @@ def forward( logits = logits.contiguous() - # Apply custom logit processor if registered - if self.custom_logits_processor is not None: - logits = self.custom_logits_processor(logits, sampling_info) + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.custom_logit_processors is not None: + for ( + processor_str, + batch_mask, + ) in sampling_info.custom_logit_processors.items(): + processor = CustomLogitProcessor.from_str(processor_str) + logits = processor(logits, batch_mask, sampling_info.custom_params) if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c5884b5f0f6..b3be7617067 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union from sglang.srt.managers.schedule_batch import BaseFinishReason +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import SamplingParams @@ -60,6 +61,10 @@ class GenerateReqInput: session: Optional[ Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] ] = None + # Custom logit processor (serialized function) + customized_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = ( + None + ) def normalize_batch_and_arguments(self): if ( @@ -174,6 +179,15 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 + if self.customized_logit_processor is None: + self.customized_logit_processor = [None] * num + elif not isinstance(self.customized_logit_processor, list): + self.customized_logit_processor = [ + self.customized_logit_processor + ] * num + else: + assert self.parallel_sample_num == 1 + def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid @@ -192,6 +206,11 @@ def __getitem__(self, i): stream=self.stream, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, + customized_logit_processor=( + self.customized_logit_processor[i] + if self.customized_logit_processor is not None + else None + ), ) @@ -225,6 +244,9 @@ class TokenizedGenerateReqInput: session_id: Optional[str] = None session_rid: Optional[str] = None + # Custom logit processor (serialized function) + customized_logit_processor: Optional[str] = None + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7081f0d0c9a..f08232b464d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -209,6 +209,7 @@ def __init__( lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, ): # Input and output info self.rid = rid @@ -227,6 +228,7 @@ def __init__( # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path + self.custom_logit_processor = custom_logit_processor # Memory pool info self.req_pool_idx = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 39c0b6af870..63ce63361bc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -517,6 +517,7 @@ def handle_generate_request( stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, + custom_logit_processor=recv_req.custom_logit_processor, ) req.tokenizer = self.tokenizer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 020e96e65de..f7b873ce8e9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -289,6 +289,7 @@ async def _tokenize_one_request( input_embeds=input_embeds, session_id=session_id, session_rid=session_rid, + custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py new file mode 100644 index 00000000000..a4d449faf22 --- /dev/null +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -0,0 +1,32 @@ +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +import dill +from torch import Tensor + + +class CustomLogitProcessor(ABC): + """Abstract base class for callable functions.""" + + required_args: List[str] + + @abstractmethod + def __call__( + self, + logits: Tensor, + batch_mask: List[bool], + custom_params: Dict[str, List[Any]], + ) -> Tensor: + """Define the callable behavior.""" + raise NotImplementedError + + def to_str(self) -> str: + """Serialize the callable function to a JSON-compatible string.""" + return json.dumps({"callable": dill.dumps(self).hex()}) + + @classmethod + def from_str(cls, json_str: str): + """Deserialize a callable function from a JSON string.""" + data = json.loads(json_str) + return dill.loads(bytes.fromhex(data["callable"])) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index c05d53b21ca..ae660e77ba4 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,7 +3,7 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch @@ -47,7 +47,10 @@ class SamplingBatchInfo: device: str = "cuda" # Custom Parameters - custom_params: Optional[List[Dict[str, Any]]] = None + custom_params: Optional[Dict[str, Any]] = None + + # Custom Logit Processor + custom_logit_processors: Optional[Dict[str, torch.Tensor]] = None @classmethod def from_schedule_batch( @@ -73,6 +76,30 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) + # Merge custom params as a dictionary for one batch of requests + custom_params_list = [r.sampling_params.custom_params for r in reqs] + merged_custom_params = { + key: [d.get(key, None) for d in custom_params_list] + for key in set(key for d in custom_params_list for key in d) + } + + # Merge the same type of customlogit processors together + processor_dict = {} + for i, r in enumerate(reqs): + if r.custom_logit_processor is None: + continue + processor_str = r.custom_logit_processor + if processor_str not in processor_dict: + processor_dict[processor_str] = [] + processor_dict[processor_str].append(i) + + merged_custom_logit_processors = { + processor_str: torch.zeros(len(reqs), dtype=torch.bool).scatter_( + 0, torch.tensor(true_indices), True + ) + for processor_str, true_indices in processor_dict.items() + } + ret = cls( temperatures=temperatures, top_ps=top_ps, @@ -82,7 +109,8 @@ def from_schedule_batch( is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), vocab_size=vocab_size, device=device, - custom_params=[r.sampling_params.custom_params for r in reqs], + custom_params=merged_custom_params, + custom_logit_processors=merged_custom_logit_processors, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -219,6 +247,30 @@ def merge_bias_tensor( return None + @staticmethod + def merge_custom_params(lhs: Dict[str, List[Any]], rhs: Dict[str, List[Any]]): + keys = set(lhs.keys()).union(set(rhs.keys())) + merged_dict = {} + + for k in keys: + left_values = lhs.get(k, [None] * len(lhs)) + right_values = rhs.get(k, [None] * len(rhs)) + merged_dict[k] = left_values + right_values + return merged_dict + + @staticmethod + def merge_custom_logit_processors( + lhs: Dict[str, torch.Tensor], rhs: Dict[str, torch.Tensor] + ): + keys = set(lhs.keys()).union(set(rhs.keys())) + merged_dict = {} + + for k in keys: + left_values = lhs.get(k, torch.zeros(len(lhs), dtype=torch.bool)) + right_values = rhs.get(k, torch.zeros(len(rhs), dtype=torch.bool)) + merged_dict[k] = torch.cat([left_values, right_values]) + return merged_dict + def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) @@ -236,3 +288,9 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + self.custom_params = SamplingBatchInfo.merge_custom_params( + self.custom_params, other.custom_params + ) + self.custom_logit_processors = SamplingBatchInfo.merge_custom_logit_processors( + self.custom_logit_processors, other.custom_logit_processors + ) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4814c8c6f05..e4d9eec333c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -700,6 +700,7 @@ def generate( logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[str, List[str]]] = None, stream: bool = False, ): obj = GenerateReqInput( @@ -711,6 +712,7 @@ def generate( top_logprobs_num=top_logprobs_num, lora_path=lora_path, stream=stream, + custom_logit_processor=custom_logit_processor, ) # get the current event loop @@ -751,6 +753,7 @@ async def async_generate( logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[str, List[str]]] = None, stream: bool = False, ): obj = GenerateReqInput( @@ -762,6 +765,7 @@ async def async_generate( top_logprobs_num=top_logprobs_num, lora_path=lora_path, stream=stream, + custom_logit_processor=custom_logit_processor, ) ret = await generate_request(obj, None) From 6dd2d945ba097a81fe3d138a1461b5a5701651dc Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 03:20:25 +0000 Subject: [PATCH 04/33] fix typos Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 4 ++-- python/sglang/srt/managers/io_struct.py | 22 ++++++++----------- .../srt/sampling/sampling_batch_info.py | 12 +++++----- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 807f0d0b421..5e0d2262182 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -38,11 +38,11 @@ def forward( logits = logits.contiguous() # Apply the custom logit processors if registered in the sampling info. - if sampling_info.custom_logit_processors is not None: + if sampling_info.custom_logit_processor is not None: for ( processor_str, batch_mask, - ) in sampling_info.custom_logit_processors.items(): + ) in sampling_info.custom_logit_processor.items(): processor = CustomLogitProcessor.from_str(processor_str) logits = processor(logits, batch_mask, sampling_info.custom_params) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index b3be7617067..5c436a7fabf 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -62,9 +62,7 @@ class GenerateReqInput: Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] ] = None # Custom logit processor (serialized function) - customized_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = ( - None - ) + custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None def normalize_batch_and_arguments(self): if ( @@ -179,12 +177,10 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 - if self.customized_logit_processor is None: - self.customized_logit_processor = [None] * num - elif not isinstance(self.customized_logit_processor, list): - self.customized_logit_processor = [ - self.customized_logit_processor - ] * num + if self.custom_logit_processor is None: + self.custom_logit_processor = [None] * num + elif not isinstance(self.custom_logit_processor, list): + self.custom_logit_processor = [self.custom_logit_processor] * num else: assert self.parallel_sample_num == 1 @@ -206,9 +202,9 @@ def __getitem__(self, i): stream=self.stream, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, - customized_logit_processor=( - self.customized_logit_processor[i] - if self.customized_logit_processor is not None + custom_logit_processor=( + self.custom_logit_processor[i] + if self.custom_logit_processor is not None else None ), ) @@ -245,7 +241,7 @@ class TokenizedGenerateReqInput: session_rid: Optional[str] = None # Custom logit processor (serialized function) - customized_logit_processor: Optional[str] = None + custom_logit_processor: Optional[str] = None @dataclass diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index ae660e77ba4..6f63ad1c308 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -50,7 +50,7 @@ class SamplingBatchInfo: custom_params: Optional[Dict[str, Any]] = None # Custom Logit Processor - custom_logit_processors: Optional[Dict[str, torch.Tensor]] = None + custom_logit_processor: Optional[Dict[str, torch.Tensor]] = None @classmethod def from_schedule_batch( @@ -93,7 +93,7 @@ def from_schedule_batch( processor_dict[processor_str] = [] processor_dict[processor_str].append(i) - merged_custom_logit_processors = { + merged_custom_logit_processor = { processor_str: torch.zeros(len(reqs), dtype=torch.bool).scatter_( 0, torch.tensor(true_indices), True ) @@ -110,7 +110,7 @@ def from_schedule_batch( vocab_size=vocab_size, device=device, custom_params=merged_custom_params, - custom_logit_processors=merged_custom_logit_processors, + custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -259,7 +259,7 @@ def merge_custom_params(lhs: Dict[str, List[Any]], rhs: Dict[str, List[Any]]): return merged_dict @staticmethod - def merge_custom_logit_processors( + def merge_custom_logit_processor( lhs: Dict[str, torch.Tensor], rhs: Dict[str, torch.Tensor] ): keys = set(lhs.keys()).union(set(rhs.keys())) @@ -291,6 +291,6 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.custom_params = SamplingBatchInfo.merge_custom_params( self.custom_params, other.custom_params ) - self.custom_logit_processors = SamplingBatchInfo.merge_custom_logit_processors( - self.custom_logit_processors, other.custom_logit_processors + self.custom_logit_processor = SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, other.custom_logit_processor ) From 34b00056f8a1d001f9764f7e9292c258298bf383 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 03:27:42 +0000 Subject: [PATCH 05/33] fix bug Signed-off-by: Hongpeng Guo --- .../sglang/srt/sampling/sampling_batch_info.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6f63ad1c308..e35f3001b1a 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -78,10 +78,12 @@ def from_schedule_batch( # Merge custom params as a dictionary for one batch of requests custom_params_list = [r.sampling_params.custom_params for r in reqs] - merged_custom_params = { - key: [d.get(key, None) for d in custom_params_list] - for key in set(key for d in custom_params_list for key in d) - } + merged_custom_params = {} + if custom_params_list: + merged_custom_params = { + key: [d.get(key, None) for d in custom_params_list] + for key in set(key for d in custom_params_list for key in d) + } # Merge the same type of customlogit processors together processor_dict = {} @@ -94,9 +96,9 @@ def from_schedule_batch( processor_dict[processor_str].append(i) merged_custom_logit_processor = { - processor_str: torch.zeros(len(reqs), dtype=torch.bool).scatter_( - 0, torch.tensor(true_indices), True - ) + processor_str: torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True) for processor_str, true_indices in processor_dict.items() } From 10622020fe6ba3c92f6387a6cd4378157a90af51 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 03:48:12 +0000 Subject: [PATCH 06/33] remove merge custom_param logic from sglang Signed-off-by: Hongpeng Guo --- .../srt/sampling/sampling_batch_info.py | 24 ++----------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index e35f3001b1a..98c79611042 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,7 +3,7 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch @@ -76,14 +76,7 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) - # Merge custom params as a dictionary for one batch of requests custom_params_list = [r.sampling_params.custom_params for r in reqs] - merged_custom_params = {} - if custom_params_list: - merged_custom_params = { - key: [d.get(key, None) for d in custom_params_list] - for key in set(key for d in custom_params_list for key in d) - } # Merge the same type of customlogit processors together processor_dict = {} @@ -249,17 +242,6 @@ def merge_bias_tensor( return None - @staticmethod - def merge_custom_params(lhs: Dict[str, List[Any]], rhs: Dict[str, List[Any]]): - keys = set(lhs.keys()).union(set(rhs.keys())) - merged_dict = {} - - for k in keys: - left_values = lhs.get(k, [None] * len(lhs)) - right_values = rhs.get(k, [None] * len(rhs)) - merged_dict[k] = left_values + right_values - return merged_dict - @staticmethod def merge_custom_logit_processor( lhs: Dict[str, torch.Tensor], rhs: Dict[str, torch.Tensor] @@ -290,9 +272,7 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) - self.custom_params = SamplingBatchInfo.merge_custom_params( - self.custom_params, other.custom_params - ) + self.custom_params = self.custom_params + other.custom_params self.custom_logit_processor = SamplingBatchInfo.merge_custom_logit_processor( self.custom_logit_processor, other.custom_logit_processor ) From 1f16689e73cc356ecc7c84fdb5cc2d385f0309dc Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 04:34:10 +0000 Subject: [PATCH 07/33] fix inconsistency Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/session_controller.py | 1 + python/sglang/srt/sampling/sampling_batch_info.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index dc5a1b670ea..5e6d53107b7 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -56,6 +56,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): sampling_params=req.sampling_params, lora_path=req.lora_path, session_id=self.session_id, + custom_logit_processor=req.custom_logit_processor, ) if len(self.reqs) > 0: new_req.image_inputs = self.reqs[-1].image_inputs diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 98c79611042..09018e12df6 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -47,7 +47,7 @@ class SamplingBatchInfo: device: str = "cuda" # Custom Parameters - custom_params: Optional[Dict[str, Any]] = None + custom_params: Optional[List[Optional[Dict[str, Any]]]] = None # Custom Logit Processor custom_logit_processor: Optional[Dict[str, torch.Tensor]] = None @@ -104,7 +104,7 @@ def from_schedule_batch( is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), vocab_size=vocab_size, device=device, - custom_params=merged_custom_params, + custom_params=custom_params_list, custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. From b1bdfdfd8f7c5298134b340af8a2d6b674d56623 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 06:39:46 +0000 Subject: [PATCH 08/33] update function defination Signed-off-by: Hongpeng Guo --- python/sglang/srt/sampling/custom_logit_processor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index a4d449faf22..352f8b4d903 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -1,9 +1,9 @@ import json from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import dill -from torch import Tensor +import torch class CustomLogitProcessor(ABC): @@ -14,10 +14,10 @@ class CustomLogitProcessor(ABC): @abstractmethod def __call__( self, - logits: Tensor, - batch_mask: List[bool], - custom_params: Dict[str, List[Any]], - ) -> Tensor: + logits: torch.Tensor, + batch_mask: torch.Tensor, + custom_params: Optional[List[Dict[str, Any]]] = None, + ) -> torch.Tensor: """Define the callable behavior.""" raise NotImplementedError From f4c79ecff6c934b1e1ff06d9b772fe06c51a0f56 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 07:15:52 +0000 Subject: [PATCH 09/33] apply the processor on the sampler Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 6 +++++- python/sglang/srt/sampling/custom_logit_processor.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 5e0d2262182..b5488f56202 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -44,7 +44,11 @@ def forward( batch_mask, ) in sampling_info.custom_logit_processor.items(): processor = CustomLogitProcessor.from_str(processor_str) - logits = processor(logits, batch_mask, sampling_info.custom_params) + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_info.custom_params[i] for i in batch_indices], + ) if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 352f8b4d903..2b5917ba2e8 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -15,7 +15,6 @@ class CustomLogitProcessor(ABC): def __call__( self, logits: torch.Tensor, - batch_mask: torch.Tensor, custom_params: Optional[List[Dict[str, Any]]] = None, ) -> torch.Tensor: """Define the callable behavior.""" From c0a59051930dc00000e085a6138f42a5fde0987a Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 07:16:16 +0000 Subject: [PATCH 10/33] apply the processor on the sampler Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b5488f56202..314698bef0a 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from typing import Union import torch from torch import nn From 13a7bc6678f8711d7fb3b72158b77eb46b44e77a Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 07:18:42 +0000 Subject: [PATCH 11/33] refine the custom logit processor Signed-off-by: Hongpeng Guo --- python/sglang/srt/sampling/custom_logit_processor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 2b5917ba2e8..dd67ee1175c 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -9,7 +9,11 @@ class CustomLogitProcessor(ABC): """Abstract base class for callable functions.""" - required_args: List[str] + @property + @abstractmethod + def required_args(self) -> List[str]: + """List of required arguments for this processor.""" + raise NotImplementedError @abstractmethod def __call__( From 530127f2f2a7bf57055ea7cbe4dfe74ea1d39960 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 07:24:31 +0000 Subject: [PATCH 12/33] update function arg name Signed-off-by: Hongpeng Guo --- python/sglang/srt/sampling/custom_logit_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index dd67ee1175c..e6731c789ab 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -19,7 +19,7 @@ def required_args(self) -> List[str]: def __call__( self, logits: torch.Tensor, - custom_params: Optional[List[Dict[str, Any]]] = None, + custom_param_list: Optional[List[Dict[str, Any]]] = None, ) -> torch.Tensor: """Define the callable behavior.""" raise NotImplementedError From 4a6c6e274f835ee8781c7e3f5c6f12f8bfabbf83 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 07:35:11 +0000 Subject: [PATCH 13/33] update function arg name Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 1 + python/sglang/srt/sampling/custom_logit_processor.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 314698bef0a..fc05d1b6363 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -48,6 +48,7 @@ def forward( logits[batch_mask] = processor( logits[batch_mask], [sampling_info.custom_params[i] for i in batch_indices], + sampling_info.device, ) if self.use_nan_detectioin and torch.any(torch.isnan(logits)): diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index e6731c789ab..5158e7e91cb 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -20,6 +20,7 @@ def __call__( self, logits: torch.Tensor, custom_param_list: Optional[List[Dict[str, Any]]] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """Define the callable behavior.""" raise NotImplementedError From 7331e36f35c00be677130c4fe918752f458c6282 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 08:30:46 +0000 Subject: [PATCH 14/33] add logger in the sampler Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index fc05d1b6363..6040956b31d 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -51,6 +51,10 @@ def forward( sampling_info.device, ) + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") logits = torch.where( From dfc1f445ca38df4e0482496787bde34d75030463 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 09:25:13 +0000 Subject: [PATCH 15/33] fix merge conflicts Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/io_struct.py | 4 ++-- python/sglang/srt/managers/tokenizer_manager.py | 4 +--- python/sglang/srt/server.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index acde11304f7..cf15759708c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -71,10 +71,10 @@ class GenerateReqInput: session: Optional[ Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] ] = None - # Custom logit processor (serialized function) - custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor (serialized function) + custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None def normalize_batch_and_arguments(self): if ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6988682a672..b7e9a59d2b6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -294,10 +294,8 @@ async def _tokenize_one_request( obj.stream, lora_path=obj.lora_path, input_embeds=input_embeds, - session_id=session_id, - session_rid=session_rid, - custom_logit_processor=obj.custom_logit_processor, session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ff440fbc07e..3c5a366dd7c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -717,7 +717,7 @@ def generate( logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[str, List[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, stream: bool = False, ): obj = GenerateReqInput( From 4a7eb4fde730339b8cbfb404e5838f5cac28fc0d Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 30 Dec 2024 09:33:36 +0000 Subject: [PATCH 16/33] fix merge conflicts Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/io_struct.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index cf15759708c..c56078c0390 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -67,10 +67,6 @@ class GenerateReqInput: # LoRA related lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - # Session id info for continual prompting - session: Optional[ - Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] - ] = None # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None # Custom logit processor (serialized function) From d8e648b91d12997aee0f3f7ef37cd0b05df9bb96 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Tue, 31 Dec 2024 07:41:43 +0000 Subject: [PATCH 17/33] resolve merge conflict Signed-off-by: Hongpeng Guo --- python/sglang/srt/sampling/sampling_batch_info.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7382e48cc00..25df9372373 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -272,6 +272,10 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + self.custom_params = self.custom_params + other.custom_params + self.custom_logit_processor = SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, other.custom_logit_processor + ) def apply_logits_bias(self, logits: torch.Tensor): # Apply logit_bias @@ -293,7 +297,3 @@ def apply_logits_bias(self, logits: torch.Tensor): # Apply regex vocab_mask if self.vocab_mask is not None: self.apply_mask(logits=logits, vocab_mask=self.vocab_mask) - self.custom_params = self.custom_params + other.custom_params - self.custom_logit_processor = SamplingBatchInfo.merge_custom_logit_processor( - self.custom_logit_processor, other.custom_logit_processor - ) From 7aa3b7c5fadba5062ed1ce66ceab3b56654c6bb2 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Tue, 31 Dec 2024 09:08:10 +0000 Subject: [PATCH 18/33] remove required args Signed-off-by: Hongpeng Guo --- python/sglang/srt/sampling/custom_logit_processor.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 5158e7e91cb..d5460eb782e 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -9,12 +9,6 @@ class CustomLogitProcessor(ABC): """Abstract base class for callable functions.""" - @property - @abstractmethod - def required_args(self) -> List[str]: - """List of required arguments for this processor.""" - raise NotImplementedError - @abstractmethod def __call__( self, From 9f0a2874dd3fb37c86ba86395fb692a7893731de Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Tue, 31 Dec 2024 09:36:10 +0000 Subject: [PATCH 19/33] update unittest in test_srt_endpoint Signed-off-by: Hongpeng Guo --- test/srt/test_srt_endpoint.py | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 0fd71efcb0b..b87389abd8a 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -9,6 +9,7 @@ import numpy as np import requests +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -248,6 +249,51 @@ def test_logprob_grammar(self): self.assertTrue(all(x is not None for x in logprobs)) + def test_custom_logit_processor(self): + """Test custom logit processor with custom params.""" + + class AddLogitProcessor(CustomLogitProcessor): + def __call__(self, logits, custom_param_list, device): + import torch + + assert logits.shape[0] == len(custom_param_list) + key = "arg1" + merged_params = { + key: torch.tensor( + [ + custom_param_list[i][key] + for i in range(len(custom_param_list)) + ], + dtype=torch.float, + ).to(device=device, non_blocking=True) + } + return logits + merged_params[key] + + prompts = "Question: Is Paris the Capital of France? Answer:" + + # Base case json data to be posted to the server. + base_json = { + "text": prompts, + "sampling_params": {"temperature": 1.0}, + } + + # Custom json data with custom logit processor and params. + custom_json = base_json.copy() + custom_json["custom_logit_processor"] = AddLogitProcessor().to_str() + custom_json["sampling_params"]["custom_params"] = {"arg1": 5.0} + + base_response = requests.post( + self.base_url + "/generate", + json=base_json, + ).json()["text"] + + custom_response = requests.post( + self.base_url + "/generate", + json=custom_json, + ).json()["text"] + + self.assertNotEqual(base_response, custom_response) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json() From e0b795a37276ea0a74249c536e31910d3165d103 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Thu, 9 Jan 2025 08:02:08 +0000 Subject: [PATCH 20/33] wip handle comments Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 45 ++++++++----- .../srt/sampling/sampling_batch_info.py | 67 ++++++++++++------- test/srt/test_srt_endpoint.py | 45 +++++++------ 3 files changed, 94 insertions(+), 63 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 2cd295cd279..c35495a35c9 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -37,22 +37,8 @@ def forward( logits = logits_output.next_token_logits # Apply the custom logit processors if registered in the sampling info. - if sampling_info.custom_logit_processor is not None: - for ( - processor_str, - batch_mask, - ) in sampling_info.custom_logit_processor.items(): - processor = CustomLogitProcessor.from_str(processor_str) - batch_indices = batch_mask.nonzero(as_tuple=True)[0] - logits[batch_mask] = processor( - logits[batch_mask], - [sampling_info.custom_params[i] for i in batch_indices], - sampling_info.device, - ) - - logger.debug( - f"Custom logit processor {processor.__class__.__name__} is applied." - ) + if sampling_info.has_custom_logit_processor: + self._apply_custom_logit_processor(logits, sampling_info) if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") @@ -140,6 +126,33 @@ def forward( return batch_next_token_ids + def _apply_custom_logit_processor( + self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo + ): + """Apply custom logit processors to the logits. + This function will modify the logits in-place.""" + + for ( + processor_str, + batch_mask, + ) in sampling_batch_info.custom_logit_processor.items(): + # Get the processor from the string representation + processor = CustomLogitProcessor.from_str(processor_str) + + # Get the batch indices that need to be processed + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + + # Apply the processor to the logits + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_batch_info.custom_params[i] for i in batch_indices], + sampling_batch_info.device, + ) + + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) + def top_k_top_p_min_p_sampling_from_probs_torch( probs: torch.Tensor, diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index cb7a0f5c774..08b5d539649 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -30,6 +30,9 @@ class SamplingBatchInfo: # Dispatch in CUDA graph need_min_p_sampling: bool + # Whether any request has custom logit processor + has_custom_logit_processor: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -76,24 +79,28 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) - custom_params_list = [r.sampling_params.custom_params for r in reqs] - - # Merge the same type of customlogit processors together - processor_dict = {} - for i, r in enumerate(reqs): - if r.custom_logit_processor is None: - continue - processor_str = r.custom_logit_processor - if processor_str not in processor_dict: - processor_dict[processor_str] = [] - processor_dict[processor_str].append(i) - - merged_custom_logit_processor = { - processor_str: torch.zeros(len(reqs), dtype=torch.bool) - .scatter_(0, torch.tensor(true_indices), True) - .to(device, non_blocking=True) - for processor_str, true_indices in processor_dict.items() - } + # Check if any request has custom logit processor + has_custom_logit_processor = any(r.custom_logit_processor for r in reqs) + + if has_custom_logit_processor: + # Merge the same type of custom logit processors together + processor_dict = {} + for i, r in enumerate(reqs): + if r.custom_logit_processor is None: + continue + processor_str = r.custom_logit_processor + if processor_str not in processor_dict: + processor_dict[processor_str] = [] + processor_dict[processor_str].append(i) + + merged_custom_logit_processor = { + processor_str: torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True) + for processor_str, true_indices in processor_dict.items() + } + else: + merged_custom_logit_processor = {} ret = cls( temperatures=temperatures, @@ -102,9 +109,10 @@ def from_schedule_batch( min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), + has_custom_logit_processor=has_custom_logit_processor, vocab_size=vocab_size, device=device, - custom_params=custom_params_list, + custom_params=[r.sampling_params.custom_params for r in reqs], custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -244,14 +252,18 @@ def merge_bias_tensor( @staticmethod def merge_custom_logit_processor( - lhs: Dict[str, torch.Tensor], rhs: Dict[str, torch.Tensor] + lhs: Dict[str, torch.Tensor], rhs: Dict[str, torch.Tensor], device: str ): keys = set(lhs.keys()).union(set(rhs.keys())) merged_dict = {} for k in keys: - left_values = lhs.get(k, torch.zeros(len(lhs), dtype=torch.bool)) - right_values = rhs.get(k, torch.zeros(len(rhs), dtype=torch.bool)) + left_values = lhs.get( + k, torch.zeros(len(lhs), dtype=torch.bool, device=device) + ) + right_values = rhs.get( + k, torch.zeros(len(rhs), dtype=torch.bool, device=device) + ) merged_dict[k] = torch.cat([left_values, right_values]) return merged_dict @@ -273,9 +285,14 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.logit_bias, other.logit_bias, len(self), len(other), self.device ) self.custom_params = self.custom_params + other.custom_params - self.custom_logit_processor = SamplingBatchInfo.merge_custom_logit_processor( - self.custom_logit_processor, other.custom_logit_processor - ) + if self.has_custom_logit_processor or other.has_custom_logit_processor: + self.custom_logit_processor = ( + SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, + other.custom_logit_processor, + self.device, + ) + ) self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling def apply_logits_bias(self, logits: torch.Tensor): diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index b87389abd8a..2a8e072c629 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -252,47 +252,48 @@ def test_logprob_grammar(self): def test_custom_logit_processor(self): """Test custom logit processor with custom params.""" - class AddLogitProcessor(CustomLogitProcessor): + class DummyLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits into a tensor of + the same shape with a single value input from the custom params. + """ + def __call__(self, logits, custom_param_list, device): import torch assert logits.shape[0] == len(custom_param_list) - key = "arg1" - merged_params = { - key: torch.tensor( - [ - custom_param_list[i][key] - for i in range(len(custom_param_list)) - ], - dtype=torch.float, - ).to(device=device, non_blocking=True) - } - return logits + merged_params[key] + key = "value" + + merged_params = torch.tensor( + [custom_param_list[i][key] for i in range(len(custom_param_list))], + dtype=torch.float, + ).to(device=device, non_blocking=True) + + return merged_params.unsqueeze(1) * torch.ones_like(logits) prompts = "Question: Is Paris the Capital of France? Answer:" # Base case json data to be posted to the server. base_json = { "text": prompts, - "sampling_params": {"temperature": 1.0}, + "sampling_params": {"temperature": 0.0}, + "return_logprob": True, } # Custom json data with custom logit processor and params. custom_json = base_json.copy() - custom_json["custom_logit_processor"] = AddLogitProcessor().to_str() - custom_json["sampling_params"]["custom_params"] = {"arg1": 5.0} - - base_response = requests.post( - self.base_url + "/generate", - json=base_json, - ).json()["text"] + custom_json["custom_logit_processor"] = DummyLogitProcessor().to_str() + custom_json["sampling_params"]["custom_params"] = {"value": 5.0} custom_response = requests.post( self.base_url + "/generate", json=custom_json, - ).json()["text"] + ).json() + + output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] - self.assertNotEqual(base_response, custom_response) + # The logit processor should always sample the same token as the logits is deterministic. + self.assertEqual(len(set(sampled_tokens)), 1) def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") From 785cf6a2bcb47b491e6f698a4683e3c39411e9bf Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Fri, 17 Jan 2025 20:13:07 +0000 Subject: [PATCH 21/33] remove the required device arg in the logit processor Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 1 - python/sglang/srt/sampling/custom_logit_processor.py | 1 - test/srt/test_srt_endpoint.py | 4 ++-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index c35495a35c9..cfab22aa450 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -146,7 +146,6 @@ def _apply_custom_logit_processor( logits[batch_mask] = processor( logits[batch_mask], [sampling_batch_info.custom_params[i] for i in batch_indices], - sampling_batch_info.device, ) logger.debug( diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index d5460eb782e..a9fcd39b595 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -14,7 +14,6 @@ def __call__( self, logits: torch.Tensor, custom_param_list: Optional[List[Dict[str, Any]]] = None, - device: Optional[torch.device] = None, ) -> torch.Tensor: """Define the callable behavior.""" raise NotImplementedError diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 2a8e072c629..9d3f27c77fd 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -257,7 +257,7 @@ class DummyLogitProcessor(CustomLogitProcessor): the same shape with a single value input from the custom params. """ - def __call__(self, logits, custom_param_list, device): + def __call__(self, logits, custom_param_list): import torch assert logits.shape[0] == len(custom_param_list) @@ -266,7 +266,7 @@ def __call__(self, logits, custom_param_list, device): merged_params = torch.tensor( [custom_param_list[i][key] for i in range(len(custom_param_list))], dtype=torch.float, - ).to(device=device, non_blocking=True) + ).to(device=logits.device, non_blocking=True) return merged_params.unsqueeze(1) * torch.ones_like(logits) From b50f62ff57338ef83ea6bd955b207f98f9eb4cce Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Fri, 17 Jan 2025 20:38:52 +0000 Subject: [PATCH 22/33] fix unnittest to always sample a given tokenid Signed-off-by: Hongpeng Guo --- test/srt/test_srt_endpoint.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 9d3f27c77fd..d0f0587435c 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -252,23 +252,23 @@ def test_logprob_grammar(self): def test_custom_logit_processor(self): """Test custom logit processor with custom params.""" - class DummyLogitProcessor(CustomLogitProcessor): - """A dummy logit processor that changes the logits into a tensor of - the same shape with a single value input from the custom params. + custom_params = {"token_id": 5} + + class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. """ def __call__(self, logits, custom_param_list): - import torch - assert logits.shape[0] == len(custom_param_list) - key = "value" - - merged_params = torch.tensor( - [custom_param_list[i][key] for i in range(len(custom_param_list))], - dtype=torch.float, - ).to(device=logits.device, non_blocking=True) + key = "token_id" - return merged_params.unsqueeze(1) * torch.ones_like(logits) + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits prompts = "Question: Is Paris the Capital of France? Answer:" @@ -281,8 +281,8 @@ def __call__(self, logits, custom_param_list): # Custom json data with custom logit processor and params. custom_json = base_json.copy() - custom_json["custom_logit_processor"] = DummyLogitProcessor().to_str() - custom_json["sampling_params"]["custom_params"] = {"value": 5.0} + custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str() + custom_json["sampling_params"]["custom_params"] = custom_params custom_response = requests.post( self.base_url + "/generate", @@ -292,8 +292,8 @@ def __call__(self, logits, custom_param_list): output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] sampled_tokens = [x[1] for x in output_token_logprobs] - # The logit processor should always sample the same token as the logits is deterministic. - self.assertEqual(len(set(sampled_tokens)), 1) + # The logit processor should always sample the given token as the logits is deterministic. + self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens)) def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") From 6bb05aae148734f5d5389bc6d6652a8f025992a0 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Fri, 17 Jan 2025 21:05:25 +0000 Subject: [PATCH 23/33] cahce deserialized custom logit processor Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index cfab22aa450..561b9e5c873 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import Dict, List import torch from torch import nn @@ -26,6 +26,7 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.custom_logit_processor_cache: Dict[int, CustomLogitProcessor] = {} def forward( self, @@ -136,8 +137,13 @@ def _apply_custom_logit_processor( processor_str, batch_mask, ) in sampling_batch_info.custom_logit_processor.items(): - # Get the processor from the string representation - processor = CustomLogitProcessor.from_str(processor_str) + # Get the processor from the string representation or the local cache + processor_hash = hash(processor_str) + if processor_hash in self.custom_logit_processor_cache: + processor = self.custom_logit_processor_cache[processor_hash] + else: + processor = CustomLogitProcessor.from_str(processor_str) + self.custom_logit_processor_cache[processor_hash] = processor # Get the batch indices that need to be processed batch_indices = batch_mask.nonzero(as_tuple=True)[0] From 9bca98a4efb548ff85c4808c06179e8ed83891e4 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Fri, 17 Jan 2025 21:25:35 +0000 Subject: [PATCH 24/33] move custom params init inside condition Signed-off-by: Hongpeng Guo --- .../srt/sampling/sampling_batch_info.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 4bdbf5a62c9..d1bb539f173 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -105,8 +105,10 @@ def from_schedule_batch( .to(device, non_blocking=True) for processor_str, true_indices in processor_dict.items() } + custom_params = [r.sampling_params.custom_params for r in reqs] else: merged_custom_logit_processor = {} + custom_params = None ret = cls( temperatures=temperatures, @@ -118,7 +120,7 @@ def from_schedule_batch( has_custom_logit_processor=has_custom_logit_processor, vocab_size=vocab_size, device=device, - custom_params=[r.sampling_params.custom_params for r in reqs], + custom_params=custom_params, custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -290,16 +292,25 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) - self.custom_params = self.custom_params + other.custom_params + self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling + + # Merge the custom logit processors and custom params lists if self.has_custom_logit_processor or other.has_custom_logit_processor: + # Merge the custom logit processors self.custom_logit_processor = ( SamplingBatchInfo.merge_custom_logit_processor( - self.custom_logit_processor, - other.custom_logit_processor, + self.custom_logit_processor or {}, + other.custom_logit_processor or {}, self.device, ) ) - self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling + # Merge the custom params lists + self.custom_params = self.custom_params or [None] * len(self) + other.custom_params = other.custom_params or [None] * len(other) + self.custom_params.extend(other.custom_params) + + # Set the flag to True if any of the two has custom logit processor + self.has_custom_logit_processor = True def apply_logits_bias(self, logits: torch.Tensor): # Apply logit_bias From 94d629ad1728c93db7230f99d96144041e1179af Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Fri, 17 Jan 2025 23:27:11 +0000 Subject: [PATCH 25/33] add a flag to turn this feature off by default Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/schedule_batch.py | 27 +++++++++++++++++++- python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/server_args.py | 8 ++++++ test/srt/test_srt_endpoint.py | 5 +++- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index d907d4253dc..001e0a2639e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -212,6 +212,7 @@ def __init__( return_logprob: bool = False, top_logprobs_num: int = 0, stream: bool = False, + enable_custom_logit_processor: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, @@ -237,7 +238,10 @@ def __init__( # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path - self.custom_logit_processor = custom_logit_processor + self._set_custom_logit_processor( + custom_logit_processor, + enable_custom_logit_processor, + ) # Memory pool info self.req_pool_idx = None @@ -316,6 +320,27 @@ def __init__( # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 + def _set_custom_logit_processor( + self, custom_logit_processor: Optional[str], enable_custom_logit_processor: bool + ) -> Optional[str]: + """ + Validate and set the custom logit processor. Set to None if the server is not + configured to enable this feature. + """ + if not custom_logit_processor: + self.custom_logit_processor = None + return + + if enable_custom_logit_processor: + self.custom_logit_processor = custom_logit_processor + else: + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) + self.custom_logit_processor = None + def extend_image_inputs(self, image_inputs): if self.image_inputs is None: self.image_inputs = image_inputs diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e875b3b85e1..ff8b59b79cb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -640,6 +640,7 @@ def handle_generate_request( return_logprob=recv_req.return_logprob, top_logprobs_num=recv_req.top_logprobs_num, stream=recv_req.stream, + enable_custom_logit_processor=self.server_args.enable_custom_logit_processor, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, custom_logit_processor=recv_req.custom_logit_processor, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 052e316b7c4..6dd0b945654 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,6 +159,9 @@ class ServerArgs: enable_memory_saver: bool = False allow_auto_truncate: bool = False + # Custom logit processor + enable_custom_logit_processor: bool = False + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -865,6 +868,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.", ) + parser.add_argument( + "--enable-custom-logit-processor", + action="store_true", + help="Enable users to pass custom logit processors to the server (disabled by default for security)", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index d0f0587435c..fd7ea4fe87a 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -25,7 +25,10 @@ def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=("--enable-custom-logit-processor",), ) @classmethod From 55d52aaa68a9be9ed15fa8fbaa02b46b9e80b1c6 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Fri, 17 Jan 2025 23:35:48 +0000 Subject: [PATCH 26/33] update doc string Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/io_struct.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 02bc96e334c..9b1cc368dd6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -250,6 +250,7 @@ class TokenizedGenerateReqInput: session_params: Optional[SessionParams] = None # Custom logit processor (serialized function) + # TODO (hpguo): Add an example and update doc string here custom_logit_processor: Optional[str] = None From 42407cd38d12caa21fbbadf070337bcc168d178b Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sat, 18 Jan 2025 14:45:24 +0000 Subject: [PATCH 27/33] moving clp deserialize into sampling_batch_info Signed-off-by: Hongpeng Guo --- python/sglang/srt/layers/sampler.py | 13 +---- .../srt/sampling/sampling_batch_info.py | 52 ++++++++++++++----- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 561b9e5c873..e8b25da0704 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -26,7 +26,6 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] - self.custom_logit_processor_cache: Dict[int, CustomLogitProcessor] = {} def forward( self, @@ -133,18 +132,10 @@ def _apply_custom_logit_processor( """Apply custom logit processors to the logits. This function will modify the logits in-place.""" - for ( - processor_str, + for _, ( + processor, batch_mask, ) in sampling_batch_info.custom_logit_processor.items(): - # Get the processor from the string representation or the local cache - processor_hash = hash(processor_str) - if processor_hash in self.custom_logit_processor_cache: - processor = self.custom_logit_processor_cache[processor_hash] - else: - processor = CustomLogitProcessor.from_str(processor_str) - self.custom_logit_processor_cache[processor_hash] = processor - # Get the batch indices that need to be processed batch_indices = batch_mask.nonzero(as_tuple=True)[0] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index d1bb539f173..67771dbd495 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,7 +3,7 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch @@ -14,6 +14,7 @@ from sgl_kernel import sampling_scaling_penalties import sglang.srt.sampling.penaltylib as penaltylib +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor logger = logging.getLogger(__name__) @@ -59,7 +60,9 @@ class SamplingBatchInfo: custom_params: Optional[List[Optional[Dict[str, Any]]]] = None # Custom Logit Processor - custom_logit_processor: Optional[Dict[str, torch.Tensor]] = None + custom_logit_processor: Optional[ + Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] + ] = None @classmethod def from_schedule_batch( @@ -100,9 +103,14 @@ def from_schedule_batch( processor_dict[processor_str].append(i) merged_custom_logit_processor = { - processor_str: torch.zeros(len(reqs), dtype=torch.bool) - .scatter_(0, torch.tensor(true_indices), True) - .to(device, non_blocking=True) + hash(processor_str): ( + # The deserialized custom logit processor object + CustomLogitProcessor.from_str(processor_str), + # The mask tensor for the requests that use this custom logit processor + torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True), + ) for processor_str, true_indices in processor_dict.items() } custom_params = [r.sampling_params.custom_params for r in reqs] @@ -260,19 +268,35 @@ def merge_bias_tensor( @staticmethod def merge_custom_logit_processor( - lhs: Dict[str, torch.Tensor], rhs: Dict[str, torch.Tensor], device: str + lhs: Optional[Dict[str, torch.Tensor]], + rhs: Optional[Dict[str, torch.Tensor]], + bs1: int, + bs2: int, + device: str, ): + if lhs is None and rhs is None: + return None + lhs, rhs = lhs or {}, rhs or {} + keys = set(lhs.keys()).union(set(rhs.keys())) merged_dict = {} for k in keys: - left_values = lhs.get( - k, torch.zeros(len(lhs), dtype=torch.bool, device=device) + # Get the logit processor object + processor = lhs[k][0] if k in lhs else rhs[k][0] + # Get and merge the mask tensors from the two dicts + left_mask = ( + lhs[k][1] + if k in lhs + else torch.zeros(bs1, dtype=torch.bool, device=device) ) - right_values = rhs.get( - k, torch.zeros(len(rhs), dtype=torch.bool, device=device) + right_mask = ( + rhs[k][1] + if k in rhs + else torch.zeros(bs2, dtype=torch.bool, device=device) ) - merged_dict[k] = torch.cat([left_values, right_values]) + merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) + return merged_dict def merge_batch(self, other: "SamplingBatchInfo"): @@ -299,8 +323,10 @@ def merge_batch(self, other: "SamplingBatchInfo"): # Merge the custom logit processors self.custom_logit_processor = ( SamplingBatchInfo.merge_custom_logit_processor( - self.custom_logit_processor or {}, - other.custom_logit_processor or {}, + self.custom_logit_processor, + other.custom_logit_processor, + len(self), + len(other), self.device, ) ) From 61967f1fcf3c8dafb5f60df20efbb411aa394f67 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 19 Jan 2025 02:20:49 +0000 Subject: [PATCH 28/33] making default clp being None in SamplingBatchInfo Signed-off-by: Hongpeng Guo --- python/sglang/srt/sampling/sampling_batch_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 67771dbd495..38f98cb8d7d 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -115,7 +115,7 @@ def from_schedule_batch( } custom_params = [r.sampling_params.custom_params for r in reqs] else: - merged_custom_logit_processor = {} + merged_custom_logit_processor = None custom_params = None ret = cls( From bb4771e8ab3acb71ae1cf931df2c2f469639efe4 Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 19 Jan 2025 03:02:51 +0000 Subject: [PATCH 29/33] reorg Req taking clp Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/schedule_batch.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b3842b2d0fb..b0bdc79e907 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -340,19 +340,16 @@ def _set_custom_logit_processor( Validate and set the custom logit processor. Set to None if the server is not configured to enable this feature. """ - if not custom_logit_processor: + if not enable_custom_logit_processor: + if custom_logit_processor: + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) self.custom_logit_processor = None - return - - if enable_custom_logit_processor: - self.custom_logit_processor = custom_logit_processor else: - logger.warning( - "The SGLang server is not configured to enable custom logit processor." - "The custom logit processor passed in will be ignored." - "Please set --enable-custom-logits-processor to enable this feature." - ) - self.custom_logit_processor = None + self.custom_logit_processor = custom_logit_processor def extend_image_inputs(self, image_inputs): if self.image_inputs is None: From d96b9a8e450ab14964f9f65d438b6b2c2e25418f Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 19 Jan 2025 07:03:34 +0000 Subject: [PATCH 30/33] make flag check into scheduler Signed-off-by: Hongpeng Guo --- python/sglang/srt/managers/schedule_batch.py | 24 +------------------- python/sglang/srt/managers/scheduler.py | 16 +++++++++++-- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2a8bce7262c..4e7a5c61bba 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -225,7 +225,6 @@ def __init__( return_logprob: bool = False, top_logprobs_num: int = 0, stream: bool = False, - enable_custom_logit_processor: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, @@ -251,10 +250,7 @@ def __init__( # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path - self._set_custom_logit_processor( - custom_logit_processor, - enable_custom_logit_processor, - ) + self.custom_logit_processor = custom_logit_processor # Memory pool info self.req_pool_idx = None @@ -333,24 +329,6 @@ def __init__( # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 - def _set_custom_logit_processor( - self, custom_logit_processor: Optional[str], enable_custom_logit_processor: bool - ) -> Optional[str]: - """ - Validate and set the custom logit processor. Set to None if the server is not - configured to enable this feature. - """ - if not enable_custom_logit_processor: - if custom_logit_processor: - logger.warning( - "The SGLang server is not configured to enable custom logit processor." - "The custom logit processor passed in will be ignored." - "Please set --enable-custom-logits-processor to enable this feature." - ) - self.custom_logit_processor = None - else: - self.custom_logit_processor = custom_logit_processor - def extend_image_inputs(self, image_inputs): if self.image_inputs is None: self.image_inputs = image_inputs diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 163a8917d30..9ef1a75d628 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -631,6 +631,19 @@ def handle_generate_request( fake_input_ids = [1] * seq_length recv_req.input_ids = fake_input_ids + # Handle custom logit processor passed to the request + custom_logit_processor = recv_req.custom_logit_processor + if ( + not self.server_args.enable_custom_logit_processor + and custom_logit_processor is not None + ): + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) + custom_logit_processor = None + req = Req( recv_req.rid, recv_req.input_text, @@ -639,10 +652,9 @@ def handle_generate_request( return_logprob=recv_req.return_logprob, top_logprobs_num=recv_req.top_logprobs_num, stream=recv_req.stream, - enable_custom_logit_processor=self.server_args.enable_custom_logit_processor, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, - custom_logit_processor=recv_req.custom_logit_processor, + custom_logit_processor=custom_logit_processor, eos_token_ids=self.model_config.hf_eos_token_id, ) req.tokenizer = self.tokenizer From 320532667c42e32501bd7803b3c6a9799e42d02b Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 19 Jan 2025 07:08:17 +0000 Subject: [PATCH 31/33] add lru cahce to clp class Signed-off-by: Hongpeng Guo --- .../sglang/srt/sampling/custom_logit_processor.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index a9fcd39b595..a64b2498f23 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -1,11 +1,21 @@ import json from abc import ABC, abstractmethod +from functools import lru_cache from typing import Any, Dict, List, Optional import dill import torch +@lru_cache(maxsize=None) +def _cache_from_str(json_str: str): + """Deserialize a json string to a Callable object. + This function is cached to avoid redundant deserialization. + """ + data = json.loads(json_str) + return dill.loads(bytes.fromhex(data["callable"])) + + class CustomLogitProcessor(ABC): """Abstract base class for callable functions.""" @@ -25,5 +35,4 @@ def to_str(self) -> str: @classmethod def from_str(cls, json_str: str): """Deserialize a callable function from a JSON string.""" - data = json.loads(json_str) - return dill.loads(bytes.fromhex(data["callable"])) + return _cache_from_str(json_str) From 252672df98ea6ba9c101e289d63997a4671b9afd Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 19 Jan 2025 07:21:47 +0000 Subject: [PATCH 32/33] cover unittest case for batch Signed-off-by: Hongpeng Guo --- .../srt/sampling/sampling_batch_info.py | 22 +++++++++++++++++++ test/srt/test_srt_endpoint.py | 15 +++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 38f98cb8d7d..d4c5c32386a 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -229,6 +229,8 @@ def update_regex_vocab_mask(self): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + if self.has_custom_logit_processor: + self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) for item in [ "temperatures", @@ -241,6 +243,26 @@ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor) if value is not None: # logit_bias can be None setattr(self, item, value[new_indices]) + def _filter_batch_custom_logit_processor( + self, unfinished_indices: List[int], new_indices: torch.Tensor + ): + """Filter the custom logit processor and custom params""" + if not self.custom_logit_processor: + return + self.custom_logit_processor = { + k: (p, mask[new_indices]) + for k, (p, mask) in self.custom_logit_processor.items() + if any( + mask[new_indices] + ) # ignore the custom logit processor whose mask is all False + } + self.custom_params = [self.custom_params[i] for i in unfinished_indices] + + if len(self) == 0: + self.custom_logit_processor = None + self.custom_params = None + self.has_custom_logit_processor = False + @staticmethod def merge_bias_tensor( lhs: torch.Tensor, diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index fd7ea4fe87a..f25cd66edf6 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -5,6 +5,7 @@ import json import unittest +from concurrent.futures import ThreadPoolExecutor import numpy as np import requests @@ -252,10 +253,10 @@ def test_logprob_grammar(self): self.assertTrue(all(x is not None for x in logprobs)) - def test_custom_logit_processor(self): + def run_custom_logit_processor(self, target_token_id: int): """Test custom logit processor with custom params.""" - custom_params = {"token_id": 5} + custom_params = {"token_id": target_token_id} class DeterministicLogitProcessor(CustomLogitProcessor): """A dummy logit processor that changes the logits to always @@ -298,6 +299,16 @@ def __call__(self, logits, custom_param_list): # The logit processor should always sample the given token as the logits is deterministic. self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens)) + def test_custom_logit_processor(self): + """Test custom logit processor with a single target token id.""" + self.run_custom_logit_processor(target_token_id=5) + + def test_custom_logit_processor_batch(self): + """Test custom logit processor with multiple target token ids.""" + target_token_ids = list(range(32)) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json() From dd63b2f8cb6c9b396d7aae996cc7a0762092129e Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Sun, 19 Jan 2025 07:24:20 +0000 Subject: [PATCH 33/33] improve doc str Signed-off-by: Hongpeng Guo --- test/srt/test_srt_endpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index f25cd66edf6..7afdc9bf41c 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -300,11 +300,11 @@ def __call__(self, logits, custom_param_list): self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens)) def test_custom_logit_processor(self): - """Test custom logit processor with a single target token id.""" + """Test custom logit processor with a single request.""" self.run_custom_logit_processor(target_token_id=5) def test_custom_logit_processor_batch(self): - """Test custom logit processor with multiple target token ids.""" + """Test custom logit processor with a batch of requests.""" target_token_ids = list(range(32)) with ThreadPoolExecutor(len(target_token_ids)) as executor: list(executor.map(self.run_custom_logit_processor, target_token_ids))