-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstablehlo.py
385 lines (327 loc) · 13.4 KB
/
stablehlo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# ####### ###### ####### ###### # ####### ### ####### # # #####
# # # # # # # # # # # # # # ## # # #
# # # # # # # # # # # # # # # # # #
# # # ###### ##### ###### # # # # # # # # # #####
# # # # # # # ####### # # # # # # # #
# # # # # # # # # # # # # # ## # #
# ####### # ####### # # # # # ### ####### # # #####
from collections.abc import Sequence
import unittest
from xdsl.ir import SSAValue
from xdsl.irdl import IRDLOperation, Operand, OpResult, VarOperand, VarRegion
from xdsl.irdl import irdl_op_definition, operand_def, prop_def, region_def, result_def, var_operand_def, var_region_def, var_result_def
from xdsl.dialects.builtin import DenseIntOrFPElementsAttr, IntegerAttr
from xdsl.traits import IsTerminator
from eff_type_aliases import TypeAlias
from eff_type_aliases import TensorType
from eff_type_aliases import I1
from eff_type_aliases import SI2, SI4, SI8, SI16, SI32, SI64
from eff_type_aliases import UI2, UI4, UI8, UI16, UI32, UI64
from eff_type_aliases import BFloat16Type
from eff_type_aliases import Float16Type, Float32Type, Float64Type
from eff_type_aliases import TokenType
from eff_type_aliases import StableHLOBoolean
from eff_type_aliases import StableHLOSignedInteger, StableHLOUnsignedInteger
from eff_type_aliases import StableHLOFloat
from eff_type_aliases import StableHLOComplex
from eff_type_aliases import StableHLOElementType
from eff_type_aliases import StableHLOTensor
@irdl_op_definition
class AbsOp(IRDLOperation):
name = "stablehlo.abs"
operand = operand_def(StableHLOTensor)
result = result_def(StableHLOTensor)
def __init__(self, operand, result_ty):
super().__init__(operands=(operand,),
result_types=(result_ty,))
def verify_(self):
msg1 = f'''{AbsOp.name} has constraint "shape(result) = shape(operand)"'''
operandty = self.operand.type
resultty = self.result.type
msg1 += f'instead has mismatched "shape(result) = {resultty} and shape(operand) = {operandty}'
assert resultty.shape == operandty.shape, msg1
msg2 = f"{AbsOp.name} has constraint"
msg2 += " baseline_element_type(result) is defined as:"
msg2 += " complex_element_type(element_type(operand)) if is_complex(operand)"
msg2 += " baseline_element_type(operand) otherwise."
resultety = resultty.element_type
operandety = operandty.element_type
assert resultety == operandety, msg2
class TestAbsOp(unittest.TestCase):
def test_abs_op(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
one = ConstantOp(val, ty)
absop = AbsOp(one, ty)
expected = """%0 = "stablehlo.abs"(%1) : (tensor<1xsi64>) -> tensor<1xsi64>"""
observed = str(absop)
assert expected == observed
def test_constraint_same_shapes(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
one = ConstantOp(val, ty)
ty2 = TensorType(si64, [2])
absop = AbsOp(one, ty2)
with self.assertRaises(AssertionError):
absop.verify_()
def test_constraint_same_etype(self):
from eff_types import si32, si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
one = ConstantOp(val, ty)
ty2 = TensorType(si32, [1])
absop = AbsOp(one, ty2)
with self.assertRaises(AssertionError):
absop.verify_()
@irdl_op_definition
class AddOp(IRDLOperation):
name = "stablehlo.add"
lhs = operand_def(StableHLOTensor)
rhs = operand_def(StableHLOTensor)
result = result_def(StableHLOTensor)
def __init__(self, lhs, rhs, result):
super().__init__(operands=(lhs, rhs),
result_types=(result,))
@staticmethod
def msgC1():
msg = 'If the operation uses non-quantized tensors: '
msg += ' type(lhs) = type(rhs) = type(result)'
return msg
def C1(self):
assert self.lhs.type == self.rhs.type, AddOp.msgC1()
assert self.lhs.type == self.result.type, AddOp.msgC1()
def verify_(self):
self.C1()
class TestAddOp(unittest.TestCase):
def test_add_op(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
c1 = ConstantOp(val, ty)
addOp = AddOp(c1, c1, ty)
expected = '%0 = "stablehlo.add"(%1, %1) : (tensor<1xsi64>, tensor<1xsi64>) -> tensor<1xsi64>'
observed = str(addOp)
assert expected == observed
def test_bad_type(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
c1 = ConstantOp(val, ty)
ty2 = TensorType(si64, [2])
c1_0 = ConstantOp(val, ty2)
addOp = AddOp(c1, c1_0, ty)
with self.assertRaises(AssertionError, msg=AddOp.msgC1()):
addOp.C1()
@irdl_op_definition
class AfterAllOp(IRDLOperation):
name = "stablehlo.after_all"
inputs = var_operand_def(TokenType)
result = result_def(TokenType)
def __init__(self, inputs, result):
super().__init__(operands=[inputs],
result_types=(result,))
class TestAfterAllOp(unittest.TestCase):
def test_after_all(self):
from xdsl.ir import Block
tokentype = TokenType()
block = Block([], arg_types=[tokentype, tokentype])
tokenvals = block.args
afterallop = AfterAllOp(tokenvals, tokentype)
observed = str(afterallop)
expected = '%0 = "stablehlo.after_all"(%1, %2) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token'
assert observed == expected
ReplicaGroupsType : TypeAlias = TensorType[SI64]
@irdl_op_definition
class AllGatherOp(IRDLOperation):
name = "stablehlo.all_gather"
inputs: VarOperand = var_operand_def(StableHLOTensor)
all_gather_dim = prop_def(SI64)
replica_groups = prop_def(ReplicaGroupsType)
channel_id = prop_def(SI64)
use_global_device_ids = prop_def(I1)
result = var_result_def(StableHLOTensor)
def __init__(self, operands : Sequence[SSAValue], all_gather_dim, replica_groups, channel_id, use_global_device_ids, result):
properties = {"all_gather_dim" : all_gather_dim,
"replica_groups" : replica_groups,
"channel_id" : channel_id,
"use_global_device_ids" : use_global_device_ids}
super().__init__(operands=(operands,),
result_types=(result,),
properties=properties)
@staticmethod
def msgC1():
return "0 <= all_gather_dim < rank(operands...)"
def C1(self):
"""Missing verification"""
def verify_(self):
self.C1()
class TestAllGatherOp(unittest.TestCase):
def test_all_gather_op(self):
from eff_types import i1, si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
c1 = ConstantOp(val, ty)
si_c1 = IntegerAttr(1, si64)
i1_c1 = IntegerAttr(1, i1)
all_gather_op = AllGatherOp([c1], si_c1, val, si_c1, i1_c1, [ty])
# ALL_REDUCE
# ALL_TO_ALL
TensorIntegerType : TypeAlias = TensorType[StableHLOBoolean | StableHLOSignedInteger | StableHLOUnsignedInteger]
@irdl_op_definition
class AndOp(IRDLOperation):
name = "stablehlo.and"
lhs : Operand = operand_def(TensorIntegerType)
rhs : Operand = operand_def(TensorIntegerType)
result : OpResult = result_def(TensorIntegerType)
def __init__(self, lhs, rhs, result):
super().__init__(operands=(lhs, rhs),
result_types=(result,))
def verify_(self):
assert self.lhs.type == self.rhs.type
assert self.rhs.type == self.result.type
class TestAndOp(unittest.TestCase):
def test_and_op(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
c1 = ConstantOp(val, ty)
assert AndOp(c1, c1, ty)
TensorFloatAndComplex : TypeAlias = TensorType[StableHLOFloat | StableHLOComplex]
@irdl_op_definition
class Atan2Op(IRDLOperation):
name = "stablehlo.atan2"
lhs = operand_def(TensorFloatAndComplex)
rhs = operand_def(TensorFloatAndComplex)
result = result_def(TensorFloatAndComplex)
def __init__(self, lhs, rhs, result):
super().__init__(operands=(lhs, rhs),
result_types=(result,))
def verify_(self):
assert self.lhs.type == self.rhs.type
assert self.rhs.type == self.result.type
class TestAtan2Op(unittest.TestCase):
def test_atan2_op(self):
from eff_types import f64
ty = TensorType(f64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
c1 = ConstantOp(val, ty)
assert Atan2Op(c1, c1, ty)
# BATCH_NORM_GRAD
# BATCH_NORM_INFERENCE
# BATCH_NORM_TRAINING
@irdl_op_definition
class BitcastConvertOp(IRDLOperation):
name = "stablehlo.bitcast_convert"
input = operand_def(StableHLOTensor)
result = result_def(StableHLOTensor)
def __init__(self, input, result):
super().__init__(operands=(input,),
result_types=(result,))
class TestBitcastConvertOp(unittest.TestCase):
def test_bitcast_convert_op(self):
from eff_types import f64
ty = TensorType(f64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
c1 = ConstantOp(val, ty)
assert BitcastConvertOp(c1, ty)
# BROADCAST_IN_DIM
TensorSI32 : TypeAlias = TensorType[SI32]
@irdl_op_definition
class CaseOp(IRDLOperation):
name = "stablehlo.case"
index = operand_def(TensorSI32)
branches : VarRegion = var_region_def("single_block")
results = var_result_def(StableHLOTensor | TokenType)
def __init__(self, index, branches, results):
super().__init__(operands=(index,),
result_types=(results,),
regions=(branches,))
@irdl_op_definition
class ConvertOp(IRDLOperation):
name = "stablehlo.convert"
input = operand_def(StableHLOTensor)
result = result_def(StableHLOTensor)
def __init__(self, input, result):
super().__init__(operands=(input,),
result_types=(result,))
def verify_(self):
assert self.input.type.shape == self.result.type.shape
TensorI1 : TypeAlias = TensorType[I1]
@irdl_op_definition
class IfOp(IRDLOperation):
name = "stablehlo.if"
pred = operand_def(TensorI1)
true_branch = region_def("single_block")
false_branch = region_def("single_block")
outputs = var_result_def(StableHLOTensor | TokenType)
def __init__(self, pred, true_branch, false_branch, results):
super().__init__(operands=(pred,),
result_types=(results,),
regions=(true_branch, false_branch),)
@irdl_op_definition
class ReturnOp(IRDLOperation):
name = "stablehlo.return"
input = var_operand_def(StableHLOTensor)
traits = frozenset([IsTerminator()])
def __init__(self, input):
super().__init__(operands=(input,))
@irdl_op_definition
class ConstantOp(IRDLOperation):
name = "stablehlo.constant"
value = prop_def(DenseIntOrFPElementsAttr)
output = result_def(StableHLOTensor)
def __init__(self, value, tensor_type):
properties = { "value" : value }
super().__init__(result_types=(tensor_type,),
properties=properties)
def verify_(self):
msg = f'''{ConstantOp.name} has constraint "type(value) = type(output)"'''
assert self.value.type == self.output.type, msg
class TestConstantOp(unittest.TestCase):
def test_constant_op(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
expected = """%0 = "stablehlo.constant"() <{"value" = dense<1> : tensor<1xsi64>}> : () -> tensor<1xsi64>"""
observed = str(ConstantOp(val, ty))
assert expected == observed
def test_raises_error(self):
from eff_types import si64
ty = TensorType(si64, [1])
val = DenseIntOrFPElementsAttr.from_list(ty, [1])
expected = """%0 = "stablehlo.constant"() <{"value" = dense<1> : tensor<1xsi64>}> : () -> tensor<1xsi64>"""
ty2 = TensorType(si64, [2])
op = ConstantOp(val, ty2)
with self.assertRaises(AssertionError):
op.verify_()
# ###### ### # # ####### ##### #######
# # # # # # # # # # #
# # # # # # # # # #
# # # # # # # ##### # #
# # # # ####### # # # #
# # # # # # # # # # #
# ###### ### # # ####### ####### ##### #
from xdsl.ir import Dialect
StableHLO = Dialect(
"stablehlo",
[
AbsOp,
AddOp,
AfterAllOp,
AllGatherOp,
AndOp,
Atan2Op,
BitcastConvertOp,
CaseOp,
ConstantOp,
ConvertOp,
ReturnOp,
],
[
TokenType
]
)
if "__main__" == __name__:
unittest.main()