|
36 | 36 | import numpy as np
|
37 | 37 | import pennylane as qml
|
38 | 38 | from pennylane import DeviceError
|
| 39 | +from pennylane.measurements import MidMeasureMP |
| 40 | +from pennylane.ops import Conditional |
39 | 41 | from pennylane.ops.op_math import Adjoint
|
| 42 | +from pennylane.tape import QuantumScript |
40 | 43 | from pennylane.wires import Wires
|
41 | 44 |
|
42 | 45 | # pylint: disable=ungrouped-imports
|
43 | 46 | from pennylane_lightning.core._serialize import global_phase_diagonal
|
44 | 47 | from pennylane_lightning.core._state_vector_base import LightningBaseStateVector
|
45 | 48 |
|
| 49 | +from ._measurements import LightningGPUMeasurements |
46 | 50 | from ._mpi_handler import MPIHandler
|
47 | 51 |
|
48 | 52 | gate_cache_needs_hash = (
|
@@ -247,15 +251,33 @@ def _apply_lightning_controlled(self, operation):
|
247 | 251 | matrix = global_phase_diagonal(param, self.wires, control_wires, control_values)
|
248 | 252 | state.apply(name, wires, inv, [[param]], matrix)
|
249 | 253 |
|
250 |
| - def _apply_lightning_midmeasure(self): |
| 254 | + def _apply_lightning_midmeasure( |
| 255 | + self, operation: MidMeasureMP, mid_measurements: dict, postselect_mode: str |
| 256 | + ): |
251 | 257 | """Execute a MidMeasureMP operation and return the sample in mid_measurements.
|
252 | 258 |
|
253 | 259 | Args:
|
| 260 | + operation (~pennylane.operation.Operation): mid-circuit measurement |
| 261 | + mid_measurements (None, dict): Dictionary of mid-circuit measurements |
| 262 | + postselect_mode (str): Configuration for handling shots with mid-circuit measurement |
| 263 | + postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to |
| 264 | + keep the same number of shots. |
254 | 265 |
|
255 | 266 | Returns:
|
256 | 267 | None
|
257 | 268 | """
|
258 |
| - raise DeviceError("LightningGPU does not support Mid-circuit measurements.") |
| 269 | + wires = self.wires.indices(operation.wires) |
| 270 | + wire = list(wires)[0] |
| 271 | + if postselect_mode == "fill-shots" and operation.postselect is not None: |
| 272 | + sample = operation.postselect |
| 273 | + else: |
| 274 | + circuit = QuantumScript([], [qml.sample(wires=operation.wires)], shots=1) |
| 275 | + sample = LightningGPUMeasurements(self).measure_final_state(circuit) |
| 276 | + sample = np.squeeze(sample) |
| 277 | + mid_measurements[operation] = sample |
| 278 | + getattr(self.state_vector, "collapse")(wire, bool(sample)) |
| 279 | + if operation.reset and bool(sample): |
| 280 | + self.apply_operations([qml.PauliX(operation.wires)], mid_measurements=mid_measurements) |
259 | 281 |
|
260 | 282 | # pylint: disable=unused-argument
|
261 | 283 | def _apply_lightning(
|
@@ -289,7 +311,14 @@ def _apply_lightning(
|
289 | 311 | method = getattr(state, name, None)
|
290 | 312 | wires = list(operation.wires)
|
291 | 313 |
|
292 |
| - if method is not None: # apply specialized gate |
| 314 | + if isinstance(operation, Conditional): |
| 315 | + if operation.meas_val.concretize(mid_measurements): |
| 316 | + self._apply_lightning([operation.base]) |
| 317 | + elif isinstance(operation, MidMeasureMP): |
| 318 | + self._apply_lightning_midmeasure( |
| 319 | + operation, mid_measurements, postselect_mode=postselect_mode |
| 320 | + ) |
| 321 | + elif method is not None: # apply specialized gate |
293 | 322 | param = operation.parameters
|
294 | 323 | method(wires, invert_param, param)
|
295 | 324 | elif isinstance(operation, qml.ops.Controlled) and isinstance(
|
|
0 commit comments