Skip to content

Commit

Permalink
Added gather_topk, separated base_iteration.
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerpearce committed Aug 3, 2024
1 parent 32fcffa commit c01c6fd
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 56 deletions.
5 changes: 5 additions & 0 deletions include/ygm/container/counting_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class counting_set
m_map.local_for_all(fn);
}

template <typename Function>
void local_for_all(Function fn) const {
m_map.local_for_all(fn);
}

void local_clear() { // What to do here
m_map.local_clear();
clear_cache();
Expand Down
259 changes: 238 additions & 21 deletions include/ygm/container/detail/base_iteration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include <tuple>
#include <vector>
#include <ygm/collective.hpp>
#include <ygm/container/detail/base_concepts.hpp>

Expand All @@ -22,6 +23,14 @@ class flatten_proxy;

template <typename derived_type, typename for_all_args>
struct base_iteration {
static_assert(sizeof(for_all_args) != sizeof(for_all_args),
"Unsupported for_all_args");
};

template <typename derived_type, SingleItemTuple for_all_args>
struct base_iteration<derived_type, for_all_args> {
using value_type = typename std::tuple_element<0, for_all_args>::type;

template <typename Function>
void for_all(Function fn) {
derived_type* derived_this = static_cast<derived_type*>(this);
Expand All @@ -38,6 +47,8 @@ struct base_iteration {

template <typename STLContainer>
void gather(STLContainer& gto, int rank) const {
static_assert(
std::is_same_v<typename STLContainer::value_type, value_type>);
// TODO, make an all gather version that defaults to rank = -1 & uses a temp
// container.
bool all_gather = (rank == -1);
Expand All @@ -56,14 +67,47 @@ struct base_iteration {
derived_this->comm().barrier();
}

template <typename MergeFunction>
std::tuple_element<0, for_all_args>::type reduce(MergeFunction merge) const
template <typename Compare = std::greater<value_type>>
std::vector<value_type> gather_topk(
size_t k, Compare comp = std::greater<value_type>()) const
requires SingleItemTuple<for_all_args>
{
const derived_type* derived_this = static_cast<const derived_type*>(this);
const ygm::comm& mycomm = derived_this->comm();
std::vector<value_type> local_topk;

//
// Find local top_k
for_all([&local_topk, comp, k](const value_type& value) {
local_topk.push_back(value);
std::sort(local_topk.begin(), local_topk.end(), comp);
if (local_topk.size() > k) {
local_topk.pop_back();
}
});

//
// All reduce global top_k
auto to_return = mycomm.all_reduce(
local_topk, [comp, k](const std::vector<value_type>& va,
const std::vector<value_type>& vb) {
std::vector<value_type> out(va.begin(), va.end());
out.insert(out.end(), vb.begin(), vb.end());
std::sort(out.begin(), out.end(), comp);
while (out.size() > k) {
out.pop_back();
}
return out;
});
return to_return;
}

template <typename MergeFunction>
value_type reduce(MergeFunction merge) const {
const derived_type* derived_this = static_cast<const derived_type*>(this);
derived_this->comm().barrier();
YGM_ASSERT_RELEASE(derived_this->local_size() >
0); // empty partition not handled yet
0); // empty partition not handled yet

using value_type = typename std::tuple_element<0, for_all_args>::type;
bool first = true;
Expand All @@ -89,8 +133,162 @@ struct base_iteration {
template <typename YGMContainer>
void collect(YGMContainer& c) const {
const derived_type* derived_this = static_cast<const derived_type*>(this);
auto clambda = [&c](const std::tuple_element<0, for_all_args>::type& item) {
c.async_insert(item);
auto clambda = [&c](const value_type& item) { c.async_insert(item); };
derived_this->for_all(clambda);
}

template <typename MapType, typename ReductionOp>
void reduce_by_key(MapType& map, ReductionOp reducer) const {
// TODO: static_assert MapType is ygm::container::map
const derived_type* derived_this = static_cast<const derived_type*>(this);
using reduce_key_type = typename MapType::key_type;
using reduce_value_type = typename MapType::mapped_type;
static_assert(std::is_same_v<value_type,
std::pair<reduce_key_type, reduce_value_type>>,
"value_type must be a std::pair");

auto rbklambda =
[&map, reducer](std::pair<reduce_key_type, reduce_value_type> kvp) {
map.async_reduce(kvp.first, kvp.second, reducer);
};
derived_this->for_all(rbklambda);
}

template <typename MapFunction>
map_proxy<derived_type, MapFunction> map(MapFunction ffn);

flatten_proxy<derived_type> flatten();

template <typename FilterFunction>
filter_proxy<derived_type, FilterFunction> filter(FilterFunction ffn);

private:
template <typename STLContainer, typename Value>
requires requires(STLContainer stc, Value v) { stc.push_back(v); }
static void generic_insert(STLContainer& stc, const Value& value) {
stc.push_back(value);
}

template <typename STLContainer, typename Value>
requires requires(STLContainer stc, Value v) { stc.insert(v); }
static void generic_insert(STLContainer& stc, const Value& value) {
stc.insert(value);
}
};

template <typename derived_type, DoubleItemTuple for_all_args>
struct base_iteration<derived_type, for_all_args> {
using key_type = typename std::tuple_element<0, for_all_args>::type;
using mapped_type = typename std::tuple_element<1, for_all_args>::type;

template <typename Function>
void for_all(Function fn) {
derived_type* derived_this = static_cast<derived_type*>(this);
derived_this->comm().barrier();
derived_this->local_for_all(fn);
}

template <typename Function>
void for_all(Function fn) const {
const derived_type* derived_this = static_cast<const derived_type*>(this);
derived_this->comm().barrier();
derived_this->local_for_all(fn);
}

template <typename STLContainer>
void gather(STLContainer& gto, int rank) const {
static_assert(std::is_same_v<typename STLContainer::value_type,
std::pair<key_type, mapped_type>>);
// TODO, make an all gather version that defaults to rank = -1 & uses a temp
// container.
bool all_gather = (rank == -1);
static STLContainer* spgto = &gto;
const derived_type* derived_this = static_cast<const derived_type*>(this);
const ygm::comm& mycomm = derived_this->comm();

auto glambda = [&mycomm, rank](const key_type& key,
const mapped_type& value) {
mycomm.async(
rank,
[](const key_type& key, const mapped_type& value) {
generic_insert(*spgto, std::make_pair(key, value));
},
key, value);
};

for_all(glambda);

derived_this->comm().barrier();
}

template <typename Compare = std::greater<std::pair<key_type, mapped_type>>>
std::vector<std::pair<key_type, mapped_type>> gather_topk(
size_t k, Compare comp = Compare()) const {
const derived_type* derived_this = static_cast<const derived_type*>(this);
const ygm::comm& mycomm = derived_this->comm();
using vec_type = std::vector<std::pair<key_type, mapped_type>>;
vec_type local_topk;

//
// Find local top_k
for_all(
[&local_topk, comp, k](const key_type& key, const mapped_type& mapped) {
local_topk.push_back(std::make_pair(key, mapped));
std::sort(local_topk.begin(), local_topk.end(), comp);
if (local_topk.size() > k) {
local_topk.pop_back();
}
});

//
// All reduce global top_k
auto to_return = mycomm.all_reduce(
local_topk, [comp, k](const vec_type& va, const vec_type& vb) {
vec_type out(va.begin(), va.end());
out.insert(out.end(), vb.begin(), vb.end());
std::sort(out.begin(), out.end(), comp);
while (out.size() > k) {
out.pop_back();
}
return out;
});
return to_return;
}

template <typename MergeFunction>
std::pair<key_type, mapped_type> reduce(MergeFunction merge) const {
const derived_type* derived_this = static_cast<const derived_type*>(this);
derived_this->comm().barrier();

bool first = true;

std::pair<key_type, mapped_type> local_reduce;

auto rlambda = [&local_reduce, &first,
&merge](const std::pair<key_type, mapped_type>& value) {
if (first) {
local_reduce = value;
first = false;
} else {
local_reduce = merge(local_reduce, value);
}
};

derived_this->for_all(rlambda);

std::optional<std::pair<key_type, mapped_type>> to_reduce;
if (first) { // local partition was empty!
to_reduce = std::move(local_reduce);
}

return ::ygm::all_reduce(to_reduce, merge, derived_this->comm());
}

template <typename YGMContainer>
void collect(YGMContainer& c) const {
const derived_type* derived_this = static_cast<const derived_type*>(this);
auto clambda = [&c](const key_type& key, const mapped_type& value) {
c.async_insert(std::make_pair(key, value));
};
derived_this->for_all(clambda);
}
Expand All @@ -101,19 +299,13 @@ struct base_iteration {
// static_assert ygm::map
using reduce_key_type = typename MapType::key_type;
using reduce_value_type = typename MapType::mapped_type;
if constexpr (std::tuple_size<for_all_args>::value == 1) {
// must be a std::pair
auto rbklambda = [&map, reducer](std::pair<reduce_key_type, reduce_value_type> kvp) {
map.async_reduce(kvp.first, kvp.second, reducer);
};
derived_this->for_all(rbklambda);
} else {
static_assert(std::tuple_size<for_all_args>::value == 2);
auto rbklambda = [&map, reducer](const reduce_key_type& key, const reduce_value_type& value) {
map.async_reduce(key, value, reducer);
};
derived_this->for_all(rbklambda);
}

static_assert(std::tuple_size<for_all_args>::value == 2);
auto rbklambda = [&map, reducer](const reduce_key_type& key,
const reduce_value_type& value) {
map.async_reduce(key, value, reducer);
};
derived_this->for_all(rbklambda);
}

template <typename MapFunction>
Expand Down Expand Up @@ -146,15 +338,15 @@ struct base_iteration {

namespace ygm::container::detail {

template <typename derived_type, typename for_all_args>
template <typename derived_type, SingleItemTuple for_all_args>
template <typename MapFunction>
map_proxy<derived_type, MapFunction>
base_iteration<derived_type, for_all_args>::map(MapFunction ffn) {
derived_type* derived_this = static_cast<derived_type*>(this);
return map_proxy<derived_type, MapFunction>(*derived_this, ffn);
}

template <typename derived_type, typename for_all_args>
template <typename derived_type, SingleItemTuple for_all_args>
inline flatten_proxy<derived_type>
base_iteration<derived_type, for_all_args>::flatten() {
// static_assert(
Expand All @@ -163,7 +355,32 @@ base_iteration<derived_type, for_all_args>::flatten() {
return flatten_proxy<derived_type>(*derived_this);
}

template <typename derived_type, typename for_all_args>
template <typename derived_type, SingleItemTuple for_all_args>
template <typename FilterFunction>
filter_proxy<derived_type, FilterFunction>
base_iteration<derived_type, for_all_args>::filter(FilterFunction ffn) {
derived_type* derived_this = static_cast<derived_type*>(this);
return filter_proxy<derived_type, FilterFunction>(*derived_this, ffn);
}

template <typename derived_type, DoubleItemTuple for_all_args>
template <typename MapFunction>
map_proxy<derived_type, MapFunction>
base_iteration<derived_type, for_all_args>::map(MapFunction ffn) {
derived_type* derived_this = static_cast<derived_type*>(this);
return map_proxy<derived_type, MapFunction>(*derived_this, ffn);
}

template <typename derived_type, DoubleItemTuple for_all_args>
inline flatten_proxy<derived_type>
base_iteration<derived_type, for_all_args>::flatten() {
// static_assert(
// type_traits::is_vector<std::tuple_element<0, for_all_args>>::value);
derived_type* derived_this = static_cast<derived_type*>(this);
return flatten_proxy<derived_type>(*derived_this);
}

template <typename derived_type, DoubleItemTuple for_all_args>
template <typename FilterFunction>
filter_proxy<derived_type, FilterFunction>
base_iteration<derived_type, for_all_args>::filter(FilterFunction ffn) {
Expand Down
33 changes: 1 addition & 32 deletions include/ygm/container/detail/map_proxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,11 @@
#include <functional>
#include <tuple>
#include <utility>
#include <ygm/container/detail/type_traits.hpp>

namespace ygm::container::detail {

namespace type_traits {
template <template <typename...> class T, typename U>
struct is_specialization_of : std::false_type {};

template <template <typename...> class T, typename... Us>
struct is_specialization_of<T, T<Us...>> : std::true_type {};

template <typename T>
struct is_vector
: is_specialization_of<std::vector, typename std::decay<T>::type> {};

template <typename T>
struct is_tuple
: is_specialization_of<std::tuple, typename std::decay<T>::type> {};

template <class T, bool isTuple>
struct tuple_wrapper_helper // T is not a tuple
{
using type = std::tuple<T>;
};

template <class T>
struct tuple_wrapper_helper<T, true> // T is a tuple
{
using type = T;
};

template <class T>
struct tuple_wrapper // T is a tuple
{
using type = tuple_wrapper_helper<T, is_tuple<T>::value>::type;
};
} // namespace type_traits

template <typename Container, typename MapFunction>
class map_proxy
Expand Down
Loading

0 comments on commit c01c6fd

Please sign in to comment.