Skip to content

Commit

Permalink
Add set_destination API (mlflow#14249)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
Signed-off-by: k99kurella <[email protected]>
  • Loading branch information
B-Step62 authored and karthikkurella committed Jan 30, 2025
1 parent 97b606d commit c308a4b
Show file tree
Hide file tree
Showing 13 changed files with 410 additions and 40 deletions.
3 changes: 3 additions & 0 deletions docs/source/python_api/mlflow.tracing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ mlflow.tracing
:members:
:undoc-members:
:show-inheritance:

.. automodule:: mlflow.tracing.destination
:members:
4 changes: 3 additions & 1 deletion mlflow/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mlflow.tracing.display import disable_notebook_display, enable_notebook_display
from mlflow.tracing.provider import disable, enable
from mlflow.tracing.provider import disable, enable, reset, set_destination
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools

__all__ = [
Expand All @@ -9,4 +9,6 @@
"enable_notebook_display",
"set_span_chat_messages",
"set_span_chat_tools",
"set_destination",
"reset",
]
39 changes: 39 additions & 0 deletions mlflow/tracing/destination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass
from typing import Optional

from mlflow.utils.annotations import experimental


@experimental
@dataclass
class TraceDestination:
"""A configuration object for specifying the destination of trace data."""

@property
def type(self) -> str:
"""Type of the destination."""
raise NotImplementedError


@experimental
@dataclass
class MlflowExperiment(TraceDestination):
"""
A destination representing an MLflow experiment.
By setting this destination in the :py:func:`mlflow.tracing.set_destination` function,
MLflow will log traces to the specified experiment.
Attributes:
experiment_id: The ID of the experiment to log traces to. If not specified,
the current active experiment will be used.
tracking_uri: The tracking URI of the MLflow server to log traces to.
If not specified, the current tracking URI will be used.
"""

experiment_id: Optional[str] = None
tracking_uri: Optional[str] = None

@property
def type(self) -> str:
return "experiment"
52 changes: 52 additions & 0 deletions mlflow/tracing/export/databricks_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
from typing import Sequence

from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter

from mlflow.deployments import get_deploy_client
from mlflow.tracing.destination import TraceDestination
from mlflow.tracing.trace_manager import InMemoryTraceManager

_logger = logging.getLogger(__name__)


class DatabricksAgentSpanExporter(SpanExporter):
"""
An exporter implementation that logs the traces to Databricks Agent Monitoring.
Args:
trace_destination: The destination of the traces.
TODO: This class will be migrated under databricks-agents package.
"""

def __init__(self, trace_destination: TraceDestination):
self._databricks_monitor_id = trace_destination.databricks_monitor_id
self._trace_manager = InMemoryTraceManager.get_instance()
self._deploy_client = get_deploy_client("databricks")

def export(self, root_spans: Sequence[ReadableSpan]):
"""
Export the spans to the destination.
Args:
root_spans: A sequence of OpenTelemetry ReadableSpan objects to be exported.
Only root spans for each trace are passed to this method.
"""
for span in root_spans:
if span._parent is not None:
_logger.debug("Received a non-root span. Skipping export.")
continue

trace = self._trace_manager.pop_trace(span.context.trace_id)
if trace is None:
_logger.debug(f"Trace for span {span} not found. Skipping export.")
continue

# Traces are exported via a serving endpoint that accepts trace JSON as
# an input payload, and then will be written to the Inference Table.
self._deploy_client.predict(
endpoint=self._databricks_monitor_id,
inputs={"inputs": [trace.to_json()]},
)
94 changes: 94 additions & 0 deletions mlflow/tracing/processor/databricks_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import logging
import uuid
from typing import Optional

from opentelemetry.context import Context
from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
from opentelemetry.sdk.trace import Span as OTelSpan
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter

from mlflow.entities.trace_info import TraceInfo
from mlflow.entities.trace_status import TraceStatus
from mlflow.tracing.constant import TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY, SpanAttributeKey
from mlflow.tracing.trace_manager import InMemoryTraceManager
from mlflow.tracing.utils import (
deduplicate_span_names_in_place,
get_otel_attribute,
maybe_get_dependencies_schemas,
)

_logger = logging.getLogger(__name__)


class DatabricksAgentSpanProcessor(SimpleSpanProcessor):
"""
Defines custom hooks to be executed when a span is started or ended (before exporting).
This process implements simple responsibilities to generate MLflow-style trace
object from OpenTelemetry spans and store them in memory.
TODO: This class will be migrated under databricks-agents package.
"""

def __init__(self, span_exporter: SpanExporter):
self.span_exporter = span_exporter
self._trace_manager = InMemoryTraceManager.get_instance()

def on_start(self, span: OTelSpan, parent_context: Optional[Context] = None):
"""
Handle the start of a span. This method is called when an OpenTelemetry span is started.
Args:
span: An OpenTelemetry Span object that is started.
parent_context: The context of the span. Note that this is only passed when the context
object is explicitly specified to OpenTelemetry start_span call. If the parent
span is obtained from the global context, it won't be passed here so we should not
rely on it.
"""

request_id = self._create_or_get_request_id(span)
span.set_attribute(SpanAttributeKey.REQUEST_ID, json.dumps(request_id))

tags = {}
if dependencies_schema := maybe_get_dependencies_schemas():
tags.update(dependencies_schema)

if span._parent is None:
trace_info = TraceInfo(
request_id=request_id,
experiment_id=None,
timestamp_ms=span.start_time // 1_000_000, # nanosecond to millisecond
execution_time_ms=None,
status=TraceStatus.IN_PROGRESS,
request_metadata={TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)},
tags=tags,
)
self._trace_manager.register_trace(span.context.trace_id, trace_info)

