-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy patheval_hooks.py
107 lines (93 loc) · 3.77 KB
/
eval_hooks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os.path as osp
from mmcv.runner import Hook
from torch.utils.data import DataLoader
class EvalHook(Hook):
"""Evaluation hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval (by epochs). Default: 1.
"""
def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got '
f'{type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.by_epoch = by_epoch
self.eval_kwargs = eval_kwargs
def after_train_iter(self, runner):
"""After train epoch hook."""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook."""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)
def evaluate(self, runner, results):
"""Call evaluate function of dataset."""
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval (by epochs). Default: 1.
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
"""
def __init__(self,
dataloader,
interval=1,
gpu_collect=False,
by_epoch=False,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(
'dataloader must be a pytorch DataLoader, but got {}'.format(
type(dataloader)))
self.dataloader = dataloader
self.interval = interval
self.gpu_collect = gpu_collect
self.by_epoch = by_epoch
self.eval_kwargs = eval_kwargs
def after_train_iter(self, runner):
"""After train epoch hook."""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook."""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)