Skip to content

Commit

Permalink
Move constants to microkernels for sqrt/rsqrt
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662021516
  • Loading branch information
dsharletg authored and xnnpack-bot committed Aug 12, 2024
1 parent 825fdd6 commit 7256bcc
Show file tree
Hide file tree
Showing 54 changed files with 342 additions and 520 deletions.
24 changes: 12 additions & 12 deletions bench/f32-vrsqrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,70 +90,70 @@ BENCHMARK_CAPTURE(f32_vrsqrt, scalar_rsqrt_u4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_vrsqrt, sse_rsqrt_u4,
xnn_f32_vrsqrt_ukernel__sse_rsqrt_u4,
xnn_init_f32_rsqrt_sse_params)
/*init_params=*/nullptr)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, sse_rsqrt_u8,
xnn_f32_vrsqrt_ukernel__sse_rsqrt_u8,
xnn_init_f32_rsqrt_sse_params)
/*init_params=*/nullptr)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, sse_rsqrt_u16,
xnn_f32_vrsqrt_ukernel__sse_rsqrt_u16,
xnn_init_f32_rsqrt_sse_params)
/*init_params=*/nullptr)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, avx_rsqrt_u8,
xnn_f32_vrsqrt_ukernel__avx_rsqrt_u8,
xnn_init_f32_rsqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, avx_rsqrt_u16,
xnn_f32_vrsqrt_ukernel__avx_rsqrt_u16,
xnn_init_f32_rsqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, avx_rsqrt_u32,
xnn_f32_vrsqrt_ukernel__avx_rsqrt_u32,
xnn_init_f32_rsqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, fma3_rsqrt_u8,
xnn_f32_vrsqrt_ukernel__fma3_rsqrt_u8,
xnn_init_f32_rsqrt_fma3_params,
/*init_params=*/nullptr,
benchmark::utils::CheckFMA3)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, fma3_rsqrt_u16,
xnn_f32_vrsqrt_ukernel__fma3_rsqrt_u16,
xnn_init_f32_rsqrt_fma3_params,
/*init_params=*/nullptr,
benchmark::utils::CheckFMA3)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, fma3_rsqrt_u32,
xnn_f32_vrsqrt_ukernel__fma3_rsqrt_u32,
xnn_init_f32_rsqrt_fma3_params,
/*init_params=*/nullptr,
benchmark::utils::CheckFMA3)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, avx512f_rsqrt_u16,
xnn_f32_vrsqrt_ukernel__avx512f_rsqrt_u16,
xnn_init_f32_rsqrt_avx512_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512F)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, avx512f_rsqrt_u32,
xnn_f32_vrsqrt_ukernel__avx512f_rsqrt_u32,
xnn_init_f32_rsqrt_avx512_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512F)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vrsqrt, avx512f_rsqrt_u64,
xnn_f32_vrsqrt_ukernel__avx512f_rsqrt_u64,
xnn_init_f32_rsqrt_avx512_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512F)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
Expand Down
30 changes: 15 additions & 15 deletions bench/f32-vsqrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,88 +96,88 @@ void f32_vsqrt(benchmark::State& state, xnn_f32_vsqrt_ukernel_fn ukernel,
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, sse_rsqrt_u4,
xnn_f32_vsqrt_ukernel__sse_rsqrt_u4,
xnn_init_f32_sqrt_sse_params)
/*init_params=*/nullptr)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, sse_rsqrt_u8,
xnn_f32_vsqrt_ukernel__sse_rsqrt_u8,
xnn_init_f32_sqrt_sse_params)
/*init_params=*/nullptr)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, sse_rsqrt_u12,
xnn_f32_vsqrt_ukernel__sse_rsqrt_u12,
xnn_init_f32_sqrt_sse_params)
/*init_params=*/nullptr)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx_sqrt_u8,
xnn_f32_vsqrt_ukernel__avx_sqrt_u8,
xnn_init_f32_sqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx_sqrt_u16,
xnn_f32_vsqrt_ukernel__avx_sqrt_u16,
xnn_init_f32_sqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx_sqrt_u32,
xnn_f32_vsqrt_ukernel__avx_sqrt_u32,
xnn_init_f32_sqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx_rsqrt_u8,
xnn_f32_vsqrt_ukernel__avx_rsqrt_u8,
xnn_init_f32_sqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx_rsqrt_u16,
xnn_f32_vsqrt_ukernel__avx_rsqrt_u16,
xnn_init_f32_sqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx_rsqrt_u32,
xnn_f32_vsqrt_ukernel__avx_rsqrt_u32,
xnn_init_f32_sqrt_avx_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, fma3_rsqrt_u8,
xnn_f32_vsqrt_ukernel__fma3_rsqrt_u8,
xnn_init_f32_sqrt_fma_params,
/*init_params=*/nullptr,
benchmark::utils::CheckFMA3)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, fma3_rsqrt_u16,
xnn_f32_vsqrt_ukernel__fma3_rsqrt_u16,
xnn_init_f32_sqrt_fma_params,
/*init_params=*/nullptr,
benchmark::utils::CheckFMA3)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, fma3_rsqrt_u32,
xnn_f32_vsqrt_ukernel__fma3_rsqrt_u32,
xnn_init_f32_sqrt_fma_params,
/*init_params=*/nullptr,
benchmark::utils::CheckFMA3)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx512f_rsqrt_u16,
xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u16,
xnn_init_f32_sqrt_avx512_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512F)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx512f_rsqrt_u32,
xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u32,
xnn_init_f32_sqrt_avx512_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512F)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
BENCHMARK_CAPTURE(f32_vsqrt, avx512f_rsqrt_u48,
xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u48,
xnn_init_f32_sqrt_avx512_params,
/*init_params=*/nullptr,
benchmark::utils::CheckAVX512F)
->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
->UseRealTime();
Expand Down
16 changes: 10 additions & 6 deletions src/amalgam/gen/avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -5881,14 +5881,16 @@ void xnn_f32_vrsqrt_ukernel__avx_rsqrt_u16(
float* output,
const union xnn_f32_rsqrt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m256 vthree = _mm256_load_ps(params->avx.three);
const __m256 vhalf = _mm256_load_ps(params->avx.half);
const __m256 vthree = _mm256_set1_ps(3.0f);
const __m256 vhalf = _mm256_set1_ps(0.5f);

for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
const __m256 vx0 = _mm256_loadu_ps(input);
Expand Down Expand Up @@ -5936,7 +5938,7 @@ void xnn_f32_vrsqrt_ukernel__avx_rsqrt_u16(
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256((const __m256i*)((uintptr_t) &params->avx.mask_table[7] - batch));
const __m256i vmask = _mm256_loadu_si256((const __m256i*)((uintptr_t) &mask_table[7] - batch));

const __m256 vx = _mm256_maskload_ps(input, vmask);

Expand Down Expand Up @@ -6224,14 +6226,16 @@ void xnn_f32_vsqrt_ukernel__avx_rsqrt_u16(
size_t batch, const float* input, float* output,
const union xnn_f32_sqrt_params params[restrict XNN_MIN_ELEMENTS(1)])
XNN_OOB_READS {
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m256 kThree = _mm256_load_ps(params->avx.three);
const __m256 kHalf = _mm256_load_ps(params->avx.half);
const __m256 kThree = _mm256_set1_ps(3.0f);
const __m256 kHalf = _mm256_set1_ps(0.5f);

for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
const __m256 vx0 = _mm256_loadu_ps(input);
Expand Down Expand Up @@ -6293,7 +6297,7 @@ void xnn_f32_vsqrt_ukernel__avx_rsqrt_u16(
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256(
(const __m256i*)((uintptr_t)&params->avx.mask_table[7] - batch));
(const __m256i*)((uintptr_t)&mask_table[7] - batch));

const __m256 vx = _mm256_maskload_ps(input, vmask);

Expand Down
8 changes: 4 additions & 4 deletions src/amalgam/gen/avx512f.c
Original file line number Diff line number Diff line change
Expand Up @@ -3893,8 +3893,8 @@ void xnn_f32_vrsqrt_ukernel__avx512f_rsqrt_u32(
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m512 vthree = _mm512_set1_ps(params->avx512.three);
const __m512 vneg_half = _mm512_set1_ps(params->avx512.neg_half);
const __m512 vthree = _mm512_set1_ps(3.0f);
const __m512 vneg_half = _mm512_set1_ps(-0.5f);

for (; batch >= 32 * sizeof(float); batch -= 32 * sizeof(float)) {
const __m512 vx0 = _mm512_loadu_ps(input);
Expand Down Expand Up @@ -4126,8 +4126,8 @@ void xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u16(
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m512 vneg_three = _mm512_set1_ps(params->avx512.neg_three);
const __m512 vneg_half = _mm512_set1_ps(params->avx512.neg_half);
const __m512 vneg_three = _mm512_set1_ps(-3.0f);
const __m512 vneg_half = _mm512_set1_ps(-0.5f);

for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
const __m512 vx = _mm512_loadu_ps(input);
Expand Down
16 changes: 10 additions & 6 deletions src/amalgam/gen/fma3.c
Original file line number Diff line number Diff line change
Expand Up @@ -5376,14 +5376,16 @@ void xnn_f32_vrsqrt_ukernel__fma3_rsqrt_u16(
float* output,
const union xnn_f32_rsqrt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m256 vthree = _mm256_load_ps(params->fma3.three);
const __m256 vneg_half = _mm256_load_ps(params->fma3.neg_half);
const __m256 vthree = _mm256_set1_ps(3.0f);
const __m256 vneg_half = _mm256_set1_ps(-0.5f);

for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
const __m256 vx0 = _mm256_loadu_ps(input);
Expand Down Expand Up @@ -5428,7 +5430,7 @@ void xnn_f32_vrsqrt_ukernel__fma3_rsqrt_u16(
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256((const __m256i*)((uintptr_t) &params->fma3.mask_table[7] - batch));
const __m256i vmask = _mm256_loadu_si256((const __m256i*)((uintptr_t) &mask_table[7] - batch));

const __m256 vx = _mm256_maskload_ps(input, vmask);

Expand Down Expand Up @@ -5462,14 +5464,16 @@ void xnn_f32_vsqrt_ukernel__fma3_rsqrt_u16(
size_t batch, const float* input, float* output,
const union xnn_f32_sqrt_params params[restrict XNN_MIN_ELEMENTS(1)])
XNN_OOB_READS {
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m256 vthree = _mm256_load_ps(params->fma3.three);
const __m256 vneg_half = _mm256_load_ps(params->fma3.neg_half);
const __m256 vthree = _mm256_set1_ps(3.0f);
const __m256 vneg_half = _mm256_set1_ps(-0.5f);

for (; batch >= 16 * sizeof(float); batch -= 16 * sizeof(float)) {
const __m256 vx0 = _mm256_loadu_ps(input);
Expand Down Expand Up @@ -5528,7 +5532,7 @@ void xnn_f32_vsqrt_ukernel__fma3_rsqrt_u16(
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256(
(const __m256i*)((uintptr_t)&params->fma3.mask_table[7] - batch));
(const __m256i*)((uintptr_t)&mask_table[7] - batch));

const __m256 vx = _mm256_maskload_ps(input, vmask);

Expand Down
8 changes: 4 additions & 4 deletions src/amalgam/gen/sse.c
Original file line number Diff line number Diff line change
Expand Up @@ -9453,8 +9453,8 @@ void xnn_f32_vrsqrt_ukernel__sse_rsqrt_u8(
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m128 vthree = _mm_load_ps(params->sse.three);
const __m128 vhalf = _mm_load_ps(params->sse.half);
const __m128 vthree = _mm_set1_ps(3.0f);
const __m128 vhalf = _mm_set1_ps(0.5f);

for (; batch >= 8 * sizeof(float); batch -= 8 * sizeof(float)) {
const __m128 vx0123 = _mm_loadu_ps(input);
Expand Down Expand Up @@ -9535,8 +9535,8 @@ void xnn_f32_vsqrt_ukernel__sse_rsqrt_u12(
assert(output != NULL);

// Constants for the Newton-Raphson iteration.
const __m128 vthree = _mm_load_ps(params->sse.three);
const __m128 vhalf = _mm_load_ps(params->sse.half);
const __m128 vthree = _mm_set1_ps(3.0f);
const __m128 vhalf = _mm_set1_ps(0.5f);

for (; batch >= 12 * sizeof(float); batch -= 12 * sizeof(float)) {
const __m128 vx0 = _mm_loadu_ps(input);
Expand Down
8 changes: 0 additions & 8 deletions src/configs/unary-elementwise-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -1406,19 +1406,15 @@ static void init_f32_sqrt_config(void) {
assert(hardware_config != NULL);
if (hardware_config->use_x86_avx512f) {
f32_sqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vsqrt_ukernel__avx512f_rsqrt_u16;
f32_sqrt_config.init.f32_sqrt = xnn_init_f32_sqrt_avx512_params;
f32_sqrt_config.element_tile = 16;
} else if (hardware_config->use_x86_fma3) {
f32_sqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vsqrt_ukernel__fma3_rsqrt_u16;
f32_sqrt_config.init.f32_sqrt = xnn_init_f32_sqrt_fma_params;
f32_sqrt_config.element_tile = 16;
} else if (hardware_config->use_x86_avx) {
f32_sqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vsqrt_ukernel__avx_rsqrt_u16;
f32_sqrt_config.init.f32_sqrt = xnn_init_f32_sqrt_avx_params;
f32_sqrt_config.element_tile = 16;
} else {
f32_sqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vsqrt_ukernel__sse_rsqrt_u12;
f32_sqrt_config.init.f32_sqrt = xnn_init_f32_sqrt_sse_params;
f32_sqrt_config.element_tile = 12;
}
#elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
Expand Down Expand Up @@ -1452,19 +1448,15 @@ static void init_f32_rsqrt_config(void) {
assert(hardware_config != NULL);
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) {
f32_rsqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vrsqrt_ukernel__avx512f_rsqrt_u32;
f32_rsqrt_config.init.f32_rsqrt = xnn_init_f32_rsqrt_avx512_params;
f32_rsqrt_config.element_tile = 32;
} else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_fma3) {
f32_rsqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vrsqrt_ukernel__fma3_rsqrt_u16;
f32_rsqrt_config.init.f32_rsqrt = xnn_init_f32_rsqrt_fma3_params;
f32_rsqrt_config.element_tile = 16;
} else if (hardware_config->use_x86_avx) {
f32_rsqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vrsqrt_ukernel__avx_rsqrt_u16;
f32_rsqrt_config.init.f32_rsqrt = xnn_init_f32_rsqrt_avx_params;
f32_rsqrt_config.element_tile = 16;
} else {
f32_rsqrt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_f32_vrsqrt_ukernel__sse_rsqrt_u8;
f32_rsqrt_config.init.f32_rsqrt = xnn_init_f32_rsqrt_sse_params;
f32_rsqrt_config.element_tile = 8;
}
#elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
Expand Down
Loading

0 comments on commit 7256bcc

Please sign in to comment.