Skip to content

Commit

Permalink
Add pwelch operator
Browse files Browse the repository at this point in the history
  • Loading branch information
tmartin-gh committed Sep 1, 2023
1 parent 327361a commit a4b76d7
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 22 deletions.
31 changes: 16 additions & 15 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
set(examples
simple_radar_pipeline
recursive_filter
set(examples
simple_radar_pipeline
recursive_filter
channelize_poly_bench
convolution
convolution
conv2d
cgsolve
fft_conv
resample
mvdr_beamformer
cgsolve
fft_conv
resample
mvdr_beamformer
pwelch
resample_poly_bench
spectrogram
spectrogram
spectrogram_graph
spherical_harmonics
spherical_harmonics
svd_power
qr
black_scholes)

add_library(example_lib INTERFACE)
target_include_directories(example_lib SYSTEM INTERFACE ${CUTLASS_INC} ${pybind11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS})
target_include_directories(example_lib SYSTEM INTERFACE ${CUTLASS_INC} ${pybind11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS})

target_link_libraries(example_lib INTERFACE matx::matx) # Transitive properties
target_link_libraries(example_lib INTERFACE matx::matx) # Transitive properties

set_property(TARGET example_lib PROPERTY ENABLE_EXPORTS 1)

Expand All @@ -28,7 +29,7 @@ if (MSVC)
else()
target_compile_options(example_lib INTERFACE ${WARN_FLAGS})
target_compile_options(example_lib INTERFACE ${MATX_CUDA_FLAGS})
endif()
endif()

if (MULTI_GPU)
set_target_properties(example_lib PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
Expand All @@ -37,15 +38,15 @@ endif()

foreach( example ${examples} )
string( CONCAT file ${example} ".cu" )
add_executable( ${example} ${file} )
add_executable( ${example} ${file} )
target_link_libraries(${example} example_lib)
endforeach()

# Build proprietary examples
file (GLOB_RECURSE proprietary_sources CONFIGURE_DEPENDS ${CMAKE_SOURCE_DIR}/proprietary/*/examples/*.cu)
foreach (pexample ${proprietary_sources})
get_filename_component(exename ${pexample} NAME_WE)
add_executable(${exename} ${pexample})
add_executable(${exename} ${pexample})
target_link_libraries(${exename} example_lib)
endforeach()

Expand Down
105 changes: 105 additions & 0 deletions examples/pwelch.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2023, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#include "matx.h"
#include <cassert>
#include <cstdio>
#include <cuda/std/ccomplex>

using namespace matx;

/**
* PWelch
*
* This example shows how to estimate the power spectral density of a 1D tensor
* using Welch's method [1].
*
* [1] P. Welch, "The use of fast Fourier transform for the estimation of power spectra: A method based on time averaging over short, modified periodograms," in IEEE Transactions on Audio and Electroacoustics, vol. 15, no. 2, pp. 70-73, June 1967, doi: 10.1109/TAU.1967.1161901.
*
*/

int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
{
MATX_ENTER_HANDLER();
using complex = cuda::std::complex<float>;

float exec_time_ms;
const int num_iterations = 100;
index_t signal_size = 256;
int nperseg = 32;
int nfft = nperseg;
int noverlap = 8;
float ftone = 3.0;
cudaStream_t stream;
cudaStreamCreate(&stream);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);

// Create buffers
auto x = make_tensor<complex>({signal_size});
auto Pxx = make_tensor<typename complex::value_type>({nfft});

for (index_t k = 0; k < signal_size; k++) {
// Create the time domain signal as a complex exponential
float phase = static_cast<float>(2.0*M_PI*ftone*static_cast<float>(k)/nfft);
x(k) = {static_cast<float>(cos(phase)), static_cast<float>(sin(phase))};
}

// Prefetch the data we just created
x.PrefetchDevice(0);

// Run one time to pre-cache the FFT plan
(Pxx = pwelch(x, nperseg, noverlap, nfft)).run(stream);

// Start the timing
cudaEventRecord(start, stream);

// Start the timing
cudaEventRecord(start, stream);

for (int iteration = 0; iteration < num_iterations; iteration++) {
// Use the PWelch operator
(Pxx = pwelch(x, nperseg, noverlap, nfft)).run(stream);
}

cudaEventRecord(stop, stream);
cudaStreamSynchronize(stream);
cudaEventElapsedTime(&exec_time_ms, start, stop);

printf("Output Pxx:\n");
print(Pxx);
printf("PWelchOp avg runtime = %.2f ms\n", exec_time_ms / num_iterations);

CUDA_CHECK_LAST_ERROR();
MATX_EXIT_HANDLER();
}
1 change: 1 addition & 0 deletions include/matx/operators/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include "matx/operators/matmul.h"
#include "matx/operators/permute.h"
#include "matx/operators/planar.h"
#include "matx/operators/pwelch.h"
#include "matx/operators/qr.h"
#include "matx/operators/r2c.h"
#include "matx/operators/remap.h"
Expand Down
139 changes: 139 additions & 0 deletions include/matx/operators/pwelch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// COpBright (c) 2023, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above cOpBright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above cOpBright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the cOpBright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COpBRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COpBRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once


#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/pwelch.h"

namespace matx
{
namespace detail {
template <typename Op_x>
class PWelchOp : public BaseOp<PWelchOp<Op_x>>
{
private:
Op_x x_;

int nperseg_;
int noverlap_;
int nfft_;
std::array<index_t, 2> out_dims_;
mutable matx::tensor_t<typename Op_x::scalar_type, 2> tmp_out_;

public:
using matxop = bool;
using scalar_type = typename Op_x::scalar_type;
using matx_transform_op = bool;
using pwelch_xform_op = bool;

__MATX_INLINE__ std::string str() const {
return "pwelch(" + get_type_str(x_) + ")";
}

__MATX_INLINE__ PWelchOp(const Op_x &x, int nperseg, int noverlap, int nfft) :
x_(x), nperseg_(nperseg), noverlap_(noverlap), nfft_(nfft) {

MATX_ASSERT_STR(x.Rank() == 1, matxInvalidDim, "pwelch: Only input rank of 1 is supported presently");
for (int r = 0; r < x.Rank(); r++) {
out_dims_[r] = x_.Size(r);
}
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
return tmp_out_(indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return remove_cvref_t<Op_x>::Rank();
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
return out_dims_[dim];
}

template <typename Out, typename Executor>
void Exec(Out &&out, Executor &&ex) const{
static_assert(is_device_executor_v<Executor>, "pwelch() only supports the CUDA executor currently");
pwelch_impl(std::get<0>(out), x_, nperseg_, noverlap_, nfft_, ex.getStream());
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
{
if constexpr (is_matx_op<Op_x>()) {
x_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_device_executor_v<Executor>) {
make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
}

Exec(std::make_tuple(tmp_out_), std::forward<Executor>(ex));
}
};
}

/**
* Estimate the power spectral density of a 1D tensor using Welch's method [1].
*
* @param x
* Input time domain tensor x
* @param nperseg
* Length of each segment
* @param stream
* cuda Stream to execute on
* @param noverlap
* Number of points to overlap between segments. Defaults to 0
* @param nfft
* Length of FFT used per segment. nfft >= nperseg. Defaults to nfft = nperseg
*
* @returns Operator with power spectral density of x
* [1] P. Welch, "The use of fast Fourier transform for the estimation of power spectra: A method based on time averaging over short, modified periodograms," in IEEE Transactions on Audio and Electroacoustics, vol. 15, no. 2, pp. 70-73, June 1967, doi: 10.1109/TAU.1967.1161901.
*
*/

template <typename xType>
__MATX_INLINE__ auto pwelch(xType x, int nperseg, int noverlap, int nfft)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

return detail::PWelchOp(x, nperseg, noverlap, nfft);
}
}
Loading

0 comments on commit a4b76d7

Please sign in to comment.