Skip to content

Commit

Permalink
Adds ygm::container::array::sort() (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
steiltre authored Jul 30, 2024
1 parent c0e3b88 commit a10a0a5
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
71 changes: 71 additions & 0 deletions include/ygm/container/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include <concepts>
#include <random>

#include <ygm/collective.hpp>
#include <ygm/comm.hpp>
Expand Down Expand Up @@ -412,6 +413,76 @@ class array
}
}

void sort() {
const key_type samples_per_pivot = std::max<key_type>(
std::min<key_type>(20, m_global_size / m_comm.size()), 1);
std::vector<mapped_type> to_sort;
to_sort.reserve(local_size() * 1.1f);

//
// Choose pivots, uses index as 3rd sorting argument to solve issue with
// lots of duplicate items
std::vector<std::pair<mapped_type, key_type>> samples;
std::vector<std::pair<mapped_type, key_type>> pivots;
static auto& s_samples = samples;
static auto& s_to_sort = to_sort;
samples.reserve((m_comm.size() - 1) * samples_per_pivot);

std::default_random_engine rng;

std::uniform_int_distribution<size_t> uintdist{0, size() - 1};

for (size_t i = 0; i < samples_per_pivot * (m_comm.size() - 1); ++i) {
size_t index = uintdist(rng);
if (index >= partitioner.local_start() &&
index < partitioner.local_start() + partitioner.local_size()) {
m_comm.async_bcast(
[](const std::pair<mapped_type, key_type>& sample) {
s_samples.push_back(sample);
},
std::make_pair(m_local_vec[index - partitioner.local_start()],
index));
}
}
m_comm.barrier();

ASSERT_RELEASE(samples.size() == samples_per_pivot * (m_comm.size() - 1));
std::sort(samples.begin(), samples.end());
for (size_t i = samples_per_pivot - 1; i < samples.size();
i += samples_per_pivot) {
pivots.push_back(samples[i]);
}
samples.clear();
samples.shrink_to_fit();

ASSERT_RELEASE(pivots.size() == m_comm.size() - 1);

//
// Partition using pivots
for (size_t i = 0; i < m_local_vec.size(); ++i) {
auto itr = std::lower_bound(
pivots.begin(), pivots.end(),
std::make_pair(m_local_vec[i], partitioner.local_start() + i));
size_t owner = std::distance(pivots.begin(), itr);

m_comm.async(
owner, [](const mapped_type& val) { s_to_sort.push_back(val); },
m_local_vec[i]);
}
m_comm.barrier();

if (not to_sort.empty()) {
std::sort(to_sort.begin(), to_sort.end());
}

size_t my_prefix = ygm::prefix_sum(to_sort.size(), m_comm);
for (key_type i = 0; i < to_sort.size(); ++i) {
async_insert(my_prefix + i, to_sort[i]);
}

m_comm.barrier();
}

detail::block_partitioner<key_type> partitioner;

private:
Expand Down
2 changes: 2 additions & 0 deletions include/ygm/container/detail/block_partitioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ struct block_partitioner {

index_type local_size() { return m_local_size; }

index_type local_start() { return m_local_start_index; }

private:
int m_comm_size;
int m_comm_rank;
Expand Down
28 changes: 28 additions & 0 deletions test/test_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,5 +503,33 @@ int main(int argc, char **argv) {
});
}

// Test sort
{
int num_values = 100;
ygm::container::array<int> arr(world, num_values);

if (world.rank0()) {
std::vector<int> values;
for (int i = 0; i < num_values; ++i) {
values.push_back(i);
}
std::random_device rd;
std::shuffle(values.begin(), values.end(), rd);

int index{0};
for (const auto v : values) {
arr.async_insert(index++, v);
}
}

world.barrier();

arr.sort();

arr.for_all([](const auto index, const auto &value) {
ASSERT_RELEASE(index == value);
});
}

return 0;
}

0 comments on commit a10a0a5

Please sign in to comment.