-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize data movement #20
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
# Directly call FlashAttention's internal function to avoid allocating | ||
# a new tensor for the output. | ||
_flash_attn_forward( | ||
query, | ||
key, | ||
value, | ||
output, | ||
cumulative_prompt_lens, | ||
cumulative_prompt_lens, | ||
max_prompt_len, | ||
max_prompt_len, | ||
dropout_p=0.0, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
)[0] | ||
# FIXME(woosuk): Unnecessary copy. Optimize this. | ||
output.copy_(out, non_blocking=True) | ||
return_softmax=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, so flash attention natively supports non-contiguous QKV tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It actually requires qkv tensor of shape [num_tokens, 3, num_heads, head_size]
. Previously, we inserted torch.stack
to meet this shape requirement, and this PR eliminates this inefficiency.
Speed before this PR on 1 A100:
After:
|
Produce artifacts for bare metal installation in Dockerfile.openvino
Fix logging lint errors
…factor Dockerfile improvements: multistage
* Fix setup.py for HPU * Fix vllm._C import ops -> vllm.hpu import ops * more of the same thing * re-add hpex rmsnorm and rope; but rope is crashing * remove unnecessary comments * add vllm/hpu files * add hpu autodetection * Add HabanaAttention stub * revert accidental changes * revert non-habana backend attention changes * add habana attention/worker/executor, sampling fails now * Restore unnecessarily changed files * enable HabanaMemoryProfiler * Make sampler pass * restore habana fused rope * prefill is now working!!! * fix prefill padding; decode is now working!!!!! * revert accidental changes * remove unused stuff in habana_paged_attn.py * remove diagnostic stuff from llm_engine.py * use HabanaExecutorAsync in async_llm_engine.py * add habana copyright headers to habana_*.py files * fix prefill attention conformance * minor naming fixes * remove naive attention from habana_attn (it never worked anyway) * re-enable profile run * Add fake HPUGraph support * add more metrics * indentation fix * ~~recipe cache metrics don't work lalalala~~ * i'm done with metrics for now * fix corner case in which hl-smi is not available but synapse is * FIXME: temporary setup.py workaround * WIP: add tensor parallelism stubs * habana worker cleanup * tensor parallelism is now working * remove unused files * remove unused func * add hpugraphrunner * improve hpu layernorm * Port pipelined PA * Port context length bucketing * remove cudagraphrunner from hpu runner * restore HPUGraphRunner back from FakeHPUGraphRunner * handle rotary embeddings properly on gaudi3 * oopsie! captured_block_counts was incorrect! * captured_block_counts.append doesn't do anything * Restore habana_main KV cache memory layout * fix memory profiler * overhaul hpugraph capture * memory profiling overhaul * format memory properly in model warmup * add graph compilation profiler for graph capture phase * adroll back log lvl on graph capture message * Remove unnecessary view on residual connection in RMSNorm (vllm-project#25) --------- Co-authored-by: madamczykhabana <[email protected]>
Should be merged after #15 .
The changes in this PR eliminate the need for redundant data movements such as
torch.cat
,torch.stack
, andtorch.contiguous
, which were previously used to align input and output shapes. The PR modifies existing kernels and adds new kernels to accommodate non-contiguous tensors, making these data movement operators unnecessary.