From 2091de7c210f483149cbc2400887ab75405c7c7f Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Wed, 8 Jan 2025 13:04:12 +0200 Subject: [PATCH 1/6] Add checkpoint artifact path prefix to MLflow logger Add a new `checkpoint_artifact_path_prefix` parameter to the MLflow logger. * Modify `src/lightning/pytorch/loggers/mlflow.py` to include the new parameter in the `MLFlowLogger` class constructor and use it in the `after_save_checkpoint` method. * Update the documentation in `docs/source-pytorch/visualize/loggers.rst` to include the new `checkpoint_artifact_path_prefix` parameter. * Add a new test in `tests/tests_pytorch/loggers/test_mlflow.py` to verify the functionality of the `checkpoint_artifact_path_prefix` parameter and ensure it is used in the artifact path. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/Lightning-AI/pytorch-lightning?shareId=XXXX-XXXX-XXXX-XXXX). --- docs/source-pytorch/visualize/loggers.rst | 34 ++++++++++++++++++++++ src/lightning/pytorch/loggers/mlflow.py | 6 ++-- tests/tests_pytorch/loggers/test_mlflow.py | 29 ++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/visualize/loggers.rst b/docs/source-pytorch/visualize/loggers.rst index bdf95ec1b675e..361056b7a5a21 100644 --- a/docs/source-pytorch/visualize/loggers.rst +++ b/docs/source-pytorch/visualize/loggers.rst @@ -54,3 +54,37 @@ Track and Visualize Experiments + +.. _mlflow_logger: + +MLflow Logger +------------- + +The MLflow logger in PyTorch Lightning now includes a `checkpoint_artifact_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts. + +Example usage: + +.. code-block:: python + + from lightning.pytorch import Trainer + from lightning.pytorch.loggers import MLFlowLogger + + mlf_logger = MLFlowLogger( + experiment_name="lightning_logs", + tracking_uri="file:./ml-runs", + checkpoint_artifact_path_prefix="my_prefix" + ) + trainer = Trainer(logger=mlf_logger) + + # Your LightningModule definition + class LitModel(LightningModule): + def training_step(self, batch, batch_idx): + # example + self.logger.experiment.whatever_ml_flow_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.whatever_ml_flow_supports(...) + + # Train your model + model = LitModel() + trainer.fit(model) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index e3d99987b7f58..b624f6795205a 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self): :paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. - + checkpoint_artifact_path_prefix: A string to prefix the checkpoint artifact's path. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. @@ -121,6 +121,7 @@ def __init__( tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, + checkpoint_artifact_path_prefix: str = "", prefix: str = "", artifact_location: Optional[str] = None, run_id: Optional[str] = None, @@ -147,6 +148,7 @@ def __init__( self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} self._initialized = False + self._checkpoint_artifact_path_prefix = checkpoint_artifact_path_prefix from mlflow.tracking import MlflowClient @@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = Path(p).stem + artifact_path = Path(self._checkpoint_artifact_path_prefix) / Path(p).stem # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..a93ddc2f5f3ca 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -427,3 +427,32 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_log_model_with_checkpoint_artifact_path_prefix(mlflow_mock, tmp_path): + """Test that the logger creates the folders and files in the right place with a prefix.""" + client = mlflow_mock.tracking.MlflowClient + + # Get model, logger, trainer and train + model = BoringModel() + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_artifact_path_prefix="my_prefix") + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") + + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=2, + limit_train_batches=3, + limit_val_batches=3, + ) + trainer.fit(model) + + # Checkpoint log + assert client.return_value.log_artifact.call_count == 2 + # Metadata and aliases log + assert client.return_value.log_artifacts.call_count == 2 + + # Check that the prefix is used in the artifact path + for call in client.return_value.log_artifact.call_args_list: + assert call[1]["artifact_path"].startswith("my_prefix") From 7059161f07d77b375f56546a7f56ab7572dd0434 Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Thu, 9 Jan 2025 14:08:19 +0000 Subject: [PATCH 2/6] Add CHANGELOG --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9f7317c218c30..864eacc82c16d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +- Added a new `checkpoint_artifact_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored. + ## [2.5.0] - 2024-12-19 ### Added From 182bef623d1a0e40c18a169cafbd48a833103f47 Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Tue, 14 Jan 2025 08:02:08 +0000 Subject: [PATCH 3/6] Fix PR comments --- docs/source-pytorch/visualize/loggers.rst | 8 ++++---- src/lightning/pytorch/loggers/mlflow.py | 8 ++++---- tests/tests_pytorch/loggers/test_mlflow.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source-pytorch/visualize/loggers.rst b/docs/source-pytorch/visualize/loggers.rst index 361056b7a5a21..76ebd5b15ec6e 100644 --- a/docs/source-pytorch/visualize/loggers.rst +++ b/docs/source-pytorch/visualize/loggers.rst @@ -60,13 +60,13 @@ Track and Visualize Experiments MLflow Logger ------------- -The MLflow logger in PyTorch Lightning now includes a `checkpoint_artifact_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts. +The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts. Example usage: .. code-block:: python - from lightning.pytorch import Trainer + import lightning as L from lightning.pytorch.loggers import MLFlowLogger mlf_logger = MLFlowLogger( @@ -74,10 +74,10 @@ Example usage: tracking_uri="file:./ml-runs", checkpoint_artifact_path_prefix="my_prefix" ) - trainer = Trainer(logger=mlf_logger) + trainer = L.Trainer(logger=mlf_logger) # Your LightningModule definition - class LitModel(LightningModule): + class LitModel(L.LightningModule): def training_step(self, batch, batch_idx): # example self.logger.experiment.whatever_ml_flow_supports(...) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index b624f6795205a..1d158f41b52bc 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self): :paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. - checkpoint_artifact_path_prefix: A string to prefix the checkpoint artifact's path. + checkpoint_path_prefix: A string to prefix the checkpoint artifact's path. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. @@ -121,7 +121,7 @@ def __init__( tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, - checkpoint_artifact_path_prefix: str = "", + checkpoint_path_prefix: str = "", prefix: str = "", artifact_location: Optional[str] = None, run_id: Optional[str] = None, @@ -148,7 +148,7 @@ def __init__( self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} self._initialized = False - self._checkpoint_artifact_path_prefix = checkpoint_artifact_path_prefix + self._checkpoint_path_prefix = checkpoint_path_prefix from mlflow.tracking import MlflowClient @@ -363,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = Path(self._checkpoint_artifact_path_prefix) / Path(p).stem + artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index a93ddc2f5f3ca..2b30f73cc4b68 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -436,7 +436,7 @@ def test_mlflow_log_model_with_checkpoint_artifact_path_prefix(mlflow_mock, tmp_ # Get model, logger, trainer and train model = BoringModel() - logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_artifact_path_prefix="my_prefix") + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix") logger = mock_mlflow_run_creation(logger, experiment_id="test-id") trainer = Trainer( From b748e5dc6b0103e29354f36bb8a04fc715d02fcb Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Wed, 15 Jan 2025 10:24:06 +0000 Subject: [PATCH 4/6] Fix MLflow logger test for `checkpoint_path_prefix` --- tests/tests_pytorch/loggers/test_mlflow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 2b30f73cc4b68..8118349ea6721 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -430,7 +430,7 @@ def test_set_tracking_uri(mlflow_mock): @mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) -def test_mlflow_log_model_with_checkpoint_artifact_path_prefix(mlflow_mock, tmp_path): +def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path): """Test that the logger creates the folders and files in the right place with a prefix.""" client = mlflow_mock.tracking.MlflowClient @@ -455,4 +455,5 @@ def test_mlflow_log_model_with_checkpoint_artifact_path_prefix(mlflow_mock, tmp_ # Check that the prefix is used in the artifact path for call in client.return_value.log_artifact.call_args_list: - assert call[1]["artifact_path"].startswith("my_prefix") + args, _ = call + assert str(args[2]).startswith("my_prefix") From 0fce2cc3d39e54d1d80ae79f7a93ee75a00dfac7 Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Sun, 9 Feb 2025 11:39:04 +0000 Subject: [PATCH 5/6] Update stale documentation --- docs/source-pytorch/visualize/loggers.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/visualize/loggers.rst b/docs/source-pytorch/visualize/loggers.rst index 76ebd5b15ec6e..f4fd5b23b2311 100644 --- a/docs/source-pytorch/visualize/loggers.rst +++ b/docs/source-pytorch/visualize/loggers.rst @@ -72,7 +72,7 @@ Example usage: mlf_logger = MLFlowLogger( experiment_name="lightning_logs", tracking_uri="file:./ml-runs", - checkpoint_artifact_path_prefix="my_prefix" + checkpoint_path_prefix="my_prefix" ) trainer = L.Trainer(logger=mlf_logger) From 29ee953084f28b6631ea4d21ba7557fa12e1e186 Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Sun, 9 Feb 2025 11:39:53 +0000 Subject: [PATCH 6/6] Update stale documentation II --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 864eacc82c16d..1bd3ca46912d3 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -- Added a new `checkpoint_artifact_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored. +- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored. ## [2.5.0] - 2024-12-19