Skip to content

Commit

Permalink
fix: Value error when dump table with none TILE_SIZE integer multiple…
Browse files Browse the repository at this point in the history
…s of offset and search length
  • Loading branch information
Lifann committed Jan 3, 2025
1 parent fa590c7 commit 3f04c9e
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2617,56 +2617,59 @@ class HashTable : public HashTableBase<K, V, S> {
return;
}

bool match_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 &&
options_.max_bucket_size >= TILE_SIZE &&
bool basic_fast_cond = options_.max_bucket_size % 32 == 0 &&
options_.max_bucket_size >= 32 &&
offset % TILE_SIZE == 0 && n % TILE_SIZE == 0;
bool use_fast_mode = false;

if (match_fast_cond) {
if (basic_fast_cond) {
int grid_size = std::min(
sm_cnt_ * max_threads_per_block_ / options_.block_size,
static_cast<int>(SAFE_GET_GRID_SIZE(n, options_.block_size)));
if (sizeof(V) == sizeof(float) && dim() >= 32 && dim() % 4 == 0) {
if (dim() >= 128) {
const int TILE_SIZE = 32;
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
if (dim() >= 128 && offset % 32 == 0 && n % 32 == 0) {
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, 32>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else if (dim() >= 64) {
const int TILE_SIZE = 16;
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 64 && offset % 16 == 0 && n % 16 == 0) {
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, 16>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else {
const int TILE_SIZE = 8;
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 8 && offset % 8 == 0 && n % 8 == 0) {
dump_kernel_v2_vectorized<key_type, value_type, score_type, PredFunctor, 8>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
use_fast_mode = true;
}
} else {
if (dim() >= 32) {
const int TILE_SIZE = 32;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
}
if (!use_fast_mode) {
if (dim() >= 32 && offset % 32 == 0 && n % 32 == 0) {
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, 32>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else if (dim() >= 16) {
const int TILE_SIZE = 16;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 16 && offset % 16 == 0 && n % 16 == 0) {
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, 16>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
} else {
const int TILE_SIZE = 8;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
use_fast_mode = true;
} else if (dim() >= 8 && offset % 8 == 0 && n % 8 == 0) {
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, 8>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
scores, offset, n, d_counter);
use_fast_mode = true;
}
}
} else {
}
if (!use_fast_mode) {
const size_t score_size = scores ? sizeof(score_type) : 0;
const size_t kvm_size =
sizeof(key_type) + sizeof(value_type) * dim() + score_size;
Expand Down

0 comments on commit 3f04c9e

Please sign in to comment.