Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix downcasting bug in local_[mul|div]_switch_sink rewrite #1059

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4246,7 +4246,11 @@ def __str__(self):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
if (
not isinstance(r, Constant)
and r not in io
and len(self.fgraph.clients[r]) > 1
):
r.name = f"t{int(i)}"

if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
Expand Down Expand Up @@ -4345,7 +4349,7 @@ def c_code_template(self):
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
subd[var] = f"({var.type.c_literal(var.data)})"
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
Expand Down Expand Up @@ -4404,7 +4408,7 @@ def c_code(self, node, nodename, inames, onames, sub):
return self.c_code_template % d

def c_code_cache_version_outer(self) -> tuple[int, ...]:
return (4,)
return (5,)


class Compositef32:
Expand Down
4 changes: 2 additions & 2 deletions pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def c_code_template(self):
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
subd[var] = f"({var.type.c_literal(var.data)})"
else:
raise ValueError(
"All orphans in the fgraph to ScalarLoop must"
Expand Down Expand Up @@ -342,4 +342,4 @@ def c_code(self, node, nodename, inames, onames, sub):
return res

def c_code_cache_version_outer(self):
return (2,)
return (3,)
183 changes: 69 additions & 114 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node):
part of the graph.

"""
for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == switch:
switch_node = i.owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[2]]))

# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)

fct = [switch(switch_node.inputs[0], 0, fmul)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[1]]))
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)

fct = [switch(switch_node.inputs[0], fmul, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
for mul_inp_idx, mul_inp in enumerate(node.inputs):
if mul_inp.owner and mul_inp.owner.op == switch:
switch_node = mul_inp.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
continue

switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]

listmul = list(node.inputs)
listmul[mul_inp_idx] = other_switch_input
fmul = mul(*listmul)

# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)

if branch == 0:
fct = switch(switch_cond, zero_switch_input, fmul)
else:
fct = switch(switch_cond, fmul, zero_switch_input)

# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return [fct]


@register_canonicalize
Expand All @@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details.

"""
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = [switch(switch_node.inputs[0], 0, fdiv)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = [switch(switch_node.inputs[0], fdiv, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
num, denom = node.inputs

# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
if num.owner and num.owner.op == switch:
switch_node = num.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
continue

switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]

fdiv = node.op(other_switch_input, denom)

# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)

fct = switch(switch_cond, zero_switch_input, fdiv)

# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan

# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return [fct]


class AlgebraicCanonizer(NodeRewriter):
Expand Down
10 changes: 6 additions & 4 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,11 +1045,13 @@ def get_unique_constant_value(x: TensorVariable) -> Number | None:
if isinstance(x, Constant):
data = x.data

if isinstance(data, np.ndarray) and data.ndim > 0:
if isinstance(data, np.ndarray) and data.size > 0:
if data.size == 1:
return data.squeeze()

flat_data = data.ravel()
if flat_data.shape[0]:
if (flat_data == flat_data[0]).all():
return flat_data[0]
if (flat_data == flat_data[0]).all():
return flat_data[0]

return None

Expand Down
17 changes: 17 additions & 0 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
floats,
int8,
int32,
int64,
ints,
invert,
log,
Expand All @@ -44,6 +45,7 @@
log10,
mean,
mul,
neg,
neq,
rad2deg,
reciprocal,
Expand Down Expand Up @@ -156,6 +158,21 @@ def checker(x, y):
(literal_value + test_y) * (test_x / test_y),
)

def test_negative_constant(self):
# Test that a negative constant is wrapped in parentheses to avoid confusing - (unary minus) and -- (decrement)
x = int64("x")
e = neg(constant(-1.5)) % x
comp_op = Composite([x], [e])
comp_node = comp_op.make_node(x)

c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
assert "-1.5" in c_code

g = FunctionGraph([x], [comp_node.out])
fn = make_function(DualLinker().accept(g))
assert fn(2) == 1.5
assert fn(1) == 0.5

def test_many_outputs(self):
x, y, z = floats("xyz")
e0 = x + y + z
Expand Down
34 changes: 16 additions & 18 deletions tests/scan/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,24 +654,22 @@ def no_shared_fn(n, x_tm1, M):
Inner graphs:

Scan{scan_fn, while_loop=False, inplace=all} [id A]
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ 0 [id J]
├─ Subtensor{i, j, k} [id K]
│ ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
│ ├─ ScalarFromTensor [id M]
│ │ └─ *0-<Scalar(int64, shape=())> [id N] -> [id C] (inner_in_seqs-0)
│ ├─ ScalarFromTensor [id O]
│ │ └─ *1-<Scalar(int64, shape=())> [id P] -> [id D] (inner_in_sit_sot-0)
│ └─ 0 [id Q]
└─ 1 [id R]

Composite{switch(lt(i0, i1), i2, i0)} [id I]
← Switch [id S] 'o0'
├─ LT [id T]
│ ├─ i0 [id U]
│ └─ i1 [id V]
├─ i2 [id W]
└─ i0 [id U]
← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
└─ Subtensor{i, j, k} [id J]
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
├─ ScalarFromTensor [id L]
│ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
├─ ScalarFromTensor [id N]
│ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
└─ 0 [id P]

Composite{switch(lt(0, i0), 1, 0)} [id I]
← Switch [id Q] 'o0'
├─ LT [id R]
│ ├─ 0 [id S]
│ └─ i0 [id T]
├─ 1 [id U]
└─ 0 [id S]
"""

output_str = debugprint(out, file="str", print_op_info=True)
Expand Down
25 changes: 24 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@
from pytensor.tensor.rewriting.math import (
compute_mul,
is_1pexp,
local_div_switch_sink,
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
local_mul_switch_sink,
local_reduce_chain,
local_sum_prod_of_mul_or_div,
mul_canonizer,
Expand Down Expand Up @@ -2115,7 +2117,6 @@ def test_local_mul_switch_sink(self):
f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode)
assert f(5) == 1, f(5)

@pytest.mark.slow
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not particularly slow and not slower than the test above

def test_local_div_switch_sink(self):
c = dscalar()
idx = 0
Expand Down Expand Up @@ -2149,6 +2150,28 @@ def test_local_div_switch_sink(self):
].size
idx += 1

@pytest.mark.parametrize(
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
)
def test_local_mul_div_switch_sink_cast(self, op, rewrite):
"""Check that we don't downcast during the rewrite.

Regression test for: https://github.com/pymc-devs/pytensor/issues/1037
"""
cond = scalar("cond", dtype="bool")
# The zero branch upcasts the output, so we can't ignore its dtype
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
other_branch = scalar("other_branch", dtype="float32")
outer_var = scalar("mul_var", dtype="bool")

out = op(switch(cond, zero_branch, other_branch), outer_var)
fgraph = FunctionGraph(outputs=[out], clone=False)
[new_out] = rewrite.transform(fgraph, out.owner)
assert new_out.type.dtype == out.type.dtype

expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
assert equal_computations([new_out], [expected_out])


@pytest.mark.skipif(
config.cxx == "",
Expand Down
Loading