diff --git a/functionary/vllm_monkey_patch/async_llm_engine.py b/functionary/vllm_monkey_patch/async_llm_engine.py index c63a8df2..77c62d69 100644 --- a/functionary/vllm_monkey_patch/async_llm_engine.py +++ b/functionary/vllm_monkey_patch/async_llm_engine.py @@ -422,9 +422,10 @@ async def step_async( outputs = await self.model_executor.execute_model_async(execute_model_req) # Loop through all the output in the batch - for i in range(len(outputs)): + for i in range(len(outputs[0])): # Check whether grammar sampling is needed - model_sampled_token_id = outputs[i].outputs[-1].samples[-1].output_token + model_sampled_token_id = outputs[0].outputs[i].samples[-1].output_token + request_id = seq_group_metadata_list[i].request_id if ( tokenizer.decode(model_sampled_token_id) == tokenizer.eos_token or request_id not in self.prompt_templates @@ -432,7 +433,6 @@ async def step_async( continue # Get all the required variables for grammar sampling - request_id = seq_group_metadata_list[i].request_id prompt_template = self.prompt_templates[request_id] gen_state = self.gen_states[request_id] tools_or_functions = self.tools_or_functions[request_id] @@ -440,10 +440,10 @@ async def step_async( # Slot the first entry of logprobs into its original position # before getting delta_token_ids_by_logprobs delta_token_id_by_logprobs = list( - outputs[i].outputs[-1].samples[-1].logprobs.keys() + outputs[0].outputs[i].samples[-1].logprobs.keys() ) delta_logprobs = list( - outputs[i].outputs[-1].samples[-1].logprobs.values() + outputs[0].outputs[i].samples[-1].logprobs.values() ) chosen_token_id = delta_token_id_by_logprobs[0] chosen_logprob = delta_logprobs[0] @@ -468,7 +468,7 @@ async def step_async( ) # Update the output token to vllm with the newly sampled one - outputs[i].outputs[-1].samples[ + outputs[0].outputs[i].samples[ -1 ].output_token = grammar_sampled_token_id