-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathinfer.py
65 lines (47 loc) · 1.78 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from config import Config
from s2s_dataset import gen
from torch.utils.data import DataLoader, Dataset
from models.Seq2Seq import Attention, Encoder, Decoder, Seq2Seq
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def init_weights(m):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.normal_(param.data, mean=0, std=0.01)
else:
nn.init.constant_(param.data, 0)
if __name__ == "__main__":
config = Config()
device = config.device
attn = Attention(config.s2s_enc_hid, config.s2s_dec_hid)
enc = Encoder(config.s2s_emb_dim,
config.s2s_enc_hid,
config.s2s_dec_hid,
config.s2s_enc_dropout)
dec = Decoder(len(config.class_char),
config.s2s_emb_dim,
config.s2s_enc_hid,
config.s2s_dec_hid,
config.s2s_enc_dropout,
attn)
model = Seq2Seq(enc, dec, device).to(device)
model.apply(init_weights)
model.load_state_dict(torch.load('weight/s2s.pt'))
model.eval()
data = gen(["data/test/1.json"], 1, config.max_box_num, device)
with torch.no_grad():
src, trg = next(data)
output = model(src)
output = output.permute(1, 0, 2).contiguous().view(-1, len(config.class_char))
output = torch.max(F.softmax(output,dim=1),1)
possible,label= output.values,output.indices
acc = np.mean((label == trg).cpu().numpy())
print("trget:",trg.long())
print("label:",label)
print("possible:",possible)
print("acc:",acc)