-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathvalidation.py
117 lines (100 loc) · 4.14 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import time
import sys
import pdb
import torch
import torch.distributed as dist
from utils import AverageMeter, calculate_accuracy
def val_epoch(epoch,
data_loader,
model,
criterion,
device,
logger,
tb_writer=None,
distributed=False,
rpn=None,
det_interval=2,
nrois=10):
print('validation at epoch {}'.format(epoch))
model.eval()
if rpn is not None:
rpn.eval()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
accuracies = AverageMeter()
end_time = time.time()
with torch.no_grad():
for i, (inputs, targets) in enumerate(data_loader):
data_time.update(time.time() - end_time)
targets = targets.to(device, non_blocking=True)
if rpn is not None:
'''
There was an unexpected CUDNN_ERROR when len(rpn_inputs) is
decrased.
'''
T = inputs.shape[2]
N, C, T, H, W = inputs.size()
if i == 0:
max_N = N
# sample frames for RPN
sample = torch.arange(0,T,det_interval)
rpn_inputs = inputs[:,:,sample].transpose(1,2).contiguous()
rpn_inputs = rpn_inputs.view(-1,C,H,W)
if len(inputs) < max_N:
print("Modified from {} to {}".format(len(inputs), max_N))
rpn_inputs = torch.cat((rpn_inputs, rpn_inputs[:(max_N-len(inputs))*(T//det_interval)]))
with torch.no_grad():
proposals = rpn(rpn_inputs)
proposals = proposals.view(-1,T//det_interval,nrois,4)
if len(inputs) < max_N:
proposals = proposals[:len(inputs)]
outputs = model(inputs, proposals.detach())
# update to the largest batch_size
max_N = max(N, max_N)
else:
outputs = model(inputs)
loss = criterion(outputs, targets)
acc = calculate_accuracy(outputs, targets)
losses.update(loss.item(), inputs.size(0))
accuracies.update(acc, inputs.size(0))
batch_time.update(time.time() - end_time)
end_time = time.time()
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc {acc.val:.3f} ({acc.avg:.3f})'.format(
epoch,
i + 1,
len(data_loader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
acc=accuracies))
if distributed:
loss_sum = torch.tensor([losses.sum],
dtype=torch.float32,
device=device)
loss_count = torch.tensor([losses.count],
dtype=torch.float32,
device=device)
acc_sum = torch.tensor([accuracies.sum],
dtype=torch.float32,
device=device)
acc_count = torch.tensor([accuracies.count],
dtype=torch.float32,
device=device)
dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)
losses.avg = loss_sum.item() / loss_count.item()
accuracies.avg = acc_sum.item() / acc_count.item()
if logger is not None:
logger.log({'epoch': epoch, 'loss': losses.avg, 'acc': accuracies.avg})
if tb_writer is not None:
tb_writer.add_scalar('val/loss', losses.avg, epoch)
tb_writer.add_scalar('val/acc', accuracies.avg, epoch)
return losses.avg