Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some explanations on NonInterpPrimitive class #6851

Merged
merged 6 commits into from
Jan 21, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 103 additions & 3 deletions pennylane/capture/explanations.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class MyClass(metaclass=MyMetaClass):
self.kwargs = kwargs
```

Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.
Creating a new type <class '__main__.MyClass'> with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': <function MyClass.__init__ at 0x11c59cae0>}), {}.


And that we have set a class property `a`
Expand All @@ -272,7 +272,7 @@ But can we actually create instances of these classes?
```python
>> obj = MyClass(0.1, a=2)
>>> obj
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
creating an instance of type <class '__main__.MyClass'> with (0.1,), {'a': 2}.
now creating an instance in __init__
<__main__.MyClass at 0x11c5a2810>
```
Expand All @@ -294,7 +294,7 @@ class MyClass2(metaclass=MetaClass2):
self.args = args
```

You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.
You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`.

Using a metaclass, we can hijack what happens when a type is called.

Expand Down Expand Up @@ -425,3 +425,103 @@ Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `A
>>> jax.make_jaxpr(PrimitiveClass2)(0.1)
{ lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) }
```

# Non-interpreted primitives

Some of the primitives in the capture module have a somewhat non-standard requirement for the
behaviour under differentiation or batching: they should ignore that an input is a differentiation
or batching tracer and just execute the standard implementation on them.

We will look at an example to make the necessity for such a non-interpreted primitive clear.

Consider a finite-difference differentiation routine together with some test function `fun`.

```python
def finite_diff_impl(x, fun, delta):
"""Finite difference differentiation routine. Only supports differentiating
a function `fun` with a single scalar argument, for simplicity."""

out_plus = fun(x + delta)
out_minus = fun(x - delta)
return tuple((out_p - out_m) / (2 * delta) for out_p, out_m in zip(out_plus, out_minus))

def fun(x):
return (x**2, 4 * x - 3, x**23)
```

Now suppose we want to turn this into a primitive. We could just promote it to a standard
`jax.core.Primitive` as

```python
import jax

fd_prim = jax.core.Primitive("finite_diff")
fd_prim.multiple_results = True
fd_prim.def_impl(finite_diff_impl)

def finite_diff(x, fun, delta=1e-5):
return fd_prim.bind(x, fun, delta)
```

This allows us to use the forward pass as usual (to compute the first-order derivative):

```pycon
>>> finite_diff(1., fun, delta=1e-6)
(2.000000000002, 3.999999999892978, 23.000000001216492)
```

Now if we want to make this primitive differentiable (with automatic
differentiation/backprop, not by using a higher order finite difference scheme),
we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example
that we could use to implement a finite difference scheme that is readily differentiable. This is
somewhat besides the point, because we did not identify a possibility to use any of those
alternatives in the PennyLane code).

However, the finite difference rule is just a standard
algebraic function making use of calls to `fun` and some elementary operations, so ideally
we would like to just use the chain rule as it is known to the AD engine. A JVP rule would
then just manually re-implement this chain rule, which we'd rather not do.

Instead, we define a non-interpreted type of primitives, and create such a primitive
for our finite difference method. We also create the usual method that binds the
primitive to inputs.

```python
class NonInterpPrimitive(jax.core.Primitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonInterpPrimitive`` with a trace.
If the trace is a ``JVPTrace``, it falls back to a standard Python function call.
Otherwise, the bind call of JAX's standard Primitive is used."""
if isinstance(trace, jax.interpreters.ad.JVPTrace):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)

fd_prim_2 = NonInterpPrimitive("finite_diff_2")
fd_prim_2.multiple_results = True
fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer

def finite_diff_2(x, fun, delta=1e-5):
return fd_prim_2.bind(x, fun, delta)
```

Now we can use the primitive in a differentiable workflow, without defining a JVP rule
that just repeats the chain rule:

```pycon
>>> # Forward execution of finite_diff_2 (-> first-order derivative)
>>> finite_diff_2(fun, delta=1e-6)(1.)
(2.000000000002, 3.999999999892978, 23.000000001216492)
>>> # Differentiation of finite_diff_2 (-> second-order derivative)
>>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6)
(Array(1.9375, dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True), Array(498., dtype=float32, weak_type=True))
```

In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, quantum operators
have non-interpreted primitives as well. This is because their differentiation is performed
by the surrounding QNode primitive rather than through the standard chain rule that acts
"locally" (in the circuit). In short, we only want gates to store their tracers (which will help
determining differentiability of gate arguments, for example), but not to do anything with them.

Loading