Skip to content

Commit 50b7fa9

Browse files
mudit2812github-actions[bot]maliasadi
authored
Fix decompositions with LightningQubit (#687)
* Fix bug; add tests; update changelog * Auto update version * Update tests * Update qft/grover decomp * Addressing code review * Update pennylane_lightning/lightning_qubit/lightning_qubit.py Co-authored-by: Ali Asadi <[email protected]> * Trigger CI * Auto update version --------- Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com> Co-authored-by: Ali Asadi <[email protected]>
1 parent 201d14a commit 50b7fa9

File tree

5 files changed

+110
-16
lines changed

5 files changed

+110
-16
lines changed

.github/CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@
7777

7878
### Bug fixes
7979

80+
* `LightningQubit` correctly decomposes state prep operations when used in the middle of a circuit.
81+
[(#687)](https://github.com/PennyLaneAI/pennylane/pull/687)
82+
83+
* `LightningQubit` correctly decomposes `qml.QFT` and `qml.GroverOperator` if `len(wires)` is greater than 9 and 12 respectively.
84+
[(#687)](https://github.com/PennyLaneAI/pennylane/pull/687)
85+
8086
* Specify `isort` `--py` (Python version) and `-l` (max line length) to stabilize `isort` across Python versions and environments.
8187
[(#647)](https://github.com/PennyLaneAI/pennylane-lightning/pull/647)
8288

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.36.0-dev31"
19+
__version__ = "0.36.0-dev32"

pennylane_lightning/lightning_qubit/lightning_qubit.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,6 @@ def simulate_and_vjp(
182182
_operations = frozenset(
183183
{
184184
"Identity",
185-
"BasisState",
186-
"QubitStateVector",
187-
"StatePrep",
188185
"QubitUnitary",
189186
"ControlledQubitUnitary",
190187
"MultiControlledX",
@@ -292,13 +289,20 @@ def simulate_and_vjp(
292289

293290
def stopping_condition(op: Operator) -> bool:
294291
"""A function that determines whether or not an operation is supported by ``lightning.qubit``."""
292+
# These thresholds are adapted from `lightning_base.py`
293+
# To avoid building matrices beyond the given thresholds.
294+
# This should reduce runtime overheads for larger systems.
295+
if isinstance(op, qml.QFT):
296+
return len(op.wires) < 10
297+
if isinstance(op, qml.GroverOperator):
298+
return len(op.wires) < 13
295299
return op.name in _operations
296300

297301

298302
def stopping_condition_shots(op: Operator) -> bool:
299303
"""A function that determines whether or not an operation is supported by ``lightning.qubit``
300304
with finite shots."""
301-
return op.name in _operations or isinstance(op, (MidMeasureMP, qml.ops.op_math.Conditional))
305+
return stopping_condition(op) or isinstance(op, (MidMeasureMP, qml.ops.op_math.Conditional))
302306

303307

304308
def accepted_observables(obs: Operator) -> bool:
@@ -536,6 +540,7 @@ def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig)
536540
decompose,
537541
stopping_condition=stopping_condition,
538542
stopping_condition_shots=stopping_condition_shots,
543+
skip_initial_state_prep=True,
539544
name=self.name,
540545
)
541546
program.add_transform(qml.transforms.broadcast_expand)

tests/new_api/test_device.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
adjoint_measurements,
3232
adjoint_observables,
3333
decompose,
34+
mid_circuit_measurements,
3435
no_sampling,
3536
stopping_condition,
3637
stopping_condition_shots,
@@ -258,13 +259,12 @@ def test_preprocess(self, adjoint):
258259
expected_program.add_transform(validate_measurements, name=device.name)
259260
expected_program.add_transform(validate_observables, accepted_observables, name=device.name)
260261
expected_program.add_transform(validate_device_wires, device.wires, name=device.name)
261-
expected_program.add_transform(
262-
qml.devices.preprocess.mid_circuit_measurements, device=device
263-
)
262+
expected_program.add_transform(mid_circuit_measurements, device=device)
264263
expected_program.add_transform(
265264
decompose,
266265
stopping_condition=stopping_condition,
267266
stopping_condition_shots=stopping_condition_shots,
267+
skip_initial_state_prep=True,
268268
name=device.name,
269269
)
270270
expected_program.add_transform(qml.transforms.broadcast_expand)
@@ -293,6 +293,63 @@ def test_preprocess(self, adjoint):
293293
actual_program, _ = device.preprocess(config)
294294
assert actual_program == expected_program
295295

296+
@pytest.mark.parametrize(
297+
"op, is_trainable",
298+
[
299+
(qml.StatePrep([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), False),
300+
(qml.StatePrep(qml.numpy.array([1 / np.sqrt(2), 1 / np.sqrt(2)]), wires=0), True),
301+
(qml.StatePrep(np.array([1, 0]), wires=0), False),
302+
(qml.BasisState([1, 1], wires=[0, 1]), False),
303+
(qml.BasisState(qml.numpy.array([1, 1]), wires=[0, 1]), True),
304+
],
305+
)
306+
def test_preprocess_state_prep_first_op_decomposition(self, op, is_trainable):
307+
"""Test that state prep ops in the beginning of a tape are decomposed with adjoint
308+
but not otherwise."""
309+
tape = qml.tape.QuantumScript([op, qml.RX(1.23, wires=0)], [qml.expval(qml.PauliZ(0))])
310+
device = LightningDevice(wires=3)
311+
312+
if is_trainable:
313+
# Need to decompose twice as the state prep ops we use first decompose into a template
314+
decomp = op.decomposition()[0].decomposition()
315+
else:
316+
decomp = [op]
317+
318+
config = ExecutionConfig(gradient_method="best" if is_trainable else None)
319+
program, _ = device.preprocess(config)
320+
[new_tape], _ = program([tape])
321+
expected_tape = qml.tape.QuantumScript([*decomp, qml.RX(1.23, wires=0)], tape.measurements)
322+
assert qml.equal(new_tape, expected_tape)
323+
324+
@pytest.mark.parametrize(
325+
"op, decomp_depth",
326+
[
327+
(qml.StatePrep([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 1),
328+
(qml.StatePrep(np.array([1, 0]), wires=0), 1),
329+
(qml.BasisState([1, 1], wires=[0, 1]), 1),
330+
(qml.BasisState(qml.numpy.array([1, 1]), wires=[0, 1]), 1),
331+
(qml.AmplitudeEmbedding([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 2),
332+
(qml.MottonenStatePreparation([1 / np.sqrt(2), 1 / np.sqrt(2)], wires=0), 0),
333+
],
334+
)
335+
def test_preprocess_state_prep_middle_op_decomposition(self, op, decomp_depth):
336+
"""Test that state prep ops in the middle of a tape are always decomposed."""
337+
tape = qml.tape.QuantumScript(
338+
[qml.RX(1.23, wires=0), op, qml.CNOT([0, 1])], [qml.expval(qml.PauliZ(0))]
339+
)
340+
device = LightningDevice(wires=3)
341+
342+
for _ in range(decomp_depth):
343+
op = op.decomposition()[0]
344+
decomp = op.decomposition()
345+
346+
program, _ = device.preprocess()
347+
[new_tape], _ = program([tape])
348+
expected_tape = qml.tape.QuantumScript(
349+
[qml.RX(1.23, wires=0), *decomp, qml.CNOT([0, 1])], tape.measurements
350+
)
351+
assert qml.equal(new_tape, expected_tape)
352+
296353
@pytest.mark.usefixtures("use_legacy_and_new_opmath")
297354
@pytest.mark.parametrize("theta, phi", list(zip(THETA, PHI)))
298355
@pytest.mark.parametrize(

tests/test_templates.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ def circuit(omega):
6262
assert np.allclose(np.sum(prob), 1.0)
6363
assert prob[index] > 0.95
6464

65+
@pytest.mark.skipif(not LightningDevice._new_API, reason="New API required.")
66+
@pytest.mark.parametrize("wires", [5, 10, 13, 15])
67+
def test_preprocess_grover_operator_decomposition(self, wires):
68+
"""Test that qml.GroverOperator is not decomposed for less than 10 wires."""
69+
tape = qml.tape.QuantumScript(
70+
[qml.GroverOperator(wires=list(range(wires)))], [qml.expval(qml.PauliZ(0))]
71+
)
72+
dev = LightningDevice(wires=wires)
73+
74+
program, _ = dev.preprocess()
75+
[new_tape], _ = program([tape])
76+
77+
if wires >= 13:
78+
assert all(not isinstance(op, qml.GroverOperator) for op in new_tape.operations)
79+
else:
80+
assert tape.operations == [qml.GroverOperator(wires=list(range(wires)))]
81+
6582

6683
class TestAngleEmbedding:
6784
"""Test the AngleEmbedding algorithm."""
@@ -416,7 +433,6 @@ class TestGateFabric:
416433
"""Test the GateFabric algorithm."""
417434

418435
def test_gatefabric(self):
419-
420436
# Build the electronic Hamiltonian
421437
symbols = ["H", "H"]
422438
coordinates = np.array([0.0, 0.0, -0.6614, 0.0, 0.0, 0.6614])
@@ -446,7 +462,6 @@ class TestUCCSD:
446462
"""Test the UCCSD algorithm."""
447463

448464
def test_uccsd(self):
449-
450465
# Define the molecule
451466
symbols = ["H", "H", "H"]
452467
geometry = np.array(
@@ -490,7 +505,6 @@ class TestkUpCCGSD:
490505
"""Test the kUpCCGSD algorithm."""
491506

492507
def test_kupccgsd(self):
493-
494508
# Define the molecule
495509
symbols = ["H", "H", "H"]
496510
geometry = np.array(
@@ -533,7 +547,6 @@ class TestParticleConservingU1:
533547
"""Test the ParticleConservingU1 algorithm."""
534548

535549
def test_particleconservingu1(self):
536-
537550
# Build the electronic Hamiltonian
538551
symbols, coordinates = (["H", "H"], np.array([0.0, 0.0, -0.66140414, 0.0, 0.0, 0.66140414]))
539552
_, n_qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)
@@ -567,7 +580,6 @@ class TestParticleConservingU2:
567580
"""Test the ParticleConservingU2 algorithm."""
568581

569582
def test_particleconservingu2(self):
570-
571583
# Build the electronic Hamiltonian
572584
symbols, coordinates = (["H", "H"], np.array([0.0, 0.0, -0.66140414, 0.0, 0.0, 0.66140414]))
573585
_, n_qubits = qml.qchem.molecular_hamiltonian(symbols, coordinates)
@@ -668,7 +680,6 @@ class TestQuantumPhaseEstimation:
668680

669681
@pytest.mark.parametrize("n_qubits", range(2, 14, 2))
670682
def test_quantumphaseestimation(self, n_qubits):
671-
672683
phase = 5
673684
target_wires = [0]
674685
unitary = qml.RX(phase, wires=0).matrix()
@@ -701,7 +712,6 @@ class TestQFT:
701712

702713
@pytest.mark.parametrize("n_qubits", range(2, 14, 2))
703714
def test_qft(self, n_qubits):
704-
705715
dev = qml.device(device_name, wires=n_qubits)
706716
dq = qml.device("default.qubit")
707717

@@ -717,13 +727,29 @@ def circuit(basis_state):
717727

718728
assert np.allclose(res, ref)
719729

730+
@pytest.mark.skipif(not LightningDevice._new_API, reason="New API required")
731+
@pytest.mark.parametrize("wires", [5, 9, 10, 13])
732+
def test_preprocess_qft_decomposition(self, wires):
733+
"""Test that qml.QFT is not decomposed for less than 10 wires."""
734+
tape = qml.tape.QuantumScript(
735+
[qml.QFT(wires=list(range(wires)))], [qml.expval(qml.PauliZ(0))]
736+
)
737+
dev = LightningDevice(wires=wires)
738+
739+
program, _ = dev.preprocess()
740+
[new_tape], _ = program([tape])
741+
742+
if wires >= 10:
743+
assert all(not isinstance(op, qml.QFT) for op in new_tape.operations)
744+
else:
745+
assert tape.operations == [qml.QFT(wires=list(range(wires)))]
746+
720747

721748
class TestAQFT:
722749
"""Test the AQFT algorithm."""
723750

724751
@pytest.mark.parametrize("n_qubits", range(4, 14, 2))
725752
def test_aqft(self, n_qubits):
726-
727753
dev = qml.device(device_name, wires=n_qubits)
728754
dq = qml.device("default.qubit")
729755

0 commit comments

Comments
 (0)