Skip to content

Commit 98a9292

Browse files
josephleeklringo-but-quantumalbi3roAmintorDusko
authored
Use native implementation for adjoints in (control) operations (#1063)
### Before submitting Please complete the following checklist when submitting a PR: - [ ] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the [`tests`](../tests) directory! - [ ] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [ ] Ensure that the test suite passes, by running `make test`. - [ ] Add a new entry to the `.github/CHANGELOG.md` file, summarizing the change, and including a link back to the PR. - [ ] Ensure that code is properly formatted by running `make format`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** Currently in `_apply_lightning`, we check for whether an operation is `Adjoint`, then we apply the operation with an adjoint (`inv_param`) flag. However, in cases where we have: - adjoint(s) within control - e.g. `control(adjoint(gate))` - control within adjoint - e.g. `adjoint(control(gate))`, these are all applied as matrices. **Description of the Change:** `_apply_lightning` and `_apply_lightning_controlled` checks for adjoint in an operation, and if it's an adjoint it applies the base operation with an adjoint flag, instead of treating everything as a matrix. So in effect we have: `control(adjoint(gate))` -> `control(gate with adjoint)` `adjoint(control(gate))` -> `control(gate with adjoint)` which are implemented natively in C++ (if the `gate` is supported), yielding better performance **Benefits:** adjoint(ctrl()) will see the most speedup, especially with large number of control wires, since we use native control operation which contains less wires than the equivalent matrix, and needs to be operated on less wires. adjoint(ctrl()) will see some speed-up, since we are now able to use the native named gate implementation in C++. Example timing improvement: 4 ctrl wires LQ: | LQ, 25 qubits, 500 repeats | master | branch | |-------------------------------------|--------|-------| | ctrl(adjoint(IsingXX)) | 9.6s | 6.0s | | ctrl(adjoint(DoubleExcitationPlus)) | 27.6s | 9.2s | | LQ, 25 qubits, 100 repeats | master | branch | |-------------------------------------|------------------|--------| | adjoint(ctrl(IsingXX)) | 267s | 2.9s| | adjoint(ctrl(DoubleExcitationPlus)) | 1002s| 3.6s | Baseline: | LQ, 25 qubits, 500 repeats | master | branch | |-------------------------------------|--------|--------| | ctrl(IsingXX) | 6.1s |6.1s | | ctrl(DoubleExcitationPlus)| 9.1s | 9.1s | LG: | LG, 31 qubits, 1000 repeats | master | branch | |-------------------------------------|--------|--------| | ctrl(adjoint(IsingXX)) | 4.9s | 4.8s | | ctrl(adjoint(DoubleExcitationPlus)) | 5.0s | 4.9s | | LG, 31 qubits, 1000 repeats | master | branch | |-------------------------------------|-------------|--------| | adjoint(ctrl(IsingXX)) | 119s | 4.8s| | adjoint(ctrl(DoubleExcitationPlus)) | 208s | 4.9s | Baseline: | LG, 31 qubits, 1000 repeats | master | branch | |-------------------------------------|-------------------|--------| | ctrl(IsingXX) | 4.8s | 4.8s | | ctrl(DoubleExcitationPlus)| 4.9s | 4.9s | LK: | LK, 25 qubits, 500 repeats | master | branch | |-------------------------------------|--------|--------| | ctrl(adjoint(IsingXX)) | 8.5s |5.7s | | ctrl(adjoint(DoubleExcitationPlus)) | 24.5s | 7.6s | | LK, 25 qubits, 100 repeats | master | branch | |-------------------------------------|-----|--------| | adjoint(ctrl(IsingXX)) | 235s | 2.6s | | adjoint(ctrl(DoubleExcitationPlus)) | 867s | 2.9s | Baseline: | LK, 25 qubits, 500 repeats | master |branch | |-------------------------------------|-------------|--------| | ctrl(IsingXX) | 5.6s |5.8s | | ctrl(DoubleExcitationPlus)| 7.7s | 7.6 s | **Possible Drawbacks:** **Related GitHub Issues:** [sc-79430] --------- Co-authored-by: ringo-but-quantum <[email protected]> Co-authored-by: Christina Lee <[email protected]> Co-authored-by: Amintor Dusko <[email protected]>
1 parent aad3e59 commit 98a9292

File tree

9 files changed

+408
-70
lines changed

9 files changed

+408
-70
lines changed

.github/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
### Improvements
88

9+
* Use native C++ kernels for controlled-adjoint and adjoint-controlled of supported operations.
10+
[(#1063)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1063)
11+
912
* In Lightning-Tensor, allow `qml.MPSPrep` to accept an MPS with `len(MPS) = n_wires-1`.
1013
[(#1064)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1064)
1114

pennylane_lightning/core/_serialize.py

+51-20
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,18 @@ def serialize_ops(self, tape: QuantumTape, wires_map: dict = None) -> Tuple[
443443
uses_stateprep = False
444444

445445
def get_wires(operation, single_op):
446-
if isinstance(operation, qml.ops.op_math.Controlled) and not isinstance(
447-
operation,
446+
# Serialize adjoint(op) and adjoint(ctrl(op))
447+
if isinstance(operation, qml.ops.op_math.Adjoint):
448+
inverse = True
449+
op_base = operation.base
450+
single_op_base = single_op.base
451+
else:
452+
inverse = False
453+
op_base = operation
454+
single_op_base = single_op
455+
456+
if isinstance(op_base, qml.ops.op_math.Controlled) and not isinstance(
457+
op_base,
448458
(
449459
qml.CNOT,
450460
qml.CY,
@@ -457,19 +467,41 @@ def get_wires(operation, single_op):
457467
qml.CSWAP,
458468
),
459469
):
460-
name = operation.base.name
461-
wires_list = list(operation.target_wires)
462-
controlled_wires_list = list(operation.control_wires)
463-
control_values_list = operation.control_values
470+
wires_list = list(op_base.target_wires)
471+
controlled_wires_list = list(op_base.control_wires)
472+
control_values_list = op_base.control_values
473+
# Serialize ctrl(adjoint(op))
474+
if isinstance(op_base.base, qml.ops.op_math.Adjoint):
475+
ctrl_adjoint = True
476+
name = op_base.base.base.name
477+
else:
478+
ctrl_adjoint = False
479+
name = op_base.base.name
480+
481+
# Inside the controlled operation, if the base operation (of the adjoint)
482+
# is supported natively, we apply the the base operation and invert the
483+
# inverse flag; otherwise we apply the QubitUnitary of a matrix which
484+
# contains the inverse and leave the inverse flag as is.
464485
if not hasattr(self.sv_type, name):
465-
single_op = QubitUnitary(matrix(single_op.base), single_op.base.wires)
466-
name = single_op.name
486+
single_op_base = QubitUnitary(
487+
matrix(single_op_base.base), single_op_base.base.wires
488+
)
489+
name = single_op_base.name
490+
else:
491+
inverse ^= ctrl_adjoint
467492
else:
468-
name = single_op.name
469-
wires_list = single_op.wires.tolist()
493+
name = single_op_base.name
494+
wires_list = single_op_base.wires.tolist()
470495
controlled_wires_list = []
471496
control_values_list = []
472-
return single_op, name, list(wires_list), controlled_wires_list, control_values_list
497+
return (
498+
single_op_base,
499+
name,
500+
inverse,
501+
list(wires_list),
502+
controlled_wires_list,
503+
control_values_list,
504+
)
473505

474506
for operation in tape.operations:
475507
if isinstance(operation, (BasisState, StatePrep)):
@@ -480,30 +512,29 @@ def get_wires(operation, single_op):
480512
else:
481513
op_list = [operation]
482514

483-
inverse = isinstance(operation, qml.ops.op_math.Adjoint)
484-
485515
for single_op in op_list:
486516
(
487-
single_op,
517+
single_op_base,
488518
name,
519+
inverse,
489520
wires_list,
490521
controlled_wires_list,
491522
controlled_values_list,
492523
) = get_wires(operation, single_op)
493524
inverses.append(inverse)
494-
names.append(single_op.base.name if inverse else name)
525+
names.append(name)
495526
# QubitUnitary is a special case, it has a parameter which is not differentiable.
496527
# We thus pass a dummy 0.0 parameter which will not be referenced
497-
if isinstance(single_op, qml.QubitUnitary):
528+
if isinstance(single_op_base, qml.QubitUnitary):
498529
params.append([0.0])
499-
mats.append(matrix(single_op))
530+
mats.append(matrix(single_op_base))
500531
else:
501-
if hasattr(self.sv_type, single_op.base.name if inverse else name):
502-
params.append(single_op.parameters)
532+
if hasattr(self.sv_type, name):
533+
params.append(single_op_base.parameters)
503534
mats.append([])
504535
else:
505536
params.append([])
506-
mats.append(matrix(single_op))
537+
mats.append(matrix(single_op_base))
507538

508539
controlled_values.append(controlled_values_list)
509540
controlled_wires.append(

pennylane_lightning/core/_state_vector_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
from pennylane import BasisState, StatePrep
2323
from pennylane.measurements import MidMeasureMP
24+
from pennylane.ops import Controlled
2425
from pennylane.tape import QuantumScript
2526
from pennylane.wires import Wires
2627

@@ -131,11 +132,12 @@ def _apply_basis_state(self, state, wires):
131132
self._qubit_state.setBasisState(list(state), list(wires))
132133

133134
@abstractmethod
134-
def _apply_lightning_controlled(self, operation):
135+
def _apply_lightning_controlled(self, operation: Controlled, adjoint: bool):
135136
"""Apply an arbitrary controlled operation to the state tensor.
136137
137138
Args:
138139
operation (~pennylane.operation.Operation): controlled operation to apply
140+
adjoint (bool): Apply the adjoint of the operation if True
139141
140142
Returns:
141143
None

pennylane_lightning/core/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.41.0-dev24"
19+
__version__ = "0.41.0-dev25"

pennylane_lightning/lightning_gpu/_state_vector.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -229,34 +229,39 @@ def _apply_state_vector(self, state, device_wires, use_async: bool = False):
229229
# set the state vector on GPU with provided state and their corresponding wires
230230
self._qubit_state.setStateVector(state, list(device_wires), use_async)
231231

232-
def _apply_lightning_controlled(self, operation):
232+
def _apply_lightning_controlled(self, operation, adjoint):
233233
"""Apply an arbitrary controlled operation to the state tensor.
234234
235235
Args:
236236
operation (~pennylane.operation.Operation): controlled operation to apply
237+
adjoint (bool): Apply the adjoint of the operation if True
237238
238239
Returns:
239240
None
240241
"""
241242
state = self.state_vector
242243

243-
basename = operation.base.name
244-
method = getattr(state, f"{basename}", None)
244+
if isinstance(operation.base, Adjoint):
245+
base_operation = operation.base.base
246+
adjoint = not adjoint
247+
else:
248+
base_operation = operation.base
249+
250+
method = getattr(state, f"{base_operation.name}", None)
245251
control_wires = list(operation.control_wires)
246252
control_values = operation.control_values
247253
target_wires = list(operation.target_wires)
248254
if method: # apply n-controlled specialized gate
249-
inv = False
250255
param = operation.parameters
251-
method(control_wires, control_values, target_wires, inv, param)
256+
method(control_wires, control_values, target_wires, adjoint, param)
252257
else: # apply gate as an n-controlled matrix
253258
method = getattr(state, "applyControlledMatrix")
254259
method(
255-
qml.matrix(operation.base),
260+
qml.matrix(base_operation),
256261
control_wires,
257262
control_values,
258263
target_wires,
259-
False,
264+
adjoint,
260265
)
261266

262267
def _apply_lightning_midmeasure(
@@ -300,6 +305,7 @@ def _apply_lightning(
300305
postselection. Use ``"hw-like"`` to discard invalid shots and ``"fill-shots"`` to
301306
keep the same number of shots. Default is ``None``.
302307
308+
303309
Returns:
304310
None
305311
"""
@@ -311,11 +317,12 @@ def _apply_lightning(
311317
if isinstance(operation, qml.Identity):
312318
continue
313319
if isinstance(operation, Adjoint):
314-
name = operation.base.name
320+
op_adjoint_base = operation.base
315321
invert_param = True
316322
else:
317-
name = operation.name
323+
op_adjoint_base = operation
318324
invert_param = False
325+
name = op_adjoint_base.name
319326
method = getattr(state, name, None)
320327
wires = list(operation.wires)
321328

@@ -330,13 +337,13 @@ def _apply_lightning(
330337
param = operation.parameters
331338
method(wires, invert_param, param)
332339
elif (
333-
isinstance(operation, qml.ops.Controlled) and not self._mpi_handler.use_mpi
340+
isinstance(op_adjoint_base, qml.ops.Controlled) and not self._mpi_handler.use_mpi
334341
): # MPI backend does not have native controlled gates support
335-
self._apply_lightning_controlled(operation)
342+
self._apply_lightning_controlled(op_adjoint_base, invert_param)
336343
elif (
337344
self._mpi_handler.use_mpi
338-
and isinstance(operation, qml.ops.Controlled)
339-
and isinstance(operation.base, qml.GlobalPhase)
345+
and isinstance(op_adjoint_base, qml.ops.Controlled)
346+
and isinstance(op_adjoint_base.base, qml.GlobalPhase)
340347
):
341348
# TODO: To move this line to the _apply_lightning_controlled method once the MPI backend supports controlled gates natively
342349
raise DeviceError(
@@ -348,7 +355,6 @@ def _apply_lightning(
348355
except AttributeError: # pragma: no cover
349356
# To support older versions of PL
350357
mat = operation.matrix
351-
352358
r_dtype = np.float32 if self.dtype == np.complex64 else np.float64
353359
param = (
354360
[[r_dtype(operation.hash)]]

pennylane_lightning/lightning_kokkos/_state_vector.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -181,34 +181,39 @@ def _apply_state_vector(self, state, device_wires: Wires):
181181
# This operate on device
182182
self._qubit_state.setStateVector(state, list(device_wires))
183183

184-
def _apply_lightning_controlled(self, operation):
184+
def _apply_lightning_controlled(self, operation, adjoint):
185185
"""Apply an arbitrary controlled operation to the state tensor.
186186
187187
Args:
188188
operation (~pennylane.operation.Operation): controlled operation to apply
189+
adjoint (bool): Apply the adjoint of the operation if True
189190
190191
Returns:
191192
None
192193
"""
193194
state = self.state_vector
194195

195-
basename = operation.base.name
196-
method = getattr(state, f"{basename}", None)
196+
if isinstance(operation.base, Adjoint):
197+
base_operation = operation.base.base
198+
adjoint = not adjoint
199+
else:
200+
base_operation = operation.base
201+
202+
method = getattr(state, f"{base_operation.name}", None)
197203
control_wires = list(operation.control_wires)
198204
control_values = operation.control_values
199205
target_wires = list(operation.target_wires)
200-
inv = False # TODO: update to use recursive _apply_lightning to handle nested adjoint/ctrl
201206
if method is not None: # apply n-controlled specialized gate
202207
param = operation.parameters
203-
method(control_wires, control_values, target_wires, inv, param)
208+
method(control_wires, control_values, target_wires, adjoint, param)
204209
else: # apply gate as an n-controlled matrix
205210
method = getattr(state, "applyControlledMatrix")
206211
method(
207-
qml.matrix(operation.base),
212+
qml.matrix(base_operation),
208213
control_wires,
209214
control_values,
210215
target_wires,
211-
inv,
216+
adjoint,
212217
)
213218

214219
def _apply_lightning_midmeasure(
@@ -262,11 +267,12 @@ def _apply_lightning(
262267
if isinstance(operation, qml.Identity):
263268
continue
264269
if isinstance(operation, Adjoint):
265-
name = operation.base.name
270+
op_adjoint_base = operation.base
266271
invert_param = True
267272
else:
268-
name = operation.name
273+
op_adjoint_base = operation
269274
invert_param = False
275+
name = op_adjoint_base.name
270276
method = getattr(state, name, None)
271277
wires = list(operation.wires)
272278

@@ -279,18 +285,17 @@ def _apply_lightning(
279285
)
280286
elif isinstance(operation, qml.PauliRot):
281287
method = getattr(state, "applyPauliRot")
282-
# pylint: disable=protected-access
283-
paulis = operation._hyperparameters[
288+
paulis = operation._hyperparameters[ # pylint: disable=protected-access
284289
"pauli_word"
285-
] # pylint: disable=protected-access
290+
]
286291
wires = [i for i, w in zip(wires, paulis) if w != "I"]
287292
word = "".join(p for p in paulis if p != "I")
288293
method(wires, invert_param, operation.parameters, word)
289294
elif method is not None: # apply specialized gate
290295
param = operation.parameters
291296
method(wires, invert_param, param)
292-
elif isinstance(operation, qml.ops.Controlled): # apply n-controlled gate
293-
self._apply_lightning_controlled(operation)
297+
elif isinstance(op_adjoint_base, qml.ops.Controlled): # apply n-controlled gate
298+
self._apply_lightning_controlled(op_adjoint_base, invert_param)
294299
else: # apply gate as a matrix
295300
# Inverse can be set to False since qml.matrix(operation) is already in
296301
# inverted form

0 commit comments

Comments
 (0)