Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <[email protected]>
  • Loading branch information
daniellok-db committed Jan 25, 2024
1 parent 3be00df commit cb17a59
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 9 deletions.
16 changes: 16 additions & 0 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@
CHAT_MODEL_INPUT_EXAMPLE,
CHAT_MODEL_INPUT_SCHEMA,
CHAT_MODEL_OUTPUT_SCHEMA,
ChatMessage,
ChatParams,
ChatResponse,
)
from mlflow.utils import (
PYTHON_VERSION,
Expand Down Expand Up @@ -2012,6 +2015,19 @@ def predict(model_input: List[str]) -> List[str]:
CHAT_MODEL_OUTPUT_SCHEMA,
)
input_example = CHAT_MODEL_INPUT_EXAMPLE

# perform output validation and throw if
# output is not coercable to ChatResponse
messages = [ChatMessage(**m) for m in input_example["messages"]]
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"})
output = python_model.predict(None, messages, params)
if not isinstance(output, ChatResponse):
raise MlflowException(
"Failed to save ChatModel. Please ensure that the model's predict() method "
"returns a ChatResponse object. If your predict() method currently returns "
"a dict, you can instantiate a ChatResponse by unpacking the output like "
"this: `ChatResponse(**output)`",
)
elif isinstance(python_model, PythonModel):
input_arg_index = 1 # second argument
if signature := _infer_signature_from_type_hints(
Expand Down
2 changes: 1 addition & 1 deletion mlflow/pyfunc/loaders/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
CONFIG_KEY_PYTHON_MODEL,
PythonModelContext,
)
from mlflow.types import ChatMessage, ChatParams, ChatRequest, ChatResponse
from mlflow.types.llm import ChatMessage, ChatParams, ChatRequest, ChatResponse
from mlflow.utils.annotations import experimental
from mlflow.utils.model_utils import _get_flavor_configuration

Expand Down
6 changes: 2 additions & 4 deletions mlflow/pyfunc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.signature import _extract_type_hints
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types import ChatMessage, ChatParams, ChatResponse
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse
from mlflow.utils.annotations import experimental
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
Expand Down Expand Up @@ -194,7 +194,7 @@ def model_config(self):


@experimental
class ChatModel(PythonModel):
class ChatModel(PythonModel, metaclass=ABCMeta):
"""
A subclass of :class:`~PythonModel` that makes it more convenient to implement models
that are compatible with popular LLM chat APIs. Test.
Expand All @@ -203,8 +203,6 @@ class ChatModel(PythonModel):
method that is more convenient for chat tasks than the generic :class:`~PythonModel` API.
"""

__metaclass__ = ABCMeta

@abstractmethod
def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
"""
Expand Down
2 changes: 1 addition & 1 deletion mlflow/types/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __post_init__(self):
{"role": "user", "content": "Hello!"},
],
"temperature": 1.0,
"max_tokens": 20,
"max_tokens": 10,
"stop": ["\n"],
"n": 1,
"stream": False,
Expand Down
40 changes: 37 additions & 3 deletions tests/pyfunc/test_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
from typing import List

import pytest

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.pyfunc.loaders.chat_model import _ChatModelPyfuncWrapper
from mlflow.types.llm import (
CHAT_MODEL_INPUT_SCHEMA,
Expand All @@ -24,9 +27,6 @@
class TestChatModel(mlflow.pyfunc.ChatModel):
def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
mock_response = {
"id": "123",
"object": "chat.completion",
"created": 1677652288,
"model": "MyChatModel",
"choices": [
{
Expand Down Expand Up @@ -67,6 +67,40 @@ def test_chat_model_save_load(tmp_path):
assert output_schema == CHAT_MODEL_OUTPUT_SCHEMA


@pytest.mark.parametrize(
"ret",
[
"not a ChatResponse",
{"dict": "with", "bad": "keys"},
{
"id": "1",
"created": 1,
"model": "m",
"choices": [{"bad": "choice"}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
},
},
],
)
def test_save_throws_on_invalid_output(tmp_path, ret):
class BadChatModel(mlflow.pyfunc.ChatModel):
def predict(self, context, messages, params) -> ChatResponse:
return ret

model = BadChatModel()
with pytest.raises(
MlflowException,
match=(
"Failed to save ChatModel. Please ensure that the model's "
r"predict\(\) method returns a ChatResponse object"
),
):
mlflow.pyfunc.save_model(python_model=model, path=tmp_path)


# test that we can predict with the model
def test_chat_model_predict(tmp_path):
model = TestChatModel()
Expand Down

0 comments on commit cb17a59

Please sign in to comment.