Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inside voxel bug fix #634

Merged
merged 1 commit into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 98 additions & 58 deletions kaolin/csrc/render/spc/raytrace_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
// Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.

// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -84,14 +84,12 @@ decide_cuda_kernel(
const float3* __restrict__ ray_o,
const float3* __restrict__ ray_d,
const uint2* __restrict__ nuggets,
float* depth,
uint* __restrict__ info,
const uint8_t* __restrict__ octree,
const uint32_t level,
const uint32_t not_done) {

uint tidx = blockDim.x * blockIdx.x + threadIdx.x;
const float eps = 1e-8;

if (tidx < num) {
uint ridx = nuggets[tidx].x;
Expand All @@ -101,7 +99,59 @@ decide_cuda_kernel(
float3 d = ray_d[ridx];

// Radius of voxel
float r = 1.0 / ((float)(0x1 << level)) + eps;
float r = 1.0 / ((float)(0x1 << level));

// Transform to [-1, 1]
const float3 vc = make_float3(
fmaf(r, fmaf(2.0, p.x, 1.0), -1.0f),
fmaf(r, fmaf(2.0, p.y, 1.0), -1.0f),
fmaf(r, fmaf(2.0, p.z, 1.0), -1.0f));

// Compute aux info (precompute to optimize)
float3 sgn = ray_sgn(d);
float3 ray_inv = make_float3(1.0 / d.x, 1.0 / d.y, 1.0 / d.z);

float depth = ray_aabb(o, d, ray_inv, sgn, vc, r);

if (not_done){
if (depth != 0.0)
info[tidx] = __popc(octree[pidx]);
else
info[tidx] = 0;
}
else { // at bottom
if (depth > 0.0)
info[tidx] = 1;
else
info[tidx] = 0;
}
}
}

// Overloaded version of function above that returns depth of voxel/ ray entry points
__global__ void
decide_cuda_kernel(
const uint num,
const point_data* __restrict__ points,
const float3* __restrict__ ray_o,
const float3* __restrict__ ray_d,
const uint2* __restrict__ nuggets,
float* depth,
uint* __restrict__ info,
const uint8_t* __restrict__ octree,
const uint32_t level) {

uint tidx = blockDim.x * blockIdx.x + threadIdx.x;

if (tidx < num) {
uint ridx = nuggets[tidx].x;
uint pidx = nuggets[tidx].y;
point_data p = points[pidx];
float3 o = ray_o[ridx];
float3 d = ray_d[ridx];

// Radius of voxel
float r = 1.0 / ((float)(0x1 << level));

// Transform to [-1, 1]
const float3 vc = make_float3(
Expand All @@ -117,14 +167,14 @@ decide_cuda_kernel(

// Perform AABB check
if (depth[tidx] > 0.0){
// Count # of occupied voxels for expansion, if more levels are left
info[tidx] = not_done ? __popc(octree[pidx]) : 1;
info[tidx] = 1; // mark to keep
} else {
info[tidx] = 0;
}
}
}

// Overloaded version of function above that returns depth of voxel/ ray entry and exit points
__global__ void
decide_cuda_kernel(
const uint num,
Expand All @@ -135,11 +185,9 @@ decide_cuda_kernel(
float2* __restrict__ depth,
uint* __restrict__ info,
const uint8_t* __restrict__ octree,
const uint32_t level,
const uint32_t not_done) {
const uint32_t level) {

uint tidx = blockDim.x * blockIdx.x + threadIdx.x;
const float eps = 1e-8;

if (tidx < num) {
uint ridx = nuggets[tidx].x;
Expand All @@ -149,7 +197,7 @@ decide_cuda_kernel(
float3 d = ray_d[ridx];

// Radius of voxel
float r = 1.0 / ((float)(0x1 << level)) + eps;
float r = 1.0 / ((float)(0x1 << level));

// Transform to [-1, 1]
const float3 vc = make_float3(
Expand All @@ -165,8 +213,7 @@ decide_cuda_kernel(

// Perform AABB check
if (depth[tidx].x > 0.0 && depth[tidx].y > 0.0){
// Count # of occupied voxels for expansion, if more levels are left
info[tidx] = not_done ? __popc(octree[pidx]) : 1;
info[tidx] = 1; // mark to keep
} else {
info[tidx] = 0;
}
Expand Down Expand Up @@ -435,10 +482,6 @@ cumsum_reverse_cuda_kernel(
}
}

////////////////////////////////////////////////////////////////////////////////////////////////
/// CUDA Implementations
////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> raytrace_cuda_impl(
at::Tensor octree,
at::Tensor points,
Expand All @@ -455,25 +498,21 @@ std::vector<at::Tensor> raytrace_cuda_impl(

uint8_t* octree_ptr = octree.data_ptr<uint8_t>();
point_data* points_ptr = reinterpret_cast<point_data*>(points.data_ptr<short>());
uint* pyramid_ptr = (uint*)pyramid.data_ptr<int>();
uint* pyramid_sum = pyramid_ptr + max_level + 2;
uint* exclusive_sum_ptr = reinterpret_cast<uint*>(exclusive_sum.data_ptr<int>());
float3* ray_o_ptr = reinterpret_cast<float3*>(ray_o.data_ptr<float>());
float3* ray_d_ptr = reinterpret_cast<float3*>(ray_d.data_ptr<float>());


// allocate local GPU storage
at::Tensor nuggets0 = at::empty({num, 2}, octree.options().dtype(at::kInt));
uint2* nuggets0_ptr = reinterpret_cast<uint2*>(nuggets0.data_ptr<int>());
at::Tensor nuggets1;

uint depth_dim = with_exit ? 2 : 1;
at::Tensor depths0 = at::empty({num, depth_dim}, octree.options().dtype(at::kFloat));
float* depth0_ptr = depths0.data_ptr<float>();
at::Tensor depths0;
at::Tensor depths1;

// Generate proposals (first proposal is root node)
init_nuggets_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(num, nuggets0_ptr);
init_nuggets_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()));

uint cnt, buffer = 0;
for (uint32_t l = 0; l <= target_level; l++) {
Expand All @@ -482,24 +521,31 @@ std::vector<at::Tensor> raytrace_cuda_impl(
uint* info_ptr = reinterpret_cast<uint*>(info.data_ptr<int>());

// Do the proposals hit?
if (with_exit) {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, nuggets0_ptr,
reinterpret_cast<float2*>(depth0_ptr),
info_ptr, octree_ptr, l, target_level - l);
} else {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, nuggets0_ptr,
depth0_ptr, info_ptr, octree_ptr, l, target_level - l);
if (l == target_level && return_depth) {
depths0 = at::empty({num, depth_dim}, octree.options().dtype(at::kFloat));

if (with_exit) {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()),
reinterpret_cast<float2*>(l == target_level ? depths0.data_ptr<float>() : 0),
info_ptr, octree_ptr, l);
} else {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()),
l == target_level ? depths0.data_ptr<float>() : 0, info_ptr, octree_ptr, l);
}
}
else {
decide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, points_ptr, ray_o_ptr, ray_d_ptr, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()),
info_ptr, octree_ptr, l, target_level - l);
}


at::Tensor prefix_sum = at::empty({num+1}, octree.options().dtype(at::kInt));
uint* prefix_sum_ptr = reinterpret_cast<uint*>(prefix_sum.data_ptr<int>());

// set first element to zero
CubDebugExit(cudaMemcpy(prefix_sum_ptr, &buffer, sizeof(uint),
cudaMemcpyHostToDevice));
CubDebugExit(cudaMemcpy(prefix_sum_ptr, &buffer, sizeof(uint), cudaMemcpyHostToDevice));

// set up memory for DeviceScan calls
void* temp_storage_ptr = NULL;
Expand All @@ -508,61 +554,55 @@ std::vector<at::Tensor> raytrace_cuda_impl(
at::Tensor temp_storage = at::empty({(int64_t)temp_storage_bytes}, octree.options());
temp_storage_ptr = (void*)temp_storage.data_ptr<uint8_t>();


CubDebugExit(cub::DeviceScan::InclusiveSum(
temp_storage_ptr, temp_storage_bytes, info_ptr,
prefix_sum_ptr + 1, num)); //start sum on second element
cudaMemcpy(&cnt, prefix_sum_ptr + num, sizeof(uint), cudaMemcpyDeviceToHost);
cudaMemcpy(&cnt, prefix_sum_ptr + num, sizeof(uint), cudaMemcpyDeviceToHost);

// allocate local GPU storage
nuggets1 = at::empty({cnt, 2}, octree.options().dtype(at::kInt));
uint2* nuggets1_ptr = reinterpret_cast<uint2*>(nuggets1.data_ptr<int>());

depths1 = at::empty({cnt, depth_dim}, octree.options().dtype(at::kFloat));
float* depth1_ptr = depths1.data_ptr<float>();

if (cnt == 0)
{
num = cnt;
break; // miss everything
// miss everything
if (cnt == 0) {
num = 0;
nuggets0 = nuggets1;
if (return_depth) depths1 = at::empty({0, depth_dim}, octree.options().dtype(at::kFloat));
break;
}

// Subdivide if more levels remain, repeat
if (l < target_level) {
subdivide_cuda_kernel<<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, nuggets0_ptr, nuggets1_ptr, ray_o_ptr, points_ptr,
num, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()), reinterpret_cast<uint2*>(nuggets1.data_ptr<int>()), ray_o_ptr, points_ptr,
octree_ptr, exclusive_sum_ptr, info_ptr, prefix_sum_ptr, l);
} else {
compactify_cuda_kernel<uint2><<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, nuggets0_ptr, nuggets1_ptr,
num, reinterpret_cast<uint2*>(nuggets0.data_ptr<int>()), reinterpret_cast<uint2*>(nuggets1.data_ptr<int>()),
info_ptr, prefix_sum_ptr);
if (return_depth) {
depths1 = at::empty({cnt, depth_dim}, octree.options().dtype(at::kFloat));

if (with_exit) {
compactify_cuda_kernel<float2><<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, reinterpret_cast<float2*>(depth0_ptr),
reinterpret_cast<float2*>(depth1_ptr),
num, reinterpret_cast<float2*>(depths0.data_ptr<float>()),
reinterpret_cast<float2*>(depths1.data_ptr<float>()),
info_ptr, prefix_sum_ptr);
} else {
compactify_cuda_kernel<float><<<(num + RT_NUM_THREADS - 1) / RT_NUM_THREADS, RT_NUM_THREADS>>>(
num, depth0_ptr, depth1_ptr,
num, depths0.data_ptr<float>(), depths1.data_ptr<float>(),
info_ptr, prefix_sum_ptr);
}
}
}

