Skip to content

Commit 74ba19e

Browse files
[Op. Arithm.] Improve performance #1 (#3022)
* add hash * use hash to compute eigendecomposition * use eigvals of factors if they commute * improve prod * merge * add hash to sum * add tests * revert * docs * refactor * remove adjoint * chore (changelog): ✏️ Add feature to changelog. * chore (changelog): ✏️ Add feature to changelog. * Update doc/releases/changelog-dev.md Co-authored-by: Christina Lee <[email protected]> Co-authored-by: Christina Lee <[email protected]>
1 parent 54f76a9 commit 74ba19e

File tree

5 files changed

+214
-68
lines changed

5 files changed

+214
-68
lines changed

doc/releases/changelog-dev.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@
266266
* `Controlled` operators now work with `qml.is_commuting`.
267267
[(#2994)](https://github.com/PennyLaneAI/pennylane/pull/2994)
268268

269-
* `Prod` and `Sum` class now support the `sparse_matrix()` method.
269+
* `Prod` and `Sum` class now support the `sparse_matrix()` method.
270270
[(#3006)](https://github.com/PennyLaneAI/pennylane/pull/3006)
271271

272272
```pycon
@@ -291,6 +291,11 @@
291291
depth greater than 0. The `__repr__` for `Controlled` show `control_wires` instead of `wires`.
292292
[(#3013)](https://github.com/PennyLaneAI/pennylane/pull/3013)
293293

294+
* Use `Operator.hash` instead of `Operator.matrix` to cache the eigendecomposition results in `Prod` and
295+
`Sum` classes. When `Prod` and `Sum` operators have no overlapping wires, compute the eigenvalues
296+
and the diagonalising gates using the factors/summands instead of using the full matrix.
297+
[(#3022)](https://github.com/PennyLaneAI/pennylane/pull/3022)
298+
294299
<h3>Breaking changes</h3>
295300

296301
* Measuring an operator that might not be hermitian as an observable now raises a warning instead of an
@@ -323,7 +328,7 @@
323328
* Fixes a bug where the tape transform `single_qubit_fusion` computed wrong rotation angles
324329
for specific combinations of rotations.
325330
[(#3024)](https://github.com/PennyLaneAI/pennylane/pull/3024)
326-
331+
327332
* Jax gradients now work with a QNode when the quantum function was transformed by `qml.simplify`.
328333
[(#3017)](https://github.com/PennyLaneAI/pennylane/pull/3017)
329334

pennylane/ops/op_math/prod.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__(
167167
self.factors = factors
168168
self._wires = qml.wires.Wires.all_wires([f.wires for f in self.factors])
169169
self._hash = None
170+
self._has_overlapping_wires = None
170171

171172
if do_queue:
172173
self.queue()
@@ -226,6 +227,16 @@ def is_hermitian(self):
226227
return False
227228
return all(op.is_hermitian for op in self.factors)
228229

230+
@property
231+
def has_overlapping_wires(self) -> bool:
232+
"""Boolean expression that indicates if the factors have overlapping wires."""
233+
if self._has_overlapping_wires is None:
234+
wires = []
235+
for op in self.factors:
236+
wires.extend(list(op.wires))
237+
self._has_overlapping_wires = len(wires) != len(set(wires))
238+
return self._has_overlapping_wires
239+
229240
def decomposition(self):
230241
r"""Decomposition of the product operator is given by each factor applied in succession.
231242
@@ -250,14 +261,13 @@ def eigendecomposition(self):
250261
dict[str, array]: dictionary containing the eigenvalues and the
251262
eigenvectors of the operator.
252263
"""
253-
Hmat = self.matrix()
254-
Hmat = math.to_numpy(Hmat)
255-
Hkey = tuple(Hmat.flatten().tolist())
256-
if Hkey not in self._eigs:
264+
if self.hash not in self._eigs:
265+
Hmat = self.matrix()
266+
Hmat = math.to_numpy(Hmat)
257267
w, U = np.linalg.eigh(Hmat)
258-
self._eigs[Hkey] = {"eigvec": U, "eigval": w}
268+
self._eigs[self.hash] = {"eigvec": U, "eigval": w}
259269

260-
return self._eigs[Hkey]
270+
return self._eigs[self.hash]
261271

262272
def diagonalizing_gates(self):
263273
r"""Sequence of gates that diagonalize the operator in the computational basis.
@@ -276,20 +286,31 @@ def diagonalizing_gates(self):
276286
Returns:
277287
list[.Operator] or None: a list of operators
278288
"""
279-
280-
eigen_vectors = self.eigendecomposition["eigvec"]
281-
return [qml.QubitUnitary(eigen_vectors.conj().T, wires=self.wires)]
289+
if self.has_overlapping_wires:
290+
eigen_vectors = self.eigendecomposition["eigvec"]
291+
return [qml.QubitUnitary(eigen_vectors.conj().T, wires=self.wires)]
292+
diag_gates = []
293+
for factor in self.factors:
294+
diag_gates.extend(factor.diagonalizing_gates())
295+
return diag_gates
282296

283297
def eigvals(self):
284-
r"""Return the eigenvalues of the specified operator.
298+
"""Return the eigenvalues of the specified operator.
285299
286300
This method uses pre-stored eigenvalues for standard observables where
287301
possible and stores the corresponding eigenvectors from the eigendecomposition.
288302
289303
Returns:
290304
array: array containing the eigenvalues of the operator
291305
"""
292-
return self.eigendecomposition["eigval"]
306+
if self.has_overlapping_wires:
307+
return self.eigendecomposition["eigval"]
308+
eigvals = [
309+
qml.utils.expand_vector(factor.eigvals(), list(factor.wires), list(self.wires))
310+
for factor in self.factors
311+
]
312+
313+
return qml.math.prod(eigvals, axis=0)
293314

294315
def matrix(self, wire_order=None):
295316
"""Representation of the operator as a matrix in the computational basis."""
@@ -423,7 +444,7 @@ def hash(self):
423444
return self._hash
424445

425446

426-
def _prod_sort(op_list, wire_map: dict = None):
447+
def _prod_sort(op_list, wire_map: dict = None) -> List[Operator]:
427448
"""Insertion sort algorithm that sorts a list of product factors by their wire indices, taking
428449
into account the operator commutivity.
429450

pennylane/ops/op_math/sum.py

+73-12
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def __init__(
153153
self._name = "Sum"
154154
self._id = id
155155
self.queue_idx = None
156+
self._hash = None
157+
self._has_overlapping_wires = None
156158

157159
if len(summands) < 2:
158160
raise ValueError(f"Require at least two operators to sum; got {len(summands)}")
@@ -202,6 +204,16 @@ def is_hermitian(self):
202204
"""If all of the terms in the sum are hermitian, then the Sum is hermitian."""
203205
return all(s.is_hermitian for s in self.summands)
204206

207+
@property
208+
def has_overlapping_wires(self) -> bool:
209+
"""Boolean expression that indicates if the factors have overlapping wires."""
210+
if self._has_overlapping_wires is None:
211+
wires = []
212+
for op in self.summands:
213+
wires.extend(list(op.wires))
214+
self._has_overlapping_wires = len(wires) != len(set(wires))
215+
return self._has_overlapping_wires
216+
205217
def terms(self):
206218
r"""Representation of the operator as a linear combination of other operators.
207219
@@ -219,7 +231,7 @@ def terms(self):
219231

220232
@property
221233
def eigendecomposition(self):
222-
r"""Return the eigendecomposition of the matrix specified by the Hermitian observable.
234+
r"""Return the eigendecomposition of the matrix specified by the operator.
223235
224236
This method uses pre-stored eigenvalues for standard observables where
225237
possible and stores the corresponding eigenvectors from the eigendecomposition.
@@ -228,16 +240,15 @@ def eigendecomposition(self):
228240
229241
Returns:
230242
dict[str, array]: dictionary containing the eigenvalues and the eigenvectors of the
231-
operator
243+
operator.
232244
"""
233-
Hmat = self.matrix()
234-
Hmat = qml.math.to_numpy(Hmat)
235-
Hkey = tuple(Hmat.flatten().tolist())
236-
if Hkey not in self._eigs:
245+
if self.hash not in self._eigs:
246+
Hmat = self.matrix()
247+
Hmat = math.to_numpy(Hmat)
237248
w, U = np.linalg.eigh(Hmat)
238-
self._eigs[Hkey] = {"eigvec": U, "eigval": w}
249+
self._eigs[self.hash] = {"eigvec": U, "eigval": w}
239250

240-
return self._eigs[Hkey]
251+
return self._eigs[self.hash]
241252

242253
def diagonalizing_gates(self):
243254
r"""Sequence of gates that diagonalize the operator in the computational basis.
@@ -256,9 +267,13 @@ def diagonalizing_gates(self):
256267
Returns:
257268
list[.Operator] or None: a list of operators
258269
"""
259-
260-
eigen_vectors = self.eigendecomposition["eigvec"]
261-
return [qml.QubitUnitary(eigen_vectors.conj().T, wires=self.wires)]
270+
if self.has_overlapping_wires:
271+
eigen_vectors = self.eigendecomposition["eigvec"]
272+
return [qml.QubitUnitary(eigen_vectors.conj().T, wires=self.wires)]
273+
diag_gates = []
274+
for summand in self.summands:
275+
diag_gates.extend(summand.diagonalizing_gates())
276+
return diag_gates
262277

263278
def eigvals(self):
264279
r"""Return the eigenvalues of the specified Hermitian observable.
@@ -269,7 +284,13 @@ def eigvals(self):
269284
Returns:
270285
array: array containing the eigenvalues of the Hermitian observable
271286
"""
272-
return self.eigendecomposition["eigval"]
287+
if self.has_overlapping_wires:
288+
return self.eigendecomposition["eigval"]
289+
eigvals = [
290+
qml.utils.expand_vector(summand.eigvals(), list(summand.wires), list(self.wires))
291+
for summand in self.summands
292+
]
293+
return qml.math.sum(eigvals, axis=0)
273294

274295
def matrix(self, wire_order=None):
275296
r"""Representation of the operator as a matrix in the computational basis.
@@ -414,6 +435,14 @@ def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ
414435
else qml.Identity(self.wires[0]),
415436
)
416437

438+
@property
439+
def hash(self):
440+
if self._hash is None:
441+
self._hash = hash(
442+
(str(self.name), str([summand.hash for summand in _sum_sort(self.summands)]))
443+
)
444+
return self._hash
445+
417446

418447
class _SumSummandsGrouping:
419448
"""Utils class used for grouping sum summands together."""
@@ -457,3 +486,35 @@ def get_summands(self, cutoff=1.0e-12):
457486
new_summands.append(qml.s_prod(coeff, summand))
458487

459488
return new_summands
489+
490+
491+
def _sum_sort(op_list, wire_map: dict = None) -> List[Operator]:
492+
"""Sort algorithm that sorts a list of sum summands by their wire indices.
493+
494+
Args:
495+
op_list (List[.Operator]): list of operators to be sorted
496+
wire_map (dict): Dictionary containing the wire values as keys and its indexes as values.
497+
Defaults to None.
498+
499+
Returns:
500+
List[.Operator]: sorted list of operators
501+
"""
502+
503+
if isinstance(op_list, tuple):
504+
op_list = list(op_list)
505+
506+
def _sort_key(op) -> bool:
507+
"""Sorting key.
508+
509+
Args:
510+
op (.Operator): Operator.
511+
512+
Returns:
513+
int: Minimum wire value.
514+
"""
515+
wires = op.wires
516+
if wire_map is not None:
517+
wires = wires.map(wire_map)
518+
return np.min(wires)
519+
520+
return sorted(op_list, key=_sort_key)

tests/ops/op_math/test_prod.py

+9-25
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,7 @@ def test_eigen_caching(self):
220220
eig_vecs = eig_decomp["eigvec"]
221221
eig_vals = eig_decomp["eigval"]
222222

223-
eigs_cache = prod_op._eigs[
224-
(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, -1.0)
225-
]
223+
eigs_cache = prod_op._eigs[prod_op.hash]
226224
cached_vecs = eigs_cache["eigvec"]
227225
cached_vals = eigs_cache["eigval"]
228226

@@ -696,9 +694,7 @@ def test_eigen_caching(self):
696694

697695
eig_vecs = eig_decomp["eigvec"]
698696
eig_vals = eig_decomp["eigval"]
699-
eigs_cache = diag_prod_op._eigs[
700-
(1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
701-
]
697+
eigs_cache = diag_prod_op._eigs[diag_prod_op.hash]
702698
cached_vecs = eigs_cache["eigvec"]
703699
cached_vals = eigs_cache["eigval"]
704700

@@ -708,19 +704,7 @@ def test_eigen_caching(self):
708704
def test_diagonalizing_gates(self):
709705
"""Test that the diagonalizing gates are correct."""
710706
diag_prod_op = Prod(qml.PauliZ(wires=0), qml.PauliZ(wires=1))
711-
diagonalizing_gates = diag_prod_op.diagonalizing_gates()[0].matrix()
712-
true_diagonalizing_gates = qnp.array(
713-
( # the gates that swap eigvals till they are ordered smallest --> largest
714-
[
715-
[0.0, 1.0, 0.0, 0.0],
716-
[0.0, 0.0, 1.0, 0.0],
717-
[1.0, 0.0, 0.0, 0.0],
718-
[0.0, 0.0, 0.0, 1.0],
719-
]
720-
)
721-
)
722-
723-
assert np.allclose(diagonalizing_gates, true_diagonalizing_gates)
707+
assert diag_prod_op.diagonalizing_gates() == []
724708

725709

726710
class TestSimplify:
@@ -1184,11 +1168,11 @@ def test_sorting_operators_with_wire_map(self):
11841168
qml.PauliX(5),
11851169
qml.Toffoli([2, "three", 4]),
11861170
qml.CNOT([2, 5]),
1187-
qml.RX("test", 5),
1171+
qml.RX(1, 5),
11881172
qml.PauliY(0),
1189-
qml.CRX("test", [0, 2]),
1173+
qml.CRX(1, [0, 2]),
11901174
qml.PauliZ("three"),
1191-
qml.CRY("test", ["test", 2]),
1175+
qml.CRY(1, ["test", 2]),
11921176
]
11931177
sorted_list = _prod_sort(op_list, wire_map={0: 0, "test": 1, 2: 2, "three": 3, 4: 4, 5: 5})
11941178
final_list = [
@@ -1197,10 +1181,10 @@ def test_sorting_operators_with_wire_map(self):
11971181
qml.Toffoli([2, "three", 4]),
11981182
qml.PauliX(5),
11991183
qml.CNOT([2, 5]),
1200-
qml.CRX("test", [0, 2]),
1201-
qml.CRY("test", ["test", 2]),
1184+
qml.CRX(1, [0, 2]),
1185+
qml.CRY(1, ["test", 2]),
12021186
qml.PauliZ("three"),
1203-
qml.RX("test", 5),
1187+
qml.RX(1, 5),
12041188
]
12051189

12061190
for op1, op2 in zip(final_list, sorted_list):

0 commit comments

Comments
 (0)