Skip to content

Commit e1ec3ad

Browse files
LuisAlfredoNuringo-but-quantummlxd
authored
Optimize memory peak for _preprocess_state_vector in LightningTensor (#943)
### Before submitting Please complete the following checklist when submitting a PR: - [ ] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the [`tests`](../tests) directory! - [X] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [X] Ensure that the test suite passes, by running `make test`. - [X] Add a new entry to the `.github/CHANGELOG.md` file, summarizing the change, and including a link back to the PR. - [X] Ensure that code is properly formatted by running `make format`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** After profiling Lightning Tensor with [scalene profiler](https://github.com/plasma-umass/scalene). We found a bottleneck of memory in the function `_preprocess_state_vector` ([code](https://github.com/PennyLaneAI/pennylane-lightning/blob/4945ed08d5475d04add8d69a3cf5978ba31d1b39/pennylane_lightning/lightning_tensor/_tensornet.py#L208-L240)) which allocates 3 arrays with dimension 2 ** wires * wires. **Description of the Change:** Optimize the cartesian product to reduce the amount of memory necessary to set the `StatePrep` with LTensor **Benefits:** Reduce by half the peak of memory for large systems close to 30 qubit ![image](https://github.com/user-attachments/assets/431c7b1c-3877-472c-b2d7-e4bc38293d20) Benchmark code ``` python import pennylane as qml import numpy as np wires = 27 state = np.random.rand(2**(wires-1)) state = state / np.linalg.norm(state) dev = qml.device('lightning.tensor', wires=wires) dev_wires = dev.wires.tolist() @qml.qnode(dev) def circuit(state=state,dev_wires=dev_wires): qml.StatePrep(state, wires=dev_wires[1:]) return qml.expval(qml.Z(0)), qml.state() return circuit(state, dev_wires) ``` **Possible Drawbacks:** This change reduces readability but with a good improvement. **Related GitHub Issues:** [sc-75692] --------- Co-authored-by: ringo-but-quantum <[email protected]> Co-authored-by: Lee James O'Riordan <[email protected]>
1 parent 4945ed0 commit e1ec3ad

File tree

5 files changed

+51
-14
lines changed

5 files changed

+51
-14
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343

4444
### Improvements
4545

46+
* Optimize the cartesian product to reduce the amount of memory necessary to set the StatePrep with LightningTensor.
47+
[(#943)](https://github.com/PennyLaneAI/pennylane-lightning/pull/943)
48+
4649
* The `prob` data return `lightning.gpu` C++ layer is aligned with other state-vector backends and `lightning.gpu` supports out-of-order `qml.prob`.
4750
[(#941)](https://github.com/PennyLaneAI/pennylane-lightning/pull/941)
4851

.github/workflows/wheel_linux_aarch64.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,13 @@ jobs:
123123
mkdir Kokkos
124124
cp -rf ${{ github.workspace }}/Kokkos_install/${{ matrix.exec_model }}/* Kokkos/
125125
126+
- name: Install Python 3.10
127+
uses: actions/setup-python@v5
128+
with:
129+
python-version: '3.10'
130+
126131
- name: Install dependencies
127-
run: python -m pip install cibuildwheel~=2.20.0 tomlkit
132+
run: python3.10 -m pip install cibuildwheel~=2.20.0 tomlkit
128133

129134
- name: Configure pyproject.toml file
130135
run: PL_BACKEND="${{ matrix.pl_backend }}" python scripts/configure_pyproject_toml.py

.github/workflows/wheel_linux_aarch64_cuda.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,13 @@ jobs:
4848
- name: Checkout PennyLane-Lightning
4949
uses: actions/checkout@v4
5050

51+
- name: Install Python 3.10
52+
uses: actions/setup-python@v5
53+
with:
54+
python-version: '3.10'
55+
5156
- name: Install cibuildwheel
52-
run: python -m pip install cibuildwheel~=2.20.0 tomlkit
57+
run: python3.10 -m pip install cibuildwheel~=2.20.0 tomlkit
5358

5459
- name: Configure pyproject.toml file
5560
run: PL_BACKEND="${{ matrix.pl_backend }}" python scripts/configure_pyproject_toml.py

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.39.0-dev42"
19+
__version__ = "0.39.0-dev43"

pennylane_lightning/lightning_tensor/_tensornet.py

+35-11
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
except ImportError:
2222
pass
2323

24-
from itertools import product
25-
2624
import numpy as np
2725
import pennylane as qml
2826
from pennylane import BasisState, DeviceError, StatePrep
@@ -223,20 +221,46 @@ def _preprocess_state_vector(self, state, device_wires):
223221
if len(device_wires) == self._num_wires and Wires(sorted(device_wires)) == device_wires:
224222
return np.reshape(state, output_shape).ravel(order="C")
225223

226-
# generate basis states on subset of qubits via the cartesian product
227-
basis_states = np.array(list(product([0, 1], repeat=len(device_wires))))
224+
local_dev_wires = device_wires.tolist().copy()
225+
local_dev_wires = local_dev_wires[::-1]
226+
227+
# generate basis states on subset of qubits via broadcasting as substitute of cartesian product.
228+
229+
# Allocate a single row as a base to avoid a large array allocation with
230+
# the cartesian product algorithm.
231+
# Initialize the base with the pattern [0 1 0 1 ...].
232+
base = np.tile([0, 1], 2 ** (len(local_dev_wires) - 1)).astype(dtype=np.int64)
233+
# Allocate the array where it will accumulate the value of the indexes depending on
234+
# the value of the basis.
235+
indexes = np.zeros(2 ** (len(local_dev_wires)), dtype=np.int64)
236+
237+
max_dev_wire = self._num_wires - 1
238+
239+
# Iterate over all device wires.
240+
for i, wire in enumerate(local_dev_wires):
241+
242+
# Accumulate indexes from the basis.
243+
indexes += base * 2 ** (max_dev_wire - wire)
244+
245+
if i == len(local_dev_wires) - 1:
246+
continue
247+
248+
two_n = 2 ** (i + 1) # Compute the value of the base.
228249

229-
# get basis states to alter on full set of qubits
230-
unravelled_indices = np.zeros((2 ** len(device_wires), self._num_wires), dtype=int)
231-
unravelled_indices[:, device_wires] = basis_states
250+
# Update the value of the base without reallocating a new array.
251+
# Reshape the basis to swap the internal columns.
252+
base = base.reshape(-1, two_n * 2)
253+
swapper_A = two_n // 2
254+
swapper_B = swapper_A + two_n
232255

233-
# get indices for which the state is changed to input state vector elements
234-
ravelled_indices = np.ravel_multi_index(unravelled_indices.T, [2] * self._num_wires)
256+
base[:, swapper_A:swapper_B] = base[:, swapper_A:swapper_B][:, ::-1]
257+
# Flatten the base array
258+
base = base.reshape(-1)
235259

236260
# get full state vector to be factorized into MPS
237261
full_state = np.zeros(2**self._num_wires, dtype=self.dtype)
238262
for i, value in enumerate(state):
239-
full_state[ravelled_indices[i]] = value
263+
full_state[indexes[i]] = value
240264
return np.reshape(full_state, output_shape).ravel(order="C")
241265

242266
def _apply_state_vector(self, state, device_wires: Wires):
@@ -285,7 +309,7 @@ def _apply_MPO(self, gate_matrix, wires):
285309
None
286310
"""
287311
# TODO: Discuss if public interface for max_mpo_bond_dim argument
288-
max_mpo_bond_dim = 2 ** len(wires) # Exact SVD decomposition for MPO
312+
max_mpo_bond_dim = self._max_bond_dim
289313

290314
# Get sorted wires and MPO site tensor
291315
mpos, sorted_wires = gate_matrix_decompose(

0 commit comments

Comments
 (0)