Skip to content

Commit fc94d92

Browse files
committed
Fix AVX2 circuit bootstrapping and CI test builds
1 parent 131ed02 commit fc94d92

9 files changed

Lines changed: 677 additions & 1182 deletions

File tree

.github/workflows/test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
submodules: recursive
2323
- name: build and test
2424
run: |
25-
cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DENABLE_TEST=ON
25+
cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Debug -DENABLE_TEST=ON
2626
cd build
2727
ninja
2828
test/test.sh
@@ -43,7 +43,7 @@ jobs:
4343
submodules: recursive
4444
- name: build and test
4545
run: |
46-
cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DENABLE_TEST=ON
46+
cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Debug -DENABLE_TEST=ON
4747
cd build
4848
ninja
4949
test/test.sh
@@ -64,7 +64,7 @@ jobs:
6464
submodules: recursive
6565
- name: build and test
6666
run: |
67-
cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DENABLE_TEST=ON -DUSE_CONCRETE=ON
67+
cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Debug -DENABLE_TEST=ON -DUSE_CONCRETE=ON
6868
cd build
6969
ninja
7070
test/test.sh
@@ -112,7 +112,7 @@ jobs:
112112
submodules: recursive
113113
- name: build and test
114114
run: |
115-
/cmake-binary/bin/cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD -DENABLE_TEST=ON
115+
/cmake-binary/bin/cmake . -B build -G Ninja -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD -DENABLE_TEST=ON
116116
cd build
117117
ninja
118118
test/test.sh

include/keyswitch.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ void SubsetPrivKeySwitch(TRLWE<typename P::targetP> &res,
244244
mask;
245245

246246
if (aij != 0)
247-
TRLWEAdd<typename P::targetP>(res, res, privksk[i][j][aij - 1]);
247+
TRLWESub<typename P::targetP>(res, res, privksk[i][j][aij - 1]);
248248
// for (int p = 0; p < P::targetP::n; p++)
249249
// for (int k = 0; k < P::targetP::k + 1; k++)
250250
// res[k][p] -= privksk[i][j][aij - 1][k][p];
@@ -547,4 +547,4 @@ void TLWE2TRLWEPacking(TRLWE<P> &res, std::vector<TLWE<P>> &tlwe,
547547
{
548548
PackLWEsLSB<P>(res, tlwe, ahk, P::nbit, 0, 1);
549549
}
550-
} // namespace TFHEpp
550+
} // namespace TFHEpp

include/mulfft.hpp

Lines changed: 38 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,7 @@ inline void TwistIFFTUInt(PolynomialInFD<P> &res, const Polynomial<P> &a)
219219
template <uint32_t N>
220220
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &b)
221221
{
222-
#if defined(USE_INTERLEAVED_FORMAT) && defined(__AVX2__)
223-
// AVX2 interleaved complex multiply: 2 complex per YMM
224-
for (uint32_t i = 0; i < N; i += 4) {
225-
__m256d a = _mm256_load_pd(res.data() + i);
226-
__m256d w = _mm256_load_pd(b.data() + i);
227-
__m256d w_swap = _mm256_permute_pd(w, 0b0101);
228-
__m256d a_re = _mm256_unpacklo_pd(a, a);
229-
__m256d a_im = _mm256_unpackhi_pd(a, a);
230-
_mm256_store_pd(res.data() + i,
231-
_mm256_fmaddsub_pd(a_re, w, _mm256_mul_pd(a_im, w_swap)));
232-
}
233-
#elif defined(USE_INTERLEAVED_FORMAT)
222+
#ifdef USE_INTERLEAVED_FORMAT
234223
for (int i = 0; i < N / 2; i++) {
235224
const std::complex tmp = std::complex(res[2 * i], res[2 * i + 1]) *
236225
std::complex(b[2 * i], b[2 * i + 1]);
@@ -256,27 +245,17 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &b)
256245
#elif defined(__AVX2__) && !defined(__AVX512F__)
257246
double *rre = res.data(), *rim = res.data() + N / 2;
258247
const double *bre = b.data(), *bim = b.data() + N / 2;
259-
for (uint32_t i = 0; i < N / 2; i += 8) {
260-
__m256d va_re0 = _mm256_load_pd(rre + i);
261-
__m256d va_im0 = _mm256_load_pd(rim + i);
262-
__m256d vb_re0 = _mm256_load_pd(bre + i);
263-
__m256d vb_im0 = _mm256_load_pd(bim + i);
264-
__m256d va_re1 = _mm256_load_pd(rre + i + 4);
265-
__m256d va_im1 = _mm256_load_pd(rim + i + 4);
266-
__m256d vb_re1 = _mm256_load_pd(bre + i + 4);
267-
__m256d vb_im1 = _mm256_load_pd(bim + i + 4);
268-
__m256d vr_re0 = _mm256_mul_pd(va_re0, vb_re0);
269-
__m256d vr_re1 = _mm256_mul_pd(va_re1, vb_re1);
270-
vr_re0 = _mm256_fnmadd_pd(va_im0, vb_im0, vr_re0);
271-
vr_re1 = _mm256_fnmadd_pd(va_im1, vb_im1, vr_re1);
272-
__m256d vr_im0 = _mm256_mul_pd(va_im0, vb_re0);
273-
__m256d vr_im1 = _mm256_mul_pd(va_im1, vb_re1);
274-
vr_im0 = _mm256_fmadd_pd(va_re0, vb_im0, vr_im0);
275-
vr_im1 = _mm256_fmadd_pd(va_re1, vb_im1, vr_im1);
276-
_mm256_store_pd(rre + i, vr_re0);
277-
_mm256_store_pd(rre + i + 4, vr_re1);
278-
_mm256_store_pd(rim + i, vr_im0);
279-
_mm256_store_pd(rim + i + 4, vr_im1);
248+
for (uint32_t i = 0; i < N / 2; i += 4) {
249+
__m256d va_re = _mm256_load_pd(rre + i);
250+
__m256d va_im = _mm256_load_pd(rim + i);
251+
__m256d vb_re = _mm256_load_pd(bre + i);
252+
__m256d vb_im = _mm256_load_pd(bim + i);
253+
__m256d vr_re = _mm256_mul_pd(va_re, vb_re);
254+
vr_re = _mm256_fnmadd_pd(va_im, vb_im, vr_re);
255+
__m256d vr_im = _mm256_mul_pd(va_im, vb_re);
256+
vr_im = _mm256_fmadd_pd(va_re, vb_im, vr_im);
257+
_mm256_store_pd(rre + i, vr_re);
258+
_mm256_store_pd(rim + i, vr_im);
280259
}
281260
#else
282261
for (int i = 0; i < N / 2; i++) {
@@ -292,17 +271,7 @@ template <uint32_t N>
292271
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
293272
const std::array<double, N> &b)
294273
{
295-
#if defined(USE_INTERLEAVED_FORMAT) && defined(__AVX2__)
296-
for (uint32_t i = 0; i < N; i += 4) {
297-
__m256d va = _mm256_load_pd(a.data() + i);
298-
__m256d w = _mm256_load_pd(b.data() + i);
299-
__m256d w_swap = _mm256_permute_pd(w, 0b0101);
300-
__m256d a_re = _mm256_unpacklo_pd(va, va);
301-
__m256d a_im = _mm256_unpackhi_pd(va, va);
302-
_mm256_store_pd(res.data() + i,
303-
_mm256_fmaddsub_pd(a_re, w, _mm256_mul_pd(a_im, w_swap)));
304-
}
305-
#elif defined(USE_INTERLEAVED_FORMAT)
274+
#ifdef USE_INTERLEAVED_FORMAT
306275
for (int i = 0; i < N / 2; i++) {
307276
const std::complex tmp = std::complex(a[2 * i], a[2 * i + 1]) *
308277
std::complex(b[2 * i], b[2 * i + 1]);
@@ -330,27 +299,17 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
330299
const double *are = a.data(), *aim = a.data() + N / 2;
331300
const double *bre = b.data(), *bim = b.data() + N / 2;
332301
double *rre = res.data(), *rim = res.data() + N / 2;
333-
for (uint32_t i = 0; i < N / 2; i += 8) {
334-
__m256d va_re0 = _mm256_load_pd(are + i);
335-
__m256d va_im0 = _mm256_load_pd(aim + i);
336-
__m256d vb_re0 = _mm256_load_pd(bre + i);
337-
__m256d vb_im0 = _mm256_load_pd(bim + i);
338-
__m256d va_re1 = _mm256_load_pd(are + i + 4);
339-
__m256d va_im1 = _mm256_load_pd(aim + i + 4);
340-
__m256d vb_re1 = _mm256_load_pd(bre + i + 4);
341-
__m256d vb_im1 = _mm256_load_pd(bim + i + 4);
342-
__m256d vr_re0 = _mm256_mul_pd(va_re0, vb_re0);
343-
__m256d vr_re1 = _mm256_mul_pd(va_re1, vb_re1);
344-
vr_re0 = _mm256_fnmadd_pd(va_im0, vb_im0, vr_re0);
345-
vr_re1 = _mm256_fnmadd_pd(va_im1, vb_im1, vr_re1);
346-
__m256d vr_im0 = _mm256_mul_pd(va_im0, vb_re0);
347-
__m256d vr_im1 = _mm256_mul_pd(va_im1, vb_re1);
348-
vr_im0 = _mm256_fmadd_pd(va_re0, vb_im0, vr_im0);
349-
vr_im1 = _mm256_fmadd_pd(va_re1, vb_im1, vr_im1);
350-
_mm256_store_pd(rre + i, vr_re0);
351-
_mm256_store_pd(rre + i + 4, vr_re1);
352-
_mm256_store_pd(rim + i, vr_im0);
353-
_mm256_store_pd(rim + i + 4, vr_im1);
302+
for (uint32_t i = 0; i < N / 2; i += 4) {
303+
__m256d va_re = _mm256_load_pd(are + i);
304+
__m256d va_im = _mm256_load_pd(aim + i);
305+
__m256d vb_re = _mm256_load_pd(bre + i);
306+
__m256d vb_im = _mm256_load_pd(bim + i);
307+
__m256d vr_re = _mm256_mul_pd(va_re, vb_re);
308+
vr_re = _mm256_fnmadd_pd(va_im, vb_im, vr_re);
309+
__m256d vr_im = _mm256_mul_pd(va_im, vb_re);
310+
vr_im = _mm256_fmadd_pd(va_re, vb_im, vr_im);
311+
_mm256_store_pd(rre + i, vr_re);
312+
_mm256_store_pd(rim + i, vr_im);
354313
}
355314
#else
356315
for (int i = 0; i < N / 2; i++) {
@@ -370,19 +329,7 @@ template <uint32_t N>
370329
inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
371330
const std::array<double, N> &b)
372331
{
373-
#if defined(USE_INTERLEAVED_FORMAT) && defined(__AVX2__)
374-
// AVX2 interleaved complex FMA: res += a * b
375-
for (uint32_t i = 0; i < N; i += 4) {
376-
__m256d va = _mm256_load_pd(a.data() + i);
377-
__m256d w = _mm256_load_pd(b.data() + i);
378-
__m256d r = _mm256_load_pd(res.data() + i);
379-
__m256d w_swap = _mm256_permute_pd(w, 0b0101);
380-
__m256d a_re = _mm256_unpacklo_pd(va, va);
381-
__m256d a_im = _mm256_unpackhi_pd(va, va);
382-
__m256d prod = _mm256_fmaddsub_pd(a_re, w, _mm256_mul_pd(a_im, w_swap));
383-
_mm256_store_pd(res.data() + i, _mm256_add_pd(r, prod));
384-
}
385-
#elif defined(USE_INTERLEAVED_FORMAT)
332+
#ifdef USE_INTERLEAVED_FORMAT
386333
for (int i = 0; i < N / 2; i++) {
387334
std::complex tmp = std::complex(a[2 * i], a[2 * i + 1]) *
388335
std::complex(b[2 * i], b[2 * i + 1]);
@@ -414,32 +361,19 @@ inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
414361
const double *are = a.data(), *aim = a.data() + N / 2;
415362
const double *bre = b.data(), *bim = b.data() + N / 2;
416363
double *rre = res.data(), *rim = res.data() + N / 2;
417-
// 2x unrolled to improve ILP on Zen 2 (2 FMA units)
418-
for (uint32_t i = 0; i < N / 2; i += 8) {
419-
__m256d va_re0 = _mm256_load_pd(are + i);
420-
__m256d va_im0 = _mm256_load_pd(aim + i);
421-
__m256d vb_re0 = _mm256_load_pd(bre + i);
422-
__m256d vb_im0 = _mm256_load_pd(bim + i);
423-
__m256d vr_re0 = _mm256_load_pd(rre + i);
424-
__m256d vr_im0 = _mm256_load_pd(rim + i);
425-
__m256d va_re1 = _mm256_load_pd(are + i + 4);
426-
__m256d va_im1 = _mm256_load_pd(aim + i + 4);
427-
__m256d vb_re1 = _mm256_load_pd(bre + i + 4);
428-
__m256d vb_im1 = _mm256_load_pd(bim + i + 4);
429-
__m256d vr_re1 = _mm256_load_pd(rre + i + 4);
430-
__m256d vr_im1 = _mm256_load_pd(rim + i + 4);
431-
vr_re0 = _mm256_fmadd_pd(va_re0, vb_re0, vr_re0);
432-
vr_re1 = _mm256_fmadd_pd(va_re1, vb_re1, vr_re1);
433-
vr_re0 = _mm256_fnmadd_pd(va_im0, vb_im0, vr_re0);
434-
vr_re1 = _mm256_fnmadd_pd(va_im1, vb_im1, vr_re1);
435-
vr_im0 = _mm256_fmadd_pd(va_im0, vb_re0, vr_im0);
436-
vr_im1 = _mm256_fmadd_pd(va_im1, vb_re1, vr_im1);
437-
vr_im0 = _mm256_fmadd_pd(va_re0, vb_im0, vr_im0);
438-
vr_im1 = _mm256_fmadd_pd(va_re1, vb_im1, vr_im1);
439-
_mm256_store_pd(rre + i, vr_re0);
440-
_mm256_store_pd(rre + i + 4, vr_re1);
441-
_mm256_store_pd(rim + i, vr_im0);
442-
_mm256_store_pd(rim + i + 4, vr_im1);
364+
for (uint32_t i = 0; i < N / 2; i += 4) {
365+
__m256d va_re = _mm256_load_pd(are + i);
366+
__m256d va_im = _mm256_load_pd(aim + i);
367+
__m256d vb_re = _mm256_load_pd(bre + i);
368+
__m256d vb_im = _mm256_load_pd(bim + i);
369+
__m256d vr_re = _mm256_load_pd(rre + i);
370+
__m256d vr_im = _mm256_load_pd(rim + i);
371+
vr_re = _mm256_fmadd_pd(va_re, vb_re, vr_re);
372+
vr_re = _mm256_fnmadd_pd(va_im, vb_im, vr_re);
373+
vr_im = _mm256_fmadd_pd(va_im, vb_re, vr_im);
374+
vr_im = _mm256_fmadd_pd(va_re, vb_im, vr_im);
375+
_mm256_store_pd(rre + i, vr_re);
376+
_mm256_store_pd(rim + i, vr_im);
443377
}
444378
#else
445379
for (int i = 0; i < N / 2; i++) {

include/trgsw.hpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -155,55 +155,6 @@ inline void DecompositionImpl(DecPolyType &decpoly, const Polynomial<P> &poly)
155155
constexpr typename P::T halfBg =
156156
static_cast<typename P::T>(1) << (D::Bgbit - 1);
157157

158-
#if defined(__AVX2__) && !defined(USE_AVX512)
159-
// AVX2 vectorized path for uint32_t with l̅=1 (the common case)
160-
if constexpr (std::is_same_v<typename P::T, uint32_t> && D::l̅ == 1) {
161-
const __m256i voffset = _mm256_set1_epi32(
162-
static_cast<int32_t>(offset + roundoffset));
163-
const __m256i vmask = _mm256_set1_epi32(static_cast<int32_t>(maskBg));
164-
const __m256i vhalf = _mm256_set1_epi32(static_cast<int32_t>(halfBg));
165-
for (int i = 0; i < D::l; i++) {
166-
const int shift = std::numeric_limits<uint32_t>::digits -
167-
(i + 1) * D::Bgbit;
168-
const __m128i vshift = _mm_cvtsi32_si128(shift);
169-
for (int n = 0; n < P::n; n += 8) {
170-
__m256i va = _mm256_loadu_si256(
171-
reinterpret_cast<const __m256i *>(poly.data() + n));
172-
va = _mm256_add_epi32(va, voffset);
173-
va = _mm256_srl_epi32(va, vshift);
174-
va = _mm256_and_si256(va, vmask);
175-
va = _mm256_sub_epi32(va, vhalf);
176-
_mm256_storeu_si256(
177-
reinterpret_cast<__m256i *>(&decpoly[i][n]), va);
178-
}
179-
}
180-
return;
181-
}
182-
// AVX2 vectorized path for uint64_t with l̅=1
183-
if constexpr (std::is_same_v<typename P::T, uint64_t> && D::l̅ == 1) {
184-
const __m256i voffset = _mm256_set1_epi64x(
185-
static_cast<int64_t>(offset + roundoffset));
186-
const __m256i vmask = _mm256_set1_epi64x(static_cast<int64_t>(maskBg));
187-
const __m256i vhalf = _mm256_set1_epi64x(static_cast<int64_t>(halfBg));
188-
for (int i = 0; i < D::l; i++) {
189-
const int shift = std::numeric_limits<uint64_t>::digits -
190-
(i + 1) * D::Bgbit;
191-
const __m128i vshift = _mm_cvtsi32_si128(shift);
192-
for (int n = 0; n < P::n; n += 4) {
193-
__m256i va = _mm256_loadu_si256(
194-
reinterpret_cast<const __m256i *>(poly.data() + n));
195-
va = _mm256_add_epi64(va, voffset);
196-
va = _mm256_srl_epi64(va, vshift);
197-
va = _mm256_and_si256(va, vmask);
198-
va = _mm256_sub_epi64(va, vhalf);
199-
_mm256_storeu_si256(
200-
reinterpret_cast<__m256i *>(&decpoly[i][n]), va);
201-
}
202-
}
203-
return;
204-
}
205-
#endif
206-
207158
for (int n = 0; n < P::n; n++) {
208159
typename P::T a = poly[n] + offset + roundoffset;
209160
for (int i = 0; i < D::l; i++) {

0 commit comments

Comments
 (0)