21
21
from warnings import warn
22
22
23
23
import numpy as np
24
+ from pennylane .measurements import MidMeasureMP
25
+ from pennylane .ops import Conditional
24
26
25
27
from pennylane_lightning .core .lightning_base import (
26
28
LightningBase ,
@@ -136,6 +138,8 @@ def _kokkos_configuration():
136
138
"QFT" ,
137
139
"ECR" ,
138
140
"BlockEncode" ,
141
+ "MidMeasureMP" ,
142
+ "Conditional" ,
139
143
}
140
144
141
145
allowed_observables = {
@@ -210,6 +214,15 @@ def __init__(
210
214
if not LightningKokkos .kokkos_config :
211
215
LightningKokkos .kokkos_config = _kokkos_configuration ()
212
216
217
+ # pylint: disable=missing-function-docstring
218
+ @classmethod
219
+ def capabilities (cls ):
220
+ capabilities = super ().capabilities ().copy ()
221
+ capabilities .update (
222
+ supports_mid_measure = True ,
223
+ )
224
+ return capabilities
225
+
213
226
@staticmethod
214
227
def _asarray (arr , dtype = None ):
215
228
arr = np .asarray (arr ) # arr is not copied
@@ -370,7 +383,25 @@ def _apply_basis_state(self, state, wires):
370
383
num = self ._get_basis_state_index (state , wires )
371
384
self ._create_basis_state (num )
372
385
373
- def apply_lightning (self , operations ):
386
+ def _apply_lightning_midmeasure (self , operation : MidMeasureMP , mid_measurements : dict ):
387
+ """Execute a MidMeasureMP operation and return the sample in mid_measurements.
388
+ Args:
389
+ operation (~pennylane.operation.Operation): mid-circuit measurement
390
+ Returns:
391
+ None
392
+ """
393
+ wires = self .wires .indices (operation .wires )
394
+ wire = list (wires )[0 ]
395
+ sample = qml .math .reshape (self .generate_samples (shots = 1 ), (- 1 ,))[wire ]
396
+ if operation .postselect is not None and sample != operation .postselect :
397
+ mid_measurements [operation ] = - 1
398
+ return
399
+ mid_measurements [operation ] = sample
400
+ getattr (self .state_vector , "collapse" )(wire , bool (sample ))
401
+ if operation .reset and bool (sample ):
402
+ self .apply ([qml .PauliX (operation .wires )], mid_measurements = mid_measurements )
403
+
404
+ def apply_lightning (self , operations , mid_measurements = None ):
374
405
"""Apply a list of operations to the state tensor.
375
406
376
407
Args:
@@ -392,12 +423,17 @@ def apply_lightning(self, operations):
392
423
else :
393
424
name = ops .name
394
425
invert_param = False
395
- if name == " Identity" :
426
+ if isinstance ( ops , qml . Identity ) :
396
427
continue
397
428
method = getattr (state , name , None )
398
429
wires = self .wires .indices (ops .wires )
399
430
400
- if ops .name == "C(GlobalPhase)" :
431
+ if isinstance (ops , Conditional ):
432
+ if ops .meas_val .concretize (mid_measurements ):
433
+ self .apply_lightning ([ops .then_op ])
434
+ elif isinstance (ops , MidMeasureMP ):
435
+ self ._apply_lightning_midmeasure (ops , mid_measurements )
436
+ elif ops .name == "C(GlobalPhase)" :
401
437
controls = ops .control_wires
402
438
control_values = ops .control_values
403
439
param = ops .base .parameters [0 ]
@@ -425,7 +461,7 @@ def apply_lightning(self, operations):
425
461
method (wires , invert_param , param )
426
462
427
463
# pylint: disable=unused-argument
428
- def apply (self , operations , rotations = None , ** kwargs ):
464
+ def apply (self , operations , rotations = None , mid_measurements = None , ** kwargs ):
429
465
"""Applies a list of operations to the state tensor."""
430
466
# State preparation is currently done in Python
431
467
if operations : # make sure operations[0] exists
@@ -445,7 +481,9 @@ def apply(self, operations, rotations=None, **kwargs):
445
481
+ f"Operations have already been applied on a { self .short_name } device."
446
482
)
447
483
448
- self .apply_lightning (operations )
484
+ self .apply_lightning (operations , mid_measurements = mid_measurements )
485
+ if mid_measurements is not None and any (v == - 1 for v in mid_measurements .values ()):
486
+ self ._apply_basis_state (np .zeros (self .num_wires ), wires = self .wires )
449
487
450
488
# pylint: disable=protected-access
451
489
def expval (self , observable , shot_range = None , bin_size = None ):
@@ -575,19 +613,20 @@ def var(self, observable, shot_range=None, bin_size=None):
575
613
576
614
return measure .var (observable .name , observable_wires )
577
615
578
- def generate_samples (self ):
616
+ def generate_samples (self , shots = None ):
579
617
"""Generate samples
580
618
581
619
Returns:
582
620
array[int]: array of samples in binary representation with shape
583
621
``(dev.shots, dev.num_wires)``
584
622
"""
623
+ shots = self .shots if shots is None else shots
585
624
measure = (
586
625
MeasurementsC64 (self ._kokkos_state )
587
626
if self .use_csingle
588
627
else MeasurementsC128 (self ._kokkos_state )
589
628
)
590
- return measure .generate_samples (len (self .wires ), self . shots ).astype (int , copy = False )
629
+ return measure .generate_samples (len (self .wires ), shots ).astype (int , copy = False )
591
630
592
631
def probability_lightning (self , wires ):
593
632
"""Return the probability of each computational basis state.
0 commit comments