Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a SubrangeMapper helper class which maps a _subrange_ of src range to its counterpart in dst range, if possible. #1702

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Add some explanatory comments.
  • Loading branch information
pratyai committed Oct 25, 2024
commit 01b9e7073e12f188082256815bb78990b32c4a93
13 changes: 9 additions & 4 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,9 @@ def map(self, r: Range) -> Optional[Range]:
while src_i < self.src.dims():
assert dst_i < self.dst.dims()

# Find the next smallest segments of `src` and `dst` whose volumes matches (and therefore can possibly have
# a mapping).
# TODO: It's possible to do this in a O(max(|src|, |dst|)) loop instead of O(|src| * |dst|).
src_j, dst_j = None, None
for sj in range(src_i + 1, self.src.dims() + 1):
for dj in range(dst_i + 1, self.dst.dims() + 1):
Expand All @@ -1404,12 +1407,14 @@ def map(self, r: Range) -> Optional[Range]:
continue
break
if src_j is None:
# Somehow, we couldn't find a matching segment. This should have been caught earlier.
return None

# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
src_segment, dst_segment, r_segment = Range(self.src.ranges[src_i: src_j]), Range(
self.dst.ranges[dst_i: dst_j]), Range(r.ranges[src_i: src_j])
src_segment = Range(self.src.ranges[src_i: src_j])
dst_segment = Range(self.dst.ranges[dst_i: dst_j])
r_segment = Range(r.ranges[src_i: src_j])
if r_segment.volume_exact() == 1:
# If we are selecting just a single point in this segment, we can just pick the mapping of that point.
# Compute the local 1D coordinate of the point on `src`.
loc = 0
for (idx, _, _), (ridx, _, _), s in zip(reversed(src_segment.ranges),
Expand All @@ -1427,7 +1432,7 @@ def map(self, r: Range) -> Optional[Range]:
# its entirety too.
out.extend(self.dst.ranges[dst_i:dst_j])
elif src_j - src_i == 1 and dst_j - dst_i == 1:
# If the segment lengths on both sides are just 1, the mapping is easy to compute.
# If the segment lengths on both sides are just 1, the mapping is easy to compute -- it's just a shift.
sb, se, ss = self.src.ranges[src_i]
db, de, ds = self.dst.ranges[dst_i]
b, e, s = r.ranges[src_i]
Expand Down