Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Patch symbolic coefficients over cross derivatives #2248

Merged
merged 5 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion devito/finite_differences/coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,33 @@ def generate_subs(deriv_order, function, index):
# NOTE: Do we want to throw a warning if the same arg has
# been provided twice?
args_provided = list(set(args_provided))
not_provided = [i for i in args_present if i not in frozenset(args_provided)]

rules = {}
not_provided = []
for i0 in args_present:
if any(i0 == i1 for i1 in args_provided):
# Perfect match, as expected by the legacy custom coeffs API
continue

# TODO: to make cross-derivs work, we must relax `not_provided` by
# checking not for equality, but rather for inclusion. This is ugly,
# but basically a major revamp is the only alternative... and for now,
# it does the trick
mapper = {}
deriv_order, expr, dim = i0
try:
for k, v in subs.rules.items():
ofs, do, f, d = k.args
if deriv_order == do and dim is d and f in expr._functions:
mapper[k.func(ofs, do, expr, d)] = v
except AttributeError:
assert subs is None

if mapper:
rules.update(mapper)
else:
not_provided.append(i0)

for i in not_provided:
rules = {**rules, **generate_subs(*i)}

Expand Down
23 changes: 16 additions & 7 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def _symbolic_functions(self):
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)

@cached_property
def _coeff_symbol(self, *args, **kwargs):
if self._uses_symbolic_coefficients:
return W
else:
raise ValueError("Couldn't find any symbolic coefficients")

def _eval_at(self, func):
if not func.is_Staggered:
# Cartesian grid, do no waste time
Expand Down Expand Up @@ -327,6 +334,10 @@ def highest_priority(DiffOp):
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


# Abstract symbol representing a symbolic coefficient
W = sympy.Function('W')


class DifferentiableOp(Differentiable):

__sympy_class__ = None
Expand Down Expand Up @@ -606,12 +617,13 @@ def __init_finalize__(self, *args, **kwargs):
assert isinstance(d, StencilDimension) and d.symbolic_size == len(weights)
assert isinstance(weights, (list, tuple, np.ndarray))

try:
self._spacings = set().union(*[i.find(Spacing) for i in weights])
except AttributeError:
self._spacing = set()
# Normalize `weights`
weights = tuple(sympy.sympify(i) for i in weights)

self._spacings = set().union(*[i.find(Spacing) for i in weights])

kwargs['scope'] = 'constant'
kwargs['initvalue'] = weights

super().__init_finalize__(*args, **kwargs)

Expand Down Expand Up @@ -766,9 +778,6 @@ def _new_rawargs(self, *args, **kwargs):
kwargs.pop('is_commutative', None)
return self.func(*args, **kwargs)

def _coeff_symbol(self, *args, **kwargs):
return self.base._coeff_symbol(*args, **kwargs)


class diffify(object):

Expand Down
8 changes: 0 additions & 8 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,6 @@ def coefficients(self):
"""Form of the coefficients of the function."""
return self._coefficients

@cached_property
def _coeff_symbol(self):
if self.coefficients == 'symbolic':
return sympy.Function('W')
else:
raise ValueError("Function was not declared with symbolic "
"coefficients.")

@cached_property
def shape(self):
"""
Expand Down
70 changes: 70 additions & 0 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def test_aggregate_w_custom_coeffs(self):

def test_cross_derivs(self):
grid = Grid(shape=(11, 11, 11))

q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2,
coefficients='symbolic')
q0 = TimeFunction(name='q', grid=grid, space_order=8, time_order=2)
Expand All @@ -389,3 +390,72 @@ def test_cross_derivs(self):

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

def test_cross_derivs_imperfect(self):
grid = Grid(shape=(11, 11, 11))

p = TimeFunction(name='p', grid=grid, space_order=4, time_order=2,
coefficients='symbolic')
q = TimeFunction(name='q', grid=grid, space_order=4, time_order=2,
coefficients='symbolic')

p0 = TimeFunction(name='p', grid=grid, space_order=4, time_order=2)
q0 = TimeFunction(name='q', grid=grid, space_order=4, time_order=2)

eq0 = Eq(q0.forward, (q0.dx + p0.dx).dy)
eq1 = Eq(q.forward, (q.dx + p.dx).dy)

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

def test_nested_subs(self):
grid = Grid(shape=(11, 11))
x, y = grid.dimensions
hx, hy = grid.spacing_symbols

p = TimeFunction(name='p', grid=grid, space_order=2,
coefficients='symbolic')

coeffs0 = np.array([100, 100, 100])
coeffs1 = np.array([200, 200, 200])

subs = Substitutions(Coefficient(1, p, x, coeffs0),
Coefficient(1, p, y, coeffs1))

eq = Eq(p.forward, p.dx.dy, coefficients=subs)

mul = lambda e: sp.Mul(e, 200, evaluate=False)
term0 = mul(p*100 +
p.subs(x, x-hx)*100 +
p.subs(x, x+hx)*100)
term1 = mul(p.subs(y, y-hy)*100 +
p.subs({x: x-hx, y: y-hy})*100 +
p.subs({x: x+hx, y: y-hy})*100)
term2 = mul(p.subs(y, y+hy)*100 +
p.subs({x: x-hx, y: y+hy})*100 +
p.subs({x: x+hx, y: y+hy})*100)

# `str` simply because some objects are of type EvalDerivative
assert str(eq.evaluate.rhs) == str(term0 + term1 + term2)

def test_compound_subs(self):
grid = Grid(shape=(11,))
x, = grid.dimensions
hx, = grid.spacing_symbols

f = Function(name='f', grid=grid, space_order=2)
p = TimeFunction(name='p', grid=grid, space_order=2,
coefficients='symbolic')

coeffs0 = np.array([100, 100, 100])

subs = Substitutions(Coefficient(1, p, x, coeffs0))

eq = Eq(p.forward, (f*p).dx, coefficients=subs)

term0 = f*p*100
term1 = (f*p*100).subs(x, x-hx)
term2 = (f*p*100).subs(x, x+hx)

# `str` simply because some objects are of type EvalDerivative
assert str(eq.evaluate.rhs) == str(term0 + term1 + term2)
16 changes: 13 additions & 3 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_backward_dt2(self):
assert_structure(op, ['t,x,y'], 't,x,y')


class TestSymbolicCoefficients(object):
class TestSymbolicCoeffs(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we should also test the case where different coefficients are passed for x and y derivatives with cross-derivatives, mixed derivatives, and non-perfect mixed derivatives. Also the case where only one of the x or y derivatives gets coefficients specified but bother are defined as symbolic.


def test_fallback_to_default(self):
grid = Grid(shape=(8, 8, 8))
Expand All @@ -40,12 +40,22 @@ def test_fallback_to_default(self):

def test_numeric_coeffs(self):
grid = Grid(shape=(11,), extent=(10.,))

u = Function(name='u', grid=grid, coefficients='symbolic', space_order=2)
v = Function(name='v', grid=grid, coefficients='symbolic', space_order=2)

coeffs = Substitutions(Coefficient(2, u, grid.dimensions[0], np.zeros(3)))

op = Operator(Eq(u, u.dx2, coefficients=coeffs), opt=({'expand': False},))
op.cfunction
opt = ('advanced', {'expand': False})

# Pure derivative
Operator(Eq(u, u.dx2, coefficients=coeffs), opt=opt).cfunction

# Mixed derivative
Operator(Eq(u, u.dx.dx, coefficients=coeffs), opt=opt).cfunction

# Non-perfect mixed derivative
Operator(Eq(u, (u.dx + v.dx).dx, coefficients=coeffs), opt=opt).cfunction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to sanity check these outputs. Print them and you will see that no errors are raised currently, but the resultant stencils are incorrect



class Test1Pass(object):
Expand Down