-
Notifications
You must be signed in to change notification settings - Fork 452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ANE-friendly static llama #8436
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8436
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 92a1be8 with merge base 52a3a9a ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
||
if not self.generate_full_logits: | ||
# Only the last logit is used for the new generated token | ||
h = h[:, input_length - 1, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add .squeeze(1)
to make h of 2d shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM! The major changes I found are:
- The dedicated KV cache
- The split linear
Anything I missed?
) | ||
self.v_caches[:, :, :, (self.cache_pos) : (self.cache_pos + length), :] = ( | ||
new_v_caches[:, :, :, start : (start + length), :] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still looks like an index put to me? Does it successfully run on ANE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InputManager and its methods (_update_cache, etc) are not part of the model. We will have a C++ implementation of it that runs on CPU. This python implementation is intended to serve as a reference for the C++ one.
torch.nn.Linear(in_features, self.common_size) | ||
for _ in range(self.num_splits) | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! So split linear is found to be more performant on ANE? Empirically 1024
is found to be the best?
PS: On our end we found split softmax would be more performant apple/coremltools#2418
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't done extensive testing yet on different hardware to find the right splitting values yet. I just noticed that 1024 works better on my M1 Pro recently.
On the SDPA pass: apple/coremltools#2418:
We observed something similar. We can get better Llama performance by processing tokens in smaller seq_length chunks (e.g., 256) (this not only chunks the SDPA, but all ops). This is easy enough to do, but it only chunks the Q seq_length (target_seq_length) in SDPA. It doesn't chunks the source_seq_length (which is more realistically the bigger value from the K/V caches, e.g., max_context_length). I suspect the chunking will help here too. But unlike chunking the target_seq_length, chunking the source_seq_length will require decomposing the SDPA op. Do you have plans to add support for this?
target_function_name="model2", | ||
) | ||
desc.default_function_name = "model1" | ||
ct.utils.save_multifunction(desc, f"{output_dir}/combined.mlpackage") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Core ML multifunction is already runable now via ExecuTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I need to remove this script from the PR. I'll do an update.
Those are the main changes. The structure of the KV caches and attention mask are different from how they're usually constructed. For example, the current tokens are always on the right-most side of the attention mask, whereas usually they are somewhere in the middle. This means we can get the K-value with one static concat (k = concat(k_cache, k_curr)). The model also distinguishes between the max_seq_length and the cache_size. cache_size can be less than max_seq_length. The effect of this is older tokens are evicted from the cache and don't serve in attention. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's build on top of it. Also there are lots of errors in CI
I'll take a look at the CI failures. |
3a57fe9
to
4db0489
Compare
4db0489
to
64f0321
Compare
This directory contains ANE-friendly Llama models.
Export model with:
The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant.
Run model with:
The model here is based on a "sliding" cache, where old tokens are evicted from the cache. By default, the cache size is max_seq_length - seq_length, but you can explicitly pass in a smaller cache size (e.g., --cache_size 512). This can speed up computation and reduce memory. Keep in mind that once cache_size is reached, older tokens get evicted from the cache and do not participate in attention.
cc @kimishpatel @YifanShenSZ @cymbalrush