Skip to content

Commit

Permalink
Experiment with simd_masked_load to read beyond without undefined beh…
Browse files Browse the repository at this point in the history
…avior
  • Loading branch information
ogxd committed Nov 6, 2024
1 parent 9eb19b0 commit 7c5afa5
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 27 deletions.
18 changes: 11 additions & 7 deletions src/gxhash/platform/arm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ pub unsafe fn load_unaligned(p: *const State) -> State {

#[inline(always)]
pub unsafe fn get_partial_safe(data: *const State, len: usize) -> State {
// Temporary buffer filled with zeros
let mut buffer = [0i8; VECTOR_SIZE];
// Copy data into the buffer
core::ptr::copy(data as *const i8, buffer.as_mut_ptr(), len);
// Load the buffer into a __m256i vector
let partial_vector = vld1q_s8(buffer.as_ptr());
vaddq_s8(partial_vector, vdupq_n_s8(len as i8))
// // Temporary buffer filled with zeros
// let mut buffer = [0i8; VECTOR_SIZE];
// // Copy data into the buffer
// core::ptr::copy(data as *const i8, buffer.as_mut_ptr(), len);
// // Load the buffer into a __m256i vector
// let partial_vector = vld1q_s8(buffer.as_ptr());
// vaddq_s8(partial_vector, vdupq_n_s8(len as i8))

let indices = vld1q_s8([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15].as_ptr());
let mask = vreinterpretq_s8_u8(vcgtq_s8(vdupq_n_s8(len as i8), indices));
std::intrinsics::simd::simd_masked_load(mask, data as *const i8, vdupq_n_s8(len as i8))
}

#[inline(always)]
Expand Down
30 changes: 16 additions & 14 deletions src/gxhash/platform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,29 @@ use core::mem::size_of;

pub(crate) const VECTOR_SIZE: usize = size_of::<State>();
// 4KiB is the default page size for most systems, and conservative for other systems such as macOS ARM (16KiB)
const PAGE_SIZE: usize = 0x1000;
// const PAGE_SIZE: usize = 0x1000;

#[inline(always)]
pub unsafe fn get_partial(p: *const State, len: usize) -> State {
// Safety check
if check_same_page(p) {
get_partial_unsafe(p, len)
} else {
get_partial_safe(p, len)
}
}
// if check_same_page(p) {
// get_partial_unsafe(p, len)
// } else {
// get_partial_safe(p, len)
// }

#[inline(always)]
unsafe fn check_same_page(ptr: *const State) -> bool {
let address = ptr as usize;
// Mask to keep only the last 12 bits
let offset_within_page = address & (PAGE_SIZE - 1);
// Check if the 16th byte from the current offset exceeds the page boundary
offset_within_page < PAGE_SIZE - VECTOR_SIZE
get_partial_safe(p, len)
}

// #[inline(always)]
// unsafe fn check_same_page(ptr: *const State) -> bool {
// let address = ptr as usize;
// // Mask to keep only the last 12 bits
// let offset_within_page = address & (PAGE_SIZE - 1);
// // Check if the 16th byte from the current offset exceeds the page boundary
// offset_within_page < PAGE_SIZE - VECTOR_SIZE
// }

#[inline(always)]
pub unsafe fn finalize(hash: State) -> State {
let mut hash = aes_encrypt(hash, ld(KEYS.as_ptr()));
Expand Down
16 changes: 10 additions & 6 deletions src/gxhash/platform/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ pub unsafe fn load_unaligned(p: *const State) -> State {
#[inline(always)]
pub unsafe fn get_partial_safe(data: *const State, len: usize) -> State {
// Temporary buffer filled with zeros
let mut buffer = [0i8; VECTOR_SIZE];
// Copy data into the buffer
core::ptr::copy(data as *const i8, buffer.as_mut_ptr(), len);
// Load the buffer into a __m256i vector
let partial_vector = _mm_loadu_si128(buffer.as_ptr() as *const State);
_mm_add_epi8(partial_vector, _mm_set1_epi8(len as i8))
// let mut buffer = [0i8; VECTOR_SIZE];
// // Copy data into the buffer
// core::ptr::copy(data as *const i8, buffer.as_mut_ptr(), len);
// // Load the buffer into a __m256i vector
// let partial_vector = _mm_loadu_si128(buffer.as_ptr() as *const State);
// _mm_add_epi8(partial_vector, _mm_set1_epi8(len as i8))

let indices = _mm_set_epi8(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
let mask = _mm_cmpgt_epi8(_mm_set1_epi8(len as i8), indices);
std::intrinsics::simd::simd_masked_load(mask, data as *const i8, _mm_set1_epi8(len as i8))
}

#[inline(always)]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(core_intrinsics)]
#![cfg_attr(not(feature = "std"), no_std)]
// Hybrid SIMD width usage currently requires unstable 'stdsimd'
#![cfg_attr(feature = "hybrid", feature(stdarch_x86_avx512))]
Expand Down

0 comments on commit 7c5afa5

Please sign in to comment.