Skip to content

Commit 654dd36

Browse files
committed
remove ghost nodes left over after TP DAG xforms
1 parent 7f8fb0e commit 654dd36

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

grudge/array_context.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,34 @@ def cached_data_wrapper_if_present(ary):
216216
return dag
217217

218218

219+
def remove_redundant_tensor_product_reshapes(ary):
220+
# FIXME: variable names can be more clear
221+
if isinstance(ary, pt.Reshape):
222+
if isinstance(ary.array, pt.Reshape):
223+
if ary.array.array.shape == ary.shape:
224+
return ary.array.array
225+
226+
return ary
227+
228+
229+
def remove_redundant_index_lambda_expressions(ary):
230+
# FIXME: this can be made much more robust
231+
if isinstance(ary, pt.IndexLambda):
232+
if len(ary.shape) >= 3:
233+
if 0.0 in ary.expr.children:
234+
return list(ary.bindings.values())[0]
235+
236+
return ary
237+
238+
219239
class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
220240
"""Inherits from :class:`meshmode.array_context.PytatoPyOpenCLArrayContext`.
221241
Extends it to understand :mod:`grudge`-specific transform metadata. (Of
222242
which there isn't any, for now.)
223243
"""
224244

225-
dot_codes_before: list[str]
226-
dot_codes_after: list[str]
245+
dot_codes_before: list[str] = []
246+
dot_codes_after: list[str] = []
227247

228248
def __init__(self, queue, allocator=None,
229249
*,
@@ -261,6 +281,12 @@ def transform_dag(self, dag):
261281
# step 3: create new operator out of inverse mass times stiffness
262282
dag = MassInverseTimesStiffnessSimplifier()(dag)
263283

284+
dag = pt.transform.map_and_copy(
285+
dag, remove_redundant_tensor_product_reshapes)
286+
287+
dag = pt.transform.map_and_copy(
288+
dag, remove_redundant_index_lambda_expressions)
289+
264290
# }}}
265291

266292
# dag = pt.transform.materialize_with_mpms(dag)

0 commit comments

Comments
 (0)