From aaac7da6c5ec43361a4e54e2f6555b8c5f2dfd8c Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Sun, 27 Aug 2023 21:52:40 -0700 Subject: [PATCH] Better error messaging for overrides - using incorrect type of overrides - using incorrect type for resources - using promises in overrides Signed-off-by: Ketan Umare --- flytekit/core/node.py | 48 +++++++++++++++++++-- flytekit/core/resources.py | 13 ++++++ flytekit/core/tracker.py | 1 - tests/flytekit/unit/core/test_resources.py | 13 ++++++ tests/flytekit/unit/core/test_type_hints.py | 39 +++++++++++++++++ 5 files changed, 110 insertions(+), 4 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index bf5c97ba60..1038c00521 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -12,6 +12,32 @@ from flytekit.models.task import Resources as _resources_model +def assert_not_promise(v: Any, location: str): + """ + This function will raise an exception if the value is a promise. This should be used to ensure that we don't + accidentally use a promise in a place where we don't support it. + """ + from flytekit.core.promise import Promise + + if isinstance(v, Promise): + raise AssertionError(f"Cannot use a promise in the {location} Value: {v}") + + +def assert_no_promises_in_resources(resources: _resources_model): + """ + This function will raise an exception if any of the resources have promises in them. This is because we don't + support promises in resources / runtime overriding of resources through input values. + """ + if resources is None: + return + if resources.requests is not None: + for r in resources.requests: + assert_not_promise(r.value, "resources.requests") + if resources.limits is not None: + for r in resources.limits: + assert_not_promise(r.value, "resources.limits") + + class Node(object): """ This class will hold all the things necessary to make an SdkNode but we won't make one until we know things like @@ -86,7 +112,10 @@ def with_overrides(self, *args, **kwargs): if "node_name" in kwargs: # Convert the node name into a DNS-compliant. # https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names - self._id = _dnsify(kwargs["node_name"]) + v = kwargs["node_name"] + assert_not_promise(v, "node_name") + self._id = _dnsify(v) + if "aliases" in kwargs: alias_dict = kwargs["aliases"] if not isinstance(alias_dict, dict): @@ -94,6 +123,7 @@ def with_overrides(self, *args, **kwargs): self._aliases = [] for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) + if "requests" in kwargs or "limits" in kwargs: requests = kwargs.get("requests") if requests and not isinstance(requests, Resources): @@ -101,8 +131,10 @@ def with_overrides(self, *args, **kwargs): limits = kwargs.get("limits") if limits and not isinstance(limits, Resources): raise AssertionError("limits should be specified as flytekit.Resources") + resources = convert_resources_to_resource_model(requests=requests, limits=limits) + assert_no_promises_in_resources(resources) + self._resources = resources - self._resources = convert_resources_to_resource_model(requests=requests, limits=limits) if "timeout" in kwargs: timeout = kwargs["timeout"] if timeout is None: @@ -115,21 +147,31 @@ def with_overrides(self, *args, **kwargs): raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") if "retries" in kwargs: retries = kwargs["retries"] + assert_not_promise(retries, "retries") self._metadata._retries = ( _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) ) + if "interruptible" in kwargs: + v = kwargs["interruptible"] + assert_not_promise(v, "interruptible") self._metadata._interruptible = kwargs["interruptible"] + if "name" in kwargs: self._metadata._name = kwargs["name"] + if "task_config" in kwargs: logger.warning("This override is beta. We may want to revisit this in the future.") new_task_config = kwargs["task_config"] if not isinstance(new_task_config, type(self.flyte_entity._task_config)): raise ValueError("can't change the type of the task config") self.flyte_entity._task_config = new_task_config + if "container_image" in kwargs: - self.flyte_entity._container_image = kwargs["container_image"] + v = kwargs["container_image"] + assert_not_promise(v, "container_image") + self.flyte_entity._container_image = v + return self diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 4cf2523f6a..62b880f6ed 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -32,6 +32,19 @@ class Resources(object): storage: Optional[str] = None ephemeral_storage: Optional[str] = None + def __post_init__(self): + def _check_none_or_str(value): + if value is None: + return + if not isinstance(value, str): + raise AssertionError(f"{value} should be a string") + + _check_none_or_str(self.cpu) + _check_none_or_str(self.mem) + _check_none_or_str(self.gpu) + _check_none_or_str(self.storage) + _check_none_or_str(self.ephemeral_storage) + @dataclass class ResourceSpec(object): diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 13eecf84a9..b9dedc99af 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -158,7 +158,6 @@ def _candidate_name_matches(candidate) -> bool: return k except ValueError as err: logger.warning(f"Caught ValueError {err} while attempting to auto-assign name") - pass logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 1a3bf64dee..25a637b2d6 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -66,3 +66,16 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _ assert limit.name == expected_resource_name assert limit.value == expected_resource_value assert len(resources_model.requests) == 0 + + +def test_incorrect_type_resources(): + with pytest.raises(AssertionError): + Resources(cpu=1) # type: ignore + with pytest.raises(AssertionError): + Resources(mem=1) # type: ignore + with pytest.raises(AssertionError): + Resources(gpu=1) # type: ignore + with pytest.raises(AssertionError): + Resources(storage=1) # type: ignore + with pytest.raises(AssertionError): + Resources(ephemeral_storage=1) # type: ignore diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index f81f9f9ebe..4e32070f9f 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1866,3 +1866,42 @@ def run_things() -> MyOutput: return produce_things() assert run_things().ref == "ref" + + +def test_promise_not_allowed_in_overrides(): + @task + def t1(a: int) -> int: + return a + 1 + + @workflow + def my_wf(a: int, cpu: str) -> int: + return t1(a=a).with_overrides(requests=Resources(cpu=cpu)) + + with pytest.raises(AssertionError): + my_wf(a=1, cpu=1) + + +def test_promise_illegal_resources(): + @task + def t1(a: int) -> int: + return a + 1 + + @workflow + def my_wf(a: int) -> int: + return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore + + with pytest.raises(AssertionError): + my_wf(a=1) + + +def test_promise_illegal_retries(): + @task + def t1(a: int) -> int: + return a + 1 + + @workflow + def my_wf(a: int, retries: int) -> int: + return t1(a=a).with_overrides(retries=retries) + + with pytest.raises(AssertionError): + my_wf(a=1, retries=1)