-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathp3a_scan.hpp
155 lines (142 loc) · 3.77 KB
/
p3a_scan.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#pragma once
#include <stdexcept>
#include <Kokkos_StdAlgorithms.hpp>
#include "p3a_execution.hpp"
namespace p3a {
namespace details {
template <class Iterator1, class Iterator2, class T>
class kokkos_exclusive_scan_functor {
Iterator1 m_first;
Iterator2 m_d_first;
public:
kokkos_exclusive_scan_functor(
Iterator1 first_arg,
Iterator2 d_first_arg)
:m_first(first_arg)
,m_d_first(d_first_arg)
{
}
using difference_type = typename std::iterator_traits<Iterator1>::difference_type;
P3A_ALWAYS_INLINE P3A_HOST_DEVICE inline
void operator()(difference_type const i, T& update, bool const is_final_pass) const
{
if (is_final_pass) {
m_d_first[i] = update;
}
update += m_first[i];
}
};
template <
class ExecutionSpace,
class Iterator1,
class Iterator2,
class T>
void kokkos_exclusive_scan(
Iterator1 first,
Iterator1 last,
Iterator2 d_first,
T init)
{
if (init != T(0)) {
throw std::runtime_error("p3a::details::kokkos_exclusive_scan only supports zero init");
}
using difference_type = typename std::iterator_traits<Iterator1>::difference_type;
using kokkos_policy =
Kokkos::RangePolicy<
ExecutionSpace,
Kokkos::IndexType<difference_type>>;
using functor = kokkos_exclusive_scan_functor<Iterator1, Iterator2, T>;
Kokkos::parallel_scan("p3a::details::kokkos_exclusive_scan",
kokkos_policy(0, (last - first)),
functor(first, d_first));
}
template <class InputIt, class OutputIt, class UnaryPredicate>
class kokkos_copy_if_functor {
InputIt m_first;
OutputIt m_d_first;
UnaryPredicate m_pred;
public:
kokkos_copy_if_functor(
InputIt first_arg,
OutputIt d_first_arg,
UnaryPredicate pred_arg)
:m_first(first_arg)
,m_d_first(d_first_arg)
,m_pred(pred_arg)
{
}
using input_difference_type = typename std::iterator_traits<InputIt>::difference_type;
using output_difference_type = typename std::iterator_traits<OutputIt>::difference_type;
P3A_ALWAYS_INLINE P3A_HOST_DEVICE inline
void operator()(
input_difference_type const i,
output_difference_type& update,
bool const is_final_pass) const
{
if (is_final_pass) {
if (m_pred(m_first[i])) {
m_d_first[update] = m_first[i];
}
}
if (m_pred(m_first[i])) {
update += 1;
}
}
};
template <
class ExecutionSpace,
class InputIt,
class OutputIt,
class UnaryPredicate>
OutputIt kokkos_copy_if(
InputIt first,
InputIt last,
OutputIt d_first,
UnaryPredicate pred)
{
using input_difference_type = typename std::iterator_traits<InputIt>::difference_type;
using output_difference_type = typename std::iterator_traits<OutputIt>::difference_type;
using kokkos_policy =
Kokkos::RangePolicy<
ExecutionSpace,
Kokkos::IndexType<input_difference_type>>;
using functor = kokkos_copy_if_functor<InputIt, OutputIt, UnaryPredicate>;
output_difference_type copied_count = 0;
Kokkos::parallel_scan("p3a::details::kokkos_copy_if",
kokkos_policy(0, (last - first)),
functor(first, d_first, pred),
copied_count);
return d_first + copied_count;
}
}
template <
class ExecutionPolicy,
class Iterator1,
class Iterator2,
class T>
void exclusive_scan(
ExecutionPolicy policy,
Iterator1 first,
Iterator1 last,
Iterator2 d_first,
T init)
{
details::kokkos_exclusive_scan<typename ExecutionPolicy::kokkos_execution_space>(
first, last, d_first, init);
}
template <
class ExecutionPolicy,
class InputIt,
class OutputIt,
class UnaryPredicate>
OutputIt copy_if(
ExecutionPolicy policy,
InputIt first,
InputIt last,
OutputIt d_first,
UnaryPredicate pred)
{
return details::kokkos_copy_if<typename ExecutionPolicy::kokkos_execution_space>(
first, last, d_first, pred);
}
}