Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Pipeline-Parallelism for vLLM V1 #11945

Open
1 task done
ruisearch42 opened this issue Jan 10, 2025 · 17 comments
Open
1 task done

[RFC]: Pipeline-Parallelism for vLLM V1 #11945

ruisearch42 opened this issue Jan 10, 2025 · 17 comments
Assignees
Labels

Comments

@ruisearch42
Copy link
Collaborator

Motivation.

This RFC describes the approach for supporting pipeline parallelism in vLLM V1 architecture.

Pipeline parallelism was supported in V0 with the virtual-engine approach. In short, we create multiple virtual engines to match the number of pipeline stages, and each virtual engine has its own scheduler, block manager and cache engine, so that they can schedule multiple batches simultaneously to the same executor with pipeline parallelism, saturating all pipeline stages to improve the efficiency. However, virtual engine introduces the following drawbacks:

  1. The lack of a centralized scheduler prevents global optimization from being applied.
  2. It introduces complexity to the engine architecture and implementation.

In this RFC, we aim to support pipeline parallelism in the V1 LLMEngineCore, with the following properties:

  • Good performance: throughput and TTFT
    • The design should minimize pipeline bubbles
  • KV-cache efficiency
    • The design should minimize KV-cache fragmentation
    • The design should facilitate KV-cache block reuse across different requests
    • The design should be compatible with the current prefix caching mechanism
  • Architecture
    • The design should align well with V1 architecture
  • Scheduling policy flexibility
    • The design should support existing policies (FCFS and priority) and future policies

Proposed Change.

The current V1 engine core runs a busy synchronous loop, and each iteration consists of 3 operations:

  • schedule(): schedules a batch of requests to run considering new requests, existing requests and preempted requests.
  • execute(): accepts scheduler output as the execution plan, executes the model, and returns the output.
  • update(): updates the scheduler state based on finished batch execution output.

In this section, we discuss available options of adopting the V1 engine core architecture to achieve pipeline parallelism.

Option 1: Atomic engine step

Sync Scheduler@2x

Design sketch

Intuitively, it would be ideal to keep the current busy loop mechanism in the engine core, and isolate all pipeline parallelism required changes to the executor, as shown in the above figure.

  • LLMEngineCore
    • The busy loop remains the same.
    • The model output is not corresponding to the scheduler output (i.e., microbatch in the figure) anymore. The model output in this iteration is the output of the microbatch we submitted to the executor PP_SIZE iterations ago.
  • RayExecutor
    • microbatch_queue: The queue size is the same as PP_size, and we need to guarantee that the queue is always full. If there is not a sufficient number of microbatches (e.g., cold start or idle), then we need to push empty microbatches (i.e., None) to the queue.
    • execute() takes one new microbatch, and waits and returns the execution result of the oldest microbatch. Since we guarantee that the queue is always full, we can always get the result of the oldest microbatch immediately (but it may be None).

Pros

  • The existing busy loop is (largely) unchanged, and all complexity is hidden at the executor level.
  • We still follow the “(not really) synchronous schedule” paradigm that submits one microbatch and receives the result of a (different) microbatch in the same synchronous function.

Cons

  • Degraded performance: The oldest (finished) microbatch won’t be fetched unless a new microbatch is scheduled. Although we continuously push empty microbatches when no new requests come in, this may still introduce overheads.
  • Complexity in managing empty microbatches: To achieve the desired pipeline efficiency, we have to push empty microbatches (None) to the microbatch queue when there are no requests scheduled (e.g., cold start, system idle, etc). Once we fail to maintain a full microbatch queue, the pipeline efficiency cannot be recovered unless we restart the engine.

Option 2 (Recommended): Two-stage engine loop

Async Scheduler (current)@2x

Since pipeline parallelism enables multiple inflight executions in a pipelined fashion, scheduling and execution become asynchronous by nature: before one microbatch finishes execution, the engine needs to schedule and submit another microbatch. Therefore in this option, execute() is separated into two operations: submission and finish of the microbatch. Specifically, 4 operations are involved:

  • schedule(): the scheduler considers new requests and scheduelable existing requests and schedules the microbatch
  • submit(): the engine submits the microbatch to executor for execution
  • finish(): the executor finishes the execution of microbatch
  • update(): the scheduler updates its state based on finished microbatch execution output

