From a4b76d78455bf26bc2b4763d40cf04a7de1e6e8d Mon Sep 17 00:00:00 2001 From: Tim Martin <38798827+tmartin-gh@users.noreply.github.com> Date: Fri, 1 Sep 2023 14:07:58 -0700 Subject: [PATCH] Add pwelch operator --- examples/CMakeLists.txt | 31 ++-- examples/pwelch.cu | 105 +++++++++++++ include/matx/operators/operators.h | 1 + include/matx/operators/pwelch.h | 139 ++++++++++++++++++ include/matx/transforms/pwelch.h | 83 +++++++++++ test/00_transform/PWelch.cu | 101 +++++++++++++ test/CMakeLists.txt | 7 +- test/test_vectors/generators/00_transforms.py | 38 ++++- 8 files changed, 483 insertions(+), 22 deletions(-) create mode 100644 examples/pwelch.cu create mode 100644 include/matx/operators/pwelch.h create mode 100644 include/matx/transforms/pwelch.h create mode 100644 test/00_transform/PWelch.cu diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e92c9e26d..bd16e3a99 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,25 +1,26 @@ -set(examples - simple_radar_pipeline - recursive_filter +set(examples + simple_radar_pipeline + recursive_filter channelize_poly_bench - convolution + convolution conv2d - cgsolve - fft_conv - resample - mvdr_beamformer + cgsolve + fft_conv + resample + mvdr_beamformer + pwelch resample_poly_bench - spectrogram + spectrogram spectrogram_graph - spherical_harmonics + spherical_harmonics svd_power qr black_scholes) add_library(example_lib INTERFACE) -target_include_directories(example_lib SYSTEM INTERFACE ${CUTLASS_INC} ${pybind11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS}) +target_include_directories(example_lib SYSTEM INTERFACE ${CUTLASS_INC} ${pybind11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS}) -target_link_libraries(example_lib INTERFACE matx::matx) # Transitive properties +target_link_libraries(example_lib INTERFACE matx::matx) # Transitive properties set_property(TARGET example_lib PROPERTY ENABLE_EXPORTS 1) @@ -28,7 +29,7 @@ if (MSVC) else() target_compile_options(example_lib INTERFACE ${WARN_FLAGS}) target_compile_options(example_lib INTERFACE ${MATX_CUDA_FLAGS}) -endif() +endif() if (MULTI_GPU) set_target_properties(example_lib PROPERTIES CUDA_SEPARABLE_COMPILATION ON) @@ -37,7 +38,7 @@ endif() foreach( example ${examples} ) string( CONCAT file ${example} ".cu" ) - add_executable( ${example} ${file} ) + add_executable( ${example} ${file} ) target_link_libraries(${example} example_lib) endforeach() @@ -45,7 +46,7 @@ endforeach() file (GLOB_RECURSE proprietary_sources CONFIGURE_DEPENDS ${CMAKE_SOURCE_DIR}/proprietary/*/examples/*.cu) foreach (pexample ${proprietary_sources}) get_filename_component(exename ${pexample} NAME_WE) - add_executable(${exename} ${pexample}) + add_executable(${exename} ${pexample}) target_link_libraries(${exename} example_lib) endforeach() diff --git a/examples/pwelch.cu b/examples/pwelch.cu new file mode 100644 index 000000000..128bab0d0 --- /dev/null +++ b/examples/pwelch.cu @@ -0,0 +1,105 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2023, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#include "matx.h" +#include +#include +#include + +using namespace matx; + +/** + * PWelch + * + * This example shows how to estimate the power spectral density of a 1D tensor + * using Welch's method [1]. + * + * [1] P. Welch, "The use of fast Fourier transform for the estimation of power spectra: A method based on time averaging over short, modified periodograms," in IEEE Transactions on Audio and Electroacoustics, vol. 15, no. 2, pp. 70-73, June 1967, doi: 10.1109/TAU.1967.1161901. + * + */ + +int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) +{ + MATX_ENTER_HANDLER(); + using complex = cuda::std::complex; + + float exec_time_ms; + const int num_iterations = 100; + index_t signal_size = 256; + int nperseg = 32; + int nfft = nperseg; + int noverlap = 8; + float ftone = 3.0; + cudaStream_t stream; + cudaStreamCreate(&stream); + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + // Create buffers + auto x = make_tensor({signal_size}); + auto Pxx = make_tensor({nfft}); + + for (index_t k = 0; k < signal_size; k++) { + // Create the time domain signal as a complex exponential + float phase = static_cast(2.0*M_PI*ftone*static_cast(k)/nfft); + x(k) = {static_cast(cos(phase)), static_cast(sin(phase))}; + } + + // Prefetch the data we just created + x.PrefetchDevice(0); + + // Run one time to pre-cache the FFT plan + (Pxx = pwelch(x, nperseg, noverlap, nfft)).run(stream); + + // Start the timing + cudaEventRecord(start, stream); + + // Start the timing + cudaEventRecord(start, stream); + + for (int iteration = 0; iteration < num_iterations; iteration++) { + // Use the PWelch operator + (Pxx = pwelch(x, nperseg, noverlap, nfft)).run(stream); + } + + cudaEventRecord(stop, stream); + cudaStreamSynchronize(stream); + cudaEventElapsedTime(&exec_time_ms, start, stop); + + printf("Output Pxx:\n"); + print(Pxx); + printf("PWelchOp avg runtime = %.2f ms\n", exec_time_ms / num_iterations); + + CUDA_CHECK_LAST_ERROR(); + MATX_EXIT_HANDLER(); +} diff --git a/include/matx/operators/operators.h b/include/matx/operators/operators.h index 2d1f68691..fda8f71c2 100644 --- a/include/matx/operators/operators.h +++ b/include/matx/operators/operators.h @@ -75,6 +75,7 @@ #include "matx/operators/matmul.h" #include "matx/operators/permute.h" #include "matx/operators/planar.h" +#include "matx/operators/pwelch.h" #include "matx/operators/qr.h" #include "matx/operators/r2c.h" #include "matx/operators/remap.h" diff --git a/include/matx/operators/pwelch.h b/include/matx/operators/pwelch.h new file mode 100644 index 000000000..b7df7dde0 --- /dev/null +++ b/include/matx/operators/pwelch.h @@ -0,0 +1,139 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// COpBright (c) 2023, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above cOpBright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above cOpBright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the cOpBright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COpBRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COpBRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + + +#include "matx/core/type_utils.h" +#include "matx/operators/base_operator.h" +#include "matx/transforms/pwelch.h" + +namespace matx +{ + namespace detail { + template + class PWelchOp : public BaseOp> + { + private: + Op_x x_; + + int nperseg_; + int noverlap_; + int nfft_; + std::array out_dims_; + mutable matx::tensor_t tmp_out_; + + public: + using matxop = bool; + using scalar_type = typename Op_x::scalar_type; + using matx_transform_op = bool; + using pwelch_xform_op = bool; + + __MATX_INLINE__ std::string str() const { + return "pwelch(" + get_type_str(x_) + ")"; + } + + __MATX_INLINE__ PWelchOp(const Op_x &x, int nperseg, int noverlap, int nfft) : + x_(x), nperseg_(nperseg), noverlap_(noverlap), nfft_(nfft) { + + MATX_ASSERT_STR(x.Rank() == 1, matxInvalidDim, "pwelch: Only input rank of 1 is supported presently"); + for (int r = 0; r < x.Rank(); r++) { + out_dims_[r] = x_.Size(r); + } + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const + { + return tmp_out_(indices...); + } + + static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() + { + return remove_cvref_t::Rank(); + } + + constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const + { + return out_dims_[dim]; + } + + template + void Exec(Out &&out, Executor &&ex) const{ + static_assert(is_device_executor_v, "pwelch() only supports the CUDA executor currently"); + pwelch_impl(std::get<0>(out), x_, nperseg_, noverlap_, nfft_, ex.getStream()); + } + + template + __MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept + { + if constexpr (is_matx_op()) { + x_.PreRun(std::forward(shape), std::forward(ex)); + } + + if constexpr (is_device_executor_v) { + make_tensor(tmp_out_, out_dims_, MATX_ASYNC_DEVICE_MEMORY, ex.getStream()); + } + + Exec(std::make_tuple(tmp_out_), std::forward(ex)); + } + }; + } + + /** + * Estimate the power spectral density of a 1D tensor using Welch's method [1]. + * + * @param x + * Input time domain tensor x + * @param nperseg + * Length of each segment + * @param stream + * cuda Stream to execute on + * @param noverlap + * Number of points to overlap between segments. Defaults to 0 + * @param nfft + * Length of FFT used per segment. nfft >= nperseg. Defaults to nfft = nperseg + * + * @returns Operator with power spectral density of x + + * [1] P. Welch, "The use of fast Fourier transform for the estimation of power spectra: A method based on time averaging over short, modified periodograms," in IEEE Transactions on Audio and Electroacoustics, vol. 15, no. 2, pp. 70-73, June 1967, doi: 10.1109/TAU.1967.1161901. + * + */ + + template + __MATX_INLINE__ auto pwelch(xType x, int nperseg, int noverlap, int nfft) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + + return detail::PWelchOp(x, nperseg, noverlap, nfft); + } +} diff --git a/include/matx/transforms/pwelch.h b/include/matx/transforms/pwelch.h new file mode 100644 index 000000000..acf93863f --- /dev/null +++ b/include/matx/transforms/pwelch.h @@ -0,0 +1,83 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2023, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +namespace matx +{ + /** + * Estimate the power spectral density of a 1D tensor using Welch's method [1]. + * + * @param Pxx + * Output power spectral density of time domain signal x + * @param x + * Input time domain signal x + * @param nperseg + * Length of each segment + * @param stream + * cuda Stream to execute on + * @param noverlap + * Number of points to overlap between segments. Defaults to 0 + * @param nfft + * Length of FFT used per segment. nfft >= nperseg. Defaults to nfft = nperseg + * + * [1] P. Welch, "The use of fast Fourier transform for the estimation of power spectra: A method based on time averaging over short, modified periodograms," in IEEE Transactions on Audio and Electroacoustics, vol. 15, no. 2, pp. 70-73, June 1967, doi: 10.1109/TAU.1967.1161901. + * + */ + template + __MATX_INLINE__ void pwelch_impl(PxxType Pxx, xType x, int nperseg, int noverlap, int nfft, cudaStream_t stream=0) + { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + + MATX_ASSERT_STR(Pxx.Rank() == x.Rank(), matxInvalidDim, "pwelch: Pxx rank must be the same as x rank"); + MATX_ASSERT_STR(nfft >= nperseg, matxInvalidDim, "pwelch: nfft must be >= nperseg"); + MATX_ASSERT_STR((noverlap >= 0) && (noverlap < nperseg), matxInvalidDim, "pwelch: Must have 0 <= noverlap < nperseg"); + + // Create overlapping view + auto x_with_overlaps = x.OverlapView({nfft}, {noverlap}); + + // Create temporary space for fft outputs + index_t batches = x_with_overlaps.Shape()[0]; + auto X_with_overlaps = make_tensor>({batches,static_cast(nfft)}); + + (X_with_overlaps = fft(x_with_overlaps,nfft)).run(stream); + + // Compute magnitude squared in-place + (X_with_overlaps = conj(X_with_overlaps) * X_with_overlaps).run(stream); + auto mag_sq_X_with_overlaps = X_with_overlaps.RealView(); + + // Perform the reduction across 'batches' rows and normalize + float norm_factor = 1.f / static_cast(batches); + (Pxx = sum(mag_sq_X_with_overlaps, {0}) * norm_factor).run(stream); + } + +} // end namespace matx diff --git a/test/00_transform/PWelch.cu b/test/00_transform/PWelch.cu new file mode 100644 index 000000000..a7b807afa --- /dev/null +++ b/test/00_transform/PWelch.cu @@ -0,0 +1,101 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2023, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#include "assert.h" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" +#include "gtest/gtest.h" + +using namespace matx; + +struct TestParams { + index_t signal_size; + int nperseg; + int noverlap; + int nfft; + int ftone; + int sigma; +}; +const std::vector CONFIGS = { + {16, 8, 4, 8, 0, 0}, + {16, 8, 4, 8, 1, 0}, + {16, 8, 4, 8, 2, 1}, + {16384, 256, 64, 256, 63, 0} +}; + +class PWelchComplexExponentialTest : public ::testing::TestWithParam +{ +public: + void SetUp() override + { + pb = std::make_unique(); + } + + void TearDown() { pb.reset(); } + + std::unique_ptr pb; + float thresh = 0.01f; + TestParams params = ::testing::TestWithParam::GetParam(); +}; + +template +void helper(PWelchComplexExponentialTest& test) +{ + MATX_ENTER_HANDLER(); + test.pb->template InitAndRunTVGenerator( + "00_transforms", "pwelch_operators", "pwelch_complex_exponential", {test.params.signal_size, test.params.nperseg, test.params.noverlap, test.params.nfft, test.params.ftone, test.params.sigma}); + + tensor_t x{{test.params.signal_size}}; + test.pb->NumpyToTensorView(x, "x_in"); + + auto Pxx = make_tensor({test.params.nfft}); + (Pxx = pwelch(x, test.params.nperseg, test.params.noverlap, test.params.nfft)).run(); + + cudaStreamSynchronize(0); + + MATX_TEST_ASSERT_COMPARE(test.pb, Pxx, "Pxx_out", test.thresh); + MATX_EXIT_HANDLER(); +} + + +TEST_P(PWelchComplexExponentialTest, xin_complex_float) +{ + helper>(*this); +} + +TEST_P(PWelchComplexExponentialTest, xin_complex_double) +{ + helper>(*this); +} + +INSTANTIATE_TEST_CASE_P(PWelchComplexExponentialTests, PWelchComplexExponentialTest,::testing::ValuesIn(CONFIGS)); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1adcb67d3..eed53d132 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,6 +16,7 @@ set (test_sources 00_transform/Copy.cu 00_transform/Cov.cu 00_transform/FFT.cu + 00_transform/PWelch.cu 00_transform/ResamplePoly.cu 00_transform/Solve.cu 00_solver/Cholesky.cu @@ -35,11 +36,11 @@ set (test_sources main.cu ) -# Some of <00_io> tests need csv files and the binary 'test.mat' which all +# Some of <00_io> tests need csv files and the binary 'test.mat' which all # are located under 'CMAKE_SOURCE_DIR/test/00_io'. When calling the test # executable from its location in 'CMAKE_BINARY_DIR/test' the -# search paths according are -# '../test/00_io/small_csv_comma_nh.csv' and +# search paths according are +# '../test/00_io/small_csv_comma_nh.csv' and # '../test/00_io/small_csv_complex_comma_nh.csv' respectively. Therefore # they must be copied to the correct location: file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/test/00_io) diff --git a/test/test_vectors/generators/00_transforms.py b/test/test_vectors/generators/00_transforms.py index d7b0a67fd..1b1605dc3 100755 --- a/test/test_vectors/generators/00_transforms.py +++ b/test/test_vectors/generators/00_transforms.py @@ -66,7 +66,7 @@ def __init__(self, dtype: str, size: List[int]): self.res = { 'a': matx_common.randn_ndarray((*size[:-3], size[-3], size[-2]), dtype), 'b': matx_common.randn_ndarray((*size[:-3], size[-2], size[-1]), dtype) - } + } def run(self) -> Dict[str, np.ndarray]: self.res['c'] = self.res['a'] @ self.res['b'] @@ -87,7 +87,7 @@ def run_b_transpose(self) -> Dict[str, np.ndarray]: def run_transpose(self) -> Dict[str, np.ndarray]: self.res['c'] = np.transpose(self.res['a'] @ self.res['b']) - return self.res + return self.res def run_mixed(self) -> Dict[str, np.ndarray]: float_to_complex_dtype = {np.float32 : np.complex64, np.float64 : np.complex128} @@ -270,7 +270,7 @@ def fft_1d_batched(self) -> Dict[str, np.ndarray]: return { 'a_in': seq, 'a_out': np.fft.fft(seq, self.size[2]) - } + } def ifft_1d(self) -> Dict[str, np.ndarray]: seq = matx_common.randn_ndarray((self.size[0],), self.dtype) @@ -305,7 +305,7 @@ def rfft_1d_batched(self) -> Dict[str, np.ndarray]: return { 'a_in': seq, 'a_out': np.fft.rfft(seq, self.size[2]) - } + } def irfft_1d(self) -> Dict[str, np.ndarray]: seq = matx_common.randn_ndarray((self.size[0],), self.dtype) @@ -329,3 +329,33 @@ def ifft_2d(self) -> Dict[str, np.ndarray]: 'a_in': seq, 'a_out': np.fft.ifft2(seq, (self.size[1], self.size[1])) } + +class pwelch_operators: + def __init__(self, dtype: str, params: List[int]): + self.dtype = dtype + self.signal_size = params[0] + self.nperseg = params[1] + self.noverlap = params[2] + self.nfft = params[3] + self.ftone = params[4] + self.sigma = params[5] + + np.random.seed(1234) + + def pwelch_complex_exponential(self) -> Dict[str, np.ndarray]: + s = np.exp(2j*np.pi*self.ftone*np.linspace(0,self.signal_size-1,self.signal_size)/self.nfft) + n = np.random.normal(loc=0,scale=self.sigma,size=self.signal_size) + 1j*np.random.normal(loc=0,scale=self.sigma,size=self.signal_size) + x = s + n + f, Pxx = signal.welch(x, + fs=1./self.nfft, + window=np.ones(self.nperseg), + nperseg=self.nperseg, + noverlap=self.noverlap, + nfft=self.nfft, + return_onesided=False, + scaling = 'density', + detrend=False) + return { + 'x_in': x, + 'Pxx_out': Pxx + } \ No newline at end of file