-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add handler for new lmi-dist #1595
Conversation
|
||
:return: The same parameters dict, but with VLLM style parameter names. | ||
""" | ||
parameters.pop('seed', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed is supported by vllm now: vllm-project/vllm#2514
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, will remove. In this PR, I kept what vllm rolling batch does currently and wanted to tune params in the next PR.
:return: The same parameters dict, but with VLLM style parameter names. | ||
""" | ||
parameters.pop('seed', None) | ||
parameters.pop('do_sample', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't do_sample=False
map to temperature=0
, basically? vllm does support greedy, it just uses temperature to accomplish that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I said above, I used vllm config here. But it's a good point, I believe this is removed for vllm
because it doesn't support do_sample
parameter whereas lmi-dist
should (for backwards compatibility) and we should set default sampling params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can punt on this, but I think this is another point to bring up about how consistent of an interface we want to provide across engine/backend, vs. how closely the interface should change to match each engine/backend.
engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
Outdated
Show resolved
Hide resolved
engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
Outdated
Show resolved
Hide resolved
engines/python/setup/djl_python/rolling_batch/lmi_dist_v2_rolling_batch.py
Show resolved
Hide resolved
engines/python/setup/djl_python/rolling_batch/lmi_dist_v2_rolling_batch.py
Show resolved
Hide resolved
self.request_cache.pop(key) | ||
|
||
return self.postprocess_results() | ||
return random_uuid() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason not to just be consistent and use req.id here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not sure if there's a specific reason why req.id is not used, I asked internally but did not get a response so kept it as is.
|
||
def _record_speculative_decoding_metrics(self, request_output, req_id): | ||
completion_output = request_output.outputs[0] | ||
if self.engine_config.record_acceptance_rate and request_output.finished and completion_output.acceptance_history: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this need a hasattr
guard somewhere to work with vanilla copy of the library?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, i did not write this, just moved existing code to a function for better readability.
Do we anticipate multiple versions of the library in the container?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in the base class, is it not? I anticipate a vanilla copy of the library in the container. I've been very consistent in saying this. But, who knows?
'hasattr' would make it work either way. The current code will fail if the container has vanilla copy of vllm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is in the base class, it was part of inference function in vllm rolling batch. Agree that this needs update to support vanilla vllm which needs updates in other places anyway like imports etc. Keeping as is as this not related to this specific change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think some things should be different here:
- I think our code should be remain compatible with vanilla install of vllm without throwing errors
- I think "lmi_dist" should route to the installed lmi_dist, and the handler should figure out whether that is v1 or v2 and act accordingly.
However, I'm ok with punting on some of this to a separate PR.
raise AssertionError( | ||
f"Need python engine to start vLLM RollingBatcher") | ||
return engine | ||
|
||
if rolling_batch == RollingBatchEnum.lmidist_v2 and engine != "MPI": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just have individual parameter checking for each rolling batch implementation?
@@ -96,6 +96,9 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, | |||
elif rolling_batch_type == "lmi-dist": | |||
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch | |||
return LmiDistRollingBatch | |||
elif rolling_batch_type == "lmi-dist-v2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just replace lmi-dist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be done soon once the wheel is built and added to the container.
engines/python/setup/djl_python/rolling_batch/lmi_dist_v2_rolling_batch.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are the following two stuff I suggested to work on
- Clean the property and keep vLLM and LMI-Dist V2 separately. Or provide a way to allow different default configuration in the vLLM/Rubikon Engine settings.
- We can try to keep the common parts of different rolling_batcher together. But this needs some further clean up. Maybe instead of making a base class, just create a utils class to store some commonly shared functions. The step and other logic may have change with the growing of vLLM versions so keep function sharing is safer
- Replace default LMI-Dist V1 class with V2 content
record["output_size"] = len(completion_output.token_ids) | ||
logging.info(f"Speculative Decoding {record}") | ||
|
||
def _is_t5_with_lmi_dist(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this shows up in a base function? Could we make it LMI-Dist V2 only?
I will create a follow-up PR to address these concerns. |
Description
Added
lmi_dist_v2_rolling_batch.py
which has implements rolling batch for rubikon engine. The implementation is similar to vllm rolling batch, hence created vllm rolling batch base class that will be shared by vllm and rubikon engine. As a result, had to refactor exisiting vllm rolling batch.