@@ -216,14 +216,34 @@ def cached_data_wrapper_if_present(ary):
216
216
return dag
217
217
218
218
219
+ def remove_redundant_tensor_product_reshapes (ary ):
220
+ # FIXME: variable names can be more clear
221
+ if isinstance (ary , pt .Reshape ):
222
+ if isinstance (ary .array , pt .Reshape ):
223
+ if ary .array .array .shape == ary .shape :
224
+ return ary .array .array
225
+
226
+ return ary
227
+
228
+
229
+ def remove_redundant_index_lambda_expressions (ary ):
230
+ # FIXME: this can be made much more robust
231
+ if isinstance (ary , pt .IndexLambda ):
232
+ if len (ary .shape ) >= 3 :
233
+ if 0.0 in ary .expr .children :
234
+ return list (ary .bindings .values ())[0 ]
235
+
236
+ return ary
237
+
238
+
219
239
class PytatoPyOpenCLArrayContext (_PytatoPyOpenCLArrayContextBase ):
220
240
"""Inherits from :class:`meshmode.array_context.PytatoPyOpenCLArrayContext`.
221
241
Extends it to understand :mod:`grudge`-specific transform metadata. (Of
222
242
which there isn't any, for now.)
223
243
"""
224
244
225
- dot_codes_before : list [str ]
226
- dot_codes_after : list [str ]
245
+ dot_codes_before : list [str ] = []
246
+ dot_codes_after : list [str ] = []
227
247
228
248
def __init__ (self , queue , allocator = None ,
229
249
* ,
@@ -261,6 +281,12 @@ def transform_dag(self, dag):
261
281
# step 3: create new operator out of inverse mass times stiffness
262
282
dag = MassInverseTimesStiffnessSimplifier ()(dag )
263
283
284
+ dag = pt .transform .map_and_copy (
285
+ dag , remove_redundant_tensor_product_reshapes )
286
+
287
+ dag = pt .transform .map_and_copy (
288
+ dag , remove_redundant_index_lambda_expressions )
289
+
264
290
# }}}
265
291
266
292
# dag = pt.transform.materialize_with_mpms(dag)
0 commit comments