From 0d1b929d2e65c7c73f827da4e3e4959ecdda62d7 Mon Sep 17 00:00:00 2001 From: Clark Hollar <79466504+clarkh-ncino@users.noreply.github.com> Date: Tue, 2 Jan 2024 17:20:12 -0500 Subject: [PATCH] Adds tagging for sagemaker endpoints and sagemaker config. Issue #9159 (#9310) Signed-off-by: Clark Hollar --- mlflow/sagemaker/__init__.py | 35 ++++++++++++++--- tests/sagemaker/mock/__init__.py | 19 +++++---- .../test_sagemaker_deployment_client.py | 39 +++++++++++++++++-- 3 files changed, 76 insertions(+), 17 deletions(-) diff --git a/mlflow/sagemaker/__init__.py b/mlflow/sagemaker/__init__.py index b647dc2405c1d2..9405fdd69506e8 100644 --- a/mlflow/sagemaker/__init__.py +++ b/mlflow/sagemaker/__init__.py @@ -11,7 +11,7 @@ import time import urllib.parse from subprocess import Popen -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import mlflow import mlflow.version @@ -50,6 +50,7 @@ DEFAULT_REGION_NAME = "us-west-2" SAGEMAKER_SERVING_ENVIRONMENT = "SageMaker" +SAGEMAKER_APP_NAME_TAG_KEY = "app_name" _logger = logging.getLogger(__name__) @@ -1346,6 +1347,26 @@ def _get_sagemaker_config_name(endpoint_name): return f"{endpoint_name}-config-{get_unique_resource_id()}" +def _get_sagemaker_config_tags(endpoint_name): + return [{"Key": SAGEMAKER_APP_NAME_TAG_KEY, "Value": endpoint_name}] + + +def _prepare_sagemaker_tags( + config_tags: List[Dict[str, str]], + sagemaker_tags: Optional[Dict[str, str]] = None, +): + if not sagemaker_tags: + return config_tags + + if SAGEMAKER_APP_NAME_TAG_KEY in sagemaker_tags: + raise MlflowException.invalid_parameter_value( + f"Duplicate tag provided for '{SAGEMAKER_APP_NAME_TAG_KEY}'" + ) + parsed = [{"Key": key, "Value": str(value)} for key, value in sagemaker_tags.items()] + + return config_tags + parsed + + def _create_sagemaker_transform_job( job_name, model_name, @@ -1559,10 +1580,12 @@ def _create_sagemaker_endpoint( production_variant["InitialInstanceCount"] = instance_count config_name = _get_sagemaker_config_name(endpoint_name) + config_tags = _get_sagemaker_config_tags(endpoint_name) + tags_list = _prepare_sagemaker_tags(config_tags, tags) endpoint_config_kwargs = { "EndpointConfigName": config_name, "ProductionVariants": [production_variant], - "Tags": [{"Key": "app_name", "Value": endpoint_name}], + "Tags": config_tags, } if async_inference_config: endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config @@ -1576,7 +1599,7 @@ def _create_sagemaker_endpoint( endpoint_response = sage_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=config_name, - Tags=[], + Tags=tags_list or [], ) _logger.info("Created endpoint with arn: %s", endpoint_response["EndpointArn"]) @@ -1656,7 +1679,7 @@ def _update_sagemaker_endpoint( For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html. Defaults to ``None``. :param env: A dictionary of environment variables to set for the model. - :param tags: A dictionary of tags to apply to the endpoint. + :param tags: A dictionary of tags to apply to the endpoint configuration. """ if mode not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]: msg = f"Invalid mode `{mode}` for deployment to a pre-existing application" @@ -1714,11 +1737,11 @@ def _update_sagemaker_endpoint( # Create the new endpoint configuration and update the endpoint # to adopt the new configuration new_config_name = _get_sagemaker_config_name(endpoint_name) - # This is the hardcoded config for endpoint + config_tags = _get_sagemaker_config_tags(endpoint_name) endpoint_config_kwargs = { "EndpointConfigName": new_config_name, "ProductionVariants": production_variants, - "Tags": [{"Key": "app_name", "Value": endpoint_name}], + "Tags": config_tags, } if async_inference_config: endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config diff --git a/tests/sagemaker/mock/__init__.py b/tests/sagemaker/mock/__init__.py index caeec377441a2c..ea7e841043ee75 100644 --- a/tests/sagemaker/mock/__init__.py +++ b/tests/sagemaker/mock/__init__.py @@ -187,11 +187,12 @@ def list_tags(self): Handler for the SageMaker "ListTags" API call documented here: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ListTags.html """ - model_arn = self.request_params["ResourceArn"] - + arn = self.request_params["ResourceArn"] + sagemaker_resource = ( + "models" if "model" in arn else "endpoints" if "endpoint" in arn else None + ) results = self.sagemaker_backend.list_tags( - resource_arn=model_arn, - region_name=self.region, + resource_arn=arn, region_name=self.region, resource_type=sagemaker_resource ) return json.dumps({"Tags": results, "NextToken": None}) @@ -511,13 +512,17 @@ def list_models(self): summaries.append(summary) return summaries - def list_tags(self, resource_arn, region_name): # pylint: disable=unused-argument + def list_tags( + self, resource_arn, region_name, resource_type + ): # pylint: disable=unused-argument """ Modifies backend state during calls to the SageMaker "ListTags" API https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ListTags.html """ - model = next(model for model in self.models.values() if model.arn == resource_arn) - return model.resource.tags + resource_values = getattr(self, resource_type).values() + for sagemaker_resource in resource_values: + if sagemaker_resource.arn == resource_arn: + return sagemaker_resource.resource.tags def create_model( self, model_name, primary_container, execution_role_arn, tags, region_name, vpc_config=None diff --git a/tests/sagemaker/test_sagemaker_deployment_client.py b/tests/sagemaker/test_sagemaker_deployment_client.py index 970c88e82875dd..11f36a7c2bcd7c 100644 --- a/tests/sagemaker/test_sagemaker_deployment_client.py +++ b/tests/sagemaker/test_sagemaker_deployment_client.py @@ -508,10 +508,39 @@ def test_create_deployment_create_sagemaker_and_s3_resources_with_expected_tags_ model_name = endpoint_production_variants[0]["VariantName"] description = sagemaker_client.describe_model(ModelName=model_name) - tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"]) + model_tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"]) + endpoint_tags = sagemaker_client.list_tags(ResourceArn=endpoint_description["EndpointArn"]) # Extra tags exist besides the ones we set, so avoid strict equality - assert all(tag in tags["Tags"] for tag in expected_tags) + assert all(tag in model_tags["Tags"] for tag in expected_tags) + assert all(tag in endpoint_tags["Tags"] for tag in expected_tags) + + +def test_prepare_sagemaker_tags_without_custom_tags(): + config_tags = [{"Key": "tag1", "Value": "value1"}] + tags = mfs._prepare_sagemaker_tags(config_tags, None) + assert tags == config_tags + + +def test_prepare_sagemaker_tags_when_custom_tags_are_added(): + config_tags = [{"Key": "tag1", "Value": "value1"}] + sagemaker_tags = {"tag2": "value2", "tag3": "123"} + expected_tags = [ + {"Key": "tag1", "Value": "value1"}, + {"Key": "tag2", "Value": "value2"}, + {"Key": "tag3", "Value": "123"}, + ] + tags = mfs._prepare_sagemaker_tags(config_tags, sagemaker_tags) + assert tags == expected_tags + + +def test_prepare_sagemaker_tags_duplicate_key_raises_exception(): + config_tags = [{"Key": "app_name", "Value": "a_cool_name"}] + sagemaker_tags = {"app_name": "a_cooler_name", "tag2": "value2", "tag3": "123"} + match = "Duplicate tag provided for 'app_name'" + with pytest.raises(MlflowException, match=match) as exc: + mfs._prepare_sagemaker_tags(config_tags, sagemaker_tags) + assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) @pytest.mark.parametrize("proxies_enabled", [True, False]) @@ -673,10 +702,12 @@ def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_tags_from_l model_name = endpoint_production_variants[0]["VariantName"] description = sagemaker_client.describe_model(ModelName=model_name) - tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"]) + model_tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"]) + endpoint_tags = sagemaker_client.list_tags(ResourceArn=endpoint_description["EndpointArn"]) # Extra tags exist besides the ones we set, so avoid strict equality - assert all(tag in tags["Tags"] for tag in expected_tags) + assert all(tag in model_tags["Tags"] for tag in expected_tags) + assert all(tag in endpoint_tags["Tags"] for tag in expected_tags) @pytest.mark.parametrize("proxies_enabled", [True, False])