Skip to content

Commit

Permalink
Adds outlines performance improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
lynkz-matt-psaltis committed May 23, 2024
1 parent 6066253 commit 4a9e16a
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
from typing import Callable, DefaultDict, Dict, List, Optional, Union

import torch
import numpy as np
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from torch import Tensor, from_numpy, full_like, int64
from transformers import PreTrainedTokenizerBase


Expand All @@ -32,13 +33,14 @@ class BaseLogitsProcessor:
def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM
self.mask: Optional[Tensor] = None
self.allowed_tokens_cache: Dict[int, Tensor] = {}

def init_state(self):
"""Initialize the FSM states."""
"""Initialize the FSM states"""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
def __call__(self, input_ids: List[int], scores: Tensor) -> Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids))

Expand All @@ -50,13 +52,34 @@ def __call__(self, input_ids: List[int],
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[last_seq_id], last_token)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
state = self.fsm_state[seq_id]

# Retrieve allowed tokens from cache using the current state
if state not in self.allowed_tokens_cache:
# Cache miss, calculate allowed tokens and cache them
allowed_tokens = self.fsm.allowed_token_ids(state)
np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32)
allowed_tokens_tensor = from_numpy(np_allowed_tokens)

if allowed_tokens_tensor.device != scores.device:
allowed_tokens_tensor = allowed_tokens_tensor.to(
scores.device, dtype=int64, non_blocking=True)
else:
allowed_tokens_tensor = allowed_tokens_tensor.to(int64)

self.allowed_tokens_cache[state] = allowed_tokens_tensor

else:
allowed_tokens_tensor = self.allowed_tokens_cache[state]

if self.mask is None:
self.mask = full_like(scores, -math.inf)
else:
self.mask.fill_(-math.inf)

self.mask.index_fill_(0, allowed_tokens_tensor, 0)
scores.add_(self.mask)

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
return scores


Expand All @@ -80,9 +103,12 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):

class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None]):
def __init__(
self,
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
Expand Down

0 comments on commit 4a9e16a

Please sign in to comment.