From 8448ca4458be5edd1d05cf11cd6f6c723ec9e52a Mon Sep 17 00:00:00 2001 From: cliffburdick Date: Thu, 8 Sep 2022 15:50:19 -0700 Subject: [PATCH] Fixed bug with FFT size shorter than length of tensor --- include/matx/transforms/fft.h | 2 +- test/00_transform/FFT.cu | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/include/matx/transforms/fft.h b/include/matx/transforms/fft.h index 2a3b62be..c2bd6ab1 100644 --- a/include/matx/transforms/fft.h +++ b/include/matx/transforms/fft.h @@ -704,7 +704,7 @@ auto GetFFTInputView([[maybe_unused]] OutputTensor &o, // FFT shorter than the size of the input signal. Create a new view of this // slice. if (act_fft_size < nom_fft_size) { - ends[RANK - 1] = nom_fft_size; + ends[RANK - 1] = act_fft_size; return i.Slice(starts, ends); } else { // FFT length is longer than the input. Pad input diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index 88414eea..e26b9bb2 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -253,3 +253,39 @@ TYPED_TEST(FFTTestComplexTypes, IFFT2D16C2C) MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); MATX_EXIT_HANDLER(); } + + +TYPED_TEST(FFTTestComplexNonHalfTypes, FFT1D1024C2CShort) +{ + MATX_ENTER_HANDLER(); + const index_t fft_dim = 1024; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "fft_1d", {fft_dim, fft_dim - 16}); + + tensor_t av{{fft_dim}}; + tensor_t avo{{fft_dim - 16}}; + this->pb->NumpyToTensorView(av, "a_in"); + + fft(avo, av); + cudaStreamSynchronize(0); + + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(FFTTestComplexNonHalfTypes, IFFT1D1024C2CShort) +{ + MATX_ENTER_HANDLER(); + const index_t fft_dim = 1024; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "ifft_1d", {fft_dim, fft_dim - 16}); + tensor_t av{{fft_dim}}; + tensor_t avo{{fft_dim - 16}}; + this->pb->NumpyToTensorView(av, "a_in"); + + ifft(avo, av); + cudaStreamSynchronize(0); + + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + MATX_EXIT_HANDLER(); +}