def _create_or_get_request_id(self, span: OTelSpan) -> str:
if span._parent is None:
return "tr-" + uuid.uuid4().hex
else:
return self._trace_manager.get_request_id_from_trace_id(span.context.trace_id)

def on_end(self, span: OTelReadableSpan) -> None:
"""
Handle the end of a span. This method is called when an OpenTelemetry span is ended.
Args:
span: An OpenTelemetry ReadableSpan object that is ended.
"""
# Processing the trace only when it is a root span.
if span._parent is None:
request_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID)
with self._trace_manager.get_trace(request_id) as trace:
if trace is None:
_logger.debug(f"Trace data with request ID {request_id} not found.")
return

trace.info.execution_time_ms = (span.end_time - span.start_time) // 1_000_000
trace.info.status = TraceStatus.from_otel_status(span.status)
deduplicate_span_names_in_place(list(trace.span_dict.values()))

super().on_end(span)
61 changes: 33 additions & 28 deletions mlflow/tracing/processor/mlflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any, Optional
from typing import Optional

from opentelemetry.context import Context
from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
Expand Down Expand Up @@ -44,9 +44,15 @@ class MlflowSpanProcessor(SimpleSpanProcessor):
This processor is used when the tracing destination is MLflow Tracking Server.
"""

def __init__(self, span_exporter: SpanExporter, client: Optional[MlflowClient] = None):
def __init__(
self,
span_exporter: SpanExporter,
client: Optional[MlflowClient] = None,
experiment_id: Optional[str] = None,
):
self.span_exporter = span_exporter
self._client = client or MlflowClient()
self._experiment_id = experiment_id
self._trace_manager = InMemoryTraceManager.get_instance()

# We issue a warning when a trace is created under the default experiment.
Expand Down Expand Up @@ -87,8 +93,8 @@ def on_start(self, span: OTelSpan, parent_context: Optional[Context] = None):
def _start_trace(self, span: OTelSpan, start_time_ns: Optional[int]) -> TraceInfo:
from mlflow.tracking.fluent import _get_latest_active_run

experiment_id = get_otel_attribute(span, SpanAttributeKey.EXPERIMENT_ID)
metadata = {TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)}

# If the span is started within an active MLflow run, we should record it as a trace tag
# Note `mlflow.active_run()` can only get thread-local active run,
# but tracing routine might be applied to model inference worker threads
Expand All @@ -99,14 +105,8 @@ def _start_trace(self, span: OTelSpan, start_time_ns: Optional[int]) -> TraceInf
# all threads and set it as the tracing source run.
if run := _get_latest_active_run():
metadata[TraceMetadataKey.SOURCE_RUN] = run.info.run_id
if experiment_id is None:
# if we're inside a run, the run's experiment id should
# take precedence over the environment experiment id
experiment_id = run.info.experiment_id

if experiment_id is None:
experiment_id = _get_experiment_id()

experiment_id = self._get_experiment_id_for_trace(span)
if experiment_id == DEFAULT_EXPERIMENT_ID and not self._issued_default_exp_warning:
_logger.warning(
"Creating a trace within the default experiment with id "
Expand Down Expand Up @@ -169,6 +169,29 @@ def on_end(self, span: OTelReadableSpan) -> None:

super().on_end(span)

def _get_experiment_id_for_trace(self, span: OTelReadableSpan) -> str:
"""
Determine the experiment ID to associate with the trace.
The experiment ID can be configured in multiple ways, in order of precedence:
1. An experiment ID specified via the span creation API i.e. MlflowClient().start_trace()
2. An experiment ID specified via the processor constructor
3. An experiment ID of an active run.
4. The default experiment ID
"""
from mlflow.tracking.fluent import _get_latest_active_run

if experiment_id := get_otel_attribute(span, SpanAttributeKey.EXPERIMENT_ID):
return experiment_id

if self._experiment_id:
return self._experiment_id

if run := _get_latest_active_run():
return run.info.experiment_id

return _get_experiment_id()

def _update_trace_info(self, trace: _Trace, root_span: OTelReadableSpan):
"""Update the trace info with the final values from the root span."""
# The trace/span start time needs adjustment to exclude the latency of
Expand Down Expand Up @@ -197,21 +220,3 @@ def _truncate_metadata(self, value: Optional[str]) -> str:
trunc_length = MAX_CHARS_IN_TRACE_INFO_METADATA_AND_TAGS - len(TRUNCATION_SUFFIX)
value = value[:trunc_length] + TRUNCATION_SUFFIX
return value

def _create_trace_info(
self,
request_id: str,
span: OTelSpan,
experiment_id: Optional[str] = None,
request_metadata: Optional[dict[str, Any]] = None,
tags: Optional[dict[str, str]] = None,
) -> TraceInfo:
return TraceInfo(
request_id=request_id,
experiment_id=experiment_id,
timestamp_ms=span.start_time // 1_000_000, # nanosecond to millisecond
execution_time_ms=None,
status=TraceStatus.IN_PROGRESS,
request_metadata=request_metadata or {},
tags=tags or {},
)
Loading

0 comments on commit c308a4b

Please sign in to comment.