-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchacha_gpu_kernel.cu
145 lines (120 loc) · 4.13 KB
/
chacha_gpu_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: © 2021 Aalto University
// source: https://github.com/DPBayes/jax-chacha-prng/blob/master/lib/gpu_kernel.cpp.cu
#include <cuda_runtime.h>
#include <cstdlib>
#include <stdexcept>
#include "chacha_gpu_kernel.hpp"
__device__ __inline__
uint32_t rotate_left(uint32_t value, uint num_bits)
{
return (value << num_bits) ^ (value >> (32 - num_bits));
}
union uint32_vec4
{
struct
{
uint32_t a;
uint32_t b;
uint32_t c;
uint32_t d;
} comp;
uint32_t arr[4];
__device__ uint32_t& operator[](int i)
{
return arr[i];
}
};
__device__
uint32_vec4 quarterround_with_shuffle(uint32_vec4 vec)
{
uint32_t a = vec.comp.a;
uint32_t b = vec.comp.b;
uint32_t c = vec.comp.c;
uint32_t d = vec.comp.d;
a += b;
d ^= a;
d = rotate_left(d, 16);
c += d;
b ^= c;
b = rotate_left(b, 12);
a += b;
d ^= a;
d = rotate_left(d, 8);
c += d;
b ^= c;
b = rotate_left(b, 7);
return uint32_vec4{ a, b, c, d };
}
__device__
uint32_vec4 double_round_with_shuffle(uint32_vec4 state)
{
int state_thread_id = threadIdx.x % ThreadsPerState;
// quarterround on column
state = quarterround_with_shuffle(state);
// shuffle so that thread holds diagonal
for (int i = 1; i < WordsPerThread; ++i)
{
state[i] = __shfl_sync((uint)-1,
state[i], /*srcLane=*/state_thread_id + i, /*width=*/ThreadsPerState
);
}
// quarterround on diagonal
state = quarterround_with_shuffle(state);
// shuffle back to columns
for (int i = 1; i < WordsPerThread; ++i)
{
state[i] = __shfl_sync((uint)-1,
state[i], /*srcLane=*/state_thread_id - i, /*width=*/ThreadsPerState
);
}
return state;
}
__global__
void chacha20_block_with_shuffle(uint32_t* out_state, const uint32_t* in_state, uint num_threads)
{
// Each block consists of TargetThreadsPerBlock threads and each group of ThreadsPerState threads
// handle a single state (for a total of StatesPerBlock states in a block).
// We index into the state buffer by block id and thread group count:
const uint thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id >= num_threads) return;
const uint state_thread_id = threadIdx.x % ThreadsPerState;
const uint block_state_index = threadIdx.x / ThreadsPerState;
const uint global_state_index = blockIdx.x * StatesPerBlock + block_state_index;
const uint global_buffer_offset = global_state_index * ChaChaStateSizeInWords;
uint32_vec4 in_state_vec;
for (uint i = 0; i < WordsPerThread; ++i)
{
in_state_vec[i] = in_state[global_buffer_offset + i*WordsPerThread + state_thread_id];
}
uint32_vec4 state_vec = in_state_vec;
for (uint i = 0; i < ChaChaDoubleRoundCount; ++i)
{
state_vec = double_round_with_shuffle(state_vec);
}
for (uint i = 0; i < WordsPerThread; ++i)
{
out_state[global_buffer_offset + i*WordsPerThread + state_thread_id] = in_state_vec[i] + state_vec[i];
}
}
void gpu_chacha20_block(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_length)
{
uint32_t num_states = 1;
if (opaque_length > 0)
{
if (opaque_length != sizeof(uint32_t))
{
throw std::runtime_error(
"gpu_chacha20_block requires the opaque argument to be either null or a pointer to a 32-bit integer "
"indicating the number of states on which to operate."
);
}
num_states = *reinterpret_cast<const uint32_t*>(opaque);
}
const uint32_t* in_states = reinterpret_cast<const uint32_t*>(buffers[0]);
uint32_t* out_state = reinterpret_cast<uint32_t*>(buffers[1]);
uint num_threads = (num_states * ThreadsPerState);
uint num_blocks = (num_threads + TargetThreadsPerBlock - 1) / TargetThreadsPerBlock; // = ceil(num_threads / TargetThreadsPerBlock)
uint threads_per_block = std::min(num_threads, TargetThreadsPerBlock);
chacha20_block_with_shuffle<<<num_blocks, threads_per_block, 0, stream>>>(out_state, in_states, num_threads);
}