Skip to content

Commit

Permalink
Add GCS token refresh for log_artifacts (mlflow#13397)
Browse files Browse the repository at this point in the history
Signed-off-by: mlflow-automation <[email protected]>
Co-authored-by: mlflow-automation <[email protected]>
  • Loading branch information
2 people authored and serena-ruan committed Oct 25, 2024
1 parent 29ad1b9 commit da13335
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 38 deletions.
18 changes: 18 additions & 0 deletions mlflow/store/artifact/artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ def _truncate_error(err: str, max_length: int = 10_000) -> str:
return err[:half] + "\n\n*** Error message is too long, truncated ***\n\n" + err[-half:]


def _retry_with_new_creds(try_func, creds_func, orig_creds=None):
"""
Attempt the try_func with the original credentials (og_creds) if provided, or by generating the
credentials using creds_func. If the try_func throws, then try again with new credentials
provided by creds_func.
"""
try:
first_creds = creds_func() if orig_creds is None else orig_creds
return try_func(first_creds)
except Exception as e:
_logger.info(
f"Failed to complete request, possibly due to credential expiration (Error: {e})."
" Refreshing credentials and trying again..."
)
new_creds = creds_func()
return try_func(new_creds)


@developer_stable
class ArtifactRepository:
"""
Expand Down
6 changes: 3 additions & 3 deletions mlflow/store/artifact/azure_data_lake_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
)
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialInfo
from mlflow.store.artifact.artifact_repo import _retry_with_new_creds
from mlflow.store.artifact.cloud_artifact_repo import (
CloudArtifactRepository,
_complete_futures,
_compute_num_chunks,
_retry_with_new_creds,
)


Expand Down Expand Up @@ -125,7 +125,7 @@ def try_func(creds):
file_client.upload_data(data=file, overwrite=True)

_retry_with_new_creds(
try_func=try_func, creds_func=self._refresh_credentials, og_creds=self.fs_client
try_func=try_func, creds_func=self._refresh_credentials, orig_creds=self.fs_client
)

def list_artifacts(self, path=None):
Expand Down Expand Up @@ -166,7 +166,7 @@ def try_func(creds):
file_client.download_file().readinto(file)

_retry_with_new_creds(
try_func=try_func, creds_func=self._refresh_credentials, og_creds=self.fs_client
try_func=try_func, creds_func=self._refresh_credentials, orig_creds=self.fs_client
)

def delete_artifacts(self, artifact_path=None):
Expand Down
18 changes: 0 additions & 18 deletions mlflow/store/artifact/cloud_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,6 @@ def _complete_futures(futures_dict, file):
return results, errors


def _retry_with_new_creds(try_func, creds_func, og_creds=None):
"""
Attempt the try_func with the original credentials (og_creds) if provided, or by generating the
credentials using creds_func. If the try_func throws, then try again with new credentials
provided by creds_func.
"""
try:
first_creds = creds_func() if og_creds is None else og_creds
return try_func(first_creds)
except Exception as e:
_logger.info(
"Failed to complete request, possibly due to credential expiration."
f" Refreshing credentials and trying again... (Error: {e})"
)
new_creds = creds_func()
return try_func(new_creds)


StagedArtifactUpload = namedtuple(
"StagedArtifactUpload",
[
Expand Down
40 changes: 34 additions & 6 deletions mlflow/store/artifact/gcs_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
MLFLOW_GCS_UPLOAD_CHUNK_SIZE,
)
from mlflow.exceptions import _UnsupportedMultipartUploadException
from mlflow.store.artifact.artifact_repo import ArtifactRepository, MultipartUploadMixin
from mlflow.store.artifact.artifact_repo import (
ArtifactRepository,
MultipartUploadMixin,
_retry_with_new_creds,
)
from mlflow.utils.file_utils import relative_path_to_artifact_path

GCSMPUArguments = namedtuple("GCSMPUArguments", ["transport", "url", "headers", "content_type"])
Expand All @@ -36,7 +40,7 @@ class GCSArtifactRepository(ArtifactRepository, MultipartUploadMixin):
credentials as described in https://google-cloud.readthedocs.io/en/latest/core/auth.html
"""

def __init__(self, artifact_uri, client=None):
def __init__(self, artifact_uri, client=None, credential_refresh_def=None):
super().__init__(artifact_uri)
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import storage as gcs_storage
Expand All @@ -49,7 +53,8 @@ def __init__(self, artifact_uri, client=None):
or MLFLOW_GCS_DEFAULT_TIMEOUT.get()
or _DEFAULT_TIMEOUT
)

# Method to use for refresh
self.credential_refresh_def = credential_refresh_def
# If the user-supplied timeout environment variable value is -1,
# use `None` for `self._GCS_DEFAULT_TIMEOUT`
# to use indefinite timeout
Expand Down Expand Up @@ -78,6 +83,18 @@ def parse_gcs_uri(uri):
def _get_bucket(self, bucket):
return self.client.bucket(bucket)

def _refresh_credentials(self):
from google.cloud.storage import Client
from google.oauth2.credentials import Credentials

