Skip to content

Commit 04b1365

Browse files
committed
Drop comm_tag_to_mpi_tag
1 parent 6e9a758 commit 04b1365

File tree

5 files changed

+13
-34
lines changed

5 files changed

+13
-34
lines changed

examples/wave/wave-op-mpi.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,7 @@ def main(ctx_factory, dim=2, order=3,
192192
else:
193193
actx = actx_class(comm, queue,
194194
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
195-
force_device_scalars=True,
196-
comm_tag_to_mpi_tag={
197-
_WaveStateTag: 1234,
198-
})
195+
force_device_scalars=True)
199196

200197
from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
201198
mesh_dist = MPIMeshDistributor(comm)

grudge/array_context.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -240,36 +240,24 @@ def __init__(self,
240240
queue: "pyopencl.CommandQueue",
241241
*, allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
242242
wait_event_queue_length: Optional[int] = None,
243-
force_device_scalars: bool = False,
244-
comm_tag_to_mpi_tag: Optional[Mapping[Hashable, int]] = None) -> None:
243+
force_device_scalars: bool = False) -> None:
245244
"""
246245
See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most
247246
arguments.
248-
249-
:arg comm_tag_to_mpi_tag: A mapping from symbolic tags used
250-
in the *comm_tag* argument of
251-
:func:`grudge.trace_pair.cross_rank_trace_pairs` to numeric values
252-
to be used with MPI.
253247
"""
254248
super().__init__(queue, allocator=allocator,
255249
wait_event_queue_length=wait_event_queue_length,
256250
force_device_scalars=force_device_scalars)
257251

258252
self.mpi_communicator = mpi_communicator
259253

260-
if comm_tag_to_mpi_tag is None:
261-
comm_tag_to_mpi_tag = {}
262-
263-
self.comm_tag_to_mpi_tag = comm_tag_to_mpi_tag
264-
265254
def clone(self):
266255
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
267256
# pylint: disable=no-member
268257
return type(self)(self.mpi_communicator, self.queue,
269258
allocator=self.allocator,
270259
wait_event_queue_length=self._wait_event_queue_length,
271-
force_device_scalars=self._force_device_scalars,
272-
comm_tag_to_mpi_tag=self.comm_tag_to_mpi_tag)
260+
force_device_scalars=self._force_device_scalars)
273261

274262
# }}}
275263

grudge/trace_pair.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -492,19 +492,16 @@ def cross_rank_trace_pairs(
492492
if isinstance(comm_tag, int):
493493
num_tag = comm_tag
494494

495-
from grudge.array_context import MPIPyOpenCLArrayContext
496-
if isinstance(actx, MPIPyOpenCLArrayContext):
497-
num_tag = actx.comm_tag_to_mpi_tag.get(comm_tag)
498-
499495
if num_tag is None:
496+
# FIXME: This isn't guaranteed to be correct.
497+
# See here for discussion:
498+
# https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716
500499
num_tag = hash(comm_tag)
501500
from warnings import warn
502501
warn("Encountered unknown symbolic tag "
503502
f"'{comm_tag}', assigning a value of '{num_tag}'. "
504-
"To use a different value, "
505-
"use 'grudge.array_context.MPIPyOpenCLArrayContext' and "
506-
"assign this tag a numerical value via its "
507-
"comm_tag_to_mpi_tag attribute.")
503+
"This is a temporary workaround, please ensure that "
504+
"tags are sufficiently distinct for your use case.")
508505

509506
comm_tag = num_tag
510507

test/test_mpi_communication.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
from pytools.obj_array import flat_obj_array
4747

4848
import grudge.op as op
49-
from testlib import SimpleTag
49+
50+
51+
class SimpleTag:
52+
pass
5053

5154

5255
# {{{ mpi test infrastructure
@@ -86,8 +89,7 @@ def run_test_with_mpi_inner():
8689
if actx_class is MPIPytatoArrayContext:
8790
actx = actx_class(comm, queue, mpi_base_tag=15000)
8891
elif actx_class is MPIPyOpenCLArrayContext:
89-
actx = actx_class(comm, queue, force_device_scalars=True,
90-
comm_tag_to_mpi_tag={SimpleTag: 15000})
92+
actx = actx_class(comm, queue, force_device_scalars=True)
9193
else:
9294
raise ValueError("unknown actx_class")
9395

test/testlib.py

-5
This file was deleted.

0 commit comments

Comments
 (0)