Skip to content
This repository was archived by the owner on Nov 28, 2022. It is now read-only.

Commit

Permalink
Use fecr instead of fenics_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Feb 28, 2021
1 parent 2a468eb commit 3cfb1a9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
4 changes: 2 additions & 2 deletions fenics_pymc3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .core import create_fenics_theano_op
from .core import create_fenics_theano_op, create_fem_theano_op
from .core import FenicsOp, FenicsVJPOp
from fenics_numpy import fenics_to_numpy, numpy_to_fenics
from fecr import to_numpy, from_numpy
23 changes: 14 additions & 9 deletions fenics_pymc3/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import functools

from fenics_numpy import evaluate_primal, evaluate_vjp
from fecr import evaluate_primal, evaluate_pullback


class FenicsVJPOp(Op):
Expand All @@ -30,14 +30,16 @@ def make_node(self, *inputs):
def perform(self, node, inputs, outputs, params):
Δfenics_output = inputs[0]
fenics_output, fenics_inputs, tape = params
numpy_grads = evaluate_vjp(Δfenics_output, fenics_output, fenics_inputs, tape)
numpy_grads = evaluate_pullback(
fenics_output, fenics_inputs, tape, Δfenics_output
)

theano_grads = (
theano.gradient.grad_undefined(self, i, inputs[i]) if ng is None else ng
for i, ng in enumerate(numpy_grads)
)

for i, tg in enumerate(numpy_grads):
for i, tg in enumerate(theano_grads):
outputs[i][0] = tg


Expand Down Expand Up @@ -72,17 +74,17 @@ def grad(self, inputs, output_grads):
return theano_grads


def create_fenics_theano_op(fenics_templates):
"""Return `f(*args) = create_fenics_theano_op(*args)(ofunc(*args))`.
Given the FEniCS-side function ofunc(*args), return the Theano Op,
def create_fem_theano_op(fenics_templates):
"""Return `f(*args) = create_fem_theano_op(*args)(ofunc(*args))`.
Given the FEniCS/Firedrake-side function ofunc(*args), return the Theano Op,
that is callable and differentiable in Theano programs,
`f(*args) = create_fenics_theano_op(*args)(ofunc(*args))` with
`f(*args) = create_fem_theano_op(*args)(ofunc(*args))` with
the VJP of `f`, where:
`*args` are all arguments to `ofunc`.
Args:
ofunc: The FEniCS-side function to be wrapped.
ofunc: The FEniCS/Firedrake-side function to be wrapped.
Returns:
`f(args) = create_fenics_theano_op(*args)(ofunc(*args))`
`f(args) = create_fem_theano_op(*args)(ofunc(*args))`
"""

def decorator(fenics_function):
Expand All @@ -96,3 +98,6 @@ def theano_fem_eval(*args):
return theano_fem_eval

return decorator


create_fenics_theano_op = create_fem_theano_op

0 comments on commit 3cfb1a9

Please sign in to comment.