Skip to content

Commit 3e1521b

Browse files
[Capture] Add a QmlPrimitive class to differentiate between different types of primitives (#6847)
This PR adds a `QmlPrimitive` subclass of `jax.core.Primitive`. This class contains a `prim_type` property set using a new `PrimitiveType` enum. `PrimitiveType`s currently available are "default", "operator", "measurement", "transform", and "higher_order". This can be made more or less fine grained as needed, but should be enough to differentiate between different types of primitives for now. Additionally, this PR: * updates `NonInterpPrimitive` to be a subclass of `QmlPrimitive` * updates all existing PennyLane primitives to be either `QmlPrimitive` or `NonInterpPrimitive`. See [this comment](#6851 (comment)) to see the logic used to determine which `Primitive` subclass is used for each primitive. * updates `PlxprInterpreter.eval` and `CancelInversesInterpreter.eval` to use this `prim_type` property. [sc-82420] --------- Co-authored-by: Pietropaolo Frisoni <[email protected]>
1 parent fdf34ec commit 3e1521b

16 files changed

+209
-73
lines changed

doc/releases/changelog-dev.md

+6
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@
9090

9191
<h3>Internal changes ⚙️</h3>
9292

93+
* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module.
94+
This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives.
95+
Consequently, `QmlPrimitive` is now used to define all PennyLane primitives.
96+
[(#6847)](https://github.com/PennyLaneAI/pennylane/pull/6847)
97+
9398
<h3>Documentation 📝</h3>
9499

95100
* The docstrings for `qml.unary_mapping`, `qml.binary_mapping`, `qml.christiansen_mapping`,
@@ -115,4 +120,5 @@ Diksha Dhawan,
115120
Pietropaolo Frisoni,
116121
Marcus Gisslén,
117122
Christina Lee,
123+
Mudit Pandey,
118124
Andrija Paurevic

pennylane/capture/base_interpreter.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
from .flatfn import FlatFn
2727
from .primitives import (
28-
AbstractMeasurement,
29-
AbstractOperator,
3028
adjoint_transform_prim,
3129
cond_prim,
3230
ctrl_transform_prim,
@@ -311,20 +309,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
311309
self._env[constvar] = const
312310

313311
for eqn in jaxpr.eqns:
312+
primitive = eqn.primitive
313+
custom_handler = self._primitive_registrations.get(primitive, None)
314314

315-
custom_handler = self._primitive_registrations.get(eqn.primitive, None)
316315
if custom_handler:
317316
invals = [self.read(invar) for invar in eqn.invars]
318317
outvals = custom_handler(self, *invals, **eqn.params)
319-
elif isinstance(eqn.outvars[0].aval, AbstractOperator):
318+
elif getattr(primitive, "prim_type", "") == "operator":
320319
outvals = self.interpret_operation_eqn(eqn)
321-
elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
320+
elif getattr(primitive, "prim_type", "") == "measurement":
322321
outvals = self.interpret_measurement_eqn(eqn)
323322
else:
324323
invals = [self.read(invar) for invar in eqn.invars]
325-
outvals = eqn.primitive.bind(*invals, **eqn.params)
324+
outvals = primitive.bind(*invals, **eqn.params)
326325

327-
if not eqn.primitive.multiple_results:
326+
if not primitive.multiple_results:
328327
outvals = [outvals]
329328
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
330329
self._env[outvar] = outval

pennylane/capture/capture_diff.py

+11-30
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,6 @@
2424
has_jax = False
2525

2626

27-
@lru_cache
28-
def create_non_interpreted_prim():
29-
"""Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace
30-
and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive.
31-
"""
32-
33-
if not has_jax: # pragma: no cover
34-
return None
35-
36-
# pylint: disable=too-few-public-methods
37-
class NonInterpPrimitive(jax.core.Primitive):
38-
"""A subclass to JAX's Primitive that works like a Python function
39-
when evaluating JVPTracers and BatchTracers."""
40-
41-
def bind_with_trace(self, trace, args, params):
42-
"""Bind the ``NonInterpPrimitive`` with a trace.
43-
44-
If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
45-
Otherwise, the bind call of JAX's standard Primitive is used."""
46-
if isinstance(
47-
trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)
48-
):
49-
return self.impl(*args, **params)
50-
return super().bind_with_trace(trace, args, params)
51-
52-
return NonInterpPrimitive
53-
54-
5527
@lru_cache
5628
def _get_grad_prim():
5729
"""Create a primitive for gradient computations.
@@ -60,8 +32,11 @@ def _get_grad_prim():
6032
if not has_jax: # pragma: no cover
6133
return None
6234

63-
grad_prim = create_non_interpreted_prim()("grad")
35+
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
36+
37+
grad_prim = NonInterpPrimitive("grad")
6438
grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init
39+
grad_prim.prim_type = "higher_order"
6540

6641
# pylint: disable=too-many-arguments
6742
@grad_prim.def_impl
@@ -91,8 +66,14 @@ def _get_jacobian_prim():
9166
"""Create a primitive for Jacobian computations.
9267
This primitive is used when capturing ``qml.jacobian``.
9368
"""
94-
jacobian_prim = create_non_interpreted_prim()("jacobian")
69+
if not has_jax: # pragma: no cover
70+
return None
71+
72+
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
73+
74+
jacobian_prim = NonInterpPrimitive("jacobian")
9575
jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init
76+
jacobian_prim.prim_type = "higher_order"
9677

9778
# pylint: disable=too-many-arguments
9879
@jacobian_prim.def_impl

pennylane/capture/capture_measurements.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ def create_measurement_obs_primitive(
128128
if not has_jax:
129129
return None
130130

131-
primitive = jax.core.Primitive(name + "_obs")
131+
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
132+
133+
primitive = NonInterpPrimitive(name + "_obs")
134+
primitive.prim_type = "measurement"
132135

133136
@primitive.def_impl
134137
def _(obs, **kwargs):
@@ -165,7 +168,10 @@ def create_measurement_mcm_primitive(
165168
if not has_jax:
166169
return None
167170

168-
primitive = jax.core.Primitive(name + "_mcm")
171+
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
172+
173+
primitive = NonInterpPrimitive(name + "_mcm")
174+
primitive.prim_type = "measurement"
169175

170176
@primitive.def_impl
171177
def _(*mcms, single_mcm=True, **kwargs):
@@ -200,7 +206,10 @@ def create_measurement_wires_primitive(
200206
if not has_jax:
201207
return None
202208

203-
primitive = jax.core.Primitive(name + "_wires")
209+
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
210+
211+
primitive = NonInterpPrimitive(name + "_wires")
212+
primitive.prim_type = "measurement"
204213

205214
@primitive.def_impl
206215
def _(*args, has_eigvals=False, **kwargs):

pennylane/capture/capture_operators.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
import pennylane as qml
2222

23-
from .capture_diff import create_non_interpreted_prim
24-
2523
has_jax = True
2624
try:
2725
import jax
@@ -103,7 +101,10 @@ def create_operator_primitive(
103101
if not has_jax:
104102
return None
105103

106-
primitive = create_non_interpreted_prim()(operator_type.__name__)
104+
from .custom_primitives import NonInterpPrimitive # pylint: disable=import-outside-toplevel
105+
106+
primitive = NonInterpPrimitive(operator_type.__name__)
107+
primitive.prim_type = "operator"
107108

108109
@primitive.def_impl
109110
def _(*args, **kwargs):
+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2024 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This submodule offers custom primitives for the PennyLane capture module.
16+
"""
17+
from enum import Enum
18+
from typing import Union
19+
20+
import jax
21+
22+
23+
class PrimitiveType(Enum):
24+
"""Enum to define valid set of primitive classes"""
25+
26+
DEFAULT = "default"
27+
OPERATOR = "operator"
28+
MEASUREMENT = "measurement"
29+
HIGHER_ORDER = "higher_order"
30+
TRANSFORM = "transform"
31+
32+
33+
# pylint: disable=too-few-public-methods,abstract-method
34+
class QmlPrimitive(jax.core.Primitive):
35+
"""A subclass for JAX's Primitive that differentiates between different
36+
classes of primitives."""
37+
38+
_prim_type: PrimitiveType = PrimitiveType.DEFAULT
39+
40+
@property
41+
def prim_type(self):
42+
"""Value of Enum representing the primitive type to differentiate between various
43+
sets of PennyLane primitives."""
44+
return self._prim_type.value
45+
46+
@prim_type.setter
47+
def prim_type(self, value: Union[str, PrimitiveType]):
48+
"""Setter for QmlPrimitive.prim_type."""
49+
self._prim_type = PrimitiveType(value)
50+
51+
52+
# pylint: disable=too-few-public-methods,abstract-method
53+
class NonInterpPrimitive(QmlPrimitive):
54+
"""A subclass to JAX's Primitive that works like a Python function
55+
when evaluating JVPTracers and BatchTracers."""
56+
57+
def bind_with_trace(self, trace, args, params):
58+
"""Bind the ``NonInterpPrimitive`` with a trace.
59+
60+
If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
61+
Otherwise, the bind call of JAX's standard Primitive is used."""
62+
if isinstance(trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)):
63+
return self.impl(*args, **params)
64+
return super().bind_with_trace(trace, args, params)

pennylane/compiler/qjit_api.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections.abc import Callable
1818

1919
import pennylane as qml
20-
from pennylane.capture.capture_diff import create_non_interpreted_prim
2120
from pennylane.capture.flatfn import FlatFn
2221

2322
from .compiler import (
@@ -405,10 +404,14 @@ def _decorator(body_fn: Callable) -> Callable:
405404
def _get_while_loop_qfunc_prim():
406405
"""Get the while_loop primitive for quantum functions."""
407406

408-
import jax # pylint: disable=import-outside-toplevel
407+
# pylint: disable=import-outside-toplevel
408+
import jax
409409

410-
while_loop_prim = create_non_interpreted_prim()("while_loop")
410+
from pennylane.capture.custom_primitives import NonInterpPrimitive
411+
412+
while_loop_prim = NonInterpPrimitive("while_loop")
411413
while_loop_prim.multiple_results = True
414+
while_loop_prim.prim_type = "higher_order"
412415

413416
@while_loop_prim.def_impl
414417
def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice):
@@ -626,10 +629,14 @@ def _decorator(body_fn):
626629
def _get_for_loop_qfunc_prim():
627630
"""Get the loop_for primitive for quantum functions."""
628631

629-
import jax # pylint: disable=import-outside-toplevel
632+
# pylint: disable=import-outside-toplevel
633+
import jax
634+
635+
from pennylane.capture.custom_primitives import NonInterpPrimitive
630636

631-
for_loop_prim = create_non_interpreted_prim()("for_loop")
637+
for_loop_prim = NonInterpPrimitive("for_loop")
632638
for_loop_prim.multiple_results = True
639+
for_loop_prim.prim_type = "higher_order"
633640

634641
# pylint: disable=too-many-arguments
635642
@for_loop_prim.def_impl

pennylane/measurements/mid_measure.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,12 @@ def _create_mid_measure_primitive():
243243
measurement.
244244
245245
"""
246-
import jax # pylint: disable=import-outside-toplevel
246+
# pylint: disable=import-outside-toplevel
247+
import jax
247248

248-
mid_measure_p = jax.core.Primitive("measure")
249+
from pennylane.capture.custom_primitives import NonInterpPrimitive
250+
251+
mid_measure_p = NonInterpPrimitive("measure")
249252

250253
@mid_measure_p.def_impl
251254
def _(wires, reset=False, postselect=None):

pennylane/ops/op_math/adjoint.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Callable, overload
1919

2020
import pennylane as qml
21-
from pennylane.capture.capture_diff import create_non_interpreted_prim
2221
from pennylane.compiler import compiler
2322
from pennylane.math import conj, moveaxis, transpose
2423
from pennylane.operation import Observable, Operation, Operator
@@ -190,10 +189,14 @@ def create_adjoint_op(fn, lazy):
190189
def _get_adjoint_qfunc_prim():
191190
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
192191
# if capture is enabled, jax should be installed
193-
import jax # pylint: disable=import-outside-toplevel
192+
# pylint: disable=import-outside-toplevel
193+
import jax
194+
195+
from pennylane.capture.custom_primitives import NonInterpPrimitive
194196

195-
adjoint_prim = create_non_interpreted_prim()("adjoint_transform")
197+
adjoint_prim = NonInterpPrimitive("adjoint_transform")
196198
adjoint_prim.multiple_results = True
199+
adjoint_prim.prim_type = "higher_order"
197200

198201
@adjoint_prim.def_impl
199202
def _(*args, jaxpr, lazy, n_consts):

pennylane/ops/op_math/condition.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import pennylane as qml
2222
from pennylane import QueuingManager
23-
from pennylane.capture.capture_diff import create_non_interpreted_prim
2423
from pennylane.capture.flatfn import FlatFn
2524
from pennylane.compiler import compiler
2625
from pennylane.measurements import MeasurementValue
@@ -681,10 +680,14 @@ def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[Measurement
681680
def _get_cond_qfunc_prim():
682681
"""Get the cond primitive for quantum functions."""
683682

684-
import jax # pylint: disable=import-outside-toplevel
683+
# pylint: disable=import-outside-toplevel
684+
import jax
685685

686-
cond_prim = create_non_interpreted_prim()("cond")
686+
from pennylane.capture.custom_primitives import NonInterpPrimitive
687+
688+
cond_prim = NonInterpPrimitive("cond")
687689
cond_prim.multiple_results = True
690+
cond_prim.prim_type = "higher_order"
688691

689692
@cond_prim.def_impl
690693
def _(*all_args, jaxpr_branches, consts_slices, args_slice):

pennylane/ops/op_math/controlled.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import pennylane as qml
2929
from pennylane import math as qmlmath
3030
from pennylane import operation
31-
from pennylane.capture.capture_diff import create_non_interpreted_prim
3231
from pennylane.compiler import compiler
3332
from pennylane.operation import Operator
3433
from pennylane.wires import Wires, WiresLike
@@ -233,10 +232,15 @@ def wrapper(*args, **kwargs):
233232
def _get_ctrl_qfunc_prim():
234233
"""See capture/explanations.md : Higher Order primitives for more information on this code."""
235234
# if capture is enabled, jax should be installed
236-
import jax # pylint: disable=import-outside-toplevel
237235

238-
ctrl_prim = create_non_interpreted_prim()("ctrl_transform")
236+
# pylint: disable=import-outside-toplevel
237+
import jax
238+
239+
from pennylane.capture.custom_primitives import NonInterpPrimitive
240+
241+
ctrl_prim = NonInterpPrimitive("ctrl_transform")
239242
ctrl_prim.multiple_results = True
243+
ctrl_prim.prim_type = "higher_order"
240244

241245
@ctrl_prim.def_impl
242246
def _(*args, n_control, jaxpr, control_values, work_wires, n_consts):

pennylane/transforms/core/transform_dispatcher.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,13 @@ def final_transform(self):
540540
def _create_transform_primitive(name):
541541
try:
542542
# pylint: disable=import-outside-toplevel
543-
import jax
543+
from pennylane.capture.custom_primitives import NonInterpPrimitive
544544
except ImportError:
545545
return None
546546

547-
transform_prim = jax.core.Primitive(name + "_transform")
547+
transform_prim = NonInterpPrimitive(name + "_transform")
548548
transform_prim.multiple_results = True
549+
transform_prim.prim_type = "transform"
549550

550551
@transform_prim.def_impl
551552
def _(

0 commit comments

Comments
 (0)