Skip to content

Commit

Permalink
Backend jax: Bug fix in forward-mode gradients (#1591)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZongrenZou authored Dec 18, 2023
1 parent 9d9d0b0 commit 46e2c2e
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(self, i=0, j=None):
# it is consistent with the argument, which is also a tuple. This may be useful for further computation,
# e.g. Hessian.
return (
self.J[i]
self.J[j]
if self.dim_y == 1
else (
self.J[j][0][:, i : i + 1],
Expand Down Expand Up @@ -206,6 +206,8 @@ def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
Returns:
H[`i`][`j`].
"""
if component is None:
component = 0
# TODO: Naive implementation. To be improved.
# This jacobian is OK, as it will reuse cached Jacobians.
dy_xi = jacobian(ys, xs, i=component, j=i)
Expand Down

0 comments on commit 46e2c2e

Please sign in to comment.