diff --git a/sky/skylet/providers/azure/config.py b/sky/skylet/providers/azure/config.py index a937102f579..0c1827a1141 100644 --- a/sky/skylet/providers/azure/config.py +++ b/sky/skylet/providers/azure/config.py @@ -3,14 +3,19 @@ import random from hashlib import sha256 from pathlib import Path +import time from typing import Any, Callable from azure.common.credentials import get_cli_profile from azure.identity import AzureCliCredential +from azure.mgmt.network import NetworkManagementClient from azure.mgmt.resource import ResourceManagementClient from azure.mgmt.resource.resources.models import DeploymentMode +from sky.utils import common_utils + UNIQUE_ID_LEN = 4 +_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS = 600 logger = logging.getLogger(__name__) @@ -50,9 +55,8 @@ def _configure_resource_group(config): # Increase the timeout to fix the Azure get-access-token (used by ray azure # node_provider) timeout issue. # Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110 - resource_client = ResourceManagementClient( - AzureCliCredential(process_timeout=30), subscription_id - ) + credentials = AzureCliCredential(process_timeout=30) + resource_client = ResourceManagementClient(credentials, subscription_id) config["provider"]["subscription_id"] = subscription_id logger.info("Using subscription id: %s", subscription_id) @@ -126,9 +130,32 @@ def _configure_resource_group(config): .properties.outputs ) + # We should wait for the NSG to be created before opening any ports + # to avoid overriding the newly-added NSG rules. + nsg_id = outputs["nsg"]["value"] + nsg_name = nsg_id.split("/")[-1] + network_client = NetworkManagementClient(credentials, subscription_id) + backoff = common_utils.Backoff(max_backoff_factor=1) + start_time = time.time() + while True: + nsg = network_client.network_security_groups.get(resource_group, nsg_name) + if nsg.provisioning_state == "Succeeded": + break + if time.time() - start_time > _WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS: + raise RuntimeError( + f"Fails to create NSG {nsg_name} in {resource_group} within " + f"{_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS} seconds." + ) + backoff_time = backoff.current_backoff() + logger.info( + f"NSG {nsg_name} is not created yet. Waiting for " + f"{backoff_time} seconds before checking again." + ) + time.sleep(backoff_time) + # append output resource ids to be used with vm creation config["provider"]["msi"] = outputs["msi"]["value"] - config["provider"]["nsg"] = outputs["nsg"]["value"] + config["provider"]["nsg"] = nsg_id config["provider"]["subnet"] = outputs["subnet"]["value"] return config