diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index c9dbfb46b0..55414a94d0 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4246,7 +4246,11 @@ def __str__(self): r.name = f"o{int(i)}" io = set(self.fgraph.inputs + self.fgraph.outputs) for i, r in enumerate(self.fgraph.variables): - if r not in io and len(self.fgraph.clients[r]) > 1: + if ( + not isinstance(r, Constant) + and r not in io + and len(self.fgraph.clients[r]) > 1 + ): r.name = f"t{int(i)}" if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10: @@ -4345,7 +4349,7 @@ def c_code_template(self): if var not in self.fgraph.inputs: # This is an orphan if isinstance(var, Constant) and isinstance(var.type, CLinkerType): - subd[var] = var.type.c_literal(var.data) + subd[var] = f"({var.type.c_literal(var.data)})" else: raise ValueError( "All orphans in the fgraph to Composite must" @@ -4404,7 +4408,7 @@ def c_code(self, node, nodename, inames, onames, sub): return self.c_code_template % d def c_code_cache_version_outer(self) -> tuple[int, ...]: - return (4,) + return (5,) class Compositef32: diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 189cd461c7..59664374f9 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -239,7 +239,7 @@ def c_code_template(self): if var not in self.fgraph.inputs: # This is an orphan if isinstance(var, Constant) and isinstance(var.type, CLinkerType): - subd[var] = var.type.c_literal(var.data) + subd[var] = f"({var.type.c_literal(var.data)})" else: raise ValueError( "All orphans in the fgraph to ScalarLoop must" @@ -342,4 +342,4 @@ def c_code(self, node, nodename, inames, onames, sub): return res def c_code_cache_version_outer(self): - return (2,) + return (3,) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2e30e1399b..68cc0e5e96 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node): part of the graph. """ - for idx, i in enumerate(node.inputs): - if i.owner and i.owner.op == switch: - switch_node = i.owner - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[1], only_process_constants=True - ) - == 0.0 - ): - listmul = node.inputs[:idx] + node.inputs[idx + 1 :] - fmul = mul(*([*listmul, switch_node.inputs[2]])) - - # Copy over stacktrace for elementwise multiplication op - # from previous elementwise multiplication op. - # An error in the multiplication (e.g. errors due to - # inconsistent shapes), will point to the - # multiplication op. - copy_stack_trace(node.outputs, fmul) - - fct = [switch(switch_node.inputs[0], 0, fmul)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise multiplication op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[2], only_process_constants=True - ) - == 0.0 - ): - listmul = node.inputs[:idx] + node.inputs[idx + 1 :] - fmul = mul(*([*listmul, switch_node.inputs[1]])) - # Copy over stacktrace for elementwise multiplication op - # from previous elementwise multiplication op. - # An error in the multiplication (e.g. errors due to - # inconsistent shapes), will point to the - # multiplication op. - copy_stack_trace(node.outputs, fmul) - - fct = [switch(switch_node.inputs[0], fmul, 0)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise multiplication op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - return False + for mul_inp_idx, mul_inp in enumerate(node.inputs): + if mul_inp.owner and mul_inp.owner.op == switch: + switch_node = mul_inp.owner + # Look for a zero as the first or second branch of the switch + for branch in range(2): + zero_switch_input = switch_node.inputs[1 + branch] + if not get_unique_constant_value(zero_switch_input) == 0.0: + continue + + switch_cond = switch_node.inputs[0] + other_switch_input = switch_node.inputs[1 + (1 - branch)] + + listmul = list(node.inputs) + listmul[mul_inp_idx] = other_switch_input + fmul = mul(*listmul) + + # Copy over stacktrace for elementwise multiplication op + # from previous elementwise multiplication op. + # An error in the multiplication (e.g. errors due to + # inconsistent shapes), will point to the + # multiplication op. + copy_stack_trace(node.outputs, fmul) + + if branch == 0: + fct = switch(switch_cond, zero_switch_input, fmul) + else: + fct = switch(switch_cond, fmul, zero_switch_input) + + # Tell debug_mode than the output is correct, even if nan disappear + fct.tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise multiplication op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return [fct] @register_canonicalize @@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node): See `local_mul_switch_sink` for more details. """ - op = node.op - if node.inputs[0].owner and node.inputs[0].owner.op == switch: - switch_node = node.inputs[0].owner - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[1], only_process_constants=True - ) - == 0.0 - ): - fdiv = op(switch_node.inputs[2], node.inputs[1]) - # Copy over stacktrace for elementwise division op - # from previous elementwise multiplication op. - # An error in the division (e.g. errors due to - # inconsistent shapes or division by zero), - # will point to the new division op. - copy_stack_trace(node.outputs, fdiv) - - fct = [switch(switch_node.inputs[0], 0, fdiv)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise division op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - try: - if ( - get_underlying_scalar_constant_value( - switch_node.inputs[2], only_process_constants=True - ) - == 0.0 - ): - fdiv = op(switch_node.inputs[1], node.inputs[1]) - # Copy over stacktrace for elementwise division op - # from previous elementwise multiplication op. - # An error in the division (e.g. errors due to - # inconsistent shapes or division by zero), - # will point to the new division op. - copy_stack_trace(node.outputs, fdiv) - - fct = [switch(switch_node.inputs[0], fdiv, 0)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan + num, denom = node.inputs - # Copy over stacktrace for switch op from both previous - # elementwise division op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - return False + if num.owner and num.owner.op == switch: + switch_node = num.owner + # Look for a zero as the first or second branch of the switch + for branch in range(2): + zero_switch_input = switch_node.inputs[1 + branch] + if not get_unique_constant_value(zero_switch_input) == 0.0: + continue + + switch_cond = switch_node.inputs[0] + other_switch_input = switch_node.inputs[1 + (1 - branch)] + + fdiv = node.op(other_switch_input, denom) + + # Copy over stacktrace for elementwise division op + # from previous elementwise multiplication op. + # An error in the division (e.g. errors due to + # inconsistent shapes or division by zero), + # will point to the new division op. + copy_stack_trace(node.outputs, fdiv) + + fct = switch(switch_cond, zero_switch_input, fdiv) + + # Tell debug_mode than the output is correct, even if nan disappear + fct.tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise division op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return [fct] class AlgebraicCanonizer(NodeRewriter): diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 261a8bbc4a..db87d04f93 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -1045,11 +1045,13 @@ def get_unique_constant_value(x: TensorVariable) -> Number | None: if isinstance(x, Constant): data = x.data - if isinstance(data, np.ndarray) and data.ndim > 0: + if isinstance(data, np.ndarray) and data.size > 0: + if data.size == 1: + return data.squeeze() + flat_data = data.ravel() - if flat_data.shape[0]: - if (flat_data == flat_data[0]).all(): - return flat_data[0] + if (flat_data == flat_data[0]).all(): + return flat_data[0] return None diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index c8f0fc335b..e648869d4c 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -36,6 +36,7 @@ floats, int8, int32, + int64, ints, invert, log, @@ -44,6 +45,7 @@ log10, mean, mul, + neg, neq, rad2deg, reciprocal, @@ -156,6 +158,21 @@ def checker(x, y): (literal_value + test_y) * (test_x / test_y), ) + def test_negative_constant(self): + # Test that a negative constant is wrapped in parentheses to avoid confusing - (unary minus) and -- (decrement) + x = int64("x") + e = neg(constant(-1.5)) % x + comp_op = Composite([x], [e]) + comp_node = comp_op.make_node(x) + + c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0)) + assert "-1.5" in c_code + + g = FunctionGraph([x], [comp_node.out]) + fn = make_function(DualLinker().accept(g)) + assert fn(2) == 1.5 + assert fn(1) == 0.5 + def test_many_outputs(self): x, y, z = floats("xyz") e0 = x + y + z diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 42d81fbf11..9df0966b78 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -654,24 +654,22 @@ def no_shared_fn(n, x_tm1, M): Inner graphs: Scan{scan_fn, while_loop=False, inplace=all} [id A] - ← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0) - ├─ 0 [id J] - ├─ Subtensor{i, j, k} [id K] - │ ├─ *2- [id L] -> [id H] (inner_in_non_seqs-0) - │ ├─ ScalarFromTensor [id M] - │ │ └─ *0- [id N] -> [id C] (inner_in_seqs-0) - │ ├─ ScalarFromTensor [id O] - │ │ └─ *1- [id P] -> [id D] (inner_in_sit_sot-0) - │ └─ 0 [id Q] - └─ 1 [id R] - - Composite{switch(lt(i0, i1), i2, i0)} [id I] - ← Switch [id S] 'o0' - ├─ LT [id T] - │ ├─ i0 [id U] - │ └─ i1 [id V] - ├─ i2 [id W] - └─ i0 [id U] + ← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0) + └─ Subtensor{i, j, k} [id J] + ├─ *2- [id K] -> [id H] (inner_in_non_seqs-0) + ├─ ScalarFromTensor [id L] + │ └─ *0- [id M] -> [id C] (inner_in_seqs-0) + ├─ ScalarFromTensor [id N] + │ └─ *1- [id O] -> [id D] (inner_in_sit_sot-0) + └─ 0 [id P] + + Composite{switch(lt(0, i0), 1, 0)} [id I] + ← Switch [id Q] 'o0' + ├─ LT [id R] + │ ├─ 0 [id S] + │ └─ i0 [id T] + ├─ 1 [id U] + └─ 0 [id S] """ output_str = debugprint(out, file="str", print_op_info=True) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 019833a9d5..1212ee4fbd 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -97,9 +97,11 @@ from pytensor.tensor.rewriting.math import ( compute_mul, is_1pexp, + local_div_switch_sink, local_grad_log_erfc_neg, local_greedy_distributor, local_mul_canonizer, + local_mul_switch_sink, local_reduce_chain, local_sum_prod_of_mul_or_div, mul_canonizer, @@ -2115,7 +2117,6 @@ def test_local_mul_switch_sink(self): f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode) assert f(5) == 1, f(5) - @pytest.mark.slow def test_local_div_switch_sink(self): c = dscalar() idx = 0 @@ -2149,6 +2150,28 @@ def test_local_div_switch_sink(self): ].size idx += 1 + @pytest.mark.parametrize( + "op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)] + ) + def test_local_mul_div_switch_sink_cast(self, op, rewrite): + """Check that we don't downcast during the rewrite. + + Regression test for: https://github.com/pymc-devs/pytensor/issues/1037 + """ + cond = scalar("cond", dtype="bool") + # The zero branch upcasts the output, so we can't ignore its dtype + zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch") + other_branch = scalar("other_branch", dtype="float32") + outer_var = scalar("mul_var", dtype="bool") + + out = op(switch(cond, zero_branch, other_branch), outer_var) + fgraph = FunctionGraph(outputs=[out], clone=False) + [new_out] = rewrite.transform(fgraph, out.owner) + assert new_out.type.dtype == out.type.dtype + + expected_out = switch(cond, zero_branch, op(other_branch, outer_var)) + assert equal_computations([new_out], [expected_out]) + @pytest.mark.skipif( config.cxx == "",