@@ -219,18 +219,7 @@ inline void TwistIFFTUInt(PolynomialInFD<P> &res, const Polynomial<P> &a)
219219template <uint32_t N>
220220inline void MulInFD (std::array<double , N> &res, const std::array<double , N> &b)
221221{
222- #if defined(USE_INTERLEAVED_FORMAT) && defined(__AVX2__)
223- // AVX2 interleaved complex multiply: 2 complex per YMM
224- for (uint32_t i = 0 ; i < N; i += 4 ) {
225- __m256d a = _mm256_load_pd (res.data () + i);
226- __m256d w = _mm256_load_pd (b.data () + i);
227- __m256d w_swap = _mm256_permute_pd (w, 0b0101 );
228- __m256d a_re = _mm256_unpacklo_pd (a, a);
229- __m256d a_im = _mm256_unpackhi_pd (a, a);
230- _mm256_store_pd (res.data () + i,
231- _mm256_fmaddsub_pd (a_re, w, _mm256_mul_pd (a_im, w_swap)));
232- }
233- #elif defined(USE_INTERLEAVED_FORMAT)
222+ #ifdef USE_INTERLEAVED_FORMAT
234223 for (int i = 0 ; i < N / 2 ; i++) {
235224 const std::complex tmp = std::complex (res[2 * i], res[2 * i + 1 ]) *
236225 std::complex (b[2 * i], b[2 * i + 1 ]);
@@ -256,27 +245,17 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &b)
256245#elif defined(__AVX2__) && !defined(__AVX512F__)
257246 double *rre = res.data (), *rim = res.data () + N / 2 ;
258247 const double *bre = b.data (), *bim = b.data () + N / 2 ;
259- for (uint32_t i = 0 ; i < N / 2 ; i += 8 ) {
260- __m256d va_re0 = _mm256_load_pd (rre + i);
261- __m256d va_im0 = _mm256_load_pd (rim + i);
262- __m256d vb_re0 = _mm256_load_pd (bre + i);
263- __m256d vb_im0 = _mm256_load_pd (bim + i);
264- __m256d va_re1 = _mm256_load_pd (rre + i + 4 );
265- __m256d va_im1 = _mm256_load_pd (rim + i + 4 );
266- __m256d vb_re1 = _mm256_load_pd (bre + i + 4 );
267- __m256d vb_im1 = _mm256_load_pd (bim + i + 4 );
268- __m256d vr_re0 = _mm256_mul_pd (va_re0, vb_re0);
269- __m256d vr_re1 = _mm256_mul_pd (va_re1, vb_re1);
270- vr_re0 = _mm256_fnmadd_pd (va_im0, vb_im0, vr_re0);
271- vr_re1 = _mm256_fnmadd_pd (va_im1, vb_im1, vr_re1);
272- __m256d vr_im0 = _mm256_mul_pd (va_im0, vb_re0);
273- __m256d vr_im1 = _mm256_mul_pd (va_im1, vb_re1);
274- vr_im0 = _mm256_fmadd_pd (va_re0, vb_im0, vr_im0);
275- vr_im1 = _mm256_fmadd_pd (va_re1, vb_im1, vr_im1);
276- _mm256_store_pd (rre + i, vr_re0);
277- _mm256_store_pd (rre + i + 4 , vr_re1);
278- _mm256_store_pd (rim + i, vr_im0);
279- _mm256_store_pd (rim + i + 4 , vr_im1);
248+ for (uint32_t i = 0 ; i < N / 2 ; i += 4 ) {
249+ __m256d va_re = _mm256_load_pd (rre + i);
250+ __m256d va_im = _mm256_load_pd (rim + i);
251+ __m256d vb_re = _mm256_load_pd (bre + i);
252+ __m256d vb_im = _mm256_load_pd (bim + i);
253+ __m256d vr_re = _mm256_mul_pd (va_re, vb_re);
254+ vr_re = _mm256_fnmadd_pd (va_im, vb_im, vr_re);
255+ __m256d vr_im = _mm256_mul_pd (va_im, vb_re);
256+ vr_im = _mm256_fmadd_pd (va_re, vb_im, vr_im);
257+ _mm256_store_pd (rre + i, vr_re);
258+ _mm256_store_pd (rim + i, vr_im);
280259 }
281260#else
282261 for (int i = 0 ; i < N / 2 ; i++) {
@@ -292,17 +271,7 @@ template <uint32_t N>
292271inline void MulInFD (std::array<double , N> &res, const std::array<double , N> &a,
293272 const std::array<double , N> &b)
294273{
295- #if defined(USE_INTERLEAVED_FORMAT) && defined(__AVX2__)
296- for (uint32_t i = 0 ; i < N; i += 4 ) {
297- __m256d va = _mm256_load_pd (a.data () + i);
298- __m256d w = _mm256_load_pd (b.data () + i);
299- __m256d w_swap = _mm256_permute_pd (w, 0b0101 );
300- __m256d a_re = _mm256_unpacklo_pd (va, va);
301- __m256d a_im = _mm256_unpackhi_pd (va, va);
302- _mm256_store_pd (res.data () + i,
303- _mm256_fmaddsub_pd (a_re, w, _mm256_mul_pd (a_im, w_swap)));
304- }
305- #elif defined(USE_INTERLEAVED_FORMAT)
274+ #ifdef USE_INTERLEAVED_FORMAT
306275 for (int i = 0 ; i < N / 2 ; i++) {
307276 const std::complex tmp = std::complex (a[2 * i], a[2 * i + 1 ]) *
308277 std::complex (b[2 * i], b[2 * i + 1 ]);
@@ -330,27 +299,17 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
330299 const double *are = a.data (), *aim = a.data () + N / 2 ;
331300 const double *bre = b.data (), *bim = b.data () + N / 2 ;
332301 double *rre = res.data (), *rim = res.data () + N / 2 ;
333- for (uint32_t i = 0 ; i < N / 2 ; i += 8 ) {
334- __m256d va_re0 = _mm256_load_pd (are + i);
335- __m256d va_im0 = _mm256_load_pd (aim + i);
336- __m256d vb_re0 = _mm256_load_pd (bre + i);
337- __m256d vb_im0 = _mm256_load_pd (bim + i);
338- __m256d va_re1 = _mm256_load_pd (are + i + 4 );
339- __m256d va_im1 = _mm256_load_pd (aim + i + 4 );
340- __m256d vb_re1 = _mm256_load_pd (bre + i + 4 );
341- __m256d vb_im1 = _mm256_load_pd (bim + i + 4 );
342- __m256d vr_re0 = _mm256_mul_pd (va_re0, vb_re0);
343- __m256d vr_re1 = _mm256_mul_pd (va_re1, vb_re1);
344- vr_re0 = _mm256_fnmadd_pd (va_im0, vb_im0, vr_re0);
345- vr_re1 = _mm256_fnmadd_pd (va_im1, vb_im1, vr_re1);
346- __m256d vr_im0 = _mm256_mul_pd (va_im0, vb_re0);
347- __m256d vr_im1 = _mm256_mul_pd (va_im1, vb_re1);
348- vr_im0 = _mm256_fmadd_pd (va_re0, vb_im0, vr_im0);
349- vr_im1 = _mm256_fmadd_pd (va_re1, vb_im1, vr_im1);
350- _mm256_store_pd (rre + i, vr_re0);
351- _mm256_store_pd (rre + i + 4 , vr_re1);
352- _mm256_store_pd (rim + i, vr_im0);
353- _mm256_store_pd (rim + i + 4 , vr_im1);
302+ for (uint32_t i = 0 ; i < N / 2 ; i += 4 ) {
303+ __m256d va_re = _mm256_load_pd (are + i);
304+ __m256d va_im = _mm256_load_pd (aim + i);
305+ __m256d vb_re = _mm256_load_pd (bre + i);
306+ __m256d vb_im = _mm256_load_pd (bim + i);
307+ __m256d vr_re = _mm256_mul_pd (va_re, vb_re);
308+ vr_re = _mm256_fnmadd_pd (va_im, vb_im, vr_re);
309+ __m256d vr_im = _mm256_mul_pd (va_im, vb_re);
310+ vr_im = _mm256_fmadd_pd (va_re, vb_im, vr_im);
311+ _mm256_store_pd (rre + i, vr_re);
312+ _mm256_store_pd (rim + i, vr_im);
354313 }
355314#else
356315 for (int i = 0 ; i < N / 2 ; i++) {
@@ -370,19 +329,7 @@ template <uint32_t N>
370329inline void FMAInFD (std::array<double , N> &res, const std::array<double , N> &a,
371330 const std::array<double , N> &b)
372331{
373- #if defined(USE_INTERLEAVED_FORMAT) && defined(__AVX2__)
374- // AVX2 interleaved complex FMA: res += a * b
375- for (uint32_t i = 0 ; i < N; i += 4 ) {
376- __m256d va = _mm256_load_pd (a.data () + i);
377- __m256d w = _mm256_load_pd (b.data () + i);
378- __m256d r = _mm256_load_pd (res.data () + i);
379- __m256d w_swap = _mm256_permute_pd (w, 0b0101 );
380- __m256d a_re = _mm256_unpacklo_pd (va, va);
381- __m256d a_im = _mm256_unpackhi_pd (va, va);
382- __m256d prod = _mm256_fmaddsub_pd (a_re, w, _mm256_mul_pd (a_im, w_swap));
383- _mm256_store_pd (res.data () + i, _mm256_add_pd (r, prod));
384- }
385- #elif defined(USE_INTERLEAVED_FORMAT)
332+ #ifdef USE_INTERLEAVED_FORMAT
386333 for (int i = 0 ; i < N / 2 ; i++) {
387334 std::complex tmp = std::complex (a[2 * i], a[2 * i + 1 ]) *
388335 std::complex (b[2 * i], b[2 * i + 1 ]);
@@ -414,32 +361,19 @@ inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
414361 const double *are = a.data (), *aim = a.data () + N / 2 ;
415362 const double *bre = b.data (), *bim = b.data () + N / 2 ;
416363 double *rre = res.data (), *rim = res.data () + N / 2 ;
417- // 2x unrolled to improve ILP on Zen 2 (2 FMA units)
418- for (uint32_t i = 0 ; i < N / 2 ; i += 8 ) {
419- __m256d va_re0 = _mm256_load_pd (are + i);
420- __m256d va_im0 = _mm256_load_pd (aim + i);
421- __m256d vb_re0 = _mm256_load_pd (bre + i);
422- __m256d vb_im0 = _mm256_load_pd (bim + i);
423- __m256d vr_re0 = _mm256_load_pd (rre + i);
424- __m256d vr_im0 = _mm256_load_pd (rim + i);
425- __m256d va_re1 = _mm256_load_pd (are + i + 4 );
426- __m256d va_im1 = _mm256_load_pd (aim + i + 4 );
427- __m256d vb_re1 = _mm256_load_pd (bre + i + 4 );
428- __m256d vb_im1 = _mm256_load_pd (bim + i + 4 );
429- __m256d vr_re1 = _mm256_load_pd (rre + i + 4 );
430- __m256d vr_im1 = _mm256_load_pd (rim + i + 4 );
431- vr_re0 = _mm256_fmadd_pd (va_re0, vb_re0, vr_re0);
432- vr_re1 = _mm256_fmadd_pd (va_re1, vb_re1, vr_re1);
433- vr_re0 = _mm256_fnmadd_pd (va_im0, vb_im0, vr_re0);
434- vr_re1 = _mm256_fnmadd_pd (va_im1, vb_im1, vr_re1);
435- vr_im0 = _mm256_fmadd_pd (va_im0, vb_re0, vr_im0);
436- vr_im1 = _mm256_fmadd_pd (va_im1, vb_re1, vr_im1);
437- vr_im0 = _mm256_fmadd_pd (va_re0, vb_im0, vr_im0);
438- vr_im1 = _mm256_fmadd_pd (va_re1, vb_im1, vr_im1);
439- _mm256_store_pd (rre + i, vr_re0);
440- _mm256_store_pd (rre + i + 4 , vr_re1);
441- _mm256_store_pd (rim + i, vr_im0);
442- _mm256_store_pd (rim + i + 4 , vr_im1);
364+ for (uint32_t i = 0 ; i < N / 2 ; i += 4 ) {
365+ __m256d va_re = _mm256_load_pd (are + i);
366+ __m256d va_im = _mm256_load_pd (aim + i);
367+ __m256d vb_re = _mm256_load_pd (bre + i);
368+ __m256d vb_im = _mm256_load_pd (bim + i);
369+ __m256d vr_re = _mm256_load_pd (rre + i);
370+ __m256d vr_im = _mm256_load_pd (rim + i);
371+ vr_re = _mm256_fmadd_pd (va_re, vb_re, vr_re);
372+ vr_re = _mm256_fnmadd_pd (va_im, vb_im, vr_re);
373+ vr_im = _mm256_fmadd_pd (va_im, vb_re, vr_im);
374+ vr_im = _mm256_fmadd_pd (va_re, vb_im, vr_im);
375+ _mm256_store_pd (rre + i, vr_re);
376+ _mm256_store_pd (rim + i, vr_im);
443377 }
444378#else
445379 for (int i = 0 ; i < N / 2 ; i++) {
0 commit comments