From 39fc7c954d338f1276e0d44869a9ed32005f8638 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Sat, 8 Feb 2025 15:23:16 +1100 Subject: [PATCH 1/8] Implements MLFlowLogger and basic testing suite Signed-off-by: Nathan Azrak --- docs/source/api_ref_training.rst | 1 + .../torchtune/training/test_metric_logging.py | 59 +++++++ torchtune/training/metric_logging.py | 167 ++++++++++++++++++ 3 files changed, 227 insertions(+) diff --git a/docs/source/api_ref_training.rst b/docs/source/api_ref_training.rst index 747f312447..00509d4505 100644 --- a/docs/source/api_ref_training.rst +++ b/docs/source/api_ref_training.rst @@ -102,6 +102,7 @@ Various logging utilities. metric_logging.TensorBoardLogger metric_logging.StdoutLogger metric_logging.DiskLogger + metric_logging.MLFlowLogger .. _perf_profiling_label: diff --git a/tests/torchtune/training/test_metric_logging.py b/tests/torchtune/training/test_metric_logging.py index 2fc29e72aa..1d4012ab55 100644 --- a/tests/torchtune/training/test_metric_logging.py +++ b/tests/torchtune/training/test_metric_logging.py @@ -10,6 +10,9 @@ from typing import cast from unittest.mock import patch +import mlflow + +import mlflow.artifacts import pytest from omegaconf import OmegaConf from tensorboard.backend.event_processing.event_accumulator import EventAccumulator @@ -19,6 +22,7 @@ from torchtune.training.metric_logging import ( CometLogger, DiskLogger, + MLFlowLogger, StdoutLogger, TensorBoardLogger, WandBLogger, @@ -199,3 +203,58 @@ def test_log_config(self) -> None: cfg = OmegaConf.create({"a": 1, "b": 2}) logger.log_config(cfg) mock_experiment.return_value.log_parameters.assert_called_with(cfg) + + +@pytest.fixture(scope="class") +def mlflow_context_fixture(): + original_uri = mlflow.get_tracking_uri() + + with tempfile.TemporaryDirectory() as tmpdir: + mlflow.set_tracking_uri(f"file:{tmpdir}") + yield + + # Restore the original URI + mlflow.set_tracking_uri(original_uri) + + +@pytest.mark.usefixtures("mlflow_context_fixture") +class TestMLFlowLogger: + def test_log(self): + logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1") + run_id = logger._run_id + logger.log("test_metric", 1.0, step=1) + logger.close() + + run = mlflow.get_run(run_id) + assert run.data.metrics == {"test_metric": 1} + + def test_log_dict(self): + logger = MLFlowLogger(experiment_name="my_experiment", run_name="run2") + run_id = logger._run_id + metric_dict = {f"log_dict_{i}": float(i) ** 2 for i in range(5)} + logger.log_dict(metric_dict, step=2) + logger.close() + + run = mlflow.get_run(run_id) + assert run.data.metrics == metric_dict + + def test_log_config(self) -> None: + with tempfile.TemporaryDirectory() as output_dir: + cfg = OmegaConf.create( + {"foo": {"bar": "baz"}, "qux": "quux", "output_dir": output_dir} + ) + logger = MLFlowLogger(experiment_name="my_experiment", run_name="run2") + run_id = logger._run_id + + logger.log_config(cfg) + + expected = {"foo.bar": "baz", "qux": "quux", "output_dir": output_dir} + + run = mlflow.get_run(run_id) + assert run.data.params == expected + + artifacts = mlflow.artifacts.list_artifacts( + run_id=run_id, artifact_path=output_dir.lstrip("/") + ) + assert len(artifacts) == 1 + assert artifacts[0].path.endswith("torchtune_config.yaml") diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index dde1619194..a8f21482b8 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -48,6 +48,17 @@ def save_config(config: DictConfig) -> Path: log.warning(f"Error saving config.\nError: \n{e}.") +def flatten_dict(d, parent_key="", sep="."): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + class MetricLoggerInterface(Protocol): """Abstract metric logger.""" @@ -469,3 +480,159 @@ def close(self) -> None: def __del__(self) -> None: self.close() + + +class MetricLoggerInterface(Protocol): + """Abstract metric logger.""" + + def log( + self, + name: str, + data: Scalar, + step: int, + ) -> None: + """Log scalar data. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record + """ + pass + + def log_config(self, config: DictConfig) -> None: + """Logs the config as file + + Args: + config (DictConfig): config to log + """ + pass + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + """Log multiple scalar values. + + Args: + payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record + """ + pass + + def close(self) -> None: + """ + Close log resource, flushing if necessary. + Logs should not be written after `close` is called. + """ + pass + + +class MLFlowLogger(MetricLoggerInterface): + """Logger for use w/ MLFlow. + + Args: + experiment_name (Optional[str]): MLFlow experiment name. If not specified, will + default to MLFLOW_EXPERIMENT_NAME environment variable if set, or default. + tracking_uri (Optional[str]): MLFlow tracking uri. If not specified, will default + to MLFLOW_TRACKING_URI environment variable if set, or default. + run_id (Optional[str]): MLFlow run name. If not specified, will default + to mlflow-generated HRID. Unused if run_id is specified or MLFLOW_RUN_ID + environment variable is found. + run_name (Optional[str]): MLFlow run ID. If not specified, will default + to MLFLOW_RUN_ID environment variable if set, or a new run will be created. + + Example: + >>> logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1", log_dir="./mlruns") + >>> logger.log("accuracy", 0.95, step=1) + >>> logger.log_dict({"loss": 0.1, "accuracy": 0.95}, step=1) + >>> logger.log_config(config) + >>> logger.close() + + Raises: + ImportError: If ``mlflow`` package is not installed. + + Note: + This logger requires the mlflow package to be installed. + You can install it with `pip install mlflow`. + """ + + def __init__( + self, + experiment_name: Optional[str] = None, + tracking_uri: Optional[str] = None, + run_id: Optional[str] = None, + run_name: Optional[str] = None, + ): + try: + import mlflow + except ImportError as e: + raise ImportError( + "``mlflow`` package not found. Please install mlflow using `pip install mlflow` to use MLFlowLogger." + "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'." + ) from e + + _, self.rank = get_world_size_and_rank() + + self._mlflow = mlflow + + self._tracking_uri = tracking_uri or os.getenv("MLFLOW_TRACKING_URI", None) + self._experiment_name = experiment_name or os.getenv( + "MLFLOW_EXPERIMENT_NAME", None + ) + self._run_id = run_id or os.getenv("MLFLOW_RUN_ID", None) + + if self.rank == 0: + if not self._mlflow.is_tracking_uri_set(): + if self._tracking_uri: + self._mlflow.set_tracking_uri(self._tracking_uri) + + if self._mlflow.active_run() is None or self._nested_run or self._run_id: + if self._experiment_name: + # Use of set_experiment() ensure that Experiment is created if not exists + self._mlflow.set_experiment(self._experiment_name) + run = self._mlflow.start_run(run_name=run_name) + self._run_id = run.info.run_id + + def log_config(self, config: DictConfig) -> None: + """Saves the config locally and also logs the config to mlflow. The config is + stored in the same directory as the checkpoint. + + Args: + config (DictConfig): config to log + """ + if self._mlflow.active_run(): + resolved = OmegaConf.to_container(config, resolve=True) + + # mlflow's params must be flat key-value pairs + config_as_params = flatten_dict(resolved) + self._mlflow.log_params(config_as_params, run_id=self._run_id) + + output_config_fname = save_config(config) + + # this avoids break if config's output_dir is an absolute path + artifact_path = str(output_config_fname.parent).lstrip("/") + + self._mlflow.log_artifact( + output_config_fname, + artifact_path=artifact_path, + run_id=self._run_id, + ) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self._mlflow.active_run(): + self._mlflow.log_metric(name, data, step=step, run_id=self._run_id) + + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + if self._mlflow.active_run(): + self._mlflow.log_metrics(payload, step=step, run_id=self._run_id) + + def close(self) -> None: + """ + Ends the MLflow run. + After calling close, no further logging should be performed. + """ + if self.rank == 0 and self._mlflow.active_run(): + self._mlflow.end_run() + + def __del__(self) -> None: + # Ensure the MLflow run is closed when the logger is deleted. + if hasattr(self, "_mlflow") and self._mlflow.active_run(): + self._mlflow.end_run() From 6df00224241e153bc4800d7797e863467cc25ef8 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Sat, 8 Feb 2025 15:29:07 +1100 Subject: [PATCH 2/8] Remove duplicated interface Signed-off-by: Nathan Azrak --- torchtune/training/metric_logging.py | 43 ---------------------------- 1 file changed, 43 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index a8f21482b8..ee776e63a9 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -482,49 +482,6 @@ def __del__(self) -> None: self.close() -class MetricLoggerInterface(Protocol): - """Abstract metric logger.""" - - def log( - self, - name: str, - data: Scalar, - step: int, - ) -> None: - """Log scalar data. - - Args: - name (str): tag name used to group scalars - data (Scalar): scalar data to log - step (int): step value to record - """ - pass - - def log_config(self, config: DictConfig) -> None: - """Logs the config as file - - Args: - config (DictConfig): config to log - """ - pass - - def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: - """Log multiple scalar values. - - Args: - payload (Mapping[str, Scalar]): dictionary of tag name and scalar value - step (int): step value to record - """ - pass - - def close(self) -> None: - """ - Close log resource, flushing if necessary. - Logs should not be written after `close` is called. - """ - pass - - class MLFlowLogger(MetricLoggerInterface): """Logger for use w/ MLFlow. From cc54013c12edc7550ecced442096a78bddebfece Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Tue, 11 Feb 2025 11:14:35 +1100 Subject: [PATCH 3/8] Add mlflow to dev dependencies Signed-off-by: Nathan Azrak --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f94732b58a..9b5f386997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ tune = "torchtune._cli.tune:main" dev = [ "bitsandbytes>=0.43.0", "comet_ml>=3.44.2", + "mlflow", "pre-commit", "pytest==7.4.0", "pytest-cov", From 6d4f859813819411c92a0015f5518aa800fa2d3b Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 12 Feb 2025 08:56:15 +1100 Subject: [PATCH 4/8] Remove unnecessary None default from getenv Signed-off-by: Nathan Azrak --- torchtune/training/metric_logging.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index ee776e63a9..4420fde2d9 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -530,11 +530,9 @@ def __init__( self._mlflow = mlflow - self._tracking_uri = tracking_uri or os.getenv("MLFLOW_TRACKING_URI", None) - self._experiment_name = experiment_name or os.getenv( - "MLFLOW_EXPERIMENT_NAME", None - ) - self._run_id = run_id or os.getenv("MLFLOW_RUN_ID", None) + self._tracking_uri = tracking_uri or os.getenv("MLFLOW_TRACKING_URI") + self._experiment_name = experiment_name or os.getenv("MLFLOW_EXPERIMENT_NAME") + self._run_id = run_id or os.getenv("MLFLOW_RUN_ID") if self.rank == 0: if not self._mlflow.is_tracking_uri_set(): From 162bc3c1068dbc248953326145de7c7789d842ca Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 12 Feb 2025 08:57:39 +1100 Subject: [PATCH 5/8] Fix logger docstring - delete removed param, add mlflow website Signed-off-by: Nathan Azrak --- torchtune/training/metric_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 4420fde2d9..55d22d531f 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -483,7 +483,7 @@ def __del__(self) -> None: class MLFlowLogger(MetricLoggerInterface): - """Logger for use w/ MLFlow. + """Logger for use w/ MLFlow (https://mlflow.org/). Args: experiment_name (Optional[str]): MLFlow experiment name. If not specified, will @@ -497,7 +497,7 @@ class MLFlowLogger(MetricLoggerInterface): to MLFLOW_RUN_ID environment variable if set, or a new run will be created. Example: - >>> logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1", log_dir="./mlruns") + >>> logger = MLFlowLogger(experiment_name="my_experiment", run_name="run1") >>> logger.log("accuracy", 0.95, step=1) >>> logger.log_dict({"loss": 0.1, "accuracy": 0.95}, step=1) >>> logger.log_config(config) From 675eadf0af56830d8a4708d83b84e5b97a836dc4 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 12 Feb 2025 09:35:06 +1100 Subject: [PATCH 6/8] Update `flatten_dict` with reordered args and docstring Signed-off-by: Nathan Azrak --- torchtune/training/metric_logging.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 55d22d531f..6e6d22cf87 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -48,12 +48,30 @@ def save_config(config: DictConfig) -> Path: log.warning(f"Error saving config.\nError: \n{e}.") -def flatten_dict(d, parent_key="", sep="."): +def flatten_dict(d: Dict[str, Any], sep: str = ".", parent_key: str = ""): + """Recursively flattens a nested dictionary into one level of key-value pairs. + + Args: + d (Dict[str, Any]): Any dictionary to flatten + sep (str): Desired separator for flattening nested keys + parent_key (str): Key prefix for children (nested keys), containing parent key + names + + Example: + >>> flatten_dict({"foo": {"bar": "baz"}, "qux"; "quux"}, parent_key="--") + {"foo--bar": "baz", "qux", "quux"} + + Returns: + Flattened dictionary + + Note: + Does not unnest dictionaries within list values (i.e. {"foo": [{"bar": "baz}]}}) + """ items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): - items.extend(flatten_dict(v, new_key, sep=sep).items()) + items.extend(flatten_dict(v, sep=sep, parent_key=new_key).items()) else: items.append((new_key, v)) return dict(items) From 753170082e1063e6069f8d349a7d3506c37c241b Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Wed, 12 Feb 2025 09:44:53 +1100 Subject: [PATCH 7/8] `if x` -> `if x is not None` Co-authored-by: Joe Cummings --- torchtune/training/metric_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 6e6d22cf87..200f632571 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -554,11 +554,11 @@ def __init__( if self.rank == 0: if not self._mlflow.is_tracking_uri_set(): - if self._tracking_uri: + if self._tracking_uri is not None: self._mlflow.set_tracking_uri(self._tracking_uri) if self._mlflow.active_run() is None or self._nested_run or self._run_id: - if self._experiment_name: + if self._experiment_name is not None: # Use of set_experiment() ensure that Experiment is created if not exists self._mlflow.set_experiment(self._experiment_name) run = self._mlflow.start_run(run_name=run_name) From c06943894cdc0302be9ce9e16f7befcf95572133 Mon Sep 17 00:00:00 2001 From: Nathan Azrak Date: Wed, 12 Feb 2025 10:43:23 +1100 Subject: [PATCH 8/8] Force kwarg specification in flatten_dict and fix docstring Signed-off-by: Nathan Azrak --- torchtune/training/metric_logging.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchtune/training/metric_logging.py b/torchtune/training/metric_logging.py index 200f632571..c1d6483a57 100644 --- a/torchtune/training/metric_logging.py +++ b/torchtune/training/metric_logging.py @@ -48,24 +48,23 @@ def save_config(config: DictConfig) -> Path: log.warning(f"Error saving config.\nError: \n{e}.") -def flatten_dict(d: Dict[str, Any], sep: str = ".", parent_key: str = ""): +def flatten_dict(d: Dict[str, Any], *, sep: str = ".", parent_key: str = ""): """Recursively flattens a nested dictionary into one level of key-value pairs. Args: - d (Dict[str, Any]): Any dictionary to flatten - sep (str): Desired separator for flattening nested keys - parent_key (str): Key prefix for children (nested keys), containing parent key - names + d (Dict[str, Any]): Any dictionary to flatten. + sep (str, optional): Desired separator for flattening nested keys. Defaults to ".". + parent_key (str, optional): Key prefix for children (nested keys), containing parent key names. Defaults to "". Example: - >>> flatten_dict({"foo": {"bar": "baz"}, "qux"; "quux"}, parent_key="--") - {"foo--bar": "baz", "qux", "quux"} + >>> flatten_dict({"foo": {"bar": "baz"}, "qux": "quux"}, sep="--") + {"foo--bar": "baz", "qux": "quux"} Returns: - Flattened dictionary + Dict[str, Any]: Flattened dictionary. Note: - Does not unnest dictionaries within list values (i.e. {"foo": [{"bar": "baz}]}}) + Does not unnest dictionaries within list values (i.e., {"foo": [{"bar": "baz"}]}). """ items = [] for k, v in d.items():