Skip to content

Commit

Permalink
Partial functions still expose bound parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Sep 16, 2024
1 parent 0929970 commit dda4739
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 26 deletions.
9 changes: 4 additions & 5 deletions python/mrc/_pymrc/src/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,10 @@ std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source(mrc::s
py::function gen_factory)
{
// Determine if the gen_factory is expecting to receive a subscription object
auto inspect_mod = py::module::import("inspect");
auto signature = inspect_mod.attr("signature")(gen_factory);
auto num_params = py::len(signature.attr("parameters"));
auto inspect_mod = py::module::import("inspect");
auto signature = inspect_mod.attr("signature")(gen_factory);
auto num_params = py::len(signature.attr("parameters"));
bool expects_subscription = false;

if (num_params == 1)
{
Expand All @@ -375,8 +376,6 @@ std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source(mrc::s
{
return build_source(self, name, PyIteratorWrapper(std::move(gen_factory)));
}

throw py::value_error("Invalid number of parameters for source generator function. Expected 0 or 1");
}

std::shared_ptr<mrc::segment::ObjectProperties> BuilderProxy::make_source_component(mrc::segment::IBuilder& self,
Expand Down
21 changes: 0 additions & 21 deletions python/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,27 +489,6 @@ def on_completed():
assert on_completed_count == 1


def test_invalid_source():

def segment_init(seg: mrc.Builder):

def source_gen(a, b):
yield a + b

seg.make_source("my_src", source_gen)

pipeline = mrc.Pipeline()
pipeline.make_segment("my_seg", segment_init)

options = mrc.Options()
executor = mrc.Executor(options)
executor.register_pipeline(pipeline)

with pytest.raises(ValueError):
executor.start()
executor.join()


def test_source_with_bound_value():
"""
This test ensures that the bound values isn't confused with a subscription object
Expand Down

0 comments on commit dda4739

Please sign in to comment.