Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support logit_bias in v1 Sampler #13079
Support logit_bias in v1 Sampler #13079
Changes from 2 commits
7b17c04
c48b194
17b50c5
3b84436
d105842
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment like
TODO: This is extremely slow. Optimize this.
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill @robertgshaw2-redhat Although this implementation is a bit slow, I'm comfortable merging the PR since I haven't found a way to optimize it yet, and getting the functionality in is our top priority. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree. We can write some customized kernel to handle such things in c++. Also we may change the representation of logit_bias from dict to key value pair.
I can create some TODO as follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad Sg. Could you please update the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon I think that's fine since we need the functionality asap. But I think it should be simple to vectorize this without any custom kernel.
We just need to maintain in the batch three one-dim tensors of the same length:
We only need to update these when any requests with logit bias are added or removed from the batch.
Then we can just do
logits[(s, t)] += b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the PR. Please let me know if you want me directly jump to the optimized solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill @houseroad Considering that the code is pretty isolated, I think we can merge the PR first and have a followup PR to optimize it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill Re your idea: Each request's
logits_bias
has different lengths. How do we handle that (with the persistent batch)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon I could be missing something... let's merge this and then I can open another PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can have ragged format representation, 3 tensors, tensor a for length/offset of each batch, tensor b for token ids, tensor c for bias. And one torch op takes these 3 tensors as inputs, and we leverage C++ logic to handle it. It should be much faster then the current logic.
Maybe in the SamplingMeta or param, we should just preproces things like this.
The additional overhead is that once we update the batch, we may need to generate new tensors, which should be acceptable.