Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handler for new lmi-dist #1595

Merged
merged 9 commits into from
Mar 11, 2024
Merged

Conversation

rohithkrn
Copy link
Contributor

@rohithkrn rohithkrn commented Mar 2, 2024

Description

Added lmi_dist_v2_rolling_batch.py which has implements rolling batch for rubikon engine. The implementation is similar to vllm rolling batch, hence created vllm rolling batch base class that will be shared by vllm and rubikon engine. As a result, had to refactor exisiting vllm rolling batch.

@rohithkrn rohithkrn requested review from zachgk, frankfliu and a team as code owners March 2, 2024 02:20
@rohithkrn rohithkrn changed the title [WIP]Add handler for new lmi-dist Add handler for new lmi-dist Mar 6, 2024

:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters.pop('seed', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed is supported by vllm now: vllm-project/vllm#2514

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will remove. In this PR, I kept what vllm rolling batch does currently and wanted to tune params in the next PR.

:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters.pop('seed', None)
parameters.pop('do_sample', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't do_sample=False map to temperature=0, basically? vllm does support greedy, it just uses temperature to accomplish that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said above, I used vllm config here. But it's a good point, I believe this is removed for vllm because it doesn't support do_sample parameter whereas lmi-dist should (for backwards compatibility) and we should set default sampling params

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can punt on this, but I think this is another point to bring up about how consistent of an interface we want to provide across engine/backend, vs. how closely the interface should change to match each engine/backend.

self.request_cache.pop(key)

return self.postprocess_results()
return random_uuid()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to just be consistent and use req.id here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure if there's a specific reason why req.id is not used, I asked internally but did not get a response so kept it as is.


def _record_speculative_decoding_metrics(self, request_output, req_id):
completion_output = request_output.outputs[0]
if self.engine_config.record_acceptance_rate and request_output.finished and completion_output.acceptance_history:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this need a hasattr guard somewhere to work with vanilla copy of the library?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, i did not write this, just moved existing code to a function for better readability.

Do we anticipate multiple versions of the library in the container?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the base class, is it not? I anticipate a vanilla copy of the library in the container. I've been very consistent in saying this. But, who knows?

'hasattr' would make it work either way. The current code will fail if the container has vanilla copy of vllm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is in the base class, it was part of inference function in vllm rolling batch. Agree that this needs update to support vanilla vllm which needs updates in other places anyway like imports etc. Keeping as is as this not related to this specific change.

Copy link
Contributor

@davidthomas426 davidthomas426 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think some things should be different here:

  • I think our code should be remain compatible with vanilla install of vllm without throwing errors
  • I think "lmi_dist" should route to the installed lmi_dist, and the handler should figure out whether that is v1 or v2 and act accordingly.

However, I'm ok with punting on some of this to a separate PR.

@rohithkrn rohithkrn merged commit 82d2bae into deepjavalibrary:master Mar 11, 2024
8 checks passed
raise AssertionError(
f"Need python engine to start vLLM RollingBatcher")
return engine

if rolling_batch == RollingBatchEnum.lmidist_v2 and engine != "MPI":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just have individual parameter checking for each rolling batch implementation?

@@ -96,6 +96,9 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
elif rolling_batch_type == "lmi-dist":
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
elif rolling_batch_type == "lmi-dist-v2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just replace lmi-dist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be done soon once the wheel is built and added to the container.

Copy link
Contributor

@lanking520 lanking520 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the following two stuff I suggested to work on

  • Clean the property and keep vLLM and LMI-Dist V2 separately. Or provide a way to allow different default configuration in the vLLM/Rubikon Engine settings.
  • We can try to keep the common parts of different rolling_batcher together. But this needs some further clean up. Maybe instead of making a base class, just create a utils class to store some commonly shared functions. The step and other logic may have change with the growing of vLLM versions so keep function sharing is safer
  • Replace default LMI-Dist V1 class with V2 content

record["output_size"] = len(completion_output.token_ids)
logging.info(f"Speculative Decoding {record}")

def _is_t5_with_lmi_dist(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this shows up in a base function? Could we make it LMI-Dist V2 only?

@rohithkrn
Copy link
Contributor Author

I will create a follow-up PR to address these concerns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants