Skip to content

Commit 87bcd10

Browse files
Optimize lightning.tensor by adding direct MPS sites data set (#983)
**Context:** Optimize `lightning.tensor` by adding direct MPS sites data set **Description of the Change:** Adding the `MPSPrep` gate to be able to pass an MPS directly to the Tensor Network. The `MPSPrep` gate frontend was developed on this [PR](PennyLaneAI/pennylane#6431) **Benefits:** Avoid the decomposition from state vector to MPS sites which are expensive and inefficient **Possible Drawbacks:** **Related GitHub Issues:** [sc-74709] --------- Co-authored-by: ringo-but-quantum <[email protected]>
1 parent 11b03d4 commit 87bcd10

File tree

12 files changed

+281
-36
lines changed

12 files changed

+281
-36
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
### Improvements
3131

32+
* Optimize lightning.tensor by adding direct MPS sites data set with `qml.MPSPrep`.
33+
[(#983)](https://github.com/PennyLaneAI/pennylane-lightning/pull/983)
34+
3235
* Replace the `dummy_tensor_update` method with the `cutensornetStateCaptureMPS`API to ensure that further gates apply is allowed after the `cutensornetStateCompute` call.
3336
[(#1028)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1028/)
3437

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.40.0-dev41"
19+
__version__ = "0.40.0-dev42"

pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/TNCuda.hpp

+34-26
Original file line numberDiff line numberDiff line change
@@ -60,32 +60,6 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
6060
using ComplexT = std::complex<PrecisionT>;
6161
using BaseType = TNCudaBase<PrecisionT, Derived>;
6262

63-
protected:
64-
// Note both maxBondDim_ and bondDims_ are used for both MPS and Exact
65-
// Tensor Network. Per Exact Tensor Network, maxBondDim_ is 1 and bondDims_
66-
// is {1}. Per Exact Tensor Network, setting bondDims_ allows call to
67-
// appendInitialMPSState_() to append the initial state to the Exact Tensor
68-
// Network state.
69-
const std::size_t
70-
maxBondDim_; // maxBondDim_ default is 1 for Exact Tensor Network
71-
const std::vector<std::size_t>
72-
bondDims_; // bondDims_ default is {1} for Exact Tensor Network
73-
74-
private:
75-
const std::vector<std::vector<std::size_t>> sitesModes_;
76-
const std::vector<std::vector<std::size_t>> sitesExtents_;
77-
const std::vector<std::vector<int64_t>> sitesExtents_int64_;
78-
79-
SharedCublasCaller cublascaller_;
80-
81-
std::shared_ptr<TNCudaGateCache<PrecisionT>> gate_cache_;
82-
std::set<int64_t> gate_ids_;
83-
84-
std::vector<std::size_t> identiy_gate_ids_;
85-
86-
std::vector<TensorCuda<PrecisionT>> tensors_;
87-
std::vector<TensorCuda<PrecisionT>> tensors_out_;
88-
8963
public:
9064
TNCuda() = delete;
9165

@@ -499,7 +473,27 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
499473
projected_mode_values, numHyperSamples);
500474
}
501475

476+
/**
477+
* @brief Get a const vector reference of sitesExtents_.
478+
*
479+
* @return const std::vector<std::vector<std::size_t>>
480+
*/
481+
[[nodiscard]] auto getSitesExtents() const
482+
-> const std::vector<std::vector<std::size_t>> & {
483+
return sitesExtents_;
484+
}
485+
502486
protected:
487+
// Note both maxBondDim_ and bondDims_ are used for both MPS and Exact
488+
// Tensor Network. For Exact Tensor Network, maxBondDim_ is 1 and bondDims_
489+
// is {1}. For Exact Tensor Network, setting bondDims_ allows call to
490+
// appendInitialMPSState_() to append the initial state to the Exact Tensor
491+
// Network state.
492+
const std::size_t
493+
maxBondDim_; // maxBondDim_ default is 1 for Exact Tensor Network
494+
const std::vector<std::size_t>
495+
bondDims_; // bondDims_ default is {1} for Exact Tensor Network
496+
503497
/**
504498
* @brief Get a vector of pointers to tensor data of each site.
505499
*
@@ -578,6 +572,20 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
578572
}
579573

580574
private:
575+
const std::vector<std::vector<std::size_t>> sitesModes_;
576+
const std::vector<std::vector<std::size_t>> sitesExtents_;
577+
const std::vector<std::vector<int64_t>> sitesExtents_int64_;
578+
579+
SharedCublasCaller cublascaller_;
580+
581+
std::shared_ptr<TNCudaGateCache<PrecisionT>> gate_cache_;
582+
std::set<int64_t> gate_ids_;
583+
584+
std::vector<std::size_t> identiy_gate_ids_;
585+
586+
std::vector<TensorCuda<PrecisionT>> tensors_;
587+
std::vector<TensorCuda<PrecisionT>> tensors_out_;
588+
581589
/**
582590
* @brief Get accessor of a state tensor
583591
*

pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/bindings/LTensorTNCudaBindings.hpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@
3333
#include "TypeList.hpp"
3434
#include "Util.hpp"
3535
#include "cuda_helpers.hpp"
36+
#include "tncuda_helpers.hpp"
3637

3738
/// @cond DEV
3839
namespace {
3940
using namespace Pennylane;
4041
using namespace Pennylane::Bindings;
4142
using namespace Pennylane::LightningGPU::Util;
42-
using Pennylane::LightningTensor::TNCuda::MPSTNCuda;
43+
using namespace Pennylane::LightningTensor::TNCuda::Util;
4344
} // namespace
4445
/// @endcond
4546

@@ -137,6 +138,20 @@ void registerBackendClassSpecificBindingsMPS(PyClass &pyclass) {
137138
.def(
138139
"updateMPSSitesData",
139140
[](TensorNet &tensor_network, std::vector<np_arr_c> &tensors) {
141+
// Extract the incoming MPS shape
142+
std::vector<std::vector<std::size_t>> MPS_shape_source;
143+
for (std::size_t idx = 0; idx < tensors.size(); idx++) {
144+
py::buffer_info numpyArrayInfo = tensors[idx].request();
145+
auto MPS_site_source_shape = numpyArrayInfo.shape;
146+
std::vector<std::size_t> MPS_site_source(
147+
MPS_site_source_shape.begin(),
148+
MPS_site_source_shape.end());
149+
MPS_shape_source.emplace_back(std::move(MPS_site_source));
150+
}
151+
152+
const auto &MPS_shape_dest = tensor_network.getSitesExtents();
153+
MPSShapeCheck(MPS_shape_dest, MPS_shape_source);
154+
140155
for (std::size_t idx = 0; idx < tensors.size(); idx++) {
141156
py::buffer_info numpyArrayInfo = tensors[idx].request();
142157
auto *data_ptr = static_cast<std::complex<PrecisionT> *>(

pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/tests/Tests_MPSTNCuda.cpp

+62-2
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ TEMPLATE_TEST_CASE("MPSTNCuda::getDataVector()", "[MPSTNCuda]", float, double) {
291291

292292
TEMPLATE_TEST_CASE("MPOTNCuda::getBondDims()", "[MPOTNCuda]", float, double) {
293293
using cp_t = std::complex<TestType>;
294-
SECTION("Check if bondDims is correctly set") {
294+
SECTION("Check if bondDims is correct set") {
295295
const std::size_t num_qubits = 3;
296296
const std::size_t maxBondDim = 128;
297297
const DevTag<int> dev_tag{0, 0};
@@ -323,4 +323,64 @@ TEMPLATE_TEST_CASE("MPOTNCuda::getBondDims()", "[MPOTNCuda]", float, double) {
323323

324324
CHECK(bondDims == expected_bondDims);
325325
}
326-
}
326+
}
327+
328+
TEMPLATE_TEST_CASE("MPSTNCuda::getSitesExtents()", "[MPSTNCuda]", float,
329+
double) {
330+
SECTION("Check if sitesExtents retrun is correct with 3 qubits") {
331+
const std::size_t num_qubits = 3;
332+
const std::size_t maxBondDim = 128;
333+
const DevTag<int> dev_tag{0, 0};
334+
335+
const std::vector<std::vector<std::size_t>> reference{
336+
{{2, 2}, {2, 2, 2}, {2, 2}}};
337+
338+
MPSTNCuda<TestType> mps{num_qubits, maxBondDim, dev_tag};
339+
340+
const auto &sitesExtents = mps.getSitesExtents();
341+
342+
CHECK(reference == sitesExtents);
343+
}
344+
345+
SECTION("Check if sitesExtents retrun is correct with 8 qubits") {
346+
const std::size_t num_qubits = 8;
347+
const std::size_t maxBondDim = 128;
348+
const DevTag<int> dev_tag{0, 0};
349+
350+
const std::vector<std::vector<std::size_t>> reference{{{2, 2},
351+
{2, 2, 4},
352+
{4, 2, 8},
353+
{8, 2, 16},
354+
{16, 2, 8},
355+
{8, 2, 4},
356+
{4, 2, 2},
357+
{2, 2}}};
358+
359+
MPSTNCuda<TestType> mps{num_qubits, maxBondDim, dev_tag};
360+
361+
const auto &sitesExtents = mps.getSitesExtents();
362+
363+
CHECK(reference == sitesExtents);
364+
}
365+
SECTION("Check if sitesExtents retrun is correct with 8 qubits and "
366+
"maxBondDim=8") {
367+
const std::size_t num_qubits = 8;
368+
const std::size_t maxBondDim = 8;
369+
const DevTag<int> dev_tag{0, 0};
370+
371+
const std::vector<std::vector<std::size_t>> reference{{{2, 2},
372+
{2, 2, 4},
373+
{4, 2, 8},
374+
{8, 2, 8},
375+
{8, 2, 8},
376+
{8, 2, 4},
377+
{4, 2, 2},
378+
{2, 2}}};
379+
380+
MPSTNCuda<TestType> mps{num_qubits, maxBondDim, dev_tag};
381+
382+
const auto &sitesExtents = mps.getSitesExtents();
383+
384+
CHECK(reference == sitesExtents);
385+
}
386+
}

pennylane_lightning/core/src/simulators/lightning_tensor/utils/tncuda_utils/tests/Test_TNCuda_utils.cpp

+49
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,52 @@ TEST_CASE("swap_op_wires_queue", "[TNCuda_utils]") {
8484
REQUIRE(swap_wires_queue[1] == swap_wires_queue_ref1);
8585
}
8686
}
87+
88+
TEST_CASE("MPSShapeCheck", "[TNCuda_utils]") {
89+
SECTION("Correct incoming MPS shape") {
90+
std::vector<std::vector<std::size_t>> MPS_shape_dest{
91+
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};
92+
93+
std::vector<std::vector<std::size_t>> MPS_shape_source{
94+
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};
95+
96+
REQUIRE_NOTHROW(MPSShapeCheck(MPS_shape_dest, MPS_shape_source));
97+
}
98+
99+
SECTION("Incorrect incoming MPS shape, bond dimension") {
100+
std::vector<std::vector<std::size_t>> MPS_shape_dest{
101+
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};
102+
103+
std::vector<std::vector<std::size_t>> incorrect_MPS_shape{
104+
{2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2}};
105+
106+
REQUIRE_THROWS_WITH(
107+
MPSShapeCheck(MPS_shape_dest, incorrect_MPS_shape),
108+
Catch::Matchers::Contains("The incoming MPS does not have the "
109+
"correct layout for lightning.tensor"));
110+
}
111+
SECTION("Incorrect incoming MPS shape, physical dimension") {
112+
std::vector<std::vector<std::size_t>> MPS_shape_dest{
113+
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};
114+
115+
std::vector<std::vector<std::size_t>> incorrect_shape{
116+
{4, 2}, {2, 4, 4}, {4, 4, 2}, {2, 4}};
117+
118+
REQUIRE_THROWS_WITH(
119+
MPSShapeCheck(MPS_shape_dest, incorrect_shape),
120+
Catch::Matchers::Contains("The incoming MPS does not have the "
121+
"correct layout for lightning.tensor"));
122+
}
123+
SECTION("Incorrect incoming MPS shape, number sites") {
124+
std::vector<std::vector<std::size_t>> MPS_shape_dest{
125+
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};
126+
127+
std::vector<std::vector<std::size_t>> incorrect_shape{
128+
{2, 2}, {2, 2, 2}, {2, 2}};
129+
130+
REQUIRE_THROWS_WITH(
131+
MPSShapeCheck(MPS_shape_dest, incorrect_shape),
132+
Catch::Matchers::Contains("The incoming MPS does not have the "
133+
"correct layout for lightning.tensor"));
134+
}
135+
}

pennylane_lightning/core/src/simulators/lightning_tensor/utils/tncuda_utils/tncuda_helpers.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,19 @@ inline auto create_swap_wire_pair_queue(const std::vector<std::size_t> &wires)
194194
return {local_wires, swap_wires_queue};
195195
}
196196

197+
/**
198+
* @brief Check if the provided MPS has the correct dimension for C++
199+
* backend.
200+
*
201+
* @param MPS_shape_dest Dimension list of destination MPS.
202+
* @param MPS_shape_source Dimension list of incoming MPS.
203+
*/
204+
inline void
205+
MPSShapeCheck(const std::vector<std::vector<std::size_t>> &MPS_shape_dest,
206+
const std::vector<std::vector<std::size_t>> &MPS_shape_source) {
207+
PL_ABORT_IF_NOT(MPS_shape_dest == MPS_shape_source,
208+
"The incoming MPS does not have the correct layout for "
209+
"lightning.tensor.")
210+
}
211+
197212
} // namespace Pennylane::LightningTensor::TNCuda::Util

pennylane_lightning/core/src/utils/Util.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -592,4 +592,5 @@ bool areVecsDisjoint(const std::vector<T> &v1, const std::vector<T> &v2) {
592592
}
593593
return true;
594594
}
595+
595596
} // namespace Pennylane::Util

pennylane_lightning/lightning_tensor/_tensornet.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import numpy as np
3030
import pennylane as qml
31-
from pennylane import BasisState, DeviceError, StatePrep
31+
from pennylane import BasisState, DeviceError, MPSPrep, StatePrep
3232
from pennylane.ops.op_math import Adjoint
3333
from pennylane.tape import QuantumScript
3434
from pennylane.wires import Wires
@@ -433,17 +433,24 @@ def apply_operations(self, operations):
433433
# State preparation is currently done in Python
434434
if operations: # make sure operations[0] exists
435435
if isinstance(operations[0], StatePrep):
436-
if self.method == "tn":
437-
raise DeviceError("Exact Tensor Network does not support StatePrep")
438-
439436
if self.method == "mps":
440437
self._apply_state_vector(
441438
operations[0].parameters[0].copy(), operations[0].wires
442439
)
443440
operations = operations[1:]
441+
if self.method == "tn":
442+
raise DeviceError("Exact Tensor Network does not support StatePrep")
444443
elif isinstance(operations[0], BasisState):
445444
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
446445
operations = operations[1:]
446+
elif isinstance(operations[0], MPSPrep):
447+
if self.method == "mps":
448+
mps = operations[0].mps
449+
self._tensornet.updateMPSSitesData(mps)
450+
operations = operations[1:]
451+
452+
if self.method == "tn":
453+
raise DeviceError("Exact Tensor Network does not support MPSPrep")
447454

448455
self._apply_lightning(operations)
449456

pennylane_lightning/lightning_tensor/lightning_tensor.py

+4
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
{
7272
"Identity",
7373
"BasisState",
74+
"MPSPrep",
7475
"QubitUnitary",
7576
"ControlledQubitUnitary",
7677
"DiagonalQubitUnitary",
@@ -169,6 +170,9 @@ def stopping_condition(op: Operator) -> bool:
169170
if isinstance(op, qml.ControlledQubitUnitary):
170171
return True
171172

173+
if isinstance(op, qml.MPSPrep):
174+
return True
175+
172176
return op.has_matrix and op.name in _operations
173177

174178

0 commit comments

Comments
 (0)