|
21 | 21 | except ImportError:
|
22 | 22 | pass
|
23 | 23 |
|
24 |
| -from itertools import product |
25 |
| - |
26 | 24 | import numpy as np
|
27 | 25 | import pennylane as qml
|
28 | 26 | from pennylane import BasisState, DeviceError, StatePrep
|
@@ -223,20 +221,46 @@ def _preprocess_state_vector(self, state, device_wires):
|
223 | 221 | if len(device_wires) == self._num_wires and Wires(sorted(device_wires)) == device_wires:
|
224 | 222 | return np.reshape(state, output_shape).ravel(order="C")
|
225 | 223 |
|
226 |
| - # generate basis states on subset of qubits via the cartesian product |
227 |
| - basis_states = np.array(list(product([0, 1], repeat=len(device_wires)))) |
| 224 | + local_dev_wires = device_wires.tolist().copy() |
| 225 | + local_dev_wires = local_dev_wires[::-1] |
| 226 | + |
| 227 | + # generate basis states on subset of qubits via broadcasting as substitute of cartesian product. |
| 228 | + |
| 229 | + # Allocate a single row as a base to avoid a large array allocation with |
| 230 | + # the cartesian product algorithm. |
| 231 | + # Initialize the base with the pattern [0 1 0 1 ...]. |
| 232 | + base = np.tile([0, 1], 2 ** (len(local_dev_wires) - 1)).astype(dtype=np.int64) |
| 233 | + # Allocate the array where it will accumulate the value of the indexes depending on |
| 234 | + # the value of the basis. |
| 235 | + indexes = np.zeros(2 ** (len(local_dev_wires)), dtype=np.int64) |
| 236 | + |
| 237 | + max_dev_wire = self._num_wires - 1 |
| 238 | + |
| 239 | + # Iterate over all device wires. |
| 240 | + for i, wire in enumerate(local_dev_wires): |
| 241 | + |
| 242 | + # Accumulate indexes from the basis. |
| 243 | + indexes += base * 2 ** (max_dev_wire - wire) |
| 244 | + |
| 245 | + if i == len(local_dev_wires) - 1: |
| 246 | + continue |
| 247 | + |
| 248 | + two_n = 2 ** (i + 1) # Compute the value of the base. |
228 | 249 |
|
229 |
| - # get basis states to alter on full set of qubits |
230 |
| - unravelled_indices = np.zeros((2 ** len(device_wires), self._num_wires), dtype=int) |
231 |
| - unravelled_indices[:, device_wires] = basis_states |
| 250 | + # Update the value of the base without reallocating a new array. |
| 251 | + # Reshape the basis to swap the internal columns. |
| 252 | + base = base.reshape(-1, two_n * 2) |
| 253 | + swapper_A = two_n // 2 |
| 254 | + swapper_B = swapper_A + two_n |
232 | 255 |
|
233 |
| - # get indices for which the state is changed to input state vector elements |
234 |
| - ravelled_indices = np.ravel_multi_index(unravelled_indices.T, [2] * self._num_wires) |
| 256 | + base[:, swapper_A:swapper_B] = base[:, swapper_A:swapper_B][:, ::-1] |
| 257 | + # Flatten the base array |
| 258 | + base = base.reshape(-1) |
235 | 259 |
|
236 | 260 | # get full state vector to be factorized into MPS
|
237 | 261 | full_state = np.zeros(2**self._num_wires, dtype=self.dtype)
|
238 | 262 | for i, value in enumerate(state):
|
239 |
| - full_state[ravelled_indices[i]] = value |
| 263 | + full_state[indexes[i]] = value |
240 | 264 | return np.reshape(full_state, output_shape).ravel(order="C")
|
241 | 265 |
|
242 | 266 | def _apply_state_vector(self, state, device_wires: Wires):
|
@@ -285,7 +309,7 @@ def _apply_MPO(self, gate_matrix, wires):
|
285 | 309 | None
|
286 | 310 | """
|
287 | 311 | # TODO: Discuss if public interface for max_mpo_bond_dim argument
|
288 |
| - max_mpo_bond_dim = 2 ** len(wires) # Exact SVD decomposition for MPO |
| 312 | + max_mpo_bond_dim = self._max_bond_dim |
289 | 313 |
|
290 | 314 | # Get sorted wires and MPO site tensor
|
291 | 315 | mpos, sorted_wires = gate_matrix_decompose(
|
|
0 commit comments