-
Notifications
You must be signed in to change notification settings - Fork 373
/
Copy pathmain.py
31 lines (27 loc) · 932 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
cuda_kernel = """
extern "C" __global__
void square_kernel(const float* __restrict__ input, float* __restrict__ output, int size) {
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
output[index] = input[index] * input[index];
}
}
"""
import torch
import torch.utils.cpp_extension
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
module = torch.utils.cpp_extension.load_inline(
name='square',
cpp_sources='',
cuda_sources=cuda_kernel,
functions=['square_kernel']
)
def square(input):
output = torch.empty_like(input)
threads_per_block = 1024
blocks_per_grid = (input.numel() + (threads_per_block - 1)) // threads_per_block
module.square_kernel(blocks_per_grid, threads_per_block, input, output, input.numel())
return output
# Example usage
input_tensor = torch.randn(100, device=device)
output_tensor = square(input_tensor)