Skip to content

Commit 1293b96

Browse files
Add Mid-Measurement support to single-GPU LGPU (#931)
### Before submitting Please complete the following checklist when submitting a PR: - [ ] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the [`tests`](../tests) directory! - [ ] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [ ] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `.github/CHANGELOG.md` file, summarizing the change, and including a link back to the PR. - [x] Ensure that code is properly formatted by running `make format`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** [SC-74842] **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: ringo-but-quantum <[email protected]>
1 parent e1ec3ad commit 1293b96

12 files changed

+208
-14
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
### New features since last release
44

5+
* Add `mid-circuit measurements` support to `lightning.gpu`'s single-GPU backend.
6+
[(#931)](https://github.com/PennyLaneAI/pennylane-lightning/pull/931)
7+
58
* Add Matrix Product Operator (MPO) for all gates support to `lightning.tensor`. Note current C++ implementation only works for MPO sites data provided by users.
69
[(#859)](https://github.com/PennyLaneAI/pennylane-lightning/pull/859)
710

mpitests/test_native_mcm.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2024 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for default qubit preprocessing."""
15+
import numpy as np
16+
import pennylane as qml
17+
import pytest
18+
from conftest import LightningDevice, device_name
19+
from mpi4py import MPI
20+
21+
if not LightningDevice._CPP_BINARY_AVAILABLE: # pylint: disable=protected-access
22+
pytest.skip("No binary module found. Skipping.", allow_module_level=True)
23+
24+
25+
def test_unspported_mid_measurement():
26+
"""Test unsupported mid_measurement for Lightning-GPU-MPI."""
27+
comm = MPI.COMM_WORLD
28+
dev = qml.device(device_name, wires=2, mpi=True, shots=1000)
29+
params = np.pi / 4 * np.ones(2)
30+
31+
@qml.qnode(dev)
32+
def func(x, y):
33+
qml.RX(x, wires=0)
34+
m0 = qml.measure(0)
35+
qml.cond(m0, qml.RY)(y, wires=1)
36+
return qml.probs(wires=0)
37+
38+
comm.Barrier()
39+
40+
with pytest.raises(
41+
qml.DeviceError, match="Lightning-GPU-MPI does not support Mid-circuit measurements."
42+
):
43+
func(*params)

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.39.0-dev43"
19+
__version__ = "0.39.0-dev44"

pennylane_lightning/core/src/simulators/lightning_gpu/StateVectorCudaManaged.hpp

+50
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,56 @@ class StateVectorCudaManaged
487487
applyMatrix(gate_matrix.data(), wires, adjoint);
488488
}
489489

490+
/**
491+
* @brief Collapse the state vector after having measured one of the qubit.
492+
*
493+
* Note: The branch parameter imposes the measurement result on the given
494+
* wire.
495+
*
496+
* @param wire Wire to measure.
497+
* @param branch Branch 0 or 1.
498+
*/
499+
void collapse(std::size_t wire, bool branch) {
500+
PL_ABORT_IF_NOT(wire < BaseType::getNumQubits(), "Invalid wire index.");
501+
cudaDataType_t data_type;
502+
503+
if constexpr (std::is_same_v<CFP_t, cuDoubleComplex> ||
504+
std::is_same_v<CFP_t, double2>) {
505+
data_type = CUDA_C_64F;
506+
} else {
507+
data_type = CUDA_C_32F;
508+
}
509+
510+
std::vector<int> basisBits(1, BaseType::getNumQubits() - 1 - wire);
511+
512+
double abs2sum0;
513+
double abs2sum1;
514+
515+
PL_CUSTATEVEC_IS_SUCCESS(custatevecAbs2SumOnZBasis(
516+
/* custatevecHandle_t */ handle_.get(),
517+
/* void *sv */ BaseType::getData(),
518+
/* cudaDataType_t */ data_type,
519+
/* const uint32_t nIndexBits */ BaseType::getNumQubits(),
520+
/* double * */ &abs2sum0,
521+
/* double * */ &abs2sum1,
522+
/* const int32_t * */ basisBits.data(),
523+
/* const uint32_t nBasisBits */ basisBits.size()));
524+
525+
const double norm = branch ? abs2sum1 : abs2sum0;
526+
527+
const int parity = static_cast<int>(branch);
528+
529+
PL_CUSTATEVEC_IS_SUCCESS(custatevecCollapseOnZBasis(
530+
/* custatevecHandle_t */ handle_.get(),
531+
/* void *sv */ BaseType::getData(),
532+
/* cudaDataType_t */ data_type,
533+
/* const uint32_t nIndexBits */ BaseType::getNumQubits(),
534+
/* const int32_t parity */ parity,
535+
/* const int32_t *basisBits */ basisBits.data(),
536+
/* const uint32_t nBasisBits */ basisBits.size(),
537+
/* double norm */ norm));
538+
}
539+
490540
//****************************************************************************//
491541
// Explicit gate calls for bindings
492542
//****************************************************************************//

pennylane_lightning/core/src/simulators/lightning_gpu/bindings/LGPUBindings.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
150150
},
151151
py::arg("async") = false,
152152
"Initialize the statevector data to the |0...0> state")
153+
.def("collapse", &StateVectorT::collapse,
154+
"Collapse the statevector onto the 0 or 1 branch of a given wire.")
153155
.def(
154156
"apply",
155157
[](StateVectorT &sv, const std::string &str,

pennylane_lightning/core/src/simulators/lightning_gpu/initSV.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void setBasisState_CUDA(cuDoubleComplex *sv, cuDoubleComplex &value,
5959
cudaStream_t stream_id);
6060

6161
/**
62-
* @brief The CUDA kernel that setS state vector data on GPU device from the
62+
* @brief The CUDA kernel that sets state vector data on GPU device from the
6363
* input values (on device) and their corresponding indices (on device)
6464
* information.
6565
*

pennylane_lightning/core/src/simulators/lightning_gpu/tests/Test_StateVectorCudaManaged.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,45 @@ TEMPLATE_TEST_CASE("StateVectorCudaManaged::StateVectorCudaManaged",
266266
REQUIRE(std::is_constructible_v<StateVectorT, const StateVectorT &>);
267267
}
268268
}
269+
270+
TEMPLATE_TEST_CASE("StateVectorCudaManaged::collapse",
271+
"[StateVectorCudaManaged]", float, double) {
272+
using PrecisionT = TestType;
273+
using ComplexT = typename StateVectorCudaManaged<PrecisionT>::ComplexT;
274+
using CFP_t = typename StateVectorCudaManaged<PrecisionT>::CFP_t;
275+
using TestVectorT = TestVector<ComplexT>;
276+
277+
std::size_t wire = GENERATE(0, 1, 2);
278+
std::size_t branch = GENERATE(0, 1);
279+
constexpr std::size_t num_qubits = 3;
280+
281+
// TODO @tomlqc use same template for testing all Lightning flavours?
282+
283+
SECTION("Collapse the state vector after having measured one of the "
284+
"qubits.") {
285+
TestVectorT init_state = createPlusState_<ComplexT>(num_qubits);
286+
287+
const ComplexT coef{0.5, PrecisionT{0.0}};
288+
const ComplexT zero{PrecisionT{0.0}, PrecisionT{0.0}};
289+
290+
std::vector<std::vector<std::vector<ComplexT>>> expected_state = {
291+
{{coef, coef, coef, coef, zero, zero, zero, zero},
292+
{coef, coef, zero, zero, coef, coef, zero, zero},
293+
{coef, zero, coef, zero, coef, zero, coef, zero}},
294+
{{zero, zero, zero, zero, coef, coef, coef, coef},
295+
{zero, zero, coef, coef, zero, zero, coef, coef},
296+
{zero, coef, zero, coef, zero, coef, zero, coef}},
297+
};
298+
299+
StateVectorCudaManaged<PrecisionT> sv(
300+
reinterpret_cast<CFP_t *>(init_state.data()), init_state.size());
301+
302+
sv.collapse(wire, branch);
303+
304+
PrecisionT eps = std::numeric_limits<PrecisionT>::epsilon() * 1e2;
305+
REQUIRE(isApproxEqual(sv.getDataVector().data(),
306+
sv.getDataVector().size(),
307+
expected_state[branch][wire].data(),
308+
expected_state[branch][wire].size(), eps));
309+
}
310+
}

pennylane_lightning/core/src/simulators/lightning_gpu/tests/mpi/Test_StateVectorCudaMPI.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,4 +317,4 @@ TEMPLATE_PRODUCT_TEST_CASE("StateVectorCudaMPI::applyOperations",
317317
{false, false}, {{0.0}}),
318318
LightningException, "must all be equal"); // invalid parameters
319319
}
320-
}
320+
}

pennylane_lightning/lightning_gpu/_state_vector.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@
3636
import numpy as np
3737
import pennylane as qml
3838
from pennylane import DeviceError
39+
from pennylane.measurements import MidMeasureMP
40+
from pennylane.ops import Conditional
3941
from pennylane.ops.op_math import Adjoint
42+
from pennylane.tape import QuantumScript
4043
from pennylane.wires import Wires
4144

4245
# pylint: disable=ungrouped-imports
4346
from pennylane_lightning.core._serialize import global_phase_diagonal
4447
from pennylane_lightning.core._state_vector_base import LightningBaseStateVector
4548

49+
from ._measurements import LightningGPUMeasurements
4650
from ._mpi_handler import MPIHandler
4751

4852
gate_cache_needs_hash = (
@@ -247,15 +251,33 @@ def _apply_lightning_controlled(self, operation):
247251
matrix = global_phase_diagonal(param, self.wires, control_wires, control_values)
248252
state.apply(name, wires, inv, [[param]], matrix)
249253

250-
def _apply_lightning_midmeasure(self):
254+
def _apply_lightning_midmeasure(
255+
self, operation: MidMeasureMP, mid_measurements: dict, postselect_mode: str
256+
):
251257
"""Execute a MidMeasureMP operation and return the sample in mid_measurements.
252258
253259
Args:
260+
operation (~pennylane.operation.Operation): mid-circuit measurement
261+
mid_measurements (None, dict): Dictionary of mid-circuit measurements
262+
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
263+
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
264+
keep the same number of shots.
254265
255266
Returns:
256267
None
257268
"""
258-
raise DeviceError("LightningGPU does not support Mid-circuit measurements.")
269+
wires = self.wires.indices(operation.wires)
270+
wire = list(wires)[0]
271+
if postselect_mode == "fill-shots" and operation.postselect is not None:
272+
sample = operation.postselect
273+
else:
274+
circuit = QuantumScript([], [qml.sample(wires=operation.wires)], shots=1)
275+
sample = LightningGPUMeasurements(self).measure_final_state(circuit)
276+
sample = np.squeeze(sample)
277+
mid_measurements[operation] = sample
278+
getattr(self.state_vector, "collapse")(wire, bool(sample))
279+
if operation.reset and bool(sample):
280+
self.apply_operations([qml.PauliX(operation.wires)], mid_measurements=mid_measurements)
259281

260282
# pylint: disable=unused-argument
261283
def _apply_lightning(
@@ -289,7 +311,14 @@ def _apply_lightning(
289311
method = getattr(state, name, None)
290312
wires = list(operation.wires)
291313

292-
if method is not None: # apply specialized gate
314+
if isinstance(operation, Conditional):
315+
if operation.meas_val.concretize(mid_measurements):
316+
self._apply_lightning([operation.base])
317+
elif isinstance(operation, MidMeasureMP):
318+
self._apply_lightning_midmeasure(
319+
operation, mid_measurements, postselect_mode=postselect_mode
320+
)
321+
elif method is not None: # apply specialized gate
293322
param = operation.parameters
294323
method(wires, invert_param, param)
295324
elif isinstance(operation, qml.ops.Controlled) and isinstance(

pennylane_lightning/lightning_gpu/lightning_gpu.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,7 @@ def stopping_condition(op: Operator) -> bool:
173173
def stopping_condition_shots(op: Operator) -> bool:
174174
"""A function that determines whether or not an operation is supported by ``lightning.gpu``
175175
with finite shots."""
176-
if isinstance(op, (MidMeasureMP, qml.ops.op_math.Conditional)):
177-
# LightningGPU does not support Mid-circuit measurements.
178-
return False
179-
return stopping_condition(op)
176+
return stopping_condition(op) or isinstance(op, (MidMeasureMP, qml.ops.op_math.Conditional))
180177

181178

182179
def accepted_observables(obs: Operator) -> bool:
@@ -460,6 +457,7 @@ def execute(
460457
self.simulate(
461458
circuit,
462459
self._statevector,
460+
postselect_mode=execution_config.mcm_config.postselect_mode,
463461
)
464462
)
465463

@@ -494,20 +492,47 @@ def simulate(
494492
self,
495493
circuit: QuantumScript,
496494
state: LightningGPUStateVector,
495+
postselect_mode: Optional[str] = None,
497496
) -> Result:
498497
"""Simulate a single quantum script.
499498
500499
Args:
501500
circuit (QuantumTape): The single circuit to simulate
502501
state (LightningGPUStateVector): handle to Lightning state vector
502+
postselect_mode (str): Configuration for handling shots with mid-circuit measurement
503+
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
504+
keep the same number of shots. Default is ``None``.
503505
504506
Returns:
505507
Tuple[TensorLike]: The results of the simulation
506508
507509
Note that this function can return measurements for non-commuting observables simultaneously.
508510
"""
509511
if circuit.shots and (any(isinstance(op, MidMeasureMP) for op in circuit.operations)):
510-
raise qml.DeviceError("LightningGPU does not support Mid-circuit measurements.")
512+
if self._mpi_handler.use_mpi:
513+
raise qml.DeviceError(
514+
"Lightning-GPU-MPI does not support Mid-circuit measurements."
515+
)
516+
517+
results = []
518+
aux_circ = QuantumScript(
519+
circuit.operations,
520+
circuit.measurements,
521+
shots=[1],
522+
trainable_params=circuit.trainable_params,
523+
)
524+
for _ in range(circuit.shots.total_shots):
525+
state.reset_state()
526+
mid_measurements = {}
527+
final_state = state.get_final_state(
528+
aux_circ, mid_measurements=mid_measurements, postselect_mode=postselect_mode
529+
)
530+
results.append(
531+
self.LightningMeasurements(final_state).measure_final_state(
532+
aux_circ, mid_measurements=mid_measurements
533+
)
534+
)
535+
return tuple(results)
511536

512537
state.reset_state()
513538
final_state = state.get_final_state(circuit)

pennylane_lightning/lightning_gpu/lightning_gpu.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ qjit_compatible = false
9898
# If the device requires run time generation of the quantum circuit.
9999
runtime_code_generation = false
100100
# If the device supports mid circuit measurements natively
101-
mid_circuit_measurement = false
101+
mid_circuit_measurement = true
102102

103103
# This field is currently unchecked but it is reserved for the purpose of
104104
# determining if the device supports dynamic qubit allocation/deallocation.

tests/test_native_mcm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from conftest import LightningDevice, device_name, validate_measurements
2222
from flaky import flaky
2323

24-
if device_name not in ("lightning.qubit", "lightning.kokkos"):
24+
if device_name not in ("lightning.qubit", "lightning.kokkos", "lightning.gpu"):
2525
pytest.skip("Native MCM not supported. Skipping.", allow_module_level=True)
2626

2727
if not LightningDevice._CPP_BINARY_AVAILABLE: # pylint: disable=protected-access
@@ -89,7 +89,7 @@ def func(x, y):
8989
match=f"not accepted with finite shots on lightning.qubit",
9090
):
9191
func(*params)
92-
if device_name == "lightning.kokkos":
92+
if device_name in ("lightning.kokkos", "lightning.gpu"):
9393
with pytest.raises(
9494
qml.DeviceError,
9595
match=r"Measurement shadow\(wires=\[0\]\) not accepted with finite shots on "

0 commit comments

Comments
 (0)