17
17
"""
18
18
import warnings
19
19
20
- # pylint: disable=not-callable
20
+ # pylint: disable=not-callable, unused-argument
21
21
from contextlib import contextmanager
22
+ from copy import copy , deepcopy
22
23
from dataclasses import replace
23
24
24
25
import pennylane as qml
25
- from pennylane .measurements import Shots
26
+ from pennylane .measurements import MidMeasureMP , Shots
26
27
from pennylane .transforms .core .transform_program import TransformProgram
27
28
28
- from .default_qubit import adjoint_observables , adjoint_ops
29
29
from .device_api import Device
30
30
from .execution_config import DefaultExecutionConfig
31
31
from .modifiers import single_tape_support
34
34
no_sampling ,
35
35
validate_adjoint_trainable_params ,
36
36
validate_measurements ,
37
- validate_observables ,
38
37
)
39
38
40
39
40
+ def _requests_adjoint (execution_config ):
41
+ return execution_config .gradient_method == "adjoint" or (
42
+ execution_config .gradient_method == "device"
43
+ and execution_config .gradient_keyword_arguments .get ("method" , None ) == "adjoint_jacobian"
44
+ )
45
+
46
+
41
47
@contextmanager
42
48
def _set_shots (device , shots ):
43
49
"""Context manager to temporarily change the shots
@@ -98,6 +104,15 @@ def legacy_device_batch_transform(tape, device):
98
104
return _set_shots (device , tape .shots )(device .batch_transform )(tape )
99
105
100
106
107
+ def adjoint_ops (op : qml .operation .Operator ) -> bool :
108
+ """Specify whether or not an Operator is supported by adjoint differentiation."""
109
+ if isinstance (op , qml .QubitUnitary ) and not qml .operation .is_trainable (op ):
110
+ return True
111
+ return not isinstance (op , MidMeasureMP ) and (
112
+ op .num_params == 0 or (op .num_params == 1 and op .has_generator )
113
+ )
114
+
115
+
101
116
def _add_adjoint_transforms (program : TransformProgram , name = "adjoint" ):
102
117
"""Add the adjoint specific transforms to the transform program."""
103
118
program .add_transform (no_sampling , name = name )
@@ -106,9 +121,13 @@ def _add_adjoint_transforms(program: TransformProgram, name="adjoint"):
106
121
stopping_condition = adjoint_ops ,
107
122
name = name ,
108
123
)
109
- program .add_transform (validate_observables , adjoint_observables , name = name )
124
+
125
+ def accepted_adjoint_measurements (mp ):
126
+ return isinstance (mp , qml .measurements .ExpectationMP )
127
+
110
128
program .add_transform (
111
129
validate_measurements ,
130
+ analytic_measurements = accepted_adjoint_measurements ,
112
131
name = name ,
113
132
)
114
133
program .add_transform (qml .transforms .broadcast_expand )
@@ -141,10 +160,14 @@ class LegacyDeviceFacade(Device):
141
160
142
161
# pylint: disable=super-init-not-called
143
162
def __init__ (self , device : "qml.devices.LegacyDevice" ):
163
+ if isinstance (device , type (self )):
164
+ raise RuntimeError ("An already-facaded device can not be wrapped in a facade again." )
165
+
144
166
if not isinstance (device , qml .devices .LegacyDevice ):
145
167
raise ValueError (
146
168
"The LegacyDeviceFacade only accepts a device of type qml.devices.LegacyDevice."
147
169
)
170
+
148
171
self ._device = device
149
172
150
173
@property
@@ -168,6 +191,13 @@ def __repr__(self):
168
191
def __getattr__ (self , name ):
169
192
return getattr (self ._device , name )
170
193
194
+ # These custom copy methods are needed for Catalyst
195
+ def __copy__ (self ):
196
+ return type (self )(copy (self .target_device ))
197
+
198
+ def __deepcopy__ (self , memo ):
199
+ return type (self )(deepcopy (self .target_device , memo ))
200
+
171
201
@property
172
202
def target_device (self ) -> "qml.devices.LegacyDevice" :
173
203
"""The device wrapped by the facade."""
@@ -195,13 +225,20 @@ def _debugger(self, new_debugger):
195
225
def preprocess (self , execution_config = DefaultExecutionConfig ):
196
226
execution_config = self ._setup_execution_config (execution_config )
197
227
program = qml .transforms .core .TransformProgram ()
198
- # note: need to wrap these methods with a set_shots modifier
228
+
199
229
program .add_transform (legacy_device_batch_transform , device = self ._device )
200
230
program .add_transform (legacy_device_expand_fn , device = self ._device )
201
- if execution_config .gradient_method == "adjoint" :
231
+
232
+ if _requests_adjoint (execution_config ):
202
233
_add_adjoint_transforms (program , name = f"{ self .name } + adjoint" )
203
234
204
- if not self ._device .capabilities ().get ("supports_mid_measure" , False ):
235
+ if self ._device .capabilities ().get ("supports_mid_measure" , False ):
236
+ program .add_transform (
237
+ qml .devices .preprocess .mid_circuit_measurements ,
238
+ device = self ,
239
+ mcm_config = execution_config .mcm_config ,
240
+ )
241
+ else :
205
242
program .add_transform (qml .defer_measurements , device = self )
206
243
207
244
return program , execution_config
@@ -230,8 +267,10 @@ def _setup_adjoint_config(self, execution_config):
230
267
231
268
def _setup_device_config (self , execution_config ):
232
269
tape = qml .tape .QuantumScript ([], [])
270
+
233
271
if not self ._validate_device_method (tape ):
234
272
raise qml .DeviceError ("device does not support device derivatives" )
273
+
235
274
updated_values = {}
236
275
if execution_config .use_device_gradient is None :
237
276
updated_values ["use_device_gradient" ] = True
@@ -243,19 +282,17 @@ def _setup_device_config(self, execution_config):
243
282
def _setup_execution_config (self , execution_config ):
244
283
if execution_config .gradient_method == "best" :
245
284
tape = qml .tape .QuantumScript ([], [])
246
- if self ._validate_backprop_method (tape ):
247
- config = replace (execution_config , gradient_method = "backprop" )
248
- return self ._setup_backprop_config (config )
249
- if self ._validate_adjoint_method (tape ):
250
- config = replace (execution_config , gradient_method = "adjoint" )
251
- return self ._setup_adjoint_config (config )
252
285
if self ._validate_device_method (tape ):
253
286
config = replace (execution_config , gradient_method = "device" )
254
287
return self ._setup_execution_config (config )
255
288
289
+ if self ._validate_backprop_method (tape ):
290
+ config = replace (execution_config , gradient_method = "backprop" )
291
+ return self ._setup_backprop_config (config )
292
+
256
293
if execution_config .gradient_method == "backprop" :
257
294
return self ._setup_backprop_config (execution_config )
258
- if execution_config . gradient_method == "adjoint" :
295
+ if _requests_adjoint ( execution_config ) :
259
296
return self ._setup_adjoint_config (execution_config )
260
297
if execution_config .gradient_method == "device" :
261
298
return self ._setup_device_config (execution_config )
@@ -268,17 +305,17 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool:
268
305
if execution_config is None or execution_config .gradient_method == "best" :
269
306
validation_methods = (
270
307
self ._validate_backprop_method ,
271
- self ._validate_adjoint_method ,
272
308
self ._validate_device_method ,
273
309
)
274
310
return any (validate (circuit ) for validate in validation_methods )
275
311
276
312
if execution_config .gradient_method == "backprop" :
277
313
return self ._validate_backprop_method (circuit )
278
- if execution_config . gradient_method == "adjoint" :
314
+ if _requests_adjoint ( execution_config ) :
279
315
return self ._validate_adjoint_method (circuit )
280
316
if execution_config .gradient_method == "device" :
281
317
return self ._validate_device_method (circuit )
318
+
282
319
return False
283
320
284
321
# pylint: disable=protected-access
@@ -333,7 +370,7 @@ def _create_temp_device(self, batch):
333
370
backprop_devices [mapped_interface ],
334
371
wires = self ._device .wires ,
335
372
shots = self ._device .shots ,
336
- )
373
+ ). target_device
337
374
338
375
new_device .expand_fn = expand_fn
339
376
new_device .batch_transform = batch_transform
@@ -368,6 +405,7 @@ def _validate_backprop_method(self, tape):
368
405
369
406
# determine if the device supports backpropagation
370
407
backprop_interface = self ._device .capabilities ().get ("passthru_interface" , None )
408
+
371
409
if backprop_interface is not None :
372
410
# device supports backpropagation natively
373
411
return mapped_interface in [backprop_interface , "Numpy" ]
@@ -388,9 +426,15 @@ def _validate_adjoint_method(self, tape):
388
426
supported_device = all (hasattr (self ._device , attr ) for attr in required_attrs )
389
427
supported_device = supported_device and self ._device .capabilities ().get ("returns_state" )
390
428
391
- if not supported_device :
429
+ if not supported_device or bool (tape .shots ):
430
+ return False
431
+ program = TransformProgram ()
432
+ _add_adjoint_transforms (program , name = f"{ self .name } + adjoint" )
433
+ try :
434
+ program ((tape ,))
435
+ except (qml .operation .DecompositionUndefinedError , qml .DeviceError , AttributeError ):
392
436
return False
393
- return not bool ( tape . shots )
437
+ return True
394
438
395
439
def _validate_device_method (self , _ ):
396
440
# determine if the device provides its own jacobian method
0 commit comments