-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinit.lua
78 lines (67 loc) · 2.41 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
local ffi = require 'ffi'
local cutorch = require 'cutorch'
local C = require 'cutorch-rtc/ffi'
function cutorch.launchPTX(ptx, kernel_name, arguments, gridDim, blockDim)
assert(torch.type(gridDim) == 'table' and #gridDim > 0)
assert(torch.type(blockDim) == 'table' and #blockDim > 0)
assert(torch.Tensor(blockDim):prod() <= 1024)
local args = ffi.new('void*[?]', #arguments)
for i,v in ipairs(arguments) do
if torch.type(v) == 'torch.CudaTensor' then
args[i-1] = ffi.new('float*[1]', v:data())
elseif torch.type(v) == 'torch.CudaHalfTensor' then
args[i-1] = ffi.new('half*[1]', v:data())
elseif torch.type(v) == 'torch.CudaDoubleTensor' then
args[i-1] = ffi.new('double*[1]', v:data())
elseif torch.type(v) == 'torch.CudaIntTensor' then
args[i-1] = ffi.new('int*[1]', v:data())
elseif torch.type(v) == 'torch.CudaByteTensor' then
args[i-1] = ffi.new('uint8*[1]', v:data())
elseif torch.type(v) == 'torch.CudaCharTensor' then
args[i-1] = ffi.new('int8*[1]', v:data())
elseif torch.type(v) == 'table' then
args[i-1] = ffi.new(v[1]..'[1]', v[2])
elseif torch.type(v) == 'cdata' then
args[i-1] = v
else
--TODO: add textures
error('unsupported kernel argument #'..i..': '..torch.type(v))
end
end
local grid = ffi.new('int[3]', 1)
local block = ffi.new('int[3]', 1)
for i,v in ipairs(gridDim) do grid[i-1] = v end
for i,v in ipairs(blockDim) do block[i-1] = v end
C.launchPTX(cutorch.getState(), ptx, kernel_name, args, grid, block)
end
local types = {
'CudaTensor',
'CudaHalfTensor',
'CudaDoubleTensor',
}
for i,ttype in ipairs(types) do
torch[ttype].apply1 = function(self, lambda)
assert(type(lambda) == 'string')
C['TH'..ttype..'_pointwiseApply1'](cutorch.getState(),
self:cdata(),
lambda)
return self
end
torch[ttype].apply2 = function(self, b, lambda)
assert(type(lambda) == 'string')
C['TH'..ttype..'_pointwiseApply2'](cutorch.getState(),
self:cdata(),
b:cdata(),
lambda)
return self
end
torch[ttype].apply3 = function(self, b, c, lambda)
assert(type(lambda) == 'string')
C['TH'..ttype..'_pointwiseApply3'](cutorch.getState(),
self:cdata(),
b:cdata(),
c:cdata(),
lambda)
return self
end
end