Skip to content

Commit

Permalink
Added upsample() operator
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick committed Jun 8, 2023
1 parent 31298b6 commit 5ac5c78
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs_input/api/tensorops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Advanced Operators
.. doxygenfunction:: clone(Op t, const index_t (&shape)[Rank])
.. doxygenfunction:: clone(Op t, const std::array<index_t, Rank> &shape)
.. doxygenfunction:: stack
.. doxygenfunction:: upsample
.. doxygenfunction:: slice(const OpType opIn, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()])
.. doxygenfunction:: slice(const OpType op, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()], const index_t (&strides)[OpType::Rank()])
.. doxygenfunction:: permute(detail::tensor_impl_t<T, Rank> &out, const detail::tensor_impl_t<T, Rank> &in, const std::initializer_list<uint32_t> &dims, const cudaStream_t stream)
Expand Down
1 change: 1 addition & 0 deletions include/matx/operators/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@
#include "matx/operators/slice.h"
#include "matx/operators/sph2cart.h"
#include "matx/operators/stack.h"
#include "matx/operators/upsample.h"
118 changes: 118 additions & 0 deletions include/matx/operators/upsample.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2021, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"

namespace matx
{
/**
* Upsamples a tensor by stuffing zeros
*/
namespace detail {
template <typename T>
class UpsampleOp : public BaseOp<UpsampleOp<T>>
{
private:
T op_;
int32_t dim_;
uint32_t n_;

public:
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T::scalar_type;

__MATX_INLINE__ std::string str() const { return "upsample(" + op_.str() + ")"; }

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return T::Rank();
}

__MATX_INLINE__ UpsampleOp(const T &op, int32_t dim, uint32_t n) : op_(op), dim_(dim), n_(n) {
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
static_assert(sizeof...(Is)==Rank());
static_assert((std::is_convertible_v<Is, index_t> && ... ));

// convert variadic type to array so we can read/update
std::array<index_t, Rank()> ind{indices...};
if ((ind[dim_] % n_) == 0) {
ind[dim_] /= n_;
return mapply(op_, ind);
}

return static_cast<decltype(mapply(op_, ind))>(0);
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int32_t dim) const
{
if (dim == dim_) {
return op_.Size(dim) * n_;
}
else {
return op_.Size(dim);
}
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) {
return set(*this, rhs);
}

static_assert(Rank() > 0, "UpsampleOp: Rank of operator must be greater than 0.");
static_assert(T::Rank() > 0, "UpsampleOp: Rank of input operator must be greater than 0.");
};
}

/**
* @brief Upsample across one dimension with an integer rate
*
* Upsamples an input tensor across dimension `dim` by a factor of `n`. Upsampling is performed
* by stuffing zeros
*
* @tparam T Input operator/tensor type
* @param op Input operator
* @param dim Dimension to upsample
* @param n Upsample rate
* @return Upsampled operator
*/
template <typename T>
__MATX_INLINE__ auto upsample( const T &op, int32_t dim, uint32_t n) {
return detail::UpsampleOp<T>(op, dim, n);
}
} // end namespace matx
61 changes: 61 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,67 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapRankZero)
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumericAllExecs, Upsample)
{
MATX_ENTER_HANDLER();
using TestType = std::tuple_element_t<0, TypeParam>;
using ExecType = std::tuple_element_t<1, TypeParam>;
uint32_t us_rate;

ExecType exec{};
auto t1 = make_tensor<TestType>({10});
auto t1o2 = make_tensor<TestType>({t1.Size(0) * 2}); // Upsample 2
auto t1o3 = make_tensor<TestType>({t1.Size(0) * 3}); // Upsample 3

TestType c = GenerateData<TestType>();
(t1 = c).run(exec);

us_rate = 2;
(t1o2 = upsample(t1, 0, us_rate)).run(exec);
cudaStreamSynchronize(0);
for (int i = 0; i < t1o2.Size(0); i++) {
if ((i % us_rate) == 0) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t1o2(i), t1(i/us_rate)));
}
else {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t1o2(i), (TestType)0));
}
}

us_rate = 3;
(t1o3 = upsample(t1, 0, us_rate)).run(exec);
cudaStreamSynchronize(0);
for (int i = 0; i < t1o3.Size(0); i++) {
if ((i % us_rate) == 0) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t1o3(i), t1(i/us_rate)));
}
else {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t1o3(i), (TestType)0));
}
}

auto t2 = make_tensor<TestType>({10, 10});
auto t2o2 = make_tensor<TestType>({t1.Size(0) * 2, 10}); // Upsample 2

(t2 = c).run(exec);

us_rate = 2;
(t2o2 = upsample(t2, 0, us_rate)).run(exec);
cudaStreamSynchronize(0);
for (int i = 0; i < t2o2.Size(0); i++) {
for (int j = 0; j < t2o2.Size(1); j++) {
if ((i % us_rate) == 0) {
ASSERT_EQ(t2o2(i, j), t2(i/us_rate, j));
}
else {
ASSERT_EQ(t2o2(i, j), (TestType)0);
}
}
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsComplexTypesAllExecs, RealImagOp)
{
MATX_ENTER_HANDLER();
Expand Down

0 comments on commit 5ac5c78

Please sign in to comment.