Skip to content

Commit

Permalink
Rework einsum for new cache style. Fix for issue #597 (#599)
Browse files Browse the repository at this point in the history
* Rework einsum for new cache style.  Fix for issue #597

* Rework MatX cache to use CacheId instead of CacheName.

* Switch from GetCacheIdFromFunction to GetCacheIdFromType
  • Loading branch information
tmartin-gh authored Mar 24, 2024
1 parent cc6cd07 commit 2c34e34
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 332 deletions.
41 changes: 19 additions & 22 deletions include/matx/core/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,24 @@
#include <optional>
#include <any>
#include <unordered_map>
#include <cuda/atomic>

#include "matx/core/error.h"

namespace matx {
namespace detail {

enum class CacheName {
FFT_1D,
FFT_2D,
CHOL,
LU,
QR,
SVD,
EIG,
CUB,
GEMM,
COV,
FILTER,
INV
};
using CacheId = uint64_t;

inline cuda::std::atomic<CacheId> CacheIdCounter{0};

template<typename CacheType>
CacheId GetCacheIdFromType()
{
static CacheId id = CacheIdCounter.fetch_add(1);

return id;
}

/**
* Generic caching object for caching parameters. This class is used for
Expand All @@ -72,7 +69,7 @@ class matxCache_t {
~matxCache_t() {
// Destroy all outstanding objects in the cache to free memory
for (auto &[k, v]: cache) {
v.reset();
v.reset();
}
}

Expand All @@ -81,22 +78,22 @@ class matxCache_t {
*
*/
template <typename CacheType>
void Clear(const CacheName &name) {
auto el = cache.find(name);
void Clear(const CacheId &id) {
auto el = cache.find(id);
MATX_ASSERT_STR(el != cache.end(), matxInvalidType, "Cache type not found");

std::any_cast<CacheType>(el->second).clear();
}

template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun>
void LookupAndExec(const CacheName &name, const InParams &params, const MakeFun &mfun, const ExecFun &efun) {
void LookupAndExec(const CacheId &id, const InParams &params, const MakeFun &mfun, const ExecFun &efun) {
// Create named cache if it doesn't exist
auto el = cache.find(name);
auto el = cache.find(id);
if (el == cache.end()) {
cache[name] = CacheType{};
cache[id] = CacheType{};
}

auto &cval = cache[name];
auto &cval = cache[id];
auto &rmap = std::any_cast<CacheType&>(cval);
auto cache_el = rmap.find(params);
if (cache_el == rmap.end()) {
Expand All @@ -110,7 +107,7 @@ class matxCache_t {
}

private:
std::unordered_map<CacheName, std::any> cache;
std::unordered_map<CacheId, std::any> cache;
};

/**
Expand Down
14 changes: 7 additions & 7 deletions include/matx/transforms/cov.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
static_assert(RANK >= 2);
MATX_ASSERT(c.Size(RANK - 1) == c.Size(RANK - 2), matxInvalidSize);
MATX_ASSERT(a.Size(RANK - 1) == c.Size(RANK - 1), matxInvalidSize);
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

// Ensure batch dimensions are equal
for (int i = 2; i < RANK - 2; i++) {
Expand Down Expand Up @@ -144,7 +144,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
inline void Exec(TensorTypeC &c, const TensorTypeA &a,
cudaStream_t stream)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
// Calculate a matrix of means
matmul_impl(means, onesM, a, stream,
1.0f / static_cast<float>(a.Size(RANK - 2)));
Expand All @@ -167,7 +167,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
// Multiply by itself and scale by N-1 for the final covariance
matmul_impl(c, devsT, devs, stream,
1.0f / static_cast<float>(a.Size(RANK - 2) - 1));
}
}

private:
// Member variables
Expand Down Expand Up @@ -231,21 +231,21 @@ template <typename TensorTypeC, typename TensorTypeA>
void cov_impl(TensorTypeC &c, const TensorTypeA &a,
cudaStream_t stream = 0)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
// Get parameters required by these tensors
auto params = detail::matxCovHandle_t<TensorTypeC, TensorTypeA>::GetCovParams(c, a, stream);

using cache_val_type = detail::matxCovHandle_t<TensorTypeC, TensorTypeA>;
detail::GetCache().LookupAndExec<detail::cov_cache_t>(
detail::CacheName::COV,
detail::GetCacheIdFromType<detail::cov_cache_t>(),
params,
[&]() {
return std::make_shared<cache_val_type>(c, a);
},
[&](std::shared_ptr<cache_val_type> ctype) {
ctype->Exec(c, a, stream);
}
);
);
}

} // end namespace matx
Loading

0 comments on commit 2c34e34

Please sign in to comment.