Skip to content

Commit 145c6d7

Browse files
maliasadiringo-but-quantumShiro-Ravenalbi3ro
authored andcommitted
Update generate_samples in LK and LGPU to support qml.measurements.Shots (#839)
**Context:** PR PennyLaneAI/pennylane#6046 wraps the legacy device API automatically in various device creation, qnode, and execute functions. As LK and LGPU plugins still rely on the legacy device API, the shots tests and the `generate_samples` logic in `lightning_kokkos.py` and `lightning_gpu.py` should be updated to adhere the new convention. **Related Shortcut Stories:** [sc-65998] --------- Co-authored-by: ringo-but-quantum <[email protected]> Co-authored-by: Shiro-Raven <[email protected]> Co-authored-by: albi3ro <[email protected]>
1 parent 882cc77 commit 145c6d7

14 files changed

+82
-238
lines changed

.github/CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232

3333
* Multiple calls to the `append_mps_final_state()` API is allowed in `lightning.tensor`.
3434
[(#830)](https://github.com/PennyLaneAI/pennylane-lightning/pull/830)
35+
36+
* Update `generate_samples` in `LightningKokkos` and `LightningGPU` to support `qml.measurements.Shots` type instances.
37+
[(#839)](https://github.com/PennyLaneAI/pennylane-lightning/pull/839)
3538

3639
* LightningQubit gains native support for the `PauliRot` gate.
3740
[(#834)](https://github.com/PennyLaneAI/pennylane-lightning/pull/834)
@@ -142,7 +145,7 @@
142145

143146
This release contains contributions from (in alphabetical order):
144147

145-
Ali Asadi, Astral Cai, Amintor Dusko, Vincent Michaud-Rioux, Erick Ochoa Lopez, Lee J. O'Riordan, Mudit Pandey, Shuli Shu, Raul Torres, Paul Haochen Wang
148+
Ali Asadi, Astral Cai, Ahmed Darwish, Amintor Dusko, Vincent Michaud-Rioux, Erick Ochoa Lopez, Lee J. O'Riordan, Mudit Pandey, Shuli Shu, Raul Torres, Paul Haochen Wang
146149

147150
---
148151

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.38.0-dev40"
19+
__version__ = "0.38.0-dev39"

pennylane_lightning/lightning_gpu/lightning_gpu.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from pennylane_lightning.core.lightning_base import LightningBase
3939

4040
try:
41-
4241
from pennylane_lightning.lightning_gpu_ops import (
4342
DevPool,
4443
MeasurementsC64,
@@ -818,9 +817,9 @@ def generate_samples(self):
818817
array[int]: array of samples in binary representation with shape
819818
``(dev.shots, dev.num_wires)``
820819
"""
821-
return self.measurements.generate_samples(len(self.wires), self.shots).astype(
822-
int, copy=False
823-
)
820+
shots = self.shots if isinstance(self.shots, int) else self.shots.total_shots
821+
822+
return self.measurements.generate_samples(len(self.wires), shots).astype(int, copy=False)
824823

825824
# pylint: disable=protected-access
826825
def expval(self, observable, shot_range=None, bin_size=None):

pennylane_lightning/lightning_kokkos/lightning_kokkos.py

+3
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,9 @@ def generate_samples(self, shots=None):
614614
``(dev.shots, dev.num_wires)``
615615
"""
616616
shots = self.shots if shots is None else shots
617+
618+
shots = shots.total_shots if isinstance(shots, qml.measurements.Shots) else shots
619+
617620
measure = (
618621
MeasurementsC64(self._kokkos_state)
619622
if self.use_csingle

tests/lightning_qubit/test_measurements_samples_MCMC.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,8 @@ def test_mcmc_sample_dimensions(self, dev, num_shots, measured_wires, operation,
4646
the correct dimensions
4747
"""
4848
ops = [qml.RX(1.5708, wires=[0]), qml.RX(1.5708, wires=[1])]
49-
if ld._new_API:
50-
tape = qml.tape.QuantumScript(ops, [qml.sample(op=operation)], shots=num_shots)
51-
s1 = dev.execute(tape)
52-
else:
53-
dev.apply(ops)
54-
dev.shots = num_shots
55-
dev._wires_measured = measured_wires
56-
dev._samples = dev.generate_samples()
57-
s1 = dev.sample(operation)
49+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=operation)], shots=num_shots)
50+
s1 = dev.execute(tape)
5851

5952
assert np.array_equal(s1.shape, (shape,))
6053

@@ -67,14 +60,8 @@ def test_sample_values(self, tol, kernel):
6760
device_name, wires=2, shots=1000, mcmc=True, kernel_name=kernel, num_burnin=100
6861
)
6962
ops = [qml.RX(1.5708, wires=[0])]
70-
if ld._new_API:
71-
tape = qml.tape.QuantumScript(ops, [qml.sample(op=qml.PauliZ(0))], shots=1000)
72-
s1 = dev.execute(tape)
73-
else:
74-
dev.apply([qml.RX(1.5708, wires=[0])])
75-
dev._wires_measured = {0}
76-
dev._samples = dev.generate_samples()
77-
s1 = dev.sample(qml.PauliZ(0))
63+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=qml.PauliZ(0))], shots=1000)
64+
s1 = dev.execute(tape)
7865

7966
# s1 should only contain 1 and -1, which is guaranteed if
8067
# they square to 1

tests/lightning_tensor/test_tensornet_class.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pennylane as qml
2222
import pytest
2323
from conftest import LightningDevice, device_name # tested device
24-
from pennylane import DeviceError
2524
from pennylane.wires import Wires
2625

2726
if device_name != "lightning.tensor":
@@ -88,6 +87,7 @@ def test_errors_apply_operation_state_preparation(operation, par):
8887
tensornet = LightningTensorNet(wires, bondDims)
8988

9089
with pytest.raises(
91-
DeviceError, match="lightning.tensor does not support initialization with a state vector."
90+
qml.DeviceError,
91+
match="lightning.tensor does not support initialization with a state vector.",
9292
):
9393
tensornet.apply_operations([operation(np.array(par), Wires(range(wires)))])

tests/test_adjoint_jacobian.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -700,23 +700,20 @@ def dev(self, request):
700700
return qml.device(device_name, wires=2, c_dtype=request.param)
701701

702702
@pytest.mark.skipif(ld._new_API, reason="Old API required")
703-
def test_finite_shots_warning(self):
704-
"""Tests that a warning is raised when computing the adjoint diff on a device with finite shots"""
703+
def test_finite_shots_error(self):
704+
"""Tests that an error is raised when computing the adjoint diff on a device with finite shots"""
705705

706706
dev = qml.device(device_name, wires=1, shots=1)
707707

708-
with pytest.warns(
709-
UserWarning, match="Requested adjoint differentiation to be computed with finite shots."
708+
with pytest.raises(
709+
qml.QuantumFunctionError, match="does not support adjoint with requested circuit."
710710
):
711711

712712
@qml.qnode(dev, diff_method="adjoint")
713713
def circ(x):
714714
qml.RX(x, wires=0)
715715
return qml.expval(qml.PauliZ(0))
716716

717-
with pytest.warns(
718-
UserWarning, match="Requested adjoint differentiation to be computed with finite shots."
719-
):
720717
qml.grad(circ)(0.1)
721718

722719
def test_qnode(self, mocker, dev):
@@ -741,7 +738,7 @@ def circuit(x, y, z):
741738
spy = (
742739
mocker.spy(dev, "execute_and_compute_derivatives")
743740
if ld._new_API
744-
else mocker.spy(dev, "adjoint_jacobian")
741+
else mocker.spy(dev.target_device, "adjoint_jacobian")
745742
)
746743
tol, h = get_tolerance_and_stepsize(dev, step_size=True)
747744

@@ -926,7 +923,7 @@ def cost(p1, p2):
926923
if ld._new_API:
927924
spy = mocker.spy(dev, "execute_and_compute_derivatives")
928925
else:
929-
spy = mocker.spy(dev, "adjoint_jacobian")
926+
spy = mocker.spy(dev.target_device, "adjoint_jacobian")
930927

931928
# analytic gradient
932929
grad_fn = qml.grad(cost)
@@ -968,7 +965,7 @@ def circuit(params):
968965
spy_analytic = (
969966
mocker.spy(dev, "execute_and_compute_derivatives")
970967
if ld._new_API
971-
else mocker.spy(dev, "adjoint_jacobian")
968+
else mocker.spy(dev.target_device, "adjoint_jacobian")
972969
)
973970
tol, h = get_tolerance_and_stepsize(dev, step_size=True)
974971

tests/test_apply.py

+23-78
Original file line numberDiff line numberDiff line change
@@ -566,13 +566,8 @@ def test_expval_single_wire_no_parameters(
566566
dev = qubit_device(wires=1)
567567
obs = operation(wires=[0])
568568
ops = [stateprep(np.array(input), wires=[0])]
569-
if ld._new_API:
570-
tape = qml.tape.QuantumScript(ops, [qml.expval(op=obs)])
571-
res = dev.execute(tape)
572-
else:
573-
dev.reset()
574-
dev.apply(ops, obs.diagonalizing_gates())
575-
res = dev.expval(obs)
569+
tape = qml.tape.QuantumScript(ops, [qml.expval(op=obs)])
570+
res = dev.execute(tape)
576571

577572
assert np.isclose(res, expected_output, atol=tol, rtol=0)
578573

@@ -630,13 +625,8 @@ def test_var_single_wire_no_parameters(
630625
dev = qubit_device(wires=1)
631626
obs = operation(wires=[0])
632627
ops = [stateprep(np.array(input), wires=[0])]
633-
if ld._new_API:
634-
tape = qml.tape.QuantumScript(ops, [qml.var(op=obs)])
635-
res = dev.execute(tape)
636-
else:
637-
dev.reset()
638-
dev.apply(ops, obs.diagonalizing_gates())
639-
res = dev.var(obs)
628+
tape = qml.tape.QuantumScript(ops, [qml.var(op=obs)])
629+
res = dev.execute(tape)
640630

641631
assert np.isclose(res, expected_output, atol=tol, rtol=0)
642632

@@ -680,42 +670,22 @@ def test_sample_dimensions(self, qubit_device):
680670

681671
shots = 10
682672
obs = qml.PauliZ(wires=[0])
683-
if ld._new_API:
684-
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
685-
s1 = dev.execute(tape)
686-
else:
687-
dev.reset()
688-
dev.apply(ops)
689-
dev.shots = shots
690-
dev._wires_measured = {0}
691-
dev._samples = dev.generate_samples()
692-
s1 = dev.sample(obs)
673+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
674+
s1 = dev.execute(tape)
675+
693676
assert np.array_equal(s1.shape, (shots,))
694677

695678
shots = 12
696679
obs = qml.PauliZ(wires=[1])
697-
if ld._new_API:
698-
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
699-
s2 = dev.execute(tape)
700-
else:
701-
dev.reset()
702-
dev.shots = shots
703-
dev._wires_measured = {1}
704-
dev._samples = dev.generate_samples()
705-
s2 = dev.sample(qml.PauliZ(wires=[1]))
680+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
681+
s2 = dev.execute(tape)
706682
assert np.array_equal(s2.shape, (shots,))
707683

708684
shots = 17
709685
obs = qml.PauliX(0) @ qml.PauliZ(1)
710-
if ld._new_API:
711-
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
712-
s3 = dev.execute(tape)
713-
else:
714-
dev.reset()
715-
dev.shots = shots
716-
dev._wires_measured = {0, 1}
717-
dev._samples = dev.generate_samples()
718-
s3 = dev.sample(qml.PauliZ(wires=[1]))
686+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
687+
s3 = dev.execute(tape)
688+
719689
assert np.array_equal(s3.shape, (shots,))
720690

721691
def test_sample_values(self, qubit_device, tol):
@@ -730,18 +700,10 @@ def test_sample_values(self, qubit_device, tol):
730700

731701
ops = [qml.RX(1.5708, wires=[0])]
732702

733-
shots = 1000
703+
shots = qml.measurements.Shots(1000)
734704
obs = qml.PauliZ(0)
735-
if ld._new_API:
736-
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
737-
s1 = dev.execute(tape)
738-
else:
739-
dev.reset()
740-
dev.apply(ops)
741-
dev.shots = shots
742-
dev._wires_measured = {0}
743-
dev._samples = dev.generate_samples()
744-
s1 = dev.sample(obs)
705+
tape = qml.tape.QuantumScript(ops, [qml.sample(op=obs)], shots=shots)
706+
s1 = dev.execute(tape)
745707

746708
# s1 should only contain 1 and -1, which is guaranteed if
747709
# they square to 1
@@ -756,13 +718,8 @@ def test_load_default_qubit_device(self):
756718
"""Test that the default plugin loads correctly"""
757719

758720
dev = qml.device(device_name, wires=2)
759-
if dev._new_API:
760-
assert not dev.shots
761-
assert len(dev.wires) == 2
762-
else:
763-
assert dev.shots is None
764-
assert dev.num_wires == 2
765-
assert dev.short_name == device_name
721+
assert not dev.shots
722+
assert len(dev.wires) == 2
766723

767724
@pytest.mark.xfail(ld._new_API, reason="Old device API required.")
768725
def test_no_backprop(self):
@@ -1276,14 +1233,10 @@ def test_multi_samples_return_correlated_results(self, qubit_device):
12761233
def circuit():
12771234
qml.Hadamard(0)
12781235
qml.CNOT(wires=[0, 1])
1279-
if ld._new_API:
1280-
return qml.sample(wires=[0, 1])
1281-
else:
1282-
return qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliZ(1))
1236+
return qml.sample(wires=[0, 1])
12831237

12841238
outcomes = circuit()
1285-
if ld._new_API:
1286-
outcomes = outcomes.T
1239+
outcomes = outcomes.T
12871240

12881241
assert np.array_equal(outcomes[0], outcomes[1])
12891242

@@ -1305,14 +1258,10 @@ def test_multi_samples_return_correlated_results_more_wires_than_size_of_observa
13051258
def circuit():
13061259
qml.Hadamard(0)
13071260
qml.CNOT(wires=[0, 1])
1308-
if ld._new_API:
1309-
return qml.sample(wires=[0, 1])
1310-
else:
1311-
return qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliZ(1))
1261+
return qml.sample(wires=[0, 1])
13121262

13131263
outcomes = circuit()
1314-
if ld._new_API:
1315-
outcomes = outcomes.T
1264+
outcomes = outcomes.T
13161265

13171266
assert np.array_equal(outcomes[0], outcomes[1])
13181267

@@ -1350,14 +1299,10 @@ def circuit():
13501299
qml.Snapshot()
13511300
qml.adjoint(qml.Snapshot())
13521301
qml.CNOT(wires=[0, 1])
1353-
if ld._new_API:
1354-
return qml.sample(wires=[0, 1])
1355-
else:
1356-
return qml.sample(qml.PauliZ(0)), qml.sample(qml.PauliZ(1))
1302+
return qml.sample(wires=[0, 1])
13571303

13581304
outcomes = circuit()
1359-
if ld._new_API:
1360-
outcomes = outcomes.T
1305+
outcomes = outcomes.T
13611306

13621307
assert np.array_equal(outcomes[0], outcomes[1])
13631308

0 commit comments

Comments
 (0)