-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Pipeline-Parallelism for vLLM V1 #11945
Comments
For Option 2, would the two-stage async loop need to be implemented as two separate synchronous processes communicating over zmq? I thought one of the main goals of isolating the engine loop was to eliminate the overhead of using asyncio. For this con of option 1:
Why can't the efficiency be recovered without restarting the engine? I would assume that after |
I don't think this would be bottleneck in short term, given that most other logic such as API server and output processors are already moved out. However, it is straightforward to extend Option 2 to use zmq in the future if we really have to.
It's not the case for Option 3. Note that Option 3 is synced and not event driven. It means when you push a batch to the pipeline, you will be blocked until you get an output of a batch from the pipeline. Thus, if the pipeline is empty, then you have to wait for the output of the batch you just pushed, meaning that the pipeline is now blocking. Here is an example of 4 stages without empty batches:
Another example of 4 stages with empty batches:
So if we failed to push None (N/A) at any moment, it's hard to recover. Of course this can still be achieved with careful engineering and maintenance, but it seems the engineering effort is not worthwhile and this option is not friendly to potential future extension (e.g., async scheduler). |
Yeah so far those are async routines run in the same async loop. |
How to ensure the load balancing across different micro batches? |
This is a good question but is out of scope of this RFC. Load balancing for PP is a general issue and is orthogonal to the PP implementation. For example chunked prefill is one effective way to balance batch sizes; a better layer partition algorithm is another way. Since the design proposed by this RFC doesn't prevent the scheduler from batching certain requests, our future optimizations to the scheduler for PP efficiency improvements should always be compatible. |
refer to deepspeed zero-1 Multiple prefill batches of the same request can be processed in order use Pipeline-Parallelism, It can further improve TTFT. But it is difficult to implement. Could anyone be interested in implementing it? |
This is exactly the optimization mentioned in the option 2. This feature is critical for PP, but it requires more changes to the scheduler, so we didn't plan to include it in the first milestone. |
Let's make it happen, option 2 is obviously cooler. vllm has only one implementation per module, which limits the possibilities for exploring more. Moreover, this module needs to be compatible with all models and functions, so it will be compromised, leading to suboptimal results. If vllm can support multiple implementations of a certain function, this implementation may only be optimized for a very specific scenario for the time being. The implementer does not need to think too much and only optimize this scenario. |
Thanks for the great RFC! When I design the batch1 = scheduler.schedule() # schedule into running queue 1
future_result1 = executor.collective_rpc("execute_model", batch1, wait_on_rank=[1])
batch2 = scheduler.schedule() # schedule into running queue 2
future_result2 = executor.collective_rpc("execute_model", batch2, wait_on_rank=[1])
output_batch1 = future_result1.wait()
scheduler.update(output_batch1)
output_batch2 = future_result1.wait()
scheduler.update(output_batch2) This has the benefit that:
When @andoorve adds pipeline parallel support, we choose the virtual engine design, mainly to minimize the code change, as the v0 scheduler is quite complicated and difficult to change. Now we have a much simpler scheduler in v1, I think we should support multiple running queues in the scheduler. I'm not sure how different it is from the proposed Option 2. |
Personally, I feel that the difference between option 2 and option 4 is whether to use python async syntax. python async is like another language. One function uses async def, and all functions must use async def, which is very annoying, and asyncio is really slow.
future is great. |
In addition to future, an event loop is also needed, because the tasks are essentially async. If we are only using future, the solution would be similar to option 1, if not exactly the same. I think the
Can you elaborate a bit? Is asyncio slow due to the mechanism, due to a particular implementation, or only for certain workloads (e.g., CPU-bound)? I think we can use uvloop or might also consider other frameworks. |
You have probably skipped reading the docs, you can mix and match async and sync functions as threads https://docs.python.org/3/library/asyncio-task.html#id13 Futures library was used before asyncio was made available in 3.6, although still contains useful functions such as ThreadPool and ProcessPool executors |
I'm wary of using asyncio inside the core engine process. The idea was to really isolate the critical path for the model loop as a single thread, at least for non-PP cases, so that it's not context switching with other async tasks or having critical path tasks wait behind other arbitrary tasks in the event loop queue. |
This is a fair point. Then it seems more ideal to use Threading (with 2 threads of scheduling requests and receiving outputs)? So that we never hand over to other async tasks in each thread. |
The "future" I'm talking about here is a generalize waitable object, e.g. ResultFuture or concurrent.future, and haven't considered the specific implementation yet. Not specifically referring to asyncio-future. What I want to express here is to use future (generalize waitable object) instead of asyncio syntax.
vs
Yes, I had a huge headache because of using asyncio in a multi-process environment, this can introduce weird bugs that are hard to track down.
Although asyncio has some useful implementations like queue, timeout, etc., thread can do it and can use multiple cpus. I think is good to use zmq+thread like v1 multiproc_executor and v0 mp_distributed_executor |
OK thanks. Let me prototype with threading. |
dummy zmq loop test:
result (QPS)
conclusionAt least in this simple test, asyncio is slow naive zeromq uses optimized cython code.
|
Motivation.
This RFC describes the approach for supporting pipeline parallelism in vLLM V1 architecture.
Pipeline parallelism was supported in V0 with the virtual-engine approach. In short, we create multiple virtual engines to match the number of pipeline stages, and each virtual engine has its own scheduler, block manager and cache engine, so that they can schedule multiple batches simultaneously to the same executor with pipeline parallelism, saturating all pipeline stages to improve the efficiency. However, virtual engine introduces the following drawbacks:
In this RFC, we aim to support pipeline parallelism in the V1 LLMEngineCore, with the following properties:
Proposed Change.
The current V1 engine core runs a busy synchronous loop, and each iteration consists of 3 operations:
In this section, we discuss available options of adopting the V1 engine core architecture to achieve pipeline parallelism.
Option 1: Atomic engine step
Design sketch
Intuitively, it would be ideal to keep the current busy loop mechanism in the engine core, and isolate all pipeline parallelism required changes to the executor, as shown in the above figure.
Pros
Cons
Option 2 (Recommended): Two-stage engine loop
Since pipeline parallelism enables multiple inflight executions in a pipelined fashion, scheduling and execution become asynchronous by nature: before one microbatch finishes execution, the engine needs to schedule and submit another microbatch. Therefore in this option, execute() is separated into two operations: submission and finish of the microbatch. Specifically, 4 operations are involved:
Design sketch
LLMEngineCore
Ray Executor
Pros
Cons
Option 3: Virtual Engine
This is similar to the virtual engine solution in vLLM V0
Pros
Cons
Milestones for Option 2
We have the following milestones for achieving option 2.
Feedback Period.
No response
CC List.
@WoosukKwon @robertgshaw2-neuralmagic @tylertitsworth @youkaichao @simon-mo @comaniac @stephanie-wang
Any Other Things.
No response
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: