Skip to content

Commit

Permalink
Remove target specific versions of rsum, rdsum params
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658913770
  • Loading branch information
dsharletg authored and xnnpack-bot committed Aug 12, 2024
1 parent 9e5b24e commit 33dcaec
Show file tree
Hide file tree
Showing 87 changed files with 470 additions and 499 deletions.
8 changes: 4 additions & 4 deletions bench/f16-f32acc-rdsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rdsum, f16c_c16,
xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c16,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand All @@ -62,7 +62,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rdsum, f16c_c32,
xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c32,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand All @@ -72,7 +72,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rdsum, f16c_c64,
xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c64,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand All @@ -82,7 +82,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rdsum, f16c_c128,
xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c128,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand Down
10 changes: 5 additions & 5 deletions bench/f16-f32acc-rsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rsum, f16c_u8,
xnn_f16_f32acc_rsum_ukernel__f16c_u8,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -92,7 +92,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rsum, f16c_u16_acc2,
xnn_f16_f32acc_rsum_ukernel__f16c_u16_acc2,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -102,7 +102,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rsum, f16c_u24_acc3,
xnn_f16_f32acc_rsum_ukernel__f16c_u24_acc3,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -112,7 +112,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rsum, f16c_u32_acc2,
xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc2,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -122,7 +122,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f16_f32acc_rsum, f16c_u32_acc4,
xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc4,
xnn_init_f16_f32acc_scale_avx_params,
xnn_init_f16_f32acc_scale_scalar_params,
benchmark::utils::CheckF16C)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand Down
12 changes: 6 additions & 6 deletions bench/f32-rdsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rdsum, sse_c16,
xnn_f32_rdsum_ukernel_7p7x__sse_c16,
xnn_init_f32_scaleminmax_sse_params)
xnn_init_f32_scaleminmax_scalar_params)
->Apply(BenchmarkRDSUM)
->UseRealTime();
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
Expand All @@ -67,7 +67,7 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rdsum, sse_c32,
xnn_f32_rdsum_ukernel_7p7x__sse_c32,
xnn_init_f32_scaleminmax_sse_params)
xnn_init_f32_scaleminmax_scalar_params)
->Apply(BenchmarkRDSUM)
->UseRealTime();
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
Expand All @@ -76,7 +76,7 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rdsum, sse_c64,
xnn_f32_rdsum_ukernel_7p7x__sse_c64,
xnn_init_f32_scaleminmax_sse_params)
xnn_init_f32_scaleminmax_scalar_params)
->Apply(BenchmarkRDSUM)
->UseRealTime();
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
Expand All @@ -85,7 +85,7 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rdsum, avx_c16,
xnn_f32_rdsum_ukernel_7p7x__avx_c16,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand All @@ -95,7 +95,7 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rdsum, avx_c32,
xnn_f32_rdsum_ukernel_7p7x__avx_c32,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand All @@ -105,7 +105,7 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4,
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rdsum, avx_c64,
xnn_f32_rdsum_ukernel_7p7x__avx_c64,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRDSUM)
->UseRealTime();
Expand Down
10 changes: 5 additions & 5 deletions bench/f32-rsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum, avx_u8,
xnn_f32_rsum_ukernel__avx_u8,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -127,7 +127,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum, avx_u16_acc2,
xnn_f32_rsum_ukernel__avx_u16_acc2,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -137,7 +137,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum, avx_u24_acc3,
xnn_f32_rsum_ukernel__avx_u24_acc3,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -147,7 +147,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum, avx_u32_acc2,
xnn_f32_rsum_ukernel__avx_u32_acc2,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand All @@ -157,7 +157,7 @@
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
BENCHMARK_CAPTURE(f32_rsum, avx_u32_acc4,
xnn_f32_rsum_ukernel__avx_u32_acc4,
xnn_init_f32_scaleminmax_avx_params,
xnn_init_f32_scaleminmax_scalar_params,
benchmark::utils::CheckAVX)
->Apply(BenchmarkRSUM)
->UseRealTime();
Expand Down
20 changes: 12 additions & 8 deletions src/amalgam/gen/avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -3655,14 +3655,16 @@ void xnn_f32_rdsum_ukernel_7p7x__avx_c32(
float* output,
const union xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(rows != 0);
assert(channels != 0);
assert(input != NULL);
assert(output != NULL);

const __m256 vscale = _mm256_set1_ps(params->avx.scale);
const __m256 vmin = _mm256_set1_ps(params->avx.min);
const __m256 vmax = _mm256_set1_ps(params->avx.max);
const __m256 vscale = _mm256_set1_ps(params->scalar.scale);
const __m256 vmin = _mm256_set1_ps(params->scalar.min);
const __m256 vmax = _mm256_set1_ps(params->scalar.max);

size_t input_increment = 7 * input_stride;
for (; channels >= 32; channels -= 32) {
Expand Down Expand Up @@ -3844,7 +3846,7 @@ void xnn_f32_rdsum_ukernel_7p7x__avx_c32(
}

if (remainder) {
vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx.mask_table[7] - (channels & 0x7) * sizeof(float)));
vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - (channels & 0x7) * sizeof(float)));
vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i0[num_full_chunks*8], vmask), vacc[num_full_chunks]);
vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i1[num_full_chunks*8], vmask), vacc[num_full_chunks]);
vacc[num_full_chunks] = _mm256_add_ps(_mm256_maskload_ps(&i2[num_full_chunks*8], vmask), vacc[num_full_chunks]);
Expand Down Expand Up @@ -4028,6 +4030,8 @@ void xnn_f32_rsum_ukernel__avx_u32_acc4(
float* output,
const union xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input != NULL);
Expand Down Expand Up @@ -4061,16 +4065,16 @@ void xnn_f32_rsum_ukernel__avx_u32_acc4(
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(float));
assert(batch <= 7 * sizeof(float));
const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx.mask_table[7] - batch));
const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &mask_table[7] - batch));
const __m256 vt = _mm256_maskload_ps(input, vmask);
vacc0 = _mm256_add_ps(vacc0, vt);
}
__m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), _mm256_extractf128_ps(vacc0, 1));
vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc));
vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc));
vacc = _mm_mul_ss(vacc, _mm_load_ss(&params->avx.scale));
vacc = _mm_max_ss(vacc, _mm_load_ss(&params->avx.min));
vacc = _mm_min_ss(vacc, _mm_load_ss(&params->avx.max));
vacc = _mm_mul_ss(vacc, _mm_load_ss(&params->scalar.scale));
vacc = _mm_max_ss(vacc, _mm_load_ss(&params->scalar.min));
vacc = _mm_min_ss(vacc, _mm_load_ss(&params->scalar.max));
*output += _mm_cvtss_f32(vacc);
}

Expand Down
4 changes: 2 additions & 2 deletions src/amalgam/gen/avx512skx.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void xnn_f16_f32acc_rdsum_ukernel_7p7x__avx512skx_c64(
assert(input != NULL);
assert(output != NULL);

const __m512 vscale = _mm512_set1_ps(params->scalar.scale);
const __m512 vscale = _mm512_set1_ps(params->scale);

size_t input_increment = 7 * input_stride;
for (; channels >= 64; channels -= 64) {
Expand Down Expand Up @@ -343,7 +343,7 @@ void xnn_f16_f32acc_rsum_ukernel__avx512skx_u64_acc4(
__m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc256), _mm256_extractf128_ps(vacc256, 1));
vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc));
vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc));
vacc = _mm_mul_ss(vacc, _mm_load_ss(&params->scalar.scale));
vacc = _mm_mul_ss(vacc, _mm_load_ss(&params->scale));

float vout = _mm_cvtss_f32(vacc);
*output += vout;
Expand Down
8 changes: 5 additions & 3 deletions src/amalgam/gen/f16c.c
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ void xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c32(
assert(input != NULL);
assert(output != NULL);

const __m256 vscale = _mm256_set1_ps(params->avx.scale);
const __m256 vscale = _mm256_set1_ps(params->scale);

size_t input_increment = 7 * input_stride;
for (; channels >= 32; channels -= 32) {
Expand Down Expand Up @@ -837,6 +837,8 @@ void xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc4(
float* output,
const union xnn_f16_f32acc_scale_params params[restrict XNN_MIN_ELEMENTS(1)])
{
static const int32_t mask_table[14] = {-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0};

assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input != NULL);
Expand Down Expand Up @@ -871,7 +873,7 @@ void xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc4(
if XNN_UNLIKELY(batch != 0) {
assert(batch >= 1 * sizeof(uint16_t));
assert(batch <= 7 * sizeof(uint16_t));
const __m128i vmask = _mm_loadu_si128((const __m128i*) ((uintptr_t) &params->avx.mask_table[7] - batch));
const __m128i vmask = _mm_loadu_si128((const __m128i*) ((uintptr_t) &mask_table[7] - batch));
const __m128i vh = _mm_castps_si128(_mm_maskload_ps((const float*) i, vmask));
const __m256 vt = _mm256_cvtph_ps(vh);
vacc0 = _mm256_add_ps(vacc0, vt);
Expand All @@ -885,7 +887,7 @@ void xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc4(
__m128 vacc = _mm_add_ps(_mm256_castps256_ps128(vacc0), _mm256_extractf128_ps(vacc0, 1));
vacc = _mm_add_ps(vacc, _mm_movehl_ps(vacc, vacc));
vacc = _mm_add_ss(vacc, _mm_movehdup_ps(vacc));
vacc = _mm_mul_ss(vacc, _mm_load_ss(&params->avx.scale));
vacc = _mm_mul_ss(vacc, _mm_load_ss(&params->scale));

float vout = _mm_cvtss_f32(vacc);
*output += vout;
Expand Down
4 changes: 2 additions & 2 deletions src/amalgam/gen/neonfp16arith.c
Original file line number Diff line number Diff line change
Expand Up @@ -4276,7 +4276,7 @@ void xnn_f16_f32acc_rdsum_ukernel_7p7x__neonfp16arith_c16(
assert(input != NULL);
assert(output != NULL);

const float32x4_t vscale = vdupq_n_f32(params->scalar.scale);
const float32x4_t vscale = vld1q_dup_f32(&params->scale);

size_t input_increment = 7 * input_stride;
for (; channels >= 16; channels -= 16) {
Expand Down Expand Up @@ -4532,7 +4532,7 @@ void xnn_f16_f32acc_rsum_ukernel__neonfp16arith_u32_acc4(
const float32x4_t vt = vcvt_f32_f16(vh);
vacc0 = vaddq_f32(vacc0, vt);
}
const float32x2_t vscale = vld1_dup_f32(&params->scalar.scale);
const float32x2_t vscale = vld1_dup_f32(&params->scale);
float32x2_t vacc = vadd_f32(vget_low_f32(vacc0), vget_high_f32(vacc0));
if XNN_UNLIKELY(batch & (2 * sizeof(uint16_t))) {
const float16x4_t vh = vreinterpret_f16_u32(vld1_dup_u32((const void*) i)); i += 2;
Expand Down
6 changes: 3 additions & 3 deletions src/amalgam/gen/sse.c
Original file line number Diff line number Diff line change
Expand Up @@ -7390,9 +7390,9 @@ void xnn_f32_rdsum_ukernel_7p7x__sse_c16(
assert(input != NULL);
assert(output != NULL);

const __m128 vscale = _mm_load_ps(params->sse.scale);
const __m128 vmin = _mm_load_ps(params->sse.min);
const __m128 vmax = _mm_load_ps(params->sse.max);
const __m128 vscale = _mm_set1_ps(params->scalar.scale);
const __m128 vmin = _mm_set1_ps(params->scalar.min);
const __m128 vmax = _mm_set1_ps(params->scalar.max);

size_t input_increment = 7 * input_stride;
for (; channels >= 16; channels -= 16) {
Expand Down
10 changes: 5 additions & 5 deletions src/configs/reduce-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static void init_f16_f32acc_rsum_config(void) {
} else if (hardware_config->use_x86_f16c) {
f16_f32acc_rsum_config = (struct xnn_reduce_config) {
.ukernel = (xnn_reduce_ukernel_fn) xnn_f16_f32acc_rsum_ukernel__f16c_u32_acc4,
.init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_avx_params,
.init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_scalar_params,
.element_tile = 32,
};
}
Expand Down Expand Up @@ -184,7 +184,7 @@ static void init_f32_rsum_config(void) {
} else if (hardware_config->use_x86_avx) {
f32_rsum_config = (struct xnn_reduce_config) {
.ukernel = (xnn_reduce_ukernel_fn) xnn_f32_rsum_ukernel__avx_u32_acc4,
.init.f32_scaleminmax = xnn_init_f32_scaleminmax_avx_params,
.init.f32_scaleminmax = xnn_init_f32_scaleminmax_scalar_params,
.element_tile = 32,
};
} else {
Expand Down Expand Up @@ -232,7 +232,7 @@ static void init_f16_f32acc_rdsum_config(void) {
} else if (hardware_config->use_x86_f16c) {
f16_f32acc_rdsum_config = (struct xnn_reduce_config) {
.rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f16_f32acc_rdsum_ukernel_7p7x__f16c_c32,
.init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_avx_params,
.init.f16_f32acc_scale = xnn_init_f16_f32acc_scale_scalar_params,
.element_tile = 32,
};
}
Expand Down Expand Up @@ -274,13 +274,13 @@ static void init_f32_rdsum_config(void) {
} else if (hardware_config->use_x86_avx) {
f32_rdsum_config = (struct xnn_reduce_config) {
.rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__avx_c32,
.init.f32_scaleminmax = xnn_init_f32_scaleminmax_avx_params,
.init.f32_scaleminmax = xnn_init_f32_scaleminmax_scalar_params,
.element_tile = 32,
};
} else {
f32_rdsum_config = (struct xnn_reduce_config) {
.rd_ukernel = (xnn_rdsum_ukernel_fn) xnn_f32_rdsum_ukernel_7p7x__sse_c16,
.init.f32_scaleminmax = xnn_init_f32_scaleminmax_sse_params,
.init.f32_scaleminmax = xnn_init_f32_scaleminmax_scalar_params,
.element_tile = 16,
};
}
Expand Down
2 changes: 1 addition & 1 deletion src/f16-f32acc-rdsum/avx.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void xnn_f16_f32acc_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__f16c_c${CHAN
assert(input != NULL);
assert(output != NULL);

const __m256 vscale = _mm256_set1_ps(params->avx.scale);
const __m256 vscale = _mm256_set1_ps(params->scale);

size_t input_increment = ${ACCUMULATORS} * input_stride;
for (; channels >= ${CHANNELS_BATCH}; channels -= ${CHANNELS_BATCH}) {
Expand Down
2 changes: 1 addition & 1 deletion src/f16-f32acc-rdsum/avx512skx.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void xnn_f16_f32acc_rdsum_ukernel_${ACCUMULATORS}p${ACCUMULATORS}x__avx512skx_c$
assert(input != NULL);
assert(output != NULL);

const __m512 vscale = _mm512_set1_ps(params->scalar.scale);
const __m512 vscale = _mm512_set1_ps(params->scale);

size_t input_increment = ${ACCUMULATORS} * input_stride;
for (; channels >= ${CHANNELS_BATCH}; channels -= ${CHANNELS_BATCH}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void xnn_f16_f32acc_rdsum_ukernel_7p7x__avx512skx_c128(
assert(input != NULL);
assert(output != NULL);

const __m512 vscale = _mm512_set1_ps(params->scalar.scale);
const __m512 vscale = _mm512_set1_ps(params->scale);

size_t input_increment = 7 * input_stride;
for (; channels >= 128; channels -= 128) {
Expand Down
Loading

0 comments on commit 33dcaec

Please sign in to comment.