Skip to content

Commit

Permalink
Refactor dde.grad
Browse files Browse the repository at this point in the history
  • Loading branch information
lululxvi committed Dec 28, 2023
1 parent 7388c0f commit 194c4f1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 46 deletions.
27 changes: 26 additions & 1 deletion deepxde/gradients/jacobian.py → deepxde/gradients/gradients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Compute Jacobian matrix."""
"""Compute gradients using reverse-mode or forward-mode autodiff."""

__all__ = ["jacobian", "hessian"]

from abc import ABC, abstractmethod

Expand Down Expand Up @@ -145,3 +147,26 @@ def jacobian(ys, xs, i=None, j=None):


jacobian._Jacobians = Jacobians(Jacobian)


def hessian(ys, xs, component=0, i=0, j=0):
"""Compute `Hessian matrix <https://en.wikipedia.org/wiki/Hessian_matrix>`_ H as
H[i, j] = d^2y / dx_i dx_j, where i,j = 0, ..., dim_x - 1.
Use this function to compute second-order derivatives instead of ``tf.gradients()``
or ``torch.autograd.grad()``, because
- It is lazy evaluation, i.e., it only computes H[i, j] when needed.
- It will remember the gradients that have already been computed to avoid duplicate
computation.
Args:
ys: Output Tensor of shape (batch_size, dim_y).
xs: Input Tensor of shape (batch_size, dim_x).
component: `ys[:, component]` is used as y to compute the Hessian.
i (int): `i`th row.
j (int): `j`th column.
Returns:
H[`i`, `j`].
"""
22 changes: 1 addition & 21 deletions deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__all__ = ["jacobian", "hessian"]

from .jacobian import Jacobian, Jacobians, jacobian
from .gradients import Jacobian, Jacobians, jacobian
from ..backend import backend_name, jax


Expand Down Expand Up @@ -64,25 +64,5 @@ def __call__(self, i=None, j=None):


def hessian(ys, xs, component=0, i=0, j=0):
"""Compute `Hessian matrix <https://en.wikipedia.org/wiki/Hessian_matrix>`_ H as
H[i, j] = d^2y / dx_i dx_j, where i,j = 0, ..., dim_x - 1.
Use this function to compute second-order derivatives instead of ``tf.gradients()``
or ``torch.autograd.grad()``, because
- It is lazy evaluation, i.e., it only computes H[i, j] when needed.
- It will remember the gradients that have already been computed to avoid duplicate
computation.
Args:
ys: Output Tensor of shape (batch_size, dim_y).
xs: Input Tensor of shape (batch_size, dim_x).
component: `ys[:, component]` is used as y to compute the Hessian.
i (int): `i`th row.
j (int): `j`th column.
Returns:
H[`i`, `j`].
"""
dys_xj = jacobian(ys, xs, i=None, j=j)
return jacobian(dys_xj, xs, i=component, j=i)
22 changes: 1 addition & 21 deletions deepxde/gradients/gradients_reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__all__ = ["jacobian", "hessian"]

from .jacobian import Jacobian, Jacobians, jacobian
from .gradients import Jacobian, Jacobians, jacobian
from ..backend import backend_name, tf, torch, jax, paddle


Expand Down Expand Up @@ -133,26 +133,6 @@ def clear(self):


def hessian(ys, xs, component=0, i=0, j=0):
"""Compute `Hessian matrix <https://en.wikipedia.org/wiki/Hessian_matrix>`_ H as
H[i, j] = d^2y / dx_i dx_j, where i,j = 0, ..., dim_x - 1.
Use this function to compute second-order derivatives instead of ``tf.gradients()``
or ``torch.autograd.grad()``, because
- It is lazy evaluation, i.e., it only computes H[i, j] when needed.
- It will remember the gradients that have already been computed to avoid duplicate
computation.
Args:
ys: Output Tensor of shape (batch_size, dim_y).
xs: Input Tensor of shape (batch_size, dim_x).
component: `ys[:, component]` is used as y to compute the Hessian.
i (int): `i`th row.
j (int): `j`th column.
Returns:
H[`i`, `j`].
"""
return hessian._Hessians(ys, xs, component=component, i=i, j=j)


Expand Down
6 changes: 3 additions & 3 deletions docs/modules/deepxde.gradients.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
deepxde.gradients
=================

deepxde.gradients.gradients\_reverse module
-------------------------------------------
deepxde.gradients.gradients module
----------------------------------

.. automodule:: deepxde.gradients.gradients_reverse
.. automodule:: deepxde.gradients.gradients
:members:
:undoc-members:
:show-inheritance:
Expand Down

3 comments on commit 194c4f1

@lijialin03
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error occurs after this commit, which is
image
The reason is dde.grad.jacobian(y, x) without setting i,j is not supported now, you can reproduce the error by runing example pinn_forward/Euler_beam.py, pinn_forward/ode_2nd.py, pinn_forward/Poisson_PointSetOperator_1d.py
If the commit should be merged, more modifications to the examples or the files in PR may be needed.

@lululxvi
Copy link
Owner Author

@lululxvi lululxvi commented on 194c4f1 Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the code to handle this special case.

@lululxvi
Copy link
Owner Author

@lululxvi lululxvi commented on 194c4f1 Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the new commit.

Please sign in to comment.