Skip to content

Commit

Permalink
Adds overload for NumpyToTensorView and io::ReadMAT so that the shape…
Browse files Browse the repository at this point in the history
… of the array can be obtained from numpy
  • Loading branch information
nvjonwong authored and cliffburdick committed Jun 5, 2023
1 parent 612579b commit e016476
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 1 deletion.
38 changes: 38 additions & 0 deletions include/matx/core/file_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,44 @@ void ReadMAT(TensorType &t, const std::string fname,
pb->NumpyToTensorView(t, v);
}

/**
* @brief Read a MAT file into a tensor view
*
* MAT files use SciPy's loadmat() function to read various MATLAB file
* types in. MAT files are supersets of HDF5 files, and are allowed to
* have multiple fields in them.
*
* @tparam TensorType
* Data type of tensor
* @param t
* Tensor to read data into
* @param fname
* File name of .mat file
* @param var
* Variable name inside of .mat to read
*
**/
template <typename TensorType>
auto ReadMAT(const std::string fname,
const std::string var)
{

if (!std::filesystem::exists(fname)) {
const std::string errorMessage = "Failed to read [" + fname + "], Does not Exist";
MATX_THROW(matxIOError, errorMessage.c_str());
}

MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

auto pb = std::make_unique<detail::MatXPybind>();

auto sp = pybind11::module_::import("scipy.io");
auto obj = (pybind11::dict)sp.attr("loadmat")("file_name"_a = fname);
auto v = obj[var.c_str()];

return pb->NumpyToTensorView<TensorType>(v);
}

/**
* @brief Write a MAT file from a tensor view
*
Expand Down
22 changes: 21 additions & 1 deletion include/matx/core/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#pragma once

#include "matx/core/type_utils.h"

#include "matx/core/make_tensor.h"

#if MATX_ENABLE_PYBIND11

Expand Down Expand Up @@ -358,6 +358,26 @@ class MatXPybind {
}
}

template <typename TensorType>
auto NumpyToTensorView(const pybind11::object &np_ten)
{
using T = typename TensorType::scalar_type;
constexpr int RANK = TensorType::Rank();
using ntype = matx_convert_complex_type<T>;
auto ften = pybind11::array_t<ntype, pybind11::array::c_style | pybind11::array::forcecast>(np_ten);

auto info = ften.request();

assert(info.ndim == RANK);

std::array<matx::index_t, RANK> shape;
std::copy_n(info.shape.begin(), RANK, std::begin(shape));

auto ten = make_tensor<T> (shape);
std::copy(ften.data(), ften.data() + ften.size(), ten.Data() );
return ten;
}

template <typename TensorType>
auto TensorViewToNumpy(const TensorType &ten)
{
Expand Down
31 changes: 31 additions & 0 deletions test/00_io/FileIOTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,37 @@ TYPED_TEST(FileIoTestsNonComplexFloatTypes, MATWriteRank5)
// Read "myvar" from mat file
io::WriteMAT(t, "test_write.mat", "myvar");
io::ReadMAT(t2, "test_write.mat", "myvar");
for (index_t i = 0; i < t.Size(0); i++) {
for (index_t j = 0; j < t.Size(1); j++) {
for (index_t k = 0; k < t.Size(2); k++) {
for (index_t l = 0; l < t.Size(3); l++) {
for (index_t m = 0; m < t.Size(4); m++) {
ASSERT_EQ(t(i,j,k,l,m), t2(i,j,k,l,m));
}
}
}
}
}
MATX_EXIT_HANDLER();
}

TYPED_TEST(FileIoTestsNonComplexFloatTypes, MATWriteRank5GetShape)
{
MATX_ENTER_HANDLER();

auto t = make_tensor<TypeParam>({2,3,1,2,3});
tensor_t<TypeParam,5> t2;

randomGenerator_t<TypeParam> gen(t.TotalSize(), 0);
auto random = gen.GetTensorView(t.Shape(), UNIFORM);
(t = random).run();

cudaDeviceSynchronize();

// Read "myvar" from mat file
io::WriteMAT(t, "test_write.mat", "myvar");
t2.Shallow(io::ReadMAT<decltype(t2)>("test_write.mat", "myvar"));

for (index_t i = 0; i < t.Size(0); i++) {
for (index_t j = 0; j < t.Size(1); j++) {
for (index_t k = 0; k < t.Size(2); k++) {
Expand Down

0 comments on commit e016476

Please sign in to comment.