Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite reshapes that only expand or squeeze dims #1200

Merged
merged 6 commits into from
Feb 17, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 10, 2025

We canonicalize so that reshape is only used for the behavior that is unique to it (mixing dimensions). Expand_dims and squeeze-like behavior are canonicalized into the respective forms of DimShuffle. This allows other rewrites that know how to reason about these (but not reshape, which is too flexible), such as in the following graph:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(9,))
out = pt.repeat(x[None], 12, axis=0)
pytensor.function([x], out).dprint(print_type=True)   

Which before this PR generated this graph:

# Reshape{2} [id A] <Matrix(float64, shape=(12, 9))> 1
#  ├─ Alloc [id B] <Tensor3(float64, shape=(1, 12, 9))> 0
#  │  ├─ x [id C] <Vector(float64, shape=(9,))>
#  │  ├─ 1 [id D] <Scalar(int64, shape=())>
#  │  ├─ 12 [id E] <Scalar(int64, shape=())>
#  │  └─ 9 [id F] <Scalar(int64, shape=())>
#  └─ [12  9] [id G] <Vector(int64, shape=(2,))>

And now simplifies:

# Alloc [id A] <Matrix(float64, shape=(12, 9))> 0
#  ├─ x [id B] <Vector(float64, shape=(9,))>
#  ├─ 12 [id C] <Scalar(int64, shape=())>
#  └─ 9 [id D] <Scalar(int64, shape=())>

Or a naive broadcast + elemwise comparison:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(3,))
y = pt.vector("y", shape=(9,))
out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1)
pytensor.function([x, y], out).dprint(print_type=True)   

Which used to generate this graph:

Le [id A] <Matrix(bool, shape=(9, 3))> 5
 ├─ Reshape{2} [id B] <Matrix(float64, shape=(9, 3))> 4
 │  ├─ Alloc [id C] <Tensor3(float64, shape=(1, 9, 3))> 3
 │  │  ├─ x [id D] <Vector(float64, shape=(3,))>
 │  │  ├─ 1 [id E] <Scalar(int64, shape=())>
 │  │  ├─ 9 [id F] <Scalar(int64, shape=())>
 │  │  └─ 3 [id G] <Scalar(int64, shape=())>
 │  └─ [9 3] [id H] <Vector(int64, shape=(2,))>
 └─ Reshape{2} [id I] <Matrix(float64, shape=(9, 3))> 2
    ├─ Alloc [id J] <Tensor3(float64, shape=(9, 1, 3))> 1
    │  ├─ ExpandDims{axes=[1, 2]} [id K] <Tensor3(float64, shape=(9, 1, 1))> 0
    │  │  └─ y [id L] <Vector(float64, shape=(9,))>
    │  ├─ 9 [id F] <Scalar(int64, shape=())>
    │  ├─ 1 [id E] <Scalar(int64, shape=())>
    │  └─ 3 [id G] <Scalar(int64, shape=())>
    └─ [9 3] [id H] <Vector(int64, shape=(2,))>

And now generates:

# Le [id A] <Matrix(bool, shape=(9, 3))> 2
#  ├─ ExpandDims{axis=0} [id B] <Matrix(float64, shape=(1, 3))> 1
#  │  └─ x [id C] <Vector(float64, shape=(3,))>
#  └─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(9, 1))> 0
#     └─ y [id E] <Vector(float64, shape=(9,))>

Which is great because it avoids materializing the full broadcasted inputs!

This example is not absurd, it showed up in this PyMC model: https://gist.github.com/ricardoV94/f986686ce86511b293c5dd6be374e51d

Also fixed a bug in local_useless_reshape, that may be behind some failures that @tanish1729 detected

Closes #845
Closes #1123
Related to #883


📚 Documentation preview 📚: https://pytensor--1200.org.readthedocs.build/en/1200/

@ricardoV94 ricardoV94 added the bug Something isn't working label Feb 10, 2025
@ricardoV94 ricardoV94 force-pushed the reshape_rewrite branch 2 times, most recently from 3e23871 to 60340c7 Compare February 11, 2025 15:21
Copy link

codecov bot commented Feb 11, 2025

Codecov Report

Attention: Patch coverage is 97.67442% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.27%. Comparing base (b5a64c7) to head (516f5b9).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/shape.py 97.22% 1 Missing and 2 partials ⚠️

❌ Your patch status has failed because the patch coverage (97.67%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1200      +/-   ##
==========================================
- Coverage   82.28%   82.27%   -0.02%     
==========================================
  Files         186      186              
  Lines       47987    48018      +31     
  Branches     8629     8634       +5     
==========================================
+ Hits        39486    39505      +19     
- Misses       6347     6362      +15     
+ Partials     2154     2151       -3     
Files with missing lines Coverage Δ
pytensor/graph/rewriting/basic.py 69.62% <ø> (-0.70%) ⬇️
pytensor/tensor/elemwise.py 89.01% <100.00%> (+0.04%) ⬆️
pytensor/tensor/shape.py 90.14% <100.00%> (-0.34%) ⬇️
pytensor/tensor/slinalg.py 93.52% <100.00%> (ø)
pytensor/tensor/rewriting/shape.py 82.71% <97.22%> (+0.68%) ⬆️

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some questions/nitpicks -- I think it could be merged as-is but I want answers!

@ricardoV94
Copy link
Member Author

@jessegrabowski I addressed your comments, let me know if it makes you happy :)

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code always makes me happy, I just hope my reviews always make you happy too : )

@ricardoV94 ricardoV94 merged commit 65b96c1 into pymc-devs:main Feb 17, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite away reshape that drop dims Rewrite to remove useless Reshape
2 participants