diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index 7f9059515..ab6c41e5e 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -192,10 +192,7 @@ def main(ctx_factory, dim=2, order=3, else: actx = actx_class(comm, queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), - force_device_scalars=True, - comm_tag_to_mpi_tag={ - _WaveStateTag: 1234, - }) + force_device_scalars=True) from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis mesh_dist = MPIMeshDistributor(comm) diff --git a/grudge/array_context.py b/grudge/array_context.py index 7be890e67..675e47a29 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -32,8 +32,7 @@ # {{{ imports from typing import ( - TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, - Hashable, Type) + TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type) from dataclasses import dataclass from meshmode.array_context import ( @@ -240,16 +239,10 @@ def __init__(self, queue: "pyopencl.CommandQueue", *, allocator: Optional["pyopencl.tools.AllocatorInterface"] = None, wait_event_queue_length: Optional[int] = None, - force_device_scalars: bool = False, - comm_tag_to_mpi_tag: Optional[Mapping[Hashable, int]] = None) -> None: + force_device_scalars: bool = False) -> None: """ See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most arguments. - - :arg comm_tag_to_mpi_tag: A mapping from symbolic tags used - in the *comm_tag* argument of - :func:`grudge.trace_pair.cross_rank_trace_pairs` to numeric values - to be used with MPI. """ super().__init__(queue, allocator=allocator, wait_event_queue_length=wait_event_queue_length, @@ -257,19 +250,13 @@ def __init__(self, self.mpi_communicator = mpi_communicator - if comm_tag_to_mpi_tag is None: - comm_tag_to_mpi_tag = {} - - self.comm_tag_to_mpi_tag = comm_tag_to_mpi_tag - def clone(self): # type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member # pylint: disable=no-member return type(self)(self.mpi_communicator, self.queue, allocator=self.allocator, wait_event_queue_length=self._wait_event_queue_length, - force_device_scalars=self._force_device_scalars, - comm_tag_to_mpi_tag=self.comm_tag_to_mpi_tag) + force_device_scalars=self._force_device_scalars) # }}} diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 2b9cf14f7..3e6278492 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -46,7 +46,9 @@ """ -from typing import List, Hashable, Optional +from typing import List, Hashable, Optional, Type, Any + +from pytools.persistent_dict import KeyBuilder from arraycontext import ( ArrayContainer, @@ -432,6 +434,11 @@ def finish(self): exterior=bdry_conn(self.remote_data)) +class _TagKeyBuilder(KeyBuilder): + def update_for_type(self, key_hash, key: Type[Any]): + self.rec(key_hash, (key.__module__, key.__name__, key.__name__,)) + + def cross_rank_trace_pairs( dcoll: DiscretizationCollection, ary, comm_tag: Hashable = None, @@ -492,16 +499,22 @@ def cross_rank_trace_pairs( if isinstance(comm_tag, int): num_tag = comm_tag - from grudge.array_context import MPIPyOpenCLArrayContext - if isinstance(actx, MPIPyOpenCLArrayContext): - num_tag = actx.comm_tag_to_mpi_tag.get(comm_tag) - if num_tag is None: - raise ValueError("Encountered unknown symbolic tag " - f"'{comm_tag}'. To make this symbolic tag work, " - f"use 'grudge.array_context.MPIPyOpenCLArrayContext' and " - "assign this tag a numerical value via its " - "comm_tag_to_mpi_tag attribute.") + # FIXME: This isn't guaranteed to be correct. + # See here for discussion: + # - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 # noqa + # - https://github.com/inducer/grudge/pull/222 + from mpi4py import MPI + tag_ub = actx.mpi_communicator.Get_attr(MPI.TAG_UB) + key_builder = _TagKeyBuilder() + digest = key_builder(comm_tag) + num_tag = sum(ord(ch) << i for i, ch in enumerate(digest)) % tag_ub + + from warnings import warn + warn("Encountered unknown symbolic tag " + f"'{comm_tag}', assigning a value of '{num_tag}'. " + "This is a temporary workaround, please ensure that " + "tags are sufficiently distinct for your use case.") comm_tag = num_tag diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 2dcf73339..82a6cc234 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -46,7 +46,10 @@ from pytools.obj_array import flat_obj_array import grudge.op as op -from testlib import SimpleTag + + +class SimpleTag: + pass # {{{ mpi test infrastructure @@ -86,8 +89,7 @@ def run_test_with_mpi_inner(): if actx_class is MPIPytatoArrayContext: actx = actx_class(comm, queue, mpi_base_tag=15000) elif actx_class is MPIPyOpenCLArrayContext: - actx = actx_class(comm, queue, force_device_scalars=True, - comm_tag_to_mpi_tag={SimpleTag: 15000}) + actx = actx_class(comm, queue, force_device_scalars=True) else: raise ValueError("unknown actx_class") diff --git a/test/testlib.py b/test/testlib.py deleted file mode 100644 index 85d8bb395..000000000 --- a/test/testlib.py +++ /dev/null @@ -1,5 +0,0 @@ -# Needed here because MPI test orchestration imports the test module twice, -# leading to two nominally different tag types. Grrr. - -class SimpleTag: - pass