Skip to content

Commit

Permalink
Merge branch 'master' into kmp5/experimental/upkeep_btas
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Apr 14, 2024
2 parents b036eba + d72357a commit 3ec0502
Show file tree
Hide file tree
Showing 37 changed files with 2,485 additions and 796 deletions.
30 changes: 17 additions & 13 deletions src/TiledArray/dist_eval/array_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,15 @@ class ArrayEvalImpl
/// \param pmap The process map for the result tensor tiles
/// \param perm The permutation that is applied to the tile coordinate index
/// \param op The operation that will be used to evaluate the tiles of array
template <typename Perm, typename = std::enable_if_t<
TiledArray::detail::is_permutation_v<Perm>>>
template <typename Perm,
typename = std::enable_if_t<TiledArray::detail::is_permutation_v<
std::remove_reference_t<Perm>>>>
ArrayEvalImpl(const array_type& array, World& world,
const trange_type& trange, const shape_type& shape,
const std::shared_ptr<const pmap_interface>& pmap,
const Perm& perm, const op_type& op)
: DistEvalImpl_(world, trange, shape, pmap, outer(perm)),
const std::shared_ptr<const pmap_interface>& pmap, Perm&& perm,
const op_type& op)
: DistEvalImpl_(world, trange, shape, pmap,
outer(std::forward<Perm>(perm))),
array_(array),
op_(std::make_shared<op_type>(op)),
block_range_()
Expand Down Expand Up @@ -273,17 +275,19 @@ class ArrayEvalImpl
/// \param op The operation that will be used to evaluate the tiles of array
/// \param lower_bound The sub-block lower bound
/// \param upper_bound The sub-block upper bound
template <typename Index1, typename Index2, typename Perm,
typename = std::enable_if_t<
TiledArray::detail::is_integral_range_v<Index1> &&
TiledArray::detail::is_integral_range_v<Index2> &&
TiledArray::detail::is_permutation_v<Perm>>>
template <
typename Index1, typename Index2, typename Perm,
typename = std::enable_if_t<
TiledArray::detail::is_integral_range_v<Index1> &&
TiledArray::detail::is_integral_range_v<Index2> &&
TiledArray::detail::is_permutation_v<std::remove_reference_t<Perm>>>>
ArrayEvalImpl(const array_type& array, World& world,
const trange_type& trange, const shape_type& shape,
const std::shared_ptr<const pmap_interface>& pmap,
const Perm& perm, const op_type& op, const Index1& lower_bound,
const std::shared_ptr<const pmap_interface>& pmap, Perm&& perm,
const op_type& op, const Index1& lower_bound,
const Index2& upper_bound)
: DistEvalImpl_(world, trange, shape, pmap, outer(perm)),
: DistEvalImpl_(world, trange, shape, pmap,
outer(std::forward<Perm>(perm))),
array_(array),
op_(std::make_shared<op_type>(op)),
block_range_(array.trange().tiles_range(), lower_bound, upper_bound)
Expand Down
37 changes: 17 additions & 20 deletions src/TiledArray/einsum/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

#include "TiledArray/expressions/fwd.h"

#include <TiledArray/einsum/string.h>
#include <TiledArray/error.h>
#include <TiledArray/permutation.h>
#include <TiledArray/util/vector.h>
#include <TiledArray/einsum/string.h>

