From 54ecc23af6562f029bc60998f08c6c399ddf9e99 Mon Sep 17 00:00:00 2001 From: dwierichs Date: Fri, 17 Jan 2025 10:11:35 +0100 Subject: [PATCH 1/4] add explanations --- pennylane/capture/explanations.md | 133 +++++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 4 deletions(-) diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 84feef9786f..1d4b5066bf5 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -1,4 +1,4 @@ -This documentation explains the principles behind `qml.capture.CaptureMeta` and higher order primitives. +[This](This) documentation explains the principles behind `qml.capture.CaptureMeta` and higher order primitives. ```python @@ -255,7 +255,7 @@ class MyClass(metaclass=MyMetaClass): self.kwargs = kwargs ``` - Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}. + Creating a new type with ('MyClass', (), {'__module__': '__main__', '__qualname__': 'MyClass', '__init__': }), {}. And that we have set a class property `a` @@ -272,7 +272,7 @@ But can we actually create instances of these classes? ```python >> obj = MyClass(0.1, a=2) >>> obj -creating an instance of type with (0.1,), {'a': 2}. +creating an instance of type with (0.1,), {'a': 2}. now creating an instance in __init__ <__main__.MyClass at 0x11c5a2810> ``` @@ -294,7 +294,7 @@ class MyClass2(metaclass=MetaClass2): self.args = args ``` -You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`. +You can see now that instead of actually getting an instance of `MyClass2`, we just get `2.0`. Using a metaclass, we can hijack what happens when a type is called. @@ -425,3 +425,128 @@ Now in our jaxpr, we can see thet `PrimitiveClass2` returns something of type `A >>> jax.make_jaxpr(PrimitiveClass2)(0.1) { lambda ; a:f32[]. let b:AbstractPrimitiveClass() = PrimitiveClass2 a in (b,) } ``` + +# Non-interpreted primitives + +Some of the primitives in the capture module have a somewhat non-standard requirement for the +behaviour under differentiation or batching: they should ignore that an input is a differentiation +or batching tracer and just execute the standard implementation on them. + +We will look at an example to make the necessity for such a non-interpreted primitive clear. + +Consider a finite-difference differentiation routine together with some test function `fun`. + +```python +def finite_diff_impl(x, fun, delta): + """Finite difference differentiation routine. Only supports differentiating + a function `fun` with a single scalar argument, for simplicity.""" + + out_plus = fun(x + delta) + out_minus = fun(x - delta) + return tuple((out_p - out_m) / (2 * delta) for out_p, out_m in zip(out_plus, out_minus)) + +def fun(x): + return (x**2, 4 * x - 3, x**23) +``` + +Now suppose we want to turn this into a primitive. We could just promote it to a standard +`jax.core.Primitive` as + +```python +import jax + +fd_prim = jax.core.Primitive("finite_diff") +fd_prim.multiple_results = True +fd_prim.def_impl(finite_diff_impl) + +def finite_diff(x, fun, delta=1e-5): + return fd_prim.bind(x, fun, delta) +``` + +This allows us to use the forward pass as usual (to compute the first-order derivative): + +```pycon +>>> finite_diff(1., fun, delta=1e-6) +(2.000000000002, 3.999999999892978, 23.000000001216492) +``` + +Now if we want to make this primitive differentiable (with automatic +differentiation/backprop, not by using a higher order finite difference scheme), +we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example +that we could use to implement a finite difference scheme that is readily differentiable. This is +somewhat besides the point, because we did not identify a possibility to use any of those +alternatives in the PennyLane code). + +However, the finite difference rule is just a standard +algebraic function making use of calls to `fun` and some elementary operations, so ideally +we would like to just use the chain rule as it is known to the AD engine. A JVP rule would +then just manually re-implement this chain rule, which we'd rather not do. + +Instead, we define a non-interpreted type of primitives, and create such a primitive +for our finite difference method. We also create the usual method that binds the +primitive to inputs. + +```python +class NonInterpPrimitive(jax.core.Primitive): + """A subclass to JAX's Primitive that works like a Python function + when evaluating JVPTracers.""" + + def bind_with_trace(self, trace, args, params): + """Bind the ``NonInterpPrimitive`` with a trace. + If the trace is a ``JVPTrace``, it falls back to a standard Python function call. + Otherwise, the bind call of JAX's standard Primitive is used.""" + if isinstance(trace, jax.interpreters.ad.JVPTrace): + return self.impl(*args, **params) + return super().bind_with_trace(trace, args, params) + +fd_prim_2 = NonInterpPrimitive("finite_diff_2") +fd_prim_2.multiple_results = True +fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer + +def finite_diff_2(fun, delta=1e-5): + return fd_prim_2.bind(fun, delta) +``` + +Now we can use the primitive in a differentiable workflow, without defining a JVP rule +that just repeats the chain rule: + +```pycon +>>> # Forward execution of finite_diff_2 (-> first-order derivative) +>>> finite_diff_2(fun, delta=1e-6)(1.) +(2.000000000002, 3.999999999892978, 23.000000001216492) +>>> # Differentiation of finite_diff_2 (-> second-order derivative) +>>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6) +(Array(1.9375, dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True), Array(498., dtype=float32, weak_type=True)) +``` + +In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, quantum operators +have non-interpreted primitives as well. This is because their differentiation is performed +by the surrounding QNode primitive rather than through the standard chain rule that acts +"locally" (in the circuit). In short, we only want gates to store their tracers (which will help +determining differentiability of gate arguments, for example), but not to do anything with them. + + + + + + + + + + + + + + + + + + + + + + + + + + From 421c5ab91db3b38f8c298e9183530678b8da969e Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Fri, 17 Jan 2025 21:58:04 +0100 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Mudit Pandey --- pennylane/capture/explanations.md | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 1d4b5066bf5..0c1c9d81bae 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -1,4 +1,4 @@ -[This](This) documentation explains the principles behind `qml.capture.CaptureMeta` and higher order primitives. +This documentation explains the principles behind `qml.capture.CaptureMeta` and higher order primitives. ```python @@ -503,8 +503,8 @@ fd_prim_2 = NonInterpPrimitive("finite_diff_2") fd_prim_2.multiple_results = True fd_prim_2.def_impl(finite_diff_impl) # This also defines the behaviour with a JVP tracer -def finite_diff_2(fun, delta=1e-5): - return fd_prim_2.bind(fun, delta) +def finite_diff_2(x, fun, delta=1e-5): + return fd_prim_2.bind(x, fun, delta) ``` Now we can use the primitive in a differentiable workflow, without defining a JVP rule @@ -525,28 +525,3 @@ by the surrounding QNode primitive rather than through the standard chain rule t "locally" (in the circuit). In short, we only want gates to store their tracers (which will help determining differentiability of gate arguments, for example), but not to do anything with them. - - - - - - - - - - - - - - - - - - - - - - - - - From 8695396f10f5396e1f76766207675d4f93b3f0c4 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Mon, 20 Jan 2025 20:51:32 +0100 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Pietropaolo Frisoni --- pennylane/capture/explanations.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 0c1c9d81bae..3160cd6db1c 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -471,10 +471,10 @@ This allows us to use the forward pass as usual (to compute the first-order deri ``` Now if we want to make this primitive differentiable (with automatic -differentiation/backprop, not by using a higher order finite difference scheme), +differentiation/backprop, not by using a higher-order finite difference scheme), we need to specify a JVP rule. (Note that there are multiple rather simple fixes for this example that we could use to implement a finite difference scheme that is readily differentiable. This is -somewhat besides the point, because we did not identify a possibility to use any of those +somewhat beside the point because we did not identify the possibility of using any of those alternatives in the PennyLane code). However, the finite difference rule is just a standard @@ -482,7 +482,7 @@ algebraic function making use of calls to `fun` and some elementary operations, we would like to just use the chain rule as it is known to the AD engine. A JVP rule would then just manually re-implement this chain rule, which we'd rather not do. -Instead, we define a non-interpreted type of primitives, and create such a primitive +Instead, we define a non-interpreted type of primitive and create such a primitive for our finite difference method. We also create the usual method that binds the primitive to inputs. @@ -512,7 +512,7 @@ that just repeats the chain rule: ```pycon >>> # Forward execution of finite_diff_2 (-> first-order derivative) ->>> finite_diff_2(fun, delta=1e-6)(1.) +>>> finite_diff_2(1., fun, delta=1e-6) (2.000000000002, 3.999999999892978, 23.000000001216492) >>> # Differentiation of finite_diff_2 (-> second-order derivative) >>> jax.jacobian(finite_diff_2)(1., fun, delta=1e-6) @@ -523,5 +523,5 @@ In addition to the differentiation primitives for `qml.jacobian` and `qml.grad`, have non-interpreted primitives as well. This is because their differentiation is performed by the surrounding QNode primitive rather than through the standard chain rule that acts "locally" (in the circuit). In short, we only want gates to store their tracers (which will help -determining differentiability of gate arguments, for example), but not to do anything with them. +determine the differentiability of gate arguments, for example), but not to do anything with them. From c9536091f199cfe3ac3dc98bbc05df22904b518a Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Mon, 20 Jan 2025 20:51:41 +0100 Subject: [PATCH 4/4] Apply suggestions from code review --- pennylane/capture/explanations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/capture/explanations.md b/pennylane/capture/explanations.md index 3160cd6db1c..71033aac3ee 100644 --- a/pennylane/capture/explanations.md +++ b/pennylane/capture/explanations.md @@ -479,7 +479,7 @@ alternatives in the PennyLane code). However, the finite difference rule is just a standard algebraic function making use of calls to `fun` and some elementary operations, so ideally -we would like to just use the chain rule as it is known to the AD engine. A JVP rule would +we would like to just use the chain rule as it is known to the automatic differentiation framework. A JVP rule would then just manually re-implement this chain rule, which we'd rather not do. Instead, we define a non-interpreted type of primitive and create such a primitive