16
16
This module contains device stubs for the old and new PennyLane device API, which facilitate
17
17
the application of decomposition and other device pre-processing routines.
18
18
"""
19
-
19
+ import os
20
+ import pathlib
21
+ import platform
22
+ import re
20
23
from copy import deepcopy
24
+ from dataclasses import dataclass
21
25
from functools import partial
22
- from typing import Optional , Set
26
+ from typing import Any , Dict , Optional , Set
23
27
24
28
import pennylane as qml
25
29
from pennylane .measurements import MidMeasureMP
32
36
)
33
37
from catalyst .utils .exceptions import CompileError
34
38
from catalyst .utils .patching import Patcher
35
- from catalyst .utils .runtime import BackendInfo , device_get_toml_config
39
+ from catalyst .utils .runtime_environment import get_lib_path
36
40
from catalyst .utils .toml import (
37
41
DeviceCapabilities ,
38
42
OperationProperties ,
39
- ProgramFeatures ,
40
- TOMLDocument ,
41
- get_device_capabilities ,
42
43
intersect_operations ,
43
44
pennylane_operation_set ,
44
45
)
84
85
for op in RUNTIME_OPERATIONS
85
86
}
86
87
88
+ # TODO: This should be removed after implementing `get_c_interface`
89
+ # for the following backend devices:
90
+ SUPPORTED_RT_DEVICES = {
91
+ "lightning.qubit" : ("LightningSimulator" , "librtd_lightning" ),
92
+ "lightning.kokkos" : ("LightningKokkosSimulator" , "librtd_lightning" ),
93
+ "braket.aws.qubit" : ("OpenQasmDevice" , "librtd_openqasm" ),
94
+ "braket.local.qubit" : ("OpenQasmDevice" , "librtd_openqasm" ),
95
+ }
96
+
97
+
98
+ @dataclass
99
+ class BackendInfo :
100
+ """Backend information"""
101
+
102
+ device_name : str
103
+ c_interface_name : str
104
+ lpath : str
105
+ kwargs : Dict [str , Any ]
106
+
107
+
108
+ def extract_backend_info (device : qml .QubitDevice , capabilities : DeviceCapabilities ) -> BackendInfo :
109
+ """Extract the backend info from a quantum device. The device is expected to carry a reference
110
+ to a valid TOML config file."""
111
+ # pylint: disable=too-many-branches
112
+
113
+ dname = device .name
114
+ if isinstance (device , qml .Device ):
115
+ dname = device .short_name
116
+
117
+ device_name = ""
118
+ device_lpath = ""
119
+ device_kwargs = {}
120
+
121
+ if dname in SUPPORTED_RT_DEVICES :
122
+ # Support backend devices without `get_c_interface`
123
+ device_name = SUPPORTED_RT_DEVICES [dname ][0 ]
124
+ device_lpath = get_lib_path ("runtime" , "RUNTIME_LIB_DIR" )
125
+ sys_platform = platform .system ()
126
+
127
+ if sys_platform == "Linux" :
128
+ device_lpath = os .path .join (device_lpath , SUPPORTED_RT_DEVICES [dname ][1 ] + ".so" )
129
+ elif sys_platform == "Darwin" : # pragma: no cover
130
+ device_lpath = os .path .join (device_lpath , SUPPORTED_RT_DEVICES [dname ][1 ] + ".dylib" )
131
+ else : # pragma: no cover
132
+ raise NotImplementedError (f"Platform not supported: { sys_platform } " )
133
+ elif hasattr (device , "get_c_interface" ):
134
+ # Support third party devices with `get_c_interface`
135
+ device_name , device_lpath = device .get_c_interface ()
136
+ else :
137
+ raise CompileError (f"The { dname } device does not provide C interface for compilation." )
138
+
139
+ if not pathlib .Path (device_lpath ).is_file ():
140
+ raise CompileError (f"Device at { device_lpath } cannot be found!" )
141
+
142
+ if hasattr (device , "shots" ):
143
+ if isinstance (device , qml .Device ):
144
+ device_kwargs ["shots" ] = device .shots if device .shots else 0
145
+ else :
146
+ # TODO: support shot vectors
147
+ device_kwargs ["shots" ] = device .shots .total_shots if device .shots else 0
148
+
149
+ if dname == "braket.local.qubit" : # pragma: no cover
150
+ device_kwargs ["device_type" ] = dname
151
+ device_kwargs ["backend" ] = (
152
+ # pylint: disable=protected-access
153
+ device ._device ._delegate .DEVICE_ID
154
+ )
155
+ elif dname == "braket.aws.qubit" : # pragma: no cover
156
+ device_kwargs ["device_type" ] = dname
157
+ device_kwargs ["device_arn" ] = device ._device ._arn # pylint: disable=protected-access
158
+ if device ._s3_folder : # pylint: disable=protected-access
159
+ device_kwargs ["s3_destination_folder" ] = str (
160
+ device ._s3_folder # pylint: disable=protected-access
161
+ )
162
+
163
+ for k , v in capabilities .options .items ():
164
+ if hasattr (device , v ):
165
+ device_kwargs [k ] = getattr (device , v )
166
+
167
+ return BackendInfo (dname , device_name , device_lpath , device_kwargs )
168
+
87
169
88
170
def get_qjit_device_capabilities (target_capabilities : DeviceCapabilities ) -> Set [str ]:
89
171
"""Calculate the set of supported quantum gates for the QJIT device from the gates
@@ -165,7 +247,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S
165
247
166
248
def __init__ (
167
249
self ,
168
- target_config : TOMLDocument ,
250
+ original_device_capabilities : DeviceCapabilities ,
169
251
shots = None ,
170
252
wires = None ,
171
253
backend : Optional [BackendInfo ] = None ,
@@ -175,23 +257,18 @@ def __init__(
175
257
self .backend_name = backend .c_interface_name if backend else "default"
176
258
self .backend_lib = backend .lpath if backend else ""
177
259
self .backend_kwargs = backend .kwargs if backend else {}
178
- device_name = backend .device_name if backend else "default"
179
260
180
- program_features = ProgramFeatures (shots is not None )
181
- target_device_capabilities = get_device_capabilities (
182
- target_config , program_features , device_name
183
- )
184
- self .capabilities = get_qjit_device_capabilities (target_device_capabilities )
261
+ self .qjit_capabilities = get_qjit_device_capabilities (original_device_capabilities )
185
262
186
263
@property
187
264
def operations (self ) -> Set [str ]:
188
265
"""Get the device operations using PennyLane's syntax"""
189
- return pennylane_operation_set (self .capabilities .native_ops )
266
+ return pennylane_operation_set (self .qjit_capabilities .native_ops )
190
267
191
268
@property
192
269
def observables (self ) -> Set [str ]:
193
270
"""Get the device observables"""
194
- return pennylane_operation_set (self .capabilities .native_obs )
271
+ return pennylane_operation_set (self .qjit_capabilities .native_obs )
195
272
196
273
def apply (self , operations , ** kwargs ):
197
274
"""
@@ -270,6 +347,7 @@ class QJITDeviceNewAPI(qml.devices.Device):
270
347
def __init__ (
271
348
self ,
272
349
original_device ,
350
+ original_device_capabilities : DeviceCapabilities ,
273
351
backend : Optional [BackendInfo ] = None ,
274
352
):
275
353
self .original_device = original_device
@@ -285,29 +363,23 @@ def __init__(
285
363
self .backend_name = backend .c_interface_name if backend else "default"
286
364
self .backend_lib = backend .lpath if backend else ""
287
365
self .backend_kwargs = backend .kwargs if backend else {}
288
- device_name = backend .device_name if backend else "default"
289
366
290
- target_config = device_get_toml_config (original_device )
291
- program_features = ProgramFeatures (original_device .shots is not None )
292
- target_device_capabilities = get_device_capabilities (
293
- target_config , program_features , device_name
294
- )
295
- self .capabilities = get_qjit_device_capabilities (target_device_capabilities )
367
+ self .qjit_capabilities = get_qjit_device_capabilities (original_device_capabilities )
296
368
297
369
@property
298
370
def operations (self ) -> Set [str ]:
299
371
"""Get the device operations"""
300
- return pennylane_operation_set (self .capabilities .native_ops )
372
+ return pennylane_operation_set (self .qjit_capabilities .native_ops )
301
373
302
374
@property
303
375
def observables (self ) -> Set [str ]:
304
376
"""Get the device observables"""
305
- return pennylane_operation_set (self .capabilities .native_obs )
377
+ return pennylane_operation_set (self .qjit_capabilities .native_obs )
306
378
307
379
@property
308
380
def measurement_processes (self ) -> Set [str ]:
309
381
"""Get the device measurement processes"""
310
- return self .capabilities .measurement_processes
382
+ return self .qjit_capabilities .measurement_processes
311
383
312
384
def preprocess (
313
385
self ,
@@ -334,3 +406,101 @@ def execute(self, circuits, execution_config):
334
406
Raises: RuntimeError
335
407
"""
336
408
raise RuntimeError ("QJIT devices cannot execute tapes." )
409
+
410
+
411
+ def filter_out_adjoint (operations ):
412
+ """Remove Adjoint from operations.
413
+
414
+ Args:
415
+ operations (List[Str]): List of strings with names of supported operations
416
+
417
+ Returns:
418
+ List: A list of strings with names of supported operations with Adjoint and C gates
419
+ removed.
420
+ """
421
+ adjoint = re .compile (r"^Adjoint\(.*\)$" )
422
+
423
+ def is_not_adj (op ):
424
+ return not re .match (adjoint , op )
425
+
426
+ operations_no_adj = filter (is_not_adj , operations )
427
+ return set (operations_no_adj )
428
+
429
+
430
+ def check_no_overlap (* args , device_name ):
431
+ """Check items in *args are mutually exclusive.
432
+
433
+ Args:
434
+ *args (List[Str]): List of strings.
435
+ device_name (str): Device name for error reporting.
436
+
437
+ Raises:
438
+ CompileError
439
+ """
440
+ set_of_sets = [set (arg ) for arg in args ]
441
+ union = set .union (* set_of_sets )
442
+ len_of_sets = [len (arg ) for arg in args ]
443
+ if sum (len_of_sets ) == len (union ):
444
+ return
445
+
446
+ overlaps = set ()
447
+ for s in set_of_sets :
448
+ overlaps .update (s - union )
449
+ union = union - s
450
+
451
+ msg = f"Device '{ device_name } ' has overlapping gates: { overlaps } "
452
+ raise CompileError (msg )
453
+
454
+
455
+ def validate_device_capabilities (
456
+ device : qml .QubitDevice , device_capabilities : DeviceCapabilities
457
+ ) -> None :
458
+ """Validate configuration document against the device attributes.
459
+ Raise CompileError in case of mismatch:
460
+ * If device is not qjit-compatible.
461
+ * If configuration file does not exists.
462
+ * If decomposable, matrix, and native gates have some overlap.
463
+ * If decomposable, matrix, and native gates do not match gates in ``device.operations`` and
464
+ ``device.observables``.
465
+
466
+ Args:
467
+ device (qml.Device): An instance of a quantum device.
468
+ config (TOMLDocument): A TOML document representation.
469
+
470
+ Raises: CompileError
471
+ """
472
+
473
+ if not device_capabilities .qjit_compatible_flag :
474
+ raise CompileError (
475
+ f"Attempting to compile program for incompatible device '{ device .name } ': "
476
+ f"Config is not marked as qjit-compatible"
477
+ )
478
+
479
+ device_name = device .short_name if isinstance (device , qml .Device ) else device .name
480
+
481
+ native = pennylane_operation_set (device_capabilities .native_ops )
482
+ decomposable = pennylane_operation_set (device_capabilities .to_decomp_ops )
483
+ matrix = pennylane_operation_set (device_capabilities .to_matrix_ops )
484
+
485
+ check_no_overlap (native , decomposable , matrix , device_name = device_name )
486
+
487
+ if hasattr (device , "operations" ) and hasattr (device , "observables" ):
488
+ # For gates, we require strict match
489
+ device_gates = filter_out_adjoint (set (device .operations ))
490
+ spec_gates = filter_out_adjoint (set .union (native , matrix , decomposable ))
491
+ if device_gates != spec_gates :
492
+ raise CompileError (
493
+ "Gates in qml.device.operations and specification file do not match.\n "
494
+ f"Gates that present only in the device: { device_gates - spec_gates } \n "
495
+ f"Gates that present only in spec: { spec_gates - device_gates } \n "
496
+ )
497
+
498
+ # For observables, we do not have `non-native` section in the config, so we check that
499
+ # device data supercedes the specification.
500
+ device_observables = set (device .observables )
501
+ spec_observables = pennylane_operation_set (device_capabilities .native_obs )
502
+ if (spec_observables - device_observables ) != set ():
503
+ raise CompileError (
504
+ "Observables in qml.device.observables and specification file do not match.\n "
505
+ f"Observables that present only in spec: { spec_observables - device_observables } \n "
506
+ )
0 commit comments