Skip to content

Commit 9c7e8d3

Browse files
committed
mpi: Simplify Custom domain decomposition
1 parent 422b3f3 commit 9c7e8d3

File tree

2 files changed

+40
-65
lines changed

2 files changed

+40
-65
lines changed

devito/mpi/distributed.py

+37-49
Original file line numberDiff line numberDiff line change
@@ -567,31 +567,29 @@ def _arg_values(self, *args, **kwargs):
567567
class CustomTopology(tuple):
568568

569569
"""
570-
A CustomTopology is a mechanism to describe parametric domain decompositions.
570+
The CustomTopology class provides a mechanism to describe parametric domain
571+
decompositions. It allows users to specify how the dimensions of a domain are
572+
decomposed into chunks based on certain parameters.
571573
572574
Examples
573575
--------
574-
Assuming a domain consisting of three distributed Dimensions x, y, and z, and
575-
an MPI communicator comprising N processes, a CustomTopology might be:
576+
For example, let's consider a domain with three distributed dimensions: x, y, and z,
577+
and an MPI communicator with N processes. Here are a few examples of CustomTopology:
576578
577579
With N known, say N=4:
578-
579580
* `(1, 1, 4)`: the z Dimension is decomposed into 4 chunks
580-
* `(2, 1, 2)`: the x Dimension is decomposed into 2 chunks; the z Dimension
581+
* `(2, 1, 2)`: the x Dimension is decomposed into 2 chunks and the z Dimension
581582
is decomposed into 2 chunks
582583
583584
With N unknown:
584-
585-
* `(1, '*', 1)`: the wildcard `'*'` tells the runtime to decompose the y
585+
* `(1, '*', 1)`: the wildcard `'*'` indicates that the runtime should decompose the y
586586
Dimension into N chunks
587-
* `('*', '*', 1)`: the wildcard `'*'` tells the runtime to decompose both
587+
* `('*', '*', 1)`: the wildcard `'*'` indicates that the runtime should decompose both
588588
the x and y Dimensions in `nstars` factors of N, prioritizing
589589
the outermost dimension
590590
591-
Assuming N=6 and requested topology is `('*', '*', 1)`,
592-
since there is no integer k, so that k*k=6, we resort to the closest factors to
593-
the nstars-th root (usually square or cubic) that satisfies that the decomposed
594-
domains are equal to the number of MPI processes.
591+
Assuming that the number of ranks `N` cannot evenly be decomposed to the requested
592+
stars=6 we decompose as evenly as possible by prioritising the outermost dimension
595593
596594
For N=3
597595
* `('*', '*', 1)` gives: (3, 1, 1)
@@ -611,58 +609,48 @@ class CustomTopology(tuple):
611609
612610
Notes
613611
-----
614-
Users shouldn't use this class directly. It's up to the Devito runtime to
615-
instantiate it based on the user input.
612+
Users should not directly use the CustomTopology class. It is instantiated
613+
by the Devito runtime based on user input.
616614
"""
617615

618616
def __new__(cls, items, input_comm):
619617
# Keep track of nstars and already defined decompositions
620-
nstars = len([i for i in items if i == '*'])
618+
nstars = items.count('*')
621619

