Skip to content

Commit

Permalink
support embedding input as a list
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Aug 10, 2024
1 parent 62757db commit 3e6e77a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 50 deletions.
102 changes: 59 additions & 43 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ async def generate_request(
async for response in self._handle_single_request(obj, request):
yield response
else:
if isinstance(obj, EmbeddingReqInput):
raise NotImplementedError("Please send only one prompt in each request")
if obj.stream:
if hasattr(obj, "stream") and obj.stream:
raise ValueError("Do not support stream for batch mode.")

async for response in self._handle_batch_request(obj, request):
Expand Down Expand Up @@ -283,24 +281,29 @@ async def _handle_single_request(
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids

async def _handle_batch_request(self, obj: GenerateReqInput, request):
async def _handle_batch_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
):
batch_size = obj.batch_size
parallel_sample_num = obj.parallel_sample_num

if parallel_sample_num != 1:
# Send prefill requests to cache the common input
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
async for input_id in self._handle_single_request(
obj, request, index=i, is_cache_for_prefill=True
):
if input_id_result is not None:
input_id_result.append(input_id)
if input_id_result is not None and len(input_id_result) > 1:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
if self.is_generation:
parallel_sample_num = obj.parallel_sample_num

if parallel_sample_num != 1:
# Send prefill requests to cache the common input
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
async for input_id in self._handle_single_request(
obj, request, index=i, is_cache_for_prefill=True
):
if input_id_result is not None:
input_id_result.append(input_id)
if input_id_result is not None and len(input_id_result) > 1:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
else:
parallel_sample_num = 1

# First send out all requests
for i in range(batch_size):
Expand Down Expand Up @@ -329,28 +332,38 @@ async def _handle_batch_request(self, obj: GenerateReqInput, request):
input_text = None
input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index])
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[index]
)

tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
obj.top_logprobs_num[index],
obj.stream,
)
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[index]
)

tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
obj.top_logprobs_num[index],
obj.stream,
)
else:
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj)

event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state

# Then wait for all responses
output_list = []
for i in range(batch_size):
Expand All @@ -373,14 +386,17 @@ async def _handle_batch_request(self, obj: GenerateReqInput, request):
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
output_list.append(
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[index],
obj.top_logprobs_num[index],
obj.return_text_in_logprobs,
if self.is_generation:
output_list.append(
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[index],
obj.top_logprobs_num[index],
obj.return_text_in_logprobs,
)
)
)
else:
output_list.append(state.out_list[-1])
assert state.finished
del self.rid_to_state[rid]
yield output_list
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,9 @@ def forward(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
else:
logits = []
for prompt in prompts:
response = self.runtime.encode(prompt)
response = json.loads(response)
logits.append(response["embedding"])
response = self.runtime.encode(prompts)
response = json.loads(response)
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)

def __enter__(self):
Expand Down
5 changes: 3 additions & 2 deletions test/srt/test_embedding_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def run_embedding(self, use_list_input, token_input):
num_prompt_tokens = len(self.tokenizer.encode(prompt))

if use_list_input:
prompt_arg = [prompt_input, prompt_input]
prompt_arg = [prompt_input] * 2
num_prompts = len(prompt_arg)
num_prompt_tokens *= num_prompts
else:
prompt_arg = prompt_input
num_prompts = 1
Expand Down Expand Up @@ -70,7 +71,7 @@ def run_batch(self):
def test_embedding(self):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for use_list_input in [False]:
for use_list_input in [False, True]:
for token_input in [False, True]:
self.run_embedding(use_list_input, token_input)

Expand Down

0 comments on commit 3e6e77a

Please sign in to comment.