Skip to content

Commit

Permalink
[Lambda] Fix missing private ip (#4635)
Browse files Browse the repository at this point in the history
* [Lambda] Fix missing private ip

* remove extra API call

* format test
  • Loading branch information
bend-works authored Feb 4, 2025
1 parent 3b4f31b commit e4ad98c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
18 changes: 17 additions & 1 deletion sky/provision/lambda_cloud/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def _get_ssh_key_name(prefix: str = '') -> str:
return name


def _get_private_ip(instance_info: Dict[str, Any], single_node: bool) -> str:
private_ip = instance_info.get('private_ip')
if private_ip is None:
if single_node:
# The Lambda cloud API may return an instance info without
# private IP. It does not align with their docs, but we still
# allow single-node cluster to proceed with provisioning, by using
# 127.0.0.1, as private IP is not critical for single-node case.
return '127.0.0.1'
msg = f'Failed to retrieve private IP for instance {instance_info}.'
logger.error(msg)
raise RuntimeError(msg)
return private_ip


def run_instances(region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
"""Runs instances for the given cluster"""
Expand Down Expand Up @@ -197,13 +212,14 @@ def get_cluster_info(
) -> common.ClusterInfo:
del region # unused
running_instances = _filter_instances(cluster_name_on_cloud, ['active'])
single_node = len(running_instances) == 1
instances: Dict[str, List[common.InstanceInfo]] = {}
head_instance_id = None
for instance_id, instance_info in running_instances.items():
instances[instance_id] = [
common.InstanceInfo(
instance_id=instance_id,
internal_ip=instance_info['private_ip'],
internal_ip=_get_private_ip(instance_info, single_node),
external_ip=instance_info['ip'],
ssh_port=22,
tags={},
Expand Down
15 changes: 15 additions & 0 deletions tests/unit_tests/test_lambda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from sky.provision.lambda_cloud.instance import _get_private_ip


def test_get_private_ip():
valid_info = {'private_ip': '10.19.83.125'}
invalid_info = {}
assert _get_private_ip(valid_info,
single_node=True) == valid_info['private_ip']
assert _get_private_ip(valid_info,
single_node=False) == valid_info['private_ip']
assert _get_private_ip(invalid_info, single_node=True) == '127.0.0.1'
with pytest.raises(RuntimeError):
_get_private_ip(invalid_info, single_node=False)

0 comments on commit e4ad98c

Please sign in to comment.