Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize data movement #20

Merged
merged 19 commits into from
Apr 2, 2023
Merged

Optimize data movement #20

merged 19 commits into from
Apr 2, 2023

Conversation

WoosukKwon
Copy link
Collaborator

Should be merged after #15 .

The changes in this PR eliminate the need for redundant data movements such as torch.cat, torch.stack, and torch.contiguous, which were previously used to align input and output shapes. The PR modifies existing kernels and adds new kernels to accommodate non-contiguous tensors, making these data movement operators unnecessary.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines +35 to +50
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
)[0]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output.copy_(out, non_blocking=True)
return_softmax=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, so flash attention natively supports non-contiguous QKV tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It actually requires qkv tensor of shape [num_tokens, 3, num_heads, head_size]. Previously, we inserted torch.stack to meet this shape requirement, and this PR eliminates this inefficiency.

@zhuohan123
Copy link
Member

Speed before this PR on 1 A100:

ubuntu@ray-zhuohan-cf-head-2c23a277-compute:~/nfs/cacheflow/cacheflow/benchmark$ python benchmark_latency.py --model ~/hf-llama/llama-13b/
Namespace(batch_size=8, block_size=8, dtype='half', input_len=32, max_batch_size=2560, model='/home/ubuntu/hf-llama/llama-13b/', model_path='~/.cacheflow/model_weights', output_len=128, pipeline_parallel_size=1, seed=0, swap_space=20, tensor_parallel_size=1)
2023-04-02 06:25:10,878 INFO worker.py:1535 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8266/
# GPU blocks: 1977, # CPU blocks: 3276
Warm up step
Profile step: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.53s/it]
Avg latency: 3.526289224624634 seconds
ubuntu@ray-zhuohan-cf-head-2c23a277-compute:~/nfs/cacheflow/cacheflow/benchmark$ python benchmark_latency.py --model facebook/opt-13b
Namespace(batch_size=8, block_size=8, dtype='half', input_len=32, max_batch_size=2560, model='facebook/opt-13b', model_path='~/.cacheflow/model_weights', output_len=128, pipeline_parallel_size=1, seed=0, swap_space=20, tensor_parallel_size=1)
2023-04-02 06:27:55,300 INFO worker.py:1535 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8266/
# GPU blocks: 1975, # CPU blocks: 3276
Warm up step
Profile step: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.54s/it]
Avg latency: 3.5404738585154214 seconds

After:

ubuntu@ray-zhuohan-cf-head-2c23a277-compute:~/nfs/cacheflow/cacheflow/benchmark$ python benchmark_latency.py --model facebook/opt-13b
Namespace(batch_size=8, block_size=8, dtype='half', input_len=32, max_batch_size=2560, model='facebook/opt-13b', model_path='~/.cacheflow/model_weights', output_len=128, pipeline_parallel_size=1, seed=0, swap_space=20, tensor_parallel_size=1)
2023-04-02 07:17:35,120 INFO worker.py:1535 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8266/
# GPU blocks: 1975, # CPU blocks: 3276
Warm up step
Profile step: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.43s/it]
Avg latency: 3.432361443837484 seconds
ubuntu@ray-zhuohan-cf-head-2c23a277-compute:~/nfs/cacheflow/cacheflow/benchmark$ python benchmark_latency.py --model ~/hf-llama/llama-13b/
Namespace(batch_size=8, block_size=8, dtype='half', input_len=32, max_batch_size=2560, model='/home/ubuntu/hf-llama/llama-13b/', model_path='~/.cacheflow/model_weights', output_len=128, pipeline_parallel_size=1, seed=0, swap_space=20, tensor_parallel_size=1)
2023-04-02 07:19:00,665 INFO worker.py:1535 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8266/
# GPU blocks: 1977, # CPU blocks: 3276
Warm up step
Profile step: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00,  3.27s/it]
Avg latency: 3.2731640338897705 seconds

@WoosukKwon WoosukKwon merged commit 897cb2a into main Apr 2, 2023
@WoosukKwon WoosukKwon deleted the data-move branch April 2, 2023 07:30
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
luo-cheng2021 pushed a commit to luo-cheng2021/vllm that referenced this pull request Apr 17, 2024
Produce artifacts for bare metal installation in Dockerfile.openvino
tdg5 pushed a commit to tdg5/vllm that referenced this pull request Apr 25, 2024
fxmarty pushed a commit to fxmarty/vllm-public that referenced this pull request May 31, 2024
…factor

Dockerfile improvements: multistage
tianyil1 pushed a commit to tianyil1/vllm that referenced this pull request Jun 5, 2024
* Fix setup.py for HPU

* Fix  vllm._C import ops -> vllm.hpu import ops

* more of the same thing

* re-add hpex rmsnorm and rope; but rope is crashing

* remove unnecessary comments

* add vllm/hpu files

* add hpu autodetection

* Add HabanaAttention stub

* revert accidental changes

* revert non-habana backend attention changes

* add habana attention/worker/executor, sampling fails now

* Restore unnecessarily changed files

* enable HabanaMemoryProfiler

* Make sampler pass

* restore habana fused rope

* prefill is now working!!!

* fix prefill padding; decode is now working!!!!!

* revert accidental changes

* remove unused stuff in habana_paged_attn.py

* remove diagnostic stuff from llm_engine.py

* use HabanaExecutorAsync in async_llm_engine.py

* add habana copyright headers to habana_*.py files

* fix prefill attention conformance

* minor naming fixes

* remove naive attention from habana_attn (it never worked anyway)

* re-enable profile run

* Add fake HPUGraph support

* add more metrics

* indentation fix

* ~~recipe cache metrics don't work lalalala~~

* i'm done with metrics for now

* fix corner case in which hl-smi is not available but synapse is

* FIXME: temporary setup.py workaround

* WIP: add tensor parallelism stubs

* habana worker cleanup

* tensor parallelism is now working

* remove unused files

* remove unused func

* add hpugraphrunner

* improve hpu layernorm

* Port pipelined PA

* Port context length bucketing

* remove cudagraphrunner from hpu runner

* restore HPUGraphRunner back from FakeHPUGraphRunner

* handle rotary embeddings properly on gaudi3

* oopsie! captured_block_counts was incorrect!

* captured_block_counts.append doesn't do anything

* Restore habana_main KV cache memory layout

* fix memory profiler

* overhaul hpugraph capture

* memory profiling overhaul

* format memory properly in model warmup

* add graph compilation profiler for graph capture phase

* adroll back log lvl on graph capture message

* Remove unnecessary view on residual connection in RMSNorm (vllm-project#25)

---------

Co-authored-by: madamczykhabana <[email protected]>
@alixiaodi alixiaodi mentioned this pull request Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants