Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ set(FAISS_HEADERS
utils/hamming_distance/avx512-inl.h
utils/simd_impl/distances_autovec-inl.h
utils/simd_impl/distances_simdlib256.h
utils/simd_impl/exhaustive_L2sqr_blas_cmax.h
utils/simd_impl/IVFFlatScanner-inl.h
utils/simd_impl/distances_sse-inl.h
)
Expand Down
4 changes: 3 additions & 1 deletion faiss/IndexPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1(
} // namespace

FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this);
return with_simd_level([&]<SIMDLevel SL>() {
return get_FlatCodesDistanceComputer1<SL>(*this);
});
}

/*****************************************
Expand Down
70 changes: 27 additions & 43 deletions faiss/impl/fast_scan/fast_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,10 @@ std::unique_ptr<FastScanCodeScanner> make_fast_scan_knn_scanner(
int64_t* ids,
const IDSelector* sel,
bool with_id_map) {
DISPATCH_SIMDLevel(
make_fast_scan_scanner_impl,
is_max,
impl,
nq,
ntotal,
k,
distances,
ids,
sel,
with_id_map);
return with_simd_level([&]<SIMDLevel SL>() {
return make_fast_scan_scanner_impl<SL>(
is_max, impl, nq, ntotal, k, distances, ids, sel, with_id_map);
});
}

std::unique_ptr<FastScanCodeScanner> make_range_scanner(
Expand All @@ -445,8 +438,9 @@ std::unique_ptr<FastScanCodeScanner> make_range_scanner(
float radius,
size_t ntotal,
const IDSelector* sel) {
DISPATCH_SIMDLevel(
make_range_scanner_impl, is_max, rres, radius, ntotal, sel);
return with_simd_level([&]<SIMDLevel SL>() {
return make_range_scanner_impl<SL>(is_max, rres, radius, ntotal, sel);
});
}

std::unique_ptr<FastScanCodeScanner> make_partial_range_scanner(
Expand All @@ -457,15 +451,10 @@ std::unique_ptr<FastScanCodeScanner> make_partial_range_scanner(
size_t q0,
size_t q1,
const IDSelector* sel) {
DISPATCH_SIMDLevel(
make_partial_range_scanner_impl,
is_max,
pres,
radius,
ntotal,
q0,
q1,
sel);
return with_simd_level([&]<SIMDLevel SL>() {
return make_partial_range_scanner_impl<SL>(
is_max, pres, radius, ntotal, q0, q1, sel);
});
}

std::unique_ptr<FastScanCodeScanner> rabitq_make_knn_scanner(
Expand All @@ -478,17 +467,18 @@ std::unique_ptr<FastScanCodeScanner> rabitq_make_knn_scanner(
const IDSelector* sel,
const FastScanDistancePostProcessing& context,
bool is_multi_bit) {
DISPATCH_SIMDLevel(
rabitq_make_knn_scanner_impl,
index,
is_max,
nq,
k,
distances,
ids,
sel,
context,
is_multi_bit);
return with_simd_level([&]<SIMDLevel SL>() {
return rabitq_make_knn_scanner_impl<SL>(
index,
is_max,
nq,
k,
distances,
ids,
sel,
context,
is_multi_bit);
});
}

std::unique_ptr<FastScanCodeScanner> rabitq_ivf_make_knn_scanner(
Expand All @@ -500,16 +490,10 @@ std::unique_ptr<FastScanCodeScanner> rabitq_ivf_make_knn_scanner(
int64_t* ids,
const FastScanDistancePostProcessing* context,
bool multi_bit) {
DISPATCH_SIMDLevel(
rabitq_ivf_make_knn_scanner_impl,
is_max,
index,
nq,
k,
distances,
ids,
context,
multi_bit);
return with_simd_level([&]<SIMDLevel SL>() {
return rabitq_ivf_make_knn_scanner_impl<SL>(
is_max, index, nq, k, distances, ids, context, multi_bit);
});
}

} // namespace faiss
2 changes: 1 addition & 1 deletion faiss/impl/fast_scan/fast_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ std::unique_ptr<FastScanCodeScanner> make_fast_scan_scanner_impl(
bool with_id_map);

/// Runtime dispatch wrapper: selects the best available SIMD level
/// (via DISPATCH_SIMDLevel) and delegates to the corresponding
/// (via with_simd_level) and delegates to the corresponding
/// make_fast_scan_scanner_impl<SL> specialization.
std::unique_ptr<FastScanCodeScanner> make_fast_scan_knn_scanner(
bool is_max,
Expand Down
33 changes: 18 additions & 15 deletions faiss/impl/pq_code_distance/pq_code_distance-generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// 1. _impl specializations for NONE (and ARM_NEON), using scalar code.
// 2. Non-templated PQ code distance dispatch wrappers
// (pq_code_distance_single, pq_code_distance_four) declared in
// pq_code_distance.h. These use DISPATCH_SIMDLevel to route to the
// pq_code_distance.h. These use with_simd_level to route to the
// best available SIMD implementation via pq_code_distance_*_impl
// function template specializations defined in the per-SIMD .cpp files.

Expand Down Expand Up @@ -107,7 +107,9 @@ float pq_code_distance_single(
size_t nbits,
const float* sim_table,
const uint8_t* code) {
DISPATCH_SIMDLevel(pq_code_distance_single_impl, M, nbits, sim_table, code);
return with_simd_level([&]<SIMDLevel SL>() {
return pq_code_distance_single_impl<SL>(M, nbits, sim_table, code);
});
}

void pq_code_distance_four(
Expand All @@ -122,19 +124,20 @@ void pq_code_distance_four(
float& result1,
float& result2,
float& result3) {
DISPATCH_SIMDLevel(
pq_code_distance_four_impl,
M,
nbits,
sim_table,
code0,
code1,
code2,
code3,
result0,
result1,
result2,
result3);
with_simd_level([&]<SIMDLevel SL>() {
pq_code_distance_four_impl<SL>(
M,
nbits,
sim_table,
code0,
code1,
code2,
code3,
result0,
result1,
result2,
result3);
});
}

} // namespace pq_code_distance
Expand Down
147 changes: 70 additions & 77 deletions faiss/impl/simd_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,86 +23,85 @@

namespace faiss {

/*********************** x86 SIMD dispatch cases */
/** Defining which SIMD levels are available for a given function is via a
* binary mask. Here we predefine the most common masks.
* */