(bucket, _) = self.parse_gcs_uri(self.artifact_uri)
if not self.credential_refresh_def:
return self._get_bucket(bucket)
new_token = self.credential_refresh_def()
credentials = Credentials(new_token["oauth_token"])
self.client = Client(project="mlflow", credentials=credentials)
return self._get_bucket(bucket)

def log_artifact(self, local_file, artifact_path=None):
(bucket, dest_path) = self.parse_gcs_uri(self.artifact_uri)
if artifact_path:
Expand All @@ -92,19 +109,30 @@ def log_artifacts(self, local_dir, artifact_path=None):
(bucket, dest_path) = self.parse_gcs_uri(self.artifact_uri)
if artifact_path:
dest_path = posixpath.join(dest_path, artifact_path)
gcs_bucket = self._get_bucket(bucket)

local_dir = os.path.abspath(local_dir)

for root, _, filenames in os.walk(local_dir):
upload_path = dest_path
if root != local_dir:
rel_path = os.path.relpath(root, local_dir)
rel_path = relative_path_to_artifact_path(rel_path)
upload_path = posixpath.join(dest_path, rel_path)
for f in filenames:
gcs_bucket = self._get_bucket(bucket)
path = posixpath.join(upload_path, f)
gcs_bucket.blob(path, chunk_size=self._GCS_UPLOAD_CHUNK_SIZE).upload_from_filename(
os.path.join(root, f), timeout=self._GCS_DEFAULT_TIMEOUT
# For large models, we need to speculatively retry a credential refresh
# and throw if it still fails. We cannot use the built-in refresh because UC
# does not return a refresh token with the oauth token
file_name = os.path.join(root, f)

def try_func(gcs_bucket):
gcs_bucket.blob(
path, chunk_size=self._GCS_UPLOAD_CHUNK_SIZE
).upload_from_filename(file_name, timeout=self._GCS_DEFAULT_TIMEOUT)

_retry_with_new_creds(
try_func=try_func, creds_func=self._refresh_credentials, orig_creds=gcs_bucket
)

def list_artifacts(self, path=None):
Expand Down
8 changes: 4 additions & 4 deletions mlflow/store/artifact/optimized_s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
)
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialInfo
from mlflow.store.artifact.artifact_repo import _retry_with_new_creds
from mlflow.store.artifact.cloud_artifact_repo import (
CloudArtifactRepository,
_complete_futures,
_compute_num_chunks,
_retry_with_new_creds,
_validate_chunk_size_aws,
)
from mlflow.store.artifact.s3_artifact_repo import _get_s3_client
Expand Down Expand Up @@ -163,7 +163,7 @@ def try_func(creds):
creds.upload_file(Filename=local_file, Bucket=bucket, Key=key, ExtraArgs=extra_args)

_retry_with_new_creds(
try_func=try_func, creds_func=self._refresh_credentials, og_creds=s3_client
try_func=try_func, creds_func=self._refresh_credentials, orig_creds=s3_client
)

def log_artifact(self, local_file, artifact_path=None):
Expand Down Expand Up @@ -223,7 +223,7 @@ def try_func(creds):
return response.headers["ETag"]

return _retry_with_new_creds(
try_func=try_func, creds_func=self._refresh_credentials, og_creds=s3_client
try_func=try_func, creds_func=self._refresh_credentials, orig_creds=s3_client
)

try:
Expand Down Expand Up @@ -331,7 +331,7 @@ def try_func(creds):
creds.download_file(self.bucket, s3_full_path, local_path)

_retry_with_new_creds(
try_func=try_func, creds_func=self._refresh_credentials, og_creds=s3_client
try_func=try_func, creds_func=self._refresh_credentials, orig_creds=s3_client
)

def delete_artifacts(self, artifact_path=None):
Expand Down
5 changes: 3 additions & 2 deletions mlflow/store/artifact/presigned_url_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
FilesystemService,
ListDirectoryResponse,
)
from mlflow.store.artifact.cloud_artifact_repo import CloudArtifactRepository, _retry_with_new_creds
from mlflow.store.artifact.artifact_repo import _retry_with_new_creds
from mlflow.store.artifact.cloud_artifact_repo import CloudArtifactRepository
from mlflow.utils.file_utils import download_file_using_http_uri
from mlflow.utils.proto_json_utils import message_to_json
from mlflow.utils.request_utils import augmented_raise_for_status, cloud_storage_http_request
Expand Down Expand Up @@ -89,7 +90,7 @@ def creds_func():
return self._get_write_credential_infos(remote_file_paths=[artifact_file_path])[0]

_retry_with_new_creds(
try_func=try_func, creds_func=creds_func, og_creds=cloud_credential_info
try_func=try_func, creds_func=creds_func, orig_creds=cloud_credential_info
)

def list_artifacts(self, path=""):
Expand Down
14 changes: 13 additions & 1 deletion mlflow/utils/_unity_catalog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,20 @@ def azure_credential_refresh():
from mlflow.store.artifact.gcs_artifact_repo import GCSArtifactRepository

credentials = Credentials(scoped_token.gcp_oauth_token.oauth_token)

