Skip to content

Commit 1f55c88

Browse files
Shiro-Ravenalbi3romudit2812ringo-but-quantumdwierichs
authored
Wrapping of legacy device automatically in various device creation/qnode/execute functions (#6046)
**Context:** With the `LegacyDeviceFacade` now in place, we can add automatic wrapping of legacy devices. **Description of the Change:** Add automatic wrapping to `qml.device`, `qml.execute`, `QNode` constructor, and the `get_best_method` and `best_method_str` functions of the QNode class. The tests are also updated accordingly. **Benefits:** Users no longer need to worry about upgrading their devices to the new Device API and can use the facade to access the basic functions of the new API. **Possible Drawbacks:** The facade doesn't yet provide all potential advantages of fully upgrading to the new API [[sc-65998](https://app.shortcut.com/xanaduai/story/65998)] --------- Co-authored-by: albi3ro <[email protected]> Co-authored-by: Christina Lee <[email protected]> Co-authored-by: Mudit Pandey <[email protected]> Co-authored-by: ringo-but-quantum <[email protected]> Co-authored-by: David Wierichs <[email protected]> Co-authored-by: Thomas R. Bromley <[email protected]> Co-authored-by: Korbinian Kottmann <[email protected]>
1 parent dea7a2d commit 1f55c88

File tree

68 files changed

+1285
-2018
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1285
-2018
lines changed

pennylane/capture/capture_qnode.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,6 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
9090
return qnode_prim
9191

9292

93-
# pylint: disable=protected-access
94-
def _get_device_shots(device) -> "qml.measurements.Shots":
95-
if isinstance(device, qml.devices.LegacyDevice):
96-
if device._shot_vector:
97-
return qml.measurements.Shots(device._raw_shot_sequence)
98-
return qml.measurements.Shots(device.shots)
99-
return device.shots
100-
101-
10293
def qnode_call(qnode: "qml.QNode", *args, **kwargs) -> "qml.typing.Result":
10394
"""A capture compatible call to a QNode. This function is internally used by ``QNode.__call__``.
10495
@@ -166,7 +157,7 @@ def f(x):
166157
if "shots" in kwargs:
167158
shots = qml.measurements.Shots(kwargs.pop("shots"))
168159
else:
169-
shots = _get_device_shots(qnode.device)
160+
shots = qnode.device.shots
170161
if shots.has_partitioned_shots:
171162
# Questions over the pytrees and the nested result object shape
172163
raise NotImplementedError("shot vectors are not yet supported with plxpr capture.")

pennylane/debugging/snapshot.py

-5
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,11 @@ def get_snapshots(*args, **kwargs):
228228

229229
with _SnapshotDebugger(qnode.device) as dbg:
230230
# pylint: disable=protected-access
231-
if qnode._original_device:
232-
qnode._original_device._debugger = qnode.device._debugger
233-
234231
results = qnode(*args, **kwargs)
235232

236233
# Reset interface
237234
if old_interface == "auto":
238235
qnode.interface = "auto"
239-
if qnode._original_device:
240-
qnode.device._debugger = None
241236

242237
dbg.snapshots["execution_results"] = results
243238
return dbg.snapshots

pennylane/devices/_legacy_device.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,25 @@ def _local_tape_expand(tape, depth, stop_at):
9797
return new_tape
9898

9999

100-
class Device(abc.ABC):
100+
class _LegacyMeta(abc.ABCMeta):
101+
"""
102+
A simple meta class added to circumvent the Legacy facade when
103+
checking the instance of a device against a Legacy device type.
104+
105+
To illustrate, if "dev" is of type LegacyDeviceFacade, and a user is
106+
checking "isinstance(dev, qml.devices.DefaultMixed)", the overridden
107+
"__instancecheck__" will look behind the facade, and will evaluate instead
108+
"isinstance(dev.target_device, qml.devices.DefaultMixed)"
109+
"""
110+
111+
def __instancecheck__(cls, instance):
112+
if isinstance(instance, qml.devices.LegacyDeviceFacade):
113+
return isinstance(instance.target_device, cls)
114+
115+
return super().__instancecheck__(instance)
116+
117+
118+
class Device(abc.ABC, metaclass=_LegacyMeta):
101119
"""Abstract base class for PennyLane devices.
102120
103121
Args:

pennylane/devices/device_constructor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def _safe_specifier_set(version_str):
281281

282282
# Once the device is constructed, we set its custom expansion function if
283283
# any custom decompositions were specified.
284-
285284
if custom_decomps is not None:
286285
if isinstance(dev, qml.devices.LegacyDevice):
287286
custom_decomp_expand_fn = qml.transforms.create_decomp_expand_fn(
@@ -294,6 +293,9 @@ def _safe_specifier_set(version_str):
294293
)
295294
dev.preprocess = custom_decomp_preprocess
296295

296+
if isinstance(dev, qml.devices.LegacyDevice):
297+
dev = qml.devices.LegacyDeviceFacade(dev)
298+
297299
return dev
298300

299301
raise qml.DeviceError(

pennylane/devices/legacy_facade.py

+64-20
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
"""
1818
import warnings
1919

20-
# pylint: disable=not-callable
20+
# pylint: disable=not-callable, unused-argument
2121
from contextlib import contextmanager
22+
from copy import copy, deepcopy
2223
from dataclasses import replace
2324

2425
import pennylane as qml
25-
from pennylane.measurements import Shots
26+
from pennylane.measurements import MidMeasureMP, Shots
2627
from pennylane.transforms.core.transform_program import TransformProgram
2728

28-
from .default_qubit import adjoint_observables, adjoint_ops
2929
from .device_api import Device
3030
from .execution_config import DefaultExecutionConfig
3131
from .modifiers import single_tape_support
@@ -34,10 +34,16 @@
3434
no_sampling,
3535
validate_adjoint_trainable_params,
3636
validate_measurements,
37-
validate_observables,
3837
)
3938

4039

40+
def _requests_adjoint(execution_config):
41+
return execution_config.gradient_method == "adjoint" or (
42+
execution_config.gradient_method == "device"
43+
and execution_config.gradient_keyword_arguments.get("method", None) == "adjoint_jacobian"
44+
)
45+
46+
4147
@contextmanager
4248
def _set_shots(device, shots):
4349
"""Context manager to temporarily change the shots
@@ -98,6 +104,15 @@ def legacy_device_batch_transform(tape, device):
98104
return _set_shots(device, tape.shots)(device.batch_transform)(tape)
99105

100106

107+
def adjoint_ops(op: qml.operation.Operator) -> bool:
108+
"""Specify whether or not an Operator is supported by adjoint differentiation."""
109+
if isinstance(op, qml.QubitUnitary) and not qml.operation.is_trainable(op):
110+
return True
111+
return not isinstance(op, MidMeasureMP) and (
112+
op.num_params == 0 or (op.num_params == 1 and op.has_generator)
113+
)
114+
115+
101116
def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
102117
"""Add the adjoint specific transforms to the transform program."""
103118
program.add_transform(no_sampling, name=name)
@@ -106,9 +121,13 @@ def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
106121
stopping_condition=adjoint_ops,
107122
name=name,
108123
)
109-
program.add_transform(validate_observables, adjoint_observables, name=name)
124+
125+
def accepted_adjoint_measurements(mp):
126+
return isinstance(mp, qml.measurements.ExpectationMP)
127+
110128
program.add_transform(
111129
validate_measurements,
130+
analytic_measurements=accepted_adjoint_measurements,
112131
name=name,
113132
)
114133
program.add_transform(qml.transforms.broadcast_expand)
@@ -141,10 +160,14 @@ class LegacyDeviceFacade(Device):
141160

142161
# pylint: disable=super-init-not-called
143162
def __init__(self, device: "qml.devices.LegacyDevice"):
163+
if isinstance(device, type(self)):
164+
raise RuntimeError("An already-facaded device can not be wrapped in a facade again.")
165+
144166
if not isinstance(device, qml.devices.LegacyDevice):
145167
raise ValueError(
146168
"The LegacyDeviceFacade only accepts a device of type qml.devices.LegacyDevice."
147169
)
170+
148171
self._device = device
149172

150173
@property
@@ -168,6 +191,13 @@ def __repr__(self):
168191
def __getattr__(self, name):
169192
return getattr(self._device, name)
170193

194+
# These custom copy methods are needed for Catalyst
195+
def __copy__(self):
196+
return type(self)(copy(self.target_device))
197+
198+
def __deepcopy__(self, memo):
199+
return type(self)(deepcopy(self.target_device, memo))
200+
171201
@property
172202
def target_device(self) -> "qml.devices.LegacyDevice":
173203
"""The device wrapped by the facade."""
@@ -195,13 +225,20 @@ def _debugger(self, new_debugger):
195225
def preprocess(self, execution_config=DefaultExecutionConfig):
196226
execution_config = self._setup_execution_config(execution_config)
197227
program = qml.transforms.core.TransformProgram()
198-
# note: need to wrap these methods with a set_shots modifier
228+
199229
program.add_transform(legacy_device_batch_transform, device=self._device)
200230
program.add_transform(legacy_device_expand_fn, device=self._device)
201-
if execution_config.gradient_method == "adjoint":
231+
232+
if _requests_adjoint(execution_config):
202233
_add_adjoint_transforms(program, name=f"{self.name} + adjoint")
203234

204-
if not self._device.capabilities().get("supports_mid_measure", False):
235+
if self._device.capabilities().get("supports_mid_measure", False):
236+
program.add_transform(
237+
qml.devices.preprocess.mid_circuit_measurements,
238+
device=self,
239+
mcm_config=execution_config.mcm_config,
240+
)
241+
else:
205242
program.add_transform(qml.defer_measurements, device=self)
206243

207244
return program, execution_config
@@ -230,8 +267,10 @@ def _setup_adjoint_config(self, execution_config):
230267

231268
def _setup_device_config(self, execution_config):
232269
tape = qml.tape.QuantumScript([], [])
270+
233271
if not self._validate_device_method(tape):
234272
raise qml.DeviceError("device does not support device derivatives")
273+
235274
updated_values = {}
236275
if execution_config.use_device_gradient is None:
237276
updated_values["use_device_gradient"] = True
@@ -243,19 +282,17 @@ def _setup_device_config(self, execution_config):
243282
def _setup_execution_config(self, execution_config):
244283
if execution_config.gradient_method == "best":
245284
tape = qml.tape.QuantumScript([], [])
246-
if self._validate_backprop_method(tape):
247-
config = replace(execution_config, gradient_method="backprop")
248-
return self._setup_backprop_config(config)
249-
if self._validate_adjoint_method(tape):
250-
config = replace(execution_config, gradient_method="adjoint")
251-
return self._setup_adjoint_config(config)
252285
if self._validate_device_method(tape):
253286
config = replace(execution_config, gradient_method="device")
254287
return self._setup_execution_config(config)
255288

289+
if self._validate_backprop_method(tape):
290+
config = replace(execution_config, gradient_method="backprop")
291+
return self._setup_backprop_config(config)
292+
256293
if execution_config.gradient_method == "backprop":
257294
return self._setup_backprop_config(execution_config)
258-
if execution_config.gradient_method == "adjoint":
295+
if _requests_adjoint(execution_config):
259296
return self._setup_adjoint_config(execution_config)
260297
if execution_config.gradient_method == "device":
261298
return self._setup_device_config(execution_config)
@@ -268,17 +305,17 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool:
268305
if execution_config is None or execution_config.gradient_method == "best":
269306
validation_methods = (
270307
self._validate_backprop_method,
271-
self._validate_adjoint_method,
272308
self._validate_device_method,
273309
)
274310
return any(validate(circuit) for validate in validation_methods)
275311

276312
if execution_config.gradient_method == "backprop":
277313
return self._validate_backprop_method(circuit)
278-
if execution_config.gradient_method == "adjoint":
314+
if _requests_adjoint(execution_config):
279315
return self._validate_adjoint_method(circuit)
280316
if execution_config.gradient_method == "device":
281317
return self._validate_device_method(circuit)
318+
282319
return False
283320

284321
# pylint: disable=protected-access
@@ -333,7 +370,7 @@ def _create_temp_device(self, batch):
333370
backprop_devices[mapped_interface],
334371
wires=self._device.wires,
335372
shots=self._device.shots,
336-
)
373+
).target_device
337374

