Skip to content

Commit

Permalink
Adds tagging for sagemaker endpoints and sagemaker config. Issue mlfl…
Browse files Browse the repository at this point in the history
…ow#9159 (mlflow#9310)

Signed-off-by: Clark Hollar <[email protected]>
  • Loading branch information
clarkh-ncino authored and B-Step62 committed Jan 9, 2024
1 parent d51e71d commit d3f6faa
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 17 deletions.
35 changes: 29 additions & 6 deletions mlflow/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
import urllib.parse
from subprocess import Popen
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import mlflow
import mlflow.version
Expand Down Expand Up @@ -50,6 +50,7 @@

DEFAULT_REGION_NAME = "us-west-2"
SAGEMAKER_SERVING_ENVIRONMENT = "SageMaker"
SAGEMAKER_APP_NAME_TAG_KEY = "app_name"

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1346,6 +1347,26 @@ def _get_sagemaker_config_name(endpoint_name):
return f"{endpoint_name}-config-{get_unique_resource_id()}"


def _get_sagemaker_config_tags(endpoint_name):
return [{"Key": SAGEMAKER_APP_NAME_TAG_KEY, "Value": endpoint_name}]


def _prepare_sagemaker_tags(
config_tags: List[Dict[str, str]],
sagemaker_tags: Optional[Dict[str, str]] = None,
):
if not sagemaker_tags:
return config_tags

if SAGEMAKER_APP_NAME_TAG_KEY in sagemaker_tags:
raise MlflowException.invalid_parameter_value(
f"Duplicate tag provided for '{SAGEMAKER_APP_NAME_TAG_KEY}'"
)
parsed = [{"Key": key, "Value": str(value)} for key, value in sagemaker_tags.items()]

return config_tags + parsed


