-
Notifications
You must be signed in to change notification settings - Fork 781
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding computing gradients using forward-mode autodiff
- Loading branch information
Showing
3 changed files
with
180 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
__all__ = ["jacobian", "hessian"] | ||
|
||
# from .gradients_forward import jacobian | ||
from .gradients_reverse import clear, jacobian, hessian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
"""Compute gradients using forward-mode autodiff.""" | ||
|
||
__all__ = ["jacobian", "hessian"] | ||
|
||
from ..backend import backend_name, jax | ||
|
||
|
||
class Jacobian: | ||
"""Compute Jacobian matrix J: J[i][j] = dy_i/dx_j, where i = 0, ..., dim_y-1 and | ||
j = 0, ..., dim_x - 1. | ||
It is lazy evaluation, i.e., it only computes J[i][j] when needed. | ||
Args: | ||
ys: Output Tensor of shape (batch_size, dim_y). | ||
xs: Input Tensor of shape (batch_size, dim_x). | ||
""" | ||
|
||
def __init__(self, ys, xs): | ||
self.ys = ys | ||
self.xs = xs | ||
|
||
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]: | ||
self.dim_y = ys.shape[1] | ||
# TODO: Other backends | ||
raise NotImplementedError( | ||
"Backend f{backend_name} doesn't support forward-mode autodiff." | ||
) | ||
elif backend_name == "jax": | ||
# For backend jax, a tuple of a jax array and a callable is passed as one of | ||
# the arguments, since jax does not support computational graph explicitly. | ||
# The array is used to control the dimensions and the callable is used to | ||
# obtain the derivative function, which can be used to compute the | ||
# derivatives. | ||
self.dim_y = ys[0].shape[1] | ||
self.dim_x = xs.shape[1] | ||
|
||
self.J = {} | ||
|
||
def __call__(self, i=0, j=None): | ||
"""Returns J[`i`][`j`]. If `j` is ``None``, returns the gradient of y_i, i.e., | ||
J[i]. | ||
""" | ||
if not 0 <= i < self.dim_y: | ||
raise ValueError("i={} is not valid.".format(i)) | ||
if j is not None and not 0 <= j < self.dim_x: | ||
raise ValueError("j={} is not valid.".format(j)) | ||
# Computing gradient is not supported in forward mode, unless there is only one input. | ||
if j is None and self.dim_x > 1: | ||
raise NotImplementedError( | ||
"Forward-mode autodiff doesn't support computing gradient." | ||
) | ||
# Compute J[:, j] | ||
if j not in self.J: | ||
if backend_name == "jax": | ||
# Here, we use jax.jvp to compute the gradient of a function. This is | ||
# different from TensorFlow and PyTorch that the input of a function is | ||
# no longer a batch. Instead, it is a single point. Formally, backend | ||
# jax computes gradients pointwisely and then vectorizes to batch, by | ||
# jax.vmap. However, computationally, this is in fact done batchwisely | ||
# and efficiently. It is very important to note that, without jax.vmap, | ||
# this can only deal with functions whose output is a scalar and input | ||
# is a single point. | ||
tangent = jax.numpy.zeros(self.dim_x).at[j].set(1) | ||
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1] | ||
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn) | ||
|
||
if backend_name == "jax": | ||
# Unlike other backends, in backend jax, a tuple of a jax array and a callable is returned, so that | ||
# it is consistent with the argument, which is also a tuple. This may be useful for further computation, | ||
# e.g. Hessian. | ||
return ( | ||
self.J[i] | ||
if self.dim_y == 1 | ||
else ( | ||
self.J[j][0][:, i : i + 1], | ||
lambda inputs: self.J[j][1](inputs)[i : i + 1], | ||
) | ||
) | ||
|
||
|
||
# TODO: Refactor duplicate code | ||
class Jacobians: | ||
"""Compute multiple Jacobians. | ||
A new instance will be created for a new pair of (output, input). For the (output, | ||
input) pair that has been computed before, it will reuse the previous instance, | ||
rather than creating a new one. | ||
""" | ||
|
||
def __init__(self): | ||
self.Js = {} | ||
|
||
def __call__(self, ys, xs, i=0, j=None): | ||
# For backend tensorflow and pytorch, self.Js cannot be reused across iteration. | ||
# For backend pytorch, we need to reset self.Js in each iteration to avoid | ||
# memory leak. | ||
# | ||
# For backend tensorflow, in each iteration, self.Js is reset to {}. | ||
# | ||
# Example: | ||
# | ||
# mydict = {} | ||
# | ||
# @tf.function | ||
# def f(x): | ||
# tf.print(mydict) # always {} | ||
# y = 1 * x | ||
# tf.print(hash(y.ref()), hash(x.ref())) # Doesn't change | ||
# mydict[(y.ref(), x.ref())] = 1 | ||
# tf.print(mydict) | ||
# | ||
# for _ in range(2): | ||
# x = np.random.random((3, 4)) | ||
# f(x) | ||
# | ||
# | ||
# For backend pytorch, in each iteration, ys and xs are new tensors | ||
# converted from np.ndarray, so self.Js will increase over iteration. | ||
# | ||
# Example: | ||
# | ||
# mydict = {} | ||
# | ||
# def f(x): | ||
# print(mydict) | ||
# y = 1 * x | ||
# print(hash(y), hash(x)) | ||
# mydict[(y, x)] = 1 | ||
# print(mydict) | ||
# | ||
# for i in range(2): | ||
# x = np.random.random((3, 4)) | ||
# x = torch.from_numpy(x) | ||
# x.requires_grad_() | ||
# f(x) | ||
if backend_name in ["tensorflow.compat.v1", "tensorflow"]: | ||
key = (ys.ref(), xs.ref()) | ||
elif backend_name in ["pytorch", "paddle"]: | ||
key = (ys, xs) | ||
elif backend_name == "jax": | ||
key = (id(ys[0]), id(xs)) | ||
if key not in self.Js: | ||
self.Js[key] = Jacobian(ys, xs) | ||
return self.Js[key](i, j) | ||
|
||
def clear(self): | ||
"""Clear cached Jacobians.""" | ||
self.Js = {} | ||
|
||
|
||
# TODO: Refactor duplicate code | ||
def jacobian(ys, xs, i=0, j=None): | ||
"""Compute Jacobian matrix J: J[i][j] = dy_i / dx_j, where i = 0, ..., dim_y - 1 and | ||
j = 0, ..., dim_x - 1. | ||
Use this function to compute first-order derivatives instead of ``tf.gradients()`` | ||
or ``torch.autograd.grad()``, because | ||
- It is lazy evaluation, i.e., it only computes J[i][j] when needed. | ||
- It will remember the gradients that have already been computed to avoid duplicate | ||
computation. | ||
Args: | ||
ys: Output Tensor of shape (batch_size, dim_y). | ||
xs: Input Tensor of shape (batch_size, dim_x). | ||
i (int): | ||
j (int or None): | ||
Returns: | ||
J[`i`][`j`] in Jacobian matrix J. If `j` is ``None``, returns the gradient of | ||
y_i, i.e., J[`i`]. | ||
""" | ||
return jacobian._Jacobians(ys, xs, i=i, j=j) | ||
|
||
|
||
# TODO: Refactor duplicate code | ||
jacobian._Jacobians = Jacobians() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters