Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error messaging for overrides #1807

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
from flytekit.models.task import Resources as _resources_model


def assert_not_promise(v: Any, location: str):
"""
This function will raise an exception if the value is a promise. This should be used to ensure that we don't
accidentally use a promise in a place where we don't support it.
"""
from flytekit.core.promise import Promise

if isinstance(v, Promise):
raise AssertionError(f"Cannot use a promise in the {location} Value: {v}")


def assert_no_promises_in_resources(resources: _resources_model):
"""
This function will raise an exception if any of the resources have promises in them. This is because we don't
support promises in resources / runtime overriding of resources through input values.
"""
if resources is None:
return
if resources.requests is not None:
for r in resources.requests:
assert_not_promise(r.value, "resources.requests")
if resources.limits is not None:
for r in resources.limits:
assert_not_promise(r.value, "resources.limits")


class Node(object):
"""
This class will hold all the things necessary to make an SdkNode but we won't make one until we know things like
Expand Down Expand Up @@ -86,23 +112,29 @@ def with_overrides(self, *args, **kwargs):
if "node_name" in kwargs:
# Convert the node name into a DNS-compliant.
# https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names
self._id = _dnsify(kwargs["node_name"])
v = kwargs["node_name"]
assert_not_promise(v, "node_name")
self._id = _dnsify(v)

if "aliases" in kwargs:
alias_dict = kwargs["aliases"]
if not isinstance(alias_dict, dict):
raise AssertionError("Aliases should be specified as dict[str, str]")
self._aliases = []
for k, v in alias_dict.items():
self._aliases.append(_workflow_model.Alias(var=k, alias=v))

if "requests" in kwargs or "limits" in kwargs:
requests = kwargs.get("requests")
if requests and not isinstance(requests, Resources):
raise AssertionError("requests should be specified as flytekit.Resources")
limits = kwargs.get("limits")
if limits and not isinstance(limits, Resources):
raise AssertionError("limits should be specified as flytekit.Resources")
resources = convert_resources_to_resource_model(requests=requests, limits=limits)
assert_no_promises_in_resources(resources)
self._resources = resources

self._resources = convert_resources_to_resource_model(requests=requests, limits=limits)
if "timeout" in kwargs:
timeout = kwargs["timeout"]
if timeout is None:
Expand All @@ -115,21 +147,31 @@ def with_overrides(self, *args, **kwargs):
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if "retries" in kwargs:
retries = kwargs["retries"]
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if "interruptible" in kwargs:
v = kwargs["interruptible"]
assert_not_promise(v, "interruptible")
self._metadata._interruptible = kwargs["interruptible"]

if "name" in kwargs:
self._metadata._name = kwargs["name"]

if "task_config" in kwargs:
logger.warning("This override is beta. We may want to revisit this in the future.")
new_task_config = kwargs["task_config"]
if not isinstance(new_task_config, type(self.flyte_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.flyte_entity._task_config = new_task_config

if "container_image" in kwargs:
self.flyte_entity._container_image = kwargs["container_image"]
v = kwargs["container_image"]
assert_not_promise(v, "container_image")
self.flyte_entity._container_image = v

return self


Expand Down
13 changes: 13 additions & 0 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ class Resources(object):
storage: Optional[str] = None
ephemeral_storage: Optional[str] = None

def __post_init__(self):
def _check_none_or_str(value):
if value is None:
return
if not isinstance(value, str):
raise AssertionError(f"{value} should be a string")

_check_none_or_str(self.cpu)
_check_none_or_str(self.mem)
_check_none_or_str(self.gpu)
_check_none_or_str(self.storage)
_check_none_or_str(self.ephemeral_storage)


@dataclass
class ResourceSpec(object):
Expand Down
1 change: 0 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def _candidate_name_matches(candidate) -> bool:
return k
except ValueError as err:
logger.warning(f"Caught ValueError {err} while attempting to auto-assign name")
pass

logger.error(f"Could not find LHS for {self} in {self._instantiated_in}")
raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}")
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _
assert limit.name == expected_resource_name
assert limit.value == expected_resource_value
assert len(resources_model.requests) == 0


def test_incorrect_type_resources():
with pytest.raises(AssertionError):
Resources(cpu=1) # type: ignore
with pytest.raises(AssertionError):
Resources(mem=1) # type: ignore
with pytest.raises(AssertionError):
Resources(gpu=1) # type: ignore
with pytest.raises(AssertionError):
Resources(storage=1) # type: ignore
with pytest.raises(AssertionError):
Resources(ephemeral_storage=1) # type: ignore
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,3 +1866,42 @@ def run_things() -> MyOutput:
return produce_things()

assert run_things().ref == "ref"


def test_promise_not_allowed_in_overrides():
@task
def t1(a: int) -> int:
return a + 1

@workflow
def my_wf(a: int, cpu: str) -> int:
return t1(a=a).with_overrides(requests=Resources(cpu=cpu))

with pytest.raises(AssertionError):
my_wf(a=1, cpu=1)


def test_promise_illegal_resources():
@task
def t1(a: int) -> int:
return a + 1

@workflow
def my_wf(a: int) -> int:
return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore

with pytest.raises(AssertionError):
my_wf(a=1)


def test_promise_illegal_retries():
@task
def t1(a: int) -> int:
return a + 1

@workflow
def my_wf(a: int, retries: int) -> int:
return t1(a=a).with_overrides(retries=retries)

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)