Skip to content

Commit

Permalink
Adding computing gradients using forward-mode autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
lululxvi committed Dec 16, 2023
1 parent c69cd34 commit b9adcb6
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 1 deletion.
1 change: 1 addition & 0 deletions deepxde/gradients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["jacobian", "hessian"]

# from .gradients_forward import jacobian
from .gradients_reverse import clear, jacobian, hessian
178 changes: 178 additions & 0 deletions deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Compute gradients using forward-mode autodiff."""

__all__ = ["jacobian", "hessian"]

from ..backend import backend_name, jax


class Jacobian:
"""Compute Jacobian matrix J: J[i][j] = dy_i/dx_j, where i = 0, ..., dim_y-1 and
j = 0, ..., dim_x - 1.
It is lazy evaluation, i.e., it only computes J[i][j] when needed.
Args:
ys: Output Tensor of shape (batch_size, dim_y).
xs: Input Tensor of shape (batch_size, dim_x).
"""

def __init__(self, ys, xs):
self.ys = ys
self.xs = xs

if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
self.dim_y = ys.shape[1]
# TODO: Other backends
raise NotImplementedError(
"Backend f{backend_name} doesn't support forward-mode autodiff."
)
elif backend_name == "jax":
# For backend jax, a tuple of a jax array and a callable is passed as one of
# the arguments, since jax does not support computational graph explicitly.
# The array is used to control the dimensions and the callable is used to
# obtain the derivative function, which can be used to compute the
# derivatives.
self.dim_y = ys[0].shape[1]
self.dim_x = xs.shape[1]

self.J = {}

def __call__(self, i=0, j=None):
"""Returns J[`i`][`j`]. If `j` is ``None``, returns the gradient of y_i, i.e.,
J[i].
"""
if not 0 <= i < self.dim_y:
raise ValueError("i={} is not valid.".format(i))
if j is not None and not 0 <= j < self.dim_x:
raise ValueError("j={} is not valid.".format(j))
# Computing gradient is not supported in forward mode, unless there is only one input.
if j is None and self.dim_x > 1:
raise NotImplementedError(
"Forward-mode autodiff doesn't support computing gradient."
)
# Compute J[:, j]
if j not in self.J:
if backend_name == "jax":
# Here, we use jax.jvp to compute the gradient of a function. This is
# different from TensorFlow and PyTorch that the input of a function is
# no longer a batch. Instead, it is a single point. Formally, backend
# jax computes gradients pointwisely and then vectorizes to batch, by
# jax.vmap. However, computationally, this is in fact done batchwisely
# and efficiently. It is very important to note that, without jax.vmap,
# this can only deal with functions whose output is a scalar and input
# is a single point.
tangent = jax.numpy.zeros(self.dim_x).at[j].set(1)
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)

if backend_name == "jax":
# Unlike other backends, in backend jax, a tuple of a jax array and a callable is returned, so that
# it is consistent with the argument, which is also a tuple. This may be useful for further computation,
# e.g. Hessian.
return (
self.J[i]
if self.dim_y == 1
else (
self.J[j][0][:, i : i + 1],
lambda inputs: self.J[j][1](inputs)[i : i + 1],
)
)


# TODO: Refactor duplicate code
class Jacobians:
"""Compute multiple Jacobians.
A new instance will be created for a new pair of (output, input). For the (output,
input) pair that has been computed before, it will reuse the previous instance,
rather than creating a new one.
"""

def __init__(self):
self.Js = {}

def __call__(self, ys, xs, i=0, j=None):
# For backend tensorflow and pytorch, self.Js cannot be reused across iteration.
# For backend pytorch, we need to reset self.Js in each iteration to avoid
# memory leak.
#
# For backend tensorflow, in each iteration, self.Js is reset to {}.
#
# Example:
#
# mydict = {}
#
# @tf.function
# def f(x):
# tf.print(mydict) # always {}
# y = 1 * x
# tf.print(hash(y.ref()), hash(x.ref())) # Doesn't change
# mydict[(y.ref(), x.ref())] = 1
# tf.print(mydict)
#
# for _ in range(2):
# x = np.random.random((3, 4))
# f(x)
#
#
# For backend pytorch, in each iteration, ys and xs are new tensors
# converted from np.ndarray, so self.Js will increase over iteration.
#
# Example:
#
# mydict = {}
#
# def f(x):
# print(mydict)
# y = 1 * x
# print(hash(y), hash(x))
# mydict[(y, x)] = 1
# print(mydict)
#
# for i in range(2):
# x = np.random.random((3, 4))
# x = torch.from_numpy(x)
# x.requires_grad_()
# f(x)
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
key = (ys.ref(), xs.ref())
elif backend_name in ["pytorch", "paddle"]:
key = (ys, xs)
elif backend_name == "jax":
key = (id(ys[0]), id(xs))
if key not in self.Js:
self.Js[key] = Jacobian(ys, xs)
return self.Js[key](i, j)

def clear(self):
"""Clear cached Jacobians."""
self.Js = {}


# TODO: Refactor duplicate code
def jacobian(ys, xs, i=0, j=None):
"""Compute Jacobian matrix J: J[i][j] = dy_i / dx_j, where i = 0, ..., dim_y - 1 and
j = 0, ..., dim_x - 1.
Use this function to compute first-order derivatives instead of ``tf.gradients()``
or ``torch.autograd.grad()``, because
- It is lazy evaluation, i.e., it only computes J[i][j] when needed.
- It will remember the gradients that have already been computed to avoid duplicate
computation.
Args:
ys: Output Tensor of shape (batch_size, dim_y).
xs: Input Tensor of shape (batch_size, dim_x).
i (int):
j (int or None):
Returns:
J[`i`][`j`] in Jacobian matrix J. If `j` is ``None``, returns the gradient of
y_i, i.e., J[`i`].
"""
return jacobian._Jacobians(ys, xs, i=i, j=j)


# TODO: Refactor duplicate code
jacobian._Jacobians = Jacobians()
2 changes: 1 addition & 1 deletion deepxde/gradients/gradients_reverse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Compute gradients using reverse mode."""
"""Compute gradients using reverse-mode autodiff."""

__all__ = ["jacobian", "hessian"]

Expand Down

0 comments on commit b9adcb6

Please sign in to comment.