diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index e86411dd9c..9462504e78 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -966,16 +966,15 @@ def local_reshape_to_dimshuffle(fgraph, node): inp, output_shape = node.inputs [output] = node.outputs - # Remove any broadcastable dimensions from the input - squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast] - # Trivial case, all dimensions of input/output are known to be broadcastable: # there's nothing to reshape if all(inp.type.broadcastable) or all(output.type.broadcastable): + squeeze_axes = tuple(range(inp.type.ndim)) new_output_shape = [] expand_axes = tuple(range(output.type.ndim)) else: + squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast] unpacked_shape = _unpack_shape_vector(output_shape) new_output_shape = [] expand_axes = [] diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 27678bd630..43df9ffd23 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -445,6 +445,15 @@ def test_squeeze_of_alloc(self): new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt")) assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False) + def test_reshape_implies_size_1_input(self): + x = pt.matrix("x", shape=(None, None)) + out = pt.reshape(x, (1, 1, 1)) + + new_out = rewrite_graph(out, include=("canonicalize",)) + assert equal_computations( + [new_out], [x.dimshuffle("x", "x", "x")], strict_dtype=False + ) + def test_expand_dims_squeeze_reshape_fusion(): x = pt.tensor("x", shape=(1, 9))