Skip to content

Commit

Permalink
Merge pull request #2143 from devitocodes/compute0_special
Browse files Browse the repository at this point in the history
mpi: Instrument compute0 core after specialising as ComputeCall
  • Loading branch information
mloubout authored Jun 28, 2023
2 parents 772243b + cf1ba20 commit 3d94984
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 50 deletions.
86 changes: 44 additions & 42 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,42 @@

# ElementalFunction machinery

class ElementalCall(Call):

def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
incr=False, retobj=None, is_indirect=False):
self._mapper = mapper or {}

arguments = list(as_tuple(arguments))
dynamic_args_mapper = dynamic_args_mapper or {}
for k, v in dynamic_args_mapper.items():
tv = as_tuple(v)

# Sanity check
if k not in self._mapper:
raise ValueError("`k` is not a dynamic parameter" % k)
if len(self._mapper[k]) != len(tv):
raise ValueError("Expected %d values for dynamic parameter `%s`, given %d"
% (len(self._mapper[k]), k, len(tv)))
# Create the argument list
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)

@cached_property
def dynamic_defaults(self):
return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()}


class ElementalFunction(Callable):

"""
Expand All @@ -21,6 +57,7 @@ class ElementalFunction(Callable):
supplying bounds and step increment for each Dimension listed in
``dynamic_parameters``.
"""
_Call_cls = ElementalCall

is_ElementalFunction = True

Expand All @@ -47,53 +84,18 @@ def dynamic_defaults(self):

def make_call(self, dynamic_args_mapper=None, incr=False, retobj=None,
is_indirect=False):
return ElementalCall(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


class ElementalCall(Call):

def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
incr=False, retobj=None, is_indirect=False):
self._mapper = mapper or {}

arguments = list(as_tuple(arguments))
dynamic_args_mapper = dynamic_args_mapper or {}
for k, v in dynamic_args_mapper.items():
tv = as_tuple(v)

# Sanity check
if k not in self._mapper:
raise ValueError("`k` is not a dynamic parameter" % k)
if len(self._mapper[k]) != len(tv):
raise ValueError("Expected %d values for dynamic parameter `%s`, given %d"
% (len(self._mapper[k]), k, len(tv)))
# Create the argument list
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)

@cached_property
def dynamic_defaults(self):
return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()}
return self._Call_cls(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static'):
def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static',
efunc_type=ElementalFunction):
"""
Shortcut to create an ElementalFunction.
"""
return ElementalFunction(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)
return efunc_type(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)


# Callable machinery
Expand Down
16 changes: 13 additions & 3 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir.iet import (Call, Callable, Conditional, ElementalFunction,
Expression, ExpressionBundle, AugmentedExpression,
Iteration, List, Prodder, Return, make_efunc, FindNodes,
Transformer)
Transformer, ElementalCall)
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
Expand Down Expand Up @@ -572,6 +572,14 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
return HaloUpdate('haloupdate%s' % key, iet, parameters)


class ComputeCall(ElementalCall):
pass


class ComputeFunction(ElementalFunction):
_Call_cls = ComputeCall


class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder):

"""
Expand Down Expand Up @@ -647,7 +655,8 @@ def _make_compute(self, hs, key, *args):
if hs.body.is_Call:
return None
else:
return make_efunc('compute%d' % key, hs.body, hs.arguments)
return make_efunc('compute%d' % key, hs.body, hs.arguments,
efunc_type=ComputeFunction)

def _call_compute(self, hs, compute, *args):
if compute is None:
Expand Down Expand Up @@ -952,7 +961,8 @@ def _make_compute(self, hs, key, msgs, callpoke):
mapper = {i: List(body=[callpoke, i]) for i in
FindNodes(ExpressionBundle).visit(hs.body)}
iet = Transformer(mapper).visit(hs.body)
return make_efunc('compute%d' % key, iet, hs.arguments)
return make_efunc('compute%d' % key, iet, hs.arguments,
efunc_type=ComputeFunction)

def _make_poke(self, hs, key, msgs):
lflag = Symbol(name='lflag')
Expand Down
4 changes: 2 additions & 2 deletions devito/operator/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from devito.ir.support import IntervalGroup
from devito.logger import warning, error
from devito.mpi import MPI
from devito.mpi.routines import MPICall, MPIList, RemainderCall
from devito.mpi.routines import MPICall, MPIList, RemainderCall, ComputeCall
from devito.parameters import configuration
from devito.symbolics import subs_op_args
from devito.tools import DefaultOrderedDict, flatten
Expand Down Expand Up @@ -332,7 +332,7 @@ class AdvancedProfilerVerbose2(AdvancedProfilerVerbose):

@property
def trackable_subsections(self):
return (MPICall, BusyWait)
return (MPICall, BusyWait, ComputeCall)


class AdvisorProfiler(AdvancedProfiler):
Expand Down
6 changes: 4 additions & 2 deletions devito/passes/iet/instrument.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from devito.ir.iet import (BusyWait, FindNodes, FindSymbols, MapNodes, Section,
TimedList, Transformer)
from devito.mpi.routines import (HaloUpdateCall, HaloWaitCall, MPICall, MPIList,
HaloUpdateList, HaloWaitList, RemainderCall)
HaloUpdateList, HaloWaitList, RemainderCall,
ComputeCall)
from devito.passes.iet.engine import iet_pass
from devito.types import Timer

Expand Down Expand Up @@ -36,14 +37,15 @@ def track_subsections(iet, **kwargs):
HaloUpdateCall: 'haloupdate',
HaloWaitCall: 'halowait',
RemainderCall: 'remainder',
ComputeCall: 'compute',
HaloUpdateList: 'haloupdate',
HaloWaitList: 'halowait',
BusyWait: 'busywait'
}

mapper = {}

for NodeType in [MPIList, MPICall, BusyWait]:
for NodeType in [MPIList, MPICall, BusyWait, ComputeCall]:
for k, v in MapNodes(Section, NodeType).visit(iet).items():
for i in v:
if i in mapper or not any(issubclass(i.__class__, n)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
retrieve_iteration_tree)
from devito.mpi import MPI
from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall
from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall, ComputeCall
from devito.mpi.distributed import CustomTopology
from devito.tools import Bunch

from examples.seismic.acoustic import acoustic_setup

pytestmark = skipif(['nompi'], whole_module=True)
Expand Down Expand Up @@ -1400,6 +1401,7 @@ def test_min_code_size(self):
assert len(op._func_table) == 7
assert len(calls) == 4
assert 'haloupdate1' not in op._func_table
assert len(FindNodes(ComputeCall).visit(op)) == 1

@pytest.mark.parallel(mode=[(1, 'diag2')])
def test_many_functions(self):
Expand All @@ -1418,6 +1420,23 @@ def test_many_functions(self):
assert len(calls) == 2
assert calls[0].ncomps == 7

@switchconfig(profiling='advanced2')
@pytest.mark.parallel(mode=[
(1, 'full'),
])
def test_profiled_regions(self):
grid = Grid(shape=(10, 10, 10))

f = TimeFunction(name='f', grid=grid, space_order=2)
g = TimeFunction(name='g', grid=grid, space_order=2)

eqns = [Eq(f.forward, f.dx2 + 1.),
Eq(g.forward, g.dx2 + 1.)]

op = Operator(eqns)
assert op._profiler.all_sections == ['section0', 'haloupdate0', 'halowait0',
'remainder0', 'compute0']

@pytest.mark.parallel(mode=1)
def test_enforce_haloupdate_if_unwritten_function(self):
grid = Grid(shape=(16, 16))
Expand Down

0 comments on commit 3d94984

Please sign in to comment.