-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ChatModel (pyfunc subclass) #10820
Conversation
Documentation preview for 4075860 will be available here when this CircleCI job completes successfully. More info
|
51635a5
to
cb17a59
Compare
mlflow/types/llm.py
Outdated
usage: TokenUsageStats | ||
object: str = "chat.completion" | ||
created: int = field(default_factory=lambda: int(time.time())) | ||
id: str = field(default_factory=lambda: str(uuid.uuid4())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any constraint for id
(e.g. must start with "chatcmpl-") in OpenAI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but I don't think we should make up random IDs here that don't have meaning. Can we leave this as None
for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed it to None! my initial thought was that if people want it to have meaning, they can specify the ID directly when instantiating ChatRequest
, e.g. ChatRequest(id=meaningful_id, ...)
still works, but for people who just want it to be a UUID, this saves them a couple of lines of code
# is not supported, so the code here is a little ugly. | ||
|
||
|
||
class _BaseDataclass: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all this validation logic is mainly to support the output validation done here.
input validation shouldn't really be an issue, because it's handled by signature validation.
mlflow/types/llm.py
Outdated
:param role: The role of the entity that sent the message (e.g. ``"user"``, ``"system"``). | ||
:type role: str | ||
:param content: The content of the message. | ||
:type content: str | ||
:param name: The name of the entity that sent the message. **Optional** | ||
:type name: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can unindent all of this stuff to be consistent with the rest of the codebase, but i like the way docstrings looked when they're aligned haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the google docstring style?
def _get_pyfunc_loader_module(python_model): | ||
if isinstance(python_model, ChatModel): | ||
return mlflow.pyfunc.loaders.chat_model.__name__ | ||
return __name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do we think of adding new pyfunc loaders to the mlflow.pyfunc.loaders
module? i think it would be a clean way for us to implement future custom loaders (e.g. for RAGModel
, CompletionModel
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me.
@@ -0,0 +1 @@ | |||
import mlflow.pyfunc.loaders.chat_model # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary? or does python load all files in the subdirectory into the module by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python doesn't. If we want to do from mlflow.pyfunc.loaders import chat_model
, we need this line, otherwise we don't.
# output is not coercable to ChatResponse | ||
messages = [ChatMessage(**m) for m in input_example["messages"]] | ||
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"}) | ||
output = python_model.predict(None, messages, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it a problem to perform inference during saving? i saw we do it when trying to infer output signature, but since this is kind of an LLM-specific API, inference can be kind of expensive. the input example specifies max_tokens=10
, so hopefully it isn't too bad.
if it is a concern, maybe we can just skip output validation entirely (as far as i can tell, there wouldn't be another way to ensure the return type of the predict()
method is actually a ChatResponse
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are some risks:
- It may take a while (e.g. a few seconds) for the API request to finish.
- No guarantee that the LLM service is healthy. If OpenAI is down, this line would throw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 We shouldn't predict while saving the model, the error message would be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from discussion offline, we'll keep the predict since we do it in transformers/other places already for output signature inference. i'll do some more testing here to make sure it's not a confusing experience
# output is not coercable to ChatResponse | ||
messages = [ChatMessage(**m) for m in input_example["messages"]] | ||
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"}) | ||
output = python_model.predict(None, messages, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 We shouldn't predict while saving the model, the error message would be confusing.
from mlflow.utils.model_utils import _get_flavor_configuration | ||
|
||
|
||
def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks the same as PythonModel's _load_pyfunc function (except the wrapper it returned), could we reuse the function and extract the final class as a parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored the common part to _load_context_model_and_signature
|
||
def _convert_input(self, model_input): | ||
# model_input should be correct from signature validation, so just convert it to dict here | ||
dict_input = {key: value[0] for key, value in model_input.to_dict(orient="list").items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does to_dict
accept orient
param?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems so: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_dict.html#pandas.DataFrame.to_dict
but i'm kind of new to pandas—is there something else i should use?
mlflow/types/llm.py
Outdated
elif all(isinstance(v, cls) for v in values): | ||
pass | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif all(isinstance(v, cls) for v in values): | |
pass | |
else: | |
elif any(not isinstance(v, cls) for v in values): |
mlflow/types/llm.py
Outdated
if not isinstance(self.message, ChatMessage): | ||
self.message = ChatMessage(**self.message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might encounter error if self.message is not a dictionary
self._validate_field("model", str, True) | ||
self._convert_dataclass_list("choices", ChatChoice) | ||
if not isinstance(self.usage, TokenUsageStats): | ||
self.usage = TokenUsageStats(**self.usage) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed this to check for dict
and throw ValueError
after if the field is not an instance of the expected type
total_tokens: int | ||
|
||
def __post_init__(self): | ||
self._validate_field("prompt_tokens", int, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is defining this as a required set of fields going to preclude using this interface in transformers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think it will preclude that, because we can generate these stats automatically for the user using the transformer's tokenizer. however, i can make it not required if it's a concern! it was unclear from the spec which fields are required and not.
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]>
cac0cda
to
c87f3dc
Compare
@@ -1999,6 +2009,25 @@ def predict(model_input: List[str]) -> List[str]: | |||
python_model, input_arg_index, input_example=input_example | |||
): | |||
mlflow_model.signature = signature | |||
elif isinstance(python_model, ChatModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do any validation/warning if customer specifies custom signature with ChatModel? If it doesn't comply our pydantic schema, we may want to reject here rather than at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes that's true, i'll throw a warning to say that the signature will be overridden and that it must conform to the spec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
woah actually this brought up a bug in my implementation—if the user specifies a signature, the model actually doesn't get saved as a ChatModel due to the elif
in line 2005 above. i guess it's elif
because this block contains a lot of validation/signature inference logic that we can skip if the user provides the signature themself. however, for ChatModel
we always want to do these validations (e.g. output validation)
cc @B-Step62 what do you think about raising an exception when trying to save a ChatModel subclass with a signature, e.g:
if signature is not None:
if isinstance(python_model, ChatModel):
raise MlflowException("ChatModel subclasses specify a signature automatically, please remove the provided signature from the log_model() or save_model() call.")
mlflow_model.singature = signature
elif python_model is not None:
# no change from this PR
another way is making a separate block for ChatModels, e.g:
if isinstance(python_model, ChatModel):
# move ChatModel logic to this block
...
elif signature is not None:
# no change
...
elif python_model is not None:
# no change
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice finding! I agree with throwing. Warning on happy path can be easily overlooked and almost invisible in automated environment.
mlflow/pyfunc/loaders/chat_model.py
Outdated
if isinstance(response, ChatResponse): | ||
return response.to_dict() | ||
|
||
# shouldn't happen since there is validation at save time ensuring that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we raise instead? I'm not sure ignoring unexpected behavior is beneficial.
mlflow/pyfunc/loaders/chat_model.py
Outdated
|
||
return messages, params | ||
|
||
def predict(self, model_input: ChatRequest, params: Optional[Dict[str, Any]] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def predict(self, model_input: ChatRequest, params: Optional[Dict[str, Any]] = None): | |
def predict(self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: |
super-nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes that's true haha, it won't be a ChatRequest when coming in
Signed-off-by: Daniel Lok <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left one very tiny comment, but otherwise LGTM! Awesome idea, it's always better to have typed object than handling dict everywhere:)
mlflow/pyfunc/loaders/chat_model.py
Outdated
messages, params = self._convert_input(model_input) | ||
response = self.chat_model.predict(self.context, messages, params) | ||
|
||
if isinstance(response, ChatResponse): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not isinstance(response, ChatResponse):
raise MLflowException(...)
return response.to_dict()
super-minor thing but probably more common way to structure the block
assert isinstance(response.choices[0].message, ChatMessage) | ||
|
||
|
||
def to_dict_converts_nested_dataclasses(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def to_dict_converts_nested_dataclasses(): | |
def test_to_dict_converts_nested_dataclasses(): |
assert not isinstance(response["choices"][0]["message"], ChatMessage) | ||
|
||
|
||
def to_dict_excludes_nones(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def to_dict_excludes_nones(): | |
def test_to_dict_excludes_nones(): |
|
||
def to_dict_converts_nested_dataclasses(): | ||
response = ChatResponse(**MOCK_RESPONSE).to_dict() | ||
assert not isinstance(response["choices"][0], ChatChoice) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the expected class? dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup it should be dict, i guess i should just assert that haha
Signed-off-by: Daniel Lok <[email protected]>
Signed-off-by: Daniel Lok <[email protected]> Signed-off-by: ernestwong-db <[email protected]>
Signed-off-by: Daniel Lok <[email protected]> Signed-off-by: lu-wang-dl <[email protected]>
🛠 DevTools 🛠
Install mlflow from this PR
Checkout with GitHub CLI
Related Issues/PRs
What changes are proposed in this pull request?
This PR adds the ChatModel subclass to make it more seamless for users to implement and serve chat models. The ChatModel class requires users to fill out a
predict
method of the following type (corresponding to the OpenAI chat request format):This makes it so that the user doesn't have to implement any parsing logic, and can directly work with the pydantic objects that are passed in. Additionally, input/output signatures and an input example are automatically provided.
To support this, we implement a new custom loader for these types of models, defined in
mlflow.pyfunc.loaders.chat_model
. This loader wraps theChatModel
in a_ChatModelPyfuncWrapper
class that accepts the standard chat request format, and breaks it up intomessages
andparams
for the user.How is this PR tested?
Ran the following to create a chat model:
Then on the command line:
Also tried viewing the model in MLflow UI:
Validate that the MLmodel file looks as expected

Validate that the signature looks correct:
Screen.Recording.2024-01-15.at.12.57.57.PM.mov
Does this PR require documentation update?
Requires a tutorial, but we can work on this in a follow-up PR
Release Notes
Is this a user-facing change?
Added the
ChatModel
pyfunc class, which allows for more convenient definition of chat models conforming to the OpenAI request/response format.What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/deployments
: MLflow Deployments client APIs, server, and third-party Deployments integrationsarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes