diff --git a/.gitignore b/.gitignore index 03c801a68f..c225aacf3f 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,8 @@ src.VC.VC.opendb .dacecache/ # Ignore dacecache if added as a symlink .dacecache +# Local configuration file +.dace.conf out.sdfg *.out results.log diff --git a/dace/codegen/tools/type_inference.py b/dace/codegen/tools/type_inference.py index 26b369fa9d..6c64688f5b 100644 --- a/dace/codegen/tools/type_inference.py +++ b/dace/codegen/tools/type_inference.py @@ -420,6 +420,8 @@ def _Attribute(t, symbols, inferred_symbols): if (isinstance(inferred_type, dtypes.pointer) and isinstance(inferred_type.base_type, dtypes.struct) and t.attr in inferred_type.base_type.fields): return inferred_type.base_type.fields[t.attr] + elif isinstance(inferred_type, dtypes.struct) and t.attr in inferred_type.fields: + return inferred_type.fields[t.attr] return inferred_type diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 62023f1d1b..2d65a82423 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3287,11 +3287,82 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if isinstance(node.value, (ast.Tuple, ast.List)): for n in node.value.elts: results.extend(self._gettype(n)) + elif isinstance(node.value, ast.Name) and node.value.id in self.sdfg.arrays and isinstance( + self.sdfg.arrays[node.value.id], + data.Array) and self.sdfg.arrays[node.value.id].total_size == len(elts): + # In the case where the rhs is an array (not being accessed with a slice) of exactly the same length as the + # number of elements in the lhs, the array can be expanded with a series of slice/subscript accesses to + # constant indexes (according to the number of elements in the lhs). These expansions can then be used to + # perform an unpacking assignment, similar to what Python does natively. + for i in range(len(elts)): + const_node = NumConstant(i) + ast.copy_location(const_node, node) + slice_node = ast.Subscript(node.value, const_node, node.value.ctx) + ast.copy_location(slice_node, node) + results.extend(self._gettype(slice_node)) else: results.extend(self._gettype(node.value)) if len(results) != len(elts): - raise DaceSyntaxError(self, node, 'Function returns %d values but %d provided' % (len(results), len(elts))) + if len(elts) == 1 and len(results) > 1 and isinstance(elts[0], ast.Name): + # If multiple results are being assigned to one element, attempt to perform a packing assignment, + # i.e., similar to Python. This constructs a tuple / array of the correct size for the lhs according to + # the number of elements on the rhs, and then assigns to individual array / tuple positions using the + # correct slice accesses. If the datacontainer on the lhs is not defined yet, it is created here. + # If it already exists, only succeed if the size matches the number of elements on the rhs, similar to + # Python. All elements on the rhs must have a common datatype for this to work. + elt = elts[0] + desc = None + if elt.id in self.sdfg.arrays: + desc = self.sdfg.arrays[elt.id] + if desc is not None and not isinstance(desc, data.Array): + raise DaceSyntaxError( + self, node, 'Cannot assign %d function return values to %s due to incompatible type' % + (len(results), elt.id)) + elif desc is not None and desc.total_size != len(results): + raise DaceSyntaxError( + self, node, 'Cannot assign %d function return values to a data container of size %s' % + (len(results), str(desc.total_size))) + + # Determine the result data type and make sure there is only one. + res_dtype = None + for res, _ in results: + if not (isinstance(res, str) and res in self.sdfg.arrays): + res_dtype = None + break + res_data = self.sdfg.arrays[res] + if res_dtype is None: + res_dtype = res_data.dtype + elif res_dtype != res_data.dtype: + res_dtype = None + break + if res_dtype is None: + raise DaceSyntaxError( + self, node, + 'Cannot determine common result datatype for %d function return values' % (len(results))) + + res_name = elt.id + if desc is None: + # If no data container exists yet, create it. + res_name, desc = self.sdfg.add_transient(res_name, (len(results), ), res_dtype) + self.variables[res_name] = res_name + + # Create the correct slice accesses. + new_elts = [] + for i in range(len(results)): + name_node = ast.Name(res_name, elt.ctx) + ast.copy_location(name_node, elt) + const_node = NumConstant(i) + ast.copy_location(const_node, elt) + slice_node = ast.Subscript(name_node, const_node, elt.ctx) + ast.copy_location(slice_node, elt) + new_elts.append(slice_node) + + elts = new_elts + else: + raise DaceSyntaxError( + self, node, + 'Function returns %d values but assigning to %d expected values' % (len(results), len(elts))) defined_vars = {**self.variables, **self.scope_vars} defined_arrays = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.scope_arrays}) diff --git a/dace/memlet.py b/dace/memlet.py index 85bd0a348d..ae09d4da43 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -147,8 +147,20 @@ def __init__(self, @staticmethod def from_memlet(memlet: 'Memlet') -> 'Memlet': - sbs = subsets.Range(memlet.subset.ndrange()) if memlet.subset is not None else None - osbs = subsets.Range(memlet.other_subset.ndrange()) if memlet.other_subset is not None else None + if memlet.subset is not None: + if isinstance(memlet.subset, subsets.SubsetUnion): + sbs = subsets.SubsetUnion(memlet.subset.subset_list) + else: + sbs = subsets.Range(memlet.subset.ndrange()) + else: + sbs = None + if memlet.other_subset is not None: + if isinstance(memlet.other_subset, subsets.SubsetUnion): + osbs = subsets.SubsetUnion(memlet.other_subset.subset_list) + else: + osbs = subsets.Range(memlet.other_subset.ndrange()) + else: + osbs = None result = Memlet(data=memlet.data, subset=sbs, other_subset=osbs, diff --git a/dace/properties.py b/dace/properties.py index 82be72f9fd..ee5170ecbd 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1116,6 +1116,8 @@ def to_string(val): return sbs.Range.ndslice_to_string(val) elif isinstance(val, sbs.Indices): return sbs.Indices.__str__(val) + elif isinstance(val, sbs.SubsetUnion): + return sbs.SubsetUnion.__str__(val) elif val is None: return 'None' raise TypeError diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 205222e8b2..77bedaa0c7 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -5,6 +5,8 @@ import ast from copy import deepcopy as dcpy from collections.abc import KeysView + +from matplotlib import category import dace import itertools import dace.serialize @@ -275,9 +277,11 @@ class AccessNode(Node): instrument = EnumProperty(dtype=dtypes.DataInstrumentationType, desc="Instrument data contents at this access", - default=dtypes.DataInstrumentationType.No_Instrumentation) + default=dtypes.DataInstrumentationType.No_Instrumentation, + category='Instrumentation') instrument_condition = CodeProperty(desc="Condition under which to trigger the instrumentation", - default=CodeBlock("1", language=dtypes.Language.CPP)) + default=CodeBlock("1", language=dtypes.Language.CPP), + category='Instrumentation') def __init__(self, data, debuginfo=None): super(AccessNode, self).__init__() diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 2983ec3c63..d8c811f431 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -747,7 +747,7 @@ def propagate_states(sdfg: 'SDFG', concretize_dynamic_unbounded: bool = False) - if sdfg.using_explicit_control_flow: # Avoid cyclic imports from dace.transformation.pass_pipeline import Pipeline - from dace.transformation.passes.analysis import StatePropagation + from dace.transformation.passes.analysis.propagation import StatePropagation state_prop_pipeline = Pipeline([StatePropagation()]) state_prop_pipeline.apply_pass(sdfg, {}) @@ -1456,24 +1456,48 @@ def propagate_subset(memlets: List[Memlet], else: subset = md.subset - for pclass in MemletPattern.extensions(): - pattern = pclass() - if pattern.can_be_applied([subset], variable_context, rng, [md]): - tmp_subset = pattern.propagate(arr, [subset], rng) - break + if isinstance(subset, subsets.SubsetUnion): + tmp_subset_list = [] + for sub in subset.subset_list: + for pclass in MemletPattern.extensions(): + pattern = pclass() + if pattern.can_be_applied([sub], variable_context, rng, [md]): + sub_tmp_subset = pattern.propagate(arr, [sub], rng) + break + else: + # No patterns found. Emit a warning and propagate the entire + # array whenever symbols are used + warnings.warn('Cannot find appropriate memlet pattern to ' + 'propagate %s through %s' % (str(sub), str(rng))) + entire_array = subsets.Range.from_array(arr) + paramset = set(map(str, params)) + # Fill in the entire array only if one of the parameters appears in the + # free symbols list of the subset dimension + sub_tmp_subset = subsets.Range([ + ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s + for s, ea in zip(sub, entire_array) + ]) + tmp_subset_list.append(sub_tmp_subset) + tmp_subset = subsets.SubsetUnion(tmp_subset_list) else: - # No patterns found. Emit a warning and propagate the entire - # array whenever symbols are used - warnings.warn('Cannot find appropriate memlet pattern to ' - 'propagate %s through %s' % (str(subset), str(rng))) - entire_array = subsets.Range.from_array(arr) - paramset = set(map(str, params)) - # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension - tmp_subset = subsets.Range([ - ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s - for s, ea in zip(subset, entire_array) - ]) + for pclass in MemletPattern.extensions(): + pattern = pclass() + if pattern.can_be_applied([subset], variable_context, rng, [md]): + tmp_subset = pattern.propagate(arr, [subset], rng) + break + else: + # No patterns found. Emit a warning and propagate the entire + # array whenever symbols are used + warnings.warn('Cannot find appropriate memlet pattern to ' + 'propagate %s through %s' % (str(subset), str(rng))) + entire_array = subsets.Range.from_array(arr) + paramset = set(map(str, params)) + # Fill in the entire array only if one of the parameters appears in the + # free symbols list of the subset dimension + tmp_subset = subsets.Range([ + ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s + for s, ea in zip(subset, entire_array) + ]) # Union edges as necessary if new_subset is None: diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 09b2325d1c..f099625be4 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast import collections import copy diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 30640306cd..e439e0369f 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1208,6 +1208,11 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): _sdfg: Optional['SDFG'] = None _parent_graph: Optional['ControlFlowRegion'] = None + certain_reads = DictProperty(key_type=str, value_type=mm.Memlet) + possible_reads = DictProperty(key_type=str, value_type=mm.Memlet) + certain_writes = DictProperty(key_type=str, value_type=mm.Memlet) + possible_writes = DictProperty(key_type=str, value_type=mm.Memlet) + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optional['ControlFlowRegion'] = None): super(ControlFlowBlock, self).__init__() self._label = label @@ -1218,6 +1223,10 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optio self.pre_conditions = {} self.post_conditions = {} self.invariant_conditions = {} + self.certain_reads = {} + self.possible_reads = {} + self.certain_writes = {} + self.possible_writes = {} self.guid = generate_element_id(self) diff --git a/dace/subsets.py b/dace/subsets.py index 0fdc36c22e..a176ecf958 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1,6 +1,7 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from copy import deepcopy import dace.serialize -from dace import data, symbolic, dtypes +from dace import symbolic import re import sympy as sp from functools import reduce @@ -106,7 +107,7 @@ def covers_precise(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) # If self does not cover other with a bounding box union, return false. @@ -820,35 +821,107 @@ def replace(self, repl_dict): rs.subs(repl_dict) if symbolic.issymbolic(rs) else rs) self.tile_sizes[i] = (ts.subs(repl_dict) if symbolic.issymbolic(ts) else ts) - def intersects(self, other: 'Range'): + def intersection(self, other: 'Range') -> Optional['Range']: type_error = False + expected_length = len(self.ranges) + if expected_length != len(other.ranges): + raise ValueError('Unable to compute the intersection of different length ranges.') + + intersected_ranges = [] for i, (rng, orng) in enumerate(zip(self.ranges, other.ranges)): if (rng[2] != 1 or orng[2] != 1 or self.tile_sizes[i] != 1 or other.tile_sizes[i] != 1): # TODO: This function does not consider strides or tiles - return None + raise NotImplementedError('^This function does not yet consider strides or tiles') # Special case: ranges match - if rng[0] == orng[0] or rng[1] == orng[1]: + if rng[0] == orng[0] and rng[1] == orng[1]: + intersected_ranges.append(rng) continue # Since conditions can be indeterminate, we check them separately # for being False, then make a check that may raise a TypeError cond1 = (rng[0] <= orng[1]) cond2 = (orng[0] <= rng[1]) + cond3 = (rng[0] <= orng[0]) + cond4 = (rng[1] <= orng[1]) # NOTE: We have to use the "==" operator because of SymPy returning # a special boolean type! try: if cond1 == False or cond2 == False: - return False + return None if not (cond1 and cond2): - return False + return None + + if cond3 == True: + rng_start = orng[0] + else: + rng_start = rng[0] + if cond4 == True: + rng_end = rng[1] + else: + rng_end = orng[1] + intersected_ranges.append([rng_start, rng_end, rng[2]]) except TypeError: # cannot determine truth value of Relational type_error = True if type_error: raise TypeError("cannot determine truth value of Relational") - return True + if len(intersected_ranges) != expected_length: + return None + + return Range(intersected_ranges) + + def intersects(self, other: 'Range'): + return self.intersection(other) is not None + + def difference(self, other: 'Range') -> Subset: + isect = self.intersection(other) + if isect is None: + return self + diff_ranges = [[]] + dims_cleared = [] + for i, (r1, r2) in enumerate(zip(self.ranges, isect.ranges)): + if r2[0] == r1[0]: + # Intersection over the start of the current range. + if r2[1] == r1[1]: + for dr in diff_ranges: + dr.append(r1) + dims_cleared.append(True) + else: + for dr in diff_ranges: + dr.append((r2[1] + self.ranges[i][2], r1[1], r1[2])) + dims_cleared.append(False) + elif r2[1] == r1[1]: + # Intersection over the end of the current range. + if r2[0] == r1[0]: + for dr in diff_ranges: + dr.append(r1) + dims_cleared.append(True) + else: + for dr in diff_ranges: + dr.append((r1[0], r2[0] - self.ranges[i][2], r1[2])) + dims_cleared.append(False) + else: + # Intersection completely contained inside the current range, split into subset union is necessary. + split_left = (r1[0], r2[0] - self.ranges[i][2], r1[2]) + split_right = (r2[1] + self.ranges[i][2], r1[1], r1[2]) + for dr in diff_ranges: + dr.append(split_left) + dr_copy = deepcopy(diff_ranges) + for dr in dr_copy: + dr[-1] = split_right + diff_ranges.append(dr) + dims_cleared.append(False) + if all(dims_cleared): + return Range([]) + elif len(diff_ranges) == 1: + return Range(diff_ranges[0]) + else: + subset_list = [] + for dr in diff_ranges: + subset_list.append(Range(dr)) + return SubsetUnion(subset_list) @dace.serialize.serializable @@ -1096,21 +1169,46 @@ class SubsetUnion(Subset): Wrapper subset type that stores multiple Subsets in a list. """ + subset_list: List[Range] + def __init__(self, subset): - self.subset_list: list[Subset] = [] + self.subset_list = [] if isinstance(subset, SubsetUnion): self.subset_list = subset.subset_list elif isinstance(subset, list): for subset in subset: if not subset: break - if isinstance(subset, (Range, Indices)): + if isinstance(subset, Range): self.subset_list.append(subset) + elif isinstance(subset, Indices): + self.subset_list.append(Range.from_indices(subset)) + elif isinstance(subset, SubsetUnion): + self.subset_list.extend(subset.subset_list) else: raise NotImplementedError elif isinstance(subset, (Range, Indices)): self.subset_list = [subset] + def offset(self, other, negative, indices=None, offset_end=True): + for subs in self.subset_list: + subs.offset(other, negative, indices, offset_end) + + def offset_new(self, other, negative, indices=None, offset_end=True): + new_subsets = [] + for subs in self.subset_list: + new_subsets.append(subs.offset_new(other, negative, indices, offset_end)) + return SubsetUnion(new_subsets) + + def unsqueeze(self, axes: Sequence[int]) -> List[List[int]]: + result = [] + for subs in self.subset_list: + if isinstance(subs, Range): + result.append(subs.unsqueeze(axes)) + else: + result.append([]) + return result + def covers(self, other): """ Returns True if this SubsetUnion covers another subset (using a bounding box). @@ -1154,6 +1252,37 @@ def __str__(self): string += " " string += subset.__str__() return string + + def __len__(self): + return len(self.subset_list[0]) + + def __getitem__(self, key): + dim_ranges = [] + for subs in self.subset_list: + dim_ranges.append(Range([subs[key]])) + return SubsetUnion(dim_ranges) + + def to_json(self): + ret = [] + + for sbs in self.subset_list: + ret.append(sbs.to_json()) + + return {'type': 'SubsetUnion', 'subset_list': ret} + + @staticmethod + def from_json(obj, context=None): + if not isinstance(obj, dict): + raise TypeError("Expected dict, got {}".format(type(obj))) + if obj['type'] != 'SubsetUnion': + raise TypeError("from_json of class \"SubsetUnion\" called on json " + "with type %s (expected 'SubsetUnion')" % obj['type']) + + subset_list = [] + for r in obj['subset_list']: + subset_list.append(Range.from_json(r, context)) + + return SubsetUnion(subset_list) def dims(self): if not self.subset_list: @@ -1166,12 +1295,74 @@ def union(self, other: Subset): if isinstance(other, SubsetUnion): self.subset_list += other.subset_list elif isinstance(other, Indices) or isinstance(other, Range): - self.subset_list.append(other) + if not other in self.subset_list: + self.subset_list.append(other) else: raise TypeError except TypeError: # cannot determine truth value of Relational return None + def intersection(self, other: Subset) -> 'SubsetUnion': + try: + if isinstance(other, SubsetUnion): + intersections = [] + for subs in self.subset_list: + for osubs in other.subset_list: + isect = intersection(subs, osubs) + if isect is not None: + intersections.append(isect) + if intersections: + return SubsetUnion(intersections) + elif isinstance(other, (Indices, Range)): + intersections = [] + for subs in self.subset_list: + isect = intersection(subs, other) + if isect is not None: + intersections.append(isect) + if intersections: + return SubsetUnion(intersections) + else: + raise TypeError + except TypeError: + return None + + def intersects(self, other: Subset): + return self.intersection(other) is not None + + def difference(self, other: Subset) -> 'SubsetUnion': + try: + if isinstance(other, SubsetUnion): + differences = [] + for subs in self.subset_list: + sub_diff = subs + for osubs in other.subset_list: + sub_diff = difference(sub_diff, osubs) + if sub_diff is not None: + if isinstance(sub_diff, SubsetUnion): + for s in sub_diff.subset_list: + differences.append(s) + else: + differences.append(sub_diff) + if differences: + return SubsetUnion(differences) + elif isinstance(other, (Indices, Range)): + differences = [] + for subs in self.subset_list: + diff = difference(subs, other) + if diff is not None: + if isinstance(diff, SubsetUnion): + for sub_diff in diff.subset_list: + differences.append(sub_diff) + else: + differences.append(diff) + if differences: + return SubsetUnion(differences) + else: + raise TypeError + except TypeError: + return None + pass + @property def free_symbols(self) -> Set[str]: result = set() @@ -1195,6 +1386,23 @@ def num_elements(self): return min + def to_bounding_box_subset(self) -> Union[Range, None]: + min_elem = [None] * len(self) + max_elem = [None] * len(self) + for subs in self.subset_list: + for i, rng in enumerate(subs): + try: + if min_elem[i] is None or rng[0] < min_elem[i]: + min_elem[i] = rng[0] + if max_elem[i] is None or rng[1] > max_elem[i]: + max_elem[i] = rng[1] + except: + return None + if any([x is None for x in min_elem]) or any([x is None for x in max_elem]): + return None + + new_rngs = [(mini, maxi, 1) for mini, maxi in zip(min_elem, max_elem)] + return Range(new_rngs) def _union_special_cases(arb: symbolic.SymbolicType, brb: symbolic.SymbolicType, are: symbolic.SymbolicType, @@ -1261,17 +1469,13 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: return Range(result) - - def union(subset_a: Subset, subset_b: Subset) -> Subset: """ Compute the union of two Subset objects. - If the subsets are not of the same type, degenerates to bounding-box - union. + If the subsets are not of the same type, degenerates to bounding-box union. :param subset_a: The first subset. :param subset_b: The second subset. - :return: A Subset object whose size is at least the union of the two - inputs. If union failed, returns None. + :return: A Subset object whose size is at least the union of the two inputs. If union failed, returns None. """ try: @@ -1281,8 +1485,7 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: return subset_b elif subset_a is None and subset_b is None: raise TypeError('Both subsets cannot be None') - elif isinstance(subset_a, SubsetUnion) or isinstance( - subset_b, SubsetUnion): + elif isinstance(subset_a, SubsetUnion) or isinstance(subset_b, SubsetUnion): return list_union(subset_a, subset_b) elif type(subset_a) != type(subset_b): return bounding_box_union(subset_a, subset_b) @@ -1349,6 +1552,46 @@ def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]: subset_b = Range.from_indices(subset_b) if type(subset_a) is type(subset_b): return subset_a.intersects(subset_b) + elif isinstance(subset_a, SubsetUnion): + return subset_a.intersects(subset_b) + elif isinstance(subset_b, SubsetUnion): + return subset_b.intersects(subset_a) return None except TypeError: # cannot determine truth value of Relational return None + +def intersection(subset_a: Subset, subset_b: Subset) -> Optional[Subset]: + try: + if subset_a is None or subset_b is None: + return None + if isinstance(subset_a, Indices): + subset_a = Range.from_indices(subset_a) + if isinstance(subset_b, Indices): + subset_b = Range.from_indices(subset_b) + if type(subset_a) is type(subset_b): + return subset_a.intersection(subset_b) + elif isinstance(subset_a, SubsetUnion): + return subset_a.intersection(subset_b) + elif isinstance(subset_b, SubsetUnion): + return subset_b.intersection(subset_a) + return None + except TypeError: + return None + +def difference(subset_a: Subset, subset_b: Subset) -> Optional[Subset]: + try: + if subset_a is None or subset_b is None: + return None + if isinstance(subset_a, Indices): + subset_a = Range.from_indices(subset_a) + if isinstance(subset_b, Indices): + subset_b = Range.from_indices(subset_b) + if type(subset_a) is type(subset_b): + return subset_a.difference(subset_b) + elif isinstance(subset_a, SubsetUnion): + return subset_a.difference(subset_b) + elif isinstance(subset_b, SubsetUnion): + return subset_b.difference(subset_a) + return None + except TypeError: + return None diff --git a/dace/symbolic.py b/dace/symbolic.py index 98ffa008d3..fdcfdef270 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1175,6 +1175,8 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: 'var': sympy.Symbol('var'), 'root': sympy.Symbol('root'), 'arg': sympy.Symbol('arg'), + 'id': sympy.Symbol('id'), + 'diag': sympy.Symbol('diag'), 'Is': Is, 'IsNot': IsNot, 'BitwiseAnd': bitwise_and, diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py index 5e5072ff32..c69d692922 100644 --- a/dace/transformation/dataflow/redundant_array.py +++ b/dace/transformation/dataflow/redundant_array.py @@ -4,7 +4,7 @@ import copy import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import networkx as nx from networkx.exception import NetworkXError, NodeNotFound @@ -12,7 +12,6 @@ from dace import data, dtypes from dace import memlet as mm from dace import subsets, symbolic -from dace.config import Config from dace.sdfg import SDFG, SDFGState, graph, nodes from dace.sdfg import utils as sdutil from dace.transformation import helpers diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index b703dd402d..a2655341ee 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -6,7 +6,7 @@ from dace.properties import CodeBlock from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnBlock -from dace.subsets import Range, Subset, union +from dace.subsets import Range, Subset, SubsetUnion, union import dace.subsets as subsets from typing import Dict, Iterable, List, Optional, Tuple, Set, Union @@ -784,8 +784,13 @@ def unsqueeze_memlet(internal_memlet: Memlet, 'External memlet: %s\nInternal memlet: %s' % (external_memlet, internal_memlet)) original_minima = external_memlet.subset.min_element() for i in set(range(len(original_minima))): - rb, re, rs = result.subset.ranges[i] - result.subset.ranges[i] = (original_minima[i], re, rs) + if isinstance(result.subset, SubsetUnion): + for subs in result.subset.subset_list: + rb, re, rs = subs.ranges[i] + subs.ranges[i] = (original_minima[i], re, rs) + else: + rb, re, rs = result.subset.ranges[i] + result.subset.ranges[i] = (original_minima[i], re, rs) # TODO: Offset rest of memlet according to other_subset if external_memlet.other_subset is not None: raise NotImplementedError diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 31e751bb6a..a6ff90c491 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -427,11 +427,26 @@ def apply(self, state: SDFGState, sdfg: SDFG): orig_data: Dict[Union[nodes.AccessNode, MultiConnectorEdge], str] = {} for node in nstate.nodes(): - if isinstance(node, nodes.AccessNode) and node.data in repldict: - orig_data[node] = node.data - node.data = repldict[node.data] + if isinstance(node, nodes.AccessNode): + if '.' in node.data: + parts = node.data.split('.') + root_container = parts[0] + if root_container in repldict: + orig_data[node] = node.data + full_data = [repldict[root_container]] + parts[1:] + node.data = '.'.join(full_data) + elif node.data in repldict: + orig_data[node] = node.data + node.data = repldict[node.data] for edge in nstate.edges(): - if edge.data.data in repldict: + if '.' in edge.data.data: + parts = edge.data.data.split('.') + root_container = parts[0] + if root_container in repldict: + orig_data[edge] = edge.data.data + full_data = [repldict[root_container]] + parts[1:] + edge.data.data = '.'.join(full_data) + elif edge.data.data in repldict: orig_data[edge] = edge.data.data edge.data.data = repldict[edge.data.data] diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index bca7626b85..51691c329c 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -227,6 +227,9 @@ class StatePass(Pass): CATEGORY: str = 'Helper' + top_down = properties.Property(dtype=bool, default=False, + desc='Whether or not to apply top down (i.e., parents before children)') + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[SDFGState, Optional[Any]]]: """ Applies the pass to states of the given SDFG by calling ``apply`` on each state. @@ -239,11 +242,14 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D if nothing was returned. """ result = {} - for sd in sdfg.all_sdfgs_recursive(): - for state in sd.nodes(): - retval = self.apply(state, pipeline_results) - if retval is not None: - result[state] = retval + for cfr in sdfg.all_control_flow_regions(recursive=True, parent_first=self.top_down): + if isinstance(cfr, ConditionalBlock): + continue + for state in cfr.nodes(): + if isinstance(state, SDFGState): + retval = self.apply(state, pipeline_results) + if retval is not None: + result[state] = retval if not result: return None diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 94c24399ee..9542b3cf31 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -2,21 +2,25 @@ from collections import defaultdict, deque from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +import networkx as nx import sympy +from networkx.algorithms import shortest_paths as nxsp -from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion -from dace.subsets import Range -from dace.transformation import pass_pipeline as ppl, transformation -from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic -from dace.sdfg.graph import Edge -from dace.sdfg import nodes as nd, utils as sdutil +from dace import SDFG, InterstateEdge, Memlet, SDFGState +from dace import data as dt +from dace import properties, symbolic +from dace.sdfg import nodes as nd +from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.graph import Edge from dace.sdfg.propagation import align_memlet -from typing import Dict, Iterable, List, Set, Tuple, Any, Optional, Union -import networkx as nx -from networkx.algorithms import shortest_paths as nxsp - +from dace.sdfg.state import (AbstractControlFlowRegion, ConditionalBlock, + ControlFlowBlock, ControlFlowRegion, LoopRegion) +from dace.subsets import Range +from dace.transformation import pass_pipeline as ppl +from dace.transformation import transformation from dace.transformation.passes.analysis import loop_analysis WriteScopeDict = Dict[str, Dict[Optional[Tuple[SDFGState, nd.AccessNode]], @@ -111,6 +115,8 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ :return: For each control flow region, a dictionary mapping each control flow block to its other reachable control flow blocks. """ + top_sdfg.reset_cfg_list() + single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( lambda: defaultdict(set) ) diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 69a77422e8..28bcb112f5 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -3,13 +3,16 @@ Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes. """ -from typing import Dict, Optional +from typing import Dict, Optional, Union from dace.frontend.python import astutils import sympy from dace import symbolic +from dace.frontend.python import astutils +from dace.memlet import Memlet from dace.sdfg.state import LoopRegion +from dace.subsets import Range, SubsetUnion, intersects def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: @@ -97,3 +100,59 @@ def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: if update_assignment: return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) return None + + +def _loop_read_intersects_loop_write(loop: LoopRegion, write_subset: Union[SubsetUnion, Range], + read_subset: Union[SubsetUnion, Range], update: sympy.Basic) -> bool: + """ + Check if a write subset intersects a read subset after being offset by the loop stride. The offset is performed + based on the symbolic loop update assignment expression. + """ + offset = update - symbolic.symbol(loop.loop_variable) + offset_list = [] + for i in range(write_subset.dims()): + if loop.loop_variable in write_subset.get_free_symbols_by_indices([i]): + offset_list.append(offset) + else: + offset_list.append(0) + offset_write = write_subset.offset_new(offset_list, True) + return intersects(offset_write, read_subset) + +def get_loop_carry_dependencies(loop: LoopRegion) -> Optional[Dict[Memlet, Memlet]]: + """ + Compute loop carry dependencies. + :return: A dictionary mapping loop reads to writes in the same loop, from which they may carry a RAW dependency. + None if the loop cannot be analyzed. + """ + update_assignment = None + raw_deps: Dict[Memlet, Memlet] = dict() + for data in loop.possible_reads: + if not data in loop.possible_writes: + continue + + input = loop.possible_reads[data] + read_subset = input.src_subset or input.subset + if loop.loop_variable and loop.loop_variable in input.free_symbols: + # If the iteration variable is involved in an access, we need to first offset it by the loop + # stride and then check for an overlap/intersection. If one is found after offsetting, there + # is a RAW loop carry dependency. + output = loop.possible_writes[data] + # Get and cache the update assignment for the loop. + if update_assignment is None: + update_assignment = get_update_assignment(loop) + if update_assignment is None: + return None + + if isinstance(output.subset, SubsetUnion): + if any([_loop_read_intersects_loop_write(loop, s, read_subset, update_assignment) + for s in output.subset.subset_list]): + raw_deps[input] = output + elif _loop_read_intersects_loop_write(loop, output.subset, read_subset, update_assignment): + raw_deps[input] = output + else: + # Check for basic overlaps/intersections in RAW loop carry dependencies, when there is no + # iteration variable involved. + output = loop.possible_writes[data] + if intersects(output.subset, read_subset): + raw_deps[input] = output + return raw_deps diff --git a/dace/transformation/passes/analysis/propagation.py b/dace/transformation/passes/analysis/propagation.py new file mode 100644 index 0000000000..4998e9c4da --- /dev/null +++ b/dace/transformation/passes/analysis/propagation.py @@ -0,0 +1,646 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from collections import OrderedDict, defaultdict, deque +import copy +from typing import Dict, Iterable, List, Set, Tuple, Union + +import networkx as nx +import sympy + +from dace import data as dt +from dace import properties, symbolic, subsets +from dace.memlet import Memlet +from dace.sdfg import nodes +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.propagation import align_memlet, propagate_memlet, propagate_subset +from dace.sdfg.scope import ScopeTree +from dace.sdfg.sdfg import SDFG, memlets_in_ast +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, SDFGState +from dace.transformation import pass_pipeline as ppl +from dace.transformation import transformation +from dace.transformation.helpers import unsqueeze_memlet +from dace.transformation.passes.analysis import loop_analysis +from dace.transformation.passes.analysis.analysis import ControlFlowBlockReachability + + +@transformation.explicit_cf_compatible +class StatePropagation(ppl.ControlFlowRegionPass): + """ + Analyze a control flow region to determine the number of times each block inside of it is executed in the form of a + symbolic expression, or a concrete number where possible. + Each control flow block is marked with a symbolic expression for the number of executions, and a boolean flag to + indicate whether the number of executions is dynamic or not. A combination of dynamic being set to true and the + number of executions being 0 indicates that the number of executions is dynamically unbounded. + Additionally, the pass annotates each block with a `ranges` property, which indicates for loop variables defined + at that block what range of values the variable may take on. + Note: This path directly annotates the graph. + This pass supersedes ``dace.sdfg.propagation.propagate_states`` and is based on its algorithm, with significant + simplifications thanks to the use of control flow regions. + """ + + CATEGORY: str = 'Analysis' + + def __init__(self): + super().__init__() + self.top_down = True + self.apply_to_conditionals = True + + def depends_on(self): + return {ControlFlowBlockReachability} + + def _propagate_in_cfg(self, cfg: ControlFlowRegion, reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]], + starting_executions: int, starting_dynamic_executions: bool): + visited_blocks: Set[ControlFlowBlock] = set() + traversal_q: deque[Tuple[ControlFlowBlock, int, bool, List[str]]] = deque() + traversal_q.append((cfg.start_block, starting_executions, starting_dynamic_executions, [])) + while traversal_q: + (block, proposed_executions, proposed_dynamic, itvar_stack) = traversal_q.pop() + out_edges = cfg.out_edges(block) + if block in visited_blocks: + # This block has already been visited, meaning there are multiple paths towards this block. + if proposed_executions == 0 and proposed_dynamic: + block.executions = 0 + block.dynamic_executions = True + else: + block.executions = sympy.Max(block.executions, proposed_executions).doit() + block.dynamic_executions = (block.dynamic_executions or proposed_dynamic) + elif proposed_dynamic and proposed_executions == 0: + # We're propagating a dynamic unbounded number of executions, which always gets propagated + # unconditionally. Propagate to all children. + visited_blocks.add(block) + block.executions = proposed_executions + block.dynamic_executions = proposed_dynamic + # This gets pushed through to all children unconditionally. + if len(out_edges) > 0: + for oedge in out_edges: + traversal_q.append((oedge.dst, proposed_executions, proposed_dynamic, itvar_stack)) + else: + # If the state hasn't been visited yet and we're not propagating a dynamic unbounded number of + # executions, we calculate the number of executions for the next state(s) and continue propagating. + visited_blocks.add(block) + block.executions = proposed_executions + block.dynamic_executions = proposed_dynamic + if len(out_edges) == 1: + # Continue with the only child state. + if not out_edges[0].data.is_unconditional(): + # If the transition to the child state is based on a condition, this state could be an implicit + # exit state. The child state's number of executions is thus only given as an upper bound and + # marked as dynamic. + proposed_dynamic = True + traversal_q.append((out_edges[0].dst, proposed_executions, proposed_dynamic, itvar_stack)) + elif len(out_edges) > 1: + # Conditional split + for oedge in out_edges: + traversal_q.append((oedge.dst, block.executions, True, itvar_stack)) + + # Check if the CFG contains any cycles. Any cycles left in the graph (after control flow raising) are + # irreducible control flow and thus lead to a dynamically unbounded number of executions. Mark any block + # inside and reachable from any block inside the cycle as dynamically unbounded, irrespectively of what it was + # marked as before. + cycles: Iterable[Iterable[ControlFlowBlock]] = cfg.find_cycles() + for cycle in cycles: + for blk in cycle: + blk.executions = 0 + blk.dynamic_executions = True + for reached in reachable[blk]: + reached.executions = 0 + blk.dynamic_executions = True + + def apply(self, region, pipeline_results) -> None: + if isinstance(region, ConditionalBlock): + # In a conditional block, each branch is executed up to as many times as the conditional block itself is. + # TODO(later): We may be able to derive ranges here based on the branch conditions too. + for _, b in region.branches: + b.executions = region.executions + b.dynamic_executions = True + b.ranges = region.ranges + else: + if isinstance(region, SDFG): + # The root SDFG is executed exactly once, any other, nested SDFG is executed as many times as the parent + # state is. + if region is region.root_sdfg: + region.executions = 1 + region.dynamic_executions = False + elif region.parent: + region.executions = region.parent.executions + region.dynamic_executions = region.parent.dynamic_executions + + # Clear existing annotations. + for blk in region.nodes(): + blk.executions = 0 + blk.dynamic_executions = True + blk.ranges = region.ranges + + # Determine the number of executions for the start block within this region. In the case of loops, this + # is dependent on the number of loop iterations - where they can be determined. Where they may not be + # determined, the number of iterations is assumed to be dynamically unbounded. For any other control flow + # region, the start block is executed as many times as the region itself is. + starting_execs = region.executions + starting_dynamic = region.dynamic_executions + if isinstance(region, LoopRegion): + # If inside a loop, add range information if possible. + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + if start is not None and stop is not None and stride is not None and region.loop_variable: + # This inequality needs to be checked exactly like this due to constraints in sympy/symbolic + # expressions, do not simplify! + if (stride < 0) == True: + rng = (stop, start, -stride) + else: + rng = (start, stop, stride) + for blk in region.nodes(): + blk.ranges[str(region.loop_variable)] = subsets.Range([rng]) + + # Get surrounding iteration variables for the case of nested loops. + itvar_stack = [] + par = region.parent_graph + while par is not None and not isinstance(par, SDFG): + if isinstance(par, LoopRegion) and par.loop_variable: + itvar_stack.append(par.loop_variable) + par = par.parent_graph + + # Calculate the number of loop executions. + # This resolves ranges based on the order of iteration variables from surrounding loops. + loop_executions = sympy.ceiling(((stop + 1) - start) / stride) + for outer_itvar_string in itvar_stack: + outer_range = region.ranges[outer_itvar_string] + outer_start = outer_range[0][0] + outer_stop = outer_range[0][1] + outer_stride = outer_range[0][2] + outer_itvar = symbolic.pystr_to_symbolic(outer_itvar_string) + exec_repl = loop_executions.subs({outer_itvar: (outer_itvar * outer_stride + outer_start)}) + sum_rng = (outer_itvar, 0, sympy.ceiling((outer_stop - outer_start) / outer_stride)) + loop_executions = sympy.Sum(exec_repl, sum_rng) + starting_execs = loop_executions.doit() + starting_dynamic = region.dynamic_executions + else: + starting_execs = 0 + starting_dynamic = True + + # Propagate the number of executions. + self._propagate_in_cfg(region, pipeline_results[ControlFlowBlockReachability.__name__][region.cfg_id], + starting_execs, starting_dynamic) + +@properties.make_properties +@transformation.explicit_cf_compatible +class MemletPropagation(ppl.ControlFlowRegionPass): + """ + TODO + """ + + CATEGORY: str = 'Analysis' + + def __init__(self) -> None: + super().__init__() + self.top_down = False + self.apply_to_conditionals = True + + def modifies(self): + return ppl.Modifies.Memlets + + def should_reapply(self, modified): + return modified & (ppl.Modifies.Nodes | ppl.Modifies.Memlets) + + def _propagate_node(self, state: SDFGState, node: Union[nodes.EntryNode, nodes.ExitNode]): + if isinstance(node, nodes.EntryNode): + internal_edges = [e for e in state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] + external_edges = [e for e in state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] + geticonn = lambda e: e.src_conn[4:] + geteconn = lambda e: e.dst_conn[3:] + use_dst = False + else: + internal_edges = [e for e in state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] + external_edges = [e for e in state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] + geticonn = lambda e: e.dst_conn[3:] + geteconn = lambda e: e.src_conn[4:] + use_dst = True + + for edge in external_edges: + if edge.data.is_empty(): + new_memlet = Memlet() + else: + internal_edge = next(e for e in internal_edges if geticonn(e) == geteconn(edge)) + aligned_memlet = align_memlet(state, internal_edge, dst=use_dst) + new_memlet = propagate_memlet(state, aligned_memlet, node, True, connector=geteconn(edge)) + edge.data = new_memlet + + def _propagate_scope(self, state: SDFGState, scopes: List[ScopeTree], propagate_entry: bool = True, + propagate_exit: bool = True) -> None: + scopes_to_process = scopes + next_scopes = set() + + # Process scopes from the inputs upwards, propagating edges at the + # entry and exit nodes + while len(scopes_to_process) > 0: + for scope in scopes_to_process: + if scope.entry is None: + continue + + # Propagate out of entry + if propagate_entry: + self._propagate_node(state, scope.entry) + + # Propagate out of exit + if propagate_exit: + self._propagate_node(state, scope.exit) + + # Add parent to next frontier + next_scopes.add(scope.parent) + scopes_to_process = next_scopes + next_scopes = set() + + def _propagate_nsdfg(self, parent_sdfg: SDFG, parent_state: SDFGState, nsdfg_node: nodes.NestedSDFG): + outer_symbols = parent_state.symbols_defined_at(nsdfg_node) + sdfg = nsdfg_node.sdfg + + possible_reads = copy.deepcopy(sdfg.possible_reads) + possible_writes = copy.deepcopy(sdfg.possible_writes) + certain_reads = copy.deepcopy(sdfg.certain_reads) + certain_writes = copy.deepcopy(sdfg.certain_writes) + + # Make sure any potential NSDFG symbol mapping is correctly reversed when propagating out. + for mapping in [possible_reads, possible_writes, certain_reads, certain_writes]: + for border_memlet in mapping.values(): + border_memlet.replace(nsdfg_node.symbol_mapping) + + # Also make sure that there's no symbol in the border memlet's range that only exists inside the + # nested SDFG. If that's the case, use the entire range. + if border_memlet.src_subset is not None: + if any(str(s) not in outer_symbols.keys() for s in border_memlet.src_subset.free_symbols): + border_memlet.src_subset = subsets.Range.from_array(sdfg.arrays[border_memlet.data]) + if border_memlet.dst_subset is not None: + if any(str(s) not in outer_symbols.keys() for s in border_memlet.dst_subset.free_symbols): + border_memlet.dst_subset = subsets.Range.from_array(sdfg.arrays[border_memlet.data]) + + # Propagate the inside 'border' memlets outside the SDFG by offsetting, and unsqueezing if necessary. + for iedge in parent_state.in_edges(nsdfg_node): + if iedge.dst_conn in possible_reads: + try: + inner_memlet = possible_reads[iedge.dst_conn] + iedge.data = unsqueeze_memlet(inner_memlet, iedge.data, True) + if isinstance(iedge.data.subset, subsets.SubsetUnion): + iedge.data.subset = iedge.data.subset.to_bounding_box_subset() + # If no appropriate memlet found, use array dimension + for i, (rng, s) in enumerate(zip(iedge.data.subset, parent_sdfg.arrays[iedge.data.data].shape)): + if rng[1] + 1 == s: + iedge.data.subset[i] = (iedge.data.subset[i][0], s - 1, 1) + if symbolic.issymbolic(iedge.data.volume): + if any(str(s) not in outer_symbols for s in iedge.data.volume.free_symbols): + iedge.data.volume = 0 + iedge.data.dynamic = True + except (ValueError, NotImplementedError): + # In any case of memlets that cannot be unsqueezed (i.e., reshapes), use dynamic unbounded memlets. + iedge.data.volume = 0 + iedge.data.dynamic = True + for oedge in parent_state.out_edges(nsdfg_node): + if oedge.src_conn in possible_writes: + try: + inner_memlet = possible_writes[oedge.src_conn] + oedge.data = unsqueeze_memlet(inner_memlet, oedge.data, True) + if isinstance(oedge.data.subset, subsets.SubsetUnion): + oedge.data.subset = oedge.data.subset.to_bounding_box_subset() + # If no appropriate memlet found, use array dimension + for i, (rng, s) in enumerate(zip(oedge.data.subset, parent_sdfg.arrays[oedge.data.data].shape)): + if rng[1] + 1 == s: + oedge.data.subset[i] = (oedge.data.subset[i][0], s - 1, 1) + if symbolic.issymbolic(oedge.data.volume): + if any(str(s) not in outer_symbols for s in oedge.data.volume.free_symbols): + oedge.data.volume = 0 + oedge.data.dynamic = True + except (ValueError, NotImplementedError): + # In any case of memlets that cannot be unsqueezed (i.e., reshapes), use dynamic unbounded memlets. + oedge.data.volume = 0 + oedge.data.dynamic = True + + def _propagate_state(self, state: SDFGState) -> None: + # Ensure memlets around nested SDFGs are propagated correctly. + for nd in state.nodes(): + if isinstance(nd, nodes.NestedSDFG): + self._propagate_nsdfg(state.sdfg, state, nd) + + # Propagate memlets through the scopes, bottom up, starting at the scope leaves. + # TODO: Make sure this propagation happens without overapproximation, i.e., using SubsetUnions. + self._propagate_scope(state, state.scope_leaves()) + + # Gather all writes and reads inside this state now to determine the state-wide reads and writes. + # Collect write memlets. + writes: Dict[str, List[Tuple[Memlet, nodes.AccessNode]]] = defaultdict(lambda: []) + for anode in state.data_nodes(): + is_view = isinstance(state.sdfg.data(anode.data), dt.View) + for iedge in state.in_edges(anode): + if not iedge.data.is_empty() and not (is_view and iedge.dst_conn == 'views'): + root_edge = state.memlet_tree(iedge).root().edge + writes[anode.data].append((root_edge.data, anode)) + + # Go over (overapproximated) reads and check if they are covered by writes. + not_covered_reads: Dict[str, Set[Memlet]] = defaultdict(set) + for anode in state.data_nodes(): + for oedge in state.out_edges(anode): + if not oedge.data.is_empty() and not (isinstance(oedge.dst, nodes.AccessNode) and + oedge.dst_conn == 'views'): + if oedge.data.data != anode.data: + # Special case for memlets copying data out of the scope, which are by default aligned with the + # outside data container. In this case, the source container must either be a scalar, or the + # read subset is contained in the memlet's `other_subset` property. + # See `dace.sdfg.propagation.align_memlet` for more. + desc = state.sdfg.data(oedge.data.data) + if oedge.data.other_subset is not None: + read_subset = oedge.data.other_subset + elif oedge.data.dst_subset is not None: + read_subset = oedge.data.dst_subset + elif isinstance(desc, dt.Scalar) or (isinstance(desc, dt.Array) and desc.total_size == 1): + read_subset = subsets.Range([(0, 0, 1)] * len(desc.shape)) + else: + raise RuntimeError('Invalid memlet range detected in MemletPropagation') + else: + read_subset = oedge.data.src_subset or oedge.data.subset + covered = False + for [write, to] in writes[anode.data]: + if write.subset.covers_precise(read_subset) and nx.has_path(state.nx, to, anode): + covered = True + break + if not covered: + not_covered_reads[anode.data].add(oedge.data) + + state.certain_writes = {} + state.possible_writes = {} + for data in writes: + if len(writes[data]) > 0: + subset = None + volume = None + is_dynamic = False + for memlet, _ in writes[data]: + is_dynamic |= memlet.dynamic + if subset is None: + subset = subsets.SubsetUnion(memlet.dst_subset or memlet.subset) + else: + subset.union(memlet.dst_subset or memlet.subset) + if memlet.volume == 0: + volume = 0 + else: + if volume is None: + volume = memlet.volume + elif volume != 0: + volume += memlet.volume + new_memlet = Memlet(data=data, subset=subset) + new_memlet.dynamic = is_dynamic + new_memlet.volume = volume if volume is not None else 0 + state.certain_writes[data] = new_memlet + state.possible_writes[data] = new_memlet + + state.certain_reads = {} + state.possible_reads = {} + for data in not_covered_reads: + subset = None + volume = None + is_dynamic = False + for memlet in not_covered_reads[data]: + is_dynamic |= memlet.dynamic + if subset is None: + subset = subsets.SubsetUnion(memlet.dst_subset or memlet.subset) + else: + subset.union(memlet.dst_subset or memlet.subset) + if memlet.volume == 0: + volume = 0 + else: + if volume is None: + volume = memlet.volume + elif volume != 0: + volume += memlet.volume + new_memlet = Memlet(data=data, subset=subset) + new_memlet.dynamic = is_dynamic + new_memlet.volume = volume + state.certain_reads[data] = new_memlet + state.possible_reads[data] = new_memlet + + def _propagate_conditional(self, conditional: ConditionalBlock) -> None: + # The union of all reads between all conditions and branches gives the set of _possible_ reads, while the + # intersection gives the set of _guaranteed_ reads. The first condition can also be counted as a guaranteed + # read. The same applies for writes, except that conditions do not contain writes. + + def add_memlet(memlet: Memlet, mlt_dict: Dict[str, Memlet], use_intersection: bool = False): + if memlet.data not in mlt_dict: + propagated_memlet = Memlet(data=memlet.data, subset=(memlet.src_subset or memlet.subset)) + propagated_memlet.volume = memlet.volume + propagated_memlet.dynamic = memlet.dynamic + mlt_dict[memlet.data] = propagated_memlet + else: + propagated_memlet = mlt_dict[memlet.data] + if use_intersection: + isect = propagated_memlet.subset.intersection(memlet.src_subset or memlet.subset) + if isect is not None: + propagated_memlet.subset = isect + else: + propagated_memlet.subset = subsets.SubsetUnion([]) + + propagated_memlet.volume = sympy.Min(memlet.volume, propagated_memlet.volume) + propagated_memlet.dynamic |= memlet.dynamic + else: + mlt_subset = propagated_memlet.subset + if not isinstance(mlt_subset, subsets.SubsetUnion): + mlt_subset = subsets.SubsetUnion([mlt_subset]) + mlt_subset.union(memlet.src_subset or memlet.subset) + propagated_memlet.subset = mlt_subset + + if propagated_memlet.volume != 0: + if memlet.volume == 0: + propagated_memlet.volume = memlet.volume + else: + propagated_memlet.volume = sympy.Max(memlet.volume, propagated_memlet.volume) + propagated_memlet.dynamic |= memlet.dynamic + + conditional.possible_reads = {} + conditional.certain_reads = {} + conditional.possible_writes = {} + conditional.certain_writes = {} + + # Gather the union of possible reads and writes. At the same time, determine if there is an else branch present. + has_else = False + for cond, branch in conditional.branches: + if cond is not None: + read_memlets = memlets_in_ast(cond.code[0], conditional.sdfg.arrays) + for read_memlet in read_memlets: + add_memlet(read_memlet, conditional.possible_reads) + else: + has_else = True + for read_data in branch.possible_reads: + read_memlet = branch.possible_reads[read_data] + add_memlet(read_memlet, conditional.possible_reads) + for write_data in branch.possible_writes: + write_memlet = branch.possible_writes[write_data] + add_memlet(write_memlet, conditional.possible_writes) + + # If there is no else branch or only one branch exists, there are no certain reads or writes. + if len(conditional.branches) > 1 and has_else: + # Gather the certain reads (= Intersection of certain reads for each branch) + for container in conditional.possible_reads.keys(): + candidate_memlets = [] + skip = False + for cond, branch in conditional.branches: + found = False + if cond is not None: + read_memlets = memlets_in_ast(cond.code[0], conditional.sdfg.arrays) + for read_memlet in read_memlets: + if read_memlet.data == container: + found = True + candidate_memlets.append(read_memlet) + if container in branch.certain_reads: + found = True + candidate_memlets.append(branch.certain_reads[container]) + if not found: + skip = True + break + if skip: + continue + for cand_memlet in candidate_memlets: + add_memlet(cand_memlet, conditional.certain_reads, use_intersection=True) + # Gather the certain writes (= Intersection of certain writes for each branch) + for container in conditional.possible_writes.keys(): + candidate_memlets = [] + skip = False + for _, branch in conditional.branches: + if container in branch.certain_writes: + candidate_memlets.append(branch.certain_writes[container]) + else: + skip = True + break + if skip: + continue + for cand_memlet in candidate_memlets: + add_memlet(cand_memlet, conditional.certain_writes, use_intersection=True) + + # Ensure the first condition's reads are part of the certain reads. + first_cond = conditional.branches[0][0] + if first_cond is not None: + read_memlets = memlets_in_ast(first_cond.code[0], conditional.sdfg.arrays) + for read_memlet in read_memlets: + add_memlet(read_memlet, conditional.certain_reads) + + def _propagate_loop(self, loop: LoopRegion) -> None: + # First propagate the contents of the loop for one iteration. + self._propagate_cfg(loop) + + # Propagate memlets from inside the loop through the loop ranges. + # Collect loop information and form the loop variable range first. + itvar = loop.loop_variable + start = loop_analysis.get_init_assignment(loop) + end = loop_analysis.get_loop_end(loop) + stride = loop_analysis.get_loop_stride(loop) + if itvar and start is not None and end is not None and stride is not None: + loop_range = subsets.Range([(start, end, stride)]) + deps = loop_analysis.get_loop_carry_dependencies(loop) + else: + loop_range = None + deps = {} + + # Collect all symbols and variables (i.e., scalar data containers) defined at this point, particularly by + # looking at defined loop variables in the parent chain up the control flow tree. + variables_at_loop = OrderedDict(loop.sdfg.symbols) + for k, v in loop.sdfg.arrays.items(): + if isinstance(v, dt.Scalar): + variables_at_loop[k] = v + pivot = loop + while pivot is not None: + if isinstance(pivot, LoopRegion): + new_symbols = pivot.new_symbols(loop.sdfg.symbols) + variables_at_loop.update(new_symbols) + pivot = pivot.parent_graph + defined_variables = [symbolic.pystr_to_symbolic(s) for s in variables_at_loop.keys()] + # Propagate memlet subsets through the loop variable and its range. + # TODO: Remove loop-carried dependencies from the writes (i.e., only the first read would be a true read) + for memlet_repo in [loop.certain_reads, loop.possible_reads]: + for dat in memlet_repo.keys(): + read = memlet_repo[dat] + arr = loop.sdfg.data(dat) + if read in deps: + dep_write = deps[read] + #diff = subsets.difference(read.subset, dep_write.subset) + #if isinstance(diff, subsets.SubsetUnion): + # diff = diff.to_bounding_box_subset() + #tgt_expr = symbolic.pystr_to_symbolic(itvar) - loop_range.ranges[0][-1] + #for i in range(diff.dims()): + # dim = diff.dim_to_string(i) + # ... + ## Check if the remaining read subset is only in the direction opposing the loop iteration. If + #diff.__getitem__(0) + new_read = propagate_subset([read], arr, [itvar], loop_range, defined_variables, use_dst=False) + memlet_repo[dat] = new_read + else: + new_read = propagate_subset([read], arr, [itvar], loop_range, defined_variables, use_dst=False) + memlet_repo[dat] = new_read + for memlet_repo in [loop.certain_writes, loop.possible_writes]: + for dat in memlet_repo.keys(): + write = memlet_repo[dat] + arr = loop.sdfg.data(dat) + new_write = propagate_subset([write], arr, [itvar], loop_range, defined_variables, use_dst=True) + memlet_repo[dat] = new_write + + def _propagate_cfg(self, cfg: ControlFlowRegion) -> None: + cfg.possible_reads = {} + cfg.possible_writes = {} + cfg.certain_reads = {} + cfg.certain_writes = {} + + alldoms = cfg_analysis.all_dominators(cfg) + + # For each node in the CFG, check what reads are covered by exactly covering writes in dominating nodes. If such + # a dominating write is found, the read is contained to read data originating from within the same CFG, and thus + # is not counted as an input to the CFG. + for nd in cfg.nodes(): + # For each node, also determine possible reads from interstate edges. For this, any read from any outgoing + # interstate edge is counted as a possible read. The only time it is NOT counted as a read, is when there + # is a certain write in the block itself tha covers the read. + for oedge in cfg.out_edges(nd): + for read_memlet in oedge.data.get_read_memlets(cfg.sdfg.arrays): + covered = False + if (read_memlet.data not in nd.certain_writes or + not nd.certain_writes[read_memlet.data].subset.covers_precise(read_memlet.subset)): + if read_memlet.data in nd.possible_reads: + existing_memlet = nd.possible_reads[read_memlet.data] + if isinstance(existing_memlet.subset, subsets.SubsetUnion): + existing_memlet.subset.union(read_memlet.subset) + else: + subset = subsets.SubsetUnion(read_memlet.subset) + subset.union(existing_memlet.subset) + existing_memlet.subset = subset + else: + nd.possible_reads[read_memlet.data] = read_memlet + for read_data in nd.possible_reads: + read_memlet = nd.possible_reads[read_data] + covered = False + for dom in alldoms[nd]: + if read_data in dom.certain_writes: + write_memlet = dom.certain_writes[read_data] + if write_memlet.subset.covers_precise(read_memlet.subset): + covered = True + break + if not covered: + cfg.possible_reads[read_data] = copy.deepcopy(read_memlet) + cfg.certain_reads[read_data] = cfg.possible_reads[read_data] + for nd in cfg.nodes(): + for cont in nd.possible_writes: + if cont in cfg.possible_writes: + union = subsets.SubsetUnion(cfg.possible_writes[cont].subset) + union.union(nd.possible_writes[cont].subset) + cfg.possible_writes[cont] = Memlet(data=cont, subset=union) + else: + cfg.possible_writes[cont] = copy.deepcopy(nd.possible_writes[cont]) + for cont in nd.certain_writes: + if cont in cfg.certain_writes: + union = subsets.SubsetUnion(cfg.certain_writes[cont].subset) + union.union(nd.certain_writes[cont].subset) + cfg.certain_writes[cont] = Memlet(data=cont, subset=union) + else: + cfg.certain_writes[cont] = copy.deepcopy(nd.certain_writes[cont]) + + def apply(self, region: ControlFlowRegion, _) -> None: + for nd in region.nodes(): + if isinstance(nd, SDFGState): + self._propagate_state(nd) + if isinstance(region, ConditionalBlock): + self._propagate_conditional(region) + elif isinstance(region, LoopRegion): + self._propagate_loop(region) + else: + self._propagate_cfg(region) diff --git a/dace/transformation/passes/lift_struct_views.py b/dace/transformation/passes/lift_struct_views.py index 6744161000..cd1d454d8a 100644 --- a/dace/transformation/passes/lift_struct_views.py +++ b/dace/transformation/passes/lift_struct_views.py @@ -16,6 +16,8 @@ import sys + +from dace.transformation.transformation import explicit_cf_compatible if sys.version_info >= (3, 8): from typing import Literal dirtype = Literal['in', 'out'] @@ -348,6 +350,7 @@ def _data_containers_in_ast(node: ast.AST, arrnames: Set[str]) -> Set[str]: result.add(data) return result +@explicit_cf_compatible class LiftStructViews(ppl.Pass): """ Lift direct accesses to struct members to accesses to views pointing to that struct member. diff --git a/tests/passes/analysis/propagation_test.py b/tests/passes/analysis/propagation_test.py new file mode 100644 index 0000000000..d94f78dd5b --- /dev/null +++ b/tests/passes/analysis/propagation_test.py @@ -0,0 +1,181 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.transformation.passes.analysis.propagation import MemletPropagation + + +def test_nested_conditional_in_map(): + N = dace.symbol('N') + M = dace.symbol('M') + + @dace.program + def nested_conditional_in_map(A: dace.int32[M, N]): + for i in dace.map[0:M]: + if A[0][0]: + A[i, :] = 1 + else: + A[i, :] = 2 + + sdfg = nested_conditional_in_map.to_sdfg(simplify=True) + + MemletPropagation().apply_pass(sdfg, {}) + + assert 'A' in sdfg.possible_reads + assert str(sdfg.possible_reads['A'].subset) == '0, 0' + assert sdfg.possible_reads['A'].dynamic == False + assert sdfg.possible_reads['A'].volume == M + assert 'A' in sdfg.certain_reads + assert str(sdfg.certain_reads['A'].subset) == '0, 0' + assert sdfg.certain_reads['A'].dynamic == False + assert sdfg.certain_reads['A'].volume == M + assert 'A' in sdfg.possible_writes + assert str(sdfg.possible_writes['A'].subset) == '0:M, 0:N' + assert sdfg.possible_writes['A'].dynamic == False + assert sdfg.possible_writes['A'].volume == M * N + assert 'A' in sdfg.certain_writes + assert str(sdfg.certain_writes['A'].subset) == '0:M, 0:N' + assert sdfg.certain_writes['A'].dynamic == False + assert sdfg.certain_writes['A'].volume == M * N + +def test_nested_conditional_in_loop_in_map(): + N = dace.symbol('N') + M = dace.symbol('M') + + @dace.program + def nested_conditional_in_loop_in_map(A: dace.int32[M, N]): + for i in dace.map[0:M]: + for j in range(0, N - 2, 1): + if A[0][0]: + A[i, j] = 1 + else: + A[i, j] = 2 + A[i, j] = A[i, j] * A[i, j] + + sdfg = nested_conditional_in_loop_in_map.to_sdfg(simplify=True) + + MemletPropagation().apply_pass(sdfg, {}) + + assert 'A' in sdfg.possible_reads + assert str(sdfg.possible_reads['A'].subset) == '0, 0' + assert sdfg.possible_reads['A'].dynamic == False + assert sdfg.possible_reads['A'].volume == M * (N - 2) + assert 'A' in sdfg.certain_reads + assert str(sdfg.certain_reads['A'].subset) == '0, 0' + assert sdfg.certain_reads['A'].dynamic == False + assert sdfg.certain_reads['A'].volume == M * (N - 2) + assert 'A' in sdfg.possible_writes + assert str(sdfg.possible_writes['A'].subset) == '0:M, 0:N - 2' + assert sdfg.possible_writes['A'].dynamic == False + assert sdfg.possible_writes['A'].volume == M * (N - 2) + assert 'A' in sdfg.certain_writes + assert str(sdfg.certain_writes['A'].subset) == '0:M, 0:N - 2' + assert sdfg.certain_writes['A'].dynamic == False + assert sdfg.certain_writes['A'].volume == M * (N - 2) + +def test_runtime_conditional(): + @dace.program + def rconditional(in1: dace.float64[10], out: dace.float64[10], mask: dace.int32[10]): + for i in dace.map[1:10]: + if mask[i] > 0: + out[i] = in1[i - 1] + else: + out[i] = in1[i] + + sdfg = rconditional.to_sdfg(simplify=True) + + MemletPropagation().apply_pass(sdfg, {}) + + assert 'mask' in sdfg.possible_reads + assert str(sdfg.possible_reads['mask'].subset) == '1:10' + assert sdfg.possible_reads['mask'].dynamic == False + assert sdfg.possible_reads['mask'].volume == 9 + assert 'in1' in sdfg.possible_reads + assert str(sdfg.possible_reads['in1'].subset) == '0:10' + assert sdfg.possible_reads['in1'].dynamic == False + assert sdfg.possible_reads['in1'].volume == 18 + + assert 'mask' in sdfg.certain_reads + assert str(sdfg.certain_reads['mask'].subset) == '1:10' + assert sdfg.certain_reads['mask'].dynamic == False + assert sdfg.certain_reads['mask'].volume == 9 + assert 'in1' in sdfg.certain_reads + assert str(sdfg.certain_reads['in1'].subset) == '0:10' + assert sdfg.certain_reads['in1'].dynamic == False + assert sdfg.certain_reads['in1'].volume == 18 + + assert 'out' in sdfg.possible_writes + assert str(sdfg.possible_writes['out'].subset) == '1:10' + assert sdfg.possible_writes['out'].dynamic == False + assert sdfg.possible_writes['out'].volume == 9 + + assert 'out' in sdfg.certain_writes + assert str(sdfg.certain_writes['out'].subset) == '1:10' + assert sdfg.certain_writes['out'].dynamic == False + assert sdfg.certain_writes['out'].volume == 9 + +def test_nsdfg_memlet_propagation_with_one_sparse_dimension(): + N = dace.symbol('N') + M = dace.symbol('M') + @dace.program + def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]): + for i, j in dace.map[0:M, 0:N]: + A[i, ind[i, j]] += 1 + + sdfg = sparse.to_sdfg(simplify=False) + + MemletPropagation().apply_pass(sdfg, {}) + + assert 'ind' in sdfg.possible_reads + assert str(sdfg.possible_reads['ind'].subset) == '0:M, 0:N' + assert sdfg.possible_reads['ind'].dynamic == False + assert sdfg.possible_reads['ind'].volume == N * M + + assert 'ind' in sdfg.certain_reads + assert str(sdfg.certain_reads['ind'].subset) == '0:M, 0:N' + assert sdfg.certain_reads['ind'].dynamic == False + assert sdfg.certain_reads['ind'].volume == N * M + + assert 'A' in sdfg.possible_writes + assert str(sdfg.possible_writes['A'].subset) == '0:M, 0:N' + assert sdfg.possible_writes['A'].dynamic == False + assert sdfg.possible_writes['A'].volume == N * M + + assert 'A' in sdfg.certain_writes + assert str(sdfg.certain_writes['A'].subset) == '0:M, 0:N' + assert sdfg.certain_writes['A'].dynamic == False + assert sdfg.certain_writes['A'].volume == N * M + +def test_nested_loop_in_map(): + N = dace.symbol('N') + M = dace.symbol('M') + + @dace.program + def nested_loop_in_map(A: dace.float64[N, M]): + for i in dace.map[0:N]: + for j in range(M): + A[i, j] = 0 + + sdfg = nested_loop_in_map.to_sdfg(simplify=True) + + MemletPropagation().apply_pass(sdfg, {}) + + assert sdfg.possible_reads == {} + assert sdfg.certain_reads == {} + + assert 'A' in sdfg.possible_writes + assert str(sdfg.possible_writes['A'].subset) == '0:N, 0:M' + assert sdfg.possible_writes['A'].dynamic == False + assert sdfg.possible_writes['A'].volume == N * M + + assert 'A' in sdfg.certain_writes + assert str(sdfg.certain_writes['A'].subset) == '0:N, 0:M' + assert sdfg.certain_writes['A'].dynamic == False + assert sdfg.certain_writes['A'].volume == N * M + + +if __name__ == '__main__': + test_nested_conditional_in_map() + test_nested_conditional_in_loop_in_map() + test_runtime_conditional() + test_nsdfg_memlet_propagation_with_one_sparse_dimension() + test_nested_loop_in_map() diff --git a/tests/python_frontend/assignment_statements_test.py b/tests/python_frontend/assignment_statements_test.py index f8538aa848..905e4f3d8f 100644 --- a/tests/python_frontend/assignment_statements_test.py +++ b/tests/python_frontend/assignment_statements_test.py @@ -38,6 +38,38 @@ def test_single_target_parentheses(): assert (b[0] == np.float32(np.pi)) +@dace.program +def single_target_tuple(a: dace.float32[1], b: dace.float32[1], c: dace.float32[2]): + c = (a, b) + + +def test_single_target_tuple(): + a = np.zeros((1, ), dtype=np.float32) + b = np.zeros((1, ), dtype=np.float32) + c = np.zeros((2, ), dtype=np.float32) + a[0] = np.pi + b[0] = 2 * np.pi + single_target_tuple(a=a, b=b, c=c) + assert (c[0] == a[0]) + assert (c[1] == b[0]) + + +@dace.program +def single_target_tuple_with_definition(a: dace.float32[1], b: dace.float32[1]): + c = (a, b) + return c + + +def test_single_target_tuple_with_definition(): + a = np.zeros((1, ), dtype=np.float32) + b = np.zeros((1, ), dtype=np.float32) + a[0] = np.pi + b[0] = 2 * np.pi + c = single_target_tuple_with_definition(a=a, b=b) + assert (c[0] == a[0]) + assert (c[1] == b[0]) + + @dace.program def multiple_targets(a: dace.float32[1]): b, c = a, 2 * a @@ -66,6 +98,21 @@ def test_multiple_targets_parentheses(): assert (c[0] == np.float32(2) * np.float32(np.pi)) +@dace.program +def multiple_targets_unpacking(a: dace.float32[2]): + b, c = a + return b, c + + +def test_multiple_targets_unpacking(): + a = np.zeros((2, ), dtype=np.float32) + a[0] = np.pi + a[1] = 2 * np.pi + b, c = multiple_targets_unpacking(a=a) + assert (b[0] == a[0]) + assert (c[0] == a[1]) + + @dace.program def starred_target(a: dace.float32[1]): b, *c, d, e = a, 2 * a, 3 * a, 4 * a, 5 * a, 6 * a @@ -173,8 +220,11 @@ def method(self): if __name__ == "__main__": test_single_target() test_single_target_parentheses() + test_single_target_tuple() + test_single_target_tuple_with_definition() test_multiple_targets() test_multiple_targets_parentheses() + test_multiple_targets_unpacking() # test_starred_target() # test_attribute_reference() diff --git a/tests/subset_intersects_test.py b/tests/subset_intersects_test.py deleted file mode 100644 index c7a01705b9..0000000000 --- a/tests/subset_intersects_test.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace import subsets - - -def test_intersects_symbolic(): - N, M = dace.symbol('N', positive=True), dace.symbol('M', positive=True) - rng1 = subsets.Range([(0, N - 1, 1), (0, M - 1, 1)]) - rng2 = subsets.Range([(0, 0, 1), (0, 0, 1)]) - rng3_1 = subsets.Range([(N, N, 1), (0, 1, 1)]) - rng3_2 = subsets.Range([(0, 1, 1), (M, M, 1)]) - rng4 = subsets.Range([(N, N, 1), (M, M, 1)]) - rng5 = subsets.Range([(0, 0, 1), (M, M, 1)]) - rng6 = subsets.Range([(0, N, 1), (0, M, 1)]) - rng7 = subsets.Range([(0, N - 1, 1), (N - 1, N, 1)]) - ind1 = subsets.Indices([0, 1]) - - assert subsets.intersects(rng1, rng2) is True - assert subsets.intersects(rng1, rng3_1) is False - assert subsets.intersects(rng1, rng3_2) is False - assert subsets.intersects(rng1, rng4) is False - assert subsets.intersects(rng1, rng5) is False - assert subsets.intersects(rng6, rng1) is True - assert subsets.intersects(rng1, rng7) is None - assert subsets.intersects(rng7, rng1) is None - assert subsets.intersects(rng1, ind1) is None - assert subsets.intersects(ind1, rng1) is None - - -def test_intersects_constant(): - rng1 = subsets.Range([(0, 4, 1)]) - rng2 = subsets.Range([(3, 4, 1)]) - rng3 = subsets.Range([(1, 5, 1)]) - rng4 = subsets.Range([(5, 7, 1)]) - ind1 = subsets.Indices([0]) - ind2 = subsets.Indices([1]) - ind3 = subsets.Indices([5]) - - assert subsets.intersects(rng1, rng2) is True - assert subsets.intersects(rng1, rng3) is True - assert subsets.intersects(rng1, rng4) is False - assert subsets.intersects(ind1, rng1) is True - assert subsets.intersects(rng1, ind2) is True - assert subsets.intersects(rng1, ind3) is False - - -def test_covers_symbolic(): - N, M = dace.symbol('N', positive=True), dace.symbol('M', positive=True) - rng1 = subsets.Range([(0, N - 1, 1), (0, M - 1, 1)]) - rng2 = subsets.Range([(0, 0, 1), (0, 0, 1)]) - rng3_1 = subsets.Range([(N, N, 1), (0, 1, 1)]) - rng3_2 = subsets.Range([(0, 1, 1), (M, M, 1)]) - rng4 = subsets.Range([(N, N, 1), (M, M, 1)]) - rng5 = subsets.Range([(0, 0, 1), (M, M, 1)]) - rng6 = subsets.Range([(0, N, 1), (0, M, 1)]) - rng7 = subsets.Range([(0, N - 1, 1), (N - 1, N, 1)]) - ind1 = subsets.Indices([0, 1]) - - assert rng1.covers(rng2) is True - assert rng1.covers(rng3_1) is False - assert rng1.covers(rng3_2) is False - assert rng1.covers(rng4) is False - assert rng1.covers(rng5) is False - assert rng6.covers(rng1) is True - assert rng1.covers(rng7) is False - assert rng7.covers(rng1) is False - assert rng1.covers(ind1) is True - assert ind1.covers(rng1) is False - - rng8 = subsets.Range([(0, dace.symbolic.pystr_to_symbolic('int_ceil(M, N)'), 1)]) - - assert rng8.covers(rng8) is True - - -if __name__ == '__main__': - test_intersects_symbolic() - test_intersects_constant() - test_covers_symbolic() diff --git a/tests/subsets_squeeze_test.py b/tests/subsets_squeeze_test.py deleted file mode 100644 index c60269c680..0000000000 --- a/tests/subsets_squeeze_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from copy import deepcopy -from dace.subsets import Range, Indices - - -def test_squeeze_unsqueeze_indices(): - - a1 = Indices.from_string('i, 0') - expected_squeezed = [1] - a2 = deepcopy(a1) - not_squeezed = a2.squeeze(ignore_indices=[0]) - squeezed = [i for i in range(len(a1)) if i not in not_squeezed] - unsqueezed = a2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (a1 == a2) - - b1 = Indices.from_string('0, i') - expected_squeezed = [0] - b2 = deepcopy(b1) - not_squeezed = b2.squeeze(ignore_indices=[1]) - squeezed = [i for i in range(len(b1)) if i not in not_squeezed] - unsqueezed = b2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (b1 == b2) - - c1 = Indices.from_string('i, 0, 0') - expected_squeezed = [1, 2] - c2 = deepcopy(c1) - not_squeezed = c2.squeeze(ignore_indices=[0]) - squeezed = [i for i in range(len(c1)) if i not in not_squeezed] - unsqueezed = c2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (c1 == c2) - - d1 = Indices.from_string('0, i, 0') - expected_squeezed = [0, 2] - d2 = deepcopy(d1) - not_squeezed = d2.squeeze(ignore_indices=[1]) - squeezed = [i for i in range(len(d1)) if i not in not_squeezed] - unsqueezed = d2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (d1 == d2) - - e1 = Indices.from_string('0, 0, i') - expected_squeezed = [0, 1] - e2 = deepcopy(e1) - not_squeezed = e2.squeeze(ignore_indices=[2]) - squeezed = [i for i in range(len(e1)) if i not in not_squeezed] - unsqueezed = e2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (e1 == e2) - - -def test_squeeze_unsqueeze_ranges(): - - a1 = Range.from_string('0:10, 0') - expected_squeezed = [1] - a2 = deepcopy(a1) - not_squeezed = a2.squeeze() - squeezed = [i for i in range(len(a1)) if i not in not_squeezed] - unsqueezed = a2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (a1 == a2) - - b1 = Range.from_string('0, 0:10') - expected_squeezed = [0] - b2 = deepcopy(b1) - not_squeezed = b2.squeeze() - squeezed = [i for i in range(len(b1)) if i not in not_squeezed] - unsqueezed = b2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (b1 == b2) - - c1 = Range.from_string('0:10, 0, 0') - expected_squeezed = [1, 2] - c2 = deepcopy(c1) - not_squeezed = c2.squeeze() - squeezed = [i for i in range(len(c1)) if i not in not_squeezed] - unsqueezed = c2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (c1 == c2) - - d1 = Range.from_string('0, 0:10, 0') - expected_squeezed = [0, 2] - d2 = deepcopy(d1) - not_squeezed = d2.squeeze() - squeezed = [i for i in range(len(d1)) if i not in not_squeezed] - unsqueezed = d2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (d1 == d2) - - e1 = Range.from_string('0, 0, 0:10') - expected_squeezed = [0, 1] - e2 = deepcopy(e1) - not_squeezed = e2.squeeze() - squeezed = [i for i in range(len(e1)) if i not in not_squeezed] - unsqueezed = e2.unsqueeze(squeezed) - assert (squeezed == unsqueezed) - assert (expected_squeezed == squeezed) - assert (e1 == e2) - - -if __name__ == '__main__': - test_squeeze_unsqueeze_indices() - test_squeeze_unsqueeze_ranges() diff --git a/tests/subsets_test.py b/tests/subsets_test.py new file mode 100644 index 0000000000..d397d781c5 --- /dev/null +++ b/tests/subsets_test.py @@ -0,0 +1,247 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from copy import deepcopy + +import dace +from dace import subsets + + +def test_intersects_symbolic(): + N, M = dace.symbol('N', positive=True), dace.symbol('M', positive=True) + rng1 = subsets.Range([(0, N - 1, 1), (0, M - 1, 1)]) + rng2 = subsets.Range([(0, 0, 1), (0, 0, 1)]) + rng3_1 = subsets.Range([(N, N, 1), (0, 1, 1)]) + rng3_2 = subsets.Range([(0, 1, 1), (M, M, 1)]) + rng4 = subsets.Range([(N, N, 1), (M, M, 1)]) + rng5 = subsets.Range([(0, 0, 1), (M, M, 1)]) + rng6 = subsets.Range([(0, N, 1), (0, M, 1)]) + rng7 = subsets.Range([(0, N - 1, 1), (N - 1, N, 1)]) + ind1 = subsets.Indices([0, 1]) + + assert subsets.intersects(rng1, rng2) is True + assert subsets.intersects(rng1, rng3_1) is False + assert subsets.intersects(rng1, rng3_2) is False + assert subsets.intersects(rng1, rng4) is False + assert subsets.intersects(rng1, rng5) is False + assert subsets.intersects(rng6, rng1) is True + assert subsets.intersects(rng1, rng7) is None + assert subsets.intersects(rng7, rng1) is None + assert subsets.intersects(rng1, ind1) is None + assert subsets.intersects(ind1, rng1) is None + + +def test_intersects_constant(): + rng1 = subsets.Range([(0, 4, 1)]) + rng2 = subsets.Range([(3, 4, 1)]) + rng3 = subsets.Range([(1, 5, 1)]) + rng4 = subsets.Range([(5, 7, 1)]) + ind1 = subsets.Indices([0]) + ind2 = subsets.Indices([1]) + ind3 = subsets.Indices([5]) + + assert subsets.intersects(rng1, rng2) is True + assert subsets.intersects(rng1, rng3) is True + assert subsets.intersects(rng1, rng4) is False + assert subsets.intersects(ind1, rng1) is True + assert subsets.intersects(rng1, ind2) is True + assert subsets.intersects(rng1, ind3) is False + + +def test_covers_symbolic(): + N, M = dace.symbol('N', positive=True), dace.symbol('M', positive=True) + rng1 = subsets.Range([(0, N - 1, 1), (0, M - 1, 1)]) + rng2 = subsets.Range([(0, 0, 1), (0, 0, 1)]) + rng3_1 = subsets.Range([(N, N, 1), (0, 1, 1)]) + rng3_2 = subsets.Range([(0, 1, 1), (M, M, 1)]) + rng4 = subsets.Range([(N, N, 1), (M, M, 1)]) + rng5 = subsets.Range([(0, 0, 1), (M, M, 1)]) + rng6 = subsets.Range([(0, N, 1), (0, M, 1)]) + rng7 = subsets.Range([(0, N - 1, 1), (N - 1, N, 1)]) + ind1 = subsets.Indices([0, 1]) + + assert rng1.covers(rng2) is True + assert rng1.covers(rng3_1) is False + assert rng1.covers(rng3_2) is False + assert rng1.covers(rng4) is False + assert rng1.covers(rng5) is False + assert rng6.covers(rng1) is True + assert rng1.covers(rng7) is False + assert rng7.covers(rng1) is False + assert rng1.covers(ind1) is True + assert ind1.covers(rng1) is False + + rng8 = subsets.Range([(0, dace.symbolic.pystr_to_symbolic('int_ceil(M, N)'), 1)]) + + assert rng8.covers(rng8) is True + + +def test_squeeze_unsqueeze_indices(): + + a1 = subsets.Indices.from_string('i, 0') + expected_squeezed = [1] + a2 = deepcopy(a1) + not_squeezed = a2.squeeze(ignore_indices=[0]) + squeezed = [i for i in range(len(a1)) if i not in not_squeezed] + unsqueezed = a2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (a1 == a2) + + b1 = subsets.Indices.from_string('0, i') + expected_squeezed = [0] + b2 = deepcopy(b1) + not_squeezed = b2.squeeze(ignore_indices=[1]) + squeezed = [i for i in range(len(b1)) if i not in not_squeezed] + unsqueezed = b2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (b1 == b2) + + c1 = subsets.Indices.from_string('i, 0, 0') + expected_squeezed = [1, 2] + c2 = deepcopy(c1) + not_squeezed = c2.squeeze(ignore_indices=[0]) + squeezed = [i for i in range(len(c1)) if i not in not_squeezed] + unsqueezed = c2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (c1 == c2) + + d1 = subsets.Indices.from_string('0, i, 0') + expected_squeezed = [0, 2] + d2 = deepcopy(d1) + not_squeezed = d2.squeeze(ignore_indices=[1]) + squeezed = [i for i in range(len(d1)) if i not in not_squeezed] + unsqueezed = d2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (d1 == d2) + + e1 = subsets.Indices.from_string('0, 0, i') + expected_squeezed = [0, 1] + e2 = deepcopy(e1) + not_squeezed = e2.squeeze(ignore_indices=[2]) + squeezed = [i for i in range(len(e1)) if i not in not_squeezed] + unsqueezed = e2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (e1 == e2) + + +def test_squeeze_unsqueeze_ranges(): + + a1 = subsets.Range.from_string('0:10, 0') + expected_squeezed = [1] + a2 = deepcopy(a1) + not_squeezed = a2.squeeze() + squeezed = [i for i in range(len(a1)) if i not in not_squeezed] + unsqueezed = a2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (a1 == a2) + + b1 = subsets.Range.from_string('0, 0:10') + expected_squeezed = [0] + b2 = deepcopy(b1) + not_squeezed = b2.squeeze() + squeezed = [i for i in range(len(b1)) if i not in not_squeezed] + unsqueezed = b2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (b1 == b2) + + c1 = subsets.Range.from_string('0:10, 0, 0') + expected_squeezed = [1, 2] + c2 = deepcopy(c1) + not_squeezed = c2.squeeze() + squeezed = [i for i in range(len(c1)) if i not in not_squeezed] + unsqueezed = c2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (c1 == c2) + + d1 = subsets.Range.from_string('0, 0:10, 0') + expected_squeezed = [0, 2] + d2 = deepcopy(d1) + not_squeezed = d2.squeeze() + squeezed = [i for i in range(len(d1)) if i not in not_squeezed] + unsqueezed = d2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (d1 == d2) + + e1 = subsets.Range.from_string('0, 0, 0:10') + expected_squeezed = [0, 1] + e2 = deepcopy(e1) + not_squeezed = e2.squeeze() + squeezed = [i for i in range(len(e1)) if i not in not_squeezed] + unsqueezed = e2.unsqueeze(squeezed) + assert (squeezed == unsqueezed) + assert (expected_squeezed == squeezed) + assert (e1 == e2) + + +def test_difference_symbolic(): + N, M = dace.symbol('N', positive=True), dace.symbol('M', positive=True) + rng1 = subsets.Range([(0, N - 1, 1), (0, M - 1, 1)]) + rng2 = subsets.Range([(0, 0, 1), (0, 0, 1)]) + rng3_1 = subsets.Range([(N, N, 1), (0, 1, 1)]) + rng3_2 = subsets.Range([(0, 1, 1), (M, M, 1)]) + rng4 = subsets.Range([(N, N, 1), (M, M, 1)]) + rng5 = subsets.Range([(0, 0, 1), (M, M, 1)]) + rng6 = subsets.Range([(0, N, 1), (0, M, 1)]) + rng7 = subsets.Range([(0, N - 1, 1), (N - 1, N, 1)]) + rng8 = subsets.Range([(0, N, 1), (0, 5, 1)]) + rng9 = subsets.Range([(0, N, 1), (0, 0, 1)]) + ind1 = subsets.Indices([0, 1]) + + assert subsets.difference(rng1, rng2) == subsets.Range([(1, N - 1, 1), (1, M - 1, 1)]) + assert subsets.difference(rng1, rng3_1) == rng1 + assert subsets.difference(rng1, rng3_2) == rng1 + assert subsets.difference(rng1, rng4) == rng1 + assert subsets.difference(rng1, rng5) == rng1 + assert subsets.difference(rng6, rng1) == rng4 + assert subsets.difference(rng1, rng7) is None + assert subsets.difference(rng7, rng1) is None + assert subsets.difference(rng1, ind1) is None + assert subsets.difference(ind1, rng1) is None + assert subsets.difference(rng8, rng9) == subsets.Range([(0, N, 1), (1, 5, 1)]) + + +def test_difference_constant(): + rng1 = subsets.Range([(0, 4, 1)]) + rng2 = subsets.Range([(3, 4, 1)]) + rng3 = subsets.Range([(1, 5, 1)]) + rng4 = subsets.Range([(5, 7, 1)]) + rng5_1 = subsets.Range([(0, 6, 1)]) + rng5_2 = subsets.Range([(3, 6, 1)]) + ind1 = subsets.Indices([0]) + ind2 = subsets.Indices([1]) + ind3 = subsets.Indices([5]) + ind4 = subsets.Indices([3]) + ind5 = subsets.Indices([6]) + ind6 = subsets.Indices([4]) + + assert subsets.difference(rng1, rng2) == subsets.Range([(0, 2, 1)]) + assert subsets.difference(rng1, rng3) == subsets.Range([(0, 0, 1)]) + assert subsets.difference(rng1, rng4) == rng1 + assert subsets.difference(ind1, rng1) == subsets.Range([]) + assert str(subsets.difference(rng1, ind2)) == '0 2:5' + assert str(subsets.difference(subsets.difference(rng1, ind2), ind4)) == '0 2 4' + assert str(subsets.difference(ind4, subsets.difference(rng1, ind2))) == '0 2 4' + assert subsets.difference(rng1, ind3) == rng1 + + first_diff = subsets.difference(rng5_1, ind2) + second_diff = subsets.difference(subsets.difference(rng5_2, ind6), ind5) + assert str(subsets.difference(first_diff, second_diff)) == '0 2 4 6' + + +if __name__ == '__main__': + test_intersects_symbolic() + test_intersects_constant() + test_covers_symbolic() + + test_squeeze_unsqueeze_indices() + test_squeeze_unsqueeze_ranges() + + test_difference_symbolic() + test_difference_constant()