Design sketch

  • LLMEngineCore

    • The busy loop is changed to use an async loop.
    • The async loop is driven by the following events:
      • New request comes in
      • Existing request becomes schedulable
      • Oldest microbatch finished
    • The same code can run in synchronous fashion when microbatch_queue size is 1.
  • Ray Executor

    • A pipeline executor that executes whatever microbatches it receives.

Pros

  • Event driven and performant, because the oldest microbatch can finish as soon as possible.
  • A stepping stone to extend to a fully async scheduler.

Cons

  • Changes the current synchronous busy loop.

Option 3: Virtual Engine

This is similar to the virtual engine solution in vLLM V0

Pros

  • Convenient to implement.
  • Good isolation.

Cons

  • Needs multiple schedulers, which are hard to manage and maintain.
  • Cannot reuse KV-cache from a different virtual engine; possible internal fragmentation.

Milestones for Option 2

We have the following milestones for achieving option 2.

  • Introduce async loop in LLMEngineCore
  • [In parallel] Support multiple microbatches (disjoint requests)
  • Implement pipeline-parallel
  • Optimization: support scheduling the same prefill-stage request in multiple inflight microbatches
    • Note: for a request in decode stage, it can only be scheduled to one inflight microbatch before we figure out how to deal with speculative decoding and jump decoding; however, for a request in prefill stage, it can be scheduled to multiple inflight microbatches naturally, because prefill for later layers in later PP stage does not depend on the complete finish of the scheduled tokens.

Feedback Period.

No response

CC List.

@WoosukKwon @robertgshaw2-neuralmagic @tylertitsworth @youkaichao @simon-mo @comaniac @stephanie-wang

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao youkaichao self-assigned this Jan 11, 2025
@joerunde
Copy link
Collaborator

For Option 2, would the two-stage async loop need to be implemented as two separate synchronous processes communicating over zmq? I thought one of the main goals of isolating the engine loop was to eliminate the overhead of using asyncio.

For this con of option 1:

Complexity in managing empty microbatches: To achieve the desired pipeline efficiency, we have to push empty microbatches (None) to the microbatch queue when there are no requests scheduled (e.g., cold start, system idle, etc). Once we fail to maintain a full microbatch queue, the pipeline efficiency cannot be recovered unless we restart the engine.

Why can't the efficiency be recovered without restarting the engine? I would assume that after PP_SIZE iterations with new input, the pipeline would be full again with no empty microbatches

@comaniac
Copy link
Collaborator

comaniac commented Jan 14, 2025

For Option 2, would the two-stage async loop need to be implemented as two separate synchronous processes communicating over zmq? I thought one of the main goals of isolating the engine loop was to eliminate the overhead of using asyncio.

I don't think this would be bottleneck in short term, given that most other logic such as API server and output processors are already moved out. However, it is straightforward to extend Option 2 to use zmq in the future if we really have to.

Why can't the efficiency be recovered without restarting the engine? I would assume that after PP_SIZE iterations with new input, the pipeline would be full again with no empty microbatches

It's not the case for Option 3. Note that Option 3 is synced and not event driven. It means when you push a batch to the pipeline, you will be blocked until you get an output of a batch from the pipeline. Thus, if the pipeline is empty, then you have to wait for the output of the batch you just pushed, meaning that the pipeline is now blocking.

Here is an example of 4 stages without empty batches:

microbatch_queue.submit(b0)
# P0  |  P1  |  P2  | P3
# b0  |      |      |      

output = microbatch_queue.get()
# Blocking until b0 is finished. Not pipelining.

Another example of 4 stages with empty batches:

microbatch_queue.submit(b0)
# P0  |  P1  |  P2  | P3
# b0  |  N/A |  N/A |  N/A

output = microbatch_queue.get()
# Immediately get None (N/A). Pipelining

So if we failed to push None (N/A) at any moment, it's hard to recover. Of course this can still be achieved with careful engineering and maintenance, but it seems the engineering effort is not worthwhile and this option is not friendly to potential future extension (e.g., async scheduler).

@ruisearch42
Copy link
Collaborator Author

For Option 2, would the two-stage async loop need to be implemented as two separate synchronous processes communicating over zmq? I thought one of the main goals of isolating the engine loop was to eliminate the overhead of using asyncio.

Yeah so far those are async routines run in the same async loop.

@heheda12345
Copy link
Collaborator

How to ensure the load balancing across different micro batches?
DP+EP MOE also needs to split the requests into micro batches for DP. It may be needed to think about how to make these two micro batches compatible now to support DP+EP+PP.

