Skip to content

Commit

Permalink
don't set fixed inputs under input bindings
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt committed Feb 4, 2025
1 parent f0c48bc commit 0fa70ee
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 23 deletions.
12 changes: 7 additions & 5 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,19 @@ def interface(self) -> _interface_models.TypedInterface:
@property
def bindings(self) -> List[_literal_models.Binding]:
# Required in get_serializable_node
return self._bindings
bindings = []

Check warning on line 129 in flytekit/core/array_node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node.py#L129

Added line #L129 was not covered by tests
for binding in self._bindings:
if binding.var not in self._bound_inputs:
bindings.append(binding)
return bindings

Check warning on line 133 in flytekit/core/array_node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node.py#L132-L133

Added lines #L132 - L133 were not covered by tests

@property
def fixed_inputs(self) -> List[_literal_models.Binding]:
# TODO - clean this up
# Required in get_serializable_node
fixed_inputs = []

Check warning on line 138 in flytekit/core/array_node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node.py#L138

Added line #L138 was not covered by tests
for binding in self.bindings:
for binding in self._bindings:
if binding.var in self._bound_inputs:
fixed_inputs.append(binding)
if len(fixed_inputs) != len(self._bound_inputs):
raise ValueError("Error binding fixed inputs")
return fixed_inputs

Check warning on line 142 in flytekit/core/array_node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node.py#L141-L142

Added lines #L141 - L142 were not covered by tests

@property
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,10 @@ def _raw_execute(self, **kwargs) -> Any:

def map_task(
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
fixed_inputs: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
fixed_inputs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
Expand Down
23 changes: 22 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
)
Expand All @@ -435,7 +436,7 @@ def get_serializable_node(
node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
inputs=entity.flyte_entity.bindings,
fixed_inputs=entity.flyte_entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
Expand All @@ -446,6 +447,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
array_node=get_serializable_array_node_map_task(entity_mapping, settings, entity, options=options),
Expand All @@ -459,6 +461,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
Expand All @@ -479,6 +482,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id),
Expand All @@ -489,6 +493,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options),
Expand All @@ -502,11 +507,17 @@ def get_serializable_node(
for b in entity.bindings:
if b.var not in entity.flyte_entity.fixed_inputs.literals:
node_input.append(b)
# TODO clean up and explain
fixed_node_inputs = []

Check warning on line 511 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L511

Added line #L511 was not covered by tests
for b in entity.fixed_inputs:
if b.var not in entity.flyte_entity.fixed_inputs.literals:
fixed_node_inputs.append(b)

Check warning on line 514 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L514

Added line #L514 was not covered by tests

node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=node_input,
fixed_inputs=fixed_node_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id),
Expand All @@ -528,6 +539,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
gate_node=gn,
Expand All @@ -540,6 +552,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
Expand All @@ -559,6 +572,7 @@ def get_serializable_node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.id),
Expand All @@ -571,11 +585,17 @@ def get_serializable_node(
for b in entity.bindings:
if b.var not in entity.flyte_entity.fixed_inputs.literals:
node_input.append(b)
# TODO clean up and explain
fixed_node_inputs = []

Check warning on line 589 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L589

Added line #L589 was not covered by tests
for b in entity.fixed_inputs:
if b.var not in entity.flyte_entity.fixed_inputs.literals:
fixed_node_inputs.append(b)

Check warning on line 592 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L592

Added line #L592 was not covered by tests

node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=node_input,
fixed_inputs=fixed_node_inputs,
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
workflow_node=workflow_model.WorkflowNode(launchplan_ref=entity.flyte_entity.id),
Expand Down Expand Up @@ -627,6 +647,7 @@ def get_serializable_array_node_map_task(
id=entity.name,
metadata=entity.sub_node_metadata,
inputs=node.bindings,
fixed_inputs=entity.fixed_inputs,
upstream_node_ids=[],
output_aliases=[],
task_node=task_node,
Expand Down
40 changes: 24 additions & 16 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,33 +154,41 @@ def test_lp_serialization(target, overrides_metadata, upstream_nodes, fixed_inpu
assert set(parent_node.upstream_node_ids) == set(upstream_nodes)
assert len(parent_node.fixed_inputs) == len(fixed_inputs)

assert parent_node.inputs[0].var == "a"
inputs_map = {x.var: x for x in parent_node.inputs}
fixed_inputs_map = {x.var: x for x in parent_node.fixed_inputs}

if "a" in fixed_inputs:
for fixed_input in parent_node.fixed_inputs:
if fixed_input.var == "a":
assert fixed_input == parent_node.inputs[0]
fixed_input = fixed_inputs_map["a"]
assert fixed_input
else:
assert len(parent_node.inputs[0].binding.collection.bindings) == 3
for i, binding in enumerate(parent_node.inputs[0].binding.collection.bindings):
node_input = inputs_map["a"]
assert node_input
assert len(node_input.binding.collection.bindings) == 3
for i, binding in enumerate(node_input.binding.collection.bindings):
if upstream_nodes and i == 0:
assert binding.promise.node_id == upstream_nodes[0]
else:
assert (binding.scalar.primitive.integer is not None)
assert parent_node.inputs[1].var == "b"
if "b" in fixed_inputs:
for fixed_input in parent_node.fixed_inputs:
if fixed_input.var == "b":
assert fixed_input == parent_node.inputs[1]
fixed_input = fixed_inputs_map["b"]
assert fixed_input
else:
for binding in parent_node.inputs[1].binding.collection.bindings:
node_input = inputs_map["b"]
assert node_input
for binding in node_input.binding.collection.bindings:
assert (binding.scalar.union is not None or
binding.scalar.primitive.integer is not None or
binding.scalar.primitive.string_value is not None)
assert len(parent_node.inputs[1].binding.collection.bindings) == 3
assert parent_node.inputs[2].var == "c"
assert len(parent_node.inputs[2].binding.collection.bindings) == 3
for binding in parent_node.inputs[2].binding.collection.bindings:
assert binding.scalar.primitive.integer is not None
assert len(node_input.binding.collection.bindings) == 3
if "c" in fixed_inputs:
fixed_input = fixed_inputs_map["c"]
assert fixed_input
else:
node_input = inputs_map["c"]
assert node_input
assert len(node_input.binding.collection.bindings) == 3
for binding in node_input.binding.collection.bindings:
assert binding.scalar.primitive.integer is not None

serialized_array_node = parent_node.array_node
assert (
Expand Down

0 comments on commit 0fa70ee

Please sign in to comment.