Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential bugs in FastAPI frontend and add comments #28

Merged
merged 2 commits into from
Apr 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions cacheflow/http_frontend/fastapi_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory

TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()


class FastAPIFrontend:
def __init__(
self,
Expand All @@ -30,7 +32,7 @@ def __init__(
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
max_num_batched_tokens: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
Expand All @@ -51,7 +53,7 @@ def __init__(
dtype=dtype,
seed=seed,
swap_space=swap_space,
max_batch_size=max_batch_size,
max_num_batched_tokens=max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
Expand All @@ -68,12 +70,14 @@ async def server_step(self):
self.is_server_running = True
updated_seq_groups = await self.server.step.remote()
self.is_server_running = False
# Notify the waiting coroutines that there new outputs ready.
for seq_group in updated_seq_groups:
group_id = seq_group.group_id
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id].set()

async def generate(self, request_dict: Dict):
# Preprocess the request.
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
Expand All @@ -87,15 +91,27 @@ async def generate(self, request_dict: Dict):
arrival_time = time.time()
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event = asyncio.Event()
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id] = group_event
# Add the request into the cacheflow server's waiting queue.
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await asyncio.wait_for(group_event.wait(), timeout=1)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
# Reset the event to wait for the next output.
group_event.clear()
# Decode and return new outputs
seq_group = self.running_seq_groups[group_id]
all_outputs = []
for seq in seq_group.seqs:
Expand All @@ -107,7 +123,16 @@ async def generate(self, request_dict: Dict):
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")

# Once finished, release the resources of the sequence group.
if seq_group.is_finished():
del self.running_seq_groups[group_id]
del self.sequence_group_events[group_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
break


Expand Down Expand Up @@ -143,7 +168,7 @@ async def generate_stream(request: Request):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
Expand Down