Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <[email protected]>
  • Loading branch information
daniellok-db committed Dec 18, 2024
1 parent 2e06c97 commit 8d7ca9b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 34 deletions.
4 changes: 2 additions & 2 deletions mlflow/entities/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def _repr_mimebundle_(self, include=None, exclude=None):
in Databricks notebooks to display the Trace object in a nicer UI.
"""
from mlflow.tracing.display import (
_get_notebook_iframe_html,
get_display_handler,
get_notebook_iframe_html,
is_using_tracking_server,
)
from mlflow.utils.databricks_utils import is_in_databricks_runtime
Expand All @@ -95,7 +95,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
if is_in_databricks_runtime():
bundle["application/databricks.mlflow.trace"] = self._serialize_for_mimebundle()
elif is_using_tracking_server():
bundle["text/html"] = _get_notebook_iframe_html([self])
bundle["text/html"] = get_notebook_iframe_html([self])

return bundle

Expand Down
4 changes: 2 additions & 2 deletions mlflow/tracing/display/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from mlflow.tracing.display.display_handler import (
IPythonTraceDisplayHandler,
_get_notebook_iframe_html,
get_notebook_iframe_html,
is_using_tracking_server,
)

__all__ = [
"IPythonTraceDisplayHandler",
"get_display_handler",
"is_using_tracking_server",
"_get_notebook_iframe_html",
"get_notebook_iframe_html",
]


Expand Down
51 changes: 24 additions & 27 deletions mlflow/tracing/display/display_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@
}}
</style>
<button
onclick="const display = this.nextElementSibling.style.display;
const isCollapsed = display === 'none';
this.nextElementSibling.style.display = isCollapsed ? null : 'none';
const verb = isCollapsed ? 'Collapse' : 'Expand';
this.innerText = `${{verb}} MLflow Trace`;"
onclick="
const display = this.nextElementSibling.style.display;
const isCollapsed = display === 'none';
this.nextElementSibling.style.display = isCollapsed ? null : 'none';
const verb = isCollapsed ? 'Collapse' : 'Expand';
this.innerText = `${{verb}} MLflow Trace`;
"
>Collapse MLflow Trace</button>
<iframe
id="trace-renderer"
Expand All @@ -52,7 +55,7 @@
"""


def _get_notebook_iframe_html(traces: list["Trace"]):
def get_notebook_iframe_html(traces: list["Trace"]):
# fetch assets from tracking server
uri = urljoin(mlflow.get_tracking_uri(), TRACE_RENDERER_ASSET_PATH)
query_string = _get_query_string_for_traces(traces)
Expand Down Expand Up @@ -82,7 +85,7 @@ def _get_query_string_for_traces(traces: list["Trace"]):
return urlencode(query_params)


def is_jupyter():
def _is_jupyter():
try:
from IPython import get_ipython

Expand All @@ -95,11 +98,11 @@ def is_using_tracking_server():
return is_http_uri(mlflow.get_tracking_uri())


def validate_environment():
def is_trace_ui_available():
# the notebook display feature only works in
# Databricks notebooks, or in Jupyter notebooks
# with a tracking server
return is_jupyter() and (is_in_databricks_runtime() or is_using_tracking_server())
return _is_jupyter() and (is_in_databricks_runtime() or is_using_tracking_server())


class IPythonTraceDisplayHandler:
Expand All @@ -124,7 +127,7 @@ def enable(cls):

def __init__(self):
self.traces_to_display = {}
if not is_jupyter():
if not _is_jupyter():
return

try:
Expand All @@ -143,28 +146,20 @@ def __init__(self):
_logger.debug("Failed to register post-run cell display hook", exc_info=True)

def _display_traces_post_run(self, result):
if self.disabled or not validate_environment():
if self.disabled or not is_trace_ui_available():
self.traces_to_display = {}
return

try:
from IPython.display import HTML, display
from IPython.display import display

MAX_TRACES_TO_DISPLAY = MLFLOW_MAX_TRACES_TO_DISPLAY_IN_NOTEBOOK.get()
traces_to_display = list(self.traces_to_display.values())[:MAX_TRACES_TO_DISPLAY]
if len(traces_to_display) == 0:
self.traces_to_display = {}
return

if is_in_databricks_runtime():
display(
self.get_databricks_mimebundle(traces_to_display),
display_id=True,
raw=True,
)
else:
html = HTML(_get_notebook_iframe_html(traces_to_display))
display(html)
display(self.get_mimebundle(traces_to_display), raw=True)

# reset state
self.traces_to_display = {}
Expand All @@ -176,17 +171,19 @@ def _display_traces_post_run(self, result):
_logger.error("Failed to display traces", exc_info=True)
self.traces_to_display = {}

def get_databricks_mimebundle(self, traces: list["Trace"]):
def get_mimebundle(self, traces: list["Trace"]):
if len(traces) == 1:
return traces[0]._repr_mimebundle_()
else:
return {
"application/databricks.mlflow.trace": _serialize_trace_list(traces),
"text/plain": repr(traces),
}
bundle = {"text/plain": repr(traces)}
if is_in_databricks_runtime():
bundle["application/databricks.mlflow.trace"] = _serialize_trace_list(traces)
else:
bundle["text/html"] = get_notebook_iframe_html(traces)
return bundle

def display_traces(self, traces: list["Trace"]):
if self.disabled or not validate_environment():
if self.disabled or not is_trace_ui_available():
return

try:
Expand Down
6 changes: 3 additions & 3 deletions tests/tracing/display/test_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import mlflow
from mlflow.tracing.display import (
IPythonTraceDisplayHandler,
_get_notebook_iframe_html,
get_display_handler,
get_notebook_iframe_html,
)

from tests.tracing.helper import create_trace
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_mimebundle_in_oss():
mlflow.set_tracking_uri("http://localhost:5000")
assert trace._repr_mimebundle_() == {
"text/plain": repr(trace),
"text/html": _get_notebook_iframe_html([trace]),
"text/html": get_notebook_iframe_html([trace]),
}

# disabling should remove this key, even if tracking server is used
Expand Down Expand Up @@ -283,4 +283,4 @@ def test_display_in_oss(monkeypatch):
mock_ipython.mock_run_cell()

assert mock_display.call_count == 1
assert "<iframe" in mock_display.call_args[0][0]
assert "<iframe" in mock_display.call_args[0][0]["text/html"]

0 comments on commit 8d7ca9b

Please sign in to comment.