CubDebugExit(cudaGetLastError());

nuggets0_ptr = nuggets1_ptr;
depth0_ptr = depth1_ptr;

nuggets0 = nuggets1;
num = cnt;
}

if (return_depth) {
return { nuggets1.index({Slice(0, num)}).contiguous(),
return { nuggets0.index({Slice(0, num)}).contiguous(),
depths1.index({Slice(0, num)}).contiguous() };
} else {
return { nuggets1.index({Slice(0, num)}).contiguous() };
return { nuggets0.index({Slice(0, num)}).contiguous() };
}
}

Expand Down
57 changes: 56 additions & 1 deletion tests/python/kaolin/render/spc/test_raytrace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES.
# Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -230,6 +230,61 @@ def test_raytrace_with_depth_with_exit(self, octree, point_hierarchy, pyramid, e

assert torch.equal(depth, expected_depth)

@pytest.mark.parametrize('return_depth,with_exit', [(False, False), (True, False), (True, True)])
def test_raytrace_inside(self, octree, point_hierarchy, pyramid, exsum, return_depth, with_exit):
height = 4
width = 4
direction = torch.tensor([[0., 0., -1.]], dtype=torch.float,
device='cuda').repeat(height * width , 1)
origin = self._generate_rays_origin(height, width, 0.9)
outputs = unbatched_raytrace(
octree, point_hierarchy, pyramid, exsum, origin, direction, 2,
return_depth=return_depth, with_exit=with_exit)

