Skip to content

Commit 9e26a91

Browse files
Optimize the input state-vector copy into the LGPU (#1071)
**Context:** After running different algorithm with LGPU and perform a memory profile. Show a memory bottleneck for LGPU on the Python layer because the peak of memory is 3 times the need for the computation. ![image](https://github.com/user-attachments/assets/24eebcf0-49f2-45e1-a63f-71da26cb8dd4) **Description of the Change:** Remove tmp allocation and skip indexes computation for common cases. * Remove temporal GPU allocation for input values and indexes. * The input state vector is copied directly from the host if **the target wires are contiguous and start in the most/least significant wires** (which are the most common cases). * In the case of custom target wires, LGPU follow the previous algorithm but with a speedup in the index computation thought parallel computing with OpenMP **Benefits:** Using a test algorithm with 31 qubits produce the following memory profile: ![newplot(3)](https://github.com/user-attachments/assets/8c74bfb7-f7ed-4759-bd7d-54e23e8e23df) Reduction of the memory peak from 100GB to 66GB Note: `memray` measures all the memory allocation, even for the GPU `cudaMallocX`. Using the following toy circuit ```python state_init = random_normalize_sv(wires-1) target_wires = wires[:-1] dev = qml.device("lightning.gpu", wires=wires) def circuit(): qml.StatePrep(input_state, wires=target_wires) return qml.expval(qml.PauliZ(0)) ``` Produce the following times ![image](https://github.com/user-attachments/assets/9fff27a9-bd3b-42b4-9dd6-88ce3119ee07) ![image](https://github.com/user-attachments/assets/e151ed5e-4a0f-445e-af1c-3de502617f44) **Possible Drawbacks:** **Related GitHub Issues:** [sc-58833] --------- Co-authored-by: ringo-but-quantum <[email protected]>
1 parent 109db9f commit 9e26a91

File tree

7 files changed

+239
-26
lines changed

7 files changed

+239
-26
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
### Improvements
88

9+
* Optimize the copy of a input state-vector into the LGPU #1071
10+
[(#1071)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1071)
11+
912
* Fix wheel naming in Docker builds for `setuptools v75.8.1` compatibility.
1013
[(#1075)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1075)
1114

pennylane_lightning/core/_state_vector_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def apply_operations(
184184
# State preparation is currently done in Python
185185
if operations: # make sure operations[0] exists
186186
if isinstance(operations[0], StatePrep):
187-
self._apply_state_vector(operations[0].parameters[0].copy(), operations[0].wires)
187+
self._apply_state_vector(operations[0].parameters[0], operations[0].wires)
188188
operations = operations[1:]
189189
elif isinstance(operations[0], BasisState):
190190
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.41.0-dev28"
19+
__version__ = "0.41.0-dev29"

pennylane_lightning/core/src/simulators/lightning_gpu/StateVectorCudaManaged.hpp

+78-19
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
#pragma once
1818

19+
#include <algorithm>
1920
#include <random>
2021
#include <type_traits>
2122
#include <unordered_map>
@@ -159,8 +160,8 @@ class StateVectorCudaManaged
159160
*/
160161
void resetStateVector(bool use_async = false) {
161162
BaseType::getDataBuffer().zeroInit();
162-
std::size_t index = 0;
163-
ComplexT value(1.0, 0.0);
163+
constexpr std::size_t index = 0;
164+
constexpr ComplexT value(1.0, 0.0);
164165
setBasisState_(value, index, use_async);
165166
};
166167

@@ -200,14 +201,14 @@ class StateVectorCudaManaged
200201
* @brief Set values for a batch of elements of the state-vector.
201202
*
202203
* @param state_ptr Pointer to the initial state data.
203-
* @param num_states Length of the initial state data.
204+
* @param state_size Length of the initial state data.
204205
* @param wires Wires.
205206
* @param use_async Use an asynchronous memory copy. Default is false.
206207
*/
207-
void setStateVector(const ComplexT *state_ptr, const std::size_t num_states,
208+
void setStateVector(const ComplexT *state_ptr, const std::size_t state_size,
208209
const std::vector<std::size_t> &wires,
209210
bool use_async = false) {
210-
PL_ABORT_IF_NOT(num_states == Pennylane::Util::exp2(wires.size()),
211+
PL_ABORT_IF_NOT(state_size == Pennylane::Util::exp2(wires.size()),
211212
"Inconsistent state and wires dimensions.");
212213

213214
const auto num_qubits = BaseType::getNumQubits();
@@ -222,21 +223,45 @@ class StateVectorCudaManaged
222223
typename std::conditional<std::is_same<PrecisionT, float>::value,
223224
int32_t, int64_t>::type;
224225

225-
// Calculate the indices of the state-vector to be set.
226-
// TODO: Could move to GPU calculation if the state size is large.
227-
std::vector<index_type> indices(num_states);
228-
const std::size_t num_wires = wires.size();
229-
constexpr std::size_t one{1U};
230-
for (std::size_t i = 0; i < num_states; i++) {
231-
std::size_t index{0U};
232-
for (std::size_t j = 0; j < num_wires; j++) {
233-
const std::size_t bit = (i & (one << j)) >> j;
234-
index |= bit << (num_qubits - 1 - wires[num_wires - 1 - j]);
226+
const bool is_wires_sorted_contiguous =
227+
std::is_sorted(wires.begin(), wires.end()) &&
228+
wires.front() + wires.size() - 1 == wires.back();
229+
230+
const bool is_left_significant = wires.front() == 0;
231+
const bool is_side_significant =
232+
is_left_significant || wires.back() == num_qubits - 1;
233+
234+
if (is_wires_sorted_contiguous && is_side_significant) {
235+
// Set most common case: contiguous wires
236+
setSortedContiguousStateVector_<index_type>(
237+
state_size, state_ptr, wires, is_left_significant, use_async);
238+
} else {
239+
// Set the state-vector for non-contiguous wires
240+
std::vector<index_type> indices(state_size);
241+
242+
// Calculate the indices of the state-vector to be set.
243+
// TODO: Could move to GPU calculation if the state size is large.
244+
#pragma omp parallel shared(state_size, num_qubits, indices, wires)
245+
{
246+
const std::size_t num_wires = wires.size();
247+
auto local_wires = wires;
248+
249+
#pragma omp for
250+
for (std::size_t i = 0; i < state_size; i++) {
251+
constexpr std::size_t one{1U};
252+
std::size_t index{0U};
253+
for (std::size_t j = 0; j < num_wires; j++) {
254+
const std::size_t bit = (i & (one << j)) >> j;
255+
index |= bit << (num_qubits - 1 -
256+
local_wires[num_wires - 1 - j]);
257+
}
258+
indices[i] = static_cast<index_type>(index);
259+
}
235260
}
236-
indices[i] = static_cast<index_type>(index);
261+
// set the state-vector
262+
setStateVector_<index_type>(state_size, state_ptr, indices.data(),
263+
use_async);
237264
}
238-
setStateVector_<index_type>(num_states, state_ptr, indices.data(),
239-
use_async);
240265
}
241266

242267
/**
@@ -2128,6 +2153,40 @@ class StateVectorCudaManaged
21282153
stream_id);
21292154
}
21302155

2156+
/**
2157+
* @brief Set values for a batch of elements of the state-vector. This
2158+
* method is implemented by the customized CUDA kernel defined in the
2159+
* DataBuffer class.
2160+
*
2161+
* @tparam index_type Integer value type.
2162+
*
2163+
* @param num_indices Number of elements to be passed to the state vector.
2164+
* @param values Pointer to values to be set for the target elements.
2165+
* @param wires Wires of the target elements.
2166+
* @param is_left_significant If true, the target wires start from zero.
2167+
* Otherwise, the last target wire matches the last qubit.
2168+
* @param async Use an asynchronous memory copy.
2169+
*/
2170+
template <class index_type>
2171+
void setSortedContiguousStateVector_(const index_type num_indices,
2172+
const std::complex<PrecisionT> *values,
2173+
const std::vector<std::size_t> &wires,
2174+
const bool is_left_significant = false,
2175+
const bool async = false) {
2176+
BaseType::getDataBuffer().zeroInit();
2177+
2178+
if (is_left_significant) {
2179+
size_t stride = std::size_t(1)
2180+
<< (BaseType::getNumQubits() - wires.size());
2181+
BaseType::getDataBuffer().CopyHostDataToGpuWithStride(
2182+
values, num_indices, stride, async);
2183+
} else {
2184+
BaseType::getDataBuffer().CopyHostDataToGpu(values, num_indices,
2185+
std::size_t(0), async);
2186+
}
2187+
PL_CUDA_IS_SUCCESS(cudaDeviceSynchronize());
2188+
}
2189+
21312190
/**
21322191
* @brief Set values for a batch of elements of the state-vector. This
21332192
* method is implemented by the customized CUDA kernel defined in the
@@ -2140,7 +2199,7 @@ class StateVectorCudaManaged
21402199
*/
21412200
template <class index_type, std::size_t thread_per_block = 256>
21422201
void setStateVector_(const index_type num_indices,
2143-
const std::complex<Precision> *values,
2202+
const std::complex<PrecisionT> *values,
21442203
const index_type *indices, const bool async = false) {
21452204
BaseType::getDataBuffer().zeroInit();
21462205

pennylane_lightning/core/src/simulators/lightning_gpu/gates/tests/Test_StateVectorCudaManaged_NonParam.cpp

+88-4
Original file line numberDiff line numberDiff line change
@@ -1108,11 +1108,50 @@ TEMPLATE_TEST_CASE("StateVectorCudaManaged::SetStateVector",
11081108
"the host") {
11091109
auto init_state =
11101110
createRandomStateVectorData<PrecisionT>(re, num_qubits);
1111-
auto expected_state = init_state;
11121111

1112+
StateVectorCudaManaged<TestType> sv{num_qubits};
1113+
1114+
std::vector<std::complex<PrecisionT>> values(init_state.begin(),
1115+
init_state.end());
1116+
1117+
sv.setStateVector(values.data(), values.size(),
1118+
std::vector<std::size_t>{0, 1, 2});
1119+
CHECK(init_state == Pennylane::Util::approx(sv.getDataVector()));
1120+
}
1121+
1122+
SECTION("Set state vector with values and their corresponding indices on "
1123+
"the host for a subset of wires right significant") {
1124+
auto init_state =
1125+
createRandomStateVectorData<PrecisionT>(re, num_qubits - 1);
1126+
1127+
std::vector<std::complex<PrecisionT>> expected_state(
1128+
Pennylane::Util::exp2(num_qubits), {0, 0});
1129+
1130+
std::copy(init_state.begin(), init_state.end(), expected_state.begin());
1131+
1132+
StateVectorCudaManaged<TestType> sv{num_qubits};
1133+
1134+
std::vector<std::complex<PrecisionT>> values(init_state.begin(),
1135+
init_state.end());
1136+
1137+
sv.setStateVector(values.data(), values.size(),
1138+
std::vector<std::size_t>{1, 2});
1139+
CHECK(expected_state == Pennylane::Util::approx(sv.getDataVector()));
1140+
}
1141+
1142+
SECTION("Set state vector with values and their corresponding indices on "
1143+
"the host for a subset of wires left significant") {
1144+
auto init_state =
1145+
createRandomStateVectorData<PrecisionT>(re, num_qubits - 1);
1146+
1147+
std::vector<std::complex<PrecisionT>> expected_state(
1148+
Pennylane::Util::exp2(num_qubits), {0, 0});
1149+
1150+
// Distributing along the base vector with a stride.
1151+
// Stride is 2**(n_qubits - n_target_wires)
11131152
for (std::size_t i = 0; i < Pennylane::Util::exp2(num_qubits - 1);
11141153
i++) {
1115-
std::swap(expected_state[i * 2], expected_state[i * 2 + 1]);
1154+
expected_state[i * 2] = init_state[i];
11161155
}
11171156

11181157
StateVectorCudaManaged<TestType> sv{num_qubits};
@@ -1121,8 +1160,53 @@ TEMPLATE_TEST_CASE("StateVectorCudaManaged::SetStateVector",
11211160
init_state.end());
11221161

11231162
sv.setStateVector(values.data(), values.size(),
1124-
std::vector<std::size_t>{0, 1, 2});
1125-
CHECK(init_state == Pennylane::Util::approx(sv.getDataVector()));
1163+
std::vector<std::size_t>{0, 1});
1164+
CHECK(expected_state == Pennylane::Util::approx(sv.getDataVector()));
1165+
}
1166+
SECTION("Set state vector with values and their corresponding indices on "
1167+
"the host for a subset of wires non-consecutive") {
1168+
auto init_state =
1169+
createRandomStateVectorData<PrecisionT>(re, num_qubits - 1);
1170+
1171+
std::vector<std::complex<PrecisionT>> expected_state(
1172+
Pennylane::Util::exp2(num_qubits), {0, 0});
1173+
1174+
expected_state[0] = init_state[0];
1175+
expected_state[1] = init_state[1];
1176+
expected_state[4] = init_state[2];
1177+
expected_state[5] = init_state[3];
1178+
1179+
StateVectorCudaManaged<TestType> sv{num_qubits};
1180+
1181+
std::vector<std::complex<PrecisionT>> values(init_state.begin(),
1182+
init_state.end());
1183+
1184+
sv.setStateVector(values.data(), values.size(),
1185+
std::vector<std::size_t>{0, 2});
1186+
CHECK(expected_state == Pennylane::Util::approx(sv.getDataVector()));
1187+
}
1188+
SECTION("Set state vector with values and their corresponding indices on "
1189+
"the host for a subset of wires consecutive and not significant") {
1190+
std::size_t num_qubits_local = 4;
1191+
auto init_state =
1192+
createRandomStateVectorData<PrecisionT>(re, num_qubits_local - 2);
1193+
1194+
std::vector<std::complex<PrecisionT>> expected_state(
1195+
Pennylane::Util::exp2(num_qubits_local), {0, 0});
1196+
1197+
expected_state[0] = init_state[0];
1198+
expected_state[2] = init_state[1];
1199+
expected_state[4] = init_state[2];
1200+
expected_state[6] = init_state[3];
1201+
1202+
StateVectorCudaManaged<TestType> sv{num_qubits_local};
1203+
1204+
std::vector<std::complex<PrecisionT>> values(init_state.begin(),
1205+
init_state.end());
1206+
1207+
sv.setStateVector(values.data(), values.size(),
1208+
std::vector<std::size_t>{1, 2});
1209+
CHECK(expected_state == Pennylane::Util::approx(sv.getDataVector()));
11261210
}
11271211
}
11281212

pennylane_lightning/core/src/utils/cuda_utils/DataBuffer.hpp

+67
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,73 @@ template <class GPUDataT, class DevTagT = int> class DataBuffer {
213213
}
214214
}
215215

216+
/**
217+
* @brief Explicitly copy data from host memory to GPU device with an
218+
* offset.
219+
*
220+
* @tparam HostDataT Host data type.
221+
*
222+
* @param host_in Host data buffer.
223+
* @param length Number of elements to copy.
224+
* @param offset Offset in the GPU buffer.
225+
* @param async Asynchronous copy flag.
226+
*
227+
*/
228+
template <class HostDataT = GPUDataT>
229+
void CopyHostDataToGpu(const HostDataT *host_in, std::size_t length,
230+
std::size_t offset, bool async = false) {
231+
PL_ABORT_IF(
232+
(getLength() * sizeof(GPUDataT)) <
233+
((offset + length) * sizeof(HostDataT)),
234+
"Sizes do not match for host & GPU data. Please ensure the source "
235+
"buffer is out of bounds of the destination buffer");
236+
237+
if (async) {
238+
PL_CUDA_IS_SUCCESS(cudaMemcpyAsync(
239+
getData() + offset, host_in, sizeof(GPUDataT) * length,
240+
cudaMemcpyHostToDevice, getStream()));
241+
} else {
242+
PL_CUDA_IS_SUCCESS(cudaMemcpy(getData() + offset, host_in,
243+
sizeof(GPUDataT) * length,
244+
cudaMemcpyDefault));
245+
}
246+
}
247+
248+
/**
249+
* @brief Explicitly copy data from host memory to GPU device with a stride.
250+
*
251+
* @tparam HostDataT Host data type.
252+
*
253+
* @param host_in Host data buffer.
254+
* @param length Number of elements to copy.
255+
* @param stride Stride in the GPU buffer.
256+
* @param async Asynchronous copy flag.
257+
*
258+
*/
259+
template <class HostDataT = GPUDataT>
260+
void CopyHostDataToGpuWithStride(const HostDataT *host_in,
261+
std::size_t length, std::size_t stride,
262+
bool async = false) {
263+
PL_ABORT_IF(
264+
(getLength() * sizeof(GPUDataT)) <
265+
((stride * length) * sizeof(HostDataT)),
266+
"Sizes do not match for host & GPU data. Please ensure the source "
267+
"buffer is out of bounds of the destination buffer or the stride "
268+
"is too large");
269+
270+
if (async) {
271+
PL_CUDA_IS_SUCCESS(
272+
cudaMemcpy2DAsync(getData(), sizeof(GPUDataT) * stride, host_in,
273+
sizeof(HostDataT), sizeof(HostDataT), length,
274+
cudaMemcpyHostToDevice, getStream()));
275+
} else {
276+
PL_CUDA_IS_SUCCESS(
277+
cudaMemcpy2D(getData(), sizeof(GPUDataT) * stride, host_in,
278+
sizeof(HostDataT), sizeof(HostDataT), length,
279+
cudaMemcpyHostToDevice));
280+
}
281+
}
282+
216283
/**
217284
* @brief Explicitly copy data from GPU device to host memory.
218285
*

pennylane_lightning/lightning_gpu/_state_vector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ def _apply_state_vector(self, state, device_wires, use_async: bool = False):
214214
# state = state_data
215215

216216
state = self._asarray(state, dtype=self.dtype) # this operation on host
217-
output_shape = [2] * self._num_local_wires
218217

219218
if len(device_wires) == self.num_wires and Wires(sorted(device_wires)) == device_wires:
220219
# Initialize the entire device state with the input state
220+
output_shape = [2] * self._num_local_wires
221221
if self.num_wires == self._num_local_wires:
222222
self.syncH2D(np.reshape(state, output_shape))
223223
return

0 commit comments

Comments
 (0)