622620
# If no stars exist we are ready
623621
if nstars == 0:
624622
processed = items
625623
else:
626-
# Init decomposition list
624+
# Init decomposition list and track star positions
627625
processed = [1] * len(items)
628-
629-
# Get star and integer indices
630-
int_pos = [i for i, item in enumerate(items) if isinstance(item, int)]
631-
int_vals = [item for item in items if isinstance(item, int)]
632-
star_pos = [i for i, item in enumerate(items) if not isinstance(item, int)]
633-
634-
# Decompose the processes remaining for allocation to prime factors
626+
star_pos = []
627+
for i, item in enumerate(items):
628+
if isinstance(item, int):
629+
processed[i] = item
630+
else:
631+
star_pos.append(i)
632+
633+
# Compute the remaining procs to be allocated
635634
alloc_procs = np.prod([i for i in items if i != '*'])
636-
remprocs = int(input_comm.size // alloc_procs)
637-
prime_factors = primefactors(remprocs)
638-
639-
star_i = -1
640-
dd_list = [1] * nstars
635+
rem_procs = int(input_comm.size // alloc_procs)
641636

642637
# Start by using the max prime factor at the first starred position,
643-
# then cyclically-iteratively decompose as evenly as possible until
644-
# decomposing to the number of `remprocs`
645-
while remprocs != 1:
646-
star_i = star_i + 1
647-
star_i = star_i % nstars
648-
prime_factors = primefactors(remprocs)
649-
dd_list[star_i] = dd_list[star_i]*max(prime_factors)
650-
remprocs = remprocs // max(prime_factors)
651-
652-
if int_pos:
653-
for index, value in zip(int_pos, int_vals):
654-
processed[index] = value
655-
656-
if dd_list:
657-
for index, value in zip(star_pos, dd_list):
658-
processed[index] = value
638+
# then iteratively decompose as evenly as possible until decomposing
639+
# to the number of `rem_procs`
640+
star_vals = [1] * len(items)
641+
star_i = 0
642+
while rem_procs > 1:
643+
prime_factors = primefactors(rem_procs)
644+
rem_procs //= max(prime_factors)
645+
star_vals[star_i] *= max(prime_factors)
646+
star_i = (star_i + 1) % nstars
647+
648+
# Apply computed star values to the processed
649+
for index, value in zip(star_pos, star_vals):
650+
processed[index] = value
659651

660652
# Final check that topology matches the communicator size
661-
try:
662-
assert np.prod(processed) == input_comm.size
663-
except:
664-
raise ValueError("Invalid `topology`", processed, " for given nprocs:",
665-
input_comm.size)
653+
assert np.prod(processed) == input_comm.size
666654

667655
obj = super().__new__(cls, processed)
668656
obj.logical = items

tests/test_mpi.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,13 @@
1212
retrieve_iteration_tree)
1313
from devito.mpi import MPI
1414
from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall
15+
from devito.mpi.distributed import CustomTopology
16+
from devito.tools import Bunch
1517
from examples.seismic.acoustic import acoustic_setup
1618

1719
pytestmark = skipif(['nompi'], whole_module=True)
1820

1921

20-
class DummyInputComm():
21-
"Helper class for modelling a communicator with a specific size"
22-
def __init__(self, size):
23-
self.size = size
24-
25-
2622
class TestDistributor(object):
2723

2824
@pytest.mark.parallel(mode=[2, 4])
@@ -193,12 +189,9 @@ def test_custom_topology(self):
193189
(2, (1, '*', '*'), (1, 2, 1)),
194190
(2, (2, '*', '*'), (2, 1, 1)),
195191
(3, (1, '*', '*'), (1, 3, 1)),
196-
(3, ('*', 1, '*'), (3, 1, 1)),
197192
(3, ('*', '*', 1), (3, 1, 1)),
198193
(4, (2, '*', '*'), (2, 2, 1)),
199-
(4, ('*', 2, '*'), (2, 2, 1)),
200194
(4, ('*', '*', 2), (2, 1, 2)),
201-
(6, ('*', 1, '*'), (3, 1, 2)),
202195
(6, ('*', '*', 1), (3, 2, 1)),
203196
(6, (1, '*', '*'), (1, 3, 2)),
204197
(6, ('*', '*', '*'), (3, 2, 1)),
@@ -211,29 +204,23 @@ def test_custom_topology(self):
211204
(32, ('*', '*', '*'), (4, 4, 2)),
212205
(8, ('*', 1, '*'), (4, 1, 2)),
213206
(8, ('*', '*', 1), (4, 2, 1)),
214-
(8, (1, '*', '*'), (1, 4, 2)),
215207
(8, ('*', '*', '*'), (2, 2, 2)),
216208
(9, ('*', '*', '*'), (3, 3, 1)),
217209
(11, (1, '*', '*'), (1, 11, 1)),
218210
(22, ('*', '*', '*'), (11, 2, 1)),
219-
(16, ('*', '*', 1), (4, 4, 1)),
220211
(16, ('*', 1, '*'), (4, 1, 4)),
221212
(32, ('*', '*', 1), (8, 4, 1)),
222-
(64, ('*', '*', '*'), (4, 4, 4)),
223213
(64, ('*', '*', 1), (8, 8, 1)),
224-
(64, ('*', 2, 1), (32, 2, 1)),
225214
(64, ('*', 2, 4), (8, 2, 4)),
226215
(128, ('*', '*', 1), (16, 8, 1)),
227216
(231, ('*', '*', '*'), (11, 7, 3)),
228217
(256, (1, '*', '*'), (1, 16, 16)),
229-
(256, ('*', 1, '*'), (16, 1, 16)),
230-
(256, ('*', '*', 1), (16, 16, 1)),
231218
(256, ('*', '*', '*'), (8, 8, 4)),
232219
(256, ('*', '*', 2), (16, 8, 2)),
233220
(256, ('*', 32, 2), (4, 32, 2)),
234221
])
235222
def test_custom_topology_3d_dummy(self, comm_size, topology, dist_topology):
236-
dummy_comm = DummyInputComm(comm_size)
223+
dummy_comm = Bunch(size=comm_size)
237224
custom_topology = CustomTopology(topology, dummy_comm)
238225
assert custom_topology == dist_topology
239226

0 commit comments

Comments
 (0)