Skip to content

Commit 3f3ab86

Browse files
Add MCM bindings and tests for L-Kokkos. (#672)
* Add MCM bindings and tests for L-Kokkos. * Auto update version * Refactor collapse. * Try parallelizing tests. * Only run test_native_mcm once amongst all workflows since it takes quite a bit of time. * trigger ci --------- Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>
1 parent 31bc86f commit 3f3ab86

File tree

10 files changed

+97
-74
lines changed

10 files changed

+97
-74
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
### New features since last release
44

5+
* `lightning.kokkos` supports mid-circuit measurements.
6+
[(#672)](https://github.com/PennyLaneAI/pennylane-lightning/pull/672)
7+
58
* `lightning.qubit` supports mid-circuit measurements.
69
[(#650)](https://github.com/PennyLaneAI/pennylane-lightning/pull/650)
710

.github/workflows/tests_gpu_cuda.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ jobs:
316316
OMP_PROC_BIND: false
317317
run: |
318318
cd main/
319-
PL_DEVICE=lightning.qubit python -m pytest tests/ $COVERAGE_FLAGS
319+
PL_DEVICE=lightning.qubit python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS
320320
pl-device-test --device lightning.qubit --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
321321
pl-device-test --device lightning.qubit --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
322322
PL_DEVICE=lightning.gpu python -m pytest tests/ $COVERAGE_FLAGS

.github/workflows/tests_gpu_kokkos.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ jobs:
325325
OMP_PROC_BIND: false
326326
run: |
327327
cd main/
328-
PL_DEVICE=lightning.qubit python -m pytest tests/ $COVERAGE_FLAGS
328+
PL_DEVICE=lightning.qubit python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS
329329
pl-device-test --device lightning.qubit --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
330330
pl-device-test --device lightning.qubit --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
331331
PL_DEVICE=lightning.kokkos python -m pytest tests/ $COVERAGE_FLAGS

.github/workflows/tests_linux.yml

+8-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ env:
2121
COVERAGE_FLAGS: "--cov=pennylane_lightning --cov-report=term-missing --no-flaky-report -p no:warnings --tb=native"
2222
GCC_VERSION: 11
2323
OMP_NUM_THREADS: "2"
24+
OMP_PROC_BIND: "false"
2425

2526
concurrency:
2627
group: tests_linux-${{ github.ref }}-${{ inputs.lightning-version }}-${{ inputs.pennylane-version }}
@@ -393,8 +394,8 @@ jobs:
393394
run: |
394395
cd main/
395396
DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"`
396-
OMP_NUM_THREADS=1 PL_DEVICE=${DEVICENAME} python -m pytest -n auto tests/ -k "not unitary_correct" $COVERAGE_FLAGS
397-
PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "unitary_correct" $COVERAGE_FLAGS --cov-append
397+
OMP_NUM_THREADS=1 PL_DEVICE=${DEVICENAME} python -m pytest -n auto tests/ -k "not unitary_correct and not test_native_mcm" $COVERAGE_FLAGS
398+
PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "unitary_correct and not test_native_mcm" $COVERAGE_FLAGS --cov-append
398399
pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
399400
pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
400401
mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }}
@@ -592,7 +593,8 @@ jobs:
592593
run: |
593594
cd main/
594595
DEVICENAME=`echo ${{ matrix.pl_backend }} | sed "s/_/./g"`
595-
PL_DEVICE=${DEVICENAME} python -m pytest tests/ $COVERAGE_FLAGS
596+
PL_DEVICE=${DEVICENAME} python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS
597+
OMP_NUM_THREADS=1 PL_DEVICE=${DEVICENAME} python -m pytest -n auto tests/ -k "test_native_mcm" $COVERAGE_FLAGS --cov-append
596598
pl-device-test --device ${DEVICENAME} --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
597599
pl-device-test --device ${DEVICENAME} --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
598600
mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }}
@@ -611,11 +613,11 @@ jobs:
611613
if: ${{ matrix.pl_backend == 'all' }}
612614
run: |
613615
cd main/
614-
OMP_NUM_THREADS=1 PL_DEVICE=lightning.qubit python -m pytest -n auto tests/ -k "not unitary_correct" $COVERAGE_FLAGS
615-
PL_DEVICE=lightning.qubit python -m pytest tests/ -k "unitary_correct" $COVERAGE_FLAGS --cov-append
616+
OMP_NUM_THREADS=1 PL_DEVICE=lightning.qubit python -m pytest -n auto tests/ -k "not unitary_correct and not test_native_mcm" $COVERAGE_FLAGS
617+
PL_DEVICE=lightning.qubit python -m pytest tests/ -k "unitary_correct and not test_native_mcm" $COVERAGE_FLAGS --cov-append
616618
pl-device-test --device lightning.qubit --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
617619
pl-device-test --device lightning.qubit --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
618-
PL_DEVICE=lightning.kokkos python -m pytest tests/ $COVERAGE_FLAGS --cov-append
620+
PL_DEVICE=lightning.kokkos python -m pytest tests/ -k "not test_native_mcm" $COVERAGE_FLAGS --cov-append
619621
pl-device-test --device lightning.kokkos --skip-ops --shots=20000 $COVERAGE_FLAGS --cov-append
620622
pl-device-test --device lightning.kokkos --shots=None --skip-ops $COVERAGE_FLAGS --cov-append
621623
mv .coverage .coverage-${{ github.job }}-${{ matrix.pl_backend }}

pennylane_lightning/core/src/simulators/lightning_kokkos/StateVectorKokkos.hpp

+14-20
Original file line numberDiff line numberDiff line change
@@ -767,48 +767,42 @@ class StateVectorKokkos final
767767
* @param branch Branch 0 or 1.
768768
*/
769769
void collapse(const std::size_t wire, const bool branch) {
770-
auto &&num_qubits = this->getNumQubits();
771-
772-
const size_t stride = pow(2, num_qubits_ - (1 + wire));
773-
const size_t vec_size = pow(2, num_qubits_);
774-
const auto section_size = vec_size / stride;
775-
const auto half_section_size = section_size / 2;
776-
777-
const size_t negbranch = branch ? 0 : 1;
778-
779-
Kokkos::MDRangePolicy<DoubleLoopRank> policy_2d(
780-
{0, 0}, {half_section_size, stride});
770+
KokkosVector matrix("gate_matrix", 4);
781771
Kokkos::parallel_for(
782-
policy_2d,
783-
collapseFunctor<fp_t>(*data_, num_qubits, stride, negbranch));
784-
772+
matrix.size(), KOKKOS_LAMBDA(const std::size_t k) {
773+
matrix(k) = ((k == 0 && branch == 0) || (k == 3 && branch == 1))
774+
? ComplexT{1.0, 0.0}
775+
: ComplexT{0.0, 0.0};
776+
});
777+
applyMultiQubitOp(matrix, {wire}, false);
785778
normalize();
786779
}
787780

