Skip to content

Commit

Permalink
Enhance JCat0 kernel to handle empty JIdx tensor case and add unit te…
Browse files Browse the repository at this point in the history
…st for concatenation of JaggedTensors
  • Loading branch information
fwilliams committed Dec 19, 2024
1 parent 18fc7dc commit 200cf24
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
17 changes: 11 additions & 6 deletions fvdb/src/detail/ops/JCat0.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "Ops.h"

#include <detail/utils/Utils.h>
#include <detail/utils/cuda/Utils.cuh>

#include <ATen/cuda/Atomic.cuh>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

Expand Down Expand Up @@ -40,20 +40,24 @@ template <typename IdxT>
__global__ void
computeIndexPutArg(
const size_t jti, const JOffsetsType *__restrict__ const *__restrict__ offsets,
const size_t numOffsets,
const size_t numOffsets, const size_t numElements,
const TorchRAcc32<JIdxType, 1> inJIdxI, // Jidx of the i^th input tensor
const TorchRAcc32<JOffsetsType, 1> inJoffsetsI, // JOffsets of the i^th input tensor
const TorchRAcc32<JOffsetsType, 1> outJOffsets, // Output JOffsets (already computed earlier)
TorchRAcc32<IdxT, 1> outSelIdx, // Output selection indices
TorchRAcc32<JIdxType, 1> outJIdx) { // Output Jidx
int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t numElements = inJIdxI.size(0);
TorchRAcc32<JIdxType, 1> outJIdx // Output Jidx
) {
int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx >= numElements) {
return;
}

const JIdxType jidx = inJIdxI[idx]; // Which tensor this element belongs to
// When you have a JaggedTensor that has only one tensor in it, the JIdx tensor is empty to
// save memory (it's effectively all zeros anyway).l We need to handle this case, which is
// what this flag is doing
const bool emptyJidx = inJIdxI.size(0) == 0 && numElements != 0;
const JIdxType jidx = emptyJidx ? 0 : inJIdxI[idx]; // Which tensor this element belongs to

// Where in the output tensor we're going to write to
JOffsetsType tensorWriteOffset = 0;
Expand Down Expand Up @@ -146,6 +150,7 @@ dispatchJCat0<torch::kCUDA>(const std::vector<JaggedTensor> &vec) {
GET_BLOCKS(numElements, numThreadsComputeIndexPutArg);
computeIndexPutArg<<<numBlocksComputeIndexPutArg, numThreadsComputeIndexPutArg>>>(
jti, thrust::raw_pointer_cast(offsets_d.data()), offsets_d.size(),
jt.jdata().size(0),
jt.jidx().packed_accessor32<JIdxType, 1, torch::RestrictPtrTraits>(),
jt.joffsets().packed_accessor32<JOffsetsType, 1, torch::RestrictPtrTraits>(),
outJOffsets.packed_accessor32<JOffsetsType, 1, torch::RestrictPtrTraits>(),
Expand Down
29 changes: 29 additions & 0 deletions fvdb/tests/unit/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ def check_lshape(self, jt: fvdb.JaggedTensor, lt: List[torch.Tensor] | List[List
else:
assert False, "jagged tensor ldim should be 1 or 2"

@parameterized.expand(all_device_dtype_combos)
def test_jcat_along_dim_0_with_one_tensor(self, device, dtype):
batch_size = 1

# Make a point cloud with a random number of points
def get_pc(num_pc_list: list):
pc_list = []
for num_pc in num_pc_list:
pc_list.append(torch.rand((num_pc, 3)).to(device))
return pc_list

num_pc_list = torch.randint(low=50, high=1000, size=(batch_size,), device=device).cpu().tolist()

pc1_tensor_list = get_pc(num_pc_list)
pc2_tensor_list = get_pc(num_pc_list)

pc1_jagged = fvdb.JaggedTensor(pc1_tensor_list)
pc2_jagged = fvdb.JaggedTensor(pc2_tensor_list)

cat_dim = 0
concat_tensor_list = [
torch.cat([pc1_tensor_list[i], pc2_tensor_list[i]], dim=cat_dim) for i in range(batch_size)
]

jagged_from_concat_list = fvdb.JaggedTensor(concat_tensor_list)
jcat_result = fvdb.jcat([pc1_jagged, pc2_jagged], dim=cat_dim)

self.assertTrue(torch.equal(jagged_from_concat_list.jdata, jcat_result.jdata))

@parameterized.expand(all_device_dtype_combos)
def test_pickle(self, device, dtype):
jt, _ = self.mklol(7, 4, 8, device, dtype)
Expand Down

0 comments on commit 200cf24

Please sign in to comment.