-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheff_numpy.py
79 lines (60 loc) · 2.37 KB
/
eff_numpy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# # # # # # # ###### # #
# ## # # # ## ## # # # #
# # # # # # # # # # # # # #
# # # # # # # # # ###### #
# # # # # # # # # #
# # ## # # # # # #
# # # ##### # # # #
from tracer import Tracer
from xdsl.builder import ImplicitBuilder
from stablehlo import ConstantOp, IfOp, ReturnOp
from pennylane import pytrees
def absolute(arr):
return abs(arr)
def arccos(arr): ...
def acosh(arr): ...
def add(x, y):
return x + y
def all(arr, axis=None, out=None, keepdims=False, *, where=None): ...
# # # # #
# # # # # #
# # # # # #
# # # # #
# # # ####### # #
# # # # # # #
# ##### # # # #
def cond(pred, true_fun, false_fun, *operands):
# Create two blocks one for each path
# Trace each path
# Create IfOp
from xdsl.ir import Region, Block
from eff_types import i1, si64
from eff_type_aliases import TensorType
from xdsl.dialects.builtin import DenseIntOrFPElementsAttr
ty = TensorType(i1, [])
val = DenseIntOrFPElementsAttr.from_list(ty, [pred])
predOp = ConstantOp(val, ty).results[0]
true_region = Region(Block([], arg_types=[]))
false_region = Region(Block([], arg_types=[]))
with ImplicitBuilder(true_region.blocks[0]):
ret_true = true_fun(*operands)
ret_true_flat, true_shape = pytrees.flatten(ret_true)
ret_ssa_true = [tracer.ssaval for tracer in ret_true_flat]
with ImplicitBuilder(true_region.blocks[0]):
ReturnOp(*ret_ssa_true)
with ImplicitBuilder(false_region.blocks[0]):
ret_false = false_fun(*operands)
ret_false_flat, false_shape = pytrees.flatten(ret_false)
ret_ssa_false = [tracer.ssaval for tracer in ret_false_flat]
with ImplicitBuilder(false_region.blocks[0]):
ReturnOp(*ret_ssa_false)
assert true_shape == false_shape
for tval, fval in zip(ret_true_flat, ret_false_flat):
tty = tval.type
fty = fval.type
assert tty == fty
result_tys = [tval.type for tval in ret_true_flat]
ifOp = IfOp(predOp, true_region, false_region, result_tys)
from tracer import StableHLOTracer
outtracers = [StableHLOTracer(res) for res in ifOp.results]
return pytrees.unflatten(outtracers, true_shape)