Skip to content

Commit

Permalink
Llama workflow Trace (mlflow#13305)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored and BenWilson2 committed Oct 11, 2024
1 parent 6420a1e commit 6f409b2
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 4 deletions.
23 changes: 20 additions & 3 deletions mlflow/llama_index/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import singledispatchmethod
from typing import Any, Dict, Generator, Optional, Tuple, Union

import llama_index.core
import pydantic
from llama_index.core.base.agent.types import BaseAgent, BaseAgentWorker, TaskStepOutput
from llama_index.core.base.base_retriever import BaseRetriever
Expand Down Expand Up @@ -39,6 +40,7 @@
_logger = logging.getLogger(__name__)

IS_PYDANTIC_V1 = Version(pydantic.__version__).major < 2
LLAMA_INDEX_VERSION = Version(llama_index.core.__version__)


def set_llama_index_tracer():
Expand Down Expand Up @@ -127,11 +129,14 @@ def new_span(
parent_span_id: Optional[str] = None,
**kwargs: Any,
) -> _LlamaSpan:
with self.lock:
parent = self.open_spans.get(parent_span_id) if parent_span_id else None

try:
input_args = bound_args.arguments
attributes = self._get_instance_attributes(instance)
span_type = self._get_span_type(instance) or SpanType.UNKNOWN
if parent_span_id and (parent := self.open_spans.get(parent_span_id)):
if parent:
parent_span = parent._mlflow_span
# NB: Initiate the new client every time to handle tracking URI updates.
span = MlflowClient().start_span(
Expand Down Expand Up @@ -160,9 +165,11 @@ def prepare_to_exit_span(
**kwargs: Any,
) -> _LlamaSpan:
try:
llama_span = self.open_spans.get(id_)
with self.lock:
llama_span = self.open_spans.get(id_)
if not llama_span:
return

span = llama_span._mlflow_span

if self._stream_resolver.is_streaming_result(result):
Expand All @@ -187,8 +194,18 @@ def resolve_pending_stream_span(self, span: LiveSpan, event: Any):

def prepare_to_drop_span(self, id_: str, err: Optional[Exception], **kwargs) -> _LlamaSpan:
"""Logic for handling errors during the model execution."""
llama_span = self.open_spans.get(id_)
with self.lock:
llama_span = self.open_spans.get(id_)
span = llama_span._mlflow_span

if LLAMA_INDEX_VERSION >= Version("0.10.59"):
# LlamaIndex determines if a workflow is terminated or not by propagating an special
# exception WorkflowDone. We should treat this exception as a successful termination.
from llama_index.core.workflow.errors import WorkflowDone

if err and isinstance(err, WorkflowDone):
return _end_span(span=span, status=SpanStatusCode.OK)

span.add_event(SpanEvent.from_exception(err))
_end_span(span=span, status="ERROR")
return llama_span
Expand Down
2 changes: 2 additions & 0 deletions mlflow/ml-package-versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,8 @@ llama_index:
# Required to run tests/openai/mock_openai.py
"fastapi",
"uvicorn",
# Required for testing LlamaIndex workflow
"pytest-asyncio",
]
"< 0.11.0": ["pydantic<2"]
run: pytest tests/llama_index/test_llama_index_autolog.py tests/llama_index/test_llama_index_tracer.py
Expand Down
2 changes: 1 addition & 1 deletion mlflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ def set_model(model):
globals()["__mlflow_model__"] = model
return

for validate_function in [_validate_llama_index_model]:
for validate_function in [_validate_llama_index_model, _validate_llama_index_model]:
try:
globals()["__mlflow_model__"] = validate_function(model)
return
Expand Down
95 changes: 95 additions & 0 deletions tests/llama_index/test_llama_index_tracer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import inspect
import random
from dataclasses import asdict
from typing import List
from unittest.mock import ANY
Expand All @@ -21,12 +22,14 @@
import mlflow
import mlflow.tracking._tracking_service
from mlflow.entities.span import SpanType
from mlflow.entities.span_status import SpanStatusCode
from mlflow.entities.trace import Trace
from mlflow.entities.trace_status import TraceStatus
from mlflow.llama_index.tracer import remove_llama_index_tracer, set_llama_index_tracer
from mlflow.tracking._tracking_service.utils import _use_tracking_uri
from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID

llama_core_version = Version(importlib_metadata.version("llama-index-core"))
llama_oai_version = Version(importlib_metadata.version("llama-index-llms-openai"))


Expand Down Expand Up @@ -427,3 +430,95 @@ def test_tracer_handle_tracking_uri_update(tmp_path):
# The new trace will be logged to the updated tracking URI
OpenAI().complete("Hello")
assert len(_get_all_traces()) == 1


@pytest.mark.skipif(
llama_core_version >= Version("0.11.10"),
reason="Workflow tracing does not work correctly in >= 0.11.10 until "
"https://github.com/run-llama/llama_index/issues/16283 is fixed",
)
@pytest.mark.skipif(
llama_core_version < Version("0.11.0"),
reason="Workflow was introduced in 0.11.0",
)
@pytest.mark.asyncio
async def test_tracer_simple_workflow():
from llama_index.core.workflow import StartEvent, StopEvent, Workflow, step

class MyWorkflow(Workflow):
@step
async def my_step(self, ev: StartEvent) -> StopEvent:
return StopEvent(result="Hi, world!")

w = MyWorkflow(timeout=10, verbose=False)
await w.run()

traces = _get_all_traces()
assert len(traces) == 1
assert traces[0].info.status == TraceStatus.OK
assert all(s.status.status_code == SpanStatusCode.OK for s in traces[0].data.spans)


@pytest.mark.skipif(
llama_core_version >= Version("0.11.10"),
reason="Workflow tracing does not work correctly in >= 0.11.10 until "
"https://github.com/run-llama/llama_index/issues/16283 is fixed",
)
@pytest.mark.skipif(
llama_core_version < Version("0.11.0"),
reason="Workflow was introduced in 0.11.0",
)
@pytest.mark.asyncio
async def test_tracer_parallel_workflow():
from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)

class ProcessEvent(Event):
data: str

class ResultEvent(Event):
result: str

class ParallelWorkflow(Workflow):
@step
async def start(self, ctx: Context, ev: StartEvent) -> ProcessEvent:
await ctx.set("num_to_collect", len(ev.inputs))
for item in ev.inputs:
ctx.send_event(ProcessEvent(data=item))
return None

@step(num_workers=3)
async def process_data(self, ev: ProcessEvent) -> ResultEvent:
# Simulate some time-consuming processing
await asyncio.sleep(random.randint(1, 2))
return ResultEvent(result=ev.data)

@step
async def combine_results(self, ctx: Context, ev: ResultEvent) -> StopEvent:
num_to_collect = await ctx.get("num_to_collect")
results = ctx.collect_events(ev, [ResultEvent] * num_to_collect)
if results is None:
return None

combined_result = ", ".join(sorted([event.result for event in results]))
return StopEvent(result=combined_result)

w = ParallelWorkflow()
result = await w.run(inputs=["apple", "grape", "orange", "banana"])
assert result == "apple, banana, grape, orange"

traces = _get_all_traces()
assert len(traces) == 1
assert traces[0].info.status == TraceStatus.OK
for s in traces[0].data.spans:
assert s.status.status_code == SpanStatusCode.OK

root_span = traces[0].data.spans[0]
assert root_span.inputs == {"kwargs": {"inputs": ["apple", "grape", "orange", "banana"]}}
assert root_span.outputs == "apple, banana, grape, orange"

0 comments on commit 6f409b2

Please sign in to comment.