diff --git a/mlflow/llama_index/__init__.py b/mlflow/llama_index/__init__.py index 7e8fd969be377..83ea960d31b81 100644 --- a/mlflow/llama_index/__init__.py +++ b/mlflow/llama_index/__init__.py @@ -8,7 +8,7 @@ import mlflow from mlflow import pyfunc from mlflow.exceptions import MlflowException -from mlflow.llama_index.pyfunc_wrapper import create_engine_wrapper +from mlflow.llama_index.pyfunc_wrapper import create_pyfunc_wrapper from mlflow.models import Model, ModelInputExample, ModelSignature from mlflow.models.model import MLMODEL_FILE_NAME, MODEL_CODE_PATH from mlflow.models.signature import _infer_signature_from_input_example @@ -100,7 +100,16 @@ def _supported_classes(): from llama_index.core.indices.base import BaseIndex from llama_index.core.retrievers import BaseRetriever - return BaseIndex, BaseChatEngine, BaseQueryEngine, BaseRetriever + supported = (BaseIndex, BaseChatEngine, BaseQueryEngine, BaseRetriever) + + try: + from llama_index.core.workflow import Workflow + + supported += (Workflow,) + except ImportError: + pass + + return supported @experimental @@ -123,13 +132,25 @@ def save_model( """ Save a LlamaIndex model to a path on the local file system. + .. attention:: + + Saving a non-index object is only supported in the 'Model-from-Code' saving mode. + Please refer to the `Models From Code Guide `_ + for more information. + Args: - llama_index_model: An LlamaIndex object to be saved, or a string representing the path to - a script contains LlamaIndex index/engine definition. + llama_index_model: A LlamaIndex object to be saved. Supported model types are: + + 1. An Index object. + 2. An Engine object e.g. ChatEngine, QueryEngine, Retriever. + 3. A `Workflow `_ object. + 4. A string representing the path to a script contains LlamaIndex model definition + of the one of the above types. + path: Local path where the serialized model (as YAML) is to be saved. - engine_type: Required when saving an index object to determine the inference interface + engine_type: Required when saving an Index object to determine the inference interface for the index when loaded as a pyfunc model. This field is **not** required when - saving an engine directly. The supported types are as follows: + saving other LlamaIndex objects. The supported values are as follows: - ``"chat"``: load the index as an instance of the LlamaIndex `ChatEngine `_. @@ -177,7 +198,9 @@ def save_model( # Warn when user provides `engine_type` argument while saving an engine directly if not isinstance(llama_index_model, BaseIndex) and engine_type is not None: - _logger.warning("The `engine_type` argument is ignored when saving an engine.") + _logger.warning( + "The `engine_type` argument is ignored when saving a non-index object." + ) elif isinstance(model_or_code_path, BaseIndex): _validate_engine_type(engine_type) @@ -185,9 +208,9 @@ def save_model( elif isinstance(model_or_code_path, _supported_classes()): raise MlflowException.invalid_parameter_value( - "Saving an engine object is only supported in the 'Model-from-Code' saving mode. " + "Saving a non-index object is only supported in the 'Model-from-Code' saving mode. " "The legacy serialization method is exclusively for saving index objects. Please " - "pass the path to the script containing the engine definition to save an engine " + "pass the path to the script containing the model definition to save a non-index " "object. For more information, see " "https://www.mlflow.org/docs/latest/model/models-from-code.html", ) @@ -199,7 +222,7 @@ def save_model( saved_example = _save_example(mlflow_model, input_example, path) if signature is None and saved_example is not None: - wrapped_model = create_engine_wrapper(llama_index_model, engine_type, model_config) + wrapped_model = create_pyfunc_wrapper(llama_index_model, engine_type, model_config) signature = _infer_signature_from_input_example(saved_example, wrapped_model) elif signature is False: signature = None @@ -292,13 +315,25 @@ def log_model( """ Log a LlamaIndex model as an MLflow artifact for the current run. + .. attention:: + + Saving a non-index object is only supported in the 'Model-from-Code' saving mode. + Please refer to the `Models From Code Guide `_ + for more information. + Args: - llama_index_model: An LlamaIndex object to be saved, or a string representing the path to - a script contains LlamaIndex index/engine definition. + llama_index_model: A LlamaIndex object to be saved. Supported model types are: + + 1. An Index object. + 2. An Engine object e.g. ChatEngine, QueryEngine, Retriever. + 3. A `Workflow `_ object. + 4. A string representing the path to a script contains LlamaIndex model definition + of the one of the above types. + artifact_path: Local path where the serialized model (as YAML) is to be saved. - engine_type: Required when saving an index object to determine the inference interface + engine_type: Required when saving an Index object to determine the inference interface for the index when loaded as a pyfunc model. This field is **not** required when - saving an engine directly. The supported types are as follows: + saving other LlamaIndex objects. The supported values are as follows: - ``"chat"``: load the index as an instance of the LlamaIndex `ChatEngine `_. @@ -372,7 +407,7 @@ def _save_index(index, path): def _load_llama_model(path, flavor_conf): - """Load the LlamaIndex index or engine from either model code or serialized index.""" + """Load the LlamaIndex index/engine/workflow from either model code or serialized index.""" from llama_index.core import StorageContext, load_index_from_storage _add_code_from_conf_to_system_path(path, flavor_conf) @@ -398,7 +433,7 @@ def _load_llama_model(path, flavor_conf): @trace_disabled # Suppress traces while loading model def load_model(model_uri, dst_path=None): """ - Load a LlamaIndex index or engine from a local file or a run. + Load a LlamaIndex index/engine/workflow from a local file or a run. Args: model_uri: The location, in URI format, of the MLflow model. For example: @@ -431,12 +466,14 @@ def load_model(model_uri, dst_path=None): def _load_pyfunc(path, model_config: Optional[Dict[str, Any]] = None): - from mlflow.llama_index.pyfunc_wrapper import create_engine_wrapper + from mlflow.llama_index.pyfunc_wrapper import create_pyfunc_wrapper index = load_model(path) flavor_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME) - engine_type = flavor_conf.pop("engine_type", None) # Not present when saving an engine object - return create_engine_wrapper(index, engine_type, model_config) + engine_type = flavor_conf.pop( + "engine_type", None + ) # Not present when saving an non-index object + return create_pyfunc_wrapper(index, engine_type, model_config) @experimental diff --git a/mlflow/llama_index/pyfunc_wrapper.py b/mlflow/llama_index/pyfunc_wrapper.py index 9310d326bb5e4..38295d30b554a 100644 --- a/mlflow/llama_index/pyfunc_wrapper.py +++ b/mlflow/llama_index/pyfunc_wrapper.py @@ -1,3 +1,5 @@ +import asyncio +import threading from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union if TYPE_CHECKING: @@ -56,18 +58,18 @@ def _format_predict_input_query_engine_and_retriever(data) -> "QueryBundle": class _LlamaIndexModelWrapperBase: def __init__( self, - engine, + llama_model, # Engine or Workflow model_config: Optional[Dict[str, Any]] = None, ): - self.engine = engine + self._llama_model = llama_model self.model_config = model_config or {} @property def index(self): - return self.engine.index + return self._llama_model.index def get_raw_model(self): - return self.engine + return self._llama_model def _predict_single(self, *args, **kwargs) -> Any: raise NotImplementedError @@ -101,7 +103,7 @@ def engine_type(self): return CHAT_ENGINE_NAME def _predict_single(self, *args, **kwargs) -> str: - return self.engine.chat(*args, **kwargs).response + return self._llama_model.chat(*args, **kwargs).response @staticmethod def _convert_chat_message_history_to_chat_message_objects(data: Dict) -> Dict: @@ -145,7 +147,7 @@ def engine_type(self): return QUERY_ENGINE_NAME def _predict_single(self, *args, **kwargs) -> str: - return self.engine.query(*args, **kwargs).response + return self._llama_model.query(*args, **kwargs).response def _format_predict_input(self, data) -> "QueryBundle": return _format_predict_input_query_engine_and_retriever(data) @@ -157,27 +159,130 @@ def engine_type(self): return RETRIEVER_ENGINE_NAME def _predict_single(self, *args, **kwargs) -> List[Dict]: - response = self.engine.retrieve(*args, **kwargs) + response = self._llama_model.retrieve(*args, **kwargs) return [node.dict() for node in response] def _format_predict_input(self, data) -> "QueryBundle": return _format_predict_input_query_engine_and_retriever(data) -def create_engine_wrapper( - index_or_engine: Any, +class WorkflowWrapper(_LlamaIndexModelWrapperBase): + @property + def index(self): + raise NotImplementedError("LlamaIndex Workflow does not have an index") + + @property + def engine_type(self): + raise NotImplementedError("LlamaIndex Workflow is not an engine") + + def predict(self, data, params: Optional[Dict[str, Any]] = None) -> Union[List[str], str]: + inputs = self._format_predict_input(data, params) + + # LlamaIndex Workflow runs async but MLflow pyfunc doesn't support async inference yet. + predictions = self._wait_async_task(self._run_predictions(inputs)) + + # Even if the input is single instance, the signature enforcement convert it to a Pandas + # DataFrame with a single row. In this case, we should unwrap the result (list) so it + # won't be inconsistent with the output without signature enforcement. + should_unwrap = len(data) == 1 and isinstance(predictions, list) + return predictions[0] if should_unwrap else predictions + + def _format_predict_input(self, data, params: Optional[Dict[str, Any]] = None) -> List[Dict]: + inputs = _convert_llm_input_data_with_unwrapping(data) + params = params or {} + if isinstance(inputs, dict): + return [{**inputs, **params}] + return [{**x, **params} for x in inputs] + + async def _run_predictions(self, inputs: List[Dict[str, Any]]) -> asyncio.Future: + tasks = [self._predict_single(x) for x in inputs] + return await asyncio.gather(*tasks) + + async def _predict_single(self, x: Dict[str, Any]) -> Any: + if not isinstance(x, dict): + raise ValueError(f"Unsupported input type: {type(x)}. It must be a dictionary.") + return await self._llama_model.run(**x) + + def _wait_async_task(self, task: asyncio.Future) -> Any: + """ + A utility function to run async tasks in a blocking manner. + + If there is no event loop running already, for example, in a model serving endpoint, + we can simply create a new event loop and run the task there. However, in a notebook + environment (or pytest with asyncio decoration), there is already an event loop running + at the root level and we cannot start a new one. + """ + if not self._is_event_loop_running(): + return asyncio.new_event_loop().run_until_complete(task) + else: + # NB: The popular way to run async task where an event loop is already running is to + # use nest_asyncio. However, nest_asyncio.apply() breaks the async OpenAI client + # somehow, which is used for the most of LLM calls in LlamaIndex including Databricks + # LLMs. Therefore, we use a hacky workaround that creates a new thread and run the + # new event loop there. This may degrade the performance compared to the native + # asyncio, but it should be fine because this is only used in the notebook env. + results = None + exception = None + + def _run(): + nonlocal results, exception + + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + results = loop.run_until_complete(task) + except Exception as e: + exception = e + finally: + loop.close() + + thread = threading.Thread(target=_run) + thread.start() + thread.join() + + if exception: + raise exception + + return results + + def _is_event_loop_running(self) -> bool: + try: + loop = asyncio.get_running_loop() + return loop is not None + except Exception: + return False + + +def create_pyfunc_wrapper( + model: Any, engine_type: Optional[str] = None, model_config: Optional[Dict[str, Any]] = None, ): """ - A factory function that creates a Pyfunc wrapper around a LlamaIndex index or engine. + A factory function that creates a Pyfunc wrapper around a LlamaIndex index/engine/workflow. + + Args: + model: A LlamaIndex index/engine/workflow. + engine_type: The type of the engine. Only required if `model` is an index + and must be one of [chat, query, retriever]. + model_config: A dictionary of model configuration parameters. """ + try: + from llama_index.core.workflow import Workflow + + if isinstance(model, Workflow): + return _create_wrapper_from_workflow(model, model_config) + except ImportError: + pass + from llama_index.core.indices.base import BaseIndex - if isinstance(index_or_engine, BaseIndex): - return _create_wrapper_from_index(index_or_engine, engine_type, model_config) + if isinstance(model, BaseIndex): + return _create_wrapper_from_index(model, engine_type, model_config) else: - return _create_wrapper_from_engine(index_or_engine, model_config) + # Engine does not have a common base class so we assume + # everything else is an engine + return _create_wrapper_from_engine(model, model_config) def _create_wrapper_from_index( @@ -214,3 +319,7 @@ def _create_wrapper_from_engine(engine: Any, model_config: Optional[Dict[str, An raise ValueError( f"Unsupported engine type: {type(engine)}. It must be one of {SUPPORTED_ENGINES}" ) + + +def _create_wrapper_from_workflow(workflow: Any, model_config: Optional[Dict[str, Any]] = None): + return WorkflowWrapper(workflow, model_config) diff --git a/mlflow/ml-package-versions.yml b/mlflow/ml-package-versions.yml index 73b11091ac3a0..11e80cdc16b9b 100644 --- a/mlflow/ml-package-versions.yml +++ b/mlflow/ml-package-versions.yml @@ -773,6 +773,8 @@ llama_index: # Required to run tests/openai/mock_openai.py "fastapi", "uvicorn", + # Required for testing LlamaIndex workflow + "pytest-asyncio", ] run: pytest tests/llama_index --ignore tests/llama_index/test_llama_index_autolog.py --ignore tests/llama_index/test_llama_index_tracer.py autologging: diff --git a/mlflow/models/model.py b/mlflow/models/model.py index b209ebac445d5..ca97124331e18 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -1058,7 +1058,7 @@ def set_model(model): globals()["__mlflow_model__"] = model return - for validate_function in [_validate_langchain_model, _validate_llama_index_model]: + for validate_function in [_validate_llama_index_model]: try: globals()["__mlflow_model__"] = validate_function(model) return diff --git a/tests/llama_index/sample_code/simple_workflow.py b/tests/llama_index/sample_code/simple_workflow.py new file mode 100644 index 0000000000000..8eba21b32953a --- /dev/null +++ b/tests/llama_index/sample_code/simple_workflow.py @@ -0,0 +1,37 @@ +from llama_index.core.workflow import ( + Event, + StartEvent, + StopEvent, + Workflow, + step, +) +from llama_index.llms.openai import OpenAI + +import mlflow + + +class JokeEvent(Event): + joke: str + + +class JokeFlow(Workflow): + llm = OpenAI() + + @step + async def generate_joke(self, ev: StartEvent) -> JokeEvent: + topic = ev.topic + prompt = f"Write your best joke about {topic}." + response = await self.llm.acomplete(prompt) + return JokeEvent(joke=str(response)) + + @step + async def critique_joke(self, ev: JokeEvent) -> StopEvent: + joke = ev.joke + + prompt = f"Give a thorough analysis and critique of the following joke: {joke}" + response = await self.llm.acomplete(prompt) + return StopEvent(result=str(response)) + + +w = JokeFlow(timeout=10, verbose=False) +mlflow.models.set_model(w) diff --git a/tests/llama_index/test_llama_index_model_export.py b/tests/llama_index/test_llama_index_model_export.py index a3a1f98278410..dc0ae7756be95 100644 --- a/tests/llama_index/test_llama_index_model_export.py +++ b/tests/llama_index/test_llama_index_model_export.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path from typing import Any @@ -27,11 +28,15 @@ _CHAT_MESSAGE_HISTORY_PARAMETER_NAME, ChatEngineWrapper, QueryEngineWrapper, - create_engine_wrapper, + create_pyfunc_wrapper, ) +from mlflow.models.utils import load_serving_example +from mlflow.pyfunc.scoring_server import CONTENT_TYPE_JSON from mlflow.tracking.artifact_utils import _download_artifact_from_uri from mlflow.types.schema import ColSpec, DataType, Schema +from tests.helper_functions import pyfunc_scoring_endpoint + _EMBEDDING_DIM = 1536 _TEST_QUERY = "Spell llamaindex" @@ -84,7 +89,7 @@ def test_llama_index_save_invalid_object_raise(single_index): with pytest.raises(MlflowException, match="The provided object of type "): mlflow.llama_index.save_model(llama_index_model=OpenAI(), path="model", engine_type="query") - with pytest.raises(MlflowException, match="Saving an engine object is only supported"): + with pytest.raises(MlflowException, match="Saving a non-index object is only supported"): mlflow.llama_index.save_model( llama_index_model=single_index.as_query_engine(), path="model", @@ -96,7 +101,7 @@ def test_llama_index_save_invalid_object_raise(single_index): ["query", "retriever"], ) def test_format_predict_input_correct(single_index, engine_type): - wrapped_model = create_engine_wrapper(single_index, engine_type) + wrapped_model = create_pyfunc_wrapper(single_index, engine_type) assert isinstance( wrapped_model._format_predict_input(pd.DataFrame({"query_str": ["hi"]})), QueryBundle @@ -113,7 +118,7 @@ def test_format_predict_input_correct(single_index, engine_type): ["query", "retriever"], ) def test_format_predict_input_incorrect_schema(single_index, engine_type): - wrapped_model = create_engine_wrapper(single_index, engine_type) + wrapped_model = create_pyfunc_wrapper(single_index, engine_type) exception_error = ( r"__init__\(\) got an unexpected keyword argument 'incorrect'" @@ -132,7 +137,7 @@ def test_format_predict_input_incorrect_schema(single_index, engine_type): ["query", "retriever"], ) def test_format_predict_input_correct_schema_complex(single_index, engine_type): - wrapped_model = create_engine_wrapper(single_index, engine_type) + wrapped_model = create_pyfunc_wrapper(single_index, engine_type) payload = { "query_str": "hi", @@ -471,3 +476,70 @@ def test_save_engine_with_engine_type_issues_warning(model_path): assert mock_logger.warning.call_count == 1 assert "The `engine_type` argument" in mock_logger.warning.call_args[0][0] + + +@pytest.mark.skipif( + Version(llama_index.core.__version__) < Version("0.11.0"), + reason="Workflow was introduced in 0.11.0", +) +@pytest.mark.asyncio +async def test_save_load_workflow_as_code(): + from llama_index.core.workflow import Workflow + + index_code_path = "tests/llama_index/sample_code/simple_workflow.py" + with mlflow.start_run(): + model_info = mlflow.llama_index.log_model( + llama_index_model=index_code_path, + artifact_path="model", + input_example={"topic": "pirates"}, + ) + + # Signature + assert model_info.signature.inputs == Schema([ColSpec(type=DataType.string, name="topic")]) + assert model_info.signature.outputs == Schema([ColSpec(DataType.string)]) + + # Native inference + loaded_workflow = mlflow.llama_index.load_model(model_info.model_uri) + assert isinstance(loaded_workflow, Workflow) + result = await loaded_workflow.run(topic="pirates") + assert isinstance(result, str) + assert "pirates" in result + + # Pyfunc inference + pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) + assert isinstance(pyfunc_loaded_model.get_raw_model(), Workflow) + result = pyfunc_loaded_model.predict({"topic": "pirates"}) + assert isinstance(result, str) + assert "pirates" in result + + # Batch inference + batch_result = pyfunc_loaded_model.predict( + [ + {"topic": "pirates"}, + {"topic": "ninjas"}, + {"topic": "robots"}, + ] + ) + assert len(batch_result) == 3 + assert all(isinstance(r, str) for r in batch_result) + + # Serve + inference_payload = load_serving_example(model_info.model_uri) + + with pyfunc_scoring_endpoint( + model_uri=model_info.model_uri, + extra_args=["--env-manager", "local"], + ) as endpoint: + # Single input + response = endpoint.invoke(inference_payload, content_type=CONTENT_TYPE_JSON) + assert response.status_code == 200, response.text + assert response.json()["predictions"] == result + + # Batch input + df = pd.DataFrame({"topic": ["pirates", "ninjas", "robots"]}) + response = endpoint.invoke( + json.dumps({"dataframe_split": df.to_dict(orient="split")}), + content_type=CONTENT_TYPE_JSON, + ) + assert response.status_code == 200, response.text + assert response.json()["predictions"] == batch_result diff --git a/tests/llama_index/test_llama_index_pyfunc_wrapper.py b/tests/llama_index/test_llama_index_pyfunc_wrapper.py index 108b479f49746..30cced6319d0c 100644 --- a/tests/llama_index/test_llama_index_pyfunc_wrapper.py +++ b/tests/llama_index/test_llama_index_pyfunc_wrapper.py @@ -1,8 +1,10 @@ +import llama_index.core import numpy as np import pandas as pd import pytest from llama_index.core import QueryBundle from llama_index.core.llms import ChatMessage +from packaging.version import Version import mlflow from mlflow.llama_index.pyfunc_wrapper import ( @@ -10,19 +12,19 @@ CHAT_ENGINE_NAME, QUERY_ENGINE_NAME, RETRIEVER_ENGINE_NAME, - create_engine_wrapper, + create_pyfunc_wrapper, ) ################## Inferece Input ################# def test_format_predict_input_str_chat(single_index): - wrapped_model = create_engine_wrapper(single_index, CHAT_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input("string") assert formatted_data == "string" def test_format_predict_input_dict_chat(single_index): - wrapped_model = create_engine_wrapper(single_index, CHAT_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input({"query": "string"}) assert isinstance(formatted_data, dict) @@ -32,7 +34,7 @@ def test_format_predict_input_message_history_chat(single_index): "message": "string", _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: [{"role": "user", "content": "hi"}] * 3, } - wrapped_model = create_engine_wrapper(single_index, CHAT_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input(payload) assert isinstance(formatted_data, dict) @@ -65,7 +67,7 @@ def test_format_predict_input_message_history_chat(single_index): ], ) def test_format_predict_input_message_history_chat_iterable(single_index, data): - wrapped_model = create_engine_wrapper(single_index, CHAT_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input(data) if isinstance(data, pd.DataFrame): @@ -84,7 +86,7 @@ def test_format_predict_input_message_history_chat_invalid_type(single_index): "message": "string", _CHAT_MESSAGE_HISTORY_PARAMETER_NAME: ["invalid history string", "user: hi"], } - wrapped_model = create_engine_wrapper(single_index, CHAT_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, CHAT_ENGINE_NAME) with pytest.raises(ValueError, match="It must be a list of dicts"): _ = wrapped_model._format_predict_input(payload) @@ -102,7 +104,7 @@ def test_format_predict_input_message_history_chat_invalid_type(single_index): ], ) def test_format_predict_input_no_iterable_query(single_index, data): - wrapped_model = create_engine_wrapper(single_index, QUERY_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, QUERY_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input(data) assert isinstance(formatted_data, QueryBundle) @@ -126,7 +128,7 @@ def test_format_predict_input_no_iterable_query(single_index, data): ], ) def test_format_predict_input_iterable_query(single_index, data): - wrapped_model = create_engine_wrapper(single_index, QUERY_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, QUERY_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input(data) assert isinstance(formatted_data, list) @@ -146,7 +148,7 @@ def test_format_predict_input_iterable_query(single_index, data): ], ) def test_format_predict_input_no_iterable_retriever(single_index, data): - wrapped_model = create_engine_wrapper(single_index, RETRIEVER_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, RETRIEVER_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input(data) assert isinstance(formatted_data, QueryBundle) @@ -170,7 +172,7 @@ def test_format_predict_input_no_iterable_retriever(single_index, data): ], ) def test_format_predict_input_iterable_retriever(single_index, data): - wrapped_model = create_engine_wrapper(single_index, RETRIEVER_ENGINE_NAME) + wrapped_model = create_pyfunc_wrapper(single_index, RETRIEVER_ENGINE_NAME) formatted_data = wrapped_model._format_predict_input(data) assert isinstance(formatted_data, list) assert all(isinstance(x, QueryBundle) for x in formatted_data) @@ -181,7 +183,7 @@ def test_format_predict_input_iterable_retriever(single_index, data): ["query", "retriever"], ) def test_format_predict_input_correct(single_index, engine_type): - wrapped_model = create_engine_wrapper(single_index, engine_type) + wrapped_model = create_pyfunc_wrapper(single_index, engine_type) assert isinstance( wrapped_model._format_predict_input(pd.DataFrame({"query_str": ["hi"]})), QueryBundle @@ -198,7 +200,7 @@ def test_format_predict_input_correct(single_index, engine_type): ["query", "retriever"], ) def test_format_predict_input_correct_schema_complex(single_index, engine_type): - wrapped_model = create_engine_wrapper(single_index, engine_type) + wrapped_model = create_pyfunc_wrapper(single_index, engine_type) payload = { "query_str": "hi", @@ -258,3 +260,58 @@ def test_spark_udf_chat(model_path, spark, single_index): pdf = df.toPandas() assert len(pdf["predictions"].tolist()) == 1 assert isinstance(pdf["predictions"].tolist()[0], str) + + +@pytest.mark.skipif( + Version(llama_index.core.__version__) < Version("0.11.0"), + reason="Workflow was introduced in 0.11.0", +) +@pytest.mark.asyncio +async def test_wrap_workflow(): + from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step + + class MyWorkflow(Workflow): + @step + async def my_step(self, ev: StartEvent) -> StopEvent: + return StopEvent(result=f"Hi, {ev.name}!") + + w = MyWorkflow(timeout=10, verbose=False) + wrapper = create_pyfunc_wrapper(w) + assert wrapper.get_raw_model() == w + + result = wrapper.predict({"name": "Alice"}) + assert result == "Hi, Alice!" + + results = wrapper.predict( + [ + {"name": "Bob"}, + {"name": "Charlie"}, + ] + ) + assert results == ["Hi, Bob!", "Hi, Charlie!"] + + results = wrapper.predict(pd.DataFrame({"name": ["David"]})) + assert results == "Hi, David!" + + results = wrapper.predict(pd.DataFrame({"name": ["Eve", "Frank"]})) + assert results == ["Hi, Eve!", "Hi, Frank!"] + + +@pytest.mark.skipif( + Version(llama_index.core.__version__) < Version("0.11.0"), + reason="Workflow was introduced in 0.11.0", +) +@pytest.mark.asyncio +async def test_wrap_workflow_raise_exception(): + from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step + + class MyWorkflow(Workflow): + @step + async def my_step(self, ev: StartEvent) -> StopEvent: + raise ValueError("Expected error") + + w = MyWorkflow(timeout=10, verbose=False) + wrapper = create_pyfunc_wrapper(w) + + with pytest.raises(ValueError, match="Expected error"): + wrapper.predict({"name": "Alice"})