Skip to content

Commit f26e986

Browse files
committed
mpi: Simplify custom approach using 'factorint' and 'array_split'
1 parent 9c7e8d3 commit f26e986

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

devito/mpi/distributed.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ctypes import c_int, c_void_p, sizeof
33
from itertools import groupby, product
44
from math import ceil, pow
5-
from sympy import primefactors
5+
from sympy import factorint
66

77
import atexit
88

@@ -634,16 +634,15 @@ def __new__(cls, items, input_comm):
634634
alloc_procs = np.prod([i for i in items if i != '*'])
635635
rem_procs = int(input_comm.size // alloc_procs)
636636

637-
# Start by using the max prime factor at the first starred position,
638-
# then iteratively decompose as evenly as possible until decomposing
639-
# to the number of `rem_procs`
640-
star_vals = [1] * len(items)
641-
star_i = 0
642-
while rem_procs > 1:
643-
prime_factors = primefactors(rem_procs)
644-
rem_procs //= max(prime_factors)
645-
star_vals[star_i] *= max(prime_factors)
646-
star_i = (star_i + 1) % nstars
637+
# List of all factors of rem_procs in decreasing order
638+
factors = factorint(rem_procs)
639+
vals = [k for (k, v) in factors.items() for _ in range(v)][::-1]
640+
641+
# Split in number of stars
642+
split = np.array_split(vals, nstars)
643+
644+
# Reduce
645+
star_vals = [int(np.prod(s)) for s in split]
647646

648647
# Apply computed star values to the processed
649648
for index, value in zip(star_pos, star_vals):

tests/test_mpi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_custom_topology(self):
219219
(256, ('*', '*', 2), (16, 8, 2)),
220220
(256, ('*', 32, 2), (4, 32, 2)),
221221
])
222-
def test_custom_topology_3d_dummy(self, comm_size, topology, dist_topology):
222+
def test_custom_topology_v2(self, comm_size, topology, dist_topology):
223223
dummy_comm = Bunch(size=comm_size)
224224
custom_topology = CustomTopology(topology, dummy_comm)
225225
assert custom_topology == dist_topology

0 commit comments

Comments
 (0)