Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
First version of test pipeline in experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
mibaumgartner committed Jan 20, 2019
1 parent a834aad commit 45d30ad
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
45 changes: 45 additions & 0 deletions delira/training/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,51 @@ def run(self,
"lowest")
)

@staticmethod
def test(network: AbstractPyTorchNetwork,
datamgr_test: typing.Union[BaseDataManager, ConcatDataManager],
trainer_cls = PTNetworkTrainer,
**kwargs):
"""
trains single model
Parameters
----------
network : :class:`AbstractPyTorchNetwork`
the network to train
datamgr_test : BaseDataManager or ConcatDataManager
holds the testset
trainer_cls :
class defining the actual trainer,
defaults to :class:`PyTorchNetworkTrainer`,
which should be suitable for most cases,
but can easily be overwritten and exchanged if necessary
**kwargs :
holds additional keyword arguments
(which are completly passed to the trainers init)
"""
criterions = kwargs.pop('criterions', {})
trainer = trainer_cls(network=network,
save_path='',
criterions=criterions,
optimizer_cls=None)

# testing with batchsize 1 and 1 augmentation processs to
# avoid dropping of last elements
orig_num_aug_processes = datamgr_test.n_process_augmentation
orig_batch_size = datamgr_test.batch_size

datamgr_test.batch_size = 1
datamgr_test.n_process_augmentation = 1

outputs, labels, metrics_val = trainer.predict(
datamgr_test.get_batchgen(), batch_size=orig_batch_size)

# reset old values
datamgr_test.batch_size = orig_batch_size
datamgr_test.n_process_augmentation = orig_num_aug_processes
return outputs, labels, metrics_val

def save(self):
"""
Saves the Whole experiments
Expand Down
1 change: 1 addition & 0 deletions delira/training/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def train(self, num_epochs, datamgr_train, datamgr_valid=None,
datamgr_valid.batch_size = 1
datamgr_valid.n_process_augmentation = 1

# TODO: wrong order of returns???
labels_val, pred_val, metrics_val = self.predict(
datamgr_valid.get_batchgen(), batch_size=orig_batch_size)

Expand Down

0 comments on commit 45d30ad

Please sign in to comment.