From fbe4557ee3222a8492f17faf3160798b2d16e7a6 Mon Sep 17 00:00:00 2001 From: Kris Concepcion <84737625+kriscon-db@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:52:57 -0400 Subject: [PATCH] Add GCS token refresh for log_artifacts (#13397) Signed-off-by: mlflow-automation Co-authored-by: mlflow-automation --- mlflow/store/artifact/artifact_repo.py | 18 ++++++ .../artifact/azure_data_lake_artifact_repo.py | 6 +- mlflow/store/artifact/cloud_artifact_repo.py | 18 ------ mlflow/store/artifact/gcs_artifact_repo.py | 40 +++++++++++-- .../artifact/optimized_s3_artifact_repo.py | 8 +-- .../artifact/presigned_url_artifact_repo.py | 5 +- mlflow/utils/_unity_catalog_utils.py | 14 ++++- .../test_unity_catalog_rest_store.py | 2 +- .../store/artifact/test_gcs_artifact_repo.py | 58 +++++++++++++++++++ .../test_presigned_url_artifact_repo.py | 4 +- ...test_unity_catalog_models_artifact_repo.py | 2 +- 11 files changed, 137 insertions(+), 38 deletions(-) diff --git a/mlflow/store/artifact/artifact_repo.py b/mlflow/store/artifact/artifact_repo.py index 5eb929ff05fd6..91779be8781f5 100644 --- a/mlflow/store/artifact/artifact_repo.py +++ b/mlflow/store/artifact/artifact_repo.py @@ -39,6 +39,24 @@ def _truncate_error(err: str, max_length: int = 10_000) -> str: return err[:half] + "\n\n*** Error message is too long, truncated ***\n\n" + err[-half:] +def _retry_with_new_creds(try_func, creds_func, orig_creds=None): + """ + Attempt the try_func with the original credentials (og_creds) if provided, or by generating the + credentials using creds_func. If the try_func throws, then try again with new credentials + provided by creds_func. + """ + try: + first_creds = creds_func() if orig_creds is None else orig_creds + return try_func(first_creds) + except Exception as e: + _logger.info( + f"Failed to complete request, possibly due to credential expiration (Error: {e})." + " Refreshing credentials and trying again..." + ) + new_creds = creds_func() + return try_func(new_creds) + + @developer_stable class ArtifactRepository: """ diff --git a/mlflow/store/artifact/azure_data_lake_artifact_repo.py b/mlflow/store/artifact/azure_data_lake_artifact_repo.py index 02291e139b49b..d8a1b613d2f73 100644 --- a/mlflow/store/artifact/azure_data_lake_artifact_repo.py +++ b/mlflow/store/artifact/azure_data_lake_artifact_repo.py @@ -15,11 +15,11 @@ ) from mlflow.exceptions import MlflowException from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialInfo +from mlflow.store.artifact.artifact_repo import _retry_with_new_creds from mlflow.store.artifact.cloud_artifact_repo import ( CloudArtifactRepository, _complete_futures, _compute_num_chunks, - _retry_with_new_creds, ) @@ -125,7 +125,7 @@ def try_func(creds): file_client.upload_data(data=file, overwrite=True) _retry_with_new_creds( - try_func=try_func, creds_func=self._refresh_credentials, og_creds=self.fs_client + try_func=try_func, creds_func=self._refresh_credentials, orig_creds=self.fs_client ) def list_artifacts(self, path=None): @@ -166,7 +166,7 @@ def try_func(creds): file_client.download_file().readinto(file) _retry_with_new_creds( - try_func=try_func, creds_func=self._refresh_credentials, og_creds=self.fs_client + try_func=try_func, creds_func=self._refresh_credentials, orig_creds=self.fs_client ) def delete_artifacts(self, artifact_path=None): diff --git a/mlflow/store/artifact/cloud_artifact_repo.py b/mlflow/store/artifact/cloud_artifact_repo.py index c0a34b1e45f4f..65de4da42f9c5 100644 --- a/mlflow/store/artifact/cloud_artifact_repo.py +++ b/mlflow/store/artifact/cloud_artifact_repo.py @@ -85,24 +85,6 @@ def _complete_futures(futures_dict, file): return results, errors -def _retry_with_new_creds(try_func, creds_func, og_creds=None): - """ - Attempt the try_func with the original credentials (og_creds) if provided, or by generating the - credentials using creds_func. If the try_func throws, then try again with new credentials - provided by creds_func. - """ - try: - first_creds = creds_func() if og_creds is None else og_creds - return try_func(first_creds) - except Exception as e: - _logger.info( - "Failed to complete request, possibly due to credential expiration." - f" Refreshing credentials and trying again... (Error: {e})" - ) - new_creds = creds_func() - return try_func(new_creds) - - StagedArtifactUpload = namedtuple( "StagedArtifactUpload", [ diff --git a/mlflow/store/artifact/gcs_artifact_repo.py b/mlflow/store/artifact/gcs_artifact_repo.py index c5f5524f4a0ba..6a0b99818acae 100644 --- a/mlflow/store/artifact/gcs_artifact_repo.py +++ b/mlflow/store/artifact/gcs_artifact_repo.py @@ -19,7 +19,11 @@ MLFLOW_GCS_UPLOAD_CHUNK_SIZE, ) from mlflow.exceptions import _UnsupportedMultipartUploadException -from mlflow.store.artifact.artifact_repo import ArtifactRepository, MultipartUploadMixin +from mlflow.store.artifact.artifact_repo import ( + ArtifactRepository, + MultipartUploadMixin, + _retry_with_new_creds, +) from mlflow.utils.file_utils import relative_path_to_artifact_path GCSMPUArguments = namedtuple("GCSMPUArguments", ["transport", "url", "headers", "content_type"]) @@ -36,7 +40,7 @@ class GCSArtifactRepository(ArtifactRepository, MultipartUploadMixin): credentials as described in https://google-cloud.readthedocs.io/en/latest/core/auth.html """ - def __init__(self, artifact_uri, client=None): + def __init__(self, artifact_uri, client=None, credential_refresh_def=None): super().__init__(artifact_uri) from google.auth.exceptions import DefaultCredentialsError from google.cloud import storage as gcs_storage @@ -49,7 +53,8 @@ def __init__(self, artifact_uri, client=None): or MLFLOW_GCS_DEFAULT_TIMEOUT.get() or _DEFAULT_TIMEOUT ) - + # Method to use for refresh + self.credential_refresh_def = credential_refresh_def # If the user-supplied timeout environment variable value is -1, # use `None` for `self._GCS_DEFAULT_TIMEOUT` # to use indefinite timeout @@ -78,6 +83,18 @@ def parse_gcs_uri(uri): def _get_bucket(self, bucket): return self.client.bucket(bucket) + def _refresh_credentials(self): + from google.cloud.storage import Client + from google.oauth2.credentials import Credentials + + (bucket, _) = self.parse_gcs_uri(self.artifact_uri) + if not self.credential_refresh_def: + return self._get_bucket(bucket) + new_token = self.credential_refresh_def() + credentials = Credentials(new_token["oauth_token"]) + self.client = Client(project="mlflow", credentials=credentials) + return self._get_bucket(bucket) + def log_artifact(self, local_file, artifact_path=None): (bucket, dest_path) = self.parse_gcs_uri(self.artifact_uri) if artifact_path: @@ -92,9 +109,9 @@ def log_artifacts(self, local_dir, artifact_path=None): (bucket, dest_path) = self.parse_gcs_uri(self.artifact_uri) if artifact_path: dest_path = posixpath.join(dest_path, artifact_path) - gcs_bucket = self._get_bucket(bucket) local_dir = os.path.abspath(local_dir) + for root, _, filenames in os.walk(local_dir): upload_path = dest_path if root != local_dir: @@ -102,9 +119,20 @@ def log_artifacts(self, local_dir, artifact_path=None): rel_path = relative_path_to_artifact_path(rel_path) upload_path = posixpath.join(dest_path, rel_path) for f in filenames: + gcs_bucket = self._get_bucket(bucket) path = posixpath.join(upload_path, f) - gcs_bucket.blob(path, chunk_size=self._GCS_UPLOAD_CHUNK_SIZE).upload_from_filename( - os.path.join(root, f), timeout=self._GCS_DEFAULT_TIMEOUT + # For large models, we need to speculatively retry a credential refresh + # and throw if it still fails. We cannot use the built-in refresh because UC + # does not return a refresh token with the oauth token + file_name = os.path.join(root, f) + + def try_func(gcs_bucket): + gcs_bucket.blob( + path, chunk_size=self._GCS_UPLOAD_CHUNK_SIZE + ).upload_from_filename(file_name, timeout=self._GCS_DEFAULT_TIMEOUT) + + _retry_with_new_creds( + try_func=try_func, creds_func=self._refresh_credentials, orig_creds=gcs_bucket ) def list_artifacts(self, path=None): diff --git a/mlflow/store/artifact/optimized_s3_artifact_repo.py b/mlflow/store/artifact/optimized_s3_artifact_repo.py index 545db71a7fb6d..5ae9b23c98bf2 100644 --- a/mlflow/store/artifact/optimized_s3_artifact_repo.py +++ b/mlflow/store/artifact/optimized_s3_artifact_repo.py @@ -13,11 +13,11 @@ ) from mlflow.exceptions import MlflowException from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialInfo +from mlflow.store.artifact.artifact_repo import _retry_with_new_creds from mlflow.store.artifact.cloud_artifact_repo import ( CloudArtifactRepository, _complete_futures, _compute_num_chunks, - _retry_with_new_creds, _validate_chunk_size_aws, ) from mlflow.store.artifact.s3_artifact_repo import _get_s3_client @@ -163,7 +163,7 @@ def try_func(creds): creds.upload_file(Filename=local_file, Bucket=bucket, Key=key, ExtraArgs=extra_args) _retry_with_new_creds( - try_func=try_func, creds_func=self._refresh_credentials, og_creds=s3_client + try_func=try_func, creds_func=self._refresh_credentials, orig_creds=s3_client ) def log_artifact(self, local_file, artifact_path=None): @@ -223,7 +223,7 @@ def try_func(creds): return response.headers["ETag"] return _retry_with_new_creds( - try_func=try_func, creds_func=self._refresh_credentials, og_creds=s3_client + try_func=try_func, creds_func=self._refresh_credentials, orig_creds=s3_client ) try: @@ -331,7 +331,7 @@ def try_func(creds): creds.download_file(self.bucket, s3_full_path, local_path) _retry_with_new_creds( - try_func=try_func, creds_func=self._refresh_credentials, og_creds=s3_client + try_func=try_func, creds_func=self._refresh_credentials, orig_creds=s3_client ) def delete_artifacts(self, artifact_path=None): diff --git a/mlflow/store/artifact/presigned_url_artifact_repo.py b/mlflow/store/artifact/presigned_url_artifact_repo.py index dc45c2d163495..280995484ef69 100644 --- a/mlflow/store/artifact/presigned_url_artifact_repo.py +++ b/mlflow/store/artifact/presigned_url_artifact_repo.py @@ -12,7 +12,8 @@ FilesystemService, ListDirectoryResponse, ) -from mlflow.store.artifact.cloud_artifact_repo import CloudArtifactRepository, _retry_with_new_creds +from mlflow.store.artifact.artifact_repo import _retry_with_new_creds +from mlflow.store.artifact.cloud_artifact_repo import CloudArtifactRepository from mlflow.utils.file_utils import download_file_using_http_uri from mlflow.utils.proto_json_utils import message_to_json from mlflow.utils.request_utils import augmented_raise_for_status, cloud_storage_http_request @@ -89,7 +90,7 @@ def creds_func(): return self._get_write_credential_infos(remote_file_paths=[artifact_file_path])[0] _retry_with_new_creds( - try_func=try_func, creds_func=creds_func, og_creds=cloud_credential_info + try_func=try_func, creds_func=creds_func, orig_creds=cloud_credential_info ) def list_artifacts(self, path=""): diff --git a/mlflow/utils/_unity_catalog_utils.py b/mlflow/utils/_unity_catalog_utils.py index 2436ebdf2561d..c5af7c45c7584 100644 --- a/mlflow/utils/_unity_catalog_utils.py +++ b/mlflow/utils/_unity_catalog_utils.py @@ -220,8 +220,20 @@ def azure_credential_refresh(): from mlflow.store.artifact.gcs_artifact_repo import GCSArtifactRepository credentials = Credentials(scoped_token.gcp_oauth_token.oauth_token) + + def gcp_credential_refresh(): + new_scoped_token = base_credential_refresh_def() + new_gcp_creds = new_scoped_token.gcp_oauth_token + return { + "oauth_token": new_gcp_creds.oauth_token, + } + client = Client(project="mlflow", credentials=credentials) - return GCSArtifactRepository(artifact_uri=storage_location, client=client) + return GCSArtifactRepository( + artifact_uri=storage_location, + client=client, + credential_refresh_def=gcp_credential_refresh, + ) elif credential_type == "r2_temp_credentials": from mlflow.store.artifact.r2_artifact_repo import R2ArtifactRepository diff --git a/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py b/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py index 1695ae4079ff5..6160ebc0f5fa7 100644 --- a/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py +++ b/tests/store/_unity_catalog/model_registry/test_unity_catalog_rest_store.py @@ -1449,7 +1449,7 @@ def test_create_model_version_gcp(store, local_model_dir, create_args): store.create_model_version(**create_kwargs) # Verify that gcs artifact repo mock was called with expected args gcs_artifact_repo_class_mock.assert_called_once_with( - artifact_uri=storage_location, client=ANY + artifact_uri=storage_location, client=ANY, credential_refresh_def=ANY ) mock_gcs_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="") gcs_client_args = gcs_client_class_mock.call_args_list[0] diff --git a/tests/store/artifact/test_gcs_artifact_repo.py b/tests/store/artifact/test_gcs_artifact_repo.py index 5e0b8d982b80a..9d7043adbc80a 100644 --- a/tests/store/artifact/test_gcs_artifact_repo.py +++ b/tests/store/artifact/test_gcs_artifact_repo.py @@ -483,3 +483,61 @@ def test_abort_multipart_upload(mock_client): f"{gcs_base_url}/{bucket_name}/{artifact_root_path}/{file_name}?uploadId={upload_id}", ) assert kwargs["data"] is None + + +@pytest.mark.parametrize("throw", [True, False]) +def test_retryable_log_artifacts(throw, tmp_path): + with mock.patch("google.cloud.storage.Client") as mock_gcs_client_factory, mock.patch( + "google.oauth2.credentials.Credentials" + ) as mock_gcs_credentials_factory: + gcs_client_mock = mock.Mock() + gcs_bucket_mock = mock.Mock() + gcs_client_mock.bucket.return_value = gcs_bucket_mock + + gcs_refreshed_client_mock = mock.Mock() + gcs_refreshed_bucket_mock = mock.Mock() + gcs_refreshed_client_mock.bucket.return_value = gcs_refreshed_bucket_mock + mock_gcs_client_factory.return_value = gcs_refreshed_client_mock + + def exception_thrown_side_effect_func(*args, **kwargs): + if throw: + raise Exception("Test Exception") + return None + + def success_side_effect_func(*args, **kwargs): + return None + + def creds_func(): + return {"oauth_token": "new_creds"} + + gcs_bucket_mock.blob.return_value.upload_from_filename.side_effect = ( + exception_thrown_side_effect_func + ) + gcs_refreshed_bucket_mock.blob.return_value.upload_from_filename.side_effect = ( + success_side_effect_func + ) + + repo = GCSArtifactRepository( + artifact_uri="gs://test_bucket/test_root/", + client=gcs_client_mock, + credential_refresh_def=creds_func, + ) + + data = tmp_path.joinpath("data") + data.mkdir() + subd = data.joinpath("subdir") + subd.mkdir() + subd.joinpath("a.txt").write_text("A") + + repo.log_artifacts(subd) + + if throw: + gcs_bucket_mock.blob.assert_called_once() + gcs_refreshed_bucket_mock.blob.assert_called_once() + mock_gcs_client_factory.assert_called_once() + mock_gcs_credentials_factory.assert_called_once() + else: + gcs_bucket_mock.blob.assert_called_once() + gcs_refreshed_bucket_mock.blob.assert_not_called() + mock_gcs_client_factory.assert_not_called() + mock_gcs_credentials_factory.assert_not_called() diff --git a/tests/store/artifact/test_presigned_url_artifact_repo.py b/tests/store/artifact/test_presigned_url_artifact_repo.py index cef3343e2707a..9c662dbb4f656 100644 --- a/tests/store/artifact/test_presigned_url_artifact_repo.py +++ b/tests/store/artifact/test_presigned_url_artifact_repo.py @@ -19,7 +19,7 @@ HttpHeader, ListDirectoryResponse, ) -from mlflow.store.artifact.cloud_artifact_repo import _retry_with_new_creds +from mlflow.store.artifact.artifact_repo import _retry_with_new_creds from mlflow.store.artifact.presigned_url_artifact_repo import ( DIRECTORIES_ENDPOINT, FILESYSTEM_METHOD_TO_INFO, @@ -342,7 +342,7 @@ def try_func(creds): mock_creds = mock.Mock(side_effect=creds_func) mock_func = mock.Mock(side_effect=try_func) if use_og_creds: - _retry_with_new_creds(try_func=mock_func, creds_func=mock_creds, og_creds=credentials) + _retry_with_new_creds(try_func=mock_func, creds_func=mock_creds, orig_creds=credentials) else: _retry_with_new_creds(try_func=mock_func, creds_func=mock_creds) diff --git a/tests/store/artifact/test_unity_catalog_models_artifact_repo.py b/tests/store/artifact/test_unity_catalog_models_artifact_repo.py index 514e7a698229b..9f1299c732cac 100644 --- a/tests/store/artifact/test_unity_catalog_models_artifact_repo.py +++ b/tests/store/artifact/test_unity_catalog_models_artifact_repo.py @@ -231,7 +231,7 @@ def test_uc_models_artifact_repo_download_artifacts_uses_temporary_creds_gcp(mon ) assert models_repo.download_artifacts("artifact_path", "dst_path") == fake_local_path gcs_artifact_repo_class_mock.assert_called_once_with( - artifact_uri=artifact_location, client=ANY + artifact_uri=artifact_location, client=ANY, credential_refresh_def=ANY ) mock_gcs_repo.download_artifacts.assert_called_once_with("artifact_path", "dst_path") gcs_client_args = gcs_client_class_mock.call_args_list[0]