Skip to content

Commit

Permalink
Better error messaging for overrides (#1807)
Browse files Browse the repository at this point in the history
- using incorrect type of overrides
 - using incorrect type for resources
 - using promises in overrides

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Jeev B <[email protected]>
  • Loading branch information
kumare3 authored and jeevb committed Sep 20, 2023
1 parent 68f87d9 commit ca33b5f
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 4 deletions.
48 changes: 45 additions & 3 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
from flytekit.models.task import Resources as _resources_model


def assert_not_promise(v: Any, location: str):
"""
This function will raise an exception if the value is a promise. This should be used to ensure that we don't
accidentally use a promise in a place where we don't support it.
"""
from flytekit.core.promise import Promise

if isinstance(v, Promise):
raise AssertionError(f"Cannot use a promise in the {location} Value: {v}")


def assert_no_promises_in_resources(resources: _resources_model):
"""
This function will raise an exception if any of the resources have promises in them. This is because we don't
support promises in resources / runtime overriding of resources through input values.
"""
if resources is None:
return
if resources.requests is not None:
for r in resources.requests:
assert_not_promise(r.value, "resources.requests")
if resources.limits is not None:
for r in resources.limits:
assert_not_promise(r.value, "resources.limits")


class Node(object):
"""
This class will hold all the things necessary to make an SdkNode but we won't make one until we know things like
Expand Down Expand Up @@ -86,23 +112,29 @@ def with_overrides(self, *args, **kwargs):
if "node_name" in kwargs:
# Convert the node name into a DNS-compliant.
# https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names
self._id = _dnsify(kwargs["node_name"])
v = kwargs["node_name"]
assert_not_promise(v, "node_name")
self._id = _dnsify(v)

if "aliases" in kwargs:
alias_dict = kwargs["aliases"]
if not isinstance(alias_dict, dict):
raise AssertionError("Aliases should be specified as dict[str, str]")
self._aliases = []
for k, v in alias_dict.items():
self._aliases.append(_workflow_model.Alias(var=k, alias=v))

if "requests" in kwargs or "limits" in kwargs:
requests = kwargs.get("requests")
if requests and not isinstance(requests, Resources):
raise AssertionError("requests should be specified as flytekit.Resources")
limits = kwargs.get("limits")
if limits and not isinstance(limits, Resources):
raise AssertionError("limits should be specified as flytekit.Resources")
resources = convert_resources_to_resource_model(requests=requests, limits=limits)
assert_no_promises_in_resources(resources)
self._resources = resources

self._resources = convert_resources_to_resource_model(requests=requests, limits=limits)
if "timeout" in kwargs:
timeout = kwargs["timeout"]
if timeout is None:
Expand All @@ -115,21 +147,31 @@ def with_overrides(self, *args, **kwargs):
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if "retries" in kwargs:
retries = kwargs["retries"]
assert_not_promise(retries, "retries")
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)

if "interruptible" in kwargs:
v = kwargs["interruptible"]
assert_not_promise(v, "interruptible")
self._metadata._interruptible = kwargs["interruptible"]

if "name" in kwargs:
self._metadata._name = kwargs["name"]

if "task_config" in kwargs:
logger.warning("This override is beta. We may want to revisit this in the future.")
new_task_config = kwargs["task_config"]
if not isinstance(new_task_config, type(self.flyte_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.flyte_entity._task_config = new_task_config

if "container_image" in kwargs:
self.flyte_entity._container_image = kwargs["container_image"]
v = kwargs["container_image"]
assert_not_promise(v, "container_image")
self.flyte_entity._container_image = v

return self


Expand Down
13 changes: 13 additions & 0 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ class Resources(object):
storage: Optional[str] = None
ephemeral_storage: Optional[str] = None

def __post_init__(self):
def _check_none_or_str(value):
if value is None:
return
if not isinstance(value, str):
raise AssertionError(f"{value} should be a string")

_check_none_or_str(self.cpu)
_check_none_or_str(self.mem)
_check_none_or_str(self.gpu)
_check_none_or_str(self.storage)
_check_none_or_str(self.ephemeral_storage)


@dataclass
class ResourceSpec(object):
Expand Down
1 change: 0 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def _candidate_name_matches(candidate) -> bool:
return k
except ValueError as err:
logger.warning(f"Caught ValueError {err} while attempting to auto-assign name")
pass

logger.error(f"Could not find LHS for {self} in {self._instantiated_in}")
raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}")
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _
assert limit.name == expected_resource_name
assert limit.value == expected_resource_value
assert len(resources_model.requests) == 0


def test_incorrect_type_resources():
with pytest.raises(AssertionError):
Resources(cpu=1) # type: ignore
with pytest.raises(AssertionError):
Resources(mem=1) # type: ignore
with pytest.raises(AssertionError):
Resources(gpu=1) # type: ignore
with pytest.raises(AssertionError):
Resources(storage=1) # type: ignore
with pytest.raises(AssertionError):
Resources(ephemeral_storage=1) # type: ignore
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,3 +1866,42 @@ def run_things() -> MyOutput:
return produce_things()

assert run_things().ref == "ref"


def test_promise_not_allowed_in_overrides():
@task
def t1(a: int) -> int:
return a + 1

@workflow
def my_wf(a: int, cpu: str) -> int:
return t1(a=a).with_overrides(requests=Resources(cpu=cpu))

with pytest.raises(AssertionError):
my_wf(a=1, cpu=1)


def test_promise_illegal_resources():
@task
def t1(a: int) -> int:
return a + 1

@workflow
def my_wf(a: int) -> int:
return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore

with pytest.raises(AssertionError):
my_wf(a=1)


def test_promise_illegal_retries():
@task
def t1(a: int) -> int:
return a + 1

@workflow
def my_wf(a: int, retries: int) -> int:
return t1(a=a).with_overrides(retries=retries)

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)

0 comments on commit ca33b5f

Please sign in to comment.