-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] add jit
and vmap
support to eval_jaxpr
#1055
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1055 +/- ##
===========================================
- Coverage 98.00% 66.30% -31.71%
===========================================
Files 233 28 -205
Lines 39642 2511 -37131
===========================================
- Hits 38853 1665 -37188
- Misses 789 846 +57 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ready to approve except for the testing comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, though probably we also need to hear from Lightning team to see if any concerns regarding the dependecy introduced @josephleekl @maliasadi
…I/pennylane-lightning into capture-pure-callback
Co-authored-by: Mudit Pandey <[email protected]>
…I/pennylane-lightning into capture-pure-callback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these codecov warnings all false positives?
I also got them on #1067 as well. These tests are actually running, since I got failures in the CI I had to fix. So I don't really know what's causing code cov to be so confused. |
UPDATE: FIXED IT: I was being stupid. So I'm seeing in the CI:
but when I run:
It works fine. If my install was bad, I would expect this to error out. I'm confused. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albi3ro! Happy to approve
Note that this change isn't necessarily high priority, but I thought I'd just open the PR since I already had the code.
Context:
The lightning devices currently throw an error with
jax.jit
andjax.vmap
and program capture. While there might be some discussion about where thejax.pure_callback
should like long term, for now I'm just adding thejax.pure_callback
call inside the relevant devices. Not many devices will be implementingDevice.eval_jaxpr
at the moment anyway.Description of the Change:
Adds a
jax.pure_callback
layer inDevice.eval_jaxpr
.Benefits:
jax.jit
andjax.vmap
can now be used with capture and lightning devices.Possible Drawbacks:
I don't know if we want to have the
pure_callback
layer live on the device long term.Also, it turns all errors into
jaxlib.xla_extension.XlaRuntimeError
, which a bit scarier to debug.Related Shortcut Stories:
[sc-83589]