788781
/**
789782
* @brief Normalize vector (to have norm 1).
790783
*/
791784
void normalize() {
792-
KokkosVector sv_view =
793-
getView(); // circumvent error capturing this with KOKKOS_LAMBDA
785+
auto sv_view = getView();
794786

795787
// TODO: @tomlqc what about squaredNorm()
796788
PrecisionT squaredNorm = 0.0;
797789
Kokkos::parallel_reduce(
798790
sv_view.size(),
799-
KOKKOS_LAMBDA(const size_t i, PrecisionT &sum) {
800-
sum += std::norm<PrecisionT>(sv_view(i));
791+
KOKKOS_LAMBDA(const std::size_t i, PrecisionT &sum) {
792+
const PrecisionT norm = Kokkos::abs(sv_view(i));
793+
sum += norm * norm;
801794
},
802795
squaredNorm);
803796

804797
PL_ABORT_IF(squaredNorm <
805798
std::numeric_limits<PrecisionT>::epsilon() * 1e2,
806799
"vector has norm close to zero and can't be normalized");
807800

808-
std::complex<PrecisionT> inv_norm = 1. / std::sqrt(squaredNorm);
801+
const std::complex<PrecisionT> inv_norm =
802+
1. / Kokkos::sqrt(squaredNorm);
809803
Kokkos::parallel_for(
810804
sv_view.size(),
811-
KOKKOS_LAMBDA(const size_t i) { sv_view(i) *= inv_norm; });
805+
KOKKOS_LAMBDA(const std::size_t i) { sv_view(i) *= inv_norm; });
812806
}
813807

