-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathDataLoaderOneThread.lua
72 lines (62 loc) · 1.99 KB
/
DataLoaderOneThread.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
local M = {}
local DataLoader = torch.class('DataLoader', M)
paths.dofile('DataSampler.lua')
--------------------------------------------------------------------------------
-- function: create train/val data loaders
function DataLoader.create(config)
local loaders = {}
for i, split in ipairs{'train', 'val'} do
loaders[i] = M.DataLoader(config, split)
end
return table.unpack(loaders)
end
--------------------------------------------------------------------------------
-- function: init
function DataLoader:__init(config, split)
torch.setdefaulttensortype('torch.FloatTensor')
local seed = config.seed
torch.manualSeed(seed)
-- paths.dofile('DataSampler.lua')
self.ds = DataSampler(config, split)
local sizes = self.ds:size()
self.__size = sizes
self.batch = config.batch
self.hfreq = config.hfreq
end
--------------------------------------------------------------------------------
-- function: return size of dataset
function DataLoader:size()
return math.ceil(self.__size / self.batch)
end
--------------------------------------------------------------------------------
-- function: run
function DataLoader:run()
local size, batch = self.__size, self.batch
local idx, sample = 1, nil
local n = 0
local function customloop()
if idx > size then return nil end
local bsz = math.min(batch, size - idx + 1)
local inputs, labels
local head
if torch.uniform() > self.hfreq then head = 1 else head = 2 end
for i = 1, bsz do
local input, label = self.ds:get(head)
if not inputs then
local iSz = input:size():totable()
local mSz = label:size():totable()
inputs = torch.FloatTensor(bsz, table.unpack(iSz))
labels = torch.FloatTensor(bsz, table.unpack(mSz))
end
inputs[i]:copy(input)
labels[i]:copy(label)
end
idx = idx + batch
collectgarbage()
sample = {inputs = inputs, labels = labels, head = head}
n = n + 1
return n, sample
end
return customloop
end
return M.DataLoader