Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generation from input embedding #1265

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bed0e15
feat: add prompt_embeds interface
pfldy2850 Oct 5, 2023
3394d25
fix: add get_input_embeddings
pfldy2850 Oct 5, 2023
aa9b215
feat: support all models to generate from embeds
pfldy2850 Oct 5, 2023
ce70fe7
Merge branch 'main' into feature-input-embeds
pfldy2850 Oct 5, 2023
de4199d
fix: bugfix for inputs_embeds and add last line
pfldy2850 Oct 5, 2023
9275b2d
fix: add prompt_embeds to async engine
pfldy2850 Oct 11, 2023
e6963eb
Merge branch 'main' into feature-input-embeds
pfldy2850 Oct 11, 2023
bd5539a
fix: bugfix of get_last_token_id
pfldy2850 Oct 12, 2023
99605bc
fix: apply prompt_embeds to api_server
pfldy2850 Oct 12, 2023
87162d2
refact: refactor test_models
pfldy2850 Oct 12, 2023
a3d9de6
fix: apply style guide
pfldy2850 Oct 12, 2023
44ff4ec
fix: improve comments
pfldy2850 Oct 12, 2023
a37cef0
refact: refactor prepare_inputs and models
pfldy2850 Oct 12, 2023
9633148
fix: apply style guide
pfldy2850 Oct 12, 2023
eec19ed
refact: refactor zero embeds
pfldy2850 Oct 12, 2023
bebc26b
fix: apply style guide
pfldy2850 Oct 12, 2023
a2f2054
Merge branch 'main' into feature-input-embeds
pfldy2850 Oct 17, 2023
58391ac
Merge branch 'main' into feature-input-embeds
pfldy2850 Oct 18, 2023
c28d8bf
fix: update for new prepare_inputs
pfldy2850 Oct 18, 2023
117b47f
fix: rollback commented
pfldy2850 Oct 18, 2023
c0fae79
fix: update style
pfldy2850 Oct 18, 2023
2151bc1
Merge branch 'main' into feature-input-embeds
pfldy2850 Oct 31, 2023
d613790
Merge branch 'main' into feature-input-embeds
pfldy2850 Nov 6, 2023
1956ce4
Merge branch 'main' into feature-input-embeds
pfldy2850 Nov 7, 2023
d26465a
Merge branch 'main' into feature-input-embeds
pfldy2850 Dec 4, 2023
0790351
fix: update model_runner with input_embeds
pfldy2850 Dec 4, 2023
e313eae
fix: fix typo
pfldy2850 Dec 8, 2023
57c1701
fix: bug fix
pfldy2850 Dec 8, 2023
d266c39
fix: change input_embeds argument
pfldy2850 Dec 10, 2023
662a658
refact: refactor replace_prompt_embeds
pfldy2850 Dec 10, 2023
1110834
fix: bugfix
pfldy2850 Dec 10, 2023
ff22471
Merge branch 'main' into feature-input-embeds
pfldy2850 Dec 10, 2023
f2b10c3
Merge branch 'main' into feature-input-embeds
pfldy2850 Dec 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
prompt_embeds: List[torch.Tensor] = None,
) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
sampling_params=sampling_params,
prompt_embeds=prompt_embeds)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
Expand All @@ -154,9 +156,12 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
prompt_embeds: List[torch.Tensor] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
outputs = self.generate(prompts,
greedy_params,
prompt_embeds=prompt_embeds)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -165,12 +170,15 @@ def generate_beam_search(
prompts: List[str],
beam_width: int,
max_tokens: int,
prompt_embeds: List[torch.Tensor] = None,
) -> List[Tuple[List[int], str]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
outputs = self.generate(prompts,
beam_search_params,
prompt_embeds=prompt_embeds)
return outputs


Expand Down
50 changes: 50 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,53 @@ def test_models(
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models_from_prompt_embeds(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

prompt_embeds = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(
prompt, return_tensors="pt").input_ids.to("cuda")
token_embeds = hf_model.model.get_input_embeddings()(token_ids)
prompt_embeds.append(token_embeds[0])
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs_from_prompts = vllm_model.generate_greedy(example_prompts,
max_tokens,
prompt_embeds=None)
vllm_outputs_from_embeds = vllm_model.generate_greedy(
example_prompts, max_tokens, prompt_embeds=prompt_embeds)
del vllm_model

for i in range(len(example_prompts)):
prompt = example_prompts[i]
hf_output_str = hf_outputs[i][0]
vllm_output_str_from_prompts = vllm_outputs_from_prompts[i][0]
vllm_output_str_from_embeds = vllm_outputs_from_embeds[i][0]

assert hf_output_str == vllm_output_str_from_prompts, (
f"Test{i}:\n"
"HF: {hf_output_str!r}\n"
"vLLM_prompt: {vllm_output_str_from_prompts!r}")
assert hf_output_str == vllm_output_str_from_embeds, (
f"Test{i}:\n"
"HF: {hf_output_str}\n"
"vLLM_embeds: {vllm_output_str_from_embeds}")
assert vllm_output_str_from_prompts == vllm_output_str_from_embeds, (
f"Test{i}:\n"
"vLLM_prompt: {vllm_output_str_from_prompts}\n"
"vLLM_embeds: {vllm_output_str_from_embeds}")
8 changes: 4 additions & 4 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_sampler_all_greedy(seed: int):
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
_, _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_sampler_all_random(seed: int):
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
_, _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_sampler_all_beam(seed: int):
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
_, _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_sampler_mixed(seed: int):
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
_, _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
Expand Down
31 changes: 20 additions & 11 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
import torch

from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -358,6 +359,7 @@ async def add_request(
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prompt_embeds: Optional[torch.Tensor] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
Expand Down Expand Up @@ -388,16 +390,20 @@ async def add_request(
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prompt_embeds=prompt_embeds,
)

return stream

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
) -> RequestOutput:
"""Generate outputs for a request.

Generate outputs for a request. This method is a coroutine. It adds the
Expand All @@ -421,11 +427,14 @@ async def generate(
arrival_time = time.monotonic()

try:
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
stream = await self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
prompt_embeds=prompt_embeds,
)

async for request_output in stream:
yield request_output
Expand Down
32 changes: 27 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union

import torch

from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
Expand Down Expand Up @@ -241,6 +243,7 @@ def add_request(
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
) -> None:
"""Add a request to the engine's request pool.

Expand All @@ -250,24 +253,36 @@ def add_request(

Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
prompt: The prompt string. Can be None if prompt_token_ids
or prompt_embeds are provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prompt_embeds: The prompt embeddings. If set,
input prompt and prompt_token_ids are ignored
"""
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:

# If prompt_embeds is set, prompt_token_ids is filled with 0
if prompt_embeds is not None:
prompt_token_ids = [0] * prompt_embeds.size(0)
elif prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)

# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq = Sequence(
seq_id,
prompt,
prompt_token_ids,
block_size,
prompt_embeds=prompt_embeds,
)

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
Expand Down Expand Up @@ -629,10 +644,17 @@ def _log_system_stats(
def _decode_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence."""

# if data has prompt embeds, all_input_ids are only output token ids
if seq.data.has_prompt_embeds_forwarding():
all_input_ids = seq.get_output_token_ids()
else:
all_input_ids = seq.get_token_ids()

(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.tokenizer,
all_input_ids=seq.get_token_ids(),
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
Expand Down
21 changes: 19 additions & 2 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
import torch

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand All @@ -23,16 +24,27 @@ async def generate(request: Request) -> Response:

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- prompt_embeds: the prompt embedding to use for the generation
instead of the prompt.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws an error when only prompt_embeds are passed.

prompt_embeds = request_dict.pop("prompt_embeds", None)
if prompt_embeds is not None:
prompt_embeds = torch.tensor(prompt_embeds).to("cuda")
Copy link

@bks5881 bks5881 Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loads stuff in float32. Eats all the GPU.

prompt = None
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = engine.generate(
prompt,
sampling_params,
request_id,
prompt_embeds=prompt_embeds,
)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
Expand All @@ -58,7 +70,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]:

assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
if prompt:
text_outputs = [
prompt + output.text for output in final_output.outputs
]
else:
text_outputs = [output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)

Expand Down
Loading