Skip to content

Commit 8c42cc9

Browse files
Sergei Mironovlillian542
Sergei Mironov
andauthored
[Frontend] Pass capabilities to the decomposer (#749)
**Context:** Catalyst needs to deal with operations which are not supported by the target device. This PR addresses decomposition-to-matrix branch. **Description of the Change:** We read gates to be decomposed to matrix from the device capabilities (which in turn are described in the device toml file). [sc-62011] **Benefits:** Device authors gain control over the decomposition strategies **Possible Drawbacks:** **Related GitHub Issues:** Requires #712 --------- Co-authored-by: lillian542 <[email protected]>
1 parent 2636257 commit 8c42cc9

File tree

5 files changed

+158
-47
lines changed

5 files changed

+158
-47
lines changed

frontend/catalyst/device/decomposition.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
compilation & execution on devices.
1818
"""
1919

20+
from functools import partial
21+
2022
import jax
2123
import pennylane as qml
2224
from pennylane import transform
@@ -37,15 +39,16 @@
3739
from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes
3840
from catalyst.tracing.contexts import EvaluationContext
3941
from catalyst.utils.exceptions import CompileError
42+
from catalyst.utils.toml import DeviceCapabilities
4043

4144

42-
def catalyst_decomposer(op):
45+
def catalyst_decomposer(op, capabilities: DeviceCapabilities):
4346
"""A decomposer for catalyst, to be passed to the decompose transform. Takes an operator and
4447
returns the default decomposition, unless the operator should decompose to a QubitUnitary.
4548
Raises a CompileError for MidMeasureMP"""
4649
if isinstance(op, MidMeasureMP):
4750
raise CompileError("Must use 'measure' from Catalyst instead of PennyLane.")
48-
if op.name in {"MultiControlledX", "BlockEncode"} or isinstance(op, qml.ops.Controlled):
51+
if capabilities.to_matrix_ops.get(op.name) or isinstance(op, qml.ops.Controlled):
4952
return _decompose_to_matrix(op)
5053
return op.decomposition()
5154

@@ -55,7 +58,7 @@ def catalyst_decompose(
5558
tape: qml.tape.QuantumTape,
5659
ctx,
5760
stopping_condition,
58-
decomposer=catalyst_decomposer,
61+
capabilities,
5962
max_expansion=None,
6063
):
6164
"""Decompose operations until the stopping condition is met.
@@ -76,7 +79,7 @@ def catalyst_decompose(
7679
tape,
7780
stopping_condition,
7881
skip_initial_state_prep=False,
79-
decomposer=decomposer,
82+
decomposer=partial(catalyst_decomposer, capabilities=capabilities),
8083
max_expansion=max_expansion,
8184
name="catalyst on this device",
8285
error=CompileError,
@@ -85,7 +88,7 @@ def catalyst_decompose(
8588
new_ops = []
8689
for op in toplevel_tape.operations:
8790
if has_nested_tapes(op):
88-
op = _decompose_nested_tapes(op, ctx, stopping_condition, decomposer, max_expansion)
91+
op = _decompose_nested_tapes(op, ctx, stopping_condition, capabilities, max_expansion)
8992
new_ops.append(op)
9093
tape = qml.tape.QuantumScript(new_ops, tape.measurements, shots=tape.shots)
9194

@@ -103,7 +106,7 @@ def _decompose_to_matrix(op):
103106
return [op]
104107

105108

106-
def _decompose_nested_tapes(op, ctx, stopping_condition, decomposer, max_expansion):
109+
def _decompose_nested_tapes(op, ctx, stopping_condition, capabilities, max_expansion):
107110
new_regions = []
108111
for region in op.regions:
109112
if region.quantum_tape is None:
@@ -114,7 +117,7 @@ def _decompose_nested_tapes(op, ctx, stopping_condition, decomposer, max_expansi
114117
region.quantum_tape,
115118
ctx=ctx,
116119
stopping_condition=stopping_condition,
117-
decomposer=decomposer,
120+
capabilities=capabilities,
118121
max_expansion=max_expansion,
119122
)
120123
new_tape = tapes[0]

frontend/catalyst/device/qjit_device.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,12 @@ def preprocess(
393393
program = TransformProgram()
394394

395395
ops_acceptance = partial(catalyst_acceptance, operations=self.operations)
396-
program.add_transform(catalyst_decompose, ctx=ctx, stopping_condition=ops_acceptance)
396+
program.add_transform(
397+
catalyst_decompose,
398+
ctx=ctx,
399+
stopping_condition=ops_acceptance,
400+
capabilities=self.qjit_capabilities,
401+
)
397402

398403
if self.measurement_processes == {"Counts"}:
399404
program.add_transform(measurements_from_counts)

frontend/test/lit/test_decomposition.py

+57-17
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,29 @@
1515
# RUN: %PYTHON %s | FileCheck %s
1616
# pylint: disable=line-too-long
1717

18+
import platform
1819
from copy import deepcopy
1920

2021
import jax
2122
import pennylane as qml
2223

2324
from catalyst import cond, for_loop, measure, qjit, while_loop
25+
from catalyst.compiler import get_lib_path
2426
from catalyst.utils.toml import (
27+
OperationProperties,
2528
ProgramFeatures,
2629
get_device_capabilities,
2730
pennylane_operation_set,
2831
)
2932

3033

31-
def get_custom_device_without(num_wires, discards):
34+
def get_custom_device_without(num_wires, discards=frozenset(), force_matrix=frozenset()):
3235
"""Generate a custom device without gates in discards."""
3336

34-
class CustomDevice(qml.QubitDevice):
37+
class CustomDevice(qml.devices.Device):
3538
"""Custom Gate Set Device"""
3639

3740
name = "Custom Device"
38-
short_name = "lightning.qubit"
3941
pennylane_requires = "0.35.0"
4042
version = "0.0.2"
4143
author = "Tester"
@@ -55,12 +57,13 @@ def __init__(self, shots=None, wires=None):
5557
)
5658
custom_capabilities = deepcopy(lightning_capabilities)
5759
for gate in discards:
58-
if gate in custom_capabilities.native_ops:
59-
custom_capabilities.native_ops.pop(gate)
60-
if gate in custom_capabilities.to_decomp_ops:
61-
custom_capabilities.to_decomp_ops.pop(gate)
62-
if gate in custom_capabilities.to_matrix_ops:
63-
custom_capabilities.to_matrix_ops.pop(gate)
60+
custom_capabilities.native_ops.pop(gate, None)
61+
custom_capabilities.to_decomp_ops.pop(gate, None)
62+
custom_capabilities.to_matrix_ops.pop(gate, None)
63+
for gate in force_matrix:
64+
custom_capabilities.native_ops.pop(gate, None)
65+
custom_capabilities.to_decomp_ops.pop(gate, None)
66+
custom_capabilities.to_matrix_ops[gate] = OperationProperties(False, False, False)
6467
self.qjit_capabilities = custom_capabilities
6568

6669
def apply(self, operations, **kwargs):
@@ -81,12 +84,27 @@ def observables(self):
8184
"""Return PennyLane observables"""
8285
return pennylane_operation_set(self.qjit_capabilities.native_obs)
8386

87+
@staticmethod
88+
def get_c_interface():
89+
"""Returns a tuple consisting of the device name, and
90+
the location to the shared object with the C/C++ device implementation.
91+
"""
92+
system_extension = ".dylib" if platform.system() == "Darwin" else ".so"
93+
lib_path = (
94+
get_lib_path("runtime", "RUNTIME_LIB_DIR") + "/librtd_dummy" + system_extension
95+
)
96+
return "dummy.remote", lib_path
97+
98+
def execute(self, circuits, execution_config):
99+
"""Execution."""
100+
return circuits, execution_config
101+
84102
return CustomDevice(wires=num_wires)
85103

86104

87105
def test_decompose_multicontrolledx():
88106
"""Test decomposition of MultiControlledX."""
89-
dev = get_custom_device_without(5, {"MultiControlledX"})
107+
dev = get_custom_device_without(5, discards={"MultiControlledX"})
90108

91109
@qjit(target="mlir")
92110
@qml.qnode(dev)
@@ -107,7 +125,7 @@ def decompose_multicontrolled_x1(theta: float):
107125

108126
def test_decompose_multicontrolledx_in_conditional():
109127
"""Test decomposition of MultiControlledX in conditional."""
110-
dev = get_custom_device_without(5, {"MultiControlledX"})
128+
dev = get_custom_device_without(5, discards={"MultiControlledX"})
111129

112130
@qjit(target="mlir")
113131
@qml.qnode(dev)
@@ -133,7 +151,7 @@ def cond_fn():
133151

134152
def test_decompose_multicontrolledx_in_while_loop():
135153
"""Test decomposition of MultiControlledX in while loop."""
136-
dev = get_custom_device_without(5, {"MultiControlledX"})
154+
dev = get_custom_device_without(5, discards={"MultiControlledX"})
137155

138156
@qjit(target="mlir")
139157
@qml.qnode(dev)
@@ -160,7 +178,7 @@ def loop(v):
160178

161179
def test_decompose_multicontrolledx_in_for_loop():
162180
"""Test decomposition of MultiControlledX in for loop."""
163-
dev = get_custom_device_without(5, {"MultiControlledX"})
181+
dev = get_custom_device_without(5, discards={"MultiControlledX"})
164182

165183
@qjit(target="mlir")
166184
@qml.qnode(dev)
@@ -186,7 +204,7 @@ def loop(_):
186204

187205
def test_decompose_rot():
188206
"""Test decomposition of Rot gate."""
189-
dev = get_custom_device_without(1, {"Rot", "C(Rot)"})
207+
dev = get_custom_device_without(1, discards={"Rot", "C(Rot)"})
190208

191209
@qjit(target="mlir")
192210
@qml.qnode(dev)
@@ -216,7 +234,7 @@ def decompose_rot(phi: float, theta: float, omega: float):
216234

217235
def test_decompose_s():
218236
"""Test decomposition of S gate."""
219-
dev = get_custom_device_without(1, {"S", "C(S)"})
237+
dev = get_custom_device_without(1, discards={"S", "C(S)"})
220238

221239
@qjit(target="mlir")
222240
@qml.qnode(dev)
@@ -238,7 +256,7 @@ def decompose_s():
238256

239257
def test_decompose_qubitunitary():
240258
"""Test decomposition of QubitUnitary"""
241-
dev = get_custom_device_without(1, {"QubitUnitary"})
259+
dev = get_custom_device_without(1, discards={"QubitUnitary"})
242260

243261
@qjit(target="mlir")
244262
@qml.qnode(dev)
@@ -260,7 +278,7 @@ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)):
260278

261279
def test_decompose_singleexcitationplus():
262280
"""Test decomposition of single excitation plus."""
263-
dev = get_custom_device_without(2, {"SingleExcitationPlus", "C(SingleExcitationPlus)"})
281+
dev = get_custom_device_without(2, discards={"SingleExcitationPlus", "C(SingleExcitationPlus)"})
264282

265283
@qjit(target="mlir")
266284
@qml.qnode(dev)
@@ -304,3 +322,25 @@ def decompose_singleexcitationplus(theta: float):
304322

305323

306324
test_decompose_singleexcitationplus()
325+
326+
327+
def test_decompose_to_matrix():
328+
"""Test decomposition of QubitUnitary"""
329+
dev = get_custom_device_without(1, force_matrix={"PauliY"})
330+
331+
@qjit(target="mlir")
332+
@qml.qnode(dev)
333+
# CHECK-LABEL: public @jit_decompose_to_matrix
334+
def decompose_to_matrix():
335+
# CHECK: quantum.custom "PauliX"
336+
qml.PauliX(wires=0)
337+
# CHECK: quantum.unitary
338+
qml.PauliY(wires=0)
339+
# CHECK: quantum.custom "PauliZ"
340+
qml.PauliZ(wires=0)
341+
return measure(wires=0)
342+
343+
print(decompose_to_matrix.mlir)
344+
345+
346+
test_decompose_to_matrix()

frontend/test/pytest/test_config_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytest
2323

2424
from catalyst.device import QJITDevice, validate_device_capabilities
25-
from catalyst.device.qjit_device import check_no_overlap
25+
from catalyst.device.qjit_device import check_no_overlap, validate_device_capabilities
2626
from catalyst.utils.exceptions import CompileError
2727
from catalyst.utils.toml import (
2828
DeviceCapabilities,

0 commit comments

Comments
 (0)