Skip to content

Commit

Permalink
Convolution for SSE and AVX
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinchowdhury18 committed Oct 25, 2024
1 parent e785cad commit 946a80e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 3 deletions.
1 change: 1 addition & 0 deletions chowdsp_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct FFT_Setup;
FFT_Setup* fft_new_setup (int N, fft_transform_t transform);
void fft_destroy_setup (FFT_Setup* s);
void pffft_transform_internal (FFT_Setup* setup, const float* finput, float* foutput, void* scratch, fft_direction_t direction, int ordered);
void pffft_convolve_internal (FFT_Setup* setup, const float* a, const float* b, float* ab, float scaling);
} // namespace chowdsp::fft::avx
static constexpr uintptr_t address_mask = ~static_cast<uintptr_t> (3);
static constexpr uintptr_t typeid_mask = static_cast<uintptr_t> (3);
Expand Down
41 changes: 40 additions & 1 deletion simd/chowdsp_fft_impl_avx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,46 @@ void pffft_transform_internal (FFT_Setup* setup, const float* finput, float* fou

void pffft_convolve_internal (FFT_Setup* setup, const float* a, const float* b, float* ab, float scaling)
{
// TODO!
int Ncvec = setup->Ncvec;
auto* va = reinterpret_cast<const __m256*> (a);
auto* vb = reinterpret_cast<const __m256*> (b);
auto* vab = reinterpret_cast<__m256*> (ab);

float ar0, ai0, br0, bi0, abr0, abi0;
const auto vscal = _mm256_set1_ps (scaling);
int i;

ar0 = reinterpret_cast<const float*> (&va[0])[0];
ai0 = reinterpret_cast<const float*> (&va[1])[0];
br0 = reinterpret_cast<const float*> (&vb[0])[0];
bi0 = reinterpret_cast<const float*> (&vb[1])[0];
abr0 = reinterpret_cast<const float*> (&vab[0])[0];
abi0 = reinterpret_cast<const float*> (&vab[1])[0];

for (i = 0; i < Ncvec; i += 2)
{
__m256 ar, ai, br, bi;
ar = va[2 * i + 0];
ai = va[2 * i + 1];
br = vb[2 * i + 0];
bi = vb[2 * i + 1];
cplx_mul_v (ar, ai, br, bi);
vab[2 * i + 0] = _mm256_fmadd_ps (ar, vscal, vab[2 * i + 0]);
vab[2 * i + 1] = _mm256_fmadd_ps (ai, vscal, vab[2 * i + 1]);
ar = va[2 * i + 2];
ai = va[2 * i + 3];
br = vb[2 * i + 2];
bi = vb[2 * i + 3];
cplx_mul_v (ar, ai, br, bi);
vab[2 * i + 2] = _mm256_fmadd_ps (ar, vscal, vab[2 * i + 2]);
vab[2 * i + 3] = _mm256_fmadd_ps (ai, vscal, vab[2 * i + 3]);
}

if (setup->transform == FFT_REAL)
{
reinterpret_cast<float*> (&vab[0])[0] = abr0 + ar0 * br0 * scaling;
reinterpret_cast<float*> (&vab[1])[0] = abi0 + ai0 * bi0 * scaling;
}
}
} // namespace chowdsp::fft::avx
#endif
41 changes: 40 additions & 1 deletion simd/chowdsp_fft_impl_sse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,45 @@ void pffft_transform_internal (FFT_Setup* setup, const float* finput, float* fou

void pffft_convolve_internal (FFT_Setup* setup, const float* a, const float* b, float* ab, float scaling)
{
// TODO!
int Ncvec = setup->Ncvec;
auto* va = reinterpret_cast<const __m128*> (a);
auto* vb = reinterpret_cast<const __m128*> (b);
auto* vab = reinterpret_cast<__m128*> (ab);

float ar0, ai0, br0, bi0, abr0, abi0;
const auto vscal = _mm_set1_ps (scaling);
int i;

ar0 = reinterpret_cast<const float*> (&va[0])[0];
ai0 = reinterpret_cast<const float*> (&va[1])[0];
br0 = reinterpret_cast<const float*> (&vb[0])[0];
bi0 = reinterpret_cast<const float*> (&vb[1])[0];
abr0 = reinterpret_cast<const float*> (&vab[0])[0];
abi0 = reinterpret_cast<const float*> (&vab[1])[0];

for (i = 0; i < Ncvec; i += 2)
{
__m128 ar, ai, br, bi;
ar = va[2 * i + 0];
ai = va[2 * i + 1];
br = vb[2 * i + 0];
bi = vb[2 * i + 1];
std::tie (ar, ai) = cplx_mul_v (ar, ai, br, bi);
vab[2 * i + 0] = _mm_add_ps (vab[2 * i + 0], _mm_mul_ps (ar, vscal));
vab[2 * i + 1] = _mm_add_ps (vab[2 * i + 1], _mm_mul_ps (ai, vscal));
ar = va[2 * i + 2];
ai = va[2 * i + 3];
br = vb[2 * i + 2];
bi = vb[2 * i + 3];
std::tie (ar, ai) = cplx_mul_v (ar, ai, br, bi);
vab[2 * i + 2] = _mm_add_ps (vab[2 * i + 2], _mm_mul_ps (ar, vscal));
vab[2 * i + 3] = _mm_add_ps (vab[2 * i + 3], _mm_mul_ps (ai, vscal));
}

if (setup->transform == FFT_REAL)
{
reinterpret_cast<float*> (&vab[0])[0] = abr0 + ar0 * br0 * scaling;
reinterpret_cast<float*> (&vab[1])[0] = abi0 + ai0 * bi0 * scaling;
}
}
} // namespace chowdsp::fft::sse
12 changes: 11 additions & 1 deletion test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

void compare (const float* ref, const float* test, int N)
{
const auto tol = 1.0e-6f * (float) N / 8.0f;
const auto tol = 2.0e-7f * (float) N;
for (int n = 0; n < N; ++n)
REQUIRE (test[n] == Catch::Approx { ref[n] }.margin(tol));
}
Expand Down Expand Up @@ -240,6 +240,16 @@ TEST_CASE("FFT AVX")
{
test_fft_real (fft_size, true);
}

SECTION ("Testing Complex Convolution with size: " + std::to_string (fft_size))
{
test_convolution_complex (fft_size, true);
}

SECTION ("Testing Real Convolution with size: " + std::to_string (fft_size))
{
test_convolution_real (fft_size, true);
}
}
}
#endif

0 comments on commit 946a80e

Please sign in to comment.