Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreymeetkai committed Nov 5, 2024
1 parent 749bf1a commit a6ce1c8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions functionary/vllm_monkey_patch/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,28 +422,28 @@ async def step_async(
outputs = await self.model_executor.execute_model_async(execute_model_req)

# Loop through all the output in the batch
for i in range(len(outputs)):
for i in range(len(outputs[0])):
# Check whether grammar sampling is needed
model_sampled_token_id = outputs[i].outputs[-1].samples[-1].output_token
model_sampled_token_id = outputs[0].outputs[i].samples[-1].output_token
request_id = seq_group_metadata_list[i].request_id
if (
tokenizer.decode(model_sampled_token_id) == tokenizer.eos_token
or request_id not in self.prompt_templates
):
continue

# Get all the required variables for grammar sampling
request_id = seq_group_metadata_list[i].request_id
prompt_template = self.prompt_templates[request_id]
gen_state = self.gen_states[request_id]
tools_or_functions = self.tools_or_functions[request_id]

# Slot the first entry of logprobs into its original position
# before getting delta_token_ids_by_logprobs
delta_token_id_by_logprobs = list(
outputs[i].outputs[-1].samples[-1].logprobs.keys()
outputs[0].outputs[i].samples[-1].logprobs.keys()
)
delta_logprobs = list(
outputs[i].outputs[-1].samples[-1].logprobs.values()
outputs[0].outputs[i].samples[-1].logprobs.values()
)
chosen_token_id = delta_token_id_by_logprobs[0]
chosen_logprob = delta_logprobs[0]
Expand All @@ -468,7 +468,7 @@ async def step_async(
)

# Update the output token to vllm with the newly sampled one
outputs[i].outputs[-1].samples[
outputs[0].outputs[i].samples[
-1
].output_token = grammar_sampled_token_id

Expand Down

0 comments on commit a6ce1c8

Please sign in to comment.