def gcp_credential_refresh():
new_scoped_token = base_credential_refresh_def()
new_gcp_creds = new_scoped_token.gcp_oauth_token
return {
"oauth_token": new_gcp_creds.oauth_token,
}

client = Client(project="mlflow", credentials=credentials)
return GCSArtifactRepository(artifact_uri=storage_location, client=client)
return GCSArtifactRepository(
artifact_uri=storage_location,
client=client,
credential_refresh_def=gcp_credential_refresh,
)
elif credential_type == "r2_temp_credentials":
from mlflow.store.artifact.r2_artifact_repo import R2ArtifactRepository

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ def test_create_model_version_gcp(store, local_model_dir, create_args):
store.create_model_version(**create_kwargs)
# Verify that gcs artifact repo mock was called with expected args
gcs_artifact_repo_class_mock.assert_called_once_with(
artifact_uri=storage_location, client=ANY
artifact_uri=storage_location, client=ANY, credential_refresh_def=ANY
)
mock_gcs_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="")
gcs_client_args = gcs_client_class_mock.call_args_list[0]
Expand Down
58 changes: 58 additions & 0 deletions tests/store/artifact/test_gcs_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,61 @@ def test_abort_multipart_upload(mock_client):
f"{gcs_base_url}/{bucket_name}/{artifact_root_path}/{file_name}?uploadId={upload_id}",
)
assert kwargs["data"] is None


@pytest.mark.parametrize("throw", [True, False])
def test_retryable_log_artifacts(throw, tmp_path):
with mock.patch("google.cloud.storage.Client") as mock_gcs_client_factory, mock.patch(
"google.oauth2.credentials.Credentials"
) as mock_gcs_credentials_factory:
gcs_client_mock = mock.Mock()
gcs_bucket_mock = mock.Mock()
gcs_client_mock.bucket.return_value = gcs_bucket_mock

gcs_refreshed_client_mock = mock.Mock()
gcs_refreshed_bucket_mock = mock.Mock()
gcs_refreshed_client_mock.bucket.return_value = gcs_refreshed_bucket_mock
mock_gcs_client_factory.return_value = gcs_refreshed_client_mock

def exception_thrown_side_effect_func(*args, **kwargs):
if throw:
raise Exception("Test Exception")
return None

def success_side_effect_func(*args, **kwargs):
return None

def creds_func():
return {"oauth_token": "new_creds"}

gcs_bucket_mock.blob.return_value.upload_from_filename.side_effect = (
exception_thrown_side_effect_func
)
gcs_refreshed_bucket_mock.blob.return_value.upload_from_filename.side_effect = (
success_side_effect_func
)

repo = GCSArtifactRepository(
artifact_uri="gs://test_bucket/test_root/",
client=gcs_client_mock,
credential_refresh_def=creds_func,
)

data = tmp_path.joinpath("data")
data.mkdir()
subd = data.joinpath("subdir")
subd.mkdir()
subd.joinpath("a.txt").write_text("A")

repo.log_artifacts(subd)

if throw:
gcs_bucket_mock.blob.assert_called_once()
gcs_refreshed_bucket_mock.blob.assert_called_once()
mock_gcs_client_factory.assert_called_once()
mock_gcs_credentials_factory.assert_called_once()
else:
gcs_bucket_mock.blob.assert_called_once()
gcs_refreshed_bucket_mock.blob.assert_not_called()
mock_gcs_client_factory.assert_not_called()
mock_gcs_credentials_factory.assert_not_called()
4 changes: 2 additions & 2 deletions tests/store/artifact/test_presigned_url_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
HttpHeader,
ListDirectoryResponse,
)
from mlflow.store.artifact.cloud_artifact_repo import _retry_with_new_creds
from mlflow.store.artifact.artifact_repo import _retry_with_new_creds
from mlflow.store.artifact.presigned_url_artifact_repo import (
DIRECTORIES_ENDPOINT,
FILESYSTEM_METHOD_TO_INFO,
Expand Down Expand Up @@ -342,7 +342,7 @@ def try_func(creds):
mock_creds = mock.Mock(side_effect=creds_func)
mock_func = mock.Mock(side_effect=try_func)
if use_og_creds:
_retry_with_new_creds(try_func=mock_func, creds_func=mock_creds, og_creds=credentials)
_retry_with_new_creds(try_func=mock_func, creds_func=mock_creds, orig_creds=credentials)
else:
_retry_with_new_creds(try_func=mock_func, creds_func=mock_creds)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_uc_models_artifact_repo_download_artifacts_uses_temporary_creds_gcp(mon
)
assert models_repo.download_artifacts("artifact_path", "dst_path") == fake_local_path
gcs_artifact_repo_class_mock.assert_called_once_with(
artifact_uri=artifact_location, client=ANY
artifact_uri=artifact_location, client=ANY, credential_refresh_def=ANY
)
mock_gcs_repo.download_artifacts.assert_called_once_with("artifact_path", "dst_path")
gcs_client_args = gcs_client_class_mock.call_args_list[0]
Expand Down

0 comments on commit da13335

Please sign in to comment.