Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add sampler custom logits processor #2396

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ee8eaa3
add custom params
hongpeng-guo Dec 4, 2024
19719f4
fix typos
hongpeng-guo Dec 4, 2024
54204d2
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 8, 2024
e25a145
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 8, 2024
00d02fb
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 18, 2024
59465a3
Merge remote-tracking branch 'origin' into hpguo/add_sampler_logit_pr…
hongpeng-guo Dec 20, 2024
23a3451
Merge remote-tracking branch 'origin' into hpguo/add_sampler_logit_pr…
hongpeng-guo Dec 27, 2024
5d77e98
add general custom logit processors
hongpeng-guo Dec 30, 2024
6dd2d94
fix typos
hongpeng-guo Dec 30, 2024
34b0005
fix bug
hongpeng-guo Dec 30, 2024
1062202
remove merge custom_param logic from sglang
hongpeng-guo Dec 30, 2024
1f16689
fix inconsistency
hongpeng-guo Dec 30, 2024
b1bdfdf
update function defination
hongpeng-guo Dec 30, 2024
f4c79ec
apply the processor on the sampler
hongpeng-guo Dec 30, 2024
c0a5905
apply the processor on the sampler
hongpeng-guo Dec 30, 2024
13a7bc6
refine the custom logit processor
hongpeng-guo Dec 30, 2024
530127f
update function arg name
hongpeng-guo Dec 30, 2024
4a6c6e2
update function arg name
hongpeng-guo Dec 30, 2024
7331e36
add logger in the sampler
hongpeng-guo Dec 30, 2024
96af4b1
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 30, 2024
dfc1f44
fix merge conflicts
hongpeng-guo Dec 30, 2024
4a7eb4f
fix merge conflicts
hongpeng-guo Dec 30, 2024
5c9e697
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 30, 2024
45a7df8
resolve merge conflict
hongpeng-guo Dec 31, 2024
d8e648b
resolve merge conflict
hongpeng-guo Dec 31, 2024
7aa3b7c
remove required args
hongpeng-guo Dec 31, 2024
808eec7
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 31, 2024
9f0a287
update unittest in test_srt_endpoint
hongpeng-guo Dec 31, 2024
4686c33
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Dec 31, 2024
11cefea
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 7, 2025
0d66c9f
resolve merge coflicts
hongpeng-guo Jan 9, 2025
e0b795a
wip handle comments
hongpeng-guo Jan 9, 2025
d457d82
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 17, 2025
785cf6a
remove the required device arg in the logit processor
hongpeng-guo Jan 17, 2025
b50f62f
fix unnittest to always sample a given tokenid
hongpeng-guo Jan 17, 2025
6bb05aa
cahce deserialized custom logit processor
hongpeng-guo Jan 17, 2025
9bca98a
move custom params init inside condition
hongpeng-guo Jan 17, 2025
94d629a
add a flag to turn this feature off by default
hongpeng-guo Jan 17, 2025
faf23f9
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 17, 2025
55d52aa
update doc string
hongpeng-guo Jan 17, 2025
57ebd5c
Merge branch 'hpguo/add_sampler_logit_processor' of https://github.co…
hongpeng-guo Jan 17, 2025
42407cd
moving clp deserialize into sampling_batch_info
hongpeng-guo Jan 18, 2025
895a916
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 18, 2025
0b3a414
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 18, 2025
0688823
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 18, 2025
61967f1
making default clp being None in SamplingBatchInfo
hongpeng-guo Jan 19, 2025
bb4771e
reorg Req taking clp
hongpeng-guo Jan 19, 2025
849a221
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 19, 2025
e3b689a
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 19, 2025
d96b9a8
make flag check into scheduler
hongpeng-guo Jan 19, 2025
3205326
add lru cahce to clp class
hongpeng-guo Jan 19, 2025
252672d
cover unittest case for batch
hongpeng-guo Jan 19, 2025
dd63b2f
improve doc str
hongpeng-guo Jan 19, 2025
d3ac431
Merge branch 'main' into hpguo/add_sampler_logit_processor
hongpeng-guo Jan 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import List
from typing import Dict, List

import torch
from torch import nn

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available

Expand Down Expand Up @@ -35,6 +36,10 @@ def forward(
):
logits = logits_output.next_token_logits

# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
self._apply_custom_logit_processor(logits, sampling_info)

if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
Expand Down Expand Up @@ -121,6 +126,29 @@ def forward(

return batch_next_token_ids

def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""

for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]

# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)

logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)


def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
Expand Down
19 changes: 19 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Dict, List, Optional, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import SamplingParams


Expand Down Expand Up @@ -69,6 +70,8 @@ class GenerateReqInput:

# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor (serialized function)
custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None

def normalize_batch_and_arguments(self):
if (
Expand Down Expand Up @@ -183,6 +186,13 @@ def normalize_batch_and_arguments(self):
else:
assert self.parallel_sample_num == 1

if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list):
self.custom_logit_processor = [self.custom_logit_processor] * num
else:
assert self.parallel_sample_num == 1

def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
Expand All @@ -202,6 +212,11 @@ def __getitem__(self, i):
log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
custom_logit_processor=(
self.custom_logit_processor[i]
if self.custom_logit_processor is not None
else None
),
)


Expand Down Expand Up @@ -234,6 +249,10 @@ class TokenizedGenerateReqInput:
# Session info for continual prompting
session_params: Optional[SessionParams] = None

# Custom logit processor (serialized function)
# TODO (hpguo): Add an example and update doc string here
custom_logit_processor: Optional[str] = None


@dataclass
class EmbeddingReqInput:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None,
eos_token_ids: Optional[Set[int]] = None,
):
# Input and output info
Expand All @@ -252,6 +253,7 @@ def __init__(
# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path
self.custom_logit_processor = custom_logit_processor

# Memory pool info
self.req_pool_idx = None
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,19 @@ def handle_generate_request(
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids

# Handle custom logit processor passed to the request
custom_logit_processor = recv_req.custom_logit_processor
if (
not self.server_args.enable_custom_logit_processor
and custom_logit_processor is not None
):
logger.warning(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor = None

req = Req(
recv_req.rid,
recv_req.input_text,
Expand All @@ -622,6 +635,7 @@ def handle_generate_request(
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor,
eos_token_ids=self.model_config.hf_eos_token_id,
)
req.tokenizer = self.tokenizer
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/session_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
sampling_params=req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ async def _tokenize_one_request(
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/sampling/custom_logit_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, Optional

import dill
import torch


@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
"""Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization.
"""
data = json.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"]))


class CustomLogitProcessor(ABC):
"""Abstract base class for callable functions."""

@abstractmethod
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
"""Define the callable behavior."""
raise NotImplementedError

def to_str(self) -> str:
"""Serialize the callable function to a JSON-compatible string."""
return json.dumps({"callable": dill.dumps(self).hex()})

@classmethod
def from_str(cls, json_str: str):
"""Deserialize a callable function from a JSON string."""
return _cache_from_str(json_str)
Loading
Loading