814808
/**

pennylane_lightning/core/src/simulators/lightning_kokkos/bindings/LKokkosBindings.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
137137
sv.applyOperation(str, wires, inv, std::vector<ParamT>{},
138138
conv_matrix);
139139
},
140-
"Apply operation via the gate matrix");
140+
"Apply operation via the gate matrix")
141+
.def("collapse", &StateVectorT::collapse,
142+
"Collapse the statevector onto the 0 or 1 branch of a given wire.")
143+
.def("normalize", &StateVectorT::normalize,
144+
"Normalizes the statevector to norm 1.");
141145
}
142146

143147
/**

pennylane_lightning/core/src/simulators/lightning_kokkos/gates/GateFunctorsGenerator.hpp

-30
Original file line numberDiff line numberDiff line change
@@ -973,34 +973,4 @@ template <class PrecisionT, bool adj = false> struct generatorMultiRZFunctor {
973973
}
974974
};
975975

976-
template <class PrecisionT> struct collapseFunctor {
977-
using ComplexT = Kokkos::complex<PrecisionT>;
978-
using KokkosComplexVector = Kokkos::View<ComplexT *>;
979-
980-
KokkosComplexVector arr;
981-
std::size_t num_qubits;
982-
std::size_t stride;
983-
std::size_t negbranch;
984-
985-
collapseFunctor(KokkosComplexVector &arr_, std::size_t num_qubits_,
986-
std::size_t stride_, std::size_t negbranch_) {
987-
arr = arr_;
988-
num_qubits = num_qubits_;
989-
stride = stride_;
990-
negbranch = negbranch_;
991-
}
992-
993-
// zero half the entries
994-
// the "half" entries depend on the stride
995-
// *_*_*_*_ for stride 1
996-
// **__**__ for stride 2
997-
// ****____ for stride 4
998-
999-
KOKKOS_INLINE_FUNCTION
1000-
void operator()(const std::size_t left, const std::size_t right) const {
1001-
const size_t offset = stride * (negbranch + 2 * left);
1002-
arr[offset + right] = ComplexT{0., 0.};
1003-
}
1004-
};
1005-
1006976
} // namespace Pennylane::LightningKokkos::Functors

pennylane_lightning/core/src/simulators/lightning_kokkos/measurements/MeasurementsKokkos.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#pragma once
15+
#include <chrono>
1516

1617
#include <Kokkos_Core.hpp>
1718
#include <Kokkos_Random.hpp>
@@ -682,7 +683,10 @@ class Measurements final
682683
});
683684

684685
// Sampling using Random_XorShift64_Pool
685-
Kokkos::Random_XorShift64_Pool<> rand_pool(5374857);
686+
Kokkos::Random_XorShift64_Pool<> rand_pool(
687+
std::chrono::high_resolution_clock::now()
688+
.time_since_epoch()
689+
.count());
686690

687691
Kokkos::parallel_for(
688692
Kokkos::RangePolicy<KokkosExecSpace>(0, num_samples),

pennylane_lightning/lightning_kokkos/lightning_kokkos.py

+46-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from warnings import warn
2222

2323
import numpy as np
24+
from pennylane.measurements import MidMeasureMP
25+
from pennylane.ops import Conditional
2426

2527
from pennylane_lightning.core.lightning_base import (
2628
LightningBase,
@@ -136,6 +138,8 @@ def _kokkos_configuration():
136138
"QFT",
137139
"ECR",
138140
"BlockEncode",
141+
"MidMeasureMP",
142+
"Conditional",
139143
}
140144

141145
allowed_observables = {
@@ -210,6 +214,15 @@ def __init__(
210214
if not LightningKokkos.kokkos_config:
211215
LightningKokkos.kokkos_config = _kokkos_configuration()
212216

217+
# pylint: disable=missing-function-docstring
218+
@classmethod
219+
def capabilities(cls):
220+
capabilities = super().capabilities().copy()
221+
capabilities.update(
222+
supports_mid_measure=True,
223+
)
224+
return capabilities
225+
213226
@staticmethod
214227
def _asarray(arr, dtype=None):
215228
arr = np.asarray(arr) # arr is not copied
@@ -370,7 +383,25 @@ def _apply_basis_state(self, state, wires):
370383
num = self._get_basis_state_index(state, wires)
371384
self._create_basis_state(num)
372385

373-
def apply_lightning(self, operations):
386+
def _apply_lightning_midmeasure(self, operation: MidMeasureMP, mid_measurements: dict):
387+
"""Execute a MidMeasureMP operation and return the sample in mid_measurements.
388+
Args:
389+
operation (~pennylane.operation.Operation): mid-circuit measurement
390+
Returns:
391+
None
392+
"""
393+
wires = self.wires.indices(operation.wires)
394+
wire = list(wires)[0]
395+
sample = qml.math.reshape(self.generate_samples(shots=1), (-1,))[wire]
396+
if operation.postselect is not None and sample != operation.postselect:
397+
mid_measurements[operation] = -1
398+
return
399+
mid_measurements[operation] = sample
400+
getattr(self.state_vector, "collapse")(wire, bool(sample))
401+
if operation.reset and bool(sample):
402+
self.apply([qml.PauliX(operation.wires)], mid_measurements=mid_measurements)
403+
404+
def apply_lightning(self, operations, mid_measurements=None):
374405
"""Apply a list of operations to the state tensor.
375406
376407
Args:
@@ -392,12 +423,17 @@ def apply_lightning(self, operations):
392423
else:
393424
name = ops.name
394425
invert_param = False
395-
if name == "Identity":
426+
if isinstance(ops, qml.Identity):
396427
continue
397428
method = getattr(state, name, None)
398429
wires = self.wires.indices(ops.wires)
399430

400-
if ops.name == "C(GlobalPhase)":
431+
if isinstance(ops, Conditional):
432+
if ops.meas_val.concretize(mid_measurements):
433+
self.apply_lightning([ops.then_op])
434+
elif isinstance(ops, MidMeasureMP):
435+
self._apply_lightning_midmeasure(ops, mid_measurements)
436+
elif ops.name == "C(GlobalPhase)":
401437
controls = ops.control_wires
402438
control_values = ops.control_values
403439
param = ops.base.parameters[0]
@@ -425,7 +461,7 @@ def apply_lightning(self, operations):
425461
method(wires, invert_param, param)
426462

427463
# pylint: disable=unused-argument
428-
def apply(self, operations, rotations=None, **kwargs):
464+
def apply(self, operations, rotations=None, mid_measurements=None, **kwargs):
429465
"""Applies a list of operations to the state tensor."""
430466
# State preparation is currently done in Python
431467
if operations: # make sure operations[0] exists
@@ -445,7 +481,9 @@ def apply(self, operations, rotations=None, **kwargs):
445481
+ f"Operations have already been applied on a {self.short_name} device."
446482
)
447483

448-
self.apply_lightning(operations)
484+
self.apply_lightning(operations, mid_measurements=mid_measurements)
485+
if mid_measurements is not None and any(v == -1 for v in mid_measurements.values()):
486+
self._apply_basis_state(np.zeros(self.num_wires), wires=self.wires)
449487

450488
# pylint: disable=protected-access
451489
def expval(self, observable, shot_range=None, bin_size=None):
@@ -575,19 +613,20 @@ def var(self, observable, shot_range=None, bin_size=None):
575613

576614
return measure.var(observable.name, observable_wires)
577615

578-
def generate_samples(self):
616+
def generate_samples(self, shots=None):
579617
"""Generate samples
580618
581619
Returns:
582620
array[int]: array of samples in binary representation with shape
583621
``(dev.shots, dev.num_wires)``
584622
"""
623+
shots = self.shots if shots is None else shots
585624
measure = (
586625
MeasurementsC64(self._kokkos_state)
587626
if self.use_csingle
588627
else MeasurementsC128(self._kokkos_state)
589628
)
590-
return measure.generate_samples(len(self.wires), self.shots).astype(int, copy=False)
629+
return measure.generate_samples(len(self.wires), shots).astype(int, copy=False)
591630

592631
def probability_lightning(self, wires):
593632
"""Return the probability of each computational basis state.

tests/test_native_mcm.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from flaky import flaky
2323
from pennylane._device import DeviceError
2424

25-
if not LightningDevice._new_API:
26-
pytest.skip("Exclusive tests for new device API. Skipping.", allow_module_level=True)
25+
if device_name not in ("lightning.qubit", "lightning.kokkos"):
26+
pytest.skip("Native MCM not supported. Skipping.", allow_module_level=True)
2727

2828
if not LightningDevice._CPP_BINARY_AVAILABLE: # pylint: disable=protected-access
2929
pytest.skip("No binary module found. Skipping.", allow_module_level=True)
@@ -79,11 +79,18 @@ def func(x, y):
7979
qml.cond(m0, qml.RY)(y, wires=1)
8080
return qml.classical_shadow(wires=0)
8181

82-
with pytest.raises(
83-
DeviceError,
84-
match=f"not accepted with finite shots on lightning.qubit",
85-
):
86-
func(*params)
82+
if device_name == "lightning.qubit":
83+
with pytest.raises(
84+
DeviceError,
85+
match=f"not accepted with finite shots on lightning.qubit",
86+
):
87+
func(*params)
88+
else:
89+
with pytest.raises(
90+
TypeError,
91+
match=f"Native mid-circuit measurement mode does not support ClassicalShadowMP measurements.",
92+
):
93+
func(*params)
8794

8895

8996
@flaky(max_runs=5)

0 commit comments

Comments
 (0)