@@ -142,4 +142,201 @@ TEMPLATE_TEST_CASE("Utility math functions", "[Util]", float, double) {
142
142
CHECK (isApproxEqual (result, expected_result, 1e-5 ));
143
143
}
144
144
}
145
+ SECTION (" matrixVecProd" ) {
146
+ SECTION (" Simple Iterative" ) {
147
+ for (size_t m = 2 ; m < 8 ; m++) {
148
+ std::vector<std::complex<double >> mat (m * m, {1 , 1 });
149
+ std::vector<std::complex<double >> v_in (m, {1 , 1 });
150
+ std::vector<std::complex<double >> v_expected (
151
+ m, {0 , static_cast <double >(2 * m)});
152
+ std::vector<std::complex<double >> v_out =
153
+ Util::matrixVecProd (mat, v_in, m, m);
154
+ CAPTURE (v_out);
155
+ CAPTURE (v_expected);
156
+ for (size_t i = 0 ; i < m; i++) {
157
+ CHECK (isApproxEqual (v_out[i], v_expected[i]));
158
+ }
159
+ }
160
+ }
161
+ SECTION (" Random Complex" ) {
162
+ std::vector<std::complex<double >> mat{
163
+ {0.417876 , 0.27448 }, {0.601209 , 0.723548 },
164
+ {0.781624 , 0.538222 }, {0.0597232 , 0.27755 },
165
+ {0.0431741 , 0.593319 }, {0.224124 , 0.130335 },
166
+ {0.237877 , 0.01557 }, {0.931634 , 0.786367 },
167
+ {0.378397 , 0.894381 }, {0.840747 , 0.889789 },
168
+ {0.530623 , 0.463644 }, {0.868736 , 0.760685 },
169
+ {0.258175 , 0.836569 }, {0.495012 , 0.667726 },
170
+ {0.298962 , 0.384992 }, {0.659472 , 0.232696 }};
171
+ std::vector<std::complex<double >> v_in{{0.417876 , 0.27448 },
172
+ {0.601209 , 0.723548 },
173
+ {0.781624 , 0.538222 },
174
+ {0.0597232 , 0.27755 }};
175
+ std::vector<std::complex<double >> v_expected{{0.184998 , 1.97393 },
176
+ {-0.0894368 , 0.946047 },
177
+ {-0.219747 , 2.55541 },
178
+ {-0.305997 , 1.83881 }};
179
+ std::vector<std::complex<double >> v_out =
180
+ Util::matrixVecProd (mat, v_in, 4 , 4 );
181
+ CAPTURE (v_out);
182
+ for (size_t i = 0 ; i < 4 ; i++) {
183
+ CHECK (isApproxEqual (v_out[i], v_out[i]));
184
+ }
185
+ }
186
+ SECTION (" Invalid Arguments" ) {
187
+ using namespace Catch ::Matchers;
188
+ std::vector<std::complex<double >> mat (2 * 3 , {1 , 1 });
189
+ std::vector<std::complex<double >> v_in (2 , {1 , 1 });
190
+ CHECK_THROWS_AS (Util::matrixVecProd (mat, v_in, 2 , 3 ),
191
+ std::invalid_argument);
192
+ CHECK_THROWS_WITH (Util::matrixVecProd (mat, v_in, 2 , 3 ),
193
+ Contains (" Invalid size for the input vector" ));
194
+ CHECK_THROWS_AS (Util::matrixVecProd (mat, v_in, 2 , 2 ),
195
+ std::invalid_argument);
196
+ CHECK_THROWS_WITH (Util::matrixVecProd (mat, v_in, 2 , 2 ),
197
+ Contains (" Invalid m & n for the input matrix" ));
198
+ }
199
+ }
200
+ SECTION (" Transpose" ) {
201
+ SECTION (" Simple Matrix" ) {
202
+ for (size_t m = 2 ; m < 8 ; m++) {
203
+ std::vector<std::complex<double >> mat (m * m, {0 , 0 });
204
+ for (size_t i = 0 ; i < m; i++) {
205
+ mat[i * m + i] = {1 , 1 };
206
+ }
207
+ std::vector<std::complex<double >> mat_t =
208
+ Util::Transpose (mat, m, m);
209
+ CAPTURE (mat_t );
210
+ CAPTURE (mat);
211
+ for (size_t i = 0 ; i < m * m; i++) {
212
+ CHECK (isApproxEqual (mat[i], mat_t [i]));
213
+ }
214
+ }
215
+ }
216
+ SECTION (" Random Complex" ) {
217
+ std::vector<std::complex<double >> mat{
218
+ {0.417876 , 0.27448 }, {0.601209 , 0.723548 },
219
+ {0.781624 , 0.538222 }, {0.0597232 , 0.27755 },
220
+ {0.0431741 , 0.593319 }, {0.224124 , 0.130335 },
221
+ {0.237877 , 0.01557 }, {0.931634 , 0.786367 },
222
+ {0.378397 , 0.894381 }, {0.840747 , 0.889789 },
223
+ {0.530623 , 0.463644 }, {0.868736 , 0.760685 },
224
+ {0.258175 , 0.836569 }, {0.495012 , 0.667726 },
225
+ {0.298962 , 0.384992 }, {0.659472 , 0.232696 }};
226
+ std::vector<std::complex<double >> mat_t_exp{
227
+ {0.417876 , 0.27448 }, {0.0431741 , 0.593319 },
228
+ {0.378397 , 0.894381 }, {0.258175 , 0.836569 },
229
+ {0.601209 , 0.723548 }, {0.224124 , 0.130335 },
230
+ {0.840747 , 0.889789 }, {0.495012 , 0.667726 },
231
+ {0.781624 , 0.538222 }, {0.237877 , 0.01557 },
232
+ {0.530623 , 0.463644 }, {0.298962 , 0.384992 },
233
+ {0.0597232 , 0.27755 }, {0.931634 , 0.786367 },
234
+ {0.868736 , 0.760685 }, {0.659472 , 0.232696 }};
235
+ std::vector<std::complex<double >> mat_t =
236
+ Util::Transpose (mat, 4 , 4 );
237
+ CAPTURE (mat_t );
238
+ CAPTURE (mat_t_exp);
239
+ for (size_t i = 0 ; i < 16 ; i++) {
240
+ CHECK (isApproxEqual (mat_t [i], mat_t_exp[i]));
241
+ }
242
+ }
243
+ SECTION (" Invalid Arguments" ) {
244
+ using namespace Catch ::Matchers;
245
+ std::vector<std::complex<double >> mat (2 * 3 , {1 , 1 });
246
+ CHECK_THROWS_AS (Util::Transpose (mat, 2 , 2 ), std::invalid_argument);
247
+ CHECK_THROWS_WITH (Util::Transpose (mat, 2 , 2 ),
248
+ Contains (" Invalid m & n for the input matrix" ));
249
+ }
250
+ }
251
+ SECTION (" matrixMatProd" ) {
252
+ SECTION (" Simple Iterative" ) {
253
+ for (size_t m = 2 ; m < 8 ; m++) {
254
+ std::vector<std::complex<double >> m_left (m * m, {1 , 1 });
255
+ std::vector<std::complex<double >> m_right (m * m, {1 , 1 });
256
+ std::vector<std::complex<double >> m_out_exp (
257
+ m * m, {0 , static_cast <double >(2 * m)});
258
+ std::vector<std::complex<double >> m_out =
259
+ Util::matrixMatProd (m_left, m_right, m, m, m, true );
260
+ CAPTURE (m_out);
261
+ CAPTURE (m_out_exp);
262
+ for (size_t i = 0 ; i < m * m; i++) {
263
+ CHECK (isApproxEqual (m_out[i], m_out_exp[i]));
264
+ }
265
+ }
266
+ }
267
+ SECTION (" Random Complex" ) {
268
+ std::vector<std::complex<double >> m_left{
269
+ {0.94007 , 0.424517 }, {0.256163 , 0.0615097 },
270
+ {0.505297 , 0.343107 }, {0.729021 , 0.241991 },
271
+ {0.860825 , 0.633264 }, {0.987668 , 0.195166 },
272
+ {0.606897 , 0.144482 }, {0.0183697 , 0.375071 },
273
+ {0.355853 , 0.152383 }, {0.985341 , 0.0888863 },
274
+ {0.608352 , 0.653375 }, {0.268477 , 0.58398 },
275
+ {0.960381 , 0.786669 }, {0.498357 , 0.185307 },
276
+ {0.283511 , 0.844801 }, {0.269318 , 0.792981 }};
277
+ std::vector<std::complex<double >> m_right{
278
+ {0.94007 , 0.424517 }, {0.256163 , 0.0615097 },
279
+ {0.505297 , 0.343107 }, {0.729021 , 0.241991 },
280
+ {0.860825 , 0.633264 }, {0.987668 , 0.195166 },
281
+ {0.606897 , 0.144482 }, {0.0183697 , 0.375071 },
282
+ {0.355853 , 0.152383 }, {0.985341 , 0.0888863 },
283
+ {0.608352 , 0.653375 }, {0.268477 , 0.58398 },
284
+ {0.960381 , 0.786669 }, {0.498357 , 0.185307 },
285
+ {0.283511 , 0.844801 }, {0.269318 , 0.792981 }};
286
+ std::vector<std::complex<double >> m_right_tp{
287
+ {0.94007 , 0.424517 }, {0.860825 , 0.633264 },
288
+ {0.355853 , 0.152383 }, {0.960381 , 0.786669 },
289
+ {0.256163 , 0.0615097 }, {0.987668 , 0.195166 },
290
+ {0.985341 , 0.0888863 }, {0.498357 , 0.185307 },
291
+ {0.505297 , 0.343107 }, {0.606897 , 0.144482 },
292
+ {0.608352 , 0.653375 }, {0.283511 , 0.844801 },
293
+ {0.729021 , 0.241991 }, {0.0183697 , 0.375071 },
294
+ {0.268477 , 0.58398 }, {0.269318 , 0.792981 }};
295
+ std::vector<std::complex<double >> m_out_exp{
296
+ {1.522375435807200 , 2.018315393556500 },
297
+ {1.241561065671800 , 0.915996420839700 },
298
+ {0.561409446565600 , 1.834755796266900 },
299
+ {0.503973820211400 , 1.664651528374090 },
300
+ {1.183556828429700 , 2.272762769584300 },
301
+ {1.643767359748500 , 0.987318478828500 },
302
+ {0.752063484100700 , 1.482770126810700 },
303
+ {0.205343773497200 , 1.552791421044900 },
304
+ {0.977117116888800 , 2.092066653216500 },
305
+ {1.604565422784600 , 1.379671036009100 },
306
+ {0.238648365886400 , 1.582741563052100 },
307
+ {-0.401698027789600 , 1.469264325654110 },
308
+ {0.487510164243000 , 2.939585667799000 },
309
+ {0.845207296911400 , 1.843583823364000 },
310
+ {-0.482010055957000 , 2.062995137499000 },
311
+ {-0.524094900662100 , 1.815727577737900 }};
312
+ std::vector<std::complex<double >> m_out_1 =
313
+ Util::matrixMatProd (m_left, m_right_tp, 4 , 4 , 4 , true );
314
+ std::vector<std::complex<double >> m_out_2 =
315
+ Util::matrixMatProd (m_left, m_right, 4 , 4 , 4 , false );
316
+ CAPTURE (m_out_1);
317
+ CAPTURE (m_out_2);
318
+ CAPTURE (m_out_exp);
319
+ for (size_t i = 0 ; i < 16 ; i++) {
320
+ CHECK (isApproxEqual (m_out_1[i], m_out_2[i]));
321
+ }
322
+ for (size_t i = 0 ; i < 16 ; i++) {
323
+ CHECK (isApproxEqual (m_out_1[i], m_out_exp[i]));
324
+ }
325
+ }
326
+ SECTION (" Invalid Arguments" ) {
327
+ using namespace Catch ::Matchers;
328
+ std::vector<std::complex<double >> m_left (2 * 3 , {1 , 1 });
329
+ std::vector<std::complex<double >> m_right (3 * 4 , {1 , 1 });
330
+ CHECK_THROWS_AS (Util::matrixMatProd (m_left, m_right, 2 , 3 , 4 ),
331
+ std::invalid_argument);
332
+ CHECK_THROWS_WITH (
333
+ Util::matrixMatProd (m_left, m_right, 2 , 3 , 4 ),
334
+ Contains (" Invalid m & k for the input left matrix" ));
335
+ CHECK_THROWS_AS (Util::matrixMatProd (m_left, m_right, 2 , 3 , 3 ),
336
+ std::invalid_argument);
337
+ CHECK_THROWS_WITH (
338
+ Util::matrixMatProd (m_left, m_right, 2 , 3 , 3 ),
339
+ Contains (" Invalid k & n for the input right matrix" ));
340
+ }
341
+ }
145
342
}
0 commit comments