Skip to content

Commit 7e77c5d

Browse files
mudit2812github-actions[bot]maliasadivincentmr
authored
Fix state prep operation decomposition with LightningQubit (#661)
* Fixed LQ adjoint decomp * Auto update version * Trigger CI * Apply suggestions from code review Co-authored-by: Ali Asadi <[email protected]> * Auto update version * trigger ci * trigger ci * Added fix for MCM with adjoint * Fixed LQ tests --------- Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com> Co-authored-by: Ali Asadi <[email protected]> Co-authored-by: Vincent Michaud-Rioux <[email protected]> Co-authored-by: Vincent Michaud-Rioux <[email protected]>
1 parent 5f95a2f commit 7e77c5d

File tree

4 files changed

+63
-3
lines changed

4 files changed

+63
-3
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
* Fix random `coverage xml` CI issues.
5151
[(#635)](https://github.com/PennyLaneAI/pennylane-lightning/pull/635)
5252

53+
* `lightning.qubit` correctly decomposed state preparation operations with adjoint differentiation.
54+
[(#661)](https://github.com/PennyLaneAI/pennylane-lightning/pull/661)
55+
5356
### Contributors
5457

5558
This release contains contributions from (in alphabetical order):

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.36.0-dev17"
19+
__version__ = "0.36.0-dev18"

pennylane_lightning/lightning_qubit/lightning_qubit.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _supports_adjoint(circuit):
263263

264264
try:
265265
prog((circuit,))
266-
except (qml.operation.DecompositionUndefinedError, qml.DeviceError):
266+
except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError):
267267
return False
268268
return True
269269

@@ -282,7 +282,9 @@ def _add_adjoint_transforms(program: TransformProgram) -> None:
282282

283283
name = "adjoint + lightning.qubit"
284284
program.add_transform(no_sampling, name=name)
285-
program.add_transform(decompose, stopping_condition=adjoint_ops, name=name)
285+
program.add_transform(
286+
decompose, stopping_condition=adjoint_ops, name=name, skip_initial_state_prep=False
287+
)
286288
program.add_transform(validate_observables, accepted_observables, name=name)
287289
program.add_transform(
288290
validate_measurements, analytic_measurements=adjoint_measurements, name=name

tests/new_api/test_device.py

+55
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_add_adjoint_transforms(self):
8686
decompose,
8787
stopping_condition=adjoint_ops,
8888
name=name,
89+
skip_initial_state_prep=False,
8990
)
9091
expected_program.add_transform(validate_observables, accepted_observables, name=name)
9192
expected_program.add_transform(
@@ -248,6 +249,7 @@ def test_preprocess(self, adjoint):
248249
decompose,
249250
stopping_condition=adjoint_ops,
250251
name=name,
252+
skip_initial_state_prep=False,
251253
)
252254
expected_program.add_transform(validate_observables, accepted_observables, name=name)
253255
expected_program.add_transform(
@@ -554,6 +556,59 @@ def test_derivatives_no_trainable_params(self, dev, execute_and_derivatives, bat
554556
assert len(jac) == 1
555557
assert qml.math.shape(jac[0]) == (0,)
556558

559+
@pytest.mark.parametrize("execute_and_derivatives", [True, False])
560+
@pytest.mark.parametrize(
561+
"state_prep, params, wires",
562+
[
563+
(qml.BasisState, [1, 1], [0, 1]),
564+
(qml.StatePrep, [0.0, 0.0, 0.0, 1.0], [0, 1]),
565+
(qml.StatePrep, qml.numpy.array([0.0, 1.0]), [1]),
566+
],
567+
)
568+
@pytest.mark.parametrize(
569+
"trainable_params",
570+
[(0, 1, 2), (1, 2)],
571+
)
572+
def test_state_prep_ops(
573+
self, dev, state_prep, params, wires, execute_and_derivatives, batch_obs, trainable_params
574+
):
575+
"""Test that a circuit containing state prep operations is differentiated correctly."""
576+
qs = QuantumScript(
577+
[state_prep(params, wires), qml.RX(1.23, 0), qml.CNOT([0, 1]), qml.RX(4.56, 1)],
578+
[qml.expval(qml.PauliZ(1))],
579+
)
580+
581+
config = ExecutionConfig(gradient_method="adjoint", device_options={"batch_obs": batch_obs})
582+
program, new_config = dev.preprocess(config)
583+
tapes, fn = program([qs])
584+
tapes[0].trainable_params = trainable_params
585+
if execute_and_derivatives:
586+
res, jac = dev.execute_and_compute_derivatives(tapes, new_config)
587+
res = fn(res)
588+
else:
589+
res, jac = (
590+
fn(dev.execute(tapes, new_config)),
591+
dev.compute_derivatives(tapes, new_config),
592+
)
593+
594+
dev_ref = DefaultQubit(max_workers=1)
595+
config = ExecutionConfig(gradient_method="adjoint")
596+
program, new_config = dev_ref.preprocess(config)
597+
tapes, fn = program([qs])
598+
tapes[0].trainable_params = trainable_params
599+
if execute_and_derivatives:
600+
expected, expected_jac = dev_ref.execute_and_compute_derivatives(tapes, new_config)
601+
expected = fn(expected)
602+
else:
603+
expected, expected_jac = (
604+
fn(dev_ref.execute(tapes, new_config)),
605+
dev_ref.compute_derivatives(tapes, new_config),
606+
)
607+
608+
tol = 1e-5 if dev.c_dtype == np.complex64 else 1e-7
609+
assert np.allclose(res, expected, atol=tol, rtol=0)
610+
assert np.allclose(jac, expected_jac, atol=tol, rtol=0)
611+
557612
def test_state_jacobian_not_supported(self, dev, batch_obs):
558613
"""Test that an error is raised if derivatives are requested for state measurement"""
559614
qs = QuantumScript([qml.RX(1.23, 0)], [qml.state()], trainable_params=[0])

0 commit comments

Comments
 (0)