#include <iosfwd>
#include <string>
Expand All @@ -29,10 +29,11 @@ class Index {
public:
using container_type = small_vector<T>;
using value_type = typename container_type::value_type;
using iterator = typename container_type::iterator;

Index() = default;
Index(const container_type &s) : data_(s) {}
Index(const std::initializer_list<T> &s) : data_(s) {}
explicit Index(const std::initializer_list<T> &s) : data_(s) {}

template <typename S, typename U = void>
Index(const S &s) {
Expand All @@ -45,18 +46,14 @@ class Index {
Index(const char (&s)[N]) : Index(std::string(s)) {}

template <typename U = void>
explicit Index(const char* &s) : Index(std::string(s)) {}
explicit Index(const char *&s) : Index(std::string(s)) {}

template <typename U = void>
explicit Index(const std::string &s) {
static_assert(
std::is_same_v<T,char> ||
std::is_same_v<T,std::string>
);
if constexpr (std::is_same_v<T,std::string>) {
static_assert(std::is_same_v<T, char> || std::is_same_v<T, std::string>);
if constexpr (std::is_same_v<T, std::string>) {
data_ = index::tokenize(s);
}
else {
} else {
using std::begin;
using std::end;
data_.assign(begin(s), end(s));
Expand All @@ -78,8 +75,11 @@ class Index {

size_t size() const { return data_.size(); }

auto begin() const { return data_.begin(); }
auto end() const { return data_.end(); }
auto begin() const { return data_.cbegin(); }
auto end() const { return data_.cend(); }

auto begin() { return data_.begin(); }
auto end() { return data_.end(); }

auto find(const T &v) const {
return std::find(this->begin(), this->end(), v);
Expand Down Expand Up @@ -209,11 +209,8 @@ auto permute(const Permutation &p, const Index<T> &s,
if (!p) return s;
using R = typename Index<T>::container_type;
R r(p.size());
TiledArray::detail::permute_n(
p.size(),
p.begin(), s.begin(), r.begin(),
std::bool_constant<Inverse>{}
);
TiledArray::detail::permute_n(p.size(), p.begin(), s.begin(), r.begin(),
std::bool_constant<Inverse>{});
return Index<T>{r};
}

Expand Down Expand Up @@ -306,8 +303,8 @@ IndexMap<K, V> operator|(const IndexMap<K, V> &a, const IndexMap<K, V> &b) {
} // namespace Einsum::index

namespace Einsum {
using index::Index;
using index::IndexMap;
} // namespace TiledArray::Einsum
using index::Index;
using index::IndexMap;
} // namespace Einsum

#endif /* TILEDARRAY_EINSUM_INDEX_H__INCLUDED */
64 changes: 32 additions & 32 deletions src/TiledArray/einsum/string.h
Original file line number Diff line number Diff line change
@@ -1,50 +1,50 @@
#ifndef TILEDARRAY_EINSUM_STRING_H
#define TILEDARRAY_EINSUM_STRING_H

#include <boost/algorithm/string/join.hpp>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/trim.hpp>
#include <boost/algorithm/string/join.hpp>
#include <sstream>
#include <string>
#include <vector>

namespace Einsum::string {
namespace {

// Split delimiter must match completely
template<typename T = std::string, typename U = T>
std::pair<T,U> split2(const std::string& s, const std::string &d) {
auto pos = s.find(d);
if (pos == s.npos) return { T(s), U("") };
return { T(s.substr(0,pos)), U(s.substr(pos+d.size())) };
}
// Split delimiter must match completely
template <typename T = std::string, typename U = T>
std::pair<T, U> split2(const std::string& s, const std::string& d) {
auto pos = s.find(d);
if (pos == s.npos) return {T(s), U("")};
return {T(s.substr(0, pos)), U(s.substr(pos + d.size()))};
}

// Split delimiter must match completely
std::vector<std::string> split(const std::string& s, char d) {
std::vector<std::string> res;
return boost::split(res, s, [&d](char c) { return c == d; } /*boost::is_any_of(d)*/);
}
// Split delimiter must match completely
std::vector<std::string> split(const std::string& s, char d) {
std::vector<std::string> res;
return boost::split(res, s,
[&d](char c) { return c == d; } /*boost::is_any_of(d)*/);
}

std::string trim(const std::string& s) {
return boost::trim_copy(s);
}
std::string trim(const std::string& s) { return boost::trim_copy(s); }

template <typename T>
std::string str(const T& obj) {
std::stringstream ss;
ss << obj;
return ss.str();
}
template <typename T>
std::string str(const T& obj) {
std::stringstream ss;
ss << obj;
return ss.str();
}

template<typename T, typename U = std::string>
std::string join(const T &s, const U& j = U("")) {
std::vector<std::string> strings;
for (auto e : s) {
strings.push_back(str(e));
}
return boost::join(strings, j);
template <typename T, typename U = std::string>
std::string join(const T& s, const U& j = U("")) {
std::vector<std::string> strings;
for (auto e : s) {
strings.push_back(str(e));
}

}
return boost::join(strings, j);
}

#endif //TILEDARRAY_EINSUM_STRING_H
} // namespace
} // namespace Einsum::string

#endif // TILEDARRAY_EINSUM_STRING_H
Loading

0 comments on commit 3ec0502

Please sign in to comment.