Skip to content

Commit

Permalink
Store span names & types, input names & types as internal trace tag (m…
Browse files Browse the repository at this point in the history
…lflow#12015)

Signed-off-by: Jesse Chan <[email protected]>
  • Loading branch information
jessechancy authored May 16, 2024
1 parent bd32a79 commit f1a49d4
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlflow/tracing/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TraceMetadataKey:
class TraceTagKey:
TRACE_NAME = "mlflow.traceName"
EVAL_REQUEST_ID = "eval.requestId"
TRACE_SPANS = "mlflow.traceSpans"


# A set of reserved attribute keys
Expand Down
5 changes: 5 additions & 0 deletions mlflow/tracing/export/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def export(self, root_spans: Sequence[ReadableSpan]):
self._log_trace(trace)

def _log_trace(self, trace: Trace):
try:
self._client._upload_trace_spans_as_tag(trace.info, trace.data)
except Exception as e:
_logger.debug(f"Failed to log trace spans as tag to MLflow backend: {e}", exc_info=True)

# TODO: Make this async
# The trace is already updated in processor.on_end method
# so we just log to backend store here
Expand Down
9 changes: 9 additions & 0 deletions mlflow/tracing/trace_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ def get_request_id_from_trace_id(self, trace_id: int) -> Optional[str]:
"""
return self._trace_id_to_request_id.get(trace_id)

def get_mlflow_trace(self, request_id: int) -> Optional[Trace]:
"""
Get the trace data for the given trace ID and return it as a ready-to-publish Trace object.
"""
with self._lock:
trace = self._traces.get(request_id)

return trace.to_mlflow_trace() if trace else None

def pop_trace(self, trace_id: int) -> Optional[Trace]:
"""
Pop the trace data for the given id and return it as a ready-to-publish Trace object.
Expand Down
24 changes: 24 additions & 0 deletions mlflow/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TRACE_SCHEMA_VERSION,
TRACE_SCHEMA_VERSION_KEY,
SpanAttributeKey,
TraceTagKey,
)
from mlflow.tracing.display import get_display_handler
from mlflow.tracing.trace_manager import InMemoryTraceManager
Expand Down Expand Up @@ -650,6 +651,29 @@ def end_trace(

self.end_span(request_id, root_span_id, outputs, attributes, status)

def _upload_trace_spans_as_tag(self, trace_info: TraceInfo, trace_data: TraceData):
# When a trace is logged, we set a mlflow.traceSpans tag via SetTraceTag API
# https://databricks.atlassian.net/browse/ML-40306
parsed_spans = []
for span in trace_data.spans:
parsed_span = {}

parsed_span["name"] = span.name
parsed_span["type"] = span.get_attribute(SpanAttributeKey.SPAN_TYPE)
span_inputs = span.get_attribute(SpanAttributeKey.INPUTS)
if span_inputs and isinstance(span_inputs, dict):
parsed_span["inputs"] = list(span_inputs.keys())
span_outputs = span.get_attribute(SpanAttributeKey.OUTPUTS)
if span_outputs and isinstance(span_outputs, dict):
parsed_span["outputs"] = list(span_outputs.keys())

parsed_spans.append(parsed_span)

# Directly set the tag on the trace in the backend
self._tracking_client.set_trace_tag(
trace_info.request_id, TraceTagKey.TRACE_SPANS, json.dumps(parsed_spans)
)

@experimental
def start_span(
self,
Expand Down
75 changes: 75 additions & 0 deletions tests/tracking/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import pickle
import time
from unittest import mock
Expand Down Expand Up @@ -540,6 +541,8 @@ def _mock_update_trace_info(trace_info):
"mlflow.tracking._tracking_service.client.TrackingServiceClient._upload_trace_data"
) as mock_upload_trace_data, mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.set_trace_tags",
), mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.set_trace_tag",
), mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.get_trace_info",
), mock.patch(
Expand Down Expand Up @@ -1360,3 +1363,75 @@ def test_file_store_download_upload_trace_data(clear_singleton, tmp_path):
trace_data = client.get_trace(span.request_id).data
assert trace_data.request == trace.data.request
assert trace_data.response == trace.data.response


def test_store_trace_spans_tag():
client = MlflowClient()

trace_spans_tag_value = {
"name": "test",
"type": "UNKNOWN",
"inputs": ["test"],
"outputs": ["result"],
}

with mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.set_trace_tag",
) as mock_set_trace_tag:
span = client.start_trace("test", inputs={"test": 1})
client.end_trace(span.request_id, outputs={"result": 2})
mock_set_trace_tag.assert_called_once()
assert mock_set_trace_tag.call_args[0][0] == span.request_id
assert mock_set_trace_tag.call_args[0][1] == "mlflow.traceSpans"
result = json.loads(mock_set_trace_tag.call_args[0][2])
for key in trace_spans_tag_value:
assert result[0][key] == trace_spans_tag_value[key]


def test_store_trace_span_tag_when_not_dict_input_outputs():
client = MlflowClient()

with mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.set_trace_tag",
) as mock_set_trace_tag:
span = client.start_trace("trace_name", inputs="test")
client.end_trace(span.request_id, outputs={"result": 2})
mock_set_trace_tag.assert_called_once()
assert mock_set_trace_tag.call_args[0][0] == span.request_id
assert mock_set_trace_tag.call_args[0][1] == "mlflow.traceSpans"
result = json.loads(mock_set_trace_tag.call_args[0][2])
assert result[0]["name"] == "trace_name"
assert result[0]["type"] == "UNKNOWN"
assert result[0]["outputs"] == ["result"]
assert "inputs" not in result[0]

with mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.set_trace_tag",
) as mock_set_trace_tag:
span = client.start_trace("trace_name", inputs="{'input': 2}")
client.end_trace(span.request_id, outputs="result")
mock_set_trace_tag.assert_called_once()
assert mock_set_trace_tag.call_args[0][0] == span.request_id
assert mock_set_trace_tag.call_args[0][1] == "mlflow.traceSpans"
result = json.loads(mock_set_trace_tag.call_args[0][2])
assert result[0]["name"] == "trace_name"
assert result[0]["type"] == "UNKNOWN"
assert "outputs" not in result[0]
assert "inputs" not in result[0]


# when JSON is too large, we skip logging the tag. The trace should still be logged.
def test_store_trace_span_tag_when_exception_raised():
client = MlflowClient()

with mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient._upload_trace_data"
) as mock_upload_trace_data, mock.patch(
"mlflow.tracking._tracking_service.client.TrackingServiceClient.set_trace_tag",
side_effect=MlflowException("Failed to log parameters"),
) as mock_set_trace_tag:
# This should not raise an exception
span = client.start_trace("trace_name", inputs={"input": "a" * 1000000})
client.end_trace(span.request_id, outputs={"result": "b" * 1000000})
mock_set_trace_tag.assert_called_once()
mock_upload_trace_data.assert_called_once()

0 comments on commit f1a49d4

Please sign in to comment.