Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
fix bug in PyTorchNetworkTrainer.predict
Browse files Browse the repository at this point in the history
  • Loading branch information
ORippler committed Jan 16, 2019
1 parent a4ac864 commit 99a7801
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions delira/training/pytorch_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import os
import logging
import shutil
import numpy as np
import typing
from tqdm.auto import tqdm
import torch
from collections import OrderedDict
from batchgenerators.dataloading import MultiThreadedAugmenter
from .callbacks import AbstractCallback
from .abstract_trainer import AbstractNetworkTrainer
Expand Down Expand Up @@ -425,8 +422,16 @@ def predict(self, batchgen, batch_size=None):

batch_list = []

orig_batch_size = batch_size

for i, batch in pbar:

if not batch_list and (n_batches - i) < batch_size:

batch_size = n_batches - i
logger.debug("Set Batchsize down to %d to avoid cutting "
"of the last batches" % batch_size)

data_dict = self._prepare_batch(batch, self.input_device,
self.output_device)
# queue inputs and labels
Expand All @@ -435,11 +440,6 @@ def predict(self, batchgen, batch_size=None):
# if queue is full process queue:
if batch_size is None or len(batch_list) >= batch_size:

if not batch_list and (n_batches - i) < batch_size:
batch_size = n_batches - i
logger.debug("Set Batchsize down to %d to avoid cutting "
"of the last batches" % batch_size)

batch_dict = {}
for batch in batch_list:
for key, val in batch.items():
Expand Down Expand Up @@ -501,7 +501,7 @@ def predict(self, batchgen, batch_size=None):

# if virtual batchsize is given: calculate actual number of batches
if batch_size is not None:
div = n_batches / batch_size
div = np.ceil(n_batches / orig_batch_size)
else:
div = n_batches

Expand Down

0 comments on commit 99a7801

Please sign in to comment.