338375
new_device.expand_fn = expand_fn
339376
new_device.batch_transform = batch_transform
@@ -368,6 +405,7 @@ def _validate_backprop_method(self, tape):
368405

369406
# determine if the device supports backpropagation
370407
backprop_interface = self._device.capabilities().get("passthru_interface", None)
408+
371409
if backprop_interface is not None:
372410
# device supports backpropagation natively
373411
return mapped_interface in [backprop_interface, "Numpy"]
@@ -388,9 +426,15 @@ def _validate_adjoint_method(self, tape):
388426
supported_device = all(hasattr(self._device, attr) for attr in required_attrs)
389427
supported_device = supported_device and self._device.capabilities().get("returns_state")
390428

391-
if not supported_device:
429+
if not supported_device or bool(tape.shots):
430+
return False
431+
program = TransformProgram()
432+
_add_adjoint_transforms(program, name=f"{self.name} + adjoint")
433+
try:
434+
program((tape,))
435+
except (qml.operation.DecompositionUndefinedError, qml.DeviceError, AttributeError):
392436
return False
393-
return not bool(tape.shots)
437+
return True
394438

395439
def _validate_device_method(self, _):
396440
# determine if the device provides its own jacobian method

pennylane/devices/tests/conftest.py

+22
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Contains shared fixtures for the device tests."""
1515
import argparse
1616
import os
17+
import warnings
1718

1819
import numpy as np
1920
import pytest
@@ -41,6 +42,27 @@
4142
}
4243

4344

45+
@pytest.fixture(scope="function", autouse=True)
46+
def capture_legacy_device_deprecation_warnings():
47+
"""Catches all warnings raised by a test and verifies that any Deprecation
48+
warnings released are related to the legacy devices. Otherwise, it re-raises
49+
any unrelated warnings"""
50+
51+
with warnings.catch_warnings(record=True) as recwarn:
52+
warnings.simplefilter("always")
53+
yield
54+
55+
for w in recwarn:
56+
if isinstance(w, qml.PennyLaneDeprecationWarning):
57+
assert "Use of 'default.qubit." in str(w.message)
58+
assert "is deprecated" in str(w.message)
59+
assert "use 'default.qubit'" in str(w.message)
60+
61+
for w in recwarn:
62+
if "Use of 'default.qubit." not in str(w.message):
63+
warnings.warn(message=w.message, category=w.category)
64+
65+
4466
@pytest.fixture(scope="function")
4567
def tol():
4668
"""Numerical tolerance for equality tests. Returns a different tolerance for tests

