Skip to content

Commit

Permalink
Impl fast dump kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Lifann authored and oppenheimli committed Aug 13, 2024
1 parent 4f38be5 commit 46c9f89
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -910,5 +910,80 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table,
}
}

/* Dump with score. */
template <class K, class V, class S,
template <typename, typename> class PredFunctor,
int TILE_SIZE>
__global__ void dump_kernel_v2(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* buckets, const K pattern,
const S threshold, K* d_key, V* __restrict d_val,
S* __restrict d_score, const size_t offset,
const size_t search_length,
size_t* d_dump_counter) {
const size_t bucket_max_size = table->bucket_max_size;
int dim = table->dim;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());

__shared__ block_acc;
if (threadIdx.x == 0) {
block_acc = 0;
}
__syncthreads();

size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t N = n * TILE_SIZE;

for (size_t ii = tid; i < N; i += gridDim.x * blockDim.x) {
size_t i = ii / TILE_SIZE;

/*
(4 keys per bucket)
(TILE_SIZE: 2)
origin:
0 1 2 3 4 5 6 7, 8 9 10 11 12 13 14 15
A A A A B B B B C C C C D D D D
new:
0 1 2 3 4 5 6 7, 8 9 10 11 12 13 14 15
A A A A B B B B C C C C D D D D
*/

size_t bkt_idx = (ii + offset) / bucket_max_size;
int key_idx = (ii + offset) % bucket_max_size;
int leading_key_idx = key_idx % TILE_SIZE;
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

const K key =
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);
bool match = !IS_RESERVED_KEY<K>(key) && pred(key, score, pattern, threshold);
int vote = g.ballot(match);
int tile_cnt = __popc(vote);
int tile_offset = 0;
if (g.rank() == 0) {
tile_offset = atomicAdd(d_dump_counter, static_cast<size_t>(tile_cnt));
}
tile_offset = g.shfl(tile_offset, 0);

if (match) {
d_key[tile_offset + key_idx] = key;
if (d_score) {
d_score[tile_offset + key_idx] = score;
}
}

#pragma unroll
for (int r = 0; r < TILE_SIZE; r++) {
bool cur_match = vote >> r & 1;
if (match) {
int cur_idx = leading_key_idx + r;
for (int j = g.rank(); j < dim; j += TILE_SIZE) {
d_val[(tile_offset + cur_idx) * dim + j] = bucket->vector[cur_idx * dim + j];
}
}
}
}
}

} // namespace merlin
} // namespace nv

0 comments on commit 46c9f89

Please sign in to comment.