15
15
# RUN: %PYTHON %s | FileCheck %s
16
16
# pylint: disable=line-too-long
17
17
18
+ import platform
18
19
from copy import deepcopy
19
20
20
21
import jax
21
22
import pennylane as qml
22
23
23
24
from catalyst import cond , for_loop , measure , qjit , while_loop
25
+ from catalyst .compiler import get_lib_path
24
26
from catalyst .utils .toml import (
27
+ OperationProperties ,
25
28
ProgramFeatures ,
26
29
get_device_capabilities ,
27
30
pennylane_operation_set ,
28
31
)
29
32
30
33
31
- def get_custom_device_without (num_wires , discards ):
34
+ def get_custom_device_without (num_wires , discards = frozenset (), force_matrix = frozenset () ):
32
35
"""Generate a custom device without gates in discards."""
33
36
34
- class CustomDevice (qml .QubitDevice ):
37
+ class CustomDevice (qml .devices . Device ):
35
38
"""Custom Gate Set Device"""
36
39
37
40
name = "Custom Device"
38
- short_name = "lightning.qubit"
39
41
pennylane_requires = "0.35.0"
40
42
version = "0.0.2"
41
43
author = "Tester"
@@ -55,12 +57,13 @@ def __init__(self, shots=None, wires=None):
55
57
)
56
58
custom_capabilities = deepcopy (lightning_capabilities )
57
59
for gate in discards :
58
- if gate in custom_capabilities .native_ops :
59
- custom_capabilities .native_ops .pop (gate )
60
- if gate in custom_capabilities .to_decomp_ops :
61
- custom_capabilities .to_decomp_ops .pop (gate )
62
- if gate in custom_capabilities .to_matrix_ops :
63
- custom_capabilities .to_matrix_ops .pop (gate )
60
+ custom_capabilities .native_ops .pop (gate , None )
61
+ custom_capabilities .to_decomp_ops .pop (gate , None )
62
+ custom_capabilities .to_matrix_ops .pop (gate , None )
63
+ for gate in force_matrix :
64
+ custom_capabilities .native_ops .pop (gate , None )
65
+ custom_capabilities .to_decomp_ops .pop (gate , None )
66
+ custom_capabilities .to_matrix_ops [gate ] = OperationProperties (False , False , False )
64
67
self .qjit_capabilities = custom_capabilities
65
68
66
69
def apply (self , operations , ** kwargs ):
@@ -81,12 +84,27 @@ def observables(self):
81
84
"""Return PennyLane observables"""
82
85
return pennylane_operation_set (self .qjit_capabilities .native_obs )
83
86
87
+ @staticmethod
88
+ def get_c_interface ():
89
+ """Returns a tuple consisting of the device name, and
90
+ the location to the shared object with the C/C++ device implementation.
91
+ """
92
+ system_extension = ".dylib" if platform .system () == "Darwin" else ".so"
93
+ lib_path = (
94
+ get_lib_path ("runtime" , "RUNTIME_LIB_DIR" ) + "/librtd_dummy" + system_extension
95
+ )
96
+ return "dummy.remote" , lib_path
97
+
98
+ def execute (self , circuits , execution_config ):
99
+ """Execution."""
100
+ return circuits , execution_config
101
+
84
102
return CustomDevice (wires = num_wires )
85
103
86
104
87
105
def test_decompose_multicontrolledx ():
88
106
"""Test decomposition of MultiControlledX."""
89
- dev = get_custom_device_without (5 , {"MultiControlledX" })
107
+ dev = get_custom_device_without (5 , discards = {"MultiControlledX" })
90
108
91
109
@qjit (target = "mlir" )
92
110
@qml .qnode (dev )
@@ -107,7 +125,7 @@ def decompose_multicontrolled_x1(theta: float):
107
125
108
126
def test_decompose_multicontrolledx_in_conditional ():
109
127
"""Test decomposition of MultiControlledX in conditional."""
110
- dev = get_custom_device_without (5 , {"MultiControlledX" })
128
+ dev = get_custom_device_without (5 , discards = {"MultiControlledX" })
111
129
112
130
@qjit (target = "mlir" )
113
131
@qml .qnode (dev )
@@ -133,7 +151,7 @@ def cond_fn():
133
151
134
152
def test_decompose_multicontrolledx_in_while_loop ():
135
153
"""Test decomposition of MultiControlledX in while loop."""
136
- dev = get_custom_device_without (5 , {"MultiControlledX" })
154
+ dev = get_custom_device_without (5 , discards = {"MultiControlledX" })
137
155
138
156
@qjit (target = "mlir" )
139
157
@qml .qnode (dev )
@@ -160,7 +178,7 @@ def loop(v):
160
178
161
179
def test_decompose_multicontrolledx_in_for_loop ():
162
180
"""Test decomposition of MultiControlledX in for loop."""
163
- dev = get_custom_device_without (5 , {"MultiControlledX" })
181
+ dev = get_custom_device_without (5 , discards = {"MultiControlledX" })
164
182
165
183
@qjit (target = "mlir" )
166
184
@qml .qnode (dev )
@@ -186,7 +204,7 @@ def loop(_):
186
204
187
205
def test_decompose_rot ():
188
206
"""Test decomposition of Rot gate."""
189
- dev = get_custom_device_without (1 , {"Rot" , "C(Rot)" })
207
+ dev = get_custom_device_without (1 , discards = {"Rot" , "C(Rot)" })
190
208
191
209
@qjit (target = "mlir" )
192
210
@qml .qnode (dev )
@@ -216,7 +234,7 @@ def decompose_rot(phi: float, theta: float, omega: float):
216
234
217
235
def test_decompose_s ():
218
236
"""Test decomposition of S gate."""
219
- dev = get_custom_device_without (1 , {"S" , "C(S)" })
237
+ dev = get_custom_device_without (1 , discards = {"S" , "C(S)" })
220
238
221
239
@qjit (target = "mlir" )
222
240
@qml .qnode (dev )
@@ -238,7 +256,7 @@ def decompose_s():
238
256
239
257
def test_decompose_qubitunitary ():
240
258
"""Test decomposition of QubitUnitary"""
241
- dev = get_custom_device_without (1 , {"QubitUnitary" })
259
+ dev = get_custom_device_without (1 , discards = {"QubitUnitary" })
242
260
243
261
@qjit (target = "mlir" )
244
262
@qml .qnode (dev )
@@ -260,7 +278,7 @@ def decompose_qubit_unitary(U: jax.core.ShapedArray([2, 2], float)):
260
278
261
279
def test_decompose_singleexcitationplus ():
262
280
"""Test decomposition of single excitation plus."""
263
- dev = get_custom_device_without (2 , {"SingleExcitationPlus" , "C(SingleExcitationPlus)" })
281
+ dev = get_custom_device_without (2 , discards = {"SingleExcitationPlus" , "C(SingleExcitationPlus)" })
264
282
265
283
@qjit (target = "mlir" )
266
284
@qml .qnode (dev )
@@ -304,3 +322,25 @@ def decompose_singleexcitationplus(theta: float):
304
322
305
323
306
324
test_decompose_singleexcitationplus ()
325
+
326
+
327
+ def test_decompose_to_matrix ():
328
+ """Test decomposition of QubitUnitary"""
329
+ dev = get_custom_device_without (1 , force_matrix = {"PauliY" })
330
+
331
+ @qjit (target = "mlir" )
332
+ @qml .qnode (dev )
333
+ # CHECK-LABEL: public @jit_decompose_to_matrix
334
+ def decompose_to_matrix ():
335
+ # CHECK: quantum.custom "PauliX"
336
+ qml .PauliX (wires = 0 )
337
+ # CHECK: quantum.unitary
338
+ qml .PauliY (wires = 0 )
339
+ # CHECK: quantum.custom "PauliZ"
340
+ qml .PauliZ (wires = 0 )
341
+ return measure (wires = 0 )
342
+
343
+ print (decompose_to_matrix .mlir )
344
+
345
+
346
+ test_decompose_to_matrix ()
0 commit comments