forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement ChatModel (pyfunc subclass) (mlflow#10820)
Signed-off-by: Daniel Lok <[email protected]> Signed-off-by: lu-wang-dl <[email protected]>
- Loading branch information
1 parent
ffa3064
commit ed038cd
Showing
10 changed files
with
767 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import mlflow.pyfunc.loaders.chat_model # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import Any, Dict, Optional | ||
|
||
from mlflow.exceptions import MlflowException | ||
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR | ||
from mlflow.pyfunc.model import ( | ||
_load_context_model_and_signature, | ||
) | ||
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse | ||
from mlflow.utils.annotations import experimental | ||
|
||
|
||
def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None): | ||
context, chat_model, signature = _load_context_model_and_signature(model_path, model_config) | ||
return _ChatModelPyfuncWrapper(chat_model=chat_model, context=context, signature=signature) | ||
|
||
|
||
@experimental | ||
class _ChatModelPyfuncWrapper: | ||
""" | ||
Wrapper class that converts dict inputs to pydantic objects accepted by :class:`~ChatModel`. | ||
""" | ||
|
||
def __init__(self, chat_model, context, signature): | ||
""" | ||
Args: | ||
chat_model: An instance of a subclass of :class:`~ChatModel`. | ||
context: A :class:`~PythonModelContext` instance containing artifacts that | ||
``chat_model`` may use when performing inference. | ||
signature: :class:`~ModelSignature` instance describing model input and output. | ||
""" | ||
self.chat_model = chat_model | ||
self.context = context | ||
self.signature = signature | ||
|
||
def _convert_input(self, model_input): | ||
# model_input should be correct from signature validation, so just convert it to dict here | ||
dict_input = {key: value[0] for key, value in model_input.to_dict(orient="list").items()} | ||
|
||
messages = [ChatMessage(**message) for message in dict_input.pop("messages", [])] | ||
params = ChatParams(**dict_input) | ||
|
||
return messages, params | ||
|
||
def predict( | ||
self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None | ||
) -> Dict[str, Any]: | ||
""" | ||
Args: | ||
model_input: Model input data in the form of a chat request. | ||
params: Additional parameters to pass to the model for inference. | ||
Unused in this implementation, as the params are handled | ||
via ``self._convert_input()``. | ||
Returns: | ||
Model predictions in :py:class:`~ChatResponse` format. | ||
""" | ||
messages, params = self._convert_input(model_input) | ||
response = self.chat_model.predict(self.context, messages, params) | ||
|
||
if not isinstance(response, ChatResponse): | ||
# shouldn't happen since there is validation at save time ensuring that | ||
# the output is a ChatResponse, so raise an exception if it isn't | ||
raise MlflowException( | ||
"Model returned an invalid response. Expected a ChatResponse, but " | ||
f"got {type(response)} instead.", | ||
error_code=INTERNAL_ERROR, | ||
) | ||
|
||
return response.to_dict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.