diff --git a/docs/source/api_ref_training.rst b/docs/source/api_ref_training.rst index 747f312447..00509d4505 100644 --- a/docs/source/api_ref_training.rst +++ b/docs/source/api_ref_training.rst @@ -102,6 +102,7 @@ Various logging utilities. metric_logging.TensorBoardLogger metric_logging.StdoutLogger metric_logging.DiskLogger + metric_logging.MLFlowLogger .. _perf_profiling_label: diff --git a/pyproject.toml b/pyproject.toml index f94732b58a..9b5f386997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ tune = "torchtune._cli.tune:main" dev = [ "bitsandbytes>=0.43.0", "comet_ml>=3.44.2", + "mlflow", "pre-commit", "pytest==7.4.0", "pytest-cov", diff --git a/tests/torchtune/training/test_metric_logging.py b/tests/torchtune/training/test_metric_logging.py index 2fc29e72aa..1d4012ab55 100644 --- a/tests/torchtune/training/test_metric_logging.py +++ b/tests/torchtune/training/test_metric_logging.py @@ -10,6 +10,9 @@ from typing import cast from unittest.mock import patch +import mlflow + +import mlflow.artifacts import pytest from omegaconf import OmegaConf from tensorboard.backend.event_processing.event_accumulator import EventAccumulator @@ -19,6 +22,7 @@ from torchtune.training.metric_logging import ( CometLogger, DiskLogger, + MLFlowLogger, StdoutLogger, TensorBoardLogger, WandBLogger, @@ -199,3 +203,58 @@ def test_log_config(self) -> None: cfg = OmegaConf.create({"a": 1, "b": 2}) logger.log_config(cfg) mock_experiment.return_value.log_parameters.assert_called_with(cfg) + + +@pytest.fixture(scope="class") +def mlflow_context_fixture(): + original_uri = mlflow.get_tracking_uri() + + with tempfile.TemporaryDirectory() as tmpdir: + mlflow.set_tracking_uri(f"file:{tmpdir}") + yield + + # Restore the original URI + mlflow.set_tracking_uri(original_uri) + + +@pytest.mark.usefixtures("mlflow_context_fixture") +class TestMLFlowLogger: + def test_log(self): + logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1") + run_id = logger._run_id + logger.log("test_metric", 1.0, step=1) + logger.close() + + run = mlflow.get_run(run_id) + assert run.data.metrics == {"test_metric": 1} + + def test_log_dict(self): + logger = MLFlowLogger(experiment_name="my_experiment", run_name="run2") + run_id = logger._run_id + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, step=2) + logger.close() + + run = mlflow.get_run(run_id) + assert run.data.metrics == metric_dict + + def test_log_config(self) -> None: + with tempfile.TemporaryDirectory() as output_dir: + cfg = OmegaConf.create( + {"foo": {"bar": "baz"}, "qux": "quux", "output_dir": output_dir} + ) + logger = MLFlowLogger(experiment_name="my_experiment", run_name="run2") + run_id = logger._run_id + + logger.log_config(cfg) + + expected = {"foo.bar": "baz", "qux": "quux", "output_dir": output_dir} + + run = mlflow.get_run(run_id) + assert run.data.params == expected + + artifacts = mlflow.artifacts.list_artifacts( + run_id=run_id, artifact_path=output_dir.lstrip("/") + ) + assert len(artifacts) == 1 + assert artifacts[0].path.endswith("torchtune_config.yaml") diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index dde1619194..c1d6483a57 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -48,6 +48,34 @@ def save_config(config: DictConfig) -> Path: log.warning(f"Error saving config.\nError: \n{e}.") +def flatten_dict(d: Dict[str, Any], *, sep: str = ".", parent_key: str = ""): + """Recursively flattens a nested dictionary into one level of key-value pairs. + + Args: + d (Dict[str, Any]): Any dictionary to flatten. + sep (str, optional): Desired separator for flattening nested keys. Defaults to ".". + parent_key (str, optional): Key prefix for children (nested keys), containing parent key names. Defaults to "". + + Example: + >>> flatten_dict({"foo": {"bar": "baz"}, "qux": "quux"}, sep="--") + {"foo--bar": "baz", "qux": "quux"} + + Returns: + Dict[str, Any]: Flattened dictionary. + + Note: + Does not unnest dictionaries within list values (i.e., {"foo": [{"bar": "baz"}]}). + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, sep=sep, parent_key=new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + class MetricLoggerInterface(Protocol): """Abstract metric logger.""" @@ -469,3 +497,114 @@ def close(self) -> None: def __del__(self) -> None: self.close() + + +class MLFlowLogger(MetricLoggerInterface): + """Logger for use w/ MLFlow (https://mlflow.org/). + + Args: + experiment_name (Optional[str]): MLFlow experiment name. If not specified, will + default to MLFLOW_EXPERIMENT_NAME environment variable if set, or default. + tracking_uri (Optional[str]): MLFlow tracking uri. If not specified, will default + to MLFLOW_TRACKING_URI environment variable if set, or default. + run_id (Optional[str]): MLFlow run name. If not specified, will default + to mlflow-generated HRID. Unused if run_id is specified or MLFLOW_RUN_ID + environment variable is found. + run_name (Optional[str]): MLFlow run ID. If not specified, will default + to MLFLOW_RUN_ID environment variable if set, or a new run will be created. + + Example: + >>> logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1") + >>> logger.log("accuracy", 0.95, step=1) + >>> logger.log_dict({"loss": 0.1, "accuracy": 0.95}, step=1) + >>> logger.log_config(config) + >>> logger.close() + + Raises: + ImportError: If ``mlflow`` package is not installed. + + Note: + This logger requires the mlflow package to be installed. + You can install it with `pip install mlflow`. + """ + + def __init__( + self, + experiment_name: Optional[str] = None, + tracking_uri: Optional[str] = None, + run_id: Optional[str] = None, + run_name: Optional[str] = None, + ): + try: + import mlflow + except ImportError as e: + raise ImportError( + "``mlflow`` package not found. Please install mlflow using `pip install mlflow` to use MLFlowLogger." + "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'." + ) from e + + _, self.rank = get_world_size_and_rank() + + self._mlflow = mlflow + + self._tracking_uri = tracking_uri or os.getenv("MLFLOW_TRACKING_URI") + self._experiment_name = experiment_name or os.getenv("MLFLOW_EXPERIMENT_NAME") + self._run_id = run_id or os.getenv("MLFLOW_RUN_ID") + + if self.rank == 0: + if not self._mlflow.is_tracking_uri_set(): + if self._tracking_uri is not None: + self._mlflow.set_tracking_uri(self._tracking_uri) + + if self._mlflow.active_run() is None or self._nested_run or self._run_id: + if self._experiment_name is not None: + # Use of set_experiment() ensure that Experiment is created if not exists + self._mlflow.set_experiment(self._experiment_name) + run = self._mlflow.start_run(run_name=run_name) + self._run_id = run.info.run_id + + def log_config(self, config: DictConfig) -> None: + """Saves the config locally and also logs the config to mlflow. The config is + stored in the same directory as the checkpoint. + + Args: + config (DictConfig): config to log + """ + if self._mlflow.active_run(): + resolved = OmegaConf.to_container(config, resolve=True) + + # mlflow's params must be flat key-value pairs + config_as_params = flatten_dict(resolved) + self._mlflow.log_params(config_as_params, run_id=self._run_id) + + output_config_fname = save_config(config) + + # this avoids break if config's output_dir is an absolute path + artifact_path = str(output_config_fname.parent).lstrip("/") + + self._mlflow.log_artifact( + output_config_fname, + artifact_path=artifact_path, + run_id=self._run_id, + ) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self._mlflow.active_run(): + self._mlflow.log_metric(name, data, step=step, run_id=self._run_id) + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + if self._mlflow.active_run(): + self._mlflow.log_metrics(payload, step=step, run_id=self._run_id) + + def close(self) -> None: + """ + Ends the MLflow run. + After calling close, no further logging should be performed. + """ + if self.rank == 0 and self._mlflow.active_run(): + self._mlflow.end_run() + + def __del__(self) -> None: + # Ensure the MLflow run is closed when the logger is deleted. + if hasattr(self, "_mlflow") and self._mlflow.active_run(): + self._mlflow.end_run()