pennylane/devices/tests/test_gates.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# pylint: disable=too-many-arguments
2020
# pylint: disable=pointless-statement
2121
# pylint: disable=unnecessary-lambda-assignment
22+
# pylint: disable=no-name-in-module
2223
from cmath import exp
2324
from math import cos, sin, sqrt
2425

pennylane/devices/tests/test_measurements.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1783,15 +1783,15 @@ def circuit():
17831783
qml.X(0)
17841784
return MyMeasurement()
17851785

1786-
if isinstance(dev, qml.Device):
1787-
with pytest.warns(
1786+
with (
1787+
pytest.warns(
17881788
UserWarning,
1789-
match="Requested measurement MyMeasurement with finite shots",
1790-
):
1791-
circuit()
1792-
else:
1793-
with pytest.raises(qml.DeviceError):
1794-
circuit()
1789+
match="MyMeasurement with finite shots; the returned state information is analytic",
1790+
)
1791+
if isinstance(dev, qml.devices.LegacyDevice)
1792+
else pytest.raises(qml.DeviceError, match="not accepted with finite shots")
1793+
):
1794+
circuit()
17951795

17961796
def test_method_overriden_by_device(self, device):
17971797
"""Test that the device can override a measurement process."""

pennylane/optimize/qnspsa.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -442,15 +442,11 @@ def _apply_blocking(self, cost, args, kwargs, params_next):
442442
cost.construct(params_next, kwargs)
443443
tape_loss_next = cost.tape.copy(copy_operations=True)
444444

445-
if isinstance(cost.device, qml.devices.Device):
446-
program, _ = cost.device.preprocess()
447-
448-
loss_curr, loss_next = qml.execute(
449-
[tape_loss_curr, tape_loss_next], cost.device, None, transform_program=program
450-
)
445+
program, _ = cost.device.preprocess()
451446

452-
else:
453-
loss_curr, loss_next = qml.execute([tape_loss_curr, tape_loss_next], cost.device, None)
447+
loss_curr, loss_next = qml.execute(
448+
[tape_loss_curr, tape_loss_next], cost.device, None, transform_program=program
449+
)
454450

455451
# self.k has been updated earlier
456452
ind = (self.k - 2) % self.last_n_steps.size

0 commit comments

Comments
 (0)