Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Jan 1, 2025
1 parent 7d36bb4 commit 109a3af
Showing 1 changed file with 12 additions and 28 deletions.
40 changes: 12 additions & 28 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,45 +263,29 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
# Use stack instead of as_tensor to keep the gradients.
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
elif config.autodiff == "forward": # forward mode AD
if model.net.num_outputs == 1:
shape0, shape1 = outputs.shape[0], outputs.shape[1]
else:
shape0, shape1, shape2 = (
outputs.shape[0],
outputs.shape[1],
outputs.shape[2],
)
shape0, shape1 = outputs.shape[0], outputs.shape[1]
shape2 = 1 if model.net.num_outputs == 1 else outputs.shape[2]

def forward_call(trunk_input):
output = aux[0]((inputs[0], trunk_input))
if model.net.num_outputs == 1:
return bkd.reshape(output, (shape0 * shape1, 1))
return bkd.reshape(output, (shape0 * shape1, shape2))

if model.net.num_outputs == 1:
outputs = bkd.reshape(outputs, (shape0 * shape1, 1))
auxiliary_vars = bkd.reshape(
model.net.auxiliary_vars, (shape0 * shape1, 1)
)
else:
outputs = bkd.reshape(outputs, (shape0 * shape1, shape2))
auxiliary_vars = bkd.reshape(
model.net.auxiliary_vars, (shape0 * shape1, shape2)
)

f = []
if self.pde.pde is not None:
# Each f has the shape (N1, N2)
f = self.pde.pde(inputs[1], (outputs, forward_call), auxiliary_vars)
f = self.pde.pde(
inputs[1],
(bkd.reshape(outputs, (shape0 * shape1, shape2)), forward_call),
bkd.reshape(model.net.auxiliary_vars, (shape0 * shape1, shape2)),
)
if not isinstance(f, (list, tuple)):
f = [f]
f = (
[bkd.reshape(fi, (shape0, shape1)) for fi in f]
if model.net.num_outputs == 1
else [bkd.reshape(fi, (shape0, shape1, shape2)) for fi in f]
)

if model.net.num_outputs == 1:
outputs = bkd.reshape(outputs, (shape0, shape1))
f = [bkd.reshape(fi, (shape0, shape1)) for fi in f]
else:
outputs = bkd.reshape(outputs, (shape0, shape1, shape2))
f = [bkd.reshape(fi, (shape0, shape1, shape2)) for fi in f]
# Each error has the shape (N1, ~N2)
error_f = [fi[:, bcs_start[-1] :] for fi in f]
for error in error_f:
Expand Down

0 comments on commit 109a3af

Please sign in to comment.