Skip to content

Commit

Permalink
Add unit dimensions even in the middle. And a typo fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 25, 2024
1 parent 3194cc4 commit 7cece3d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def size_exact(self):
]

def volume_exact(self) -> int:
""" Returns the total number of elements in all dimenssions together. """
""" Returns the total number of elements in all dimensions together. """
return reduce(operator.mul, self.size_exact())

def bounding_box_size(self):
Expand Down Expand Up @@ -910,7 +910,7 @@ def size_exact(self):
return self.size()

def volume_exact(self) -> int:
""" Returns the total number of elements in all dimenssions together. """
""" Returns the total number of elements in all dimensions together. """
return reduce(operator.mul, self.size_exact())

def min_element(self):
Expand Down
12 changes: 9 additions & 3 deletions tests/subsets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_mapping_with_reshaping_unit_dims(self):
# A regular cube.
src = Range([(0, K - 1, 1), (0, N - 1, 1), (0, M - 1, 1), (0, 0, 1)])
# A regular cube with different shape.
dst = Range([(0, K - 1, 1), (0, N * M - 1, 1), (0, 0, 1), (0, 0, 1)])
dst = Range([(0, K - 1, 1), (0, 0, 1), (0, N * M - 1, 1), (0, 0, 1), (0, 0, 1)])
# A Mapper
sm = SubrangeMapper(src, dst)
sm_inv = SubrangeMapper(dst, src)
Expand All @@ -179,7 +179,10 @@ def test_mapping_with_reshaping_unit_dims(self):
# Pick a point K//2, N//2, M//2.
for args in argslist:
orig = Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1), (0, 0, 1)])
orig_maps_to = Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1), (0, 0, 1), (0, 0, 1)])
orig_maps_to = Range([(K // 2, K // 2, 1),
(0, 0, 1),
((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1),
(0, 0, 1), (0, 0, 1)])
want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args)
self.assertEqual(want, got)
want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args)
Expand All @@ -191,7 +194,10 @@ def test_mapping_with_reshaping_unit_dims(self):
# Pick only points in problematic quadrants, but larger subsets elsewhere.
for args in argslist:
orig = Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1), (0, 0, 1)])
orig_maps_to = Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1), (0, 0, 1), (0, 0, 1)])
orig_maps_to = Range([(0, K // 2, 1),
(0, 0, 1),
((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1),
(0, 0, 1), (0, 0, 1)])
want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args)
self.assertEqual(want, got)
want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args)
Expand Down

0 comments on commit 7cece3d

Please sign in to comment.