Skip to content

Commit 8b8ef73

Browse files
Support the generalization of basis state preparation and the facade legacy device for MPI LGPU (#864)
### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the [`tests`](../tests) directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [x] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `.github/CHANGELOG.md` file, summarizing the change, and including a link back to the PR. - [x] Ensure that code is properly formatted by running `make format`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** - PR PennyLaneAI/pennylane#6021 removed code duplication for `BasisEmbedding` and `BasisState`. As the result `BasisState` no longer decomposes to `BasisStatePreparation`. This PR updates Python unit tests to support this generalization of basis state preparation. - PR PennyLaneAI/pennylane#6046 added a facade wrapper class for "legacy" devices. This PR is a follow up to PR #839 updating Multi-GPU LGPU device and tests. **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: ringo-but-quantum <[email protected]>
1 parent 4b79587 commit 8b8ef73

File tree

8 files changed

+54
-58
lines changed

8 files changed

+54
-58
lines changed

.github/CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@
3333

3434
### Improvements
3535

36+
* Update Lightning tests to support the generalization of basis state preparation.
37+
[(#864)](https://github.com/PennyLaneAI/pennylane-lightning/pull/864)
38+
3639
* Multiple calls to the `append_mps_final_state()` API is allowed in `lightning.tensor`.
3740
[(#830)](https://github.com/PennyLaneAI/pennylane-lightning/pull/830)
3841

3942
* Update `generate_samples` in `LightningKokkos` and `LightningGPU` to support `qml.measurements.Shots` type instances.
4043
[(#839)](https://github.com/PennyLaneAI/pennylane-lightning/pull/839)
44+
[(#864)](https://github.com/PennyLaneAI/pennylane-lightning/pull/864)
4145

4246
* LightningQubit gains native support for the `PauliRot` gate.
4347
[(#834)](https://github.com/PennyLaneAI/pennylane-lightning/pull/834)

mpitests/test_adjoint_jacobian.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -649,25 +649,20 @@ def dev(self, request):
649649
batch_obs=request.param[1],
650650
)
651651

652-
def test_finite_shots_warning(self):
653-
"""Tests that a warning is raised when computing the adjoint diff on a device with finite shots"""
652+
def test_finite_shots_error(self):
653+
"""Tests that an error is raised when computing the adjoint diff on a device with finite shots"""
654654

655655
dev = qml.device(device_name, wires=8, mpi=True, shots=1)
656656

657-
with pytest.warns(
658-
UserWarning,
659-
match="Requested adjoint differentiation to be computed with finite shots.",
657+
with pytest.raises(
658+
qml.QuantumFunctionError, match="does not support adjoint with requested circuit."
660659
):
661660

662661
@qml.qnode(dev, diff_method="adjoint")
663662
def circ(x):
664663
qml.RX(x, wires=0)
665664
return qml.expval(qml.PauliZ(0))
666665

667-
with pytest.warns(
668-
UserWarning,
669-
match="Requested adjoint differentiation to be computed with finite shots.",
670-
):
671666
qml.grad(circ)(0.1)
672667

673668
def test_qnode(self, mocker, dev):
@@ -689,7 +684,7 @@ def circuit(x, y, z):
689684
return qml.expval(qml.PauliX(0) @ qml.PauliZ(1))
690685

691686
qnode1 = QNode(circuit, dev, diff_method="adjoint")
692-
spy = mocker.spy(dev, "adjoint_jacobian")
687+
spy = mocker.spy(dev.target_device, "adjoint_jacobian")
693688

694689
grad_fn = qml.grad(qnode1)
695690
grad_A = grad_fn(*args)
@@ -731,7 +726,7 @@ def cost(p1, p2):
731726
zero_state = np.array([1.0, 0.0])
732727
cost(reused_p, other_p)
733728

734-
spy = mocker.spy(dev, "adjoint_jacobian")
729+
spy = mocker.spy(dev.target_device, "adjoint_jacobian")
735730

736731
# analytic gradient
737732
grad_fn = qml.grad(cost)
@@ -770,7 +765,7 @@ def circuit(params):
770765
qml.Rot(params[1], params[0], 2 * params[0], wires=[0])
771766
return qml.expval(qml.PauliX(0))
772767

773-
spy_analytic = mocker.spy(dev, "adjoint_jacobian")
768+
spy_analytic = mocker.spy(dev.target_device, "adjoint_jacobian")
774769

775770
h = 1e-3 if dev.R_DTYPE == np.float32 else 1e-7
776771
tol = 1e-3 if dev.R_DTYPE == np.float32 else 1e-7

mpitests/test_apply.py

+30-29
Original file line numberDiff line numberDiff line change
@@ -672,29 +672,30 @@ def test_sample_dimensions(self, C_DTYPE):
672672
"""
673673
num_wires = numQubits
674674

675-
dev = qml.device("lightning.gpu", wires=num_wires, mpi=True, shots=1000, c_dtype=C_DTYPE)
676-
677-
dev.apply([qml.RX(1.5708, wires=[0]), qml.RX(1.5708, wires=[1])])
678-
679-
dev.shots = 10
680-
dev._wires_measured = {0}
681-
dev._samples = dev.generate_samples()
682-
s1 = dev.sample(qml.PauliZ(wires=[0]))
683-
assert np.array_equal(s1.shape, (10,))
684-
685-
dev.reset()
686-
dev.shots = 12
687-
dev._wires_measured = {1}
688-
dev._samples = dev.generate_samples()
689-
s2 = dev.sample(qml.PauliZ(wires=[1]))
690-
assert np.array_equal(s2.shape, (12,))
691-
692-
dev.reset()
693-
dev.shots = 17
694-
dev._wires_measured = {0, 1}
695-
dev._samples = dev.generate_samples()
696-
s3 = dev.sample(qml.PauliX(0) @ qml.PauliZ(1))
697-
assert np.array_equal(s3.shape, (17,))
675+
dev = qml.device("lightning.gpu", wires=num_wires, mpi=True, c_dtype=C_DTYPE)
676+
677+
ops = [qml.RX(1.5708, wires=[0]), qml.RX(1.5708, wires=[1])]
678+
679+
shots = 10
680+
obs = qml.PauliZ(wires=[0])
681+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
682+
s1 = dev.execute(tape)
683+
684+
assert np.array_equal(s1.shape, (shots,))
685+
686+
shots = 12
687+
obs = qml.PauliZ(wires=[1])
688+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
689+
s2 = dev.execute(tape)
690+
691+
assert np.array_equal(s2.shape, (shots,))
692+
693+
shots = 17
694+
obs = qml.PauliX(0) @ qml.PauliZ(1)
695+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
696+
s3 = dev.execute(tape)
697+
698+
assert np.array_equal(s3.shape, (shots,))
698699

699700
@pytest.mark.parametrize("C_DTYPE", [np.complex128, np.complex64])
700701
def test_sample_values(self, tol, C_DTYPE):
@@ -703,13 +704,13 @@ def test_sample_values(self, tol, C_DTYPE):
703704
"""
704705
num_wires = numQubits
705706

706-
dev = qml.device("lightning.gpu", wires=num_wires, mpi=True, shots=1000, c_dtype=C_DTYPE)
707-
dev.reset()
708-
dev.apply([qml.RX(1.5708, wires=[0])])
709-
dev._wires_measured = {0}
710-
dev._samples = dev.generate_samples()
707+
dev = qml.device("lightning.gpu", wires=num_wires, mpi=True, c_dtype=C_DTYPE)
711708

712-
s1 = dev.sample(qml.PauliZ(0))
709+
shots = qml.measurements.Shots(1000)
710+
ops = [qml.RX(1.5708, wires=[0])]
711+
obs = qml.PauliZ(0)
712+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
713+
s1 = dev.execute(tape)
713714

714715
# s1 should only contain 1 and -1, which is guaranteed if
715716
# they square to 1

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.38.0-dev42"
19+
__version__ = "0.38.0-dev43"

tests/lightning_qubit/test_state_vector_class.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def test_wrong_dtype(dtype):
5858

5959

6060
def test_errors_basis_state():
61-
with pytest.raises(ValueError, match="BasisState parameter must consist of 0 or 1 integers."):
61+
with pytest.raises(ValueError, match="Basis state must only consist of 0s and 1s;"):
6262
state_vector = LightningStateVector(2)
6363
state_vector.apply_operations([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])])
64-
with pytest.raises(ValueError, match="BasisState parameter and wires must be of equal length."):
64+
with pytest.raises(ValueError, match="State must be of length 1;"):
6565
state_vector = LightningStateVector(1)
6666
state_vector.apply_operations([qml.BasisState(np.array([0, 1]), wires=[0])])
6767

tests/lightning_tensor/test_tensornet_class.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def test_wrong_device_name():
5757

5858
def test_errors_basis_state():
5959
"""Test that errors are raised when applying a BasisState operation."""
60-
with pytest.raises(ValueError, match="BasisState parameter must consist of 0 or 1 integers."):
60+
with pytest.raises(ValueError, match="Basis state must only consist of 0s and 1s;"):
6161
tensornet = LightningTensorNet(3, 5)
6262
tensornet.apply_operations([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])])
63-
with pytest.raises(ValueError, match="BasisState parameter and wires must be of equal length."):
63+
with pytest.raises(ValueError, match="State must be of length 1;"):
6464
tensornet = LightningTensorNet(3, 5)
6565
tensornet.apply_operations([qml.BasisState(np.array([0, 1]), wires=[0])])
6666

tests/new_api/test_device.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,9 @@ def test_preprocess_state_prep_first_op_decomposition(self, op, is_trainable):
349349
device = LightningDevice(wires=3)
350350

351351
if is_trainable:
352-
# Need to decompose twice as the state prep ops we use first decompose into a template
353-
decomp = op.decomposition()[0].decomposition()
352+
decomp = op.decomposition()
353+
# decompose one more time if it's decomposed into a template:
354+
decomp = decomp[0].decomposition() if len(decomp) == 1 else decomp
354355
else:
355356
decomp = [op]
356357

@@ -367,7 +368,7 @@ def test_preprocess_state_prep_first_op_decomposition(self, op, is_trainable):
367368
(qml.StatePrep(np.array([1, 0]), wires=0), 1),
368369
(qml.BasisState([1, 1], wires=[0, 1]), 1),
369370
(qml.BasisState(qml.numpy.array([1, 1]), wires=[0, 1]), 1),
370-
(qml.AmplitudeEmbedding([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 2),
371+
(qml.AmplitudeEmbedding([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 1),
371372
(qml.MottonenStatePreparation([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 0),
372373
],
373374
)
@@ -378,8 +379,7 @@ def test_preprocess_state_prep_middle_op_decomposition(self, op, decomp_depth):
378379
)
379380
device = LightningDevice(wires=3)
380381

381-
for _ in range(decomp_depth):
382-
op = op.decomposition()[0]
382+
op = op.decomposition()[0] if decomp_depth and len(op.decomposition()) == 1 else op
383383
decomp = op.decomposition()
384384

385385
program, _ = device.preprocess()

tests/test_apply.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def test_apply_operation_preserve_pointer_two_wires_with_parameters(
488488
def test_apply_errors_qubit_state_vector(self, stateprep, qubit_device):
489489
"""Test that apply fails for incorrect state preparation, and > 2 qubit gates"""
490490
dev = qubit_device(wires=2)
491-
with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."):
491+
with pytest.raises(ValueError, match="The state must be a vector of norm 1.0;"):
492492
dev.apply([stateprep(np.array([1, -1]), wires=[0])])
493493

494494
with pytest.raises(
@@ -500,14 +500,10 @@ def test_apply_errors_qubit_state_vector(self, stateprep, qubit_device):
500500

501501
def test_apply_errors_basis_state(self, qubit_device):
502502
dev = qubit_device(wires=2)
503-
with pytest.raises(
504-
ValueError, match="BasisState parameter must consist of 0 or 1 integers."
505-
):
503+
with pytest.raises(ValueError, match="Basis state must only consist of 0s and 1s;"):
506504
dev.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])])
507505

508-
with pytest.raises(
509-
ValueError, match="BasisState parameter and wires must be of equal length."
510-
):
506+
with pytest.raises(ValueError, match="State must be of length 1;"):
511507
dev.apply([qml.BasisState(np.array([0, 1]), wires=[0])])
512508

513509
with pytest.raises(

0 commit comments

Comments
 (0)