From 6f409b2a1e1b4ed1990bb3b2c6083f1a6e0a463f Mon Sep 17 00:00:00 2001 From: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:32:59 +0900 Subject: [PATCH] Llama workflow Trace (#13305) Signed-off-by: B-Step62 --- mlflow/llama_index/tracer.py | 23 ++++- mlflow/ml-package-versions.yml | 2 + mlflow/models/model.py | 2 +- tests/llama_index/test_llama_index_tracer.py | 95 ++++++++++++++++++++ 4 files changed, 118 insertions(+), 4 deletions(-) diff --git a/mlflow/llama_index/tracer.py b/mlflow/llama_index/tracer.py index d275b71ce0876..6ec4f17a1f450 100644 --- a/mlflow/llama_index/tracer.py +++ b/mlflow/llama_index/tracer.py @@ -4,6 +4,7 @@ from functools import singledispatchmethod from typing import Any, Dict, Generator, Optional, Tuple, Union +import llama_index.core import pydantic from llama_index.core.base.agent.types import BaseAgent, BaseAgentWorker, TaskStepOutput from llama_index.core.base.base_retriever import BaseRetriever @@ -39,6 +40,7 @@ _logger = logging.getLogger(__name__) IS_PYDANTIC_V1 = Version(pydantic.__version__).major < 2 +LLAMA_INDEX_VERSION = Version(llama_index.core.__version__) def set_llama_index_tracer(): @@ -127,11 +129,14 @@ def new_span( parent_span_id: Optional[str] = None, **kwargs: Any, ) -> _LlamaSpan: + with self.lock: + parent = self.open_spans.get(parent_span_id) if parent_span_id else None + try: input_args = bound_args.arguments attributes = self._get_instance_attributes(instance) span_type = self._get_span_type(instance) or SpanType.UNKNOWN - if parent_span_id and (parent := self.open_spans.get(parent_span_id)): + if parent: parent_span = parent._mlflow_span # NB: Initiate the new client every time to handle tracking URI updates. span = MlflowClient().start_span( @@ -160,9 +165,11 @@ def prepare_to_exit_span( **kwargs: Any, ) -> _LlamaSpan: try: - llama_span = self.open_spans.get(id_) + with self.lock: + llama_span = self.open_spans.get(id_) if not llama_span: return + span = llama_span._mlflow_span if self._stream_resolver.is_streaming_result(result): @@ -187,8 +194,18 @@ def resolve_pending_stream_span(self, span: LiveSpan, event: Any): def prepare_to_drop_span(self, id_: str, err: Optional[Exception], **kwargs) -> _LlamaSpan: """Logic for handling errors during the model execution.""" - llama_span = self.open_spans.get(id_) + with self.lock: + llama_span = self.open_spans.get(id_) span = llama_span._mlflow_span + + if LLAMA_INDEX_VERSION >= Version("0.10.59"): + # LlamaIndex determines if a workflow is terminated or not by propagating an special + # exception WorkflowDone. We should treat this exception as a successful termination. + from llama_index.core.workflow.errors import WorkflowDone + + if err and isinstance(err, WorkflowDone): + return _end_span(span=span, status=SpanStatusCode.OK) + span.add_event(SpanEvent.from_exception(err)) _end_span(span=span, status="ERROR") return llama_span diff --git a/mlflow/ml-package-versions.yml b/mlflow/ml-package-versions.yml index 11e80cdc16b9b..703f09dd8c8bc 100644 --- a/mlflow/ml-package-versions.yml +++ b/mlflow/ml-package-versions.yml @@ -788,6 +788,8 @@ llama_index: # Required to run tests/openai/mock_openai.py "fastapi", "uvicorn", + # Required for testing LlamaIndex workflow + "pytest-asyncio", ] "< 0.11.0": ["pydantic<2"] run: pytest tests/llama_index/test_llama_index_autolog.py tests/llama_index/test_llama_index_tracer.py diff --git a/mlflow/models/model.py b/mlflow/models/model.py index ca97124331e18..3b083ae1ca1dd 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -1058,7 +1058,7 @@ def set_model(model): globals()["__mlflow_model__"] = model return - for validate_function in [_validate_llama_index_model]: + for validate_function in [_validate_llama_index_model, _validate_llama_index_model]: try: globals()["__mlflow_model__"] = validate_function(model) return diff --git a/tests/llama_index/test_llama_index_tracer.py b/tests/llama_index/test_llama_index_tracer.py index 4071ba4a9a93f..5ebd741683216 100644 --- a/tests/llama_index/test_llama_index_tracer.py +++ b/tests/llama_index/test_llama_index_tracer.py @@ -1,5 +1,6 @@ import asyncio import inspect +import random from dataclasses import asdict from typing import List from unittest.mock import ANY @@ -21,12 +22,14 @@ import mlflow import mlflow.tracking._tracking_service from mlflow.entities.span import SpanType +from mlflow.entities.span_status import SpanStatusCode from mlflow.entities.trace import Trace from mlflow.entities.trace_status import TraceStatus from mlflow.llama_index.tracer import remove_llama_index_tracer, set_llama_index_tracer from mlflow.tracking._tracking_service.utils import _use_tracking_uri from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID +llama_core_version = Version(importlib_metadata.version("llama-index-core")) llama_oai_version = Version(importlib_metadata.version("llama-index-llms-openai")) @@ -427,3 +430,95 @@ def test_tracer_handle_tracking_uri_update(tmp_path): # The new trace will be logged to the updated tracking URI OpenAI().complete("Hello") assert len(_get_all_traces()) == 1 + + +@pytest.mark.skipif( + llama_core_version >= Version("0.11.10"), + reason="Workflow tracing does not work correctly in >= 0.11.10 until " + "https://github.com/run-llama/llama_index/issues/16283 is fixed", +) +@pytest.mark.skipif( + llama_core_version < Version("0.11.0"), + reason="Workflow was introduced in 0.11.0", +) +@pytest.mark.asyncio +async def test_tracer_simple_workflow(): + from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step + + class MyWorkflow(Workflow): + @step + async def my_step(self, ev: StartEvent) -> StopEvent: + return StopEvent(result="Hi, world!") + + w = MyWorkflow(timeout=10, verbose=False) + await w.run() + + traces = _get_all_traces() + assert len(traces) == 1 + assert traces[0].info.status == TraceStatus.OK + assert all(s.status.status_code == SpanStatusCode.OK for s in traces[0].data.spans) + + +@pytest.mark.skipif( + llama_core_version >= Version("0.11.10"), + reason="Workflow tracing does not work correctly in >= 0.11.10 until " + "https://github.com/run-llama/llama_index/issues/16283 is fixed", +) +@pytest.mark.skipif( + llama_core_version < Version("0.11.0"), + reason="Workflow was introduced in 0.11.0", +) +@pytest.mark.asyncio +async def test_tracer_parallel_workflow(): + from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, + ) + + class ProcessEvent(Event): + data: str + + class ResultEvent(Event): + result: str + + class ParallelWorkflow(Workflow): + @step + async def start(self, ctx: Context, ev: StartEvent) -> ProcessEvent: + await ctx.set("num_to_collect", len(ev.inputs)) + for item in ev.inputs: + ctx.send_event(ProcessEvent(data=item)) + return None + + @step(num_workers=3) + async def process_data(self, ev: ProcessEvent) -> ResultEvent: + # Simulate some time-consuming processing + await asyncio.sleep(random.randint(1, 2)) + return ResultEvent(result=ev.data) + + @step + async def combine_results(self, ctx: Context, ev: ResultEvent) -> StopEvent: + num_to_collect = await ctx.get("num_to_collect") + results = ctx.collect_events(ev, [ResultEvent] * num_to_collect) + if results is None: + return None + + combined_result = ", ".join(sorted([event.result for event in results])) + return StopEvent(result=combined_result) + + w = ParallelWorkflow() + result = await w.run(inputs=["apple", "grape", "orange", "banana"]) + assert result == "apple, banana, grape, orange" + + traces = _get_all_traces() + assert len(traces) == 1 + assert traces[0].info.status == TraceStatus.OK + for s in traces[0].data.spans: + assert s.status.status_code == SpanStatusCode.OK + + root_span = traces[0].data.spans[0] + assert root_span.inputs == {"kwargs": {"inputs": ["apple", "grape", "orange", "banana"]}} + assert root_span.outputs == "apple, banana, grape, orange"