Skip to content

Commit

Permalink
LlamaIndex workflow logging (mlflow#13277)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored and BenWilson2 committed Oct 11, 2024
1 parent 1abb24a commit 6420a1e
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 50 deletions.
75 changes: 56 additions & 19 deletions mlflow/llama_index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mlflow
from mlflow import pyfunc
from mlflow.exceptions import MlflowException
from mlflow.llama_index.pyfunc_wrapper import create_engine_wrapper
from mlflow.llama_index.pyfunc_wrapper import create_pyfunc_wrapper
from mlflow.models import Model, ModelInputExample, ModelSignature
from mlflow.models.model import MLMODEL_FILE_NAME, MODEL_CODE_PATH
from mlflow.models.signature import _infer_signature_from_input_example
Expand Down Expand Up @@ -100,7 +100,16 @@ def _supported_classes():
from llama_index.core.indices.base import BaseIndex
from llama_index.core.retrievers import BaseRetriever

return BaseIndex, BaseChatEngine, BaseQueryEngine, BaseRetriever
supported = (BaseIndex, BaseChatEngine, BaseQueryEngine, BaseRetriever)

try:
from llama_index.core.workflow import Workflow

supported += (Workflow,)
except ImportError:
pass

return supported


@experimental
Expand All @@ -123,13 +132,25 @@ def save_model(
"""
Save a LlamaIndex model to a path on the local file system.
.. attention::
Saving a non-index object is only supported in the 'Model-from-Code' saving mode.
Please refer to the `Models From Code Guide <https://www.mlflow.org/docs/latest/model/models-from-code.html>`_
for more information.
Args:
llama_index_model: An LlamaIndex object to be saved, or a string representing the path to
a script contains LlamaIndex index/engine definition.
llama_index_model: A LlamaIndex object to be saved. Supported model types are:
1. An Index object.
2. An Engine object e.g. ChatEngine, QueryEngine, Retriever.
3. A `Workflow <https://docs.llamaindex.ai/en/stable/module_guides/workflow/>`_ object.
4. A string representing the path to a script contains LlamaIndex model definition
of the one of the above types.
path: Local path where the serialized model (as YAML) is to be saved.
engine_type: Required when saving an index object to determine the inference interface
engine_type: Required when saving an Index object to determine the inference interface
for the index when loaded as a pyfunc model. This field is **not** required when
saving an engine directly. The supported types are as follows:
saving other LlamaIndex objects. The supported values are as follows:
- ``"chat"``: load the index as an instance of the LlamaIndex
`ChatEngine <https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/>`_.
Expand Down Expand Up @@ -177,17 +198,19 @@ def save_model(

# Warn when user provides `engine_type` argument while saving an engine directly
if not isinstance(llama_index_model, BaseIndex) and engine_type is not None:
_logger.warning("The `engine_type` argument is ignored when saving an engine.")
_logger.warning(
"The `engine_type` argument is ignored when saving a non-index object."
)

elif isinstance(model_or_code_path, BaseIndex):
_validate_engine_type(engine_type)
llama_index_model = model_or_code_path

elif isinstance(model_or_code_path, _supported_classes()):
raise MlflowException.invalid_parameter_value(
"Saving an engine object is only supported in the 'Model-from-Code' saving mode. "
"Saving a non-index object is only supported in the 'Model-from-Code' saving mode. "
"The legacy serialization method is exclusively for saving index objects. Please "
"pass the path to the script containing the engine definition to save an engine "
"pass the path to the script containing the model definition to save a non-index "
"object. For more information, see "
"https://www.mlflow.org/docs/latest/model/models-from-code.html",
)
Expand All @@ -199,7 +222,7 @@ def save_model(
saved_example = _save_example(mlflow_model, input_example, path)

if signature is None and saved_example is not None:
wrapped_model = create_engine_wrapper(llama_index_model, engine_type, model_config)
wrapped_model = create_pyfunc_wrapper(llama_index_model, engine_type, model_config)
signature = _infer_signature_from_input_example(saved_example, wrapped_model)
elif signature is False:
signature = None
Expand Down Expand Up @@ -292,13 +315,25 @@ def log_model(
"""
Log a LlamaIndex model as an MLflow artifact for the current run.
.. attention::
Saving a non-index object is only supported in the 'Model-from-Code' saving mode.
Please refer to the `Models From Code Guide <https://www.mlflow.org/docs/latest/model/models-from-code.html>`_
for more information.
Args:
llama_index_model: An LlamaIndex object to be saved, or a string representing the path to
a script contains LlamaIndex index/engine definition.
llama_index_model: A LlamaIndex object to be saved. Supported model types are:
1. An Index object.
2. An Engine object e.g. ChatEngine, QueryEngine, Retriever.
3. A `Workflow <https://docs.llamaindex.ai/en/stable/module_guides/workflow/>`_ object.
4. A string representing the path to a script contains LlamaIndex model definition
of the one of the above types.
artifact_path: Local path where the serialized model (as YAML) is to be saved.
engine_type: Required when saving an index object to determine the inference interface
engine_type: Required when saving an Index object to determine the inference interface
for the index when loaded as a pyfunc model. This field is **not** required when
saving an engine directly. The supported types are as follows:
saving other LlamaIndex objects. The supported values are as follows:
- ``"chat"``: load the index as an instance of the LlamaIndex
`ChatEngine <https://docs.llamaindex.ai/en/stable/module_guides/deploying/chat_engines/>`_.
Expand Down Expand Up @@ -372,7 +407,7 @@ def _save_index(index, path):


def _load_llama_model(path, flavor_conf):
"""Load the LlamaIndex index or engine from either model code or serialized index."""
"""Load the LlamaIndex index/engine/workflow from either model code or serialized index."""
from llama_index.core import StorageContext, load_index_from_storage

_add_code_from_conf_to_system_path(path, flavor_conf)
Expand All @@ -398,7 +433,7 @@ def _load_llama_model(path, flavor_conf):
@trace_disabled # Suppress traces while loading model
def load_model(model_uri, dst_path=None):
"""
Load a LlamaIndex index or engine from a local file or a run.
Load a LlamaIndex index/engine/workflow from a local file or a run.
Args:
model_uri: The location, in URI format, of the MLflow model. For example:
Expand Down Expand Up @@ -431,12 +466,14 @@ def load_model(model_uri, dst_path=None):


def _load_pyfunc(path, model_config: Optional[Dict[str, Any]] = None):
from mlflow.llama_index.pyfunc_wrapper import create_engine_wrapper
from mlflow.llama_index.pyfunc_wrapper import create_pyfunc_wrapper

index = load_model(path)
flavor_conf = _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME)
engine_type = flavor_conf.pop("engine_type", None) # Not present when saving an engine object
return create_engine_wrapper(index, engine_type, model_config)
engine_type = flavor_conf.pop(
"engine_type", None
) # Not present when saving an non-index object
return create_pyfunc_wrapper(index, engine_type, model_config)


@experimental
Expand Down
135 changes: 122 additions & 13 deletions mlflow/llama_index/pyfunc_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,18 +58,18 @@ def _format_predict_input_query_engine_and_retriever(data) -> "QueryBundle":
class _LlamaIndexModelWrapperBase:
def __init__(
self,
engine,
llama_model, # Engine or Workflow
model_config: Optional[Dict[str, Any]] = None,
):
self.engine = engine
self._llama_model = llama_model
self.model_config = model_config or {}

@property
def index(self):
return self.engine.index
return self._llama_model.index

def get_raw_model(self):
return self.engine
return self._llama_model

def _predict_single(self, *args, **kwargs) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -101,7 +103,7 @@ def engine_type(self):
return CHAT_ENGINE_NAME

def _predict_single(self, *args, **kwargs) -> str:
return self.engine.chat(*args, **kwargs).response
return self._llama_model.chat(*args, **kwargs).response

@staticmethod
def _convert_chat_message_history_to_chat_message_objects(data: Dict) -> Dict:
Expand Down Expand Up @@ -145,7 +147,7 @@ def engine_type(self):
return QUERY_ENGINE_NAME

def _predict_single(self, *args, **kwargs) -> str:
return self.engine.query(*args, **kwargs).response
return self._llama_model.query(*args, **kwargs).response

def _format_predict_input(self, data) -> "QueryBundle":
return _format_predict_input_query_engine_and_retriever(data)
Expand All @@ -157,27 +159,130 @@ def engine_type(self):
return RETRIEVER_ENGINE_NAME

def _predict_single(self, *args, **kwargs) -> List[Dict]:
response = self.engine.retrieve(*args, **kwargs)
response = self._llama_model.retrieve(*args, **kwargs)
return [node.dict() for node in response]

def _format_predict_input(self, data) -> "QueryBundle":
return _format_predict_input_query_engine_and_retriever(data)


def create_engine_wrapper(
index_or_engine: Any,
class WorkflowWrapper(_LlamaIndexModelWrapperBase):
@property
def index(self):
raise NotImplementedError("LlamaIndex Workflow does not have an index")

@property
def engine_type(self):
raise NotImplementedError("LlamaIndex Workflow is not an engine")

def predict(self, data, params: Optional[Dict[str, Any]] = None) -> Union[List[str], str]:
inputs = self._format_predict_input(data, params)

# LlamaIndex Workflow runs async but MLflow pyfunc doesn't support async inference yet.
predictions = self._wait_async_task(self._run_predictions(inputs))

# Even if the input is single instance, the signature enforcement convert it to a Pandas
# DataFrame with a single row. In this case, we should unwrap the result (list) so it
# won't be inconsistent with the output without signature enforcement.
should_unwrap = len(data) == 1 and isinstance(predictions, list)
return predictions[0] if should_unwrap else predictions

def _format_predict_input(self, data, params: Optional[Dict[str, Any]] = None) -> List[Dict]:
inputs = _convert_llm_input_data_with_unwrapping(data)
params = params or {}
if isinstance(inputs, dict):
return [{**inputs, **params}]
return [{**x, **params} for x in inputs]

async def _run_predictions(self, inputs: List[Dict[str, Any]]) -> asyncio.Future:
tasks = [self._predict_single(x) for x in inputs]
return await asyncio.gather(*tasks)

async def _predict_single(self, x: Dict[str, Any]) -> Any:
if not isinstance(x, dict):
raise ValueError(f"Unsupported input type: {type(x)}. It must be a dictionary.")
return await self._llama_model.run(**x)

def _wait_async_task(self, task: asyncio.Future) -> Any:
"""
A utility function to run async tasks in a blocking manner.
If there is no event loop running already, for example, in a model serving endpoint,
we can simply create a new event loop and run the task there. However, in a notebook
environment (or pytest with asyncio decoration), there is already an event loop running
at the root level and we cannot start a new one.
"""
if not self._is_event_loop_running():
return asyncio.new_event_loop().run_until_complete(task)
else:
# NB: The popular way to run async task where an event loop is already running is to
# use nest_asyncio. However, nest_asyncio.apply() breaks the async OpenAI client
# somehow, which is used for the most of LLM calls in LlamaIndex including Databricks
# LLMs. Therefore, we use a hacky workaround that creates a new thread and run the
# new event loop there. This may degrade the performance compared to the native
# asyncio, but it should be fine because this is only used in the notebook env.
results = None
exception = None

def _run():
nonlocal results, exception

try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
results = loop.run_until_complete(task)
except Exception as e:
exception = e
finally:
loop.close()

thread = threading.Thread(target=_run)
thread.start()
thread.join()

if exception:
raise exception

return results

def _is_event_loop_running(self) -> bool:
try:
loop = asyncio.get_running_loop()
return loop is not None
except Exception:
return False


def create_pyfunc_wrapper(
model: Any,
engine_type: Optional[str] = None,
model_config: Optional[Dict[str, Any]] = None,
):
"""
A factory function that creates a Pyfunc wrapper around a LlamaIndex index or engine.
A factory function that creates a Pyfunc wrapper around a LlamaIndex index/engine/workflow.
Args:
model: A LlamaIndex index/engine/workflow.
engine_type: The type of the engine. Only required if `model` is an index
and must be one of [chat, query, retriever].
model_config: A dictionary of model configuration parameters.
"""
try:
from llama_index.core.workflow import Workflow

if isinstance(model, Workflow):
return _create_wrapper_from_workflow(model, model_config)
except ImportError:
pass

from llama_index.core.indices.base import BaseIndex

if isinstance(index_or_engine, BaseIndex):
return _create_wrapper_from_index(index_or_engine, engine_type, model_config)
if isinstance(model, BaseIndex):
return _create_wrapper_from_index(model, engine_type, model_config)
else:
return _create_wrapper_from_engine(index_or_engine, model_config)
# Engine does not have a common base class so we assume
# everything else is an engine
return _create_wrapper_from_engine(model, model_config)


def _create_wrapper_from_index(
Expand Down Expand Up @@ -214,3 +319,7 @@ def _create_wrapper_from_engine(engine: Any, model_config: Optional[Dict[str, An
raise ValueError(
f"Unsupported engine type: {type(engine)}. It must be one of {SUPPORTED_ENGINES}"
)


def _create_wrapper_from_workflow(workflow: Any, model_config: Optional[Dict[str, Any]] = None):
return WorkflowWrapper(workflow, model_config)
2 changes: 2 additions & 0 deletions mlflow/ml-package-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ llama_index:
# Required to run tests/openai/mock_openai.py
"fastapi",
"uvicorn",
# Required for testing LlamaIndex workflow
"pytest-asyncio",
]
run: pytest tests/llama_index --ignore tests/llama_index/test_llama_index_autolog.py --ignore tests/llama_index/test_llama_index_tracer.py
autologging:
Expand Down
2 changes: 1 addition & 1 deletion mlflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ def set_model(model):
globals()["__mlflow_model__"] = model
return

for validate_function in [_validate_langchain_model, _validate_llama_index_model]:
for validate_function in [_validate_llama_index_model]:
try:
globals()["__mlflow_model__"] = validate_function(model)
return
Expand Down
Loading

0 comments on commit 6420a1e

Please sign in to comment.