Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Blockwise and RandomVariable in Numba with repeated arguments #1222

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 19, 2025

To implement Blockwise and RandomVariables in numba we use an OpWithCoreShape OpFromGraph subclass, that extends the original node with a core_shape input. This approach would fail whenever the original node had repeated inputs, as OpFromGraph raises by default when that's the case.

For now I think simply allowing OpFromGraph to ignore those disconnected inputs should work fine. We can revisit later and use a node.clone_with_new_inputs or whatever is called to mask the fact the inputs are identical. I have a 90% confidence this won't be needed so I'm trying the simple approach until reality proves it wrong.


📚 Documentation preview 📚: https://pytensor--1222.org.readthedocs.build/en/1222/

x = tensor3("x")
x_test = np.full((1, 1, 1), 2.0, dtype=x.type.dtype)
out = x @ x
fn, _ = compare_numba_and_py([x], [out], [x_test], eval_obj_mode=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it, where's the repeated arg?

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul with x twice

Copy link

codecov bot commented Feb 19, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.03%. Comparing base (e25e8a2) to head (2e036b2).
Report is 4 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1222   +/-   ##
=======================================
  Coverage   82.03%   82.03%           
=======================================
  Files         188      188           
  Lines       48535    48537    +2     
  Branches     8676     8676           
=======================================
+ Hits        39814    39816    +2     
  Misses       6558     6558           
  Partials     2163     2163           
Files with missing lines Coverage Δ
pytensor/tensor/blockwise.py 85.64% <100.00%> (+0.13%) ⬆️

@ricardoV94 ricardoV94 force-pushed the numba_rvs_with_repeated_inputs branch from 680086c to 2e036b2 Compare February 19, 2025 12:13
@ricardoV94 ricardoV94 merged commit 3cdcfde into pymc-devs:main Feb 19, 2025
67 checks passed
@ricardoV94 ricardoV94 deleted the numba_rvs_with_repeated_inputs branch February 19, 2025 12:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants