Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] add jit and vmap support to eval_jaxpr #1055

Merged
merged 20 commits into from
Feb 25, 2025
Merged

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 27, 2025

Note that this change isn't necessarily high priority, but I thought I'd just open the PR since I already had the code.

Context:

The lightning devices currently throw an error with jax.jit and jax.vmap and program capture. While there might be some discussion about where the jax.pure_callback should like long term, for now I'm just adding the jax.pure_callback call inside the relevant devices. Not many devices will be implementing Device.eval_jaxpr at the moment anyway.

Description of the Change:

Adds a jax.pure_callback layer in Device.eval_jaxpr.

Benefits:

jax.jit and jax.vmap can now be used with capture and lightning devices.

Possible Drawbacks:

I don't know if we want to have the pure_callback layer live on the device long term.

Also, it turns all errors into jaxlib.xla_extension.XlaRuntimeError, which a bit scarier to debug.

Related Shortcut Stories:
[sc-83589]

Copy link

codecov bot commented Jan 28, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 66.30%. Comparing base (7d24015) to head (046155c).
Report is 1 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (7d24015) and HEAD (046155c). Click for more details.

HEAD has 29 uploads less than BASE
Flag BASE (7d24015) HEAD (046155c)
35 6
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #1055       +/-   ##
===========================================
- Coverage   98.00%   66.30%   -31.71%     
===========================================
  Files         233       28      -205     
  Lines       39642     2511    -37131     
===========================================
- Hits        38853     1665    -37188     
- Misses        789      846       +57     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ready to approve except for the testing comment.

Copy link
Contributor

@JerryChen97 JerryChen97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, though probably we also need to hear from Lightning team to see if any concerns regarding the dependecy introduced @josephleekl @maliasadi

Copy link
Contributor

@JerryChen97 JerryChen97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@AmintorDusko AmintorDusko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these codecov warnings all false positives?

@albi3ro
Copy link
Contributor Author

albi3ro commented Feb 25, 2025

Are these codecov warnings all false positives?

I also got them on #1067 as well. These tests are actually running, since I got failures in the CI I had to fix. So I don't really know what's causing code cov to be so confused.

@albi3ro
Copy link
Contributor Author

albi3ro commented Feb 25, 2025

UPDATE: FIXED IT: I was being stupid.

So I'm seeing in the CI:

FAILED tests/test_eval_jaxpr.py::test_vmap_in_axes[0-1] - assert False

but when I run:

in_axis=0
out_axis=1

@qml.qnode(qml.device('lightning.qubit', wires=1))
def circuit(mat):
    qml.QubitUnitary(mat, 0)
    return qml.expval(qml.Z(0)), qml.state()

mats = jax.numpy.stack(
    [qml.X.compute_matrix(), qml.Y.compute_matrix(), qml.Z.compute_matrix()], axis=in_axis
)
expval, state = jax.vmap(circuit, in_axes=in_axis, out_axes=(0, out_axis))(mats)

assert expval.shape == (3,)
assert qml.math.allclose(expval, jax.numpy.array([-1, -1, 1]))  # flip, flip, no flip

assert state.shape == (3, 4) if out_axis == 0 else (4, 3)

It works fine. If my install was bad, I would expect this to error out.

I'm confused.

@albi3ro albi3ro requested a review from maliasadi February 25, 2025 16:13
Copy link
Member

@maliasadi maliasadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @albi3ro! Happy to approve

@albi3ro albi3ro merged commit dcebe62 into master Feb 25, 2025
69 of 70 checks passed
@albi3ro albi3ro deleted the capture-pure-callback branch February 25, 2025 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants