From 109a3af901b36f9087c18d687c88dc5d43b4d6fe Mon Sep 17 00:00:00 2001 From: Jerry-Jzy Date: Tue, 31 Dec 2024 23:59:40 -0500 Subject: [PATCH] update code --- deepxde/data/pde_operator.py | 40 +++++++++++------------------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index dbaf95db9..78c7d5393 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -263,45 +263,29 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None): # Use stack instead of as_tensor to keep the gradients. losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses] elif config.autodiff == "forward": # forward mode AD - if model.net.num_outputs == 1: - shape0, shape1 = outputs.shape[0], outputs.shape[1] - else: - shape0, shape1, shape2 = ( - outputs.shape[0], - outputs.shape[1], - outputs.shape[2], - ) + shape0, shape1 = outputs.shape[0], outputs.shape[1] + shape2 = 1 if model.net.num_outputs == 1 else outputs.shape[2] def forward_call(trunk_input): output = aux[0]((inputs[0], trunk_input)) - if model.net.num_outputs == 1: - return bkd.reshape(output, (shape0 * shape1, 1)) return bkd.reshape(output, (shape0 * shape1, shape2)) - if model.net.num_outputs == 1: - outputs = bkd.reshape(outputs, (shape0 * shape1, 1)) - auxiliary_vars = bkd.reshape( - model.net.auxiliary_vars, (shape0 * shape1, 1) - ) - else: - outputs = bkd.reshape(outputs, (shape0 * shape1, shape2)) - auxiliary_vars = bkd.reshape( - model.net.auxiliary_vars, (shape0 * shape1, shape2) - ) - f = [] if self.pde.pde is not None: # Each f has the shape (N1, N2) - f = self.pde.pde(inputs[1], (outputs, forward_call), auxiliary_vars) + f = self.pde.pde( + inputs[1], + (bkd.reshape(outputs, (shape0 * shape1, shape2)), forward_call), + bkd.reshape(model.net.auxiliary_vars, (shape0 * shape1, shape2)), + ) if not isinstance(f, (list, tuple)): f = [f] + f = ( + [bkd.reshape(fi, (shape0, shape1)) for fi in f] + if model.net.num_outputs == 1 + else [bkd.reshape(fi, (shape0, shape1, shape2)) for fi in f] + ) - if model.net.num_outputs == 1: - outputs = bkd.reshape(outputs, (shape0, shape1)) - f = [bkd.reshape(fi, (shape0, shape1)) for fi in f] - else: - outputs = bkd.reshape(outputs, (shape0, shape1, shape2)) - f = [bkd.reshape(fi, (shape0, shape1, shape2)) for fi in f] # Each error has the shape (N1, ~N2) error_f = [fi[:, bcs_start[-1] :] for fi in f] for error in error_f: