Skip to content

Commit

Permalink
Merge edec802 into 49c579f
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Mar 4, 2021
2 parents 49c579f + edec802 commit e2d179d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def __run_eval_epoch_end(self, num_dataloaders):

# with a single dataloader don't pass an array
outputs = self.outputs

# free memory
self.outputs = []

eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from logging import INFO

Expand All @@ -22,7 +21,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -274,6 +273,7 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
seed_everything(0)

class TestPruning(ModelPruning):

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/logging_/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def validation_step_end(self, acc):
def validation_epoch_end(self, outputs):
self.log('g', torch.tensor(2, device=self.device), on_epoch=True)
self.validation_epoch_end_called = True
assert len(self.trainer.evaluation_loop.outputs) == 0

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
Expand Down

0 comments on commit e2d179d

Please sign in to comment.