Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the Function class #955

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def function(
| dict[Variable, Variable]
| None = None,
no_default_updates: bool = False,
trust_input: bool = False,
accept_inplace: bool = False,
name: str | None = None,
rebuild_strict: bool = True,
Expand Down Expand Up @@ -310,7 +311,7 @@ def opt_log1p(node):
"semantics, which disallow using updates and givens"
)
fn = orig_function(
inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name, trust_input=trust_input
)
else:
# note: pfunc will also call orig_function -- orig_function is
Expand All @@ -322,6 +323,7 @@ def opt_log1p(node):
updates=updates,
givens=givens,
no_default_updates=no_default_updates,
trust_input=trust_input,
accept_inplace=accept_inplace,
name=name,
rebuild_strict=rebuild_strict,
Expand Down
3 changes: 2 additions & 1 deletion pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def pfunc(
givens=None,
no_default_updates=False,
accept_inplace=False,
trust_input=False,
name=None,
rebuild_strict=True,
allow_input_downcast=None,
Expand Down Expand Up @@ -467,6 +468,7 @@ def pfunc(
cloned_outputs,
mode,
accept_inplace=accept_inplace,
trust_input=trust_input,
name=name,
profile=profile,
on_unused_input=on_unused_input,
Expand All @@ -478,7 +480,6 @@ def pfunc(
def construct_pfunc_ins_and_outs(
params,
outputs=None,
mode=None,
updates=None,
givens=None,
no_default_updates=False,
Expand Down
44 changes: 13 additions & 31 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def __init__(
return_none: bool,
output_keys,
maker: "FunctionMaker",
trust_input: bool,
name: str | None = None,
):
"""
Expand Down Expand Up @@ -383,7 +384,7 @@ def __init__(
self.return_none = return_none
self.maker = maker
self.profile = None # reassigned in FunctionMaker.create
self.trust_input = False # If True, we don't check the input parameter
self.trust_input = trust_input # If True, we don't check the input parameter
self.name = name
self.nodes_with_inner_function = []
self.output_keys = output_keys
Expand Down Expand Up @@ -792,18 +793,14 @@ def __call__(self, *args, **kwargs):
The function inputs can be passed as keyword argument. For this, use
the name of the input or the input instance as the key.

Keyword argument ``output_subset`` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated. Regardless
of the presence of ``output_subset``, the updates are always calculated
and processed. To disable the updates, you should use the ``copy``
The updates are always calculated and processed.
To disable the updates, you should use the ``copy``
method with ``delete_updates=True``.

Returns
-------
list
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
List of outputs.
"""

def restore_defaults():
Expand All @@ -816,10 +813,6 @@ def restore_defaults():
profile = self.profile
t0 = time.perf_counter()

output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
Expand Down Expand Up @@ -955,11 +948,7 @@ def restore_defaults():
# Do the actual work
t0_fn = time.perf_counter()
try:
outputs = (
self.vm()
if output_subset is None
else self.vm(output_subset=output_subset)
)
outputs = self.vm()
except Exception:
restore_defaults()
if hasattr(self.vm, "position_of_error"):
Expand Down Expand Up @@ -1040,24 +1029,13 @@ def restore_defaults():
profile.ignore_first_call = False
if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
elif self.unpack_single and len(outputs) == 1:
return outputs[0]
else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)

if output_subset is None:
return dict(zip(self.output_keys, outputs))
else:
return {
self.output_keys[index]: outputs[index]
for index in output_subset
}

if output_subset is None:
return outputs
else:
return [outputs[i] for i in output_subset]
return dict(zip(self.output_keys, outputs))
return outputs

value = property(
lambda self: self._value,
Expand Down Expand Up @@ -1455,6 +1433,7 @@ def __init__(
outputs,
mode=None,
accept_inplace=False,
trust_input=False,
function_builder=Function,
profile=None,
on_unused_input=None,
Expand Down Expand Up @@ -1558,6 +1537,7 @@ def __init__(
self.unpack_single = unpack_single
self.return_none = return_none
self.accept_inplace = accept_inplace
self.trust_input = trust_input
self.function_builder = function_builder
self.on_unused_input = on_unused_input # Used for the pickling/copy
self.output_keys = output_keys
Expand Down Expand Up @@ -1689,6 +1669,7 @@ def orig_function(
outputs,
mode=None,
accept_inplace=False,
trust_input=False,
name=None,
profile=None,
on_unused_input=None,
Expand Down Expand Up @@ -1752,6 +1733,7 @@ def orig_function(
outputs,
mode,
accept_inplace=accept_inplace,
trust_input=trust_input,
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
Expand Down
Loading