forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
set_destination
API (mlflow#14249)
Signed-off-by: B-Step62 <[email protected]> Signed-off-by: k99kurella <[email protected]>
- Loading branch information
1 parent
4a9e2b6
commit 575fe00
Showing
13 changed files
with
410 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from mlflow.utils.annotations import experimental | ||
|
||
|
||
@experimental | ||
@dataclass | ||
class TraceDestination: | ||
"""A configuration object for specifying the destination of trace data.""" | ||
|
||
@property | ||
def type(self) -> str: | ||
"""Type of the destination.""" | ||
raise NotImplementedError | ||
|
||
|
||
@experimental | ||
@dataclass | ||
class MlflowExperiment(TraceDestination): | ||
""" | ||
A destination representing an MLflow experiment. | ||
By setting this destination in the :py:func:`mlflow.tracing.set_destination` function, | ||
MLflow will log traces to the specified experiment. | ||
Attributes: | ||
experiment_id: The ID of the experiment to log traces to. If not specified, | ||
the current active experiment will be used. | ||
tracking_uri: The tracking URI of the MLflow server to log traces to. | ||
If not specified, the current tracking URI will be used. | ||
""" | ||
|
||
experiment_id: Optional[str] = None | ||
tracking_uri: Optional[str] = None | ||
|
||
@property | ||
def type(self) -> str: | ||
return "experiment" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import logging | ||
from typing import Sequence | ||
|
||
from opentelemetry.sdk.trace import ReadableSpan | ||
from opentelemetry.sdk.trace.export import SpanExporter | ||
|
||
from mlflow.deployments import get_deploy_client | ||
from mlflow.tracing.destination import TraceDestination | ||
from mlflow.tracing.trace_manager import InMemoryTraceManager | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class DatabricksAgentSpanExporter(SpanExporter): | ||
""" | ||
An exporter implementation that logs the traces to Databricks Agent Monitoring. | ||
Args: | ||
trace_destination: The destination of the traces. | ||
TODO: This class will be migrated under databricks-agents package. | ||
""" | ||
|
||
def __init__(self, trace_destination: TraceDestination): | ||
self._databricks_monitor_id = trace_destination.databricks_monitor_id | ||
self._trace_manager = InMemoryTraceManager.get_instance() | ||
self._deploy_client = get_deploy_client("databricks") | ||
|
||
def export(self, root_spans: Sequence[ReadableSpan]): | ||
""" | ||
Export the spans to the destination. | ||
Args: | ||
root_spans: A sequence of OpenTelemetry ReadableSpan objects to be exported. | ||
Only root spans for each trace are passed to this method. | ||
""" | ||
for span in root_spans: | ||
if span._parent is not None: | ||
_logger.debug("Received a non-root span. Skipping export.") | ||
continue | ||
|
||
trace = self._trace_manager.pop_trace(span.context.trace_id) | ||
if trace is None: | ||
_logger.debug(f"Trace for span {span} not found. Skipping export.") | ||
continue | ||
|
||
# Traces are exported via a serving endpoint that accepts trace JSON as | ||
# an input payload, and then will be written to the Inference Table. | ||
self._deploy_client.predict( | ||
endpoint=self._databricks_monitor_id, | ||
inputs={"inputs": [trace.to_json()]}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import json | ||
import logging | ||
import uuid | ||
from typing import Optional | ||
|
||
from opentelemetry.context import Context | ||
from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan | ||
from opentelemetry.sdk.trace import Span as OTelSpan | ||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter | ||
|
||
from mlflow.entities.trace_info import TraceInfo | ||
from mlflow.entities.trace_status import TraceStatus | ||
from mlflow.tracing.constant import TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY, SpanAttributeKey | ||
from mlflow.tracing.trace_manager import InMemoryTraceManager | ||
from mlflow.tracing.utils import ( | ||
deduplicate_span_names_in_place, | ||
get_otel_attribute, | ||
maybe_get_dependencies_schemas, | ||
) | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class DatabricksAgentSpanProcessor(SimpleSpanProcessor): | ||
""" | ||
Defines custom hooks to be executed when a span is started or ended (before exporting). | ||
This process implements simple responsibilities to generate MLflow-style trace | ||
object from OpenTelemetry spans and store them in memory. | ||
TODO: This class will be migrated under databricks-agents package. | ||
""" | ||
|
||
def __init__(self, span_exporter: SpanExporter): | ||
self.span_exporter = span_exporter | ||
self._trace_manager = InMemoryTraceManager.get_instance() | ||
|
||
def on_start(self, span: OTelSpan, parent_context: Optional[Context] = None): | ||
""" | ||
Handle the start of a span. This method is called when an OpenTelemetry span is started. | ||
Args: | ||
span: An OpenTelemetry Span object that is started. | ||
parent_context: The context of the span. Note that this is only passed when the context | ||
object is explicitly specified to OpenTelemetry start_span call. If the parent | ||
span is obtained from the global context, it won't be passed here so we should not | ||
rely on it. | ||
""" | ||
|
||
request_id = self._create_or_get_request_id(span) | ||
span.set_attribute(SpanAttributeKey.REQUEST_ID, json.dumps(request_id)) | ||
|
||
tags = {} | ||
if dependencies_schema := maybe_get_dependencies_schemas(): | ||
tags.update(dependencies_schema) | ||
|
||
if span._parent is None: | ||
trace_info = TraceInfo( | ||
request_id=request_id, | ||
experiment_id=None, | ||
timestamp_ms=span.start_time // 1_000_000, # nanosecond to millisecond | ||
execution_time_ms=None, | ||
status=TraceStatus.IN_PROGRESS, | ||
request_metadata={TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)}, | ||
tags=tags, | ||
) | ||
self._trace_manager.register_trace(span.context.trace_id, trace_info) | ||
|
||
def _create_or_get_request_id(self, span: OTelSpan) -> str: | ||
if span._parent is None: | ||
return "tr-" + uuid.uuid4().hex | ||
else: | ||
return self._trace_manager.get_request_id_from_trace_id(span.context.trace_id) | ||
|
||
def on_end(self, span: OTelReadableSpan) -> None: | ||
""" | ||
Handle the end of a span. This method is called when an OpenTelemetry span is ended. | ||
Args: | ||
span: An OpenTelemetry ReadableSpan object that is ended. | ||
""" | ||
# Processing the trace only when it is a root span. | ||
if span._parent is None: | ||
request_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID) | ||
with self._trace_manager.get_trace(request_id) as trace: | ||
if trace is None: | ||
_logger.debug(f"Trace data with request ID {request_id} not found.") | ||
return | ||
|
||
trace.info.execution_time_ms = (span.end_time - span.start_time) // 1_000_000 | ||
trace.info.status = TraceStatus.from_otel_status(span.status) | ||
deduplicate_span_names_in_place(list(trace.span_dict.values())) | ||
|
||
super().on_end(span) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.