diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index cac59e50706..72fdad36b67 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -690,6 +690,7 @@ def binding_data_from_python_std( t_value: Any, t_value_type: type, nodes: List[Node], + is_union_type_variants_expanded: bool = True, ) -> _literals_models.BindingData: # This handles the case where the given value is the output of another task if isinstance(t_value, Promise): @@ -702,7 +703,7 @@ def binding_data_from_python_std( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) - elif t_value is not None and expected_literal_type.union_type is not None: + elif t_value is not None and is_union_type_variants_expanded and expected_literal_type.union_type is not None: for i in range(len(expected_literal_type.union_type.variants)): try: lt_type = expected_literal_type.union_type.variants[i] @@ -769,9 +770,17 @@ def binding_from_python_std( expected_literal_type: _type_models.LiteralType, t_value: Any, t_value_type: type, + is_union_type_variants_expanded: bool = True, ) -> Tuple[_literals_models.Binding, List[Node]]: nodes: List[Node] = [] - binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type, nodes) + binding_data = binding_data_from_python_std( + ctx, + expected_literal_type, + t_value, + t_value_type, + nodes, + is_union_type_variants_expanded=is_union_type_variants_expanded, + ) return _literals_models.Binding(var=var_name, binding=binding_data), nodes @@ -1063,6 +1072,7 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] + is_default_arg_used = False if var.type.simple == SimpleType.NONE: raise TypeError("Arguments do not have type annotation") if k not in kwargs: @@ -1076,6 +1086,7 @@ def create_and_link_node( if not isinstance(default_val, Hashable): raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument") kwargs[k] = default_val + is_default_arg_used = True else: error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}" raise _user_exceptions.FlyteAssertion(error_msg) @@ -1095,6 +1106,7 @@ def create_and_link_node( expected_literal_type=var.type, t_value=v, t_value_type=interface.inputs[k], + is_union_type_variants_expanded=not is_default_arg_used, ) bindings.append(b) nodes.extend(n) diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index ab28a668133..667a7a17bd0 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -12,8 +12,8 @@ from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec from flytekit.exceptions.user import FlyteAssertion +from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.literals import ( BindingData, @@ -586,7 +586,7 @@ def wf_with_sub_wf() -> tuple[str, str]: assert wf_with_sub_wf() == (default_val, input_val) -def test_default_args_task_optional_type_default_none(): +def test_default_args_task_optional_int_type_default_none(): default_val = None input_val = 100 @@ -648,7 +648,7 @@ def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]: assert wf_with_sub_wf() == (default_val, input_val) -def test_default_args_task_optional_type_default_int(): +def test_default_args_task_optional_int_type_default_int(): default_val = 10 input_val = 100 @@ -672,7 +672,17 @@ def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]: wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( - primitive=Primitive(integer=default_val) + union=Union( + value=Literal( + scalar=Scalar( + primitive=Primitive(integer=default_val), + ), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), ) assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( primitive=Primitive(integer=input_val) @@ -805,3 +815,115 @@ def wf_with_input() -> dict[str, int]: ) assert wf_with_input() == input_val + + +def test_default_args_task_optional_list_type_default_none(): + default_val = None + input_val = [1, 2, 3] + + @task + def t1(a: typing.Optional[typing.List[int]] = default_val) -> typing.Optional[typing.List[int]]: + return a + + @workflow + def wf_no_input() -> typing.Optional[typing.List[int]]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[typing.List[int]]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[typing.List[int]], typing.Optional[typing.List[int]]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + none_type=Void(), + ), + ), + stored_type=LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_list_type_default_list(): + input_val = [1, 2, 3] + + @task + def t1(a: typing.Optional[typing.List[int]] = []) -> typing.Optional[typing.List[int]]: + return a + + @workflow + def wf_no_input() -> typing.Optional[typing.List[int]]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[typing.List[int]]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + union_type=UnionType( + [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + + assert wf_with_input() == input_val