Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asadchev/feature/einsum ta dot #369

Merged
merged 12 commits into from
Nov 17, 2022
4 changes: 2 additions & 2 deletions src/TiledArray/einsum/eigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ void einsum(std::string expr,

using Index = TiledArray::Einsum::Index<char>;
using IndexDims = TiledArray::Einsum::IndexMap<char, size_t>;
using TiledArray::Einsum::string::split2;
using ::Einsum::string::split2;

auto permutation = [](auto src, auto dst) {
return TiledArray::Einsum::index::permutation(dst, src);
return TiledArray::Einsum::permutation(dst, src);
};

Index a, b, c;
Expand Down
4 changes: 2 additions & 2 deletions src/TiledArray/einsum/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "TiledArray/einsum/string.h"
#include "TiledArray/util/annotation.h"

namespace TiledArray::Einsum::index {
namespace Einsum::index {

std::vector<std::string> validate(const std::vector<std::string> &v) {
return v;
Expand All @@ -13,7 +13,7 @@ small_vector<std::string> tokenize(const std::string &s) {
// std::vector<std::string> r;
// boost::split(r, s, boost::is_any_of(", \t"));
// return r;
auto r = detail::tokenize_index(s, ',');
auto r = TiledArray::detail::tokenize_index(s, ',');
if (r == std::vector<std::string>{""}) return {};
return small_vector<std::string> (r.begin(), r.end()); // correct?
}
Expand Down
45 changes: 13 additions & 32 deletions src/TiledArray/einsum/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <iosfwd>
#include <string>

namespace TiledArray::Einsum::index {
namespace Einsum::index {

template <typename T>
using small_vector = TiledArray::container::svector<T>;
Expand Down Expand Up @@ -65,7 +65,7 @@ class Index {

template <typename U = void>
operator std::string() const {
return string::join(data_, ",");
return Einsum::string::join(data_, ",");
}

explicit operator bool() const { return !data_.empty(); }
Expand Down Expand Up @@ -186,6 +186,8 @@ Index<T> sorted(const Index<T> &a) {
return Index<T>(r);
}

using Permutation = TiledArray::Permutation;

/// @param[in] from original (preimage) indices
/// @param[in] to target (image) indices
/// @return Permutation mapping @p from to @p to
Expand All @@ -207,8 +209,11 @@ auto permute(const Permutation &p, const Index<T> &s,
if (!p) return s;
using R = typename Index<T>::container_type;
R r(p.size());
detail::permute_n(p.size(), p.begin(), s.begin(), r.begin(),
std::bool_constant<Inverse>{});
TiledArray::detail::permute_n(
p.size(),
p.begin(), s.begin(), r.begin(),
std::bool_constant<Inverse>{}
);
return Index<T>{r};
}

Expand Down Expand Up @@ -298,35 +303,11 @@ IndexMap<K, V> operator|(const IndexMap<K, V> &a, const IndexMap<K, V> &b) {
return IndexMap(d);
}

} // namespace TiledArray::Einsum::index

namespace TiledArray::Einsum {

using TiledArray::Einsum::index::Index;
using TiledArray::Einsum::index::IndexMap;

/// converts the annotation of an expression to an Index
template <typename Array>
auto idx(const std::string &s) {
using Index = Einsum::Index<std::string>;
if constexpr (detail::is_tensor_of_tensor_v<typename Array::value_type>) {
auto semi = std::find(s.begin(), s.end(), ';');
TA_ASSERT(semi != s.end());
auto [first,second] = string::split2(s, ";");
TA_ASSERT(!first.empty());
TA_ASSERT(!second.empty());
return std::tuple<Index, Index>{first, second};
} else {
return std::tuple<Index>{s};
}
}

/// converts the annotation of an expression to an Index
template <typename A, bool Alias>
auto idx(const TiledArray::expressions::TsrExpr<A, Alias> &e) {
return idx<A>(e.annotation());
}
} // namespace Einsum::index

namespace Einsum {
using index::Index;
using index::IndexMap;
} // namespace TiledArray::Einsum

#endif /* TILEDARRAY_EINSUM_INDEX_H__INCLUDED */
8 changes: 4 additions & 4 deletions src/TiledArray/einsum/range.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include <vector>
#include <boost/iterator/counting_iterator.hpp>

namespace TiledArray::Einsum::range {
namespace Einsum::range {

template<typename T>
using small_vector = container::svector<T>;
using small_vector = TiledArray::container::svector<T>;

struct Range {
using value_type = int64_t;
Expand Down Expand Up @@ -130,9 +130,9 @@ void cartesian_foreach(const std::vector<R>& rs, F f) {
}
}

} // namespace TiledArray::Einsum::range
} // namespace Einsum::range

namespace TiledArray::Einsum {
namespace Einsum {
using range::Range;
using range::RangeProduct;
}
Expand Down
2 changes: 1 addition & 1 deletion src/TiledArray/einsum/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <string>
#include <vector>

namespace TiledArray::Einsum::string {
namespace Einsum::string {
namespace {

// Split delimiter must match completely
Expand Down
Loading