@comaniac
Copy link
Collaborator

How to ensure the load balancing across different micro batches? DP+EP MOE also needs to split the requests into micro batches for DP. It may be needed to think about how to make these two micro batches compatible now to support DP+EP+PP.

This is a good question but is out of scope of this RFC. Load balancing for PP is a general issue and is orthogonal to the PP implementation. For example chunked prefill is one effective way to balance batch sizes; a better layer partition algorithm is another way. Since the design proposed by this RFC doesn't prevent the scheduler from batching certain requests, our future optimizations to the scheduler for PP efficiency improvements should always be compatible.

@noooop
Copy link
Contributor

noooop commented Jan 18, 2025

refer to deepspeed zero-1

Multiple prefill batches of the same request can be processed in order use Pipeline-Parallelism,
As long as the corresponding layer of kv cache of the previous batch has been written.

It can further improve TTFT.

But it is difficult to implement.

Could anyone be interested in implementing it?

@comaniac
Copy link
Collaborator

refer to deepspeed zero-1

Multiple prefill batches of the same request can be processed in order use Pipeline-Parallelism,
As long as the corresponding layer of kv cache of the previous batch has been written.

It can further improve TTFT.

But it is difficult to implement.

Could anyone be interested in implementing it?

This is exactly the optimization mentioned in the option 2. This feature is critical for PP, but it requires more changes to the scheduler, so we didn't plan to include it in the first milestone.

@noooop
Copy link
Contributor

noooop commented Jan 18, 2025

This is exactly the optimization mentioned in the option 2. This feature is critical for PP, but it requires more changes to the scheduler, so we didn't plan to include it in the first milestone.

Let's make it happen, option 2 is obviously cooler.


vllm has only one implementation per module, which limits the possibilities for exploring more.

Moreover, this module needs to be compatible with all models and functions, so it will be compromised, leading to suboptimal results.

If vllm can support multiple implementations of a certain function, this implementation may only be optimized for a very specific scenario for the time being. The implementer does not need to think too much and only optimize this scenario.

@youkaichao
Copy link
Member

Thanks for the great RFC!

When I design the collective_rpc API with @tlrmchlsmth in V1, I have pipeline parallel in mind. What I think (we can call it Option 4) is that we add async to the collective_rpc call:

batch1 = scheduler.schedule() # schedule into running queue 1
future_result1 = executor.collective_rpc("execute_model", batch1, wait_on_rank=[1])
batch2 = scheduler.schedule() # schedule into running queue 2
future_result2 = executor.collective_rpc("execute_model", batch2, wait_on_rank=[1])

output_batch1 = future_result1.wait()
scheduler.update(output_batch1)

output_batch2 = future_result1.wait()
scheduler.update(output_batch2)

This has the benefit that:

  • we only have one scheduler, and all requests can share kv cache
  • different batches can have different batch size, their sizes are independent. (of course we will try to balance them to reduce bubbles)

When @andoorve adds pipeline parallel support, we choose the virtual engine design, mainly to minimize the code change, as the v0 scheduler is quite complicated and difficult to change. Now we have a much simpler scheduler in v1, I think we should support multiple running queues in the scheduler.

I'm not sure how different it is from the proposed Option 2.

@noooop
Copy link
Contributor

noooop commented Jan 25, 2025

I'm not sure how different it is from the proposed Option 2.

Personally, I feel that the difference between option 2 and option 4 is whether to use python async syntax.

python async is like another language.

One function uses async def, and all functions must use async def, which is very annoying, and asyncio is really slow.

And sometimes when starting thread or process, asyncio loop and (zmq) socket may be forked, inadvertently use forked loop or socket resulting in very strange bugs.

future is great.

@ruisearch42
Copy link
Collaborator Author

Personally, I feel that the difference between option 2 and option 4 is whether to use python async syntax.

python async is like another language.

future is great.

In addition to future, an event loop is also needed, because the tasks are essentially async. If we are only using future, the solution would be similar to option 1, if not exactly the same.

I think the collective_rpc @youkaichao mentioned would be a possible implementation for executing the batch. It may return a future, but we'd still need an event loop.

asyncio is really slow.

Can you elaborate a bit? Is asyncio slow due to the mechanism, due to a particular implementation, or only for certain workloads (e.g., CPU-bound)? I think we can use uvloop or might also consider other frameworks.

@wedobetter
Copy link

wedobetter commented Feb 3, 2025

I'm not sure how different it is from the proposed Option 2.

