diff --git a/sky/provision/lambda_cloud/instance.py b/sky/provision/lambda_cloud/instance.py index d33c97df95c..13ee1aeb534 100644 --- a/sky/provision/lambda_cloud/instance.py +++ b/sky/provision/lambda_cloud/instance.py @@ -64,6 +64,21 @@ def _get_ssh_key_name(prefix: str = '') -> str: return name +def _get_private_ip(instance_info: Dict[str, Any], single_node: bool) -> str: + private_ip = instance_info.get('private_ip') + if private_ip is None: + if single_node: + # The Lambda cloud API may return an instance info without + # private IP. It does not align with their docs, but we still + # allow single-node cluster to proceed with provisioning, by using + # 127.0.0.1, as private IP is not critical for single-node case. + return '127.0.0.1' + msg = f'Failed to retrieve private IP for instance {instance_info}.' + logger.error(msg) + raise RuntimeError(msg) + return private_ip + + def run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Runs instances for the given cluster""" @@ -197,13 +212,14 @@ def get_cluster_info( ) -> common.ClusterInfo: del region # unused running_instances = _filter_instances(cluster_name_on_cloud, ['active']) + single_node = len(running_instances) == 1 instances: Dict[str, List[common.InstanceInfo]] = {} head_instance_id = None for instance_id, instance_info in running_instances.items(): instances[instance_id] = [ common.InstanceInfo( instance_id=instance_id, - internal_ip=instance_info['private_ip'], + internal_ip=_get_private_ip(instance_info, single_node), external_ip=instance_info['ip'], ssh_port=22, tags={}, diff --git a/tests/unit_tests/test_lambda.py b/tests/unit_tests/test_lambda.py new file mode 100644 index 00000000000..88c87958bf6 --- /dev/null +++ b/tests/unit_tests/test_lambda.py @@ -0,0 +1,15 @@ +import pytest + +from sky.provision.lambda_cloud.instance import _get_private_ip + + +def test_get_private_ip(): + valid_info = {'private_ip': '10.19.83.125'} + invalid_info = {} + assert _get_private_ip(valid_info, + single_node=True) == valid_info['private_ip'] + assert _get_private_ip(valid_info, + single_node=False) == valid_info['private_ip'] + assert _get_private_ip(invalid_info, single_node=True) == '127.0.0.1' + with pytest.raises(RuntimeError): + _get_private_ip(invalid_info, single_node=False)