diff --git a/docs/source-pytorch/visualize/loggers.rst b/docs/source-pytorch/visualize/loggers.rst index bdf95ec1b675e..f4fd5b23b2311 100644 --- a/docs/source-pytorch/visualize/loggers.rst +++ b/docs/source-pytorch/visualize/loggers.rst @@ -54,3 +54,37 @@ Track and Visualize Experiments + +.. _mlflow_logger: + +MLflow Logger +------------- + +The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts. + +Example usage: + +.. code-block:: python + + import lightning as L + from lightning.pytorch.loggers import MLFlowLogger + + mlf_logger = MLFlowLogger( + experiment_name="lightning_logs", + tracking_uri="file:./ml-runs", + checkpoint_path_prefix="my_prefix" + ) + trainer = L.Trainer(logger=mlf_logger) + + # Your LightningModule definition + class LitModel(L.LightningModule): + def training_step(self, batch, batch_idx): + # example + self.logger.experiment.whatever_ml_flow_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.whatever_ml_flow_supports(...) + + # Train your model + model = LitModel() + trainer.fit(model) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8bc8e45989f77..ed4c854f7ef3f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored. + ### Removed ### Fixed diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index e3d99987b7f58..1d158f41b52bc 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self): :paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. - + checkpoint_path_prefix: A string to prefix the checkpoint artifact's path. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. @@ -121,6 +121,7 @@ def __init__( tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, + checkpoint_path_prefix: str = "", prefix: str = "", artifact_location: Optional[str] = None, run_id: Optional[str] = None, @@ -147,6 +148,7 @@ def __init__( self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} self._initialized = False + self._checkpoint_path_prefix = checkpoint_path_prefix from mlflow.tracking import MlflowClient @@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = Path(p).stem + artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..8118349ea6721 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -427,3 +427,33 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path): + """Test that the logger creates the folders and files in the right place with a prefix.""" + client = mlflow_mock.tracking.MlflowClient + + # Get model, logger, trainer and train + model = BoringModel() + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix") + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") + + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=2, + limit_train_batches=3, + limit_val_batches=3, + ) + trainer.fit(model) + + # Checkpoint log + assert client.return_value.log_artifact.call_count == 2 + # Metadata and aliases log + assert client.return_value.log_artifacts.call_count == 2 + + # Check that the prefix is used in the artifact path + for call in client.return_value.log_artifact.call_args_list: + args, _ = call + assert str(args[2]).startswith("my_prefix")