diff --git a/.gitignore b/.gitignore index 38c769e2a..f2039e4a7 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,9 @@ allure-* !build_android.sh !build_ios.sh +# congfig +doc/ +config/ +examples/python/ +examples/c_api/ +logs/ \ No newline at end of file diff --git a/examples/c/index_example.c b/examples/c/index_example.c index 403c0bef9..e281a7368 100644 --- a/examples/c/index_example.c +++ b/examples/c/index_example.c @@ -85,6 +85,11 @@ int main() { zvec_index_params_set_metric_type(hnsw_params_fast, ZVEC_METRIC_TYPE_L2); zvec_index_params_set_hnsw_params(hnsw_params_fast, 16, 100); + // Demonstrate INT8 quantization with random rotation preprocessing + // (enable_rotate rotates vectors before INT8 quantization to reduce error) + zvec_index_params_set_quantize_type(hnsw_params_fast, ZVEC_QUANTIZE_TYPE_INT8); + zvec_index_params_set_quantizer_enable_rotate(hnsw_params_fast, true); + zvec_index_params_t *hnsw_params_balanced = zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); if (!hnsw_params_balanced) { diff --git a/python/zvec/__init__.py b/python/zvec/__init__.py index 5fdf9732c..929c75fa2 100644 --- a/python/zvec/__init__.py +++ b/python/zvec/__init__.py @@ -108,6 +108,7 @@ IVFIndexParam, IVFQueryParam, OptimizeOption, + QuantizerParam, VamanaIndexParam, VamanaQueryParam, ) @@ -171,6 +172,7 @@ "HnswQueryParam", "HnswRabitqQueryParam", "IVFQueryParam", + "QuantizerParam", "VamanaIndexParam", "VamanaQueryParam", # Extensions diff --git a/python/zvec/__init__.pyi b/python/zvec/__init__.pyi index dd468cae1..7177a5ec0 100644 --- a/python/zvec/__init__.pyi +++ b/python/zvec/__init__.pyi @@ -28,6 +28,7 @@ from .model.param import ( IVFIndexParam, IVFQueryParam, OptimizeOption, + QuantizerParam, VamanaIndexParam, VamanaQueryParam, ) @@ -74,6 +75,7 @@ __all__: list = [ "MetricType", "OptimizeOption", "QuantizeType", + "QuantizerParam", "Query", "ReRanker", "RrfReRanker", diff --git a/python/zvec/model/param/__init__.py b/python/zvec/model/param/__init__.py index 43fc1ddce..084a79d00 100644 --- a/python/zvec/model/param/__init__.py +++ b/python/zvec/model/param/__init__.py @@ -31,6 +31,7 @@ IVFIndexParam, IVFQueryParam, OptimizeOption, + QuantizerParam, VamanaIndexParam, VamanaQueryParam, ) @@ -53,6 +54,7 @@ "IndexOption", "InvertIndexParam", "OptimizeOption", + "QuantizerParam", "VamanaIndexParam", "VamanaQueryParam", ] diff --git a/python/zvec/model/param/__init__.pyi b/python/zvec/model/param/__init__.pyi index 759b41348..86c517e41 100644 --- a/python/zvec/model/param/__init__.pyi +++ b/python/zvec/model/param/__init__.pyi @@ -24,6 +24,7 @@ __all__: list[str] = [ "IndexParam", "InvertIndexParam", "OptimizeOption", + "QuantizerParam", "QueryParam", "SegmentOption", "VectorIndexParam", @@ -145,6 +146,8 @@ class FlatIndexParam(VectorIndexParam): quantize_type (QuantizeType): Optional quantization type for vector compression (e.g., FP16, INT8). Use ``QuantizeType.UNDEFINED`` to disable quantization. Default is ``QuantizeType.UNDEFINED``. + quantizer_param (QuantizerParam): Quantizer configuration (e.g., enable_rotate). + Default is ``QuantizerParam()``. Examples: >>> from zvec.typing import MetricType, QuantizeType @@ -161,6 +164,7 @@ class FlatIndexParam(VectorIndexParam): self, metric_type: _zvec.typing.MetricType = ..., quantize_type: _zvec.typing.QuantizeType = ..., + quantizer_param: QuantizerParam = ..., ) -> None: """ Constructs a FlatIndexParam instance. @@ -169,6 +173,8 @@ class FlatIndexParam(VectorIndexParam): metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED (no quantization). + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). """ def __repr__(self) -> str: ... @@ -224,6 +230,7 @@ class HnswIndexParam(VectorIndexParam): ef_construction: typing.SupportsInt = 500, quantize_type: _zvec.typing.QuantizeType = ..., use_contiguous_memory: bool = False, + quantizer_param: QuantizerParam = ..., ) -> None: ... def __repr__(self) -> str: ... def __setstate__(self, arg0: tuple) -> None: ... @@ -496,6 +503,7 @@ class IVFIndexParam(VectorIndexParam): n_iters: typing.SupportsInt = 10, use_soar: bool = False, quantize_type: _zvec.typing.QuantizeType = ..., + quantizer_param: QuantizerParam = ..., ) -> None: """ Constructs an IVFIndexParam instance. @@ -509,6 +517,8 @@ class IVFIndexParam(VectorIndexParam): use_soar (bool, optional): Enable SOAR optimization. Defaults to False. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). """ def __repr__(self) -> str: ... @@ -912,6 +922,49 @@ class SegmentOption: bool: Whether the segment is read-only. """ +class QuantizerParam: + """ + + Parameters for quantizer configuration. + + Encapsulates quantization-related settings such as enable_rotate. + Designed for future extensibility. + + Attributes: + enable_rotate (bool): Whether to apply random rotation before INT8 + quantization to reduce quantization error. + Only effective with quantize_type=INT8. Defaults to False. + + Examples: + >>> qp = QuantizerParam(enable_rotate=True) + >>> print(qp.enable_rotate) + True + """ + + def __getstate__(self) -> tuple: ... + def __init__(self, enable_rotate: bool = False) -> None: + """ + Constructs a QuantizerParam instance. + + Args: + enable_rotate (bool, optional): Whether to apply random rotation + before INT8 quantization. Defaults to False. + """ + + def __repr__(self) -> str: ... + def __setstate__(self, arg0: tuple) -> None: ... + def __eq__(self, arg0: typing.Any) -> bool: ... + def to_dict(self) -> dict: + """ + Convert to dictionary with all fields + """ + + @property + def enable_rotate(self) -> bool: + """ + bool: Whether random rotation is enabled before INT8 quantization. + """ + class VectorIndexParam(IndexParam): """ @@ -923,6 +976,7 @@ class VectorIndexParam(IndexParam): type (IndexType): The specific vector index type (e.g., HNSW, FLAT). metric_type (MetricType): Distance metric used for similarity search. quantize_type (QuantizeType): Optional vector quantization type. + quantizer_param (QuantizerParam): Quantizer configuration (e.g., enable_rotate). """ def __getstate__(self) -> tuple: ... @@ -944,6 +998,12 @@ class VectorIndexParam(IndexParam): QuantizeType: Vector quantization type (e.g., FP16, INT8). """ + @property + def quantizer_param(self) -> QuantizerParam: + """ + QuantizerParam: Quantizer configuration including enable_rotate. + """ + class _SearchQuery: field_name: str filter: str diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5fc4e34ea..6196aedfa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -64,6 +64,7 @@ function(zvec_add_all_in_one_shared TARGET_NAME OUTPUT_NAME) ${ZVEC_ALLIN_LIBS} Threads::Threads ) + target_link_options(${TARGET_NAME} PRIVATE /OPT:REF /OPT:ICF) elseif(APPLE) foreach(ZVEC_ALLIN_LIB ${ZVEC_ALLIN_LIBS}) list(APPEND ZVEC_ALLIN_WA_OPTIONS diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index a81cc3864..6c7f226db 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -1476,6 +1476,60 @@ zvec_quantize_type_t zvec_index_params_get_quantize_type( return ZVEC_QUANTIZE_TYPE_UNDEFINED; } +/** + * @brief Set enable_rotate for quantizer parameters + * @param params Index parameters (must be vector index type) + * @param enable_rotate Whether to enable random rotation before quantization + * @return ZVEC_OK on success, error code on failure + */ +zvec_error_code_t zvec_index_params_set_quantizer_enable_rotate( + zvec_index_params_t *params, bool enable_rotate) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Index params pointer cannot be null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *cpp_params = reinterpret_cast(params); + + if (!cpp_params->is_vector_index_type()) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Index params is not a vector index type"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *vec_params = dynamic_cast(cpp_params); + if (!vec_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Failed to cast to VectorIndexParams"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + zvec::QuantizerParam qp = vec_params->quantizer_param(); + qp.set_enable_rotate(enable_rotate); + vec_params->set_quantizer_param(qp); + return ZVEC_OK; +} + +/** + * @brief Get enable_rotate setting from quantizer parameters + * @param params Index parameters + * @return true if rotation is enabled, false otherwise + */ +bool zvec_index_params_get_quantizer_enable_rotate( + const zvec_index_params_t *params) { + if (!params) { + return false; + } + auto *cpp_params = reinterpret_cast(params); + + if (cpp_params->is_vector_index_type()) { + auto *vec_params = + dynamic_cast(cpp_params); + if (vec_params) { + return vec_params->quantizer_param().enable_rotate(); + } + } + return false; +} + /** * @brief Get index type from index parameters * @param params Index parameters diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 268214dbd..80dc203af 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -338,6 +338,61 @@ Constructs an FtsIndexParam instance. t[2].cast()); })); + // binding QuantizerParam + py::class_> quantizer_param( + m, "QuantizerParam", R"pbdoc( +Parameters for quantizer configuration. + +Encapsulates quantization-related settings such as enable_rotate. +Designed for future extensibility. + +Attributes: + enable_rotate (bool): Whether to apply random rotation before INT8 + quantization to reduce quantization error. + Only effective with quantize_type=INT8. Defaults to False. + +Examples: + >>> qp = QuantizerParam(enable_rotate=True) + >>> print(qp.enable_rotate) + True +)pbdoc"); + quantizer_param.def(py::init(), py::arg("enable_rotate") = false) + .def_property_readonly( + "enable_rotate", + [](const QuantizerParam &self) -> bool { + return self.enable_rotate(); + }, + "bool: Whether random rotation is enabled before INT8 quantization.") + .def( + "to_dict", + [](const QuantizerParam &self) -> py::dict { + py::dict dict; + dict["enable_rotate"] = self.enable_rotate(); + return dict; + }, + "Convert to dictionary with all fields") + .def("__repr__", + [](const QuantizerParam &self) -> std::string { + return "{\"enable_rotate\":" + + std::string(self.enable_rotate() ? "true" : "false") + "}"; + }) + .def( + "__eq__", + [](const QuantizerParam &self, const py::object &other) { + if (!py::isinstance(other)) return false; + return self == other.cast(); + }, + py::is_operator()) + .def(py::pickle( + [](const QuantizerParam &self) { + return py::make_tuple(self.enable_rotate()); + }, + [](py::tuple t) { + if (t.size() != 1) + throw std::runtime_error("Invalid state for QuantizerParam"); + return std::make_shared(t[0].cast()); + })); + // binding base vector index params py::class_> vector_params(m, "VectorIndexParam", R"pbdoc( @@ -349,6 +404,7 @@ Encapsulates common settings for all vector index types. type (IndexType): The specific vector index type (e.g., HNSW, FLAT). metric_type (MetricType): Distance metric used for similarity search. quantize_type (QuantizeType): Optional vector quantization type. + quantizer_param (QuantizerParam): Quantizer configuration (e.g., enable_rotate). )pbdoc"); vector_params .def_property_readonly( @@ -363,6 +419,12 @@ Encapsulates common settings for all vector index types. return self.quantize_type(); }, "QuantizeType: Vector quantization type (e.g., FP16, INT8).") + .def_property_readonly( + "quantizer_param", + [](const VectorIndexParams &self) -> QuantizerParam { + return self.quantizer_param(); + }, + "QuantizerParam: Quantizer configuration including enable_rotate.") .def( "to_dict", [](const VectorIndexParams &self) -> py::dict { @@ -371,6 +433,9 @@ Encapsulates common settings for all vector index types. dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") @@ -382,7 +447,7 @@ Encapsulates common settings for all vector index types. [](py::tuple t) { // __setstate__ if (t.size() != 3) throw std::runtime_error("Invalid state for VectorIndexParams"); - // 基类,不能直接实例化,用于子类 + // Base class, cannot instantiate directly, used by subclasses return std::shared_ptr(); })); @@ -421,13 +486,20 @@ encapsulates its construction hyperparameters. {'metric_type': 'IP', 'm': 16, 'ef_construction': 200, 'quantize_type': 'INT8', 'use_contiguous_memory': True} )pbdoc"); hnsw_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int m, int ef_construction, + QuantizeType quantize_type, bool use_contiguous_memory, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, m, ef_construction, quantize_type, + use_contiguous_memory, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("m") = core_interface::kDefaultHnswNeighborCnt, py::arg("ef_construction") = core_interface::kDefaultHnswEfConstruction, py::arg("quantize_type") = QuantizeType::UNDEFINED, - py::arg("use_contiguous_memory") = false) + py::arg("use_contiguous_memory") = false, + py::arg("quantizer_param") = QuantizerParam()) .def_property_readonly( "m", &HnswIndexParams::m, "int: Maximum number of neighbors per node in upper layers.") @@ -450,34 +522,43 @@ encapsulates its construction hyperparameters. dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); dict["use_contiguous_memory"] = self.use_contiguous_memory(); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const HnswIndexParams &self) -> std::string { - return "{" - "\"metric_type\":" + - metric_type_to_string(self.metric_type()) + - ", \"m\":" + std::to_string(self.m()) + - ", \"ef_construction\":" + - std::to_string(self.ef_construction()) + - ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + - ", \"use_contiguous_memory\":" + - (self.use_contiguous_memory() ? "true" : "false") + "}"; - }) + .def( + "__repr__", + [](const HnswIndexParams &self) -> std::string { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"m\":" + std::to_string(self.m()) + + ", \"ef_construction\":" + + std::to_string(self.ef_construction()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + + ", \"use_contiguous_memory\":" + + (self.use_contiguous_memory() ? "true" : "false") + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const HnswIndexParams &self) { return py::make_tuple(self.metric_type(), self.m(), self.ef_construction(), self.quantize_type(), - self.use_contiguous_memory()); + self.use_contiguous_memory(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 5) + if (t.size() != 5 && t.size() != 6) throw std::runtime_error("Invalid state for HnswIndexParams"); + QuantizerParam qp(t.size() >= 6 ? t[5].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), - t[3].cast(), t[4].cast()); + t[3].cast(), t[4].cast(), qp); })); // binding hnsw rabitq index params @@ -626,8 +707,16 @@ its construction hyperparameters. ... ) )pbdoc"); vamana_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int max_degree, + int search_list_size, float alpha, bool saturate_graph, + bool use_contiguous_memory, bool use_id_map, + QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, max_degree, search_list_size, alpha, + saturate_graph, use_contiguous_memory, use_id_map, + quantize_type, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("max_degree") = core_interface::kDefaultVamanaMaxDegree, py::arg("search_list_size") = @@ -637,7 +726,8 @@ its construction hyperparameters. core_interface::kDefaultVamanaSaturateGraph, py::arg("use_contiguous_memory") = false, py::arg("use_id_map") = false, - py::arg("quantize_type") = QuantizeType::UNDEFINED) + py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam()) .def_property_readonly( "max_degree", &VamanaIndexParams::max_degree, "int: Maximum out-degree (R) of every node in the Vamana graph.") @@ -673,45 +763,53 @@ its construction hyperparameters. dict["use_id_map"] = self.use_id_map(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const VamanaIndexParams &self) -> std::string { - return "{" - "\"type\":\"" + - index_type_to_string(self.type()) + - "\", \"metric_type\":\"" + - metric_type_to_string(self.metric_type()) + - "\", \"max_degree\":" + std::to_string(self.max_degree()) + - ", \"search_list_size\":" + - std::to_string(self.search_list_size()) + - ", \"alpha\":" + std::to_string(self.alpha()) + - ", \"saturate_graph\":" + - std::string(self.saturate_graph() ? "true" : "false") + - ", \"use_contiguous_memory\":" + - std::string(self.use_contiguous_memory() ? "true" - : "false") + - ", \"use_id_map\":" + - std::string(self.use_id_map() ? "true" : "false") + - ", \"quantize_type\":\"" + - quantize_type_to_string(self.quantize_type()) + "\"}"; - }) + .def( + "__repr__", + [](const VamanaIndexParams &self) -> std::string { + return "{" + "\"type\":\"" + + index_type_to_string(self.type()) + + "\", \"metric_type\":\"" + + metric_type_to_string(self.metric_type()) + + "\", \"max_degree\":" + std::to_string(self.max_degree()) + + ", \"search_list_size\":" + + std::to_string(self.search_list_size()) + + ", \"alpha\":" + std::to_string(self.alpha()) + + ", \"saturate_graph\":" + + std::string(self.saturate_graph() ? "true" : "false") + + ", \"use_contiguous_memory\":" + + std::string(self.use_contiguous_memory() ? "true" + : "false") + + ", \"use_id_map\":" + + std::string(self.use_id_map() ? "true" : "false") + + ", \"quantize_type\":\"" + + quantize_type_to_string(self.quantize_type()) + + "\", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const VamanaIndexParams &self) { - return py::make_tuple(self.metric_type(), self.max_degree(), - self.search_list_size(), self.alpha(), - self.saturate_graph(), - self.use_contiguous_memory(), - self.use_id_map(), self.quantize_type()); + return py::make_tuple( + self.metric_type(), self.max_degree(), self.search_list_size(), + self.alpha(), self.saturate_graph(), + self.use_contiguous_memory(), self.use_id_map(), + self.quantize_type(), self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 8) + if (t.size() != 8 && t.size() != 9) throw std::runtime_error("Invalid state for VamanaIndexParams"); + QuantizerParam qp(t.size() >= 9 ? t[8].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast(), t[5].cast(), - t[6].cast(), t[7].cast()); + t[6].cast(), t[7].cast(), qp); })); // FlatIndexParams @@ -741,9 +839,14 @@ suitable for small to medium datasets or as a baseline. {'metric_type': 'L2', 'quantize_type': 'FP16'} )pbdoc"); flat_params - .def(py::init(), + .def(py::init([](MetricType metric_type, QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, quantize_type, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam(), R"pbdoc( Constructs a FlatIndexParam instance. @@ -751,6 +854,8 @@ Constructs a FlatIndexParam instance. metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED (no quantization). + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). )pbdoc") .def( "to_dict", @@ -759,26 +864,35 @@ Constructs a FlatIndexParam instance. dict["metric_type"] = metric_type_to_string(self.metric_type()); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const FlatIndexParams &self) -> std::string { - return "{" - "\"metric_type\":" + - metric_type_to_string(self.metric_type()) + - ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + "}"; - }) + .def( + "__repr__", + [](const FlatIndexParams &self) -> std::string { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const FlatIndexParams &self) { - return py::make_tuple(self.metric_type(), self.quantize_type()); + return py::make_tuple(self.metric_type(), self.quantize_type(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 2) + if (t.size() != 2 && t.size() != 3) throw std::runtime_error("Invalid state for FlatIndexParams"); - return std::make_shared(t[0].cast(), - t[1].cast()); + QuantizerParam qp(t.size() >= 3 ? t[2].cast() : false); + return std::make_shared( + t[0].cast(), t[1].cast(), qp); })); // IVFIndexParams @@ -815,10 +929,17 @@ and accuracy. 100 )pbdoc"); ivf_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int n_list, int n_iters, + bool use_soar, QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, n_list, n_iters, use_soar, quantize_type, + quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("n_list") = 10, py::arg("n_iters") = 10, py::arg("use_soar") = false, py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam(), R"pbdoc( Constructs an IVFIndexParam instance. @@ -831,6 +952,8 @@ Constructs an IVFIndexParam instance. use_soar (bool, optional): Enable SOAR optimization. Defaults to False. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). )pbdoc") .def_property_readonly("n_list", &IVFIndexParams::n_list, "int: Number of inverted lists.") @@ -850,32 +973,41 @@ Constructs an IVFIndexParam instance. dict["use_soar"] = self.use_soar(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") - .def("__repr__", - [](const IVFIndexParams &self) { - return "{" - "\"metric_type\":" + - metric_type_to_string(self.metric_type()) + - ", \"n_list\":" + std::to_string(self.n_list()) + - ", \"n_iters\":" + std::to_string(self.n_iters()) + - ", \"use_soar\":" + std::to_string(self.use_soar()) + - ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + "}"; - }) + .def( + "__repr__", + [](const IVFIndexParams &self) { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"n_list\":" + std::to_string(self.n_list()) + + ", \"n_iters\":" + std::to_string(self.n_iters()) + + ", \"use_soar\":" + std::to_string(self.use_soar()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; + }) .def(py::pickle( [](const IVFIndexParams &self) { return py::make_tuple(self.metric_type(), self.n_list(), self.n_iters(), self.use_soar(), - self.quantize_type()); + self.quantize_type(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 5) + if (t.size() != 5 && t.size() != 6) throw std::runtime_error("Invalid state for IVFIndexParams"); + QuantizerParam qp(t.size() >= 6 ? t[5].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), - t[3].cast(), t[4].cast()); + t[3].cast(), t[4].cast(), qp); })); // DiskAnnIndexParams @@ -915,10 +1047,17 @@ only compressed vector will be loaded into memory. By this way, search memory at 100 )pbdoc"); diskann_params - .def(py::init(), + .def(py::init([](MetricType metric_type, int max_degree, int list_size, + int pq_chunk_num, QuantizeType quantize_type, + QuantizerParam quantizer_param) { + return std::make_shared( + metric_type, max_degree, list_size, pq_chunk_num, + quantize_type, quantizer_param); + }), py::arg("metric_type") = MetricType::IP, py::arg("max_degree") = 100, py::arg("list_size") = 50, py::arg("pq_chunk_num") = 0, py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("quantizer_param") = QuantizerParam(), R"pbdoc( Constructs an DiskAnnIndexParams instance. @@ -933,6 +1072,8 @@ Constructs an DiskAnnIndexParams instance. Clamped to [1, 1024]. Defaults to 0. quantize_type (QuantizeType, optional): Vector quantization type. Defaults to QuantizeType.UNDEFINED. + quantizer_param (QuantizerParam, optional): Quantizer configuration. + Defaults to QuantizerParam(). )pbdoc") .def_property_readonly("max_degree", &DiskAnnIndexParams::max_degree, "int: max node degree.") @@ -955,6 +1096,9 @@ Constructs an DiskAnnIndexParams instance. dict["pq_chunk_num"] = self.pq_chunk_num(); dict["quantize_type"] = quantize_type_to_string(self.quantize_type()); + py::dict qp_dict; + qp_dict["enable_rotate"] = self.quantizer_param().enable_rotate(); + dict["quantizer_param"] = qp_dict; return dict; }, "Convert to dictionary with all fields") @@ -968,20 +1112,25 @@ Constructs an DiskAnnIndexParams instance. ", \"list_size\":" + std::to_string(self.list_size()) + ", \"pq_chunk_num\":" + std::to_string(self.pq_chunk_num()) + ", \"quantize_type\":" + - quantize_type_to_string(self.quantize_type()) + "}"; + quantize_type_to_string(self.quantize_type()) + + ", \"quantizer_param\":{" + "\"enable_rotate\":" + + (self.quantizer_param().enable_rotate() ? "true" : "false") + + "}}"; }) .def(py::pickle( [](const DiskAnnIndexParams &self) { return py::make_tuple(self.metric_type(), self.max_degree(), self.list_size(), self.pq_chunk_num(), - self.quantize_type()); + self.quantize_type(), + self.quantizer_param().enable_rotate()); }, [](py::tuple t) { - if (t.size() != 5) + if (t.size() != 5 && t.size() != 6) throw std::runtime_error("Invalid state for DiskAnnIndexParams"); + QuantizerParam qp(t.size() >= 6 ? t[5].cast() : false); return std::make_shared( t[0].cast(), t[1].cast(), t[2].cast(), - t[3].cast(), t[4].cast()); + t[3].cast(), t[4].cast(), qp); })); } diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 6405c3220..cac3d1840 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -12,6 +12,7 @@ if(RABITQ_SUPPORTED AND AUTO_DETECT_ARCH) ) set(HNSW_RABITQ_FILES_FULL ${HNSW_RABITQ_FILES}) list(TRANSFORM HNSW_RABITQ_FILES_FULL PREPEND "algorithm/hnsw_rabitq/") + foreach(FILE ${HNSW_RABITQ_FILES_FULL}) set_source_files_properties( ${FILE} @@ -42,6 +43,26 @@ if(NOT ANDROID AND AUTO_DETECT_ARCH) endif() endif() +# quantizer/record_rotator.cc uses FFHT AVX inline assembly guarded by +# __AVX2__/__AVX512F__. zvec_core glob-collects this source, so per-file +# AVX2 flags must be set here as well (in addition to the core_quantizer +# target in quantizer/CMakeLists.txt). Without this the zvec_core copy +# would compile without AVX2 and the fast path would never be emitted. +if(NOT ANDROID AND AUTO_DETECT_ARCH) + if(HOST_ARCH MATCHES "^(x86|x64)$") + setup_compiler_march_for_x86( + _ROTATOR_MARCH_SSE _ROTATOR_MARCH_AVX2 + _ROTATOR_MARCH_AVX512 _ROTATOR_MARCH_AVX512FP16) + if(_ROTATOR_MARCH_AVX2) + set_source_files_properties( + quantizer/record_rotator.cc + PROPERTIES + COMPILE_FLAGS "${_ROTATOR_MARCH_AVX2}" + ) + endif() + endif() +endif() + cc_directory(framework) cc_directory(algorithm) cc_directory(metric) diff --git a/src/core/algorithm/ivf/ivf_entity.cc b/src/core/algorithm/ivf/ivf_entity.cc index 6dccc2b2c..decc86d22 100644 --- a/src/core/algorithm/ivf/ivf_entity.cc +++ b/src/core/algorithm/ivf/ivf_entity.cc @@ -71,6 +71,16 @@ int IVFEntity::IVFReformerWrapper::init(const IndexMeta &imeta) { return 0; } +//! Load reformer state (e.g. rotation matrix) from storage +int IVFEntity::IVFReformerWrapper::load(const IndexStorage::Pointer &storage) { + if (!reformer_) { + return 0; + } + int ret = reformer_->load(storage); + ivf_check_with_msg(ret, "Failed to load reformer state"); + return 0; +} + //! Update the params, Called by gpu searcher only int IVFEntity::IVFReformerWrapper::update(const IndexMeta &meta) { auto &name = meta.reformer_name(); @@ -503,6 +513,12 @@ int IVFEntity::load(const IndexStorage::Pointer &container) { //! Load the remaining segments container_ = container; + + //! Load reformer state (e.g. rotation matrix) from the main container, + //! which holds the rotator segment dumped at build time. + ret = reformer_.load(container); + ivf_check_error_code(ret); + size_t expect_size = header_.inverted_body_size; inverted_ = load_segment(IVF_INVERTED_BODY_SEG_ID, expect_size); if (!inverted_) { diff --git a/src/core/algorithm/ivf/ivf_entity.h b/src/core/algorithm/ivf/ivf_entity.h index e6fd4b6c4..e0265b6eb 100644 --- a/src/core/algorithm/ivf/ivf_entity.h +++ b/src/core/algorithm/ivf/ivf_entity.h @@ -267,6 +267,9 @@ class IVFEntity { //! Initialize int init(const IndexMeta &imeta); + //! Load reformer state (e.g. rotation matrix) from storage + int load(const IndexStorage::Pointer &storage); + //! Update int update(const IndexMeta &meta); diff --git a/src/core/algorithm/ivf/ivf_params.h b/src/core/algorithm/ivf/ivf_params.h index a33a7aa50..6cd66b474 100644 --- a/src/core/algorithm/ivf/ivf_params.h +++ b/src/core/algorithm/ivf/ivf_params.h @@ -62,6 +62,8 @@ static const std::string PARAM_IVF_BUILDER_BLOCK_VECTOR_COUNT( // searcher params static const std::string PARAM_IVF_SEARCHER_SCAN_RATIO( "proxima.ivf.searcher.scan_ratio"); +static const std::string PARAM_IVF_SEARCHER_NPROBE( + "proxima.ivf.searcher.nprobe"); static const std::string PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD( "proxima.ivf.searcher.brute_force_threshold"); static const std::string PARAM_IVF_SEARCHER_OPTIMIZER( diff --git a/src/core/algorithm/ivf/ivf_searcher.cc b/src/core/algorithm/ivf/ivf_searcher.cc index 972fc8680..047046701 100644 --- a/src/core/algorithm/ivf/ivf_searcher.cc +++ b/src/core/algorithm/ivf/ivf_searcher.cc @@ -86,6 +86,13 @@ int IVFSearcher::load(IndexStorage::Pointer container, } auto reformer = centroid_index_->reformer(); + if (reformer) { + //! The centroid index is loaded from the centroid sub-segment which does + //! not contain the rotator segment. Load the reformer state (e.g. rotation + //! matrix) from the main container instead. + ret = reformer->load(container); + ivf_check_error_code(ret); + } params_.set(PARAM_IVF_SEARCHER_CONVERTER_REFORMER, reformer); //! load iverted index diff --git a/src/core/algorithm/ivf/ivf_searcher_context.h b/src/core/algorithm/ivf/ivf_searcher_context.h index d9ccc45c1..dbd2b7ae1 100644 --- a/src/core/algorithm/ivf/ivf_searcher_context.h +++ b/src/core/algorithm/ivf/ivf_searcher_context.h @@ -62,18 +62,27 @@ class IVFSearcherContext : public IndexSearcher::Context { params.get(PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD, &bruteforce_threshold_); params.get(PARAM_IVF_SEARCHER_SCAN_RATIO, &scan_ratio_); + params.get(PARAM_IVF_SEARCHER_NPROBE, &nprobe_); if (scan_ratio_ <= 0.0) { LOG_ERROR("Invalid params %s=%f", PARAM_IVF_SEARCHER_SCAN_RATIO.c_str(), scan_ratio_); return IndexError_InvalidArgument; } - size_t topk_val = - std::max(static_cast( - std::round(entity_->inverted_list_count() * scan_ratio_)), - 1u); - centroid_searcher_ctx_->set_topk(topk_val); - max_scan_count_ = - static_cast(std::ceil(entity_->vector_count() * scan_ratio_)); + size_t nlist = entity_->inverted_list_count(); + size_t topk_val; + if (nprobe_ > 0) { + //! nprobe explicitly controls how many inverted lists (centroids) to + //! probe. Do not let max_scan_count_ cut off the probed lists. + topk_val = std::min(static_cast(nprobe_), nlist); + topk_val = std::max(topk_val, static_cast(1)); + max_scan_count_ = static_cast(entity_->vector_count()); + } else { + topk_val = + std::max(static_cast(std::round(nlist * scan_ratio_)), 1u); + max_scan_count_ = static_cast( + std::ceil(entity_->vector_count() * scan_ratio_)); + } + centroid_searcher_ctx_->set_topk(static_cast(topk_val)); max_scan_count_ = std::max(bruteforce_threshold_, max_scan_count_); return 0; } @@ -215,6 +224,7 @@ class IVFSearcherContext : public IndexSearcher::Context { uint32_t topk_{0}; uint32_t magic_{0}; float scan_ratio_{kDefaultScanRatio}; + int nprobe_{0}; uint32_t max_scan_count_{0}; uint32_t bruteforce_threshold_{kDefaultBfThreshold}; }; diff --git a/src/core/algorithm/ivf/ivf_streamer.cc b/src/core/algorithm/ivf/ivf_streamer.cc index a2c924141..e42728e9a 100644 --- a/src/core/algorithm/ivf/ivf_streamer.cc +++ b/src/core/algorithm/ivf/ivf_streamer.cc @@ -86,6 +86,13 @@ int IVFStreamer::open(IndexStorage::Pointer storage) { } auto reformer = centroid_index_->reformer(); + if (reformer) { + //! The centroid index is loaded from the centroid sub-segment which does + //! not contain the rotator segment. Load the reformer state (e.g. rotation + //! matrix) from the main storage instead. + ret = reformer->load(storage); + ivf_check_error_code(ret); + } params_.set(PARAM_IVF_SEARCHER_CONVERTER_REFORMER, reformer); //! load iverted index diff --git a/src/core/interface/index.cc b/src/core/interface/index.cc index 332d4526f..4cb266d3a 100644 --- a/src/core/interface/index.cc +++ b/src/core/interface/index.cc @@ -182,6 +182,21 @@ int Index::CreateAndInitConverterReformer(const QuantizerParam ¶m, } } + // Pass enable_rotate to converter_params (only effective for INT8) + if (param.enable_rotate) { + if (param.type == QuantizerType::kInt8) { + if (index_param.metric_type == MetricType::kCosine) { + converter_params.set("cosine.converter.enable_rotate", true); + } else { + converter_params.set("integer_streaming.converter.enable_rotate", true); + } + } else { + LOG_WARN( + "enable_rotate is only supported for INT8 quantizer, " + "ignoring for current quantizer type"); + } + } + proxima_index_meta_.set_converter(converter_name, 0, converter_params); converter_ = core::IndexFactory::CreateConverter(converter_name); if (converter_ == nullptr || @@ -336,6 +351,21 @@ int Index::Open(const std::string &file_path, StorageOptions storage_options) { // converter/reformer/metric are created in IndexFactory::CreateIndex // TODO: init + // Load reformer data from storage (e.g., rotation matrix for + // IntegerStreaming) + if (reformer_ != nullptr) { + // When building a new index, dump converter state (e.g., rotator) to + // storage so the reformer can load it. This is needed for + // enable_rotate with INT8 quantization. + if (storage_options.create_new && converter_ != nullptr) { + converter_->dump_to_storage(storage_); + } + if (reformer_->load(storage_) != 0) { + LOG_ERROR("Failed to load reformer, path: %s", file_path.c_str()); + return core::IndexError_Runtime; + } + } + // TODO: context pool if (!init_context()) { // to validate if any error, will be overwritten LOG_ERROR("Failed to init context"); diff --git a/src/core/interface/index_param.cc b/src/core/interface/index_param.cc index 9226eeed0..5d75276fd 100644 --- a/src/core/interface/index_param.cc +++ b/src/core/interface/index_param.cc @@ -251,12 +251,16 @@ ailego::JsonObject QuantizerParam::SerializeToJsonObject( json_obj.set("type", zvec::ailego::JsonValue(magic_enum::enum_name(type).data())); } + if (!omit_empty_value || enable_rotate) { + json_obj.set("enable_rotate", ailego::JsonValue(enable_rotate)); + } return json_obj; } bool QuantizerParam::DeserializeFromJsonObject( const ailego::JsonObject &json_obj) { DESERIALIZE_ENUM_FIELD(json_obj, type, QuantizerType); + DESERIALIZE_VALUE_FIELD(json_obj, enable_rotate); return true; } diff --git a/src/core/interface/indexes/ivf_index.cc b/src/core/interface/indexes/ivf_index.cc index d85acce62..6bd793b2a 100644 --- a/src/core/interface/indexes/ivf_index.cc +++ b/src/core/interface/indexes/ivf_index.cc @@ -121,6 +121,10 @@ int IVFIndex::Open(const std::string &file_path, LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } + // Load reformer data from storage (e.g., rotation matrix for INT8+rotate) + if (reformer_ != nullptr) { + reformer_->load(storage_); + } is_trained_ = true; } is_open_ = true; @@ -164,6 +168,10 @@ int IVFIndex::Train() { dumper->create(file_path_); builder_->dump(dumper); + // Dump converter state (e.g., rotator for INT8+rotate) to dumper + if (converter_) { + converter_->dump(dumper); + } dumper->close(); int ret = storage_->open(file_path_, false); if (ret != 0) { @@ -175,6 +183,10 @@ int IVFIndex::Train() { LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } + // Load reformer data from storage (e.g., rotation matrix) + if (reformer_ != nullptr) { + reformer_->load(storage_); + } is_trained_ = true; return 0; } @@ -209,11 +221,8 @@ int IVFIndex::_prepare_for_search( } if (ivf_search_param->nprobe > 0) { - // TODO: 1. sparse; 2. default ef ailego::Params params; - // need fix - params.set(core::PARAM_IVF_BUILDER_CENTROID_COUNT, - ivf_search_param->nprobe); + params.set(core::PARAM_IVF_SEARCHER_NPROBE, ivf_search_param->nprobe); context->update(params); } return 0; @@ -229,6 +238,10 @@ int IVFIndex::Merge(const std::vector &indexes, dumper->create(file_path_); builder_->dump(dumper); + // Dump converter state (e.g., rotator for INT8+rotate) to dumper + if (converter_) { + converter_->dump(dumper); + } dumper->close(); int ret = storage_->open(file_path_, false); if (ret != 0) { @@ -240,6 +253,10 @@ int IVFIndex::Merge(const std::vector &indexes, LOG_ERROR("Failed to open streamer, path: %s", file_path_.c_str()); return core::IndexError_Runtime; } + // Load reformer data from storage (e.g., rotation matrix) + if (reformer_ != nullptr) { + reformer_->load(storage_); + } is_trained_ = true; return 0; } diff --git a/src/core/quantizer/CMakeLists.txt b/src/core/quantizer/CMakeLists.txt index f5c9ad898..4c1558735 100644 --- a/src/core/quantizer/CMakeLists.txt +++ b/src/core/quantizer/CMakeLists.txt @@ -6,11 +6,28 @@ if(NOT APPLE) "-Wl,--exclude-libs,libparquet.a:libarrow.a:libarrow_bundled_dependencies.a") endif() +# x86: use AVX2/AVX512 arch flag from RABITQ detection +if(RABITQ_SUPPORTED AND RABITQ_ARCH_FLAG) + set_source_files_properties( + record_rotator.cc + PROPERTIES + COMPILE_FLAGS "${RABITQ_ARCH_FLAG}" + ) +# ARM aarch64: use armv8-a to enable NEON intrinsics +elseif(HOST_ARCH MATCHES "^(arm|arm64)$" AND NOT MSVC) + set_source_files_properties( + record_rotator.cc + PROPERTIES + COMPILE_FLAGS "-march=armv8-a" + ) +endif() + cc_library( NAME core_quantizer STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc - LIBS zvec_ailego zvec_turbo core_framework + LIBS zvec_ailego core_framework + LIBS zvec_ailego zvec_turbo core_framework rabitqlib INCS . ${PROJECT_ROOT_DIR}/src/core LDFLAGS "${CORE_QUANTIZER_LDFLAGS}" VERSION "${PROXIMA_ZVEC_VERSION}" diff --git a/src/core/quantizer/cosine_converter.cc b/src/core/quantizer/cosine_converter.cc index ded1e3eb5..df5231534 100644 --- a/src/core/quantizer/cosine_converter.cc +++ b/src/core/quantizer/cosine_converter.cc @@ -19,6 +19,7 @@ #include #include #include "record_quantizer.h" +#include "record_rotator.h" #include "../metric/metric_params.h" namespace zvec { @@ -54,6 +55,11 @@ class CosineConverterHolder : public IndexHolder { type_ == IndexMeta::DataType::DT_INT8) { buffer_.resize(element_size, 0); } + + // Allocate rotate buffer if owner has a rotator + if (owner_->rotator_) { + rotate_buffer_.resize(owner_->rotator_->dimension()); + } } this->convert_record(); @@ -116,17 +122,27 @@ class CosineConverterHolder : public IndexHolder { original_element_size); float *buf = reinterpret_cast(&normalize_buffer_[0]); + const float *vec = buf; + + // Apply rotation if enabled + if (owner_->rotator_) { + owner_->rotator_->rotate(vec, rotate_buffer_.data()); + vec = rotate_buffer_.data(); + } float norm = 0.0f; - ailego::Normalizer::L2(buf, original_dimension_, &norm); + ailego::Normalizer::L2(const_cast(vec), + original_dimension_, &norm); if (type_ == IndexMeta::DataType::DT_FP32) { + ::memcpy(reinterpret_cast(&normalize_buffer_[0]), vec, + original_dimension_ * sizeof(float)); ::memcpy(reinterpret_cast(&normalize_buffer_[0]) + original_dimension_, &norm, NORM_SIZE); } else if (type_ == IndexMeta::DataType::DT_FP16) { ailego::FloatHelper::ToFP16( - buf, original_dimension_, + const_cast(vec), original_dimension_, reinterpret_cast(&buffer_[0])); ::memcpy( @@ -134,9 +150,8 @@ class CosineConverterHolder : public IndexHolder { &norm, NORM_SIZE); } else if (type_ == IndexMeta::DataType::DT_INT4 || type_ == IndexMeta::DataType::DT_INT8) { - RecordQuantizer::quantize_record( - reinterpret_cast(normalize_buffer_.data()), - original_dimension_, type_, false, &buffer_[0]); + RecordQuantizer::quantize_record(vec, original_dimension_, type_, + false, &buffer_[0]); ::memcpy(reinterpret_cast(&buffer_[0]) + element_size - NORM_SIZE, @@ -149,6 +164,7 @@ class CosineConverterHolder : public IndexHolder { const CosineConverterHolder *owner_{nullptr}; std::string buffer_{}; std::string normalize_buffer_{}; + std::vector rotate_buffer_; IndexHolder::Iterator::Pointer front_iter_{}; size_t dimension_{0u}; size_t original_dimension_{0u}; @@ -159,11 +175,13 @@ class CosineConverterHolder : public IndexHolder { //! Constructor CosineConverterHolder(IndexHolder::Pointer front, IndexMeta::DataType original_type, - IndexMeta::DataType type) + IndexMeta::DataType type, + std::shared_ptr rotator = nullptr) : front_(std::move(front)), original_type_(original_type), type_(type), - dimension_(front_->dimension()) {} + dimension_(front_->dimension()), + rotator_(std::move(rotator)) {} //! Retrieve count of elements in holder (-1 indicates unknown) size_t count(void) const override { @@ -222,6 +240,7 @@ class CosineConverterHolder : public IndexHolder { IndexMeta::DataType original_type_{}; IndexMeta::DataType type_{}; uint32_t dimension_{0}; + std::shared_ptr rotator_{}; }; /*! Converter of Cosine @@ -264,8 +283,19 @@ class CosineConverter : public IndexConverter { return IndexError_Unsupported; } + // Read rotation config + params.get(COSINE_CONVERTER_ENABLE_ROTATE, &enable_rotate_); + ailego::Params reformer_params; + // Create rotator if rotation is enabled + if (enable_rotate_) { + size_t dim = index_meta.dimension(); + rotator_ = std::make_shared(); + rotator_->init(dim); + LOG_DEBUG("CosineConverter: rotation enabled, dim=%zu", dim); + } + if (dst_type_ == IndexMeta::DataType::DT_INT8) { meta_.set_converter("CosineInt8Converter", 0, params); meta_.set_reformer("CosineInt8Reformer", 0, reformer_params); @@ -333,12 +363,23 @@ class CosineConverter : public IndexConverter { *stats_.mutable_transformed_count() += holder->count(); holder_ = std::make_shared( - holder, holder->data_type(), dst_type_); + holder, holder->data_type(), dst_type_, rotator_); + return 0; + } + + //! Dump index into storage (writes rotator segment if present) + int dump(const IndexDumper::Pointer &dumper) override { + if (rotator_) { + return rotator_->dump(dumper); + } return 0; } - //! Dump index into storage - int dump(const IndexDumper::Pointer & /*dumper*/) override { + //! Dump converter state to storage (rotator) + int dump_to_storage(const IndexStorage::Pointer &storage) override { + if (rotator_) { + return rotator_->dump(storage); + } return 0; } @@ -378,6 +419,8 @@ class CosineConverter : public IndexConverter { IndexHolder::Pointer holder_{}; IndexMeta::DataType original_type_{IndexMeta::DataType::DT_UNDEFINED}; IndexMeta::DataType dst_type_{IndexMeta::DataType::DT_UNDEFINED}; + bool enable_rotate_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_CONVERTER_ALIAS(CosineNormalizeConverter, diff --git a/src/core/quantizer/cosine_reformer.cc b/src/core/quantizer/cosine_reformer.cc index d6080b8d9..5670f1221 100644 --- a/src/core/quantizer/cosine_reformer.cc +++ b/src/core/quantizer/cosine_reformer.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. #include +#include #include #include #include #include #include #include "record_quantizer.h" +#include "record_rotator.h" namespace zvec { namespace core { @@ -53,7 +55,24 @@ class CosineReformer : public IndexReformer { } //! Load index from container - int load(IndexStorage::Pointer) override { + //! Auto-detects rotation by checking for rotator segment in storage. + int load(IndexStorage::Pointer storage) override { + if (enable_rotate_ || storage->get(RECORD_ROTATOR_SEG_ID)) { + rotator_ = std::make_shared(); + int ret = rotator_->open(storage); + if (ret != 0) { + if (enable_rotate_) { + LOG_ERROR("CosineReformer: load rotator failed, ret=%d", ret); + rotator_.reset(); + return ret; + } + rotator_.reset(); + } else { + enable_rotate_ = true; + LOG_DEBUG("CosineReformer: rotator auto-loaded, dim=%zu", + rotator_->dimension()); + } + } return 0; } @@ -83,28 +102,42 @@ class CosineReformer : public IndexReformer { ometa->set_meta(dst_type_, qmeta.dimension() + ExtraDimension(dst_type_)); out->resize(ometa->element_size()); - float norm = 0.0f; size_t origin_dimension = qmeta.dimension(); + const float *vec = reinterpret_cast(query); + float norm = 0.0f; + + // Fast path: no rotation — matches main branch behavior exactly std::string normalized_buffer(reinterpret_cast(query), qmeta.element_size()); - float *buf = reinterpret_cast(&normalized_buffer[0]); - ailego::Normalizer::L2(buf, origin_dimension, &norm); + if (enable_rotate_ && rotator_) { + // Rotate then normalize the rotated vector + std::vector rotate_buffer(rotator_->dimension()); + rotator_->rotate(vec, rotate_buffer.data()); + std::memcpy(buf, rotate_buffer.data(), + origin_dimension * sizeof(float)); + ailego::Normalizer::L2(buf, origin_dimension, &norm); + vec = buf; + } else { + ailego::Normalizer::L2(buf, origin_dimension, &norm); + vec = buf; + } ::memcpy(reinterpret_cast(&(*out)[0]) + ometa->element_size() - NORM_SIZE, &norm, NORM_SIZE); if (dst_type_ == IndexMeta::DataType::DT_FP32) { - ::memcpy(reinterpret_cast(&(*out)[0]), buf, + ::memcpy(reinterpret_cast(&(*out)[0]), vec, ometa->element_size() - NORM_SIZE); } else if (dst_type_ == IndexMeta::DataType::DT_FP16) { - RecordQuantizer::quantize_record(buf, origin_dimension, dst_type_, - false, &(*out)[0]); + RecordQuantizer::quantize_record(const_cast(vec), + qmeta.dimension(), dst_type_, false, + &(*out)[0]); } else if (dst_type_ == IndexMeta::DataType::DT_INT4 || dst_type_ == IndexMeta::DataType::DT_INT8) { - RecordQuantizer::quantize_record(buf, qmeta.dimension(), dst_type_, + RecordQuantizer::quantize_record(vec, qmeta.dimension(), dst_type_, false, &(*out)[0]); } } else if (type == IndexMeta::DataType::DT_FP16) { @@ -186,6 +219,11 @@ class CosineReformer : public IndexReformer { NORM_SIZE, NORM_SIZE); + // Rotation was applied in transform() for all FP32-origin paths (FP32, + // INT8, INT4 stored types). FP16 input path never rotates. + const bool need_inv_rotate = + (type != IndexMeta::DataType::DT_FP16 && enable_rotate_ && rotator_); + if (type == IndexMeta::DataType::DT_FP32) { if (dst_type_ != IndexMeta::DataType::DT_FP32) { return IndexError_Unsupported; @@ -195,6 +233,11 @@ class CosineReformer : public IndexReformer { const float *in_buf = reinterpret_cast(in); this->denormalize(in_buf, out_buf, qmeta, norm); + if (need_inv_rotate) { + std::vector tmp(dimension); + rotator_->unrotate(out_buf, tmp.data()); + std::memcpy(out_buf, tmp.data(), dimension * sizeof(float)); + } } else if (type == IndexMeta::DataType::DT_FP16) { if (dst_type_ != IndexMeta::DataType::DT_FP16) { return IndexError_Unsupported; @@ -210,6 +253,7 @@ class CosineReformer : public IndexReformer { RecordQuantizer::unquantize_record(in, dimension, dst_type_, out_buf); this->denormalize(out_buf, out_buf, qmeta, norm); + // FP16 type path: no rotation was applied, skip inverse } else { ailego::Float16 *out_buf = reinterpret_cast(&(*out)[0]); @@ -228,6 +272,11 @@ class CosineReformer : public IndexReformer { RecordQuantizer::unquantize_record(in, dimension, dst_type_, out_buf); this->denormalize(out_buf, out_buf, qmeta, norm); + if (need_inv_rotate) { + std::vector tmp(dimension); + rotator_->unrotate(out_buf, tmp.data()); + std::memcpy(out_buf, tmp.data(), dimension * sizeof(float)); + } } return 0; @@ -262,6 +311,8 @@ class CosineReformer : public IndexReformer { //! Members IndexMeta::DataType original_type_{IndexMeta::DataType::DT_UNDEFINED}; IndexMeta::DataType dst_type_{IndexMeta::DataType::DT_UNDEFINED}; + bool enable_rotate_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_REFORMER_ALIAS(CosineNormalizeReformer, CosineReformer, diff --git a/src/core/quantizer/integer_quantizer_converter.cc b/src/core/quantizer/integer_quantizer_converter.cc index f812b6e3c..b67914162 100644 --- a/src/core/quantizer/integer_quantizer_converter.cc +++ b/src/core/quantizer/integer_quantizer_converter.cc @@ -19,6 +19,7 @@ #include #include #include "record_quantizer.h" +#include "record_rotator.h" #include "../metric/metric_params.h" namespace zvec { @@ -378,6 +379,7 @@ class IntegerStreamingConverter : public IndexConverter { meta_ = index_meta; params.get(INTEGER_STREAMING_CONVERTER_ENABLE_NORMALIZE, &enable_normalize_); + params.get(INTEGER_STREAMING_CONVERTER_ENABLE_ROTATE, &enable_rotate_); ailego::Params reformer_params; if (enable_normalize_) { reformer_params.set(INTEGER_STREAMING_REFORMER_ENABLE_NORMALIZE, true); @@ -390,6 +392,13 @@ class IntegerStreamingConverter : public IndexConverter { reformer_params.set(INTEGER_STREAMING_REFORMER_IS_EUCLIDEAN, true); } + // Create rotator if rotation is enabled + if (enable_rotate_) { + rotator_ = std::make_shared(); + rotator_->init(index_meta.dimension()); + LOG_DEBUG("IntegerStreamingConverter: rotation enabled, dim=%zu", + static_cast(index_meta.dimension())); + } if (data_type_ == IndexMeta::DataType::DT_INT8) { meta_.set_converter("Int8StreamingConverter", 0, params); @@ -433,12 +442,30 @@ class IntegerStreamingConverter : public IndexConverter { *stats_.mutable_transformed_count() += holder->count(); holder_ = std::make_shared( - holder, data_type_, enable_normalize_, is_euclidean_); + holder, data_type_, enable_normalize_, is_euclidean_, rotator_); return 0; } - //! Dump index into storage - int dump(const IndexDumper::Pointer & /*dumper*/) override { + //! Dump index into storage (writes rotator segment if rotate enabled) + int dump(const IndexDumper::Pointer &dumper) override { + if (enable_rotate_ && rotator_) { + return rotator_->dump(dumper); + } + return 0; + } + + //! Dump converter state to IndexStorage for streaming build + int dump_to_storage(const IndexStorage::Pointer &storage) override { + if (enable_rotate_ && rotator_) { + int ret = rotator_->dump(storage); + if (ret != 0) { + LOG_ERROR( + "IntegerStreamingConverter: dump rotator to storage failed, ret=%d", + ret); + return ret; + } + LOG_DEBUG("IntegerStreamingConverter: rotator dumped to storage"); + } return 0; } @@ -468,7 +495,8 @@ class IntegerStreamingConverter : public IndexConverter { IndexHolder::Iterator::Pointer &&iter) : owner_(owner), buffer_(owner->element_size(), 0), - normalize_buffer_(owner->front_->element_size(), 0), + normalize_buffer_(owner->dimension_ * sizeof(float), 0), + rotate_buffer_(owner->dimension_ * sizeof(float), 0), front_iter_(std::move(iter)) { this->encode_record(); } @@ -503,18 +531,24 @@ class IntegerStreamingConverter : public IndexConverter { if (front_iter_->is_valid()) { const float *vec = reinterpret_cast(front_iter_->data()); + size_t dim = owner_->dimension_; + if (owner_->rotator_) { + float *rotate_buf = + reinterpret_cast(rotate_buffer_.data()); + owner_->rotator_->rotate(vec, rotate_buf); + vec = rotate_buf; + } if (owner_->enable_normalize_) { float norm = 0.0; - memcpy((void *)normalize_buffer_.data(), vec, - owner_->front_->element_size()); + memcpy((void *)normalize_buffer_.data(), vec, dim * sizeof(float)); ailego::Normalizer::L2((float *)normalize_buffer_.data(), - owner_->dimension_, &norm); + dim, &norm); vec = (float *)normalize_buffer_.data(); } - RecordQuantizer::quantize_record( - vec, owner_->dimension_, owner_->data_type(), - owner_->is_euclidean_, buffer_.data()); + RecordQuantizer::quantize_record(vec, dim, owner_->data_type(), + owner_->is_euclidean_, + buffer_.data()); } } @@ -522,18 +556,21 @@ class IntegerStreamingConverter : public IndexConverter { const IntegerStreamingConverterHolder *owner_{nullptr}; std::vector buffer_{}; std::string normalize_buffer_{}; + std::string rotate_buffer_{}; IndexHolder::Iterator::Pointer front_iter_{}; }; //! Constructor IntegerStreamingConverterHolder(IndexHolder::Pointer front, IndexMeta::DataType tp, - bool enable_normalize, bool is_euclidean) + bool enable_normalize, bool is_euclidean, + std::shared_ptr rotator) : front_(std::move(front)), data_type_(tp), dimension_(front_->dimension()), enable_normalize_(enable_normalize), - is_euclidean_(is_euclidean) {} + is_euclidean_(is_euclidean), + rotator_(std::move(rotator)) {} //! Retrieve count of elements in holder (-1 indicates unknown) size_t count(void) const override { @@ -576,6 +613,7 @@ class IntegerStreamingConverter : public IndexConverter { uint32_t dimension_{0}; bool enable_normalize_{false}; bool is_euclidean_{false}; + std::shared_ptr rotator_{}; }; static size_t ExtraDimension(IndexMeta::DataType type) { @@ -593,7 +631,9 @@ class IntegerStreamingConverter : public IndexConverter { IndexHolder::Pointer holder_{}; IndexMeta::DataType data_type_{}; bool enable_normalize_{false}; + bool enable_rotate_{false}; bool is_euclidean_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_CONVERTER_ALIAS( diff --git a/src/core/quantizer/integer_quantizer_reformer.cc b/src/core/quantizer/integer_quantizer_reformer.cc index 4228d0fda..eaa858a10 100644 --- a/src/core/quantizer/integer_quantizer_reformer.cc +++ b/src/core/quantizer/integer_quantizer_reformer.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -19,6 +20,7 @@ #include #include #include "record_quantizer.h" +#include "record_rotator.h" namespace zvec { namespace core { @@ -295,7 +297,30 @@ class IntegerStreamingReformer : public IndexReformer { } //! Load index from container - int load(IndexStorage::Pointer) override { + //! Auto-detects rotation by checking for rotator segment in storage. + //! No need for enable_rotate in search config. + int load(IndexStorage::Pointer storage) override { + // If config explicitly enables rotate but rotator not yet loaded, try + // storage If config doesn't enable rotate, still try storage (auto-detect) + if (enable_rotate_ || storage->get(RECORD_ROTATOR_SEG_ID)) { + rotator_ = std::make_shared(); + int ret = rotator_->open(storage); + if (ret != 0) { + if (enable_rotate_) { + // Config said enable_rotate but storage has no rotator — error + LOG_ERROR("IntegerStreamingReformer: load rotator failed, ret=%d", + ret); + rotator_.reset(); + return ret; + } + // No rotator in storage, rotation not available + rotator_.reset(); + } else { + enable_rotate_ = true; + LOG_DEBUG("IntegerStreamingReformer: rotator auto-loaded, dim=%zu", + rotator_->dimension()); + } + } return 0; } @@ -319,10 +344,16 @@ class IntegerStreamingReformer : public IndexReformer { ometa->set_meta(data_type_, qmeta.dimension() + extra_dimension_); out->resize(ometa->element_size()); const float *vec = reinterpret_cast(query); + std::unique_ptr rotate_buffer; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } std::unique_ptr normalized; if (enable_normalize_) { normalized.reset(new float[qmeta.dimension()]); - vec = normalize(query, qmeta, normalized.get()); + vec = normalize(vec, qmeta, normalized.get()); } RecordQuantizer::quantize_record(vec, qmeta.dimension(), data_type_, @@ -344,13 +375,21 @@ class IntegerStreamingReformer : public IndexReformer { *ometa = qmeta; ometa->set_meta(data_type_, qmeta.dimension() + extra_dimension_); out->resize(count * ometa->element_size()); + std::unique_ptr rotate_buffer; std::unique_ptr normalized; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + } if (enable_normalize_) { normalized.reset(new float[qmeta.dimension()]); } for (size_t i = 0; i < count; ++i) { const float *vec = reinterpret_cast(query) + i * qmeta.dimension(); + if (enable_rotate_ && rotator_) { + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } if (enable_normalize_) { vec = normalize(vec, qmeta, normalized.get()); } @@ -378,10 +417,16 @@ class IntegerStreamingReformer : public IndexReformer { ometa->set_meta(data_type_, rmeta.dimension() + extra_dimension_); out->resize(ometa->element_size()); const float *vec = reinterpret_cast(record); + std::unique_ptr rotate_buffer; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } std::unique_ptr normalized; if (enable_normalize_) { normalized.reset(new float[rmeta.dimension()]); - vec = normalize(record, rmeta, normalized.get()); + vec = normalize(vec, rmeta, normalized.get()); } RecordQuantizer::quantize_record(vec, rmeta.dimension(), data_type_, @@ -404,13 +449,21 @@ class IntegerStreamingReformer : public IndexReformer { *ometa = rmeta; ometa->set_meta(data_type_, rmeta.dimension() + extra_dimension_); out->resize(count * ometa->element_size()); + std::unique_ptr rotate_buffer; std::unique_ptr normalized; + if (enable_rotate_ && rotator_) { + rotate_buffer.reset(new float[rotator_->dimension()]); + } if (enable_normalize_) { normalized.reset(new float[rmeta.dimension()]); } for (size_t i = 0; i < count; ++i) { const float *vec = reinterpret_cast(records) + i * rmeta.dimension(); + if (enable_rotate_ && rotator_) { + rotator_->rotate(vec, rotate_buffer.get()); + vec = rotate_buffer.get(); + } if (enable_normalize_) { vec = normalize(vec, rmeta, normalized.get()); } @@ -447,15 +500,23 @@ class IntegerStreamingReformer : public IndexReformer { std::string *out) const override { if (enable_normalize_) { LOG_ERROR("Unsupported revert for normalized value"); - return IndexError_Unsupported; } - out->resize((qmeta.dimension() - extra_dimension_) * sizeof(float)); - float *out_buf = reinterpret_cast(out->data()); + const size_t stored_dim = qmeta.dimension() - extra_dimension_; - RecordQuantizer::unquantize_record(in, qmeta.dimension() - extra_dimension_, - data_type_, out_buf); + // Step 1: Unquantize into out buffer (stored_dim floats) + out->resize(stored_dim * sizeof(float)); + float *out_buf = reinterpret_cast(out->data()); + RecordQuantizer::unquantize_record(in, stored_dim, data_type_, out_buf); + + // Step 2: Inverse rotate in-place if rotation was applied + if (enable_rotate_ && rotator_) { + std::vector tmp(rotator_->dimension()); + rotator_->unrotate(out_buf, tmp.data()); + out->assign(reinterpret_cast(tmp.data()), + tmp.size() * sizeof(float)); + } return 0; } @@ -465,6 +526,8 @@ class IntegerStreamingReformer : public IndexReformer { uint32_t extra_dimension_{0}; bool enable_normalize_{false}; bool is_euclidean_{false}; + bool enable_rotate_{false}; + std::shared_ptr rotator_{}; }; INDEX_FACTORY_REGISTER_REFORMER_ALIAS( diff --git a/src/core/quantizer/quantizer_params.h b/src/core/quantizer/quantizer_params.h index a089a2d9f..d56c8591d 100644 --- a/src/core/quantizer/quantizer_params.h +++ b/src/core/quantizer/quantizer_params.h @@ -100,6 +100,8 @@ static const std::string INT4_QUANTIZER_REFORMER_METRIC = //! CosineConverter static const std::string COSINE_CONVERTER_FORCED_HALF_FLOAT = "cosine.converter.forced_half_float"; +static const std::string COSINE_CONVERTER_ENABLE_ROTATE = + "cosine.converter.enable_rotate"; //! CosineReformer static const std::string COSINE_REFORMER_FORCED_HALF_FLOAT = @@ -108,8 +110,10 @@ static const std::string COSINE_REFORMER_FORCED_HALF_FLOAT = //! IntegerStreamingConverter static const std::string INTEGER_STREAMING_CONVERTER_ENABLE_NORMALIZE = "integer_streaming.converter.enable_normalize"; +static const std::string INTEGER_STREAMING_CONVERTER_ENABLE_ROTATE = + "integer_streaming.converter.enable_rotate"; -//! IntegerStreamingConverter +//! IntegerStreamingReformer static const std::string INTEGER_STREAMING_REFORMER_ENABLE_NORMALIZE = "integer_streaming.reformer.enable_normalize"; static const std::string INTEGER_STREAMING_REFORMER_IS_EUCLIDEAN = diff --git a/src/core/quantizer/record_rotator.cc b/src/core/quantizer/record_rotator.cc new file mode 100644 index 000000000..d2d51fd7f --- /dev/null +++ b/src/core/quantizer/record_rotator.cc @@ -0,0 +1,917 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "record_rotator.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__AVX2__) || defined(__AVX512F__) +#include +// FFHT (Fastest Fast Hadamard Transform) — hand-tuned AVX inline assembly +// from https://github.com/FALCONN-LIB/FFHT, originally bundled in rabitqlib. +// Provides fht_float(buf, log_n) with per-size helper_float_N specialisations. +#if defined(__GNUC__) +#include "rabitqlib/utils/fht_avx.hpp" +#endif +#elif defined(__SSE2__) +#include +#endif + +#if defined(__ARM_NEON) && defined(__aarch64__) +#include +#endif + +#include +#include "zvec/core/framework/index_error.h" +#include "zvec/core/framework/index_logger.h" + +namespace zvec { +namespace core { + +namespace { + +template +using RowMajorMatrix = + Eigen::Matrix; + +template +using RowMajorMatrixMap = Eigen::Map>; + +template +using ConstRowMajorMatrixMap = Eigen::Map>; + +template +RowMajorMatrix random_gaussian_matrix(size_t rows, size_t cols) { + RowMajorMatrix rand(rows, cols); + static std::random_device rd; + static std::mt19937 gen(rd()); + std::normal_distribution dist(0, 1); + + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + rand(i, j) = dist(gen); + } + } + + return rand; +} + +// ============================================================================ +// Scalar / SIMD helper functions for rotation +// ============================================================================ + +//! In-place Fast Hadamard Transform on a power-of-2 length array. +//! Uses FFHT hand-tuned AVX assembly when available; generic scalar loop +//! otherwise (ARM NEON / SSE2 / pure scalar). +void fht_inplace(float *data, size_t n) { +#if (defined(__AVX2__) || defined(__AVX512F__)) && defined(__GNUC__) + // Compute floor(log2(n)) for power-of-2 n. + int log_n = 0; + for (size_t v = n; v > 1; v >>= 1) ++log_n; + fht_float(data, log_n); +#else + for (size_t len = 1; len < n; len <<= 1) { + for (size_t i = 0; i < n; i += len << 1) { + for (size_t j = i; j < i + len; ++j) { + float u = data[j]; + float v = data[j + len]; + data[j] = u + v; + data[j + len] = u - v; + } + } + } +#endif +} + +//! Flip the sign of elements based on a packed bit-array. +void flip_sign(const uint8_t *flip, float *data, size_t dim) { +#if defined(__AVX512F__) && defined(__AVX512DQ__) + size_t simd_end = dim & ~63u; +#elif defined(__AVX2__) + size_t simd_end = dim & ~31u; +#else + size_t simd_end = dim; // SSE2/NEON/scalar: chunk divides 4, no tail +#endif + +#if defined(__AVX512F__) && defined(__AVX512DQ__) + constexpr size_t kChunk = 64; + const __m512 sign_flip = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000)); + for (size_t i = 0; i < simd_end; i += kChunk) { + uint64_t mask_bits; + std::memcpy(&mask_bits, &flip[i / 8], sizeof(mask_bits)); + const __mmask16 m0 = _cvtu32_mask16(mask_bits & 0xFFFF); + const __mmask16 m1 = _cvtu32_mask16((mask_bits >> 16) & 0xFFFF); + const __mmask16 m2 = _cvtu32_mask16((mask_bits >> 32) & 0xFFFF); + const __mmask16 m3 = _cvtu32_mask16((mask_bits >> 48) & 0xFFFF); + __m512 v0 = _mm512_loadu_ps(&data[i]); + v0 = _mm512_mask_xor_ps(v0, m0, v0, sign_flip); + _mm512_storeu_ps(&data[i], v0); + __m512 v1 = _mm512_loadu_ps(&data[i + 16]); + v1 = _mm512_mask_xor_ps(v1, m1, v1, sign_flip); + _mm512_storeu_ps(&data[i + 16], v1); + __m512 v2 = _mm512_loadu_ps(&data[i + 32]); + v2 = _mm512_mask_xor_ps(v2, m2, v2, sign_flip); + _mm512_storeu_ps(&data[i + 32], v2); + __m512 v3 = _mm512_loadu_ps(&data[i + 48]); + v3 = _mm512_mask_xor_ps(v3, m3, v3, sign_flip); + _mm512_storeu_ps(&data[i + 48], v3); + } +#elif defined(__AVX2__) + constexpr size_t kChunk = 32; + const __m256i bit_select = + _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); + const __m256 sign_flip = _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000)); + for (size_t i = 0; i < simd_end; i += kChunk) { + uint32_t mask_bits; + std::memcpy(&mask_bits, &flip[i / 8], sizeof(mask_bits)); + for (int b = 0; b < 4; ++b) { + __m256i mb = _mm256_set1_epi32((mask_bits >> (b * 8)) & 0xFF); + __m256i test = _mm256_and_si256(mb, bit_select); + __m256i cmp = _mm256_cmpeq_epi32(test, bit_select); + __m256 xor_mask = _mm256_and_ps(_mm256_castsi256_ps(cmp), sign_flip); + __m256 v = _mm256_loadu_ps(&data[i + b * 8]); + v = _mm256_xor_ps(v, xor_mask); + _mm256_storeu_ps(&data[i + b * 8], v); + } + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + // 128-bit NEON: process 4 floats per iteration. + // Load 2 bytes (16 bits) to safely handle cross-byte boundaries. + const uint32x4_t sign_bit = vdupq_n_u32(0x80000000u); + for (size_t i = 0; i < dim; i += 4) { + uint16_t bits16; + std::memcpy(&bits16, &flip[i / 8], sizeof(bits16)); + bits16 >>= (i % 8); + uint32_t b0 = bits16 & 1u; + uint32_t b1 = (bits16 >> 1) & 1u; + uint32_t b2 = (bits16 >> 2) & 1u; + uint32_t b3 = (bits16 >> 3) & 1u; + uint32x4_t bit_mask = {b0, b1, b2, b3}; + uint32x4_t sign_mask = vmulq_u32(bit_mask, sign_bit); + float32x4_t v = vld1q_f32(&data[i]); + v = vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(v), sign_mask)); + vst1q_f32(&data[i], v); + } +#elif defined(__SSE2__) + // 128-bit SSE2: process 4 floats per iteration. + // Load 2 bytes (16 bits) to safely handle cross-byte boundaries. + for (size_t i = 0; i < dim; i += 4) { + uint16_t bits16; + std::memcpy(&bits16, &flip[i / 8], sizeof(bits16)); + bits16 >>= (i % 8); + uint32_t b0 = bits16 & 1u; + uint32_t b1 = (bits16 >> 1) & 1u; + uint32_t b2 = (bits16 >> 2) & 1u; + uint32_t b3 = (bits16 >> 3) & 1u; + __m128i bit_mask = _mm_set_epi32(b3, b2, b1, b0); + __m128i sign_mask = _mm_slli_epi32(bit_mask, 31); + __m128 v = _mm_loadu_ps(&data[i]); + v = _mm_xor_ps(v, _mm_castsi128_ps(sign_mask)); + _mm_storeu_ps(&data[i], v); + } +#else + for (size_t i = 0; i < dim; ++i) { + if (flip[i / 8] & (1u << (i % 8))) { + data[i] = -data[i]; + } + } +#endif + // Scalar tail: handle remaining elements when dim is not SIMD-aligned. + for (size_t i = simd_end; i < dim; ++i) { + if (flip[i / 8] & (1u << (i % 8))) { + data[i] = -data[i]; + } + } +} + +//! Kac random walk: butterfly add/sub between first and second halves. +void kacs_walk(float *data, size_t len) { + size_t half = len / 2; +#if defined(__AVX512F__) + size_t half_end = half & ~15u; +#elif defined(__AVX2__) + size_t half_end = half & ~7u; +#elif defined(__SSE2__) || (defined(__ARM_NEON) && defined(__aarch64__)) + size_t half_end = half & ~3u; +#else + size_t half_end = half; +#endif + +#if defined(__AVX512F__) + for (size_t i = 0; i < half_end; i += 16) { + __m512 x = _mm512_loadu_ps(&data[i]); + __m512 y = _mm512_loadu_ps(&data[i + half]); + _mm512_storeu_ps(&data[i], _mm512_add_ps(x, y)); + _mm512_storeu_ps(&data[i + half], _mm512_sub_ps(x, y)); + } +#elif defined(__AVX2__) + for (size_t i = 0; i < half_end; i += 8) { + __m256 x = _mm256_loadu_ps(&data[i]); + __m256 y = _mm256_loadu_ps(&data[i + half]); + _mm256_storeu_ps(&data[i], _mm256_add_ps(x, y)); + _mm256_storeu_ps(&data[i + half], _mm256_sub_ps(x, y)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (size_t i = 0; i < half_end; i += 4) { + float32x4_t x = vld1q_f32(&data[i]); + float32x4_t y = vld1q_f32(&data[i + half]); + vst1q_f32(&data[i], vaddq_f32(x, y)); + vst1q_f32(&data[i + half], vsubq_f32(x, y)); + } +#elif defined(__SSE2__) + for (size_t i = 0; i < half_end; i += 4) { + __m128 x = _mm_loadu_ps(&data[i]); + __m128 y = _mm_loadu_ps(&data[i + half]); + _mm_storeu_ps(&data[i], _mm_add_ps(x, y)); + _mm_storeu_ps(&data[i + half], _mm_sub_ps(x, y)); + } +#else + for (size_t i = 0; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +#endif + // Scalar tail: handle remaining pairs when half is not SIMD-aligned. + for (size_t i = half_end; i < half; ++i) { + float x = data[i]; + float y = data[i + half]; + data[i] = x + y; + data[i + half] = x - y; + } +} + +//! Inverse Kac walk: undo butterfly add/sub with 0.5 factor. +//! If forward maps (x,y) -> (x+y, x-y), inverse maps (a,b) -> ((a+b)/2, +//! (a-b)/2). +void inv_kacs_walk(float *data, size_t len) { + size_t half = len / 2; +#if defined(__AVX512F__) + size_t half_end = half & ~15u; +#elif defined(__AVX2__) + size_t half_end = half & ~7u; +#elif defined(__SSE2__) || (defined(__ARM_NEON) && defined(__aarch64__)) + size_t half_end = half & ~3u; +#else + size_t half_end = half; +#endif + +#if defined(__AVX512F__) + const __m512 half_fac = _mm512_set1_ps(0.5f); + for (size_t i = 0; i < half_end; i += 16) { + __m512 a = _mm512_loadu_ps(&data[i]); + __m512 b = _mm512_loadu_ps(&data[i + half]); + _mm512_storeu_ps(&data[i], _mm512_mul_ps(_mm512_add_ps(a, b), half_fac)); + _mm512_storeu_ps(&data[i + half], + _mm512_mul_ps(_mm512_sub_ps(a, b), half_fac)); + } +#elif defined(__AVX2__) + const __m256 half_fac = _mm256_set1_ps(0.5f); + for (size_t i = 0; i < half_end; i += 8) { + __m256 a = _mm256_loadu_ps(&data[i]); + __m256 b = _mm256_loadu_ps(&data[i + half]); + _mm256_storeu_ps(&data[i], _mm256_mul_ps(_mm256_add_ps(a, b), half_fac)); + _mm256_storeu_ps(&data[i + half], + _mm256_mul_ps(_mm256_sub_ps(a, b), half_fac)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + const float32x4_t half_fac = vdupq_n_f32(0.5f); + for (size_t i = 0; i < half_end; i += 4) { + float32x4_t a = vld1q_f32(&data[i]); + float32x4_t b = vld1q_f32(&data[i + half]); + vst1q_f32(&data[i], vmulq_f32(vaddq_f32(a, b), half_fac)); + vst1q_f32(&data[i + half], vmulq_f32(vsubq_f32(a, b), half_fac)); + } +#elif defined(__SSE2__) + const __m128 half_fac = _mm_set1_ps(0.5f); + for (size_t i = 0; i < half_end; i += 4) { + __m128 a = _mm_loadu_ps(&data[i]); + __m128 b = _mm_loadu_ps(&data[i + half]); + _mm_storeu_ps(&data[i], _mm_mul_ps(_mm_add_ps(a, b), half_fac)); + _mm_storeu_ps(&data[i + half], _mm_mul_ps(_mm_sub_ps(a, b), half_fac)); + } +#else + for (size_t i = 0; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +#endif + // Scalar tail: handle remaining pairs when half is not SIMD-aligned. + for (size_t i = half_end; i < half; ++i) { + float a = data[i]; + float b = data[i + half]; + data[i] = (a + b) * 0.5f; + data[i + half] = (a - b) * 0.5f; + } +} + +//! Scale each element by a constant factor. +void vec_rescale(float *data, size_t n, float factor) { + for (size_t i = 0; i < n; ++i) { + data[i] *= factor; + } +} + +//! Largest power-of-2 not exceeding n. +size_t floor_pow2(size_t n) { + size_t p = 1; + while ((p << 1) <= n) p <<= 1; + return p; +} + +//! Read a little-endian uint32 from raw bytes. +uint32_t read_u32_le(const char *p) { + return static_cast(static_cast(p[0])) | + (static_cast(static_cast(p[1])) << 8) | + (static_cast(static_cast(p[2])) << 16) | + (static_cast(static_cast(p[3])) << 24); +} + +//! Write a uint32 in little-endian to raw bytes. +void write_u32_le(char *p, uint32_t v) { + p[0] = static_cast(v & 0xFF); + p[1] = static_cast((v >> 8) & 0xFF); + p[2] = static_cast((v >> 16) & 0xFF); + p[3] = static_cast((v >> 24) & 0xFF); +} + +// ============================================================================ +// FhtKacRotatorImpl - O(d log d) FHT-based Kac random rotation +// +// Requires dimension % 4 == 0 (scalar tails handle SIMD remainder). +// When dimension is a power of 2, uses 4 rounds of (flip -> FHT -> +// rescale). When dimension is NOT a power of 2 (e.g. 96, 100, 192), +// uses kacs_walk reduction to handle the non-power-of-2 case. +// ============================================================================ + +struct FhtKacRotatorImpl { + std::vector flip; + size_t trunc_dim{0}; + float fac{0}; + + static constexpr size_t kByteLen = 8; + + void init(size_t dim) { + flip.resize(4 * dim / kByteLen); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dist(0, 255); + for (auto &b : flip) b = static_cast(dist(gen)); + } + + void rotate(const float *in, float *out, size_t dim) const { + std::memcpy(out, in, sizeof(float) * dim); + + if (trunc_dim == dim) { + // Exact power-of-2: 4 rounds of (flip -> FHT -> rescale) + flip_sign(flip.data(), out, dim); + fht_inplace(out, trunc_dim); + vec_rescale(out, trunc_dim, fac); + + flip_sign(flip.data() + dim / kByteLen, out, dim); + fht_inplace(out, trunc_dim); + vec_rescale(out, trunc_dim, fac); + + flip_sign(flip.data() + 2 * dim / kByteLen, out, dim); + fht_inplace(out, trunc_dim); + vec_rescale(out, trunc_dim, fac); + + flip_sign(flip.data() + 3 * dim / kByteLen, out, dim); + fht_inplace(out, trunc_dim); + vec_rescale(out, trunc_dim, fac); + + return; + } + + // Non-power-of-2 (64-aligned, e.g. 192, 320): 4 rounds with kacs_walk + // reduction. FHT always operates on trunc_dim (largest power-of-2 <= dim). + size_t start = dim - trunc_dim; + float *trunc_ptr = out + start; + + // Round 1: FHT on [0, trunc_dim) + flip_sign(flip.data(), out, dim); + fht_inplace(out, trunc_dim); + vec_rescale(out, trunc_dim, fac); + kacs_walk(out, dim); + + // Round 2: FHT on [start, start + trunc_dim) + flip_sign(flip.data() + dim / kByteLen, out, dim); + fht_inplace(trunc_ptr, trunc_dim); + vec_rescale(trunc_ptr, trunc_dim, fac); + kacs_walk(out, dim); + + // Round 3: FHT on [0, trunc_dim) + flip_sign(flip.data() + 2 * dim / kByteLen, out, dim); + fht_inplace(out, trunc_dim); + vec_rescale(out, trunc_dim, fac); + kacs_walk(out, dim); + + // Round 4: FHT on [start, start + trunc_dim) + flip_sign(flip.data() + 3 * dim / kByteLen, out, dim); + fht_inplace(trunc_ptr, trunc_dim); + vec_rescale(trunc_ptr, trunc_dim, fac); + kacs_walk(out, dim); + + // Final rescale: combine the 4 kacs_walk reductions + vec_rescale(out, dim, 0.25f); + } + + void save(char *data) const { + std::memcpy(data, flip.data(), flip.size()); + } + + void load(const char *data) { + std::memcpy(flip.data(), data, flip.size()); + } + + size_t dump_bytes() const { + return flip.size(); + } + + void unrotate(const float *in, float *out, size_t dim) const { + // Copy input into working buffer + std::vector data(in, in + dim); + + if (trunc_dim == dim) { + // Exact power-of-2: reverse 4 rounds in reverse order. + // Forward per round: flip -> fht -> rescale(fac) + // Reverse per round: rescale(1/fac) -> inv_fht -> flip + // Combined: fht + rescale(1/sqrt(trunc_dim)) + flip + const float inv_fac = 1.0f / std::sqrt(static_cast(trunc_dim)); + for (int round = 3; round >= 0; --round) { + fht_inplace(data.data(), trunc_dim); + vec_rescale(data.data(), trunc_dim, inv_fac); + flip_sign(flip.data() + round * dim / kByteLen, data.data(), dim); + } + std::memcpy(out, data.data(), dim * sizeof(float)); + return; + } + + // Non-power-of-2: undo final rescale(0.25) first + vec_rescale(data.data(), dim, 4.0f); + + // Reverse 4 rounds in reverse order. + // Forward round: flip -> fht -> rescale(fac) -> kacs_walk + // Reverse: inv_kacs_walk -> rescale(1/fac) -> inv_fht -> flip + // Combined inv_fht: fht + rescale(1/sqrt(trunc_dim)) + const float inv_fac = 1.0f / std::sqrt(static_cast(trunc_dim)); + size_t start = dim - trunc_dim; + float *trunc_ptr = data.data() + start; + + // Undo Round 4 (FHT on [start, start+trunc_dim)) + inv_kacs_walk(data.data(), dim); + fht_inplace(trunc_ptr, trunc_dim); + vec_rescale(trunc_ptr, trunc_dim, inv_fac); + flip_sign(flip.data() + 3 * dim / kByteLen, data.data(), dim); + + // Undo Round 3 (FHT on [0, trunc_dim)) + inv_kacs_walk(data.data(), dim); + fht_inplace(data.data(), trunc_dim); + vec_rescale(data.data(), trunc_dim, inv_fac); + flip_sign(flip.data() + 2 * dim / kByteLen, data.data(), dim); + + // Undo Round 2 (FHT on [start, start+trunc_dim)) + inv_kacs_walk(data.data(), dim); + fht_inplace(trunc_ptr, trunc_dim); + vec_rescale(trunc_ptr, trunc_dim, inv_fac); + flip_sign(flip.data() + dim / kByteLen, data.data(), dim); + + // Undo Round 1 (FHT on [0, trunc_dim)) + inv_kacs_walk(data.data(), dim); + fht_inplace(data.data(), trunc_dim); + vec_rescale(data.data(), trunc_dim, inv_fac); + flip_sign(flip.data(), data.data(), dim); + + std::memcpy(out, data.data(), dim * sizeof(float)); + } +}; + +// ============================================================================ +// MatrixRotatorImpl - O(d^2) random orthogonal matrix rotation +// +// No alignment requirement on dimension. Uses a dim x dim square orthogonal +// matrix generated via Householder QR on a random Gaussian matrix. +// ============================================================================ + +struct MatrixRotatorImpl { + std::vector matrix; // dim x dim, row-major + + void init(size_t dim) { + // Generate dim x dim random Gaussian matrix + RowMajorMatrix rand_mat = random_gaussian_matrix(dim, dim); + + // Householder QR: numerically stable orthogonalisation + Eigen::HouseholderQR> qr(rand_mat); + RowMajorMatrix q_inv = qr.householderQ().transpose(); + + matrix.resize(dim * dim); + std::memcpy(matrix.data(), &q_inv(0, 0), sizeof(float) * dim * dim); + } + + void rotate(const float *in, float *out, size_t dim) const { + // v (1 x dim) * M (dim x dim) -> rv (1 x dim) + ConstRowMajorMatrixMap v(in, 1, dim); + RowMajorMatrixMap rv(out, 1, dim); + rv = v * ConstRowMajorMatrixMap(matrix.data(), dim, dim); + } + + void save(char *data) const { + std::memcpy(data, matrix.data(), matrix.size() * sizeof(float)); + } + + void load(const char *data) { + std::memcpy(matrix.data(), data, matrix.size() * sizeof(float)); + } + + size_t dump_bytes() const { + return matrix.size() * sizeof(float); + } + + //! Inverse rotate using M^T (transpose of the dim x dim orthogonal matrix). + void unrotate(const float *in, float *out, size_t dim) const { + // in (1 x dim) * M^T (dim x dim) -> out (1 x dim) + ConstRowMajorMatrixMap v(in, 1, dim); + RowMajorMatrixMap rv(out, 1, dim); + rv = v * ConstRowMajorMatrixMap(matrix.data(), dim, dim).transpose(); + } +}; + +} // anonymous namespace + +// ============================================================================ +// RecordRotator::Impl +// ============================================================================ + +struct RecordRotator::Impl { + //! Header layout (12 bytes, backward-compatible with older serialised data): + //! type(1B) + padding(3B) + origin_dim(4B) + reserved(4B) = 12B + //! The reserved field previously stored padded_dim; it now mirrors + //! origin_dim. + static constexpr size_t kHeaderSize = 12; + + struct Header { + uint8_t type; + uint32_t origin_dim; + uint32_t reserved; // backward-compat placeholder (was padded_dim) + + void write_to(char *buf) const { + // Write fields individually (avoids GCC -Warray-bounds false positive + // on memset when inlined through vector::data() at -O3). + buf[0] = static_cast(type); + buf[1] = buf[2] = buf[3] = 0; // padding + write_u32_le(buf + 4, origin_dim); + write_u32_le(buf + 8, reserved); + } + + void read_from(const char *buf) { + type = static_cast(buf[0]); + origin_dim = read_u32_le(buf + 4); + // reserved (buf+8) is intentionally ignored for forward compatibility + } + }; + + size_t dimension{0}; + RecordRotatorType type{RecordRotatorType::FhtKac}; + + std::unique_ptr fht_impl; + std::unique_ptr mat_impl; + + void do_rotate(const float *in, float *out) const { + if (fht_impl) { + fht_impl->rotate(in, out, dimension); + } else { + mat_impl->rotate(in, out, dimension); + } + } + + void do_unrotate(const float *in, float *out) const { + if (fht_impl) { + fht_impl->unrotate(in, out, dimension); + } else { + mat_impl->unrotate(in, out, dimension); + } + } + + size_t blob_bytes() const { + if (fht_impl) return fht_impl->dump_bytes(); + return mat_impl->dump_bytes(); + } + + void save_blob(char *data) const { + if (fht_impl) { + fht_impl->save(data); + } else { + mat_impl->save(data); + } + } + + void load_blob(const char *data) { + if (fht_impl) { + fht_impl->load(data); + } else { + mat_impl->load(data); + } + } +}; + +// ============================================================================ +// RecordRotator public methods +// ============================================================================ + +RecordRotator::RecordRotator() : impl_(std::make_unique()) {} + +RecordRotator::~RecordRotator() = default; + +RecordRotator::RecordRotator(RecordRotator &&) noexcept = default; +RecordRotator &RecordRotator::operator=(RecordRotator &&) noexcept = default; + +void RecordRotator::init(size_t dimension, RecordRotatorType rotator_type) { + impl_->dimension = dimension; + + // Auto-select implementation based on dimension alignment when FhtKac + // is requested. FhtKac requires the dimension to be a multiple of 4; + // scalar tails handle the SIMD remainder. When the dimension is not + // 4-aligned we transparently fall back to the O(d^2) Matrix rotator. + bool use_fht = + (rotator_type == RecordRotatorType::FhtKac) && (dimension % 4 == 0); + + if (use_fht) { + impl_->type = RecordRotatorType::FhtKac; + impl_->fht_impl = std::make_unique(); + impl_->fht_impl->trunc_dim = floor_pow2(dimension); + impl_->fht_impl->fac = + 1.0f / std::sqrt(static_cast(impl_->fht_impl->trunc_dim)); + impl_->fht_impl->init(dimension); + } else { + if (rotator_type == RecordRotatorType::FhtKac) { + LOG_DEBUG( + "RecordRotator::init: dimension %zu is not 4-aligned, " + "falling back from FhtKac to Matrix rotator", + dimension); + } + impl_->type = RecordRotatorType::Matrix; + impl_->mat_impl = std::make_unique(); + impl_->mat_impl->init(dimension); + } +} + +void RecordRotator::rotate(const float *in, float *out) const { + impl_->do_rotate(in, out); +} + +std::vector RecordRotator::rotate(const float *in) const { + std::vector out(impl_->dimension); + impl_->do_rotate(in, out.data()); + return out; +} + +void RecordRotator::unrotate(const float *in, float *out) const { + if (!impl_->fht_impl && !impl_->mat_impl) { + LOG_ERROR("RecordRotator::unrotate: rotator not initialized"); + return; + } + impl_->do_unrotate(in, out); +} + +std::vector RecordRotator::unrotate(const float *in) const { + std::vector out(impl_->dimension); + unrotate(in, out.data()); + return out; +} + +size_t RecordRotator::dump_bytes() const { + return Impl::kHeaderSize + impl_->blob_bytes(); +} + +int RecordRotator::dump(const IndexStorage::Pointer &storage, + const std::string &seg_id) const { + if (!storage) { + LOG_ERROR("RecordRotator::dump(storage): null storage"); + return IndexError_InvalidArgument; + } + if (!impl_->fht_impl && !impl_->mat_impl) { + LOG_ERROR("RecordRotator::dump(storage): rotator not initialized"); + return IndexError_NoReady; + } + + auto align_size = [](size_t size) -> size_t { + return (size + 0x1F) & (~0x1F); + }; + + // Serialize: [Header: type|origin_dim|reserved] [rotation blob] + const size_t blob_size = impl_->blob_bytes(); + const size_t data_size = Impl::kHeaderSize + blob_size; + const size_t total_size = align_size(data_size); + std::vector buffer(data_size); + + Impl::Header header; + header.type = static_cast(impl_->type); + header.origin_dim = static_cast(impl_->dimension); + header.reserved = static_cast(impl_->dimension); // backward compat + header.write_to(buffer.data()); + impl_->save_blob(buffer.data() + Impl::kHeaderSize); + + // Append segment to storage + int ret = storage->append(seg_id, total_size); + if (ret != 0) { + LOG_ERROR( + "RecordRotator::dump(storage): append segment '%s' failed, ret=%d", + seg_id.c_str(), ret); + return ret; + } + + auto segment = storage->get(seg_id); + if (!segment) { + LOG_ERROR("RecordRotator::dump(storage): get segment '%s' failed", + seg_id.c_str()); + return IndexError_WriteData; + } + + size_t written = segment->write(0, buffer.data(), data_size); + if (written != data_size) { + LOG_ERROR( + "RecordRotator::dump(storage): write failed, written=%zu, expected=%zu", + written, data_size); + return IndexError_WriteData; + } + segment->resize(data_size); + segment->update_data_crc(ailego::Crc32c::Hash(buffer.data(), data_size, 0)); + + LOG_DEBUG( + "RecordRotator::dump(storage) done: seg=%s, data_size=%zu, total=%zu", + seg_id.c_str(), data_size, total_size); + return 0; +} + +int RecordRotator::dump(const IndexDumper::Pointer &dumper, + const std::string &seg_id) const { + if (!dumper) { + LOG_ERROR("RecordRotator::dump(dumper): null dumper"); + return IndexError_InvalidArgument; + } + if (!impl_->fht_impl && !impl_->mat_impl) { + LOG_ERROR("RecordRotator::dump(dumper): rotator not initialized"); + return IndexError_NoReady; + } + + // Serialize: [Header: type|origin_dim|reserved] [rotation blob] + const size_t blob_size = impl_->blob_bytes(); + const size_t data_size = Impl::kHeaderSize + blob_size; + const size_t total_size = (data_size + 0x1F) & (~0x1F); + + std::vector buffer(total_size, 0); + Impl::Header header; + header.type = static_cast(impl_->type); + header.origin_dim = static_cast(impl_->dimension); + header.reserved = static_cast(impl_->dimension); // backward compat + header.write_to(buffer.data()); + impl_->save_blob(buffer.data() + Impl::kHeaderSize); + + const uint32_t crc = ailego::Crc32c::Hash(buffer.data(), data_size, 0); + const size_t padding_size = total_size - data_size; + + // Write data + padding to dumper + if (dumper->write(buffer.data(), total_size) != total_size) { + LOG_ERROR("RecordRotator::dump(dumper): write failed, seg=%s", + seg_id.c_str()); + return IndexError_WriteData; + } + + // Register segment + int ret = dumper->append(seg_id, data_size, padding_size, crc); + if (ret != 0) { + LOG_ERROR("RecordRotator::dump(dumper): append failed, seg=%s, ret=%d", + seg_id.c_str(), ret); + return ret; + } + + LOG_DEBUG( + "RecordRotator::dump(dumper) done: seg=%s, data_size=%zu, padding=%zu", + seg_id.c_str(), data_size, padding_size); + return 0; +} + +int RecordRotator::open(IndexStorage::Pointer storage, + const std::string &seg_id) { + if (!storage) { + LOG_ERROR("RecordRotator::open: null storage"); + return IndexError_InvalidArgument; + } + + auto segment = storage->get(seg_id); + if (!segment) { + LOG_ERROR("RecordRotator::open: segment '%s' not found", seg_id.c_str()); + return IndexError_InvalidFormat; + } + + // Read the rotator data from the segment (header + blob) + const size_t data_size = segment->data_size(); + if (data_size <= Impl::kHeaderSize) { + LOG_ERROR("RecordRotator::open: data too small (%zu bytes)", data_size); + return IndexError_InvalidFormat; + } + + IndexStorage::MemoryBlock block; + size_t read_size = segment->read(0, block, data_size); + if (read_size != data_size) { + LOG_ERROR("RecordRotator::open: read failed, read=%zu, expected=%zu", + read_size, data_size); + return IndexError_InvalidFormat; + } + + // Verify CRC if available (covers header + blob) + uint32_t expected_crc = segment->data_crc(); + if (expected_crc != 0) { + uint32_t actual_crc = ailego::Crc32c::Hash(block.data(), data_size, 0); + if (actual_crc != expected_crc) { + LOG_ERROR( + "RecordRotator::open: CRC mismatch, expected=0x%08x, actual=0x%08x", + expected_crc, actual_crc); + return IndexError_InvalidFormat; + } + } + + // Parse self-describing header (reserved field is ignored) + const char *raw = reinterpret_cast(block.data()); + Impl::Header header; + header.read_from(raw); + + impl_->type = static_cast(header.type); + impl_->dimension = static_cast(header.origin_dim); + + // Reconstruct the rotator from header info and load blob + if (impl_->type == RecordRotatorType::FhtKac) { + impl_->fht_impl = std::make_unique(); + impl_->fht_impl->flip.resize(4 * impl_->dimension / + FhtKacRotatorImpl::kByteLen); + impl_->fht_impl->trunc_dim = floor_pow2(impl_->dimension); + impl_->fht_impl->fac = + 1.0f / std::sqrt(static_cast(impl_->fht_impl->trunc_dim)); + impl_->fht_impl->load(raw + Impl::kHeaderSize); + } else { + impl_->mat_impl = std::make_unique(); + impl_->mat_impl->matrix.resize(impl_->dimension * impl_->dimension); + impl_->mat_impl->load(raw + Impl::kHeaderSize); + } + + LOG_DEBUG("RecordRotator::open done: seg=%s, dim=%zu, data_size=%zu", + seg_id.c_str(), impl_->dimension, data_size); + + return 0; +} + +int RecordRotator::load(const float *matrix, size_t dimension) { + if (!matrix) { + LOG_ERROR("RecordRotator::load: null matrix"); + return IndexError_InvalidArgument; + } + if (dimension == 0) { + LOG_ERROR("RecordRotator::load: invalid dim %zu", dimension); + return IndexError_InvalidArgument; + } + + impl_->dimension = dimension; + impl_->type = RecordRotatorType::Matrix; + impl_->mat_impl = std::make_unique(); + impl_->mat_impl->matrix.resize(dimension * dimension); + impl_->mat_impl->load(reinterpret_cast(matrix)); + + LOG_DEBUG("RecordRotator::load done: dim=%zu", dimension); + + return 0; +} + +size_t RecordRotator::dimension() const { + return impl_->dimension; +} + +RecordRotatorType RecordRotator::rotator_type() const { + return impl_->type; +} + +bool RecordRotator::initialized() const { + return impl_->fht_impl != nullptr || impl_->mat_impl != nullptr; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/quantizer/record_rotator.h b/src/core/quantizer/record_rotator.h new file mode 100644 index 000000000..cd60118ed --- /dev/null +++ b/src/core/quantizer/record_rotator.h @@ -0,0 +1,131 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include "zvec/core/framework/index_dumper.h" +#include "zvec/core/framework/index_storage.h" + +namespace zvec { +namespace core { + +//! Segment ID used when dumping/loading the rotator data +inline const std::string RECORD_ROTATOR_SEG_ID{"enable_rotate"}; + +//! Rotator type exposed without rabitqlib dependency +enum class RecordRotatorType : uint8_t { + FhtKac = 0, //!< O(d log d) FHT-based Kac random rotation (default) + Matrix = 1, //!< O(d^2) explicit random matrix rotation +}; + +/*! RecordRotator provides per-vector rotation without external dependencies. + * + * All rotation algorithms are implemented inline (FHT-based Kac walk and + * explicit random matrix), so no rabitqlib headers are required. + * + * Auto-selects the rotation algorithm based on dimension alignment: + * - dimension % 4 == 0 -> FhtKac (O(d log d), with scalar tails) + * - otherwise -> Matrix (O(d^2), no alignment requirement) + * + * Rotation preserves dimension: output size == input size (no padding). + * + * Used by IntegerStreamingConverter/Reformer and CosineConverter/Reformer + * when enable_rotate is true. + */ +class RecordRotator { + public: + RecordRotator(); + ~RecordRotator(); + + //! Move-only (pimpl with unique_ptr) + RecordRotator(RecordRotator &&) noexcept; + RecordRotator &operator=(RecordRotator &&) noexcept; + RecordRotator(const RecordRotator &) = delete; + RecordRotator &operator=(const RecordRotator &) = delete; + + //! Initialize the rotator. + //! Auto-selects FhtKac when dimension is 4-aligned, else falls back to + //! Matrix. The @p rotator_type parameter can force Matrix explicitly. + //! @param dimension vector dimension (input and output size) + //! @param rotator_type rotation algorithm (default: FhtKac, auto-degrades + //! to Matrix when dimension is not 4-aligned) + void init(size_t dimension, + RecordRotatorType rotator_type = RecordRotatorType::FhtKac); + + //! Rotate a single vector + //! @param in input vector of size >= dimension + //! @param out output buffer of size >= dimension + void rotate(const float *in, float *out) const; + + //! Rotate a single vector into a managed buffer + //! @param in input vector of size >= dimension + //! @return vector of size dimension containing rotated result + std::vector rotate(const float *in) const; + + //! Inverse-rotate a single vector (from rotated space back to original) + //! @param in input vector of size >= dimension (rotated vector) + //! @param out output buffer of size >= dimension (original space) + void unrotate(const float *in, float *out) const; + + //! Inverse-rotate a single vector into a managed buffer + //! @param in input vector of size >= dimension (rotated vector) + //! @return vector of size dimension containing inverse-rotated + //! result + std::vector unrotate(const float *in) const; + + //! Return the serialized size of the rotator in bytes (header + blob) + size_t dump_bytes() const; + + //! Dump the rotator to an IndexStorage as a named segment. + //! Same self-describing format as the dumper variant. + int dump(const IndexStorage::Pointer &storage, + const std::string &seg_id = RECORD_ROTATOR_SEG_ID) const; + + //! Dump the rotator to an IndexDumper as a named segment. + //! Format: [Header: type(1B)|origin_dim(4B)|reserved(4B)] [rotation blob] + //! Appends padding for 32-byte alignment. + int dump(const IndexDumper::Pointer &dumper, + const std::string &seg_id = RECORD_ROTATOR_SEG_ID) const; + + //! Open the rotator from an IndexStorage segment (self-describing, no init + //! needed). Parses header to get type/dimension, then reconstructs the + //! rotator. + int open(IndexStorage::Pointer storage, + const std::string &seg_id = RECORD_ROTATOR_SEG_ID); + + //! Load a user-specified rotation matrix. + //! Always uses MatrixRotator internally. + //! @param matrix row-major square matrix of shape dimension x dimension + //! @param dimension vector dimension + int load(const float *matrix, size_t dimension); + + //! Return the vector dimension + size_t dimension() const; + + //! Return the rotator type + RecordRotatorType rotator_type() const; + + //! Check if the rotator is initialized + bool initialized() const; + + private: + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/db/index/column/vector_column/engine_helper.hpp b/src/db/index/column/vector_column/engine_helper.hpp index f7b727b7d..7fa0c16d3 100644 --- a/src/db/index/column/vector_column/engine_helper.hpp +++ b/src/db/index/column/vector_column/engine_helper.hpp @@ -352,6 +352,7 @@ class ProximaEngineHelper { return tl::make_unexpected( Status::InvalidArgument("unsupported quantize type")); } + index_param_builder->WithEnableRotate(db_index_params->enable_rotate()); return index_param_builder; } diff --git a/src/db/index/common/proto_converter.cc b/src/db/index/common/proto_converter.cc index faf0cf0e3..80b4c61ca 100644 --- a/src/db/index/common/proto_converter.cc +++ b/src/db/index/common/proto_converter.cc @@ -18,11 +18,12 @@ namespace zvec { HnswIndexParams::OPtr ProtoConverter::FromPb( const proto::HnswIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); auto params = std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.m(), params_pb.ef_construction(), QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), - params_pb.use_contiguous_memory()); + params_pb.use_contiguous_memory(), QuantizerParam(enable_rotate)); return params; } @@ -33,6 +34,8 @@ proto::HnswIndexParams ProtoConverter::ToPb(const HnswIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_ef_construction(params->ef_construction()); params_pb.set_m(params->m()); params_pb.set_use_contiguous_memory(params->use_contiguous_memory()); @@ -68,9 +71,11 @@ proto::HnswRabitqIndexParams ProtoConverter::ToPb( // FlatIndexParams FlatIndexParams::OPtr ProtoConverter::FromPb( const proto::FlatIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::FlatIndexParams ProtoConverter::ToPb(const FlatIndexParams *params) { @@ -79,16 +84,20 @@ proto::FlatIndexParams ProtoConverter::ToPb(const FlatIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); return params_pb; } // IVFIndexParams IVFIndexParams::OPtr ProtoConverter::FromPb( const proto::IVFIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.n_list(), params_pb.n_iters(), params_pb.use_soar(), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { @@ -97,6 +106,8 @@ proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_n_list(params->n_list()); params_pb.set_n_iters(params->n_iters()); params_pb.set_use_soar(params->use_soar()); @@ -106,12 +117,14 @@ proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { // VamanaIndexParams VamanaIndexParams::OPtr ProtoConverter::FromPb( const proto::VamanaIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.max_degree(), params_pb.search_list_size(), params_pb.alpha(), params_pb.saturate_graph(), params_pb.use_contiguous_memory(), params_pb.use_id_map(), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::VamanaIndexParams ProtoConverter::ToPb(const VamanaIndexParams *params) { @@ -120,6 +133,8 @@ proto::VamanaIndexParams ProtoConverter::ToPb(const VamanaIndexParams *params) { MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_max_degree(params->max_degree()); params_pb.set_search_list_size(params->search_list_size()); params_pb.set_alpha(params->alpha()); @@ -147,10 +162,12 @@ proto::InvertIndexParams ProtoConverter::ToPb(const InvertIndexParams *params) { // DiskAnnIndexParams DiskAnnIndexParams::OPtr ProtoConverter::FromPb( const proto::DiskAnnIndexParams ¶ms_pb) { + bool enable_rotate = params_pb.base().quantizer_param().enable_rotate(); return std::make_shared( MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.max_degree(), params_pb.list_size(), params_pb.pq_chunk_num(), - QuantizeTypeCodeBook::Get(params_pb.base().quantize_type())); + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + QuantizerParam(enable_rotate)); } proto::DiskAnnIndexParams ProtoConverter::ToPb( @@ -160,6 +177,8 @@ proto::DiskAnnIndexParams ProtoConverter::ToPb( MetricTypeCodeBook::Get(params->metric_type())); params_pb.mutable_base()->set_quantize_type( QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.mutable_base()->mutable_quantizer_param()->set_enable_rotate( + params->quantizer_param().enable_rotate()); params_pb.set_max_degree(params->max_degree()); params_pb.set_list_size(params->list_size()); params_pb.set_pq_chunk_num(params->pq_chunk_num()); diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 2e334c3fc..9b613ae97 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -4028,7 +4028,8 @@ Status SegmentImpl::load_vector_index_blocks() { if (!segment_meta_->vector_indexed(column)) { new_field_params.set_index_params(MakeDefaultQuantVectorIndexParams( vector_index_params->metric_type(), - vector_index_params->quantize_type())); + vector_index_params->quantize_type(), + vector_index_params->quantizer_param())); } } @@ -4163,7 +4164,8 @@ Status SegmentImpl::init_memory_components() { block_id = allocate_block_id(); FieldSchema normal_quant_field(*field); normal_quant_field.set_index_params(MakeDefaultQuantVectorIndexParams( - index_params->metric_type(), index_params->quantize_type())); + index_params->metric_type(), index_params->quantize_type(), + index_params->quantizer_param())); auto quant_vector_indexer = create_vector_indexer( field->name(), normal_quant_field, block_id, true); diff --git a/src/db/proto/zvec.proto b/src/db/proto/zvec.proto index ad6cfb158..f2c18f5ad 100644 --- a/src/db/proto/zvec.proto +++ b/src/db/proto/zvec.proto @@ -87,9 +87,19 @@ message InvertIndexParams { bool enable_range_optimization = 1; }; +// Quantizer-related parameters for vector indexes. +// Designed for future extensibility. +message QuantizerParam { + // When enabled, vectors are rotated before INT8 quantization to reduce + // quantization error. Only effective with quantize_type=INT8. + bool enable_rotate = 1; +}; + message BaseIndexParams { MetricType metric_type = 1; QuantizeType quantize_type = 2; + // Quantizer parameters (enable_rotate, etc.) + QuantizerParam quantizer_param = 4; }; message HnswIndexParams { diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index d02335cb3..3f3e38638 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -961,6 +961,30 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_index_params_set_quantize_type( ZVEC_EXPORT zvec_quantize_type_t ZVEC_CALL zvec_index_params_get_quantize_type(const zvec_index_params_t *params); +/** + * @brief Set enable_rotate for quantizer (only effective with INT8 quantize + * type) + * + * When enabled, vectors are randomly rotated before INT8 quantization to + * reduce quantization error. The rotation matrix is stored with the index + * and automatically applied to query vectors at search time. + * + * @param params Index parameters (must be vector index type) + * @param enable_rotate Whether to enable random rotation before quantization + * @return ZVEC_OK on success, error code on failure + */ +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_index_params_set_quantizer_enable_rotate(zvec_index_params_t *params, + bool enable_rotate); + +/** + * @brief Get enable_rotate setting from quantizer parameters + * @param params Index parameters (must not be NULL) + * @return true if rotation is enabled, false otherwise (default) + */ +ZVEC_EXPORT bool ZVEC_CALL zvec_index_params_get_quantizer_enable_rotate( + const zvec_index_params_t *params); + /** * @brief Set HNSW specific parameters * @param params Index parameters (must be HNSW type) diff --git a/src/include/zvec/core/framework/index_converter.h b/src/include/zvec/core/framework/index_converter.h index 53ac1c7a2..4dc26468f 100644 --- a/src/include/zvec/core/framework/index_converter.h +++ b/src/include/zvec/core/framework/index_converter.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "zvec/core/framework/index_reformer.h" namespace zvec { @@ -196,6 +197,13 @@ class IndexConverter : public IndexModule { //! Dump index into storage virtual int dump(const IndexDumper::Pointer &dumper) = 0; + //! Dump converter state (e.g. rotator) to IndexStorage for streaming build. + //! Default is no-op; override in subclasses that need storage persistence. + virtual int dump_to_storage(const IndexStorage::Pointer &storage) { + (void)storage; + return 0; + } + //! Retrieve statistics virtual const Stats &stats(void) const = 0; diff --git a/src/include/zvec/core/interface/index_param.h b/src/include/zvec/core/interface/index_param.h index 186c160f6..5d4e8a206 100644 --- a/src/include/zvec/core/interface/index_param.h +++ b/src/include/zvec/core/interface/index_param.h @@ -122,12 +122,17 @@ struct QuantizerParam : public SerializableBase { QuantizerType type = QuantizerType::kNone; int num_subquantizers = 8; // M int num_bits = 8; // bits per subquantizer + bool enable_rotate = + false; // rotate vectors before quantization to reduce error // Constructors // QuantizerParam() = default; QuantizerParam(QuantizerType t = QuantizerType::kNone, int subquantizers = 8, - int bits = 8) - : type(t), num_subquantizers(subquantizers), num_bits(bits) {} + int bits = 8, bool rotate = false) + : type(t), + num_subquantizers(subquantizers), + num_bits(bits), + enable_rotate(rotate) {} protected: diff --git a/src/include/zvec/core/interface/index_param_builders.h b/src/include/zvec/core/interface/index_param_builders.h index 8b009c135..971fc615a 100644 --- a/src/include/zvec/core/interface/index_param_builders.h +++ b/src/include/zvec/core/interface/index_param_builders.h @@ -88,6 +88,11 @@ class BaseIndexParamBuilder { // : public return static_cast(*this); } + ActualIndexParamBuilderType &WithEnableRotate(bool enable_rotate) { + param->quantizer_param.enable_rotate = enable_rotate; + return static_cast(*this); + } + ActualIndexParamBuilderType &WithUseExternalVector(bool use_external_vector) { param->use_external_vector = use_external_vector; return static_cast(*this); diff --git a/src/include/zvec/db/index_params.h b/src/include/zvec/db/index_params.h index c19cf8028..a4c2654d9 100644 --- a/src/include/zvec/db/index_params.h +++ b/src/include/zvec/db/index_params.h @@ -118,16 +118,50 @@ class InvertIndexParams : public IndexParams { bool enable_extended_wildcard_{false}; }; +/* + * Quantizer parameters for vector indexes. + * Encapsulates quantization-related settings such as enable_rotate. + * Designed for future extensibility (e.g., num_bits, calibration_size). + */ +class QuantizerParam { + public: + QuantizerParam() = default; + explicit QuantizerParam(bool enable_rotate) : enable_rotate_(enable_rotate) {} + + bool enable_rotate() const { + return enable_rotate_; + } + + void set_enable_rotate(bool v) { + enable_rotate_ = v; + } + + bool operator==(const QuantizerParam &other) const { + return enable_rotate_ == other.enable_rotate_; + } + + bool operator!=(const QuantizerParam &other) const { + return !(*this == other); + } + + private: + // When enabled, vectors are rotated before INT8 quantization to reduce + // quantization error. Only effective with quantize_type=INT8. + bool enable_rotate_{false}; +}; + /* * Column index params */ class VectorIndexParams : public IndexParams { public: VectorIndexParams(IndexType type, MetricType metric_type, - QuantizeType quantize_type = QuantizeType::UNDEFINED) + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) : IndexParams(type), metric_type_(metric_type), - quantize_type_(quantize_type) {} + quantize_type_(quantize_type), + quantizer_param_(quantizer_param) {} ~VectorIndexParams() override = default; @@ -151,9 +185,23 @@ class VectorIndexParams : public IndexParams { quantize_type_ = quantize_type; } + const QuantizerParam &quantizer_param() const { + return quantizer_param_; + } + + void set_quantizer_param(const QuantizerParam &quantizer_param) { + quantizer_param_ = quantizer_param; + } + + // Convenience getter for internal use (engine_helper, segment, etc.) + bool enable_rotate() const { + return quantizer_param_.enable_rotate(); + } + protected: MetricType metric_type_; QuantizeType quantize_type_; + QuantizerParam quantizer_param_; }; /* @@ -165,8 +213,9 @@ class HnswIndexParams : public VectorIndexParams { MetricType metric_type, int m = core_interface::kDefaultHnswNeighborCnt, int ef_construction = core_interface::kDefaultHnswEfConstruction, QuantizeType quantize_type = QuantizeType::UNDEFINED, - bool use_contiguous_memory = false) - : VectorIndexParams(IndexType::HNSW, metric_type, quantize_type), + bool use_contiguous_memory = false, QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::HNSW, metric_type, quantize_type, + quantizer_param), m_(m), ef_construction_(ef_construction), use_contiguous_memory_(use_contiguous_memory) {} @@ -175,9 +224,9 @@ class HnswIndexParams : public VectorIndexParams { public: Ptr clone() const override { - return std::make_shared(metric_type_, m_, ef_construction_, - quantize_type_, - use_contiguous_memory_); + return std::make_shared( + metric_type_, m_, ef_construction_, quantize_type_, + use_contiguous_memory_, quantizer_param_); } std::string to_string() const override { @@ -186,7 +235,8 @@ class HnswIndexParams : public VectorIndexParams { std::ostringstream oss; oss << base_str << ",m:" << m_ << ",ef_construction:" << ef_construction_ << ",use_contiguous_memory:" - << (use_contiguous_memory_ ? "true" : "false") << "}"; + << (use_contiguous_memory_ ? "true" : "false") << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -200,7 +250,9 @@ class HnswIndexParams : public VectorIndexParams { quantize_type() == static_cast(other).quantize_type() && use_contiguous_memory_ == static_cast(other) - .use_contiguous_memory_; + .use_contiguous_memory_ && + quantizer_param_ == + static_cast(other).quantizer_param_; } void set_m(int m) { @@ -348,21 +400,25 @@ class HnswRabitqIndexParams : public VectorIndexParams { class FlatIndexParams : public VectorIndexParams { public: FlatIndexParams(MetricType metric_type, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::FLAT, metric_type, quantize_type) {} + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::FLAT, metric_type, quantize_type, + quantizer_param) {} using OPtr = std::shared_ptr; public: Ptr clone() const override { - return std::make_shared(metric_type_, quantize_type_); + return std::make_shared(metric_type_, quantize_type_, + quantizer_param_); } std::string to_string() const override { auto base_str = vector_index_params_to_string("FlatIndexParams", metric_type_, quantize_type_); std::ostringstream oss; - oss << base_str << "}"; + oss << base_str << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -371,7 +427,9 @@ class FlatIndexParams : public VectorIndexParams { metric_type() == static_cast(other).metric_type() && quantize_type() == - static_cast(other).quantize_type(); + static_cast(other).quantize_type() && + quantizer_param() == + static_cast(other).quantizer_param(); } }; @@ -383,16 +441,19 @@ inline FlatIndexParams MakeDefaultVectorIndexParams(MetricType metric_type) { } inline FlatIndexParams MakeDefaultQuantVectorIndexParams( - MetricType metric_type, QuantizeType quantize_type) { - return FlatIndexParams(metric_type, quantize_type); + MetricType metric_type, QuantizeType quantize_type, + QuantizerParam quantizer_param = {}) { + return FlatIndexParams(metric_type, quantize_type, quantizer_param); } class IVFIndexParams : public VectorIndexParams { public: IVFIndexParams(MetricType metric_type, int n_list = 1024, int n_iters = 10, bool use_soar = false, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::IVF, metric_type, quantize_type), + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::IVF, metric_type, quantize_type, + quantizer_param), n_list_(n_list), n_iters_(n_iters), use_soar_(use_soar) {} @@ -402,14 +463,17 @@ class IVFIndexParams : public VectorIndexParams { public: Ptr clone() const override { return std::make_shared(metric_type_, n_list_, n_iters_, - use_soar_, quantize_type_); + use_soar_, quantize_type_, + quantizer_param_); } std::string to_string() const override { auto base_str = vector_index_params_to_string("IVFIndexParams", metric_type_, quantize_type_); std::ostringstream oss; - oss << base_str << ",n_list:" << n_list_ << ",n_iters:" << n_iters_ << "}"; + oss << base_str << ",n_list:" << n_list_ << ",n_iters:" << n_iters_ + << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -445,7 +509,9 @@ class IVFIndexParams : public VectorIndexParams { n_iters_ == static_cast(other).n_iters_ && use_soar_ == static_cast(other).use_soar_ && quantize_type() == - static_cast(other).quantize_type(); + static_cast(other).quantize_type() && + quantizer_param_ == + static_cast(other).quantizer_param_; } private: @@ -458,8 +524,10 @@ class DiskAnnIndexParams : public VectorIndexParams { public: DiskAnnIndexParams(MetricType metric_type, int max_degree = 100, int list_size = 50, int pq_chunk_num = 0, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::DISKANN, metric_type, quantize_type), + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::DISKANN, metric_type, quantize_type, + quantizer_param), max_degree_{max_degree}, list_size_{list_size}, pq_chunk_num_{pq_chunk_num} {} @@ -469,7 +537,8 @@ class DiskAnnIndexParams : public VectorIndexParams { public: Ptr clone() const override { return std::make_shared( - metric_type_, max_degree_, list_size_, pq_chunk_num_, quantize_type_); + metric_type_, max_degree_, list_size_, pq_chunk_num_, quantize_type_, + quantizer_param_); } std::string to_string() const override { @@ -478,7 +547,8 @@ class DiskAnnIndexParams : public VectorIndexParams { std::ostringstream oss; oss << base_str << ",max_degree:" << max_degree_ << ",list_size:" << list_size_ << ", pq_chunk_num:" << pq_chunk_num_ - << "}"; + << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -517,7 +587,9 @@ class DiskAnnIndexParams : public VectorIndexParams { pq_chunk_num_ == static_cast(other).pq_chunk_num_ && quantize_type() == - static_cast(other).quantize_type(); + static_cast(other).quantize_type() && + quantizer_param_ == + static_cast(other).quantizer_param_; } private: @@ -538,8 +610,10 @@ class VamanaIndexParams : public VectorIndexParams { float alpha = core_interface::kDefaultVamanaAlpha, bool saturate_graph = core_interface::kDefaultVamanaSaturateGraph, bool use_contiguous_memory = false, bool use_id_map = false, - QuantizeType quantize_type = QuantizeType::UNDEFINED) - : VectorIndexParams(IndexType::VAMANA, metric_type, quantize_type), + QuantizeType quantize_type = QuantizeType::UNDEFINED, + QuantizerParam quantizer_param = {}) + : VectorIndexParams(IndexType::VAMANA, metric_type, quantize_type, + quantizer_param), max_degree_(max_degree), search_list_size_(search_list_size), alpha_(alpha), @@ -553,7 +627,7 @@ class VamanaIndexParams : public VectorIndexParams { Ptr clone() const override { return std::make_shared( metric_type_, max_degree_, search_list_size_, alpha_, saturate_graph_, - use_contiguous_memory_, use_id_map_, quantize_type_); + use_contiguous_memory_, use_id_map_, quantize_type_, quantizer_param_); } std::string to_string() const override { @@ -565,7 +639,9 @@ class VamanaIndexParams : public VectorIndexParams { << ",saturate_graph:" << (saturate_graph_ ? "true" : "false") << ",use_contiguous_memory:" << (use_contiguous_memory_ ? "true" : "false") - << ",use_id_map:" << (use_id_map_ ? "true" : "false") << "}"; + << ",use_id_map:" << (use_id_map_ ? "true" : "false") + << ",enable_rotate:" + << (quantizer_param_.enable_rotate() ? "true" : "false") << "}"; return oss.str(); } @@ -580,7 +656,8 @@ class VamanaIndexParams : public VectorIndexParams { search_list_size_ == rhs.search_list_size_ && alpha_ == rhs.alpha_ && saturate_graph_ == rhs.saturate_graph_ && use_contiguous_memory_ == rhs.use_contiguous_memory_ && - use_id_map_ == rhs.use_id_map_; + use_id_map_ == rhs.use_id_map_ && + quantizer_param_ == rhs.quantizer_param_; } int max_degree() const { diff --git a/tests/c/c_api_test.c b/tests/c/c_api_test.c index 8670ff845..746075d68 100644 --- a/tests/c/c_api_test.c +++ b/tests/c/c_api_test.c @@ -3491,6 +3491,61 @@ void test_index_params_functions(void) { TEST_END(); } +void test_quantizer_enable_rotate(void) { + TEST_START(); + + // Test 1: set enable_rotate=true on HNSW params and verify + zvec_index_params_t *hnsw_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_HNSW); + TEST_ASSERT(hnsw_params != NULL); + + // Default should be false + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(hnsw_params) == + false); + + // Set to true and verify + zvec_error_code_t err = + zvec_index_params_set_quantizer_enable_rotate(hnsw_params, true); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(hnsw_params) == + true); + + // Set back to false and verify + err = zvec_index_params_set_quantizer_enable_rotate(hnsw_params, false); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(hnsw_params) == + false); + + zvec_index_params_destroy(hnsw_params); + + // Test 2: set enable_rotate on FLAT index params (also a vector index) + zvec_index_params_t *flat_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_FLAT); + TEST_ASSERT(flat_params != NULL); + err = zvec_index_params_set_quantizer_enable_rotate(flat_params, true); + TEST_ASSERT(err == ZVEC_OK); + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(flat_params) == + true); + zvec_index_params_destroy(flat_params); + + // Test 3: set enable_rotate on non-vector index (INVERT) should fail + zvec_index_params_t *invert_params = + zvec_index_params_create(ZVEC_INDEX_TYPE_INVERT); + TEST_ASSERT(invert_params != NULL); + err = zvec_index_params_set_quantizer_enable_rotate(invert_params, true); + TEST_ASSERT(err != ZVEC_OK); + zvec_index_params_destroy(invert_params); + + // Test 4: NULL params should return false for getter + TEST_ASSERT(zvec_index_params_get_quantizer_enable_rotate(NULL) == false); + + // Test 5: NULL params should return error for setter + err = zvec_index_params_set_quantizer_enable_rotate(NULL, true); + TEST_ASSERT(err != ZVEC_OK); + + TEST_END(); +} + void test_index_params_api_functions(void) { TEST_START(); @@ -5953,6 +6008,7 @@ int main(void) { // Index tests test_index_params(); test_index_params_functions(); + test_quantizer_enable_rotate(); test_index_params_api_functions(); test_index_creation_and_management(); diff --git a/tests/core/quantizer/integer_quantizer_reformer_test.cc b/tests/core/quantizer/integer_quantizer_reformer_test.cc index 21967bb23..104c3c28d 100644 --- a/tests/core/quantizer/integer_quantizer_reformer_test.cc +++ b/tests/core/quantizer/integer_quantizer_reformer_test.cc @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include #include #include +#include "quantizer/record_rotator.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_holder.h" @@ -821,3 +824,133 @@ TEST(IntegerReformer, Int4InitConverterWithTrainedParams) { EXPECT_EQ(buffer, buffer2); } } + +// Test FhtKac rotator (dim=64, power of 2, hot path) +TEST(RecordRotatorTest, RotateUnrotateFhtKac) { + const size_t dim = 64; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=64) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test Matrix rotator (dim=15, odd, not 4-aligned, auto-fallback) +TEST(RecordRotatorTest, RotateUnrotateMatrix) { + const size_t dim = 15; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::Matrix); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "Matrix (dim=15) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test FhtKac rotator (dim=100, 4-aligned but not 16/32/64-aligned) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim100) { + const size_t dim = 100; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=100) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test FhtKac rotator (dim=200, 4-aligned, non-power-of-2 kacs_walk path) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim200) { + const size_t dim = 200; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=200) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} + +// Test FhtKac rotator (dim=96, 32-aligned but not 64-aligned, kacs_walk path) +TEST(RecordRotatorTest, RotateUnrotateFhtKac_Dim96) { + const size_t dim = 96; + RecordRotator rotator; + rotator.init(dim); + EXPECT_EQ(rotator.rotator_type(), RecordRotatorType::FhtKac); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + std::vector original(dim); + for (size_t j = 0; j < dim; ++j) original[j] = dist(gen); + + std::vector rotated(dim); + rotator.rotate(original.data(), rotated.data()); + + std::vector recovered(dim); + rotator.unrotate(rotated.data(), recovered.data()); + + float max_err = 0.0f; + for (size_t j = 0; j < dim; ++j) + max_err = std::max(max_err, std::abs(recovered[j] - original[j])); + std::cout << "FhtKac (dim=96) max error: " << max_err << std::endl; + EXPECT_LT(max_err, 1e-3f); +} diff --git a/tests/db/index/common/db_proto_converter_test.cc b/tests/db/index/common/db_proto_converter_test.cc index dff93e9dd..9c71c3c89 100644 --- a/tests/db/index/common/db_proto_converter_test.cc +++ b/tests/db/index/common/db_proto_converter_test.cc @@ -470,4 +470,81 @@ TEST(ConverterTest, SegmentMetaWithEmptyFields) { EXPECT_EQ(pb_result.persisted_blocks_size(), 0); EXPECT_FALSE(pb_result.has_writing_forward_block()); EXPECT_EQ(pb_result.indexed_vector_fields_size(), 0); +} + +// ==================== enable_rotate roundtrip tests ==================== + +TEST(ConverterTest, HnswIndexParamsWithEnableRotate) { + // C++ -> PB -> C++ roundtrip with enable_rotate = true + HnswIndexParams original(MetricType::COSINE, 16, 200, QuantizeType::INT8, + false, QuantizerParam(true)); + EXPECT_TRUE(original.quantizer_param().enable_rotate()); + + auto pb = ProtoConverter::ToPb(&original); + EXPECT_TRUE(pb.base().quantizer_param().enable_rotate()); + + auto restored = ProtoConverter::FromPb(pb); + ASSERT_NE(restored, nullptr); + EXPECT_TRUE(restored->quantizer_param().enable_rotate()); + EXPECT_TRUE(restored->enable_rotate()); // convenience getter + EXPECT_EQ(restored->metric_type(), MetricType::COSINE); + EXPECT_EQ(restored->m(), 16); + EXPECT_EQ(restored->ef_construction(), 200); + EXPECT_EQ(restored->quantize_type(), QuantizeType::INT8); + + // C++ -> PB -> C++ roundtrip with enable_rotate = false + HnswIndexParams original_no_rot(MetricType::L2, 32, 100, QuantizeType::FP16); + auto pb2 = ProtoConverter::ToPb(&original_no_rot); + EXPECT_FALSE(pb2.base().quantizer_param().enable_rotate()); + auto restored2 = ProtoConverter::FromPb(pb2); + ASSERT_NE(restored2, nullptr); + EXPECT_FALSE(restored2->quantizer_param().enable_rotate()); +} + +TEST(ConverterTest, FlatIndexParamsWithEnableRotate) { + FlatIndexParams original(MetricType::IP, QuantizeType::INT8, + QuantizerParam(true)); + EXPECT_TRUE(original.quantizer_param().enable_rotate()); + + auto pb = ProtoConverter::ToPb(&original); + EXPECT_TRUE(pb.base().quantizer_param().enable_rotate()); + + auto restored = ProtoConverter::FromPb(pb); + ASSERT_NE(restored, nullptr); + EXPECT_TRUE(restored->quantizer_param().enable_rotate()); + EXPECT_EQ(restored->metric_type(), MetricType::IP); + EXPECT_EQ(restored->quantize_type(), QuantizeType::INT8); + + // enable_rotate = false + FlatIndexParams original_no_rot(MetricType::L2, QuantizeType::FP16); + auto pb2 = ProtoConverter::ToPb(&original_no_rot); + EXPECT_FALSE(pb2.base().quantizer_param().enable_rotate()); + auto restored2 = ProtoConverter::FromPb(pb2); + EXPECT_FALSE(restored2->quantizer_param().enable_rotate()); +} + +TEST(ConverterTest, IVFIndexParamsWithEnableRotate) { + IVFIndexParams original(MetricType::COSINE, 256, 20, true, QuantizeType::INT8, + QuantizerParam(true)); + EXPECT_TRUE(original.quantizer_param().enable_rotate()); + + auto pb = ProtoConverter::ToPb(&original); + EXPECT_TRUE(pb.base().quantizer_param().enable_rotate()); + + auto restored = ProtoConverter::FromPb(pb); + ASSERT_NE(restored, nullptr); + EXPECT_TRUE(restored->quantizer_param().enable_rotate()); + EXPECT_EQ(restored->metric_type(), MetricType::COSINE); + EXPECT_EQ(restored->n_list(), 256); + EXPECT_EQ(restored->n_iters(), 20); + EXPECT_TRUE(restored->use_soar()); + EXPECT_EQ(restored->quantize_type(), QuantizeType::INT8); + + // enable_rotate = false + IVFIndexParams original_no_rot(MetricType::L2, 128, 10, false, + QuantizeType::FP16); + auto pb2 = ProtoConverter::ToPb(&original_no_rot); + EXPECT_FALSE(pb2.base().quantizer_param().enable_rotate()); + auto restored2 = ProtoConverter::FromPb(pb2); + EXPECT_FALSE(restored2->quantizer_param().enable_rotate()); } \ No newline at end of file diff --git a/tests/db/index/common/index_params_test.cc b/tests/db/index/common/index_params_test.cc index af67e7398..d5a85aeb9 100644 --- a/tests/db/index/common/index_params_test.cc +++ b/tests/db/index/common/index_params_test.cc @@ -186,4 +186,96 @@ TEST(IndexParamsTest, DynamicPointerCast) { IndexParams &base_ref = *base_ptr; auto &hnsw_ref = dynamic_cast(base_ref); EXPECT_EQ(hnsw_ref.type(), IndexType::HNSW); +} + +// ==================== QuantizerParam tests ==================== + +TEST(IndexParamsTest, QuantizerParamBasic) { + // Default constructor: enable_rotate should be false + QuantizerParam qp_default; + EXPECT_FALSE(qp_default.enable_rotate()); + + // Constructor with true + QuantizerParam qp_true(true); + EXPECT_TRUE(qp_true.enable_rotate()); + + // Constructor with false + QuantizerParam qp_false(false); + EXPECT_FALSE(qp_false.enable_rotate()); + + // Setter + qp_default.set_enable_rotate(true); + EXPECT_TRUE(qp_default.enable_rotate()); + qp_default.set_enable_rotate(false); + EXPECT_FALSE(qp_default.enable_rotate()); + + // Equality + EXPECT_TRUE(qp_true == QuantizerParam(true)); + EXPECT_TRUE(qp_false == QuantizerParam(false)); + EXPECT_FALSE(qp_true == qp_false); + + // Inequality + EXPECT_TRUE(qp_true != qp_false); + EXPECT_FALSE(qp_true != QuantizerParam(true)); +} + +TEST(IndexParamsTest, QuantizerParamWithVectorIndex) { + // HnswIndexParams + { + HnswIndexParams params(MetricType::COSINE, 16, 100, QuantizeType::INT8); + EXPECT_FALSE(params.quantizer_param().enable_rotate()); + EXPECT_FALSE(params.enable_rotate()); // convenience getter + + params.set_quantizer_param(QuantizerParam(true)); + EXPECT_TRUE(params.quantizer_param().enable_rotate()); + EXPECT_TRUE(params.enable_rotate()); + + // Clone preserves quantizer_param + auto cloned = params.clone(); + auto *cloned_hnsw = dynamic_cast(cloned.get()); + ASSERT_NE(cloned_hnsw, nullptr); + EXPECT_TRUE(cloned_hnsw->quantizer_param().enable_rotate()); + EXPECT_TRUE(*cloned == params); + + // Equality: different enable_rotate -> not equal + HnswIndexParams params2(MetricType::COSINE, 16, 100, QuantizeType::INT8); + params2.set_quantizer_param(QuantizerParam(false)); + EXPECT_FALSE(params == params2); + } + + // FlatIndexParams + { + FlatIndexParams params(MetricType::L2, QuantizeType::INT8); + EXPECT_FALSE(params.quantizer_param().enable_rotate()); + + params.set_quantizer_param(QuantizerParam(true)); + EXPECT_TRUE(params.quantizer_param().enable_rotate()); + EXPECT_TRUE(params.enable_rotate()); + + auto cloned = params.clone(); + auto *cloned_flat = dynamic_cast(cloned.get()); + ASSERT_NE(cloned_flat, nullptr); + EXPECT_TRUE(cloned_flat->quantizer_param().enable_rotate()); + + FlatIndexParams params2(MetricType::L2, QuantizeType::INT8); + EXPECT_FALSE(params == params2); + } + + // IVFIndexParams + { + IVFIndexParams params(MetricType::IP, 128, 10, false, QuantizeType::INT8); + EXPECT_FALSE(params.quantizer_param().enable_rotate()); + + params.set_quantizer_param(QuantizerParam(true)); + EXPECT_TRUE(params.quantizer_param().enable_rotate()); + EXPECT_TRUE(params.enable_rotate()); + + auto cloned = params.clone(); + auto *cloned_ivf = dynamic_cast(cloned.get()); + ASSERT_NE(cloned_ivf, nullptr); + EXPECT_TRUE(cloned_ivf->quantizer_param().enable_rotate()); + + IVFIndexParams params2(MetricType::IP, 128, 10, false, QuantizeType::INT8); + EXPECT_FALSE(params == params2); + } } \ No newline at end of file diff --git a/tools/core/local_builder.cc b/tools/core/local_builder.cc index 52ae8321d..7d3b7bf0a 100644 --- a/tools/core/local_builder.cc +++ b/tools/core/local_builder.cc @@ -35,7 +35,6 @@ #include "zvec/core/framework/index_reformer.h" #include "zvec/core/framework/index_streamer.h" #include "index_meta_helper.h" -#include "meta_segment_common.h" #include "vecs_index_holder.h" #ifdef __clang__ @@ -206,10 +205,6 @@ bool check_config(YAML::Node &config_root) { return false; } } - if (!common["DumpPath"]) { - LOG_ERROR("Can not find [DumpPath] in config"); - return false; - } if (!config_root["BuilderParams"]) { LOG_ERROR("Can not find [BuilderParams] in config"); return false; @@ -217,75 +212,6 @@ bool check_config(YAML::Node &config_root) { return true; } -static inline size_t AlignSize(size_t size) { - return (size + 0x1F) & (~0x1F); -} - -bool dump_meta_segment(const IndexDumper::Pointer &dumper, - const std::string &segment_id, const void *data, - size_t size, size_t &writes) { - size_t len = dumper->write(data, size); - if (len != size) { - LOG_ERROR("Dump segment %s data failed, expect: %lu, actual: %lu", - segment_id.c_str(), size, len); - return false; - } - - size_t padding_size = AlignSize(size) - size; - if (padding_size > 0) { - std::string padding(padding_size, '\0'); - if (dumper->write(padding.data(), padding_size) != padding_size) { - LOG_ERROR("Append padding failed, size %lu", padding_size); - return false; - } - } - - uint32_t crc = ailego::Crc32c::Hash(data, size); - int ret = dumper->append(segment_id, size, padding_size, crc); - if (ret != 0) { - LOG_ERROR("Dump segment %s meta failed, ret=%d", segment_id.c_str(), ret); - return false; - } - - writes = len + padding_size; - - return true; -} - -int dump_taglist(IndexDumper::Pointer dumper, size_t num_vecs, - const void *key_base, const void *taglist_data, - uint64_t taglist_size) { - TagListHeader taglist_header; - - taglist_header.num_vecs = num_vecs; - - size_t total_writes; - - bool ret = - dump_meta_segment(dumper, TAGLIST_HEADER_SEGMENT_NAME, &taglist_header, - sizeof(TagListHeader), total_writes); - if (ret == false) { - LOG_ERROR("dump taglist meta failed"); - return IndexError_WriteData; - } - - ret = dump_meta_segment(dumper, TAGLIST_KEY_SEGMENT_NAME, key_base, - num_vecs * sizeof(uint64_t), total_writes); - if (ret == false) { - LOG_ERROR("dump taglist key failed"); - return IndexError_WriteData; - } - - ret = dump_meta_segment(dumper, TAGLIST_DATA_SEGMENT_NAME, taglist_data, - taglist_size, total_writes); - if (ret == false) { - LOG_ERROR("dump taglist data failed"); - return IndexError_WriteData; - } - - return 0; -} - int do_build_sparse_by_streamer(IndexStreamer::Pointer &streamer, uint32_t thread_count) { int ret; @@ -422,7 +348,8 @@ int do_build_sparse_by_streamer(IndexStreamer::Pointer &streamer, } int build_sparse_by_streamer(IndexStreamer::Pointer &streamer, - YAML::Node &config_common) { + YAML::Node &config_common, + const IndexConverter::Pointer &converter) { if (!config_common["IndexPath"]) { LOG_ERROR("Miss params IndexPath for Streamer"); return IndexError_InvalidArgument; @@ -451,6 +378,15 @@ int build_sparse_by_streamer(IndexStreamer::Pointer &streamer, return IndexError_Runtime; } + // Dump converter state (e.g. rotator) to storage for streaming build + if (converter) { + ret = converter->dump_to_storage(storage); + if (ret != 0) { + LOG_ERROR("Failed to dump converter to storage, ret=%d", ret); + return ret; + } + } + size_t thread_count = config_common["ThreadCount"] ? config_common["ThreadCount"].as() : std::thread::hardware_concurrency(); @@ -464,7 +400,8 @@ int build_sparse_by_streamer(IndexStreamer::Pointer &streamer, } int do_build_by_streamer(IndexStreamer::Pointer &streamer, - uint32_t thread_count, RetrievalMode retrieval_mode) { + uint32_t thread_count, RetrievalMode retrieval_mode, + const IndexStorage::Pointer &storage = nullptr) { int ret; ailego::ThreadPool pool(thread_count, false); std::atomic finished{0}; @@ -486,6 +423,14 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, return IndexError_NoExist; } reformer->init(meta.reformer_params()); + // Load reformer state from storage (e.g. rotator for IntegerStreaming) + if (storage) { + ret = reformer->load(storage); + if (ret != 0) { + LOG_ERROR("Failed to load reformer from storage, ret=%d", ret); + return ret; + } + } } } @@ -593,7 +538,8 @@ int do_build_by_streamer(IndexStreamer::Pointer &streamer, } int build_by_streamer(IndexStreamer::Pointer &streamer, - YAML::Node &config_common) { + YAML::Node &config_common, + const IndexConverter::Pointer &converter) { if (!config_common["IndexPath"]) { LOG_ERROR("Miss params IndexPath for Streamer"); return IndexError_InvalidArgument; @@ -624,6 +570,15 @@ int build_by_streamer(IndexStreamer::Pointer &streamer, return IndexError_Runtime; } + // Dump converter state (e.g. rotator) to storage for streaming build + if (converter) { + ret = converter->dump_to_storage(storage); + if (ret != 0) { + LOG_ERROR("Failed to dump converter to storage, ret=%d", ret); + return ret; + } + } + size_t thread_count = config_common["ThreadCount"] ? config_common["ThreadCount"].as() : std::thread::hardware_concurrency(); @@ -639,14 +594,15 @@ int build_by_streamer(IndexStreamer::Pointer &streamer, LOG_DEBUG("thread count: %zu, retrieval mode: %s", thread_count, retrieval_mode == 1 ? "Dense" : "Sparse"); - do_build_by_streamer(streamer, thread_count, retrieval_mode); + do_build_by_streamer(streamer, thread_count, retrieval_mode, storage); return 0; } IndexSparseHolder::Pointer convert_sparse_holder( const std::string &name, const ailego::Params ¶ms, - VecsIndexSparseHolder::Pointer &in_holder, IndexMeta &index_meta) { + VecsIndexSparseHolder::Pointer &in_holder, IndexMeta &index_meta, + IndexConverter::Pointer *out_converter) { IndexSparseHolder::Pointer cast_holder = std::dynamic_pointer_cast(in_holder); if (name.empty()) { @@ -679,13 +635,17 @@ IndexSparseHolder::Pointer convert_sparse_holder( index_meta = converter->meta(); + if (out_converter) { + *out_converter = converter; + } return converter->sparse_result(); } IndexHolder::Pointer convert_holder(const std::string &name, const ailego::Params ¶ms, VecsIndexHolder::Pointer &in_holder, - IndexMeta &index_meta) { + IndexMeta &index_meta, + IndexConverter::Pointer *out_converter) { IndexHolder::Pointer cast_holder = std::dynamic_pointer_cast(in_holder); if (name.empty()) { @@ -718,6 +678,9 @@ IndexHolder::Pointer convert_holder(const std::string &name, index_meta = converter->meta(); + if (out_converter) { + *out_converter = converter; + } return converter->result(); } @@ -782,8 +745,9 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { } cout << "Created builder " << builder_class << endl; + IndexConverter::Pointer build_converter; IndexSparseHolder::Pointer cv_build_holder = convert_sparse_holder( - converter_name, converter_params, build_holder, meta); + converter_name, converter_params, build_holder, meta, &build_converter); if (!cv_build_holder) { LOG_ERROR("Convert holder failed."); return -1; @@ -819,7 +783,7 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { } IndexSparseHolder::Pointer cv_train_holder = convert_sparse_holder( - converter_name, converter_params, train_holder, meta); + converter_name, converter_params, train_holder, meta, nullptr); if (!cv_train_holder) { LOG_ERROR("Convert train holder failed."); return -1; @@ -846,7 +810,7 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { if (builder != nullptr) { ret = builder->build(std::move(cv_build_holder)); } else { - ret = build_sparse_by_streamer(streamer, config_common); + ret = build_sparse_by_streamer(streamer, config_common, build_converter); } size_t build_time = timer.milli_seconds(); if (ret < 0) { @@ -856,45 +820,6 @@ int do_build_sparse(YAML::Node &config_root, YAML::Node &config_common) { cout << "Build finished, consume " << build_time << "ms." << endl; signal(SIGINT, SIG_DFL); - // DUMP - IndexDumper::Pointer dumper = IndexFactory::CreateDumper("FileDumper"); - if (!dumper) { - LOG_ERROR("Failed to create FileDumper."); - return -1; - } - string dump_prefix = config_common["DumpPath"].as(); - ret = dumper->create(dump_prefix); - if (ret != 0) { - LOG_ERROR("Failed to create in dumper, ret=%d", ret); - return -1; - } - timer.reset(); - ret = streamer ? streamer->dump(dumper) : builder->dump(dumper); - size_t dump_time = timer.milli_seconds(); - if (ret == IndexError_NotImplemented) { - LOG_WARN("Dump index not implemented"); - } else if (ret < 0) { - LOG_ERROR("Failed to dump in builder, ret=%d", ret); - return -1; - } - - if (build_holder->has_taglist()) { - size_t taglist_size{0}; - const void *taglist_data = build_holder->get_taglist_data(taglist_size); - const void *key_base = build_holder->get_key_base(); - - dump_taglist(dumper, build_holder->get_num_vecs(), key_base, taglist_data, - taglist_size); - } - - ret = dumper->close(); - if (ret != 0) { - LOG_ERROR("Dumper failed to close, ret=%d", ret); - return -1; - } - std::cout << "Dump to [" << dump_prefix << "] finished, consume " << dump_time - << "ms." << std::endl; - if (builder) { auto &stats = reinterpret_cast(builder.get())->stats(); @@ -987,8 +912,9 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Created builder " << builder_class << endl; - IndexHolder::Pointer cv_build_holder = - convert_holder(converter_name, converter_params, build_holder, meta); + IndexConverter::Pointer build_converter; + IndexHolder::Pointer cv_build_holder = convert_holder( + converter_name, converter_params, build_holder, meta, &build_converter); if (!cv_build_holder) { LOG_ERROR("Convert holder failed."); return -1; @@ -1079,8 +1005,8 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { // support fp16 convert - IndexHolder::Pointer cv_train_holder = - convert_holder(converter_name, converter_params, train_holder, meta); + IndexHolder::Pointer cv_train_holder = convert_holder( + converter_name, converter_params, train_holder, meta, nullptr); if (!cv_train_holder) { LOG_ERROR("Convert train holder failed."); return -1; @@ -1136,8 +1062,8 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { if (!metric_name.empty()) { train_holder->set_metric(metric_name, metric_params); } - IndexHolder::Pointer cv_train_holder = - convert_holder(converter_name, converter_params, train_holder, meta); + IndexHolder::Pointer cv_train_holder = convert_holder( + converter_name, converter_params, train_holder, meta, nullptr); if (!cv_train_holder) { LOG_ERROR("Convert train holder failed."); return -1; @@ -1177,7 +1103,7 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { retrieval_mode = "dense"; } - ret = build_by_streamer(streamer, config_common); + ret = build_by_streamer(streamer, config_common, build_converter); } size_t build_time = timer.milli_seconds(); if (ret < 0) { @@ -1187,45 +1113,6 @@ int do_build(YAML::Node &config_root, YAML::Node &config_common) { cout << "Build finished, consume " << build_time << "ms." << endl; signal(SIGINT, SIG_DFL); - // DUMP - IndexDumper::Pointer dumper = IndexFactory::CreateDumper("FileDumper"); - if (!dumper) { - LOG_ERROR("Failed to create FileDumper."); - return -1; - } - string dump_prefix = config_common["DumpPath"].as(); - ret = dumper->create(dump_prefix); - if (ret != 0) { - LOG_ERROR("Failed to create in dumper, ret=%d", ret); - return -1; - } - timer.reset(); - ret = streamer ? streamer->dump(dumper) : builder->dump(dumper); - size_t dump_time = timer.milli_seconds(); - if (ret == IndexError_NotImplemented) { - LOG_WARN("Dump index not implemented"); - } else if (ret < 0) { - LOG_ERROR("Failed to dump in builder, ret=%d", ret); - return -1; - } - - if (build_holder->has_taglist()) { - size_t taglist_size{0}; - const void *taglist_data = build_holder->get_taglist_data(taglist_size); - const void *key_base = build_holder->get_key_base(); - - dump_taglist(dumper, build_holder->get_num_vecs(), key_base, taglist_data, - taglist_size); - } - - ret = dumper->close(); - if (ret != 0) { - LOG_ERROR("Dumper failed to close, ret=%d", ret); - return -1; - } - std::cout << "Dump to [" << dump_prefix << "] finished, consume " << dump_time - << "ms." << std::endl; - if (builder) { auto &stats = reinterpret_cast(builder.get())->stats();