ridx = outputs[0]
pidx = outputs[1]

expected_nuggets = torch.tensor([
[ 0, 13],
[ 0, 6],
[ 0, 5],
[ 1, 8],
[ 1, 7],
[ 2, 15],
[ 4, 10],
[ 4, 9],
[ 5, 12],
[ 5, 11]], device='cuda', dtype=torch.int)
assert torch.equal(ridx, expected_nuggets[...,0])
assert torch.equal(pidx, expected_nuggets[...,1])
if return_depth:
depth = outputs[2]
if with_exit:
expected_depth = torch.tensor([
[0.4, 0.9],
[0.9, 1.4],
[1.4, 1.9],
[0.9, 1.4],
[1.4, 1.9],
[1.4, 1.9],
[0.9, 1.4],
[1.4, 1.9],
[0.9, 1.4],
[1.4, 1.9]], device='cuda', dtype=torch.float)
else:
expected_depth = torch.tensor([
[0.4],
[0.9],
[1.4],
[0.9],
[1.4],
[1.4],
[0.9],
[1.4],
[0.9],
[1.4]], device='cuda', dtype=torch.float)
assert torch.allclose(depth, expected_depth)

def test_ambiguous_raytrace(self):
# TODO(ttakikawa):
# Since 0.10.0, the behaviour of raytracing exactly between voxels
Expand Down