def _create_sagemaker_transform_job(
job_name,
model_name,
Expand Down Expand Up @@ -1559,10 +1580,12 @@ def _create_sagemaker_endpoint(
production_variant["InitialInstanceCount"] = instance_count

config_name = _get_sagemaker_config_name(endpoint_name)
config_tags = _get_sagemaker_config_tags(endpoint_name)
tags_list = _prepare_sagemaker_tags(config_tags, tags)
endpoint_config_kwargs = {
"EndpointConfigName": config_name,
"ProductionVariants": [production_variant],
"Tags": [{"Key": "app_name", "Value": endpoint_name}],
"Tags": config_tags,
}
if async_inference_config:
endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config
Expand All @@ -1576,7 +1599,7 @@ def _create_sagemaker_endpoint(
endpoint_response = sage_client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=config_name,
Tags=[],
Tags=tags_list or [],
)
_logger.info("Created endpoint with arn: %s", endpoint_response["EndpointArn"])

Expand Down Expand Up @@ -1656,7 +1679,7 @@ def _update_sagemaker_endpoint(
For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
Defaults to ``None``.
:param env: A dictionary of environment variables to set for the model.
:param tags: A dictionary of tags to apply to the endpoint.
:param tags: A dictionary of tags to apply to the endpoint configuration.
"""
if mode not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]:
msg = f"Invalid mode `{mode}` for deployment to a pre-existing application"
Expand Down Expand Up @@ -1714,11 +1737,11 @@ def _update_sagemaker_endpoint(
# Create the new endpoint configuration and update the endpoint
# to adopt the new configuration
new_config_name = _get_sagemaker_config_name(endpoint_name)
# This is the hardcoded config for endpoint
config_tags = _get_sagemaker_config_tags(endpoint_name)
endpoint_config_kwargs = {
"EndpointConfigName": new_config_name,
"ProductionVariants": production_variants,
"Tags": [{"Key": "app_name", "Value": endpoint_name}],
"Tags": config_tags,
}
if async_inference_config:
endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config
Expand Down
19 changes: 12 additions & 7 deletions tests/sagemaker/mock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ def list_tags(self):
Handler for the SageMaker "ListTags" API call documented here:
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ListTags.html
"""
model_arn = self.request_params["ResourceArn"]

arn = self.request_params["ResourceArn"]
sagemaker_resource = (
"models" if "model" in arn else "endpoints" if "endpoint" in arn else None
)
results = self.sagemaker_backend.list_tags(
resource_arn=model_arn,
region_name=self.region,
resource_arn=arn, region_name=self.region, resource_type=sagemaker_resource
)

return json.dumps({"Tags": results, "NextToken": None})
Expand Down Expand Up @@ -511,13 +512,17 @@ def list_models(self):
summaries.append(summary)
return summaries

def list_tags(self, resource_arn, region_name): # pylint: disable=unused-argument
def list_tags(
self, resource_arn, region_name, resource_type
): # pylint: disable=unused-argument
"""
Modifies backend state during calls to the SageMaker "ListTags" API
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ListTags.html
"""
model = next(model for model in self.models.values() if model.arn == resource_arn)
return model.resource.tags
resource_values = getattr(self, resource_type).values()
for sagemaker_resource in resource_values:
if sagemaker_resource.arn == resource_arn:
return sagemaker_resource.resource.tags

def create_model(
self, model_name, primary_container, execution_role_arn, tags, region_name, vpc_config=None
Expand Down
39 changes: 35 additions & 4 deletions tests/sagemaker/test_sagemaker_deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,39 @@ def test_create_deployment_create_sagemaker_and_s3_resources_with_expected_tags_
model_name = endpoint_production_variants[0]["VariantName"]
description = sagemaker_client.describe_model(ModelName=model_name)

tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"])
model_tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"])
endpoint_tags = sagemaker_client.list_tags(ResourceArn=endpoint_description["EndpointArn"])

# Extra tags exist besides the ones we set, so avoid strict equality
assert all(tag in tags["Tags"] for tag in expected_tags)
assert all(tag in model_tags["Tags"] for tag in expected_tags)
assert all(tag in endpoint_tags["Tags"] for tag in expected_tags)


def test_prepare_sagemaker_tags_without_custom_tags():
config_tags = [{"Key": "tag1", "Value": "value1"}]
tags = mfs._prepare_sagemaker_tags(config_tags, None)
assert tags == config_tags


def test_prepare_sagemaker_tags_when_custom_tags_are_added():
config_tags = [{"Key": "tag1", "Value": "value1"}]
sagemaker_tags = {"tag2": "value2", "tag3": "123"}
expected_tags = [
{"Key": "tag1", "Value": "value1"},
{"Key": "tag2", "Value": "value2"},
{"Key": "tag3", "Value": "123"},
]
tags = mfs._prepare_sagemaker_tags(config_tags, sagemaker_tags)
assert tags == expected_tags


def test_prepare_sagemaker_tags_duplicate_key_raises_exception():
config_tags = [{"Key": "app_name", "Value": "a_cool_name"}]
sagemaker_tags = {"app_name": "a_cooler_name", "tag2": "value2", "tag3": "123"}
match = "Duplicate tag provided for 'app_name'"
with pytest.raises(MlflowException, match=match) as exc:
mfs._prepare_sagemaker_tags(config_tags, sagemaker_tags)
assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)


@pytest.mark.parametrize("proxies_enabled", [True, False])
Expand Down Expand Up @@ -673,10 +702,12 @@ def test_deploy_cli_creates_sagemaker_and_s3_resources_with_expected_tags_from_l
model_name = endpoint_production_variants[0]["VariantName"]
description = sagemaker_client.describe_model(ModelName=model_name)

tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"])
model_tags = sagemaker_client.list_tags(ResourceArn=description["ModelArn"])
endpoint_tags = sagemaker_client.list_tags(ResourceArn=endpoint_description["EndpointArn"])

# Extra tags exist besides the ones we set, so avoid strict equality
assert all(tag in tags["Tags"] for tag in expected_tags)
assert all(tag in model_tags["Tags"] for tag in expected_tags)
assert all(tag in endpoint_tags["Tags"] for tag in expected_tags)


@pytest.mark.parametrize("proxies_enabled", [True, False])
Expand Down

0 comments on commit d3f6faa

Please sign in to comment.