#ifdef COMPILE_SIMD_AVX2
#define DISPATCH_SIMDLevel_AVX2(f, ...) \
case SIMDLevel::AVX2: \
return f<SIMDLevel::AVX2>(__VA_ARGS__)
#else
#define DISPATCH_SIMDLevel_AVX2(f, ...)
#endif
constexpr int AVAILABLE_SIMD_LEVELS_NONE = (1 << int(SIMDLevel::NONE));

#ifdef COMPILE_SIMD_AVX512
#define DISPATCH_SIMDLevel_AVX512(f, ...) \
case SIMDLevel::AVX512: \
return f<SIMDLevel::AVX512>(__VA_ARGS__)
#else
#define DISPATCH_SIMDLevel_AVX512(f, ...)
#endif
constexpr int AVAILABLE_SIMD_LEVELS_AVX2_NEON = AVAILABLE_SIMD_LEVELS_NONE |
(1 << int(SIMDLevel::AVX2)) | (1 << int(SIMDLevel::ARM_NEON));

// A0: same + AVX512
constexpr int AVAILABLE_SIMD_LEVELS_A0 =
AVAILABLE_SIMD_LEVELS_AVX2_NEON | (1 << int(SIMDLevel::AVX512));

constexpr int AVAILABLE_SIMD_LEVELS_ALL = -1;

/** The complete dispatching function. It takes into account:
* - the currently selected SIMD level
* - the compiled in SIMD levels (given by COMPILE_SIMD_XXX)
* - the available SIMD implementations for that particular function (given by
* available_levels)
*/

