Skip to content

Commit

Permalink
Implements MLFlowLogger (#2365)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-az authored Feb 12, 2025
1 parent 7f3e70e commit a965fb0
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Various logging utilities.
metric_logging.TensorBoardLogger
metric_logging.StdoutLogger
metric_logging.DiskLogger
metric_logging.MLFlowLogger

.. _perf_profiling_label:

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ tune = "torchtune._cli.tune:main"
dev = [
"bitsandbytes>=0.43.0",
"comet_ml>=3.44.2",
"mlflow",
"pre-commit",
"pytest==7.4.0",
"pytest-cov",
Expand Down
59 changes: 59 additions & 0 deletions tests/torchtune/training/test_metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from typing import cast
from unittest.mock import patch

import mlflow

import mlflow.artifacts
import pytest
from omegaconf import OmegaConf
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
Expand All @@ -19,6 +22,7 @@
from torchtune.training.metric_logging import (
CometLogger,
DiskLogger,
MLFlowLogger,
StdoutLogger,
TensorBoardLogger,
WandBLogger,
Expand Down Expand Up @@ -199,3 +203,58 @@ def test_log_config(self) -> None:
cfg = OmegaConf.create({"a": 1, "b": 2})
logger.log_config(cfg)
mock_experiment.return_value.log_parameters.assert_called_with(cfg)


@pytest.fixture(scope="class")
def mlflow_context_fixture():
original_uri = mlflow.get_tracking_uri()

with tempfile.TemporaryDirectory() as tmpdir:
mlflow.set_tracking_uri(f"file:{tmpdir}")
yield

# Restore the original URI
mlflow.set_tracking_uri(original_uri)


@pytest.mark.usefixtures("mlflow_context_fixture")
class TestMLFlowLogger:
def test_log(self):
logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1")
run_id = logger._run_id
logger.log("test_metric", 1.0, step=1)
logger.close()

run = mlflow.get_run(run_id)
assert run.data.metrics == {"test_metric": 1}

def test_log_dict(self):
logger = MLFlowLogger(experiment_name="my_experiment", run_name="run2")
run_id = logger._run_id
metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)}
logger.log_dict(metric_dict, step=2)
logger.close()

run = mlflow.get_run(run_id)
assert run.data.metrics == metric_dict

def test_log_config(self) -> None:
with tempfile.TemporaryDirectory() as output_dir:
cfg = OmegaConf.create(
{"foo": {"bar": "baz"}, "qux": "quux", "output_dir": output_dir}
)
logger = MLFlowLogger(experiment_name="my_experiment", run_name="run2")
run_id = logger._run_id

logger.log_config(cfg)

expected = {"foo.bar": "baz", "qux": "quux", "output_dir": output_dir}

run = mlflow.get_run(run_id)
assert run.data.params == expected

artifacts = mlflow.artifacts.list_artifacts(
run_id=run_id, artifact_path=output_dir.lstrip("/")
)
assert len(artifacts) == 1
assert artifacts[0].path.endswith("torchtune_config.yaml")
139 changes: 139 additions & 0 deletions torchtune/training/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@ def save_config(config: DictConfig) -> Path:
log.warning(f"Error saving config.\nError: \n{e}.")


def flatten_dict(d: Dict[str, Any], *, sep: str = ".", parent_key: str = ""):
"""Recursively flattens a nested dictionary into one level of key-value pairs.
Args:
d (Dict[str, Any]): Any dictionary to flatten.
sep (str, optional): Desired separator for flattening nested keys. Defaults to ".".
parent_key (str, optional): Key prefix for children (nested keys), containing parent key names. Defaults to "".
Example:
>>> flatten_dict({"foo": {"bar": "baz"}, "qux": "quux"}, sep="--")
{"foo--bar": "baz", "qux": "quux"}
Returns:
Dict[str, Any]: Flattened dictionary.
Note:
Does not unnest dictionaries within list values (i.e., {"foo": [{"bar": "baz"}]}).
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, sep=sep, parent_key=new_key).items())
else:
items.append((new_key, v))
return dict(items)


class MetricLoggerInterface(Protocol):
"""Abstract metric logger."""

Expand Down Expand Up @@ -469,3 +497,114 @@ def close(self) -> None:

def __del__(self) -> None:
self.close()


class MLFlowLogger(MetricLoggerInterface):
"""Logger for use w/ MLFlow (https://mlflow.org/).
Args:
experiment_name (Optional[str]): MLFlow experiment name. If not specified, will
default to MLFLOW_EXPERIMENT_NAME environment variable if set, or default.
tracking_uri (Optional[str]): MLFlow tracking uri. If not specified, will default
to MLFLOW_TRACKING_URI environment variable if set, or default.
run_id (Optional[str]): MLFlow run name. If not specified, will default
to mlflow-generated HRID. Unused if run_id is specified or MLFLOW_RUN_ID
environment variable is found.
run_name (Optional[str]): MLFlow run ID. If not specified, will default
to MLFLOW_RUN_ID environment variable if set, or a new run will be created.
Example:
>>> logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1")
>>> logger.log("accuracy", 0.95, step=1)
>>> logger.log_dict({"loss": 0.1, "accuracy": 0.95}, step=1)
>>> logger.log_config(config)
>>> logger.close()
Raises:
ImportError: If ``mlflow`` package is not installed.
Note:
This logger requires the mlflow package to be installed.
You can install it with `pip install mlflow`.
"""

def __init__(
self,
experiment_name: Optional[str] = None,
tracking_uri: Optional[str] = None,
run_id: Optional[str] = None,
run_name: Optional[str] = None,
):
try:
import mlflow
except ImportError as e:
raise ImportError(
"``mlflow`` package not found. Please install mlflow using `pip install mlflow` to use MLFlowLogger."
"Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'."
) from e

_, self.rank = get_world_size_and_rank()

self._mlflow = mlflow

self._tracking_uri = tracking_uri or os.getenv("MLFLOW_TRACKING_URI")
self._experiment_name = experiment_name or os.getenv("MLFLOW_EXPERIMENT_NAME")
self._run_id = run_id or os.getenv("MLFLOW_RUN_ID")

if self.rank == 0:
if not self._mlflow.is_tracking_uri_set():
if self._tracking_uri is not None:
self._mlflow.set_tracking_uri(self._tracking_uri)

if self._mlflow.active_run() is None or self._nested_run or self._run_id:
if self._experiment_name is not None:
# Use of set_experiment() ensure that Experiment is created if not exists
self._mlflow.set_experiment(self._experiment_name)
run = self._mlflow.start_run(run_name=run_name)
self._run_id = run.info.run_id

def log_config(self, config: DictConfig) -> None:
"""Saves the config locally and also logs the config to mlflow. The config is
stored in the same directory as the checkpoint.
Args:
config (DictConfig): config to log
"""
if self._mlflow.active_run():
resolved = OmegaConf.to_container(config, resolve=True)

# mlflow's params must be flat key-value pairs
config_as_params = flatten_dict(resolved)
self._mlflow.log_params(config_as_params, run_id=self._run_id)

output_config_fname = save_config(config)

# this avoids break if config's output_dir is an absolute path
artifact_path = str(output_config_fname.parent).lstrip("/")

self._mlflow.log_artifact(
output_config_fname,
artifact_path=artifact_path,
run_id=self._run_id,
)

def log(self, name: str, data: Scalar, step: int) -> None:
if self._mlflow.active_run():
self._mlflow.log_metric(name, data, step=step, run_id=self._run_id)

def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
if self._mlflow.active_run():
self._mlflow.log_metrics(payload, step=step, run_id=self._run_id)

def close(self) -> None:
"""
Ends the MLflow run.
After calling close, no further logging should be performed.
"""
if self.rank == 0 and self._mlflow.active_run():
self._mlflow.end_run()

def __del__(self) -> None:
# Ensure the MLflow run is closed when the logger is deleted.
if hasattr(self, "_mlflow") and self._mlflow.active_run():
self._mlflow.end_run()

0 comments on commit a965fb0

Please sign in to comment.