Skip to content

Commit

Permalink
Backend PyTorch supports forward-mode automatic differentiation: Step…
Browse files Browse the repository at this point in the history
… 1 (#1607)
  • Loading branch information
ZongrenZou authored Jan 1, 2024
1 parent 247bb2c commit b78ccec
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 7 additions & 1 deletion deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,14 @@ def __init__(
self.test()

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
if backend_name in ["tensorflow.compat.v1", "tensorflow", "paddle"]:
outputs_pde = outputs
elif backend_name == "pytorch":
if config.autodiff == "reverse":
outputs_pde = outputs
elif config.autodiff == "forward":
# forward-mode AD in PyTorch requires functions
outputs_pde = (outputs, aux[0])
elif backend_name == "jax":
# JAX requires pure functions
outputs_pde = (outputs, aux[0])
Expand Down
4 changes: 3 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
# Data losses
if targets is not None:
targets = torch.as_tensor(targets)
losses = losses_fn(targets, outputs_, loss_fn, inputs, self)
# if forward-mode AD is used, then a forward call needs to be passed
aux = [self.net] if config.autodiff == "forward" else None
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
if not isinstance(losses, list):
losses = [losses]
losses = torch.stack(losses)
Expand Down

0 comments on commit b78ccec

Please sign in to comment.