From 2f4e340d2903d5518d90d8a3db8e8cdf88613707 Mon Sep 17 00:00:00 2001 From: Jerry-Jzy Date: Tue, 24 Dec 2024 23:43:55 -0500 Subject: [PATCH] update comment --- deepxde/gradients/gradients_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepxde/gradients/gradients_forward.py b/deepxde/gradients/gradients_forward.py index 8aaf3e3c2..2f9aa10f7 100644 --- a/deepxde/gradients/gradients_forward.py +++ b/deepxde/gradients/gradients_forward.py @@ -92,7 +92,9 @@ def grad_fn(x): # In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array # and a callable is returned, so that it is consistent with the argument, # which is also a tuple. This is useful for further computation, e.g., - # Hessian. The code still works even the output dim is > 2. + # Hessian. The code is designed for the output shape of (batch size, dim), + # but we find that the code also works for the output shape of (batch size 1, + # batch size 2, dim) such as multiple-output DeepONet. self.J[i, j] = ( self.J[j][0][..., i : i + 1], lambda x: self.J[j][1](x)[i : i + 1],