Skip to content

Commit

Permalink
[Provisioner] Fix open ports on Azure (#2891)
Browse files Browse the repository at this point in the history
* fix

* add timeout

* log & max backoff = 1
  • Loading branch information
cblmemo authored Dec 22, 2023
1 parent e93d400 commit a9c52e6
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions sky/skylet/providers/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
import random
from hashlib import sha256
from pathlib import Path
import time
from typing import Any, Callable

from azure.common.credentials import get_cli_profile
from azure.identity import AzureCliCredential
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.resource.resources.models import DeploymentMode

from sky.utils import common_utils

UNIQUE_ID_LEN = 4
_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS = 600

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,9 +55,8 @@ def _configure_resource_group(config):
# Increase the timeout to fix the Azure get-access-token (used by ray azure
# node_provider) timeout issue.
# Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110
resource_client = ResourceManagementClient(
AzureCliCredential(process_timeout=30), subscription_id
)
credentials = AzureCliCredential(process_timeout=30)
resource_client = ResourceManagementClient(credentials, subscription_id)
config["provider"]["subscription_id"] = subscription_id
logger.info("Using subscription id: %s", subscription_id)

Expand Down Expand Up @@ -126,9 +130,32 @@ def _configure_resource_group(config):
.properties.outputs
)

# We should wait for the NSG to be created before opening any ports
# to avoid overriding the newly-added NSG rules.
nsg_id = outputs["nsg"]["value"]
nsg_name = nsg_id.split("/")[-1]
network_client = NetworkManagementClient(credentials, subscription_id)
backoff = common_utils.Backoff(max_backoff_factor=1)
start_time = time.time()
while True:
nsg = network_client.network_security_groups.get(resource_group, nsg_name)
if nsg.provisioning_state == "Succeeded":
break
if time.time() - start_time > _WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS:
raise RuntimeError(
f"Fails to create NSG {nsg_name} in {resource_group} within "
f"{_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS} seconds."
)
backoff_time = backoff.current_backoff()
logger.info(
f"NSG {nsg_name} is not created yet. Waiting for "
f"{backoff_time} seconds before checking again."
)
time.sleep(backoff_time)

# append output resource ids to be used with vm creation
config["provider"]["msi"] = outputs["msi"]["value"]
config["provider"]["nsg"] = outputs["nsg"]["value"]
config["provider"]["nsg"] = nsg_id
config["provider"]["subnet"] = outputs["subnet"]["value"]

return config
Expand Down

0 comments on commit a9c52e6

Please sign in to comment.