Skip to content

Commit

Permalink
Fix potential bugs in FastAPI frontend and add comments (vllm-project#28
Browse files Browse the repository at this point in the history
)
  • Loading branch information
zhuohan123 authored Apr 6, 2023
1 parent 6376304 commit 1097c78
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions cacheflow/http_frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory

TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()


class FastAPIFrontend:
def __init__(
self,
Expand All @@ -30,7 +32,7 @@ def __init__(
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
max_num_batched_tokens: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
Expand All @@ -51,7 +53,7 @@ def __init__(
dtype=dtype,
seed=seed,
swap_space=swap_space,
max_batch_size=max_batch_size,
max_num_batched_tokens=max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
Expand All @@ -68,12 +70,14 @@ async def server_step(self):
self.is_server_running = True
updated_seq_groups = await self.server.step.remote()
self.is_server_running = False
# Notify the waiting coroutines that there new outputs ready.
for seq_group in updated_seq_groups:
group_id = seq_group.group_id
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id].set()

async def generate(self, request_dict: Dict):
# Preprocess the request.
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
Expand All @@ -87,15 +91,27 @@ async def generate(self, request_dict: Dict):
arrival_time = time.time()
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event = asyncio.Event()
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id] = group_event
# Add the request into the cacheflow server's waiting queue.
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await asyncio.wait_for(group_event.wait(), timeout=1)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
# Reset the event to wait for the next output.
group_event.clear()
# Decode and return new outputs
seq_group = self.running_seq_groups[group_id]
all_outputs = []
for seq in seq_group.seqs:
Expand All @@ -107,7 +123,16 @@ async def generate(self, request_dict: Dict):
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")

# Once finished, release the resources of the sequence group.
if seq_group.is_finished():
del self.running_seq_groups[group_id]
del self.sequence_group_events[group_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
break


Expand Down Expand Up @@ -143,7 +168,7 @@ async def generate_stream(request: Request):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
Expand Down

0 comments on commit 1097c78

Please sign in to comment.