Personally, I feel that the difference between option 2 and option 4 is whether to use python async syntax.

python async is like another language.

One function uses async def, and all functions must use async def, which is very annoying, and asyncio is really slow.

And sometimes when starting thread or process, asyncio loop and (zmq) socket may be forked, inadvertently use forked loop or socket resulting in very strange bugs.

future is great.

You have probably skipped reading the docs, you can mix and match async and sync functions as threads

https://docs.python.org/3/library/asyncio-task.html#id13

Futures library was used before asyncio was made available in 3.6, although still contains useful functions such as ThreadPool and ProcessPool executors

@njhill njhill added the v1 label Feb 4, 2025
@njhill
Copy link
Member

njhill commented Feb 4, 2025

I'm wary of using asyncio inside the core engine process. The idea was to really isolate the critical path for the model loop as a single thread, at least for non-PP cases, so that it's not context switching with other async tasks or having critical path tasks wait behind other arbitrary tasks in the event loop queue.

@comaniac
Copy link
Collaborator

comaniac commented Feb 5, 2025

I'm wary of using asyncio inside the core engine process. The idea was to really isolate the critical path for the model loop as a single thread, at least for non-PP cases, so that it's not context switching with other async tasks or having critical path tasks wait behind other arbitrary tasks in the event loop queue.

This is a fair point. Then it seems more ideal to use Threading (with 2 threads of scheduling requests and receiving outputs)? So that we never hand over to other async tasks in each thread.

@noooop
Copy link
Contributor

noooop commented Feb 5, 2025

You have probably skipped reading the docs, you can mix and match async and sync functions as threads

The "future" I'm talking about here is a generalize waitable object, e.g. ResultFuture or concurrent.future, and haven't considered the specific implementation yet. Not specifically referring to asyncio-future.

What I want to express here is to use future (generalize waitable object) instead of asyncio syntax.

batch1 = scheduler.schedule() # schedule into running queue 1
future_result1 = executor.collective_rpc("execute_model", batch1, wait_on_rank=[1])
batch2 = scheduler.schedule() # schedule into running queue 2
future_result2 = executor.collective_rpc("execute_model", batch2, wait_on_rank=[1])

vs

batch1 = scheduler.schedule() # schedule into running queue 1
result1 = wait executor.collective_rpc("execute_model", batch1, wait_on_rank=[1])
batch2 = scheduler.schedule() # schedule into running queue 2
result2 = wait executor.collective_rpc("execute_model", batch2, wait_on_rank=[1])

I'm wary of using asyncio inside the core engine process.

Yes, I had a huge headache because of using asyncio in a multi-process environment, this can introduce weird bugs that are hard to track down.

Probably because neither asyncio loop nor zmq Context and Socket are thread-safe. Forked asyncio loop is still listening and even reading zmq Forked Socket uses Forked Context. Very weird things can happen

Although asyncio has some useful implementations like queue, timeout, etc., thread can do it and can use multiple cpus.

I think is good to use zmq+thread like v1 multiproc_executor and v0 mp_distributed_executor

@ruisearch42
Copy link
Collaborator Author

OK thanks. Let me prototype with threading.

@noooop
Copy link
Contributor

noooop commented Feb 6, 2025

asyncio is really slow.

Can you elaborate a bit? Is asyncio slow due to the mechanism, due to a particular implementation, or only for certain workloads (e.g., CPU-bound)? I think we can use uvloop or might also consider other frameworks.

code

dummy zmq loop test:

def server():
    import zmq

    context = zmq.Context()
    socket = context.socket(zmq.REP)
    socket.bind("tcp://*:5555")

    while True:
        socket.recv()
        socket.send(b"World")

result (QPS)

server: naive server: gevent server: asyncio server: uvloop avg all server
client: naive 53586.69 39127.87 34617.86 37090.49 41105.7275
client: gevent 38408.05 30665.28 26724.27 28110.91 30977.1275
client: asyncio 34353.76 27077.97 23354.54 26413.4 27799.9175
client: uvloop 37912.84 30066.15 25288.72 28999.83 30566.885
avg all client 41065.335 31734.3175 27496.3475 30153.6575

conclusion

At least in this simple test, asyncio is slow

naive zeromq uses optimized cython code.

@ruisearch42


If use asyncio, the api Server qps cannot exceed 3W token.
In particular, 13700kf has much better single-core performance than ordinary server CPUs.
This value is not high for pp, which is scary if you think about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants