Skip to content

Commit 4ddcdf2

Browse files
Adding unit test for measurement with shots for LT with tn method (#1027)
**Context:** With Lightning Tensor, ensure `tn` has feature parity with `mps` method. **Description of the Change:** Using the current testing suite for measurement, we change only the `conftest.py` file to make available the `tn` method with Lightning Tensor **Benefits:** Change only a few lines in the code. **Possible Drawbacks:** The testing time will increase double for Lightning Tensor :disappointed: **Related GitHub Issues:** [sc-65726] --------- Co-authored-by: ringo-but-quantum <[email protected]>
1 parent 1782bbd commit 4ddcdf2

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
### Improvements
3131

32+
* Add unit test for measurement with shots for Lightning Tensor with `tn` method.
33+
[(#1027)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1027)
34+
3235
* Update the python layer UI of Lightning Tensor.
3336
[(#1022)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1022/)
3437

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.40.0-dev37"
19+
__version__ = "0.40.0-dev38"

tests/conftest.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,22 @@ def _device(wires, shots=None):
195195
# General LightningStateVector fixture, for any number of wires.
196196
@pytest.fixture(
197197
scope="function",
198-
params=[np.complex64, np.complex128],
198+
params=(
199+
[np.complex64, np.complex128]
200+
if device_name != "lightning.tensor"
201+
else [
202+
[c_dtype, method]
203+
for c_dtype in [np.complex64, np.complex128]
204+
for method in ["mps", "tn"]
205+
]
206+
),
199207
)
200208
def lightning_sv(request):
201209
def _statevector(num_wires):
202210
if device_name == "lightning.tensor":
203-
return LightningStateVector(num_wires=num_wires, c_dtype=request.param)
211+
return LightningStateVector(
212+
num_wires=num_wires, c_dtype=request.param[0], method=request.param[1]
213+
)
204214
return LightningStateVector(num_wires=num_wires, dtype=request.param)
205215

206216
return _statevector

tests/lightning_qubit/test_measurements_class.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,12 @@ def test_single_return_value(self, shots, measurement, observable, lightning_sv,
545545

546546
# a few tests may fail in single precision, and hence we increase the tolerance
547547
if shots is None:
548-
assert np.allclose(result, expected, max(tol, 1.0e-4))
548+
assert np.allclose(
549+
result,
550+
expected,
551+
max(tol, 1.0e-4),
552+
1e-6 if statevector.dtype == np.complex64 else 1e-8,
553+
)
549554
else:
550555
# TODO Set better atol and rtol
551556
dtol = max(tol, 1.0e-2)
@@ -788,6 +793,9 @@ def test_controlled_qubit_gates(self, operation, n_qubits, control_value, tol, l
788793
tape = qml.tape.QuantumScript(ops, measurements)
789794

790795
statevector = lightning_sv(n_qubits)
796+
if device_name == "lightning.tensor" and statevector.method == "tn":
797+
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")
798+
791799
statevector = get_final_state(statevector, tape)
792800
m = LightningMeasurements(statevector)
793801
result = measure_final_state(m, tape)
@@ -845,6 +853,9 @@ def test_cnot_controlled_qubit_unitary(self, control_wires, target_wires, tol, l
845853
)
846854

847855
statevector = lightning_sv(n_qubits)
856+
if device_name == "lightning.tensor" and statevector.method == "tn":
857+
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")
858+
848859
statevector = get_final_state(statevector, tape)
849860
m = LightningMeasurements(statevector)
850861
result = measure_final_state(m, tape)
@@ -889,6 +900,9 @@ def test_controlled_globalphase(self, n_qubits, control_value, tol, lightning_sv
889900
[qml.state()],
890901
)
891902
statevector = lightning_sv(n_qubits)
903+
if device_name == "lightning.tensor" and statevector.method == "tn":
904+
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")
905+
892906
statevector = get_final_state(statevector, tape)
893907
m = LightningMeasurements(statevector)
894908
result = measure_final_state(m, tape)
@@ -967,6 +981,9 @@ def test_state_vector_2_qubit_subset(tol, op, par, wires, expected, lightning_sv
967981
)
968982

969983
statevector = lightning_sv(2)
984+
if device_name == "lightning.tensor" and statevector.method == "tn":
985+
pytest.skip("StatePrep not supported in lightning.tensor with the tn method.")
986+
970987
statevector = get_final_state(statevector, tape)
971988

972989
m = LightningMeasurements(statevector)

0 commit comments

Comments
 (0)