diff --git a/pytensor/compile/function/__init__.py b/pytensor/compile/function/__init__.py index 7fa3a179ac..c33108951d 100644 --- a/pytensor/compile/function/__init__.py +++ b/pytensor/compile/function/__init__.py @@ -101,6 +101,7 @@ def function( | dict[Variable, Variable] | None = None, no_default_updates: bool = False, + trust_input: bool = False, accept_inplace: bool = False, name: str | None = None, rebuild_strict: bool = True, @@ -310,7 +311,7 @@ def opt_log1p(node): "semantics, which disallow using updates and givens" ) fn = orig_function( - inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name + inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name, trust_input=trust_input ) else: # note: pfunc will also call orig_function -- orig_function is @@ -322,6 +323,7 @@ def opt_log1p(node): updates=updates, givens=givens, no_default_updates=no_default_updates, + trust_input=trust_input, accept_inplace=accept_inplace, name=name, rebuild_strict=rebuild_strict, diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index 49a6840719..b263375ce8 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -370,6 +370,7 @@ def pfunc( givens=None, no_default_updates=False, accept_inplace=False, + trust_input=False, name=None, rebuild_strict=True, allow_input_downcast=None, @@ -467,6 +468,7 @@ def pfunc( cloned_outputs, mode, accept_inplace=accept_inplace, + trust_input=trust_input, name=name, profile=profile, on_unused_input=on_unused_input, @@ -478,7 +480,6 @@ def pfunc( def construct_pfunc_ins_and_outs( params, outputs=None, - mode=None, updates=None, givens=None, no_default_updates=False, diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 43199328a3..d711282581 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -335,6 +335,7 @@ def __init__( return_none: bool, output_keys, maker: "FunctionMaker", + trust_input: bool, name: str | None = None, ): """ @@ -383,7 +384,7 @@ def __init__( self.return_none = return_none self.maker = maker self.profile = None # reassigned in FunctionMaker.create - self.trust_input = False # If True, we don't check the input parameter + self.trust_input = trust_input # If True, we don't check the input parameter self.name = name self.nodes_with_inner_function = [] self.output_keys = output_keys @@ -792,18 +793,14 @@ def __call__(self, *args, **kwargs): The function inputs can be passed as keyword argument. For this, use the name of the input or the input instance as the key. - Keyword argument ``output_subset`` is a list of either indices of the - function's outputs or the keys belonging to the `output_keys` dict - and represent outputs that are requested to be calculated. Regardless - of the presence of ``output_subset``, the updates are always calculated - and processed. To disable the updates, you should use the ``copy`` + The updates are always calculated and processed. + To disable the updates, you should use the ``copy`` method with ``delete_updates=True``. Returns ------- list - List of outputs on indices/keys from ``output_subset`` or all of them, - if ``output_subset`` is not passed. + List of outputs. """ def restore_defaults(): @@ -816,10 +813,6 @@ def restore_defaults(): profile = self.profile t0 = time.perf_counter() - output_subset = kwargs.pop("output_subset", None) - if output_subset is not None and self.output_keys is not None: - output_subset = [self.output_keys.index(key) for key in output_subset] - # Reinitialize each container's 'provided' counter if self.trust_input: i = 0 @@ -955,11 +948,7 @@ def restore_defaults(): # Do the actual work t0_fn = time.perf_counter() try: - outputs = ( - self.vm() - if output_subset is None - else self.vm(output_subset=output_subset) - ) + outputs = self.vm() except Exception: restore_defaults() if hasattr(self.vm, "position_of_error"): @@ -1040,24 +1029,13 @@ def restore_defaults(): profile.ignore_first_call = False if self.return_none: return None - elif self.unpack_single and len(outputs) == 1 and output_subset is None: + elif self.unpack_single and len(outputs) == 1: return outputs[0] else: if self.output_keys is not None: assert len(self.output_keys) == len(outputs) - - if output_subset is None: - return dict(zip(self.output_keys, outputs)) - else: - return { - self.output_keys[index]: outputs[index] - for index in output_subset - } - - if output_subset is None: - return outputs - else: - return [outputs[i] for i in output_subset] + return dict(zip(self.output_keys, outputs)) + return outputs value = property( lambda self: self._value, @@ -1455,6 +1433,7 @@ def __init__( outputs, mode=None, accept_inplace=False, + trust_input=False, function_builder=Function, profile=None, on_unused_input=None, @@ -1558,6 +1537,7 @@ def __init__( self.unpack_single = unpack_single self.return_none = return_none self.accept_inplace = accept_inplace + self.trust_input = trust_input self.function_builder = function_builder self.on_unused_input = on_unused_input # Used for the pickling/copy self.output_keys = output_keys @@ -1689,6 +1669,7 @@ def orig_function( outputs, mode=None, accept_inplace=False, + trust_input=False, name=None, profile=None, on_unused_input=None, @@ -1752,6 +1733,7 @@ def orig_function( outputs, mode, accept_inplace=accept_inplace, + trust_input=trust_input, profile=profile, on_unused_input=on_unused_input, output_keys=output_keys,