Skip to content

Commit

Permalink
[Frontend] Add backend-specific options for guided decoding (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#13505)

Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Feb 20, 2025
1 parent 6a417b8 commit bfbc0b3
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 42 deletions.
2 changes: 1 addition & 1 deletion docs/source/features/structured_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters:
- `guided_json`: the output will follow the JSON schema.
- `guided_grammar`: the output will follow the context free grammar.
- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding.
- `guided_decoding_backend`: used to select the guided decoding backend to use.
- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error.

You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from enum import Enum

from openai import OpenAI
from openai import BadRequestError, OpenAI
from pydantic import BaseModel

client = OpenAI(
Expand Down Expand Up @@ -94,3 +94,26 @@ class CarDescription(BaseModel):
extra_body={"guided_grammar": simplified_sql_grammar},
)
print(completion.choices[0].message.content)

# Extra backend options
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"[email protected]\n")

try:
# The no-fallback option forces vLLM to use xgrammar, so when it fails
# you get a 400 with the reason why
completion = client.chat.completions.create(
model="Qwen/Qwen2.5-3B-Instruct",
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={
"guided_regex": "\w+@\w+\.com\n",
"stop": ["\n"],
"guided_decoding_backend": "xgrammar:no-fallback"
},
)
except BadRequestError as e:
print("This error is expected:", e)
16 changes: 16 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,22 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
guided_options_request=dict(guided_regex=sample_regex))


@pytest.mark.skip_global_cleanup
def test_disable_guided_decoding_fallback(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend="xgrammar:no-fallback"))

with pytest.raises(
ValueError,
match="xgrammar does not support regex guided decoding"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str):
Expand Down
10 changes: 10 additions & 0 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")


def test_guided_decoding_backend_options():
"""Test backend-specific options"""
params = GuidedDecodingParams(
backend="xgrammar:option-1,option-2,option-3")
assert params.backend_options() == ["option-1", "option-2", "option-3"]

no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
assert no_fallback.no_fallback()


def test_pickle_xgrammar_tokenizer_data():

# TODO: move to another test file for xgrammar
Expand Down
5 changes: 4 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum
from vllm.sampling_params import GuidedDecodingParams
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
Expand Down Expand Up @@ -2631,7 +2632,9 @@ def compute_hash(self) -> str:

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend

backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
f"must be one of {valid_guided_backends}")
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--guided-decoding-backend',
type=str,
default='xgrammar',
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines, '
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
' parameter.\n'
'Backend-sepcific options can be supplied in a comma-separated '
'list following a colon after the backend name. Valid backends and '
'all available options are: [xgrammar:no-fallback, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
Expand Down
81 changes: 44 additions & 37 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,56 @@

def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:

def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if the `no-fallback` option is specified."""
if guided_params.no_fallback():
raise ValueError(message)

logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback

# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend == "lm-format-enforcer":
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None:
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
fallback_or_error(
guided_params,
"lm-format-enforcer does not support grammar guided decoding.",
"xgrammar")

# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
logger.warning(
fallback_or_error(
guided_params,
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"features like patterns or numeric ranges.", "outlines")

if guided_params.backend == "xgrammar":
if guided_params.backend_name == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(guided_params,
"xgrammar is only supported on x86 CPUs.",
"outlines")

# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
logger.warning("xgrammar does not support regex guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")

# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
logger.warning(
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
"patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"enums, patterns or numeric ranges.", "outlines")

# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
Expand All @@ -72,25 +81,23 @@ def maybe_backend_fallback(
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
logger.warning(
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"grammar failed to convert to GBNF.", "outlines")

# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
logger.warning("xgrammar module cannot be imported successfully. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")

if (guided_params.backend == "outlines"
if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
logger.warning("outlines does not support json_object. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar")

return guided_params

Expand All @@ -100,18 +107,18 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
if guided_params.backend_name == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
Expand All @@ -127,18 +134,18 @@ def get_local_guided_decoding_logits_processor(
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
if guided_params.backend_name == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
Expand Down
19 changes: 19 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def from_optional(
whitespace_pattern=whitespace_pattern,
)

@property
def backend_name(self) -> str:
"""Return the backend name without any options.
For example if the backend is "xgrammar:no-fallback", returns "xgrammar"
"""
return (self.backend or "").split(":")[0]

def backend_options(self) -> List[str]:
"""Return the backend options as a list of strings."""
if not self.backend or ":" not in self.backend:
return []
return self.backend.split(":")[1].split(",")

def no_fallback(self) -> bool:
"""Returns True if the "no-fallback" option is supplied for the guided
decoding backend"""
return "no-fallback" in self.backend_options()

def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
Expand Down

0 comments on commit bfbc0b3

Please sign in to comment.