forked from torch/cutorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFFI.lua
81 lines (65 loc) · 1.91 KB
/
FFI.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
local ok, ffi = pcall(require, 'ffi')
if ok then
local cdefs = [[
typedef struct CUstream_st *cudaStream_t;
struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
typedef struct CUhandle_st *cublasHandle_t;
typedef struct _THCCudaResourcesPerDevice {
cudaStream_t* streams;
cublasHandle_t* blasHandles;
size_t scratchSpacePerStream;
void** devScratchSpacePerStream;
} THCCudaResourcesPerDevice;
typedef struct THCState
{
struct THCRNGState* rngState;
struct cudaDeviceProp* deviceProperties;
cudaStream_t currentStream;
cublasHandle_t currentBlasHandle;
THCCudaResourcesPerDevice* resourcesPerDevice;
int numDevices;
int numUserStreams;
int numUserBlasHandles;
int currentPerDeviceStream;
int currentPerDeviceBlasHandle;
struct THAllocator* cudaHostAllocator;
} THCState;
cudaStream_t THCState_getCurrentStream(THCState *state);
typedef struct THCudaStorage
{
float *data;
long size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
struct THCudaStorage *view;
} THCudaStorage;
typedef struct THCudaTensor
{
long *size;
long *stride;
int nDimension;
THCudaStorage *storage;
long storageOffset;
int refcount;
char flag;
} THCudaTensor;
]]
ffi.cdef(cdefs)
local Storage = torch.getmetatable('torch.CudaStorage')
local Storage_tt = ffi.typeof('THCudaStorage**')
rawset(Storage, "cdata", function(self) return Storage_tt(self)[0] end)
rawset(Storage, "data", function(self) return Storage_tt(self)[0].data end)
-- Tensor
local Tensor = torch.getmetatable('torch.CudaTensor')
local Tensor_tt = ffi.typeof('THCudaTensor**')
rawset(Tensor, "cdata", function(self) return Tensor_tt(self)[0] end)
rawset(Tensor, "data",
function(self)
self = Tensor_tt(self)[0]
return self.storage ~= nil and self.storage.data + self.storageOffset or nil
end
)
end