forked from facebookresearch/adaptive-softmax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_cont_sent_map_funky.lua
108 lines (97 loc) · 3.03 KB
/
gen_cont_sent_map_funky.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
require 'math'
require 'cutorch'
require 'nn'
require 'cunn'
require 'rnnlib'
local tablex = require 'pl.tablex'
local stringx = require 'pl.stringx'
local tnt = require 'torchnet'
local data = require 'data'
local utils = require 'utils'
local decoders = dofile 'decoders/init.lua'
torch.setheaptracking(true)
local cmd = torch.CmdLine('-', '-') cmd:option('-seed', 1, 'random gen seed')
cmd:option('-dicpath', '', 'Path to dictionary txt file')
cmd:option('-modelpath', '', 'Path to the model')
cmd:option('-contextpath', '', 'Path to text file with context words')
cmd:option('-devid', 1, 'GPU to use')
cmd:option('-k', 128, 'guesses to rerank')
cmd:option('-r', 0, 'reward per word')
cmd:option('-g', 0.5, 'reward decay')
cmd:option('-maxsteps', 100, 'reward per word')
cmd:option('-cr', 0, 'reward per context word')
cmd:option('-v', 0, 'verbosity')
local config = cmd:parse(arg)
torch.manualSeed(config.seed)
cutorch.manualSeed(config.seed)
cutorch.setDevice(config.devid)
local dic
if paths.filep(config.dicpath) then
dic = data.loaddictionary(config.dicpath)
else
error('Dictionary not found!')
end
local all = torch.load(config.modelpath)
dic = data.sortthresholddictionary(dic, all.config.threshold or 2)
collectgarbage()
local ntoken = #dic.idx2word
model = all['model']
local lut = model.modules[1].modules[2].modules[1]
local rnn = model.modules[2]
local dec = model.modules[7]
model:cuda()
model:remove()
model:evaluate()
collectgarbage()
local model2 = nn.Sequential()
:add(nn.ParallelTable()
:add(nn.Identity())
:add(nn.Sequential()
:add(lut)
:add(nn.SplitTable(1))
)
)
:add(rnn)
:add(nn.SelectTable(2))
:add(nn.SelectTable(-1))
:add(nn.JoinTable(1)):cuda()
local ne = 0
local f = assert(io.open(config.contextpath, "r"))
local line = f:read("*line")
while line ~= nil do
local init_seq = {}
local i = 0
for column in line:gmatch("[^\t]+") do
if i == 0 then
for word in column:gmatch("[^ ]+") do
local idx = data.getidx(dic, word)
if idx ~= 2 then
table.insert(init_seq, idx)
end
end
else
break
end
i = i + 1
end
local term = data.getidx(dic, '</s>')
table.insert(init_seq, term)
local template = {
torch.CudaTensor(init_seq),
torch.CudaTensor({term})
}
local prefix = torch.CudaTensor(init_seq)
local best = decoders.template_beam_search(model2,
rnn,
dec,
config.k,
template,
dic)
for i = 1, #best do
io.write(dic.idx2word[best[i]] .. ' ')
end
print('')
ne = ne + 1
if ne % 10 == 0 then collectgarbage() end
line = f:read("*line")
end