Skip to content

Commit 715523e

Browse files
Sergei Mironovdime10
Sergei Mironov
andauthored
[Frontend] Make tests toml-schema independent (#712)
**Context:** Transition to the quantum device config schema 2 **Description of the Change:** Solve a regarding toml schema 2 udpate in tests by switching our test custom devices from toml text manipulations to the device capability manipulations **Benefits:** * Tests no longer require toml text manipulations. * Tests now contain simple examples of custom devices. * toml-specific code is now locates in `catalyst.utils.toml`. **Possible Drawbacks:** **Related GitHub Issues:** PennyLaneAI/pennylane-lightning#642 --------- Co-authored-by: David Ittah <[email protected]>
1 parent 47b3482 commit 715523e

17 files changed

+704
-696
lines changed

doc/changelog.md

+8-4
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,14 @@
9292
annotations.
9393
[(#751)](https://github.com/PennyLaneAI/catalyst/pull/751)
9494

95-
* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable
96-
class that implements the decorator's logic. This prevents having to excessively define
95+
* Refactored `vmap` decorator in order to follow a unified pattern that uses a callable
96+
class that implements the decorator's logic. This prevents having to excessively define
9797
functions in a nested fashion.
9898
[(#758)](https://github.com/PennyLaneAI/catalyst/pull/758)
9999

100+
* Catalyst tests now manipulate device capabilities rather than text configurations files.
101+
[(#712)](https://github.com/PennyLaneAI/catalyst/pull/712)
102+
100103
<h3>Breaking changes</h3>
101104

102105
* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
@@ -198,11 +201,12 @@
198201
This release contains contributions from (in alphabetical order):
199202

200203
David Ittah,
201-
Mehrdad Malekmohammadi,
202204
Erick Ochoa,
205+
Haochen Paul Wang,
203206
Lee James O'Riordan,
207+
Mehrdad Malekmohammadi,
204208
Raul Torres,
205-
Haochen Paul Wang.
209+
Sergei Mironov.
206210

207211
# Release 0.6.0
208212

frontend/catalyst/compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from catalyst.utils.exceptions import CompileError
3535
from catalyst.utils.filesystem import Directory
36-
from catalyst.utils.runtime import get_lib_path
36+
from catalyst.utils.runtime_environment import get_lib_path
3737

3838
package_root = os.path.dirname(__file__)
3939

frontend/catalyst/device/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,18 @@
1616
Internal API for the device module.
1717
"""
1818

19-
from catalyst.device.qjit_device import QJITDevice, QJITDeviceNewAPI
19+
from catalyst.device.qjit_device import (
20+
BackendInfo,
21+
QJITDevice,
22+
QJITDeviceNewAPI,
23+
extract_backend_info,
24+
validate_device_capabilities,
25+
)
2026

2127
__all__ = (
2228
"QJITDevice",
2329
"QJITDeviceNewAPI",
30+
"BackendInfo",
31+
"extract_backend_info",
32+
"validate_device_capabilities",
2433
)

frontend/catalyst/device/qjit_device.py

+195-25
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
This module contains device stubs for the old and new PennyLane device API, which facilitate
1717
the application of decomposition and other device pre-processing routines.
1818
"""
19-
19+
import os
20+
import pathlib
21+
import platform
22+
import re
2023
from copy import deepcopy
24+
from dataclasses import dataclass
2125
from functools import partial
22-
from typing import Optional, Set
26+
from typing import Any, Dict, Optional, Set
2327

2428
import pennylane as qml
2529
from pennylane.measurements import MidMeasureMP
@@ -32,13 +36,10 @@
3236
)
3337
from catalyst.utils.exceptions import CompileError
3438
from catalyst.utils.patching import Patcher
35-
from catalyst.utils.runtime import BackendInfo, device_get_toml_config
39+
from catalyst.utils.runtime_environment import get_lib_path
3640
from catalyst.utils.toml import (
3741
DeviceCapabilities,
3842
OperationProperties,
39-
ProgramFeatures,
40-
TOMLDocument,
41-
get_device_capabilities,
4243
intersect_operations,
4344
pennylane_operation_set,
4445
)
@@ -84,6 +85,87 @@
8485
for op in RUNTIME_OPERATIONS
8586
}
8687

88+
# TODO: This should be removed after implementing `get_c_interface`
89+
# for the following backend devices:
90+
SUPPORTED_RT_DEVICES = {
91+
"lightning.qubit": ("LightningSimulator", "librtd_lightning"),
92+
"lightning.kokkos": ("LightningKokkosSimulator", "librtd_lightning"),
93+
"braket.aws.qubit": ("OpenQasmDevice", "librtd_openqasm"),
94+
"braket.local.qubit": ("OpenQasmDevice", "librtd_openqasm"),
95+
}
96+
97+
98+
@dataclass
99+
class BackendInfo:
100+
"""Backend information"""
101+
102+
device_name: str
103+
c_interface_name: str
104+
lpath: str
105+
kwargs: Dict[str, Any]
106+
107+
108+
def extract_backend_info(device: qml.QubitDevice, capabilities: DeviceCapabilities) -> BackendInfo:
109+
"""Extract the backend info from a quantum device. The device is expected to carry a reference
110+
to a valid TOML config file."""
111+
# pylint: disable=too-many-branches
112+
113+
dname = device.name
114+
if isinstance(device, qml.Device):
115+
dname = device.short_name
116+
117+
device_name = ""
118+
device_lpath = ""
119+
device_kwargs = {}
120+
121+
if dname in SUPPORTED_RT_DEVICES:
122+
# Support backend devices without `get_c_interface`
123+
device_name = SUPPORTED_RT_DEVICES[dname][0]
124+
device_lpath = get_lib_path("runtime", "RUNTIME_LIB_DIR")
125+
sys_platform = platform.system()
126+
127+
if sys_platform == "Linux":
128+
device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".so")
129+
elif sys_platform == "Darwin": # pragma: no cover
130+
device_lpath = os.path.join(device_lpath, SUPPORTED_RT_DEVICES[dname][1] + ".dylib")
131+
else: # pragma: no cover
132+
raise NotImplementedError(f"Platform not supported: {sys_platform}")
133+
elif hasattr(device, "get_c_interface"):
134+
# Support third party devices with `get_c_interface`
135+
device_name, device_lpath = device.get_c_interface()
136+
else:
137+
raise CompileError(f"The {dname} device does not provide C interface for compilation.")
138+
139+
if not pathlib.Path(device_lpath).is_file():
140+
raise CompileError(f"Device at {device_lpath} cannot be found!")
141+
142+
if hasattr(device, "shots"):
143+
if isinstance(device, qml.Device):
144+
device_kwargs["shots"] = device.shots if device.shots else 0
145+
else:
146+
# TODO: support shot vectors
147+
device_kwargs["shots"] = device.shots.total_shots if device.shots else 0
148+
149+
if dname == "braket.local.qubit": # pragma: no cover
150+
device_kwargs["device_type"] = dname
151+
device_kwargs["backend"] = (
152+
# pylint: disable=protected-access
153+
device._device._delegate.DEVICE_ID
154+
)
155+
elif dname == "braket.aws.qubit": # pragma: no cover
156+
device_kwargs["device_type"] = dname
157+
device_kwargs["device_arn"] = device._device._arn # pylint: disable=protected-access
158+
if device._s3_folder: # pylint: disable=protected-access
159+
device_kwargs["s3_destination_folder"] = str(
160+
device._s3_folder # pylint: disable=protected-access
161+
)
162+
163+
for k, v in capabilities.options.items():
164+
if hasattr(device, v):
165+
device_kwargs[k] = getattr(device, v)
166+
167+
return BackendInfo(dname, device_name, device_lpath, device_kwargs)
168+
87169

88170
def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Set[str]:
89171
"""Calculate the set of supported quantum gates for the QJIT device from the gates
@@ -165,7 +247,7 @@ def _get_operations_to_convert_to_matrix(_capabilities: DeviceCapabilities) -> S
165247

166248
def __init__(
167249
self,
168-
target_config: TOMLDocument,
250+
original_device_capabilities: DeviceCapabilities,
169251
shots=None,
170252
wires=None,
171253
backend: Optional[BackendInfo] = None,
@@ -175,23 +257,18 @@ def __init__(
175257
self.backend_name = backend.c_interface_name if backend else "default"
176258
self.backend_lib = backend.lpath if backend else ""
177259
self.backend_kwargs = backend.kwargs if backend else {}
178-
device_name = backend.device_name if backend else "default"
179260

180-
program_features = ProgramFeatures(shots is not None)
181-
target_device_capabilities = get_device_capabilities(
182-
target_config, program_features, device_name
183-
)
184-
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
261+
self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)
185262

186263
@property
187264
def operations(self) -> Set[str]:
188265
"""Get the device operations using PennyLane's syntax"""
189-
return pennylane_operation_set(self.capabilities.native_ops)
266+
return pennylane_operation_set(self.qjit_capabilities.native_ops)
190267

191268
@property
192269
def observables(self) -> Set[str]:
193270
"""Get the device observables"""
194-
return pennylane_operation_set(self.capabilities.native_obs)
271+
return pennylane_operation_set(self.qjit_capabilities.native_obs)
195272

196273
def apply(self, operations, **kwargs):
197274
"""
@@ -270,6 +347,7 @@ class QJITDeviceNewAPI(qml.devices.Device):
270347
def __init__(
271348
self,
272349
original_device,
350+
original_device_capabilities: DeviceCapabilities,
273351
backend: Optional[BackendInfo] = None,
274352
):
275353
self.original_device = original_device
@@ -285,29 +363,23 @@ def __init__(
285363
self.backend_name = backend.c_interface_name if backend else "default"
286364
self.backend_lib = backend.lpath if backend else ""
287365
self.backend_kwargs = backend.kwargs if backend else {}
288-
device_name = backend.device_name if backend else "default"
289366

290-
target_config = device_get_toml_config(original_device)
291-
program_features = ProgramFeatures(original_device.shots is not None)
292-
target_device_capabilities = get_device_capabilities(
293-
target_config, program_features, device_name
294-
)
295-
self.capabilities = get_qjit_device_capabilities(target_device_capabilities)
367+
self.qjit_capabilities = get_qjit_device_capabilities(original_device_capabilities)
296368

297369
@property
298370
def operations(self) -> Set[str]:
299371
"""Get the device operations"""
300-
return pennylane_operation_set(self.capabilities.native_ops)
372+
return pennylane_operation_set(self.qjit_capabilities.native_ops)
301373

302374
@property
303375
def observables(self) -> Set[str]:
304376
"""Get the device observables"""
305-
return pennylane_operation_set(self.capabilities.native_obs)
377+
return pennylane_operation_set(self.qjit_capabilities.native_obs)
306378

307379
@property
308380
def measurement_processes(self) -> Set[str]:
309381
"""Get the device measurement processes"""
310-
return self.capabilities.measurement_processes
382+
return self.qjit_capabilities.measurement_processes
311383

312384
def preprocess(
313385
self,
@@ -334,3 +406,101 @@ def execute(self, circuits, execution_config):
334406
Raises: RuntimeError
335407
"""
336408
raise RuntimeError("QJIT devices cannot execute tapes.")
409+
410+
411+
def filter_out_adjoint(operations):
412+
"""Remove Adjoint from operations.
413+
414+
Args:
415+
operations (List[Str]): List of strings with names of supported operations
416+
417+
Returns:
418+
List: A list of strings with names of supported operations with Adjoint and C gates
419+
removed.
420+
"""
421+
adjoint = re.compile(r"^Adjoint\(.*\)$")
422+
423+
def is_not_adj(op):
424+
return not re.match(adjoint, op)
425+
426+
operations_no_adj = filter(is_not_adj, operations)
427+
return set(operations_no_adj)
428+
429+
430+
def check_no_overlap(*args, device_name):
431+
"""Check items in *args are mutually exclusive.
432+
433+
Args:
434+
*args (List[Str]): List of strings.
435+
device_name (str): Device name for error reporting.
436+
437+
Raises:
438+
CompileError
439+
"""
440+
set_of_sets = [set(arg) for arg in args]
441+
union = set.union(*set_of_sets)
442+
len_of_sets = [len(arg) for arg in args]
443+
if sum(len_of_sets) == len(union):
444+
return
445+
446+
overlaps = set()
447+
for s in set_of_sets:
448+
overlaps.update(s - union)
449+
union = union - s
450+
451+
msg = f"Device '{device_name}' has overlapping gates: {overlaps}"
452+
raise CompileError(msg)
453+
454+
455+
def validate_device_capabilities(
456+
device: qml.QubitDevice, device_capabilities: DeviceCapabilities
457+
) -> None:
458+
"""Validate configuration document against the device attributes.
459+
Raise CompileError in case of mismatch:
460+
* If device is not qjit-compatible.
461+
* If configuration file does not exists.
462+
* If decomposable, matrix, and native gates have some overlap.
463+
* If decomposable, matrix, and native gates do not match gates in ``device.operations`` and
464+
``device.observables``.
465+
466+
Args:
467+
device (qml.Device): An instance of a quantum device.
468+
config (TOMLDocument): A TOML document representation.
469+
470+
Raises: CompileError
471+
"""
472+
473+
if not device_capabilities.qjit_compatible_flag:
474+
raise CompileError(
475+
f"Attempting to compile program for incompatible device '{device.name}': "
476+
f"Config is not marked as qjit-compatible"
477+
)
478+
479+
device_name = device.short_name if isinstance(device, qml.Device) else device.name
480+
481+
native = pennylane_operation_set(device_capabilities.native_ops)
482+
decomposable = pennylane_operation_set(device_capabilities.to_decomp_ops)
483+
matrix = pennylane_operation_set(device_capabilities.to_matrix_ops)
484+
485+
check_no_overlap(native, decomposable, matrix, device_name=device_name)
486+
487+
if hasattr(device, "operations") and hasattr(device, "observables"):
488+
# For gates, we require strict match
489+
device_gates = filter_out_adjoint(set(device.operations))
490+
spec_gates = filter_out_adjoint(set.union(native, matrix, decomposable))
491+
if device_gates != spec_gates:
492+
raise CompileError(
493+
"Gates in qml.device.operations and specification file do not match.\n"
494+
f"Gates that present only in the device: {device_gates - spec_gates}\n"
495+
f"Gates that present only in spec: {spec_gates - device_gates}\n"
496+
)
497+
498+
# For observables, we do not have `non-native` section in the config, so we check that
499+
# device data supercedes the specification.
500+
device_observables = set(device.observables)
501+
spec_observables = pennylane_operation_set(device_capabilities.native_obs)
502+
if (spec_observables - device_observables) != set():
503+
raise CompileError(
504+
"Observables in qml.device.observables and specification file do not match.\n"
505+
f"Observables that present only in spec: {spec_observables - device_observables}\n"
506+
)

0 commit comments

Comments
 (0)