From 8cc57a6feae0f770ef4b7939d80cd446a261ea66 Mon Sep 17 00:00:00 2001
From: Pratyai Mazumder <pratyai.mazumder@gmail.com>
Date: Thu, 24 Oct 2024 18:34:50 +0200
Subject: [PATCH] JUST CHECKING CI

---
 dace/libraries/standard/nodes/reduce.py       |   3 +-
 .../dataflow/redundant_array.py               | 236 +++++-------------
 2 files changed, 62 insertions(+), 177 deletions(-)

diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py
index fa231c07f2..68cb45e5a7 100644
--- a/dace/libraries/standard/nodes/reduce.py
+++ b/dace/libraries/standard/nodes/reduce.py
@@ -103,7 +103,8 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
                 'reduce_init', {'_o%d' % i: '0:%s' % symstr(d)
                                 for i, d in enumerate(outedge.data.subset.size())}, {},
                 '__out = %s' % node.identity,
-                {'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in range(output_dims)]))},
+                # {'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in range(output_dims)]))},
+                {'__out': dace.Memlet.simple('_out', ','.join(['_o%d' % i for i in osqdim]))},
                 external_edges=True)
         else:
             nstate = nsdfg.add_state()
diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py
index 7b241ff9cd..b0b9ecc843 100644
--- a/dace/transformation/dataflow/redundant_array.py
+++ b/dace/transformation/dataflow/redundant_array.py
@@ -7,9 +7,10 @@
 from typing import Dict, List, Optional, Tuple
 
 import networkx as nx
+from dace.subsets import Range, Indices, SubrangeMapper
 from networkx.exception import NetworkXError, NodeNotFound
 
-from dace import data, dtypes
+from dace import data, dtypes, Memlet
 from dace import memlet as mm
 from dace import subsets, symbolic
 from dace.config import Config
@@ -24,7 +25,7 @@
 def _validate_subsets(edge: graph.MultiConnectorEdge,
                       arrays: Dict[str, data.Data],
                       src_name: str = None,
-                      dst_name: str = None) -> Tuple[subsets.Subset]:
+                      dst_name: str = None) -> Tuple[subsets.Subset, subsets.Subset]:
     """ Extracts and validates src and dst subsets from the edge. """
 
     # Find src and dst names
@@ -499,182 +500,65 @@ def _is_reshaping_memlet(
 
         return True
 
-    def apply(self, graph, sdfg):
-        in_array = self.in_array
-        out_array = self.out_array
-        in_desc = sdfg.arrays[in_array.data]
-        out_desc = sdfg.arrays[out_array.data]
+    def apply(self, graph: SDFGState, sdfg: SDFG):
+        # The pattern is A ---> B, and we want to remove A
+        A, B = self.in_array, self.out_array
 
         # 1. Get edge e1 and extract subsets for arrays A and B
-        e1 = graph.edges_between(in_array, out_array)[0]
-        a1_subset, b_subset = _validate_subsets(e1, sdfg.arrays)
-
-        # View connected to a view: simple case
-        if (isinstance(in_desc, data.View) and isinstance(out_desc, data.View)):
-            simple_case = True
-            for e in graph.in_edges(in_array):
-                if e.data.dst_subset is not None and a1_subset != e.data.dst_subset:
-                    simple_case = False
-                    break
-            if simple_case:
-                for e in graph.in_edges(in_array):
-                    for e2 in graph.memlet_tree(e):
-                        if e2 is e:
-                            continue
-                        if e2.data.data == in_array.data:
-                            e2.data.data = out_array.data
-                    new_memlet = copy.deepcopy(e.data)
-                    if new_memlet.data == in_array.data:
-                        new_memlet.data = out_array.data
-                    new_memlet.dst_subset = b_subset
-                    graph.add_edge(e.src, e.src_conn, out_array, e.dst_conn, new_memlet)
-                graph.remove_node(in_array)
-                try:
-                    if in_array.data in sdfg.arrays:
-                        sdfg.remove_data(in_array.data)
-                except ValueError:  # Used somewhere else
-                    pass
-                return
-
-        # Find extraneous A or B subset dimensions
-        a_dims_to_pop = []
-        b_dims_to_pop = []
-        bset = b_subset
-        popped = []
-        if a1_subset and b_subset and a1_subset.dims() != b_subset.dims():
-            a_size = a1_subset.size_exact()
-            b_size = b_subset.size_exact()
-            if a1_subset.dims() > b_subset.dims():
-                a_dims_to_pop = find_dims_to_pop(a_size, b_size)
-            else:
-                b_dims_to_pop = find_dims_to_pop(b_size, a_size)
-                bset, popped = pop_dims(b_subset, b_dims_to_pop)
-
-        from dace.libraries.standard import Reduce
-        reduction = False
-        for e in graph.in_edges(in_array):
-            if isinstance(e.src, Reduce) or (isinstance(e.src, nodes.NestedSDFG)
-                                             and len(in_desc.shape) != len(out_desc.shape)):
-                reduction = True
-
-        # If:
-        # 1. A reduce node is involved; or
-        # 2. A NestedSDFG node is involved and the arrays have different dimensionality; or
-        # 3. The memlet does not cover the removed array; or
-        # 4. Dimensions are mismatching (all dimensions are popped);
-        # create a view.
-        if (
-                reduction
-                or len(a_dims_to_pop) == len(in_desc.shape)
-                or any(m != a for m, a in zip(a1_subset.size(), in_desc.shape))
-        ):
-            self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
-            return in_array
-
-        # TODO: Fix me.
-        #  As described in [issue 1595](https://github.com/spcl/dace/issues/1595) the
-        #  transformation is unable to handle certain cases of reshaping Memlets
-        #  correctly and fixing this case has proven rather difficult. In a first
-        #  attempt the case of reshaping Memlets was forbidden (in the
-        #  `can_be_applied()` method), however, this caused other (useful) cases to
-        #  fail. For that reason such Memlets are transformed to Views.
-        #  This is a fix and it should be addressed.
-        if self._is_reshaping_memlet(graph=graph, edge=e1):
-            self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
-            return in_array
-
-        # Validate that subsets are composable. If not, make a view
-        try:
-            for e2 in graph.in_edges(in_array):
-                path = graph.memlet_tree(e2)
-                wcr = e1.data.wcr
-                wcr_nonatomic = e1.data.wcr_nonatomic
-                for e3 in path:
-                    # 2-a. Extract subsets for array B and others
-                    other_subset, a3_subset = _validate_subsets(e3, sdfg.arrays, dst_name=in_array.data)
-                    # 2-b. Modify memlet to match array B.
-                    dname = out_array.data
-                    src_is_data = False
-                    a3_subset.offset(a1_subset, negative=True)
-
-                    if a3_subset and a_dims_to_pop:
-                        aset, _ = pop_dims(a3_subset, a_dims_to_pop)
-                    else:
-                        aset = a3_subset
-
-                    compose_and_push_back(bset, aset, b_dims_to_pop, popped)
-        except (ValueError, NotImplementedError):
-            self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
-            print(f"CREATED VIEW(2): {in_array}")
-            return in_array
-
-        # 2. Iterate over the e2 edges and traverse the memlet tree
-        for e2 in graph.in_edges(in_array):
-            path = graph.memlet_tree(e2)
-            wcr = e1.data.wcr
-            wcr_nonatomic = e1.data.wcr_nonatomic
-            for e3 in path:
-                # 2-a. Extract subsets for array B and others
-                other_subset, a3_subset = _validate_subsets(e3, sdfg.arrays, dst_name=in_array.data)
-                # 2-b. Modify memlet to match array B.
-                dname = out_array.data
-                src_is_data = False
-                a3_subset.offset(a1_subset, negative=True)
-
-                if a3_subset and a_dims_to_pop:
-                    aset, _ = pop_dims(a3_subset, a_dims_to_pop)
-                else:
-                    aset = a3_subset
-
-                dst_subset = compose_and_push_back(bset, aset, b_dims_to_pop, popped)
-                # NOTE: This fixes the following case:
-                # Tasklet ----> A[subset] ----> ... -----> A
-                # Tasklet is not data, so it doesn't have an other subset.
-                if isinstance(e3.src, nodes.AccessNode):
-                    if e3.src.data == out_array.data:
-                        dname = e3.src.data
-                        src_is_data = True
-                    src_subset = other_subset
-                else:
-                    src_subset = None
-
-                subset = src_subset if src_is_data else dst_subset
-                other_subset = dst_subset if src_is_data else src_subset
-                e3.data.data = dname
-                e3.data.subset = subset
-                e3.data.other_subset = other_subset
-                wcr = wcr or e3.data.wcr
-                wcr_nonatomic = wcr_nonatomic or e3.data.wcr_nonatomic
-                e3.data.wcr = wcr
-                e3.data.wcr_nonatomic = wcr_nonatomic
-
-            # 2-c. Remove edge and add new one
-            graph.remove_edge(e2)
-            e2.data.wcr = wcr
-            e2.data.wcr_nonatomic = wcr_nonatomic
-            graph.add_edge(e2.src, e2.src_conn, out_array, e2.dst_conn, e2.data)
-
-            # 2-d. Fix strides in nested SDFGs
-            if in_desc.strides != out_desc.strides:
-                sources = []
-                if path.downwards:
-                    sources = [path.root().edge]
-                else:
-                    sources = [e for e in path.leaves()]
-                for source_edge in sources:
-                    if not isinstance(source_edge.src, nodes.NestedSDFG):
-                        continue
-                    conn = source_edge.src_conn
-                    inner_desc = source_edge.src.sdfg.arrays[conn]
-                    inner_desc.strides = out_desc.strides
-
-        # Finally, remove in_array node
-        graph.remove_node(in_array)
-        try:
-            if in_array.data in sdfg.arrays:
-                sdfg.remove_data(in_array.data)
-        except ValueError:  # Already in use (e.g., with Views)
-            pass
+        e_ab = graph.edges_between(A, B)
+        assert len(e_ab) == 1
+        e_ab = e_ab[0]
+        print(e_ab)
+        a_subset, b_subset = _validate_subsets(e_ab, sdfg.arrays)
+        # Other cases should have been handled in `can_be_applied()`.
+        assert isinstance(a_subset, Range) or isinstance(a_subset, Indices)
+        assert isinstance(b_subset, Range) or isinstance(b_subset, Indices)
+        # And this should be self-evident.
+        assert a_subset.volume_exact() == b_subset.volume_exact()
+
+        for ie in graph.in_edges(A):
+            # The pattern is now: C -(ie)-> A ---> B
+            path = graph.memlet_tree(ie)
+            for pe in path:
+                # The pattern is now: C -(pe)-> C1 ---> ... ---> A ---> B
+                print('PE:', pe)
+                c_subset, a0_subset = _validate_subsets(pe, sdfg.arrays, dst_name=A.data)
+                print('c, a0:', c_subset, a0_subset)
+                if a0_subset is None:
+                    continue
+
+                # Other cases should have been handled already in `can_be_applied()`.
+                assert c_subset is None or isinstance(c_subset, Range) or isinstance(c_subset, Indices)
+                if c_subset is not None:
+                    assert c_subset.volume_exact() == a0_subset.volume_exact()
+                assert isinstance(a0_subset, Range) or isinstance(a0_subset, Indices)
+                print('SUBS:', a_subset, '|', a0_subset)
+                assert a_subset.dims() == a0_subset.dims()
+                # assert a_subset.covers_precise(a0_subset)
+                # assert all(b0 >= b and (b0 - b) % s == 0 and s0 % s == 0
+                #            for (b, e, s), (b0, e0, s0) in zip(a_subset.ndrange(), a0_subset.ndrange()))
+
+                # Find out where `a0_subset` maps to, given that `a_subset` precisely maps to `b_subset`.
+                # `reshapr` describes how `a_subset` maps to `b_subset`.
+                reshapr = SubrangeMapper(a_subset, b_subset)
+                # `b0_subset` is the mapping for `a0_subset`.
+                b0_subset = reshapr.map(a0_subset)
+                print(a_subset, b_subset)
+                print(a0_subset, b0_subset)
+                assert isinstance(b0_subset, Range) or isinstance(b0_subset, Indices)
+                assert b0_subset.volume_exact() == a0_subset.volume_exact()
+
+                # Now we can replace the path: C -(pe)-> C1 ---> ... ---> A ---> B
+                # with an equivalent path: C -(pe)-> C1 ---> ... ---> B
+                dst, dst_conn = (B, None) if pe.dst is A else (pe.dst, pe.dst_conn)
+                print('dst:', pe.src, dst, dst_conn)
+                print('mem:', B.data, b0_subset, c_subset)
+                e = graph.add_edge(pe.src, pe.src_conn, dst, dst_conn,
+                                   memlet=Memlet(data=B.data, subset=b0_subset, other_subset=c_subset))
+                print('e:', e)
+                graph.remove_edge(pe)
+        graph.remove_node(A)
+        sdfg.remove_data(A.data)
 
 
 class RedundantSecondArray(pm.SingleStateTransformation):