template <int available_levels, typename LambdaType>
inline auto with_selected_simd_levels(LambdaType&& action) {
#ifdef FAISS_ENABLE_DD
switch (SIMDConfig::level) {
// For x86 -- try from highest to lowest level

#ifdef COMPILE_SIMD_AVX512_SPR
#define DISPATCH_SIMDLevel_AVX512_SPR(f, ...) \
case SIMDLevel::AVX512_SPR: \
return f<SIMDLevel::AVX512_SPR>(__VA_ARGS__)
#else
#define DISPATCH_SIMDLevel_AVX512_SPR(f, ...)
case SIMDLevel::AVX512_SPR:
if constexpr (
available_levels & (1 << int(SIMDLevel::AVX512_SPR))) {
return action.template operator()<SIMDLevel::AVX512_SPR>();
}
[[fallthrough]];
#endif

/*********************** ARM SIMD dispatch cases */
#ifdef COMPILE_SIMD_AVX512
case SIMDLevel::AVX512:
if constexpr (available_levels & (1 << int(SIMDLevel::AVX512))) {
return action.template operator()<SIMDLevel::AVX512>();
}
[[fallthrough]];
#endif

#ifdef COMPILE_SIMD_ARM_NEON
#define DISPATCH_SIMDLevel_ARM_NEON(f, ...) \
case SIMDLevel::ARM_NEON: \
return f<SIMDLevel::ARM_NEON>(__VA_ARGS__)
#else
#define DISPATCH_SIMDLevel_ARM_NEON(f, ...)
#ifdef COMPILE_SIMD_AVX2
case SIMDLevel::AVX2:
if constexpr (available_levels & (1 << int(SIMDLevel::AVX2))) {
return action.template operator()<SIMDLevel::AVX2>();
}
[[fallthrough]];
#endif

// For ARM, try from highest to lowest level
#ifdef COMPILE_SIMD_ARM_SVE
#define DISPATCH_SIMDLevel_ARM_SVE(f, ...) \
case SIMDLevel::ARM_SVE: \
return f<SIMDLevel::ARM_SVE>(__VA_ARGS__)
#else
#define DISPATCH_SIMDLevel_ARM_SVE(f, ...)
case SIMDLevel::ARM_SVE:
if constexpr (available_levels & (1 << int(SIMDLevel::ARM_SVE))) {
return action.template operator()<SIMDLevel::ARM_SVE>();
}
[[fallthrough]];
#endif

/*********************** Main dispatch macro */

#ifdef FAISS_ENABLE_DD

// DD mode: runtime dispatch based on SIMDConfig::level
#define DISPATCH_SIMDLevel(f, ...) \
switch (SIMDConfig::level) { \
case SIMDLevel::NONE: \
return f<SIMDLevel::NONE>(__VA_ARGS__); \
DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \
DISPATCH_SIMDLevel_AVX512(f, __VA_ARGS__); \
DISPATCH_SIMDLevel_AVX512_SPR(f, __VA_ARGS__); \
DISPATCH_SIMDLevel_ARM_NEON(f, __VA_ARGS__); \
DISPATCH_SIMDLevel_ARM_SVE(f, __VA_ARGS__); \
default: \
FAISS_THROW_MSG("Invalid SIMD level"); \
#ifdef COMPILE_SIMD_ARM_NEON
case SIMDLevel::ARM_NEON:
if constexpr (available_levels & (1 << int(SIMDLevel::ARM_NEON))) {
return action.template operator()<SIMDLevel::ARM_NEON>();
}
[[fallthrough]];
#endif
default:
return action.template operator()<SIMDLevel::NONE>();
}

#else // Static mode

// Static mode: direct call to compiled-in SIMD level (no runtime switch)
#if defined(COMPILE_SIMD_AVX512_SPR)
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX512_SPR>(__VA_ARGS__)
#elif defined(COMPILE_SIMD_AVX512)
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX512>(__VA_ARGS__)
#elif defined(COMPILE_SIMD_AVX2)
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::AVX2>(__VA_ARGS__)
#elif defined(COMPILE_SIMD_ARM_SVE)
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::ARM_SVE>(__VA_ARGS__)
#elif defined(COMPILE_SIMD_ARM_NEON)
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::ARM_NEON>(__VA_ARGS__)
#else
#define DISPATCH_SIMDLevel(f, ...) return f<SIMDLevel::NONE>(__VA_ARGS__)
#else // static dispatch
// In static mode, SINGLE_SIMD_LEVEL is a constexpr resolved at compile
// time, so this is a direct call with no runtime switch.
static_assert(available_levels & (1 << int(SINGLE_SIMD_LEVEL)));
return action.template operator()<SINGLE_SIMD_LEVEL>();
#endif

#endif // FAISS_ENABLE_DD
}

/**
* Dispatch to a lambda with SIMDLevel as a compile-time constant.
Expand All @@ -126,31 +125,25 @@ namespace faiss {
* });
*
* The lambda must be a generic lambda with a SIMDLevel template parameter.
* By default, the lambda uses levels AVX2 + AVX512 + NEON, since these are the
* most common cases.
*
* @param action A generic lambda with signature `template<SIMDLevel> T
* operator()()`
* @return The return value of the lambda
*/
template <typename LambdaType>
inline auto with_simd_level(LambdaType&& action) {
DISPATCH_SIMDLevel(action.template operator());
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(action);
}

/**
* Like with_simd_level, but maps to the 256-bit SIMD equivalent:
* AVX512, AVX512_SPR -> AVX2
* ARM_SVE -> ARM_NEON
* AVX2, ARM_NEON, NONE -> unchanged
*
* Use for functions implemented with simd8float32 (256-bit) operations
* Use for functions implemented with simdXintY (256-bit) operations
* that don't have dedicated AVX512 or SVE implementations.
*/
template <typename LambdaType>
inline auto with_simd_level_256bit(LambdaType&& action) {
return with_simd_level([&]<SIMDLevel level>() {
constexpr SIMDLevel level256 = simd256_level_selector<level>::value;
return action.template operator()<level256>();
});
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_AVX2_NEON>(action);
}

} // namespace faiss
Loading
Loading