Skip to content

Commit

Permalink
api: enforce interpolation radius to be smaller than any input space …
Browse files Browse the repository at this point in the history
…order
  • Loading branch information
mloubout committed Oct 12, 2023
1 parent 617ce82 commit df1ea3c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
17 changes: 16 additions & 1 deletion devito/operations/interpolators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from functools import wraps

import sympy
from cached_property import cached_property

from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.symbolics import retrieve_function_carriers, INT
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension)
Expand All @@ -14,6 +15,18 @@
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']


def check_radius(func):
@wraps(func)
def wrapper(interp, *args, **kwargs):
r = interp.sfunction.r
funcs = set(retrieve_functions(args)) - {interp.sfunction}
so = min({f.space_order for f in funcs})
if so < r:
raise ValueError("Space order %d smaller than interpolation r %d" % (so, r))
return func(interp, *args, **kwargs)
return wrapper


class UnevaluatedSparseOperation(sympy.Expr, Evaluable):

"""
Expand Down Expand Up @@ -209,6 +222,7 @@ def _interp_idx(self, variables, implicit_dims=None):

return idx_subs, temps

@check_radius
def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Expand All @@ -226,6 +240,7 @@ def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
"""
return Interpolation(expr, increment, implicit_dims, self_subs, self)

@check_radius
def inject(self, field, expr, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.
Expand Down
13 changes: 13 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,16 @@ def test_inject_function():
for i in [0, 1, 3, 4]:
for j in [0, 1, 3, 4]:
assert u.data[1, i, j] == 0


def test_interpolation_radius():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, space_order=0)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1)
try:
src.interpolate(u)
assert False
except ValueError:
assert True
10 changes: 5 additions & 5 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def test_sparsefunction_inject(self):
Test injection of a SparseFunction into a Function
"""
grid = Grid(shape=(11, 11))
u = Function(name='u', grid=grid, space_order=0)
u = Function(name='u', grid=grid, space_order=1)

sf1 = SparseFunction(name='s', grid=grid, npoint=1)
op = Operator(sf1.inject(u, expr=sf1))
Expand All @@ -542,7 +542,7 @@ def test_sparsefunction_interp(self):
Test interpolation of a SparseFunction from a Function
"""
grid = Grid(shape=(11, 11))
u = Function(name='u', grid=grid, space_order=0)
u = Function(name='u', grid=grid, space_order=1)

sf1 = SparseFunction(name='s', grid=grid, npoint=1)
op = Operator(sf1.interpolate(u))
Expand All @@ -563,7 +563,7 @@ def test_sparsetimefunction_interp(self):
Test injection of a SparseTimeFunction into a TimeFunction
"""
grid = Grid(shape=(11, 11))
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)

sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
op = Operator(sf1.interpolate(u))
Expand All @@ -586,7 +586,7 @@ def test_sparsetimefunction_inject(self):
Test injection of a SparseTimeFunction from a TimeFunction
"""
grid = Grid(shape=(11, 11))
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)

sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5)
op = Operator(sf1.inject(u, expr=3*sf1))
Expand All @@ -611,7 +611,7 @@ def test_sparsetimefunction_inject_dt(self):
Test injection of the time deivative of a SparseTimeFunction into a TimeFunction
"""
grid = Grid(shape=(11, 11))
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0)
u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1)

sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5, time_order=2)

Expand Down

0 comments on commit df1ea3c

Please sign in to comment.