From 6e9a758b1877926f564d0683d13e5fe9bd57a8db Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 2 Mar 2022 09:44:17 -0600 Subject: [PATCH 1/6] use hash of comm_tag if not numeric --- grudge/trace_pair.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 2b9cf14f7..4e27f07d8 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -497,9 +497,12 @@ def cross_rank_trace_pairs( num_tag = actx.comm_tag_to_mpi_tag.get(comm_tag) if num_tag is None: - raise ValueError("Encountered unknown symbolic tag " - f"'{comm_tag}'. To make this symbolic tag work, " - f"use 'grudge.array_context.MPIPyOpenCLArrayContext' and " + num_tag = hash(comm_tag) + from warnings import warn + warn("Encountered unknown symbolic tag " + f"'{comm_tag}', assigning a value of '{num_tag}'. " + "To use a different value, " + "use 'grudge.array_context.MPIPyOpenCLArrayContext' and " "assign this tag a numerical value via its " "comm_tag_to_mpi_tag attribute.") From 6b2ec0cd7bd82b0035af6e682c8bffc893074f36 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Mar 2022 10:17:03 -0600 Subject: [PATCH 2/6] Drop comm_tag_to_mpi_tag --- examples/wave/wave-op-mpi.py | 5 +---- grudge/array_context.py | 19 +++---------------- grudge/trace_pair.py | 13 +++++-------- test/test_mpi_communication.py | 8 +++++--- test/testlib.py | 5 ----- 5 files changed, 14 insertions(+), 36 deletions(-) delete mode 100644 test/testlib.py diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index 7f9059515..ab6c41e5e 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -192,10 +192,7 @@ def main(ctx_factory, dim=2, order=3, else: actx = actx_class(comm, queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), - force_device_scalars=True, - comm_tag_to_mpi_tag={ - _WaveStateTag: 1234, - }) + force_device_scalars=True) from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis mesh_dist = MPIMeshDistributor(comm) diff --git a/grudge/array_context.py b/grudge/array_context.py index 7be890e67..675e47a29 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -32,8 +32,7 @@ # {{{ imports from typing import ( - TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, - Hashable, Type) + TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type) from dataclasses import dataclass from meshmode.array_context import ( @@ -240,16 +239,10 @@ def __init__(self, queue: "pyopencl.CommandQueue", *, allocator: Optional["pyopencl.tools.AllocatorInterface"] = None, wait_event_queue_length: Optional[int] = None, - force_device_scalars: bool = False, - comm_tag_to_mpi_tag: Optional[Mapping[Hashable, int]] = None) -> None: + force_device_scalars: bool = False) -> None: """ See :class:`arraycontext.impl.pyopencl.PyOpenCLArrayContext` for most arguments. - - :arg comm_tag_to_mpi_tag: A mapping from symbolic tags used - in the *comm_tag* argument of - :func:`grudge.trace_pair.cross_rank_trace_pairs` to numeric values - to be used with MPI. """ super().__init__(queue, allocator=allocator, wait_event_queue_length=wait_event_queue_length, @@ -257,19 +250,13 @@ def __init__(self, self.mpi_communicator = mpi_communicator - if comm_tag_to_mpi_tag is None: - comm_tag_to_mpi_tag = {} - - self.comm_tag_to_mpi_tag = comm_tag_to_mpi_tag - def clone(self): # type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member # pylint: disable=no-member return type(self)(self.mpi_communicator, self.queue, allocator=self.allocator, wait_event_queue_length=self._wait_event_queue_length, - force_device_scalars=self._force_device_scalars, - comm_tag_to_mpi_tag=self.comm_tag_to_mpi_tag) + force_device_scalars=self._force_device_scalars) # }}} diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 4e27f07d8..9275e7cf7 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -492,19 +492,16 @@ def cross_rank_trace_pairs( if isinstance(comm_tag, int): num_tag = comm_tag - from grudge.array_context import MPIPyOpenCLArrayContext - if isinstance(actx, MPIPyOpenCLArrayContext): - num_tag = actx.comm_tag_to_mpi_tag.get(comm_tag) - if num_tag is None: + # FIXME: This isn't guaranteed to be correct. + # See here for discussion: + # https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 num_tag = hash(comm_tag) from warnings import warn warn("Encountered unknown symbolic tag " f"'{comm_tag}', assigning a value of '{num_tag}'. " - "To use a different value, " - "use 'grudge.array_context.MPIPyOpenCLArrayContext' and " - "assign this tag a numerical value via its " - "comm_tag_to_mpi_tag attribute.") + "This is a temporary workaround, please ensure that " + "tags are sufficiently distinct for your use case.") comm_tag = num_tag diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 2dcf73339..82a6cc234 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -46,7 +46,10 @@ from pytools.obj_array import flat_obj_array import grudge.op as op -from testlib import SimpleTag + + +class SimpleTag: + pass # {{{ mpi test infrastructure @@ -86,8 +89,7 @@ def run_test_with_mpi_inner(): if actx_class is MPIPytatoArrayContext: actx = actx_class(comm, queue, mpi_base_tag=15000) elif actx_class is MPIPyOpenCLArrayContext: - actx = actx_class(comm, queue, force_device_scalars=True, - comm_tag_to_mpi_tag={SimpleTag: 15000}) + actx = actx_class(comm, queue, force_device_scalars=True) else: raise ValueError("unknown actx_class") diff --git a/test/testlib.py b/test/testlib.py deleted file mode 100644 index 85d8bb395..000000000 --- a/test/testlib.py +++ /dev/null @@ -1,5 +0,0 @@ -# Needed here because MPI test orchestration imports the test module twice, -# leading to two nominally different tag types. Grrr. - -class SimpleTag: - pass From 5041828d34b2343b1efc0597bd0889fd1a23247c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 2 Mar 2022 10:48:41 -0600 Subject: [PATCH 3/6] modulo tag_ub --- grudge/trace_pair.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 9275e7cf7..b24d15eda 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -495,8 +495,11 @@ def cross_rank_trace_pairs( if num_tag is None: # FIXME: This isn't guaranteed to be correct. # See here for discussion: - # https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 - num_tag = hash(comm_tag) + # - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 + # - https://github.com/inducer/grudge/pull/222 + tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB) + num_tag = hash(comm_tag) % tag_ub + from warnings import warn warn("Encountered unknown symbolic tag " f"'{comm_tag}', assigning a value of '{num_tag}'. " From 42b4fd323cb0451fb8aeb70b9a8b193edf2fc354 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 2 Mar 2022 10:54:44 -0600 Subject: [PATCH 4/6] flake8 --- grudge/trace_pair.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index b24d15eda..2f3372137 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -495,8 +495,9 @@ def cross_rank_trace_pairs( if num_tag is None: # FIXME: This isn't guaranteed to be correct. # See here for discussion: - # - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 + # - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 # noqa # - https://github.com/inducer/grudge/pull/222 + from mpi4py import MPI tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB) num_tag = hash(comm_tag) % tag_ub From 2caeff58bab30a3158ac8f9cd5239d08d7a940a8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 2 Mar 2022 13:19:06 -0600 Subject: [PATCH 5/6] use actx mpi_comm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- grudge/trace_pair.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 2f3372137..e142a954e 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -498,7 +498,7 @@ def cross_rank_trace_pairs( # - https://github.com/illinois-ceesd/mirgecom/issues/617#issuecomment-1057082716 # noqa # - https://github.com/inducer/grudge/pull/222 from mpi4py import MPI - tag_ub = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB) + tag_ub = actx.mpi_communicator.Get_attr(MPI.TAG_UB) num_tag = hash(comm_tag) % tag_ub from warnings import warn From 9199fe2e2075251b6c67a263de46ff9bf0f82a45 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Mar 2022 14:13:16 -0600 Subject: [PATCH 6/6] User persistent_dict hashing to guess a tag --- grudge/trace_pair.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index e142a954e..3e6278492 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -46,7 +46,9 @@ """ -from typing import List, Hashable, Optional +from typing import List, Hashable, Optional, Type, Any + +from pytools.persistent_dict import KeyBuilder from arraycontext import ( ArrayContainer, @@ -432,6 +434,11 @@ def finish(self): exterior=bdry_conn(self.remote_data)) +class _TagKeyBuilder(KeyBuilder): + def update_for_type(self, key_hash, key: Type[Any]): + self.rec(key_hash, (key.__module__, key.__name__, key.__name__,)) + + def cross_rank_trace_pairs( dcoll: DiscretizationCollection, ary, comm_tag: Hashable = None, @@ -499,7 +506,9 @@ def cross_rank_trace_pairs( # - https://github.com/inducer/grudge/pull/222 from mpi4py import MPI tag_ub = actx.mpi_communicator.Get_attr(MPI.TAG_UB) - num_tag = hash(comm_tag) % tag_ub + key_builder = _TagKeyBuilder() + digest = key_builder(comm_tag) + num_tag = sum(ord(ch) << i for i, ch in enumerate(digest)) % tag_ub from warnings import warn warn("Encountered unknown symbolic tag "