From 52629f17c63c14b294385de1c62f46fb67507bdc Mon Sep 17 00:00:00 2001 From: "jiliang.ljl" Date: Tue, 9 Jun 2026 14:58:39 +0800 Subject: [PATCH 1/3] feat(query): add VectorViewClause zero-copy path and unify validate - Add VectorViewClause (string_view-based) as zero-copy counterpart to VectorClause; variant now holds VectorClause | VectorViewClause | FtsClause - Add QueryTarget::get_vector_view() unified accessor via std::visit, returns optional regardless of which variant is held - Split validate_and_sanitize into QueryTarget::validate (read-only) + sanitize_sparse_vector (mutate); validate handles both VectorClause and VectorViewClause via get_vector_view() - Collection::Query passes original request directly to sqlengine when no sparse sanitization is needed; only copies when sort is required - Change build_query_info/BuildSQLInfoFromSearchQuery to take const SearchQuery& so VectorMatrixNode string_views point to caller's data - sqlengine internals use get_vector_view() instead of get_vector_view_clause() --- src/db/collection.cc | 45 +++++- src/db/index/common/query.cc | 107 +++++++++---- src/db/index/common/type_helper.cc | 13 -- src/db/sqlengine/analyzer/query_analyzer.cc | 4 +- src/db/sqlengine/analyzer/query_info.h | 25 +-- src/db/sqlengine/analyzer/query_node.h | 6 +- src/db/sqlengine/parser/node.h | 36 ++--- src/db/sqlengine/parser/sql_info_helper.cc | 21 ++- src/db/sqlengine/parser/sql_info_helper.h | 4 +- src/db/sqlengine/sqlengine_impl.cc | 4 +- src/db/sqlengine/sqlengine_impl.h | 2 +- src/include/zvec/db/query.h | 64 +++++++- tests/db/index/common/doc_test.cc | 165 +++++++++++++++----- tests/db/sqlengine/query_info_test.cc | 28 ++-- tests/db/sqlengine/sqlengine_test.cc | 3 + tests/db/sqlengine/vector_recall_test.cc | 8 + 16 files changed, 373 insertions(+), 162 deletions(-) diff --git a/src/db/collection.cc b/src/db/collection.cc index bab103e5b..004852305 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -1667,14 +1667,14 @@ Result CollectionImpl::Query(const SearchQuery &query) const { CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false); - SearchQuery sanitized = query; // When field_name_ is set, use get_field to retrieve the schema uniformly. - // validate_and_sanitize checks that the field type matches the query type + // validate checks that the field type matches the query type // (FTS query requires an FTS field, vector query requires a vector field). - const auto &field_name = sanitized.target_.field_name_; + const auto &field_name = query.target_.field_name_; const FieldSchema *field_schema = field_name.empty() ? nullptr : schema_->get_field(field_name); - auto s = sanitized.validate_and_sanitize(field_schema); + bool need_sanitize = false; + auto s = query.validate(field_schema, &need_sanitize); CHECK_RETURN_STATUS_EXPECTED(s); auto segments = get_all_segments(); @@ -1682,7 +1682,15 @@ Result CollectionImpl::Query(const SearchQuery &query) const { return DocPtrList(); } - return sql_engine_->execute(schema_, std::move(sanitized), segments); + if (!need_sanitize) { + return sql_engine_->execute(schema_, query, segments); + } + + // Sparse needs sanitization: make a mutable copy and sort indices in place. + SearchQuery sanitized_query = query; + auto ss = sanitize_sparse_vector(sanitized_query.target_, field_schema); + CHECK_RETURN_STATUS_EXPECTED(ss); + return sql_engine_->execute(schema_, std::move(sanitized_query), segments); } Result CollectionImpl::Query(const MultiQuery &query) const { @@ -1716,6 +1724,10 @@ Result CollectionImpl::Query(const MultiQuery &query) const { } auto *field_schema = field_ptr.get(); + bool need_sanitize = false; + auto s = target.validate(field_schema, &need_sanitize); + CHECK_RETURN_STATUS_EXPECTED(s); + SearchQuery sq; sq.target_ = target; sq.topk_ = sub.num_candidates_; @@ -1724,8 +1736,10 @@ Result CollectionImpl::Query(const MultiQuery &query) const { sq.include_doc_id_ = query.include_doc_id_; sq.output_fields_ = query.output_fields; - auto s = sq.validate_and_sanitize(field_schema); - CHECK_RETURN_STATUS_EXPECTED(s); + if (need_sanitize) { + auto ss = sanitize_sparse_vector(sq.target_, field_schema); + CHECK_RETURN_STATUS_EXPECTED(ss); + } pending_queries.push_back(std::move(sq)); field_schemas.push_back(std::move(field_ptr)); } @@ -1777,7 +1791,22 @@ Result CollectionImpl::GroupByQuery( return GroupResults(); } - return sql_engine_->execute_group_by(schema_, query, segments); + // Determine vector data source (zero-copy for dense, copy+sort for sparse) + const FieldSchema *field_schema = + schema_->get_field(query.target_.field_name_); + bool need_sanitize = false; + auto s = query.target_.validate(field_schema, &need_sanitize); + CHECK_RETURN_STATUS_EXPECTED(s); + + if (!need_sanitize) { + return sql_engine_->execute_group_by(schema_, query, segments); + } + + // Sparse needs sanitization: make a mutable copy and sort indices in place. + GroupByVectorQuery sanitized_query = query; + auto ss = sanitize_sparse_vector(sanitized_query.target_, field_schema); + CHECK_RETURN_STATUS_EXPECTED(ss); + return sql_engine_->execute_group_by(schema_, sanitized_query, segments); } Result CollectionImpl::Fetch( diff --git a/src/db/index/common/query.cc b/src/db/index/common/query.cc index e9f451ffb..494d06cf9 100644 --- a/src/db/index/common/query.cc +++ b/src/db/index/common/query.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include #include #include "db/common/constants.h" @@ -20,27 +22,16 @@ namespace zvec { -Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) { - if ((uint32_t)topk_ > kMaxQueryTopk) { - return Status::InvalidArgument("Invalid query: topk[", topk_, - "] exceeds the maximum allowed value of ", - kMaxQueryTopk); - } - if (output_fields_.has_value() && - output_fields_->size() > kMaxOutputFieldSize) { - return Status::InvalidArgument( - "Invalid query: too many output fields, the maximum allowed is ", - kMaxOutputFieldSize); - } - - auto *vc = target_.get_vector_clause(); - auto *fc = target_.get_fts_clause(); - auto &field_name = target_.field_name_; - auto &query_params = target_.query_params_; +Status QueryTarget::validate(const FieldSchema *schema, + bool *need_sanitize) const { + auto opt_view = get_vector_view(); + auto *fc = get_fts_clause(); + auto &field_name = field_name_; + auto &query_params = query_params_; // A "scalar-only filter" query has no vector payload — either the clause // is not a VectorClause (e.g., FtsClause) or its fields are all empty. - bool no_vector_payload = (vc == nullptr) || (vc->query_vector_.empty() && - vc->sparse_indices_.empty()); + bool no_vector_payload = !opt_view || (opt_view->query_vector_.empty() && + opt_view->sparse_indices_.empty()); if (schema == nullptr) { if (fc != nullptr) { @@ -87,9 +78,9 @@ Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) { "Invalid query: missing query clause for field[", field_name, "]"); } - auto &query_vector = vc->query_vector_; - auto &query_sparse_indices = vc->sparse_indices_; - auto &query_sparse_values = vc->sparse_values_; + auto &query_vector = opt_view->query_vector_; + auto &query_sparse_indices = opt_view->sparse_indices_; + auto &query_sparse_values = opt_view->sparse_values_; // Vector query if (schema->is_dense_vector()) { @@ -163,12 +154,15 @@ Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) { "Invalid query: too many sparse indices, the maximum allowed is ", kSparseMaxDimSize); } - if (sort_and_find_duplicates( - reinterpret_cast(query_sparse_indices.data()), - query_sparse_values.data(), n_indices, value_byte_size)) { - return Status::InvalidArgument( - "Invalid query: sparse vector query for field[", field_name, - "] contains duplicate indices"); + if (n_indices > 1 && need_sanitize) { + const auto *idx = + reinterpret_cast(query_sparse_indices.data()); + // Detect any non-strictly-increasing pair (unsorted or duplicate). + if (std::adjacent_find(idx, idx + n_indices, + std::greater_equal()) != + idx + n_indices) { + *need_sanitize = true; + } } } else { return Status::InvalidArgument("Invalid query: field[", field_name, @@ -186,4 +180,61 @@ Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) { return Status::OK(); } +Status SearchQuery::validate(const FieldSchema *schema, + bool *need_sanitize) const { + if (need_sanitize) { + *need_sanitize = false; + } + if ((uint32_t)topk_ > kMaxQueryTopk) { + return Status::InvalidArgument("Invalid query: topk[", topk_, + "] exceeds the maximum allowed value of ", + kMaxQueryTopk); + } + if (output_fields_.has_value() && + output_fields_->size() > kMaxOutputFieldSize) { + return Status::InvalidArgument( + "Invalid query: too many output fields, the maximum allowed is ", + kMaxOutputFieldSize); + } + return target_.validate(schema, need_sanitize); +} + +Status sanitize_sparse_vector(VectorClause &vc, const FieldSchema *schema) { + if (!schema || !schema->is_sparse_vector()) { + return Status::OK(); + } + size_t value_byte_size = 0; + switch (schema->data_type()) { + case DataType::SPARSE_VECTOR_FP32: + value_byte_size = sizeof(float); + break; + case DataType::SPARSE_VECTOR_FP16: + value_byte_size = sizeof(float16_t); + break; + default: + return Status::OK(); + } + size_t n_indices = vc.sparse_indices_.size() / sizeof(uint32_t); + if (n_indices <= 1) { + return Status::OK(); + } + if (sort_and_find_duplicates( + reinterpret_cast(vc.sparse_indices_.data()), + vc.sparse_values_.data(), n_indices, value_byte_size)) { + return Status::InvalidArgument( + "Invalid query: sparse vector query for field[", schema->name(), + "] contains duplicate indices"); + } + return Status::OK(); +} + +Status sanitize_sparse_vector(QueryTarget &target, const FieldSchema *schema) { + if (auto *vvc = target.get_vector_view_clause()) { + target.clause_ = VectorClause{std::string(vvc->query_vector_), + std::string(vvc->sparse_indices_), + std::string(vvc->sparse_values_)}; + } + return sanitize_sparse_vector(*target.get_vector_clause(), schema); +} + } // namespace zvec diff --git a/src/db/index/common/type_helper.cc b/src/db/index/common/type_helper.cc index 45f2c24a5..3360c89ce 100644 --- a/src/db/index/common/type_helper.cc +++ b/src/db/index/common/type_helper.cc @@ -26,19 +26,6 @@ bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n, if (n <= 1) { return false; } - bool already_sorted = true; - for (size_t i = 1; i < n; ++i) { - if (indices[i] == indices[i - 1]) { - return true; - } - if (indices[i] < indices[i - 1]) { - already_sorted = false; - break; - } - } - if (already_sorted) { - return false; - } std::vector perm(n); std::iota(perm.begin(), perm.end(), size_t{0}); std::sort(perm.begin(), perm.end(), diff --git a/src/db/sqlengine/analyzer/query_analyzer.cc b/src/db/sqlengine/analyzer/query_analyzer.cc index 2e144ddb7..e410ffd25 100644 --- a/src/db/sqlengine/analyzer/query_analyzer.cc +++ b/src/db/sqlengine/analyzer/query_analyzer.cc @@ -533,8 +533,8 @@ Status QueryAnalyzer::check_and_convert_vector( } *vector_cond = std::make_shared( - vector_meta, vector_data->take_matrix(), core_data_type, dimension, - vector_data->take_sparse_indices(), vector_data->take_sparse_values(), + vector_meta, vector_data->matrix(), core_data_type, dimension, + vector_data->sparse_indices(), vector_data->sparse_values(), vector_data->take_query_params()); return Status::OK(); } else { diff --git a/src/db/sqlengine/analyzer/query_info.h b/src/db/sqlengine/analyzer/query_info.h index c21ea48a5..b22eef717 100644 --- a/src/db/sqlengine/analyzer/query_info.h +++ b/src/db/sqlengine/analyzer/query_info.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -47,17 +48,17 @@ class QueryInfo { using Ptr = std::shared_ptr; QueryVectorCondInfo(const FieldSchema *vector_schema, - std::string vector_term, + std::string_view vector_term, core::IndexMeta::DataType core_data_type, int dimension, - std::string vector_sparse_indices, - std::string vector_sparse_values, + std::string_view vector_sparse_indices, + std::string_view vector_sparse_values, QueryParams::Ptr query_params) : vector_schema_(vector_schema), - vector_term_(std::move(vector_term)), + vector_term_(vector_term), data_type_(core_data_type), dimension_(dimension), - vector_sparse_indices_(std::move(vector_sparse_indices)), - vector_sparse_values_(std::move(vector_sparse_values)), + vector_sparse_indices_(vector_sparse_indices), + vector_sparse_values_(vector_sparse_values), query_params_(std::move(query_params)) { auto *vector_params = dynamic_cast( vector_schema_->index_params().get()); @@ -75,7 +76,7 @@ class QueryInfo { return vector_schema_; } - const std::string &vector_term() const { + std::string_view vector_term() const { return vector_term_; } @@ -99,11 +100,11 @@ class QueryInfo { return vector_sparse_indices_.size() / sizeof(uint32_t); } - const std::string &vector_sparse_indices() const { + std::string_view vector_sparse_indices() const { return vector_sparse_indices_; } - const std::string &vector_sparse_values() const { + std::string_view vector_sparse_values() const { return vector_sparse_values_; } @@ -117,11 +118,11 @@ class QueryInfo { private: const FieldSchema *vector_schema_{nullptr}; - std::string vector_term_{""}; + std::string_view vector_term_; core::IndexMeta::DataType data_type_; uint32_t dimension_{0}; - std::string vector_sparse_indices_{""}; - std::string vector_sparse_values_{""}; + std::string_view vector_sparse_indices_; + std::string_view vector_sparse_values_; QueryParams::Ptr query_params_; bool reverse_sort_{false}; }; diff --git a/src/db/sqlengine/analyzer/query_node.h b/src/db/sqlengine/analyzer/query_node.h index 6d2e352ab..6f250fbe9 100644 --- a/src/db/sqlengine/analyzer/query_node.h +++ b/src/db/sqlengine/analyzer/query_node.h @@ -202,15 +202,15 @@ class QueryVectorMatrixNode : public QueryNode { std::string text() const override; - const std::string &matrix() const { + std::string_view matrix() const { return node_->matrix(); } - const std::string &sparse_indices() const { + std::string_view sparse_indices() const { return node_->sparse_indices(); } - const std::string &sparse_values() const { + std::string_view sparse_values() const { return node_->sparse_values(); } diff --git a/src/db/sqlengine/parser/node.h b/src/db/sqlengine/parser/node.h index 344b1eac8..e83ae7884 100644 --- a/src/db/sqlengine/parser/node.h +++ b/src/db/sqlengine/parser/node.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include "db/sqlengine/common/generic_node.h" @@ -158,39 +159,28 @@ class VectorMatrixNode : public Node { public: using Ptr = std::shared_ptr; - VectorMatrixNode(std::string matrix, std::string sparse_indices, - std::string sparse_values, QueryParams::Ptr query_params) - : matrix_(std::move(matrix)), - sparse_indices_(std::move(sparse_indices)), - sparse_values_(std::move(sparse_values)), + VectorMatrixNode(std::string_view matrix, std::string_view sparse_indices, + std::string_view sparse_values, + QueryParams::Ptr query_params) + : matrix_(matrix), + sparse_indices_(sparse_indices), + sparse_values_(sparse_values), query_params_(std::move(query_params)) { set_op(NodeOp::T_VECTOR_MATRIX_VALUE); } - const std::string &matrix() const { + std::string_view matrix() const { return matrix_; } - std::string take_matrix() { - return std::move(matrix_); - } - - const std::string &sparse_indices() const { + std::string_view sparse_indices() const { return sparse_indices_; } - std::string take_sparse_indices() { - return std::move(sparse_indices_); - } - - const std::string &sparse_values() const { + std::string_view sparse_values() const { return sparse_values_; } - std::string take_sparse_values() { - return std::move(sparse_values_); - } - const QueryParams::Ptr &query_params() const { return query_params_; } @@ -206,9 +196,9 @@ class VectorMatrixNode : public Node { } private: - std::string matrix_; - std::string sparse_indices_; - std::string sparse_values_; + std::string_view matrix_; + std::string_view sparse_indices_; + std::string_view sparse_values_; QueryParams::Ptr query_params_; }; diff --git a/src/db/sqlengine/parser/sql_info_helper.cc b/src/db/sqlengine/parser/sql_info_helper.cc index 8b2c9379a..6f8071790 100644 --- a/src/db/sqlengine/parser/sql_info_helper.cc +++ b/src/db/sqlengine/parser/sql_info_helper.cc @@ -26,17 +26,16 @@ namespace zvec::sqlengine { using namespace zvec; -Node::Ptr handle_vector(SearchQuery *request) { - auto *vc = request->target_.get_vector_clause(); - if (vc == nullptr) { +Node::Ptr handle_vector(const SearchQuery &request) { + auto opt_view = request.target_.get_vector_view(); + if (!opt_view) { return nullptr; } Node::Ptr rel_exp = std::make_shared(NodeOp::T_EQ); - rel_exp->set_left(std::make_shared(request->target_.field_name_)); + rel_exp->set_left(std::make_shared(request.target_.field_name_)); rel_exp->set_right(std::make_shared( - std::move(vc->query_vector_), std::move(vc->sparse_indices_), - std::move(vc->sparse_values_), - std::move(request->target_.query_params_))); + opt_view->query_vector_, opt_view->sparse_indices_, + opt_view->sparse_values_, request.target_.query_params_)); return rel_exp; } @@ -67,13 +66,13 @@ void handle_query_field(const SearchQuery *query, SelectInfo *selected_info) { } Result SQLInfoHelper::BuildSQLInfoFromSearchQuery( - SearchQuery query, Node::Ptr filter_node, + const SearchQuery &query, Node::Ptr filter_node, std::shared_ptr group_by) { Node::Ptr index_params_node_ptr = nullptr; - if (const auto *vc = query.target_.get_vector_clause(); - vc != nullptr && + if (auto vc = query.target_.get_vector_view(); + vc.has_value() && (!vc->query_vector_.empty() || !vc->sparse_indices_.empty())) { - index_params_node_ptr = handle_vector(&query); + index_params_node_ptr = handle_vector(query); if (index_params_node_ptr == nullptr) { return tl::make_unexpected(Status::InvalidArgument( "Failed to build vector condition for field: ", diff --git a/src/db/sqlengine/parser/sql_info_helper.h b/src/db/sqlengine/parser/sql_info_helper.h index 09d45fa5a..84b75efd9 100644 --- a/src/db/sqlengine/parser/sql_info_helper.h +++ b/src/db/sqlengine/parser/sql_info_helper.h @@ -24,10 +24,8 @@ namespace zvec::sqlengine { class SQLInfoHelper { public: - //! Build SQLInfo from SearchQuery. Takes query by value so callers may copy - //! or move it; vector payloads can be moved while building SQLInfo. static Result BuildSQLInfoFromSearchQuery( - SearchQuery query, Node::Ptr filter_node, + const SearchQuery &query, Node::Ptr filter_node, std::shared_ptr group_by); }; diff --git a/src/db/sqlengine/sqlengine_impl.cc b/src/db/sqlengine/sqlengine_impl.cc index 69d8c7fe7..c99f5c735 100644 --- a/src/db/sqlengine/sqlengine_impl.cc +++ b/src/db/sqlengine/sqlengine_impl.cc @@ -280,7 +280,7 @@ Result SQLEngineImpl::parse_sql_info( } Result SQLEngineImpl::build_query_info( - CollectionSchema::Ptr collection, SearchQuery request, + CollectionSchema::Ptr collection, const SearchQuery &request, std::shared_ptr group_by) { ScopedProfilerStage stage_guard(profiler_, "build_sql_info"); Node::Ptr filter_node; @@ -317,7 +317,7 @@ Result SQLEngineImpl::build_query_info( } auto sql_info = sqlengine::SQLInfoHelper::BuildSQLInfoFromSearchQuery( - std::move(request), std::move(filter_node), std::move(group_by)); + request, std::move(filter_node), std::move(group_by)); if (!sql_info) { return tl::make_unexpected(sql_info.error()); } diff --git a/src/db/sqlengine/sqlengine_impl.h b/src/db/sqlengine/sqlengine_impl.h index bcc616b64..602253208 100644 --- a/src/db/sqlengine/sqlengine_impl.h +++ b/src/db/sqlengine/sqlengine_impl.h @@ -36,7 +36,7 @@ class SQLEngineImpl : public SQLEngine { //! Build analyzed query info from a structured search query. Result build_query_info(CollectionSchema::Ptr collection, - SearchQuery request, + const SearchQuery &request, std::shared_ptr group_by); //! Perform search with given query_info, segments and index filter diff --git a/src/include/zvec/db/query.h b/src/include/zvec/db/query.h index abfa66f22..a3499513d 100644 --- a/src/include/zvec/db/query.h +++ b/src/include/zvec/db/query.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -25,12 +26,30 @@ namespace zvec { +struct VectorViewClause; + struct VectorClause { std::string query_vector_; std::string sparse_indices_; std::string sparse_values_; + + // Return a non-owning view of this clause. Caller must ensure this + // VectorClause outlives the returned VectorViewClause. + VectorViewClause to_view() const; +}; + +// Non-owning view counterpart of VectorClause. The referenced strings +// must outlive any VectorViewClause instance. +struct VectorViewClause { + std::string_view query_vector_; + std::string_view sparse_indices_; + std::string_view sparse_values_; }; +inline VectorViewClause VectorClause::to_view() const { + return VectorViewClause{query_vector_, sparse_indices_, sparse_values_}; +} + struct FtsClause { std::string query_string_; std::string match_string_; @@ -38,7 +57,7 @@ struct FtsClause { struct QueryTarget { std::string field_name_; - std::variant clause_; + std::variant clause_; QueryParams::Ptr query_params_; // Mutators ensure clause_ holds a VectorClause. @@ -53,6 +72,14 @@ struct QueryTarget { return std::get_if(&clause_); } + // nullptr when clause_ holds a non-VectorViewClause alternative. + VectorViewClause *get_vector_view_clause() { + return std::get_if(&clause_); + } + const VectorViewClause *get_vector_view_clause() const { + return std::get_if(&clause_); + } + // nullptr when clause_ holds a non-FtsClause alternative. FtsClause *get_fts_clause() { return std::get_if(&clause_); @@ -61,6 +88,28 @@ struct QueryTarget { return std::get_if(&clause_); } + // Unified accessor: returns a view regardless of whether the variant + // holds VectorClause or VectorViewClause. Returns nullopt for FTS. + std::optional get_vector_view() const { + return std::visit( + [](auto &&arg) -> std::optional { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return arg.to_view(); + } else if constexpr (std::is_same_v) { + return arg; + } else { + return std::nullopt; + } + }, + clause_); + } + + // Read-only validation of vector/FTS target (dimension, types, sizes). + // For sparse vectors, checks if indices are sorted; sets *need_sanitize=true + // if sorting is needed (when need_sanitize is non-null). + Status validate(const FieldSchema *schema, bool *need_sanitize) const; + private: // Resets clause_ to an empty VectorClause unless it already holds one. VectorClause &ensure_vector_clause() { @@ -94,10 +143,19 @@ struct SearchQuery { // non-empty -> select only the listed fields std::optional> output_fields_; - // FtsClause currently bypasses validation (FTS not yet implemented). - Status validate_and_sanitize(const FieldSchema *schema); + // Read-only validation (topk, output_fields, target). + // For sparse vectors: sets *need_sanitize=true if indices are not sorted. + Status validate(const FieldSchema *schema, bool *need_sanitize) const; }; +// Sort sparse indices in-place and check for duplicates. +// Returns error if duplicates are found after sorting. +Status sanitize_sparse_vector(VectorClause &vc, const FieldSchema *schema); + +// Materializes VectorViewClause into VectorClause if needed, then sorts +// sparse indices in place. Operates on the QueryTarget's clause_ variant. +Status sanitize_sparse_vector(QueryTarget &target, const FieldSchema *schema); + struct GroupByVectorQuery { QueryTarget target_; std::string filter_; diff --git a/tests/db/index/common/doc_test.cc b/tests/db/index/common/doc_test.cc index 895bd276b..63c703d96 100644 --- a/tests/db/index/common/doc_test.cc +++ b/tests/db/index/common/doc_test.cc @@ -1225,7 +1225,7 @@ TEST(SearchQuery, ValidateAndSanitize) { SearchQuery query; query.topk_ = 10; query.target_.field_name_ = "field_name"; - auto s = query.validate_and_sanitize(nullptr); + auto s = query.validate(nullptr, nullptr); EXPECT_TRUE(s.ok()); } @@ -1239,7 +1239,7 @@ TEST(SearchQuery, ValidateAndSanitize) { std::string(reinterpret_cast(query_vector.data()), query_vector.size() * sizeof(float)); query.target_.set_vector(query_vector_str); - auto s = query.validate_and_sanitize(nullptr); + auto s = query.validate(nullptr, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } @@ -1251,7 +1251,7 @@ TEST(SearchQuery, ValidateAndSanitize) { query.topk_ = 10; query.output_fields_ = std::vector(1025); FieldSchema schema = FieldSchema("field_name", DataType::INT32); - auto s = query.validate_and_sanitize(&schema); + auto s = query.validate(&schema, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } @@ -1268,11 +1268,11 @@ TEST(SearchQuery, ValidateAndSanitize) { query.target_.set_vector(query_vector_str); FieldSchema schema = FieldSchema("field_name", DataType::VECTOR_FP32, 4, true); - auto s = query.validate_and_sanitize(&schema); + auto s = query.validate(&schema, nullptr); EXPECT_TRUE(s.ok()); query.target_.set_vector(query_vector_str.substr(0, 3)); - s = query.validate_and_sanitize(&schema); + s = query.validate(&schema, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } @@ -1291,7 +1291,7 @@ TEST(SearchQuery, ValidateAndSanitize) { query_values.size() * sizeof(float))); FieldSchema schema = FieldSchema("field_name", DataType::SPARSE_VECTOR_FP32); - auto s = query.validate_and_sanitize(&schema); + auto s = query.validate(&schema, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); @@ -1301,12 +1301,12 @@ TEST(SearchQuery, ValidateAndSanitize) { query.target_.set_sparse_vector( std::string(reinterpret_cast(&one_index), sizeof(uint32_t)), std::string(reinterpret_cast(&one_value), sizeof(float))); - s = query.validate_and_sanitize(&schema); + s = query.validate(&schema, nullptr); EXPECT_TRUE(s.ok()); } - // sparse query must have matching counts, and indices must be strictly - // ascending and unique + // sparse: validate sets need_sanitize for unsorted, sanitize sorts and + // detects duplicates { auto pack_idx = [](const std::vector &v) { return std::string(reinterpret_cast(v.data()), @@ -1327,61 +1327,85 @@ TEST(SearchQuery, ValidateAndSanitize) { FieldSchema schema = FieldSchema("field_name", DataType::SPARSE_VECTOR_FP32); - // unsorted indices are sorted in place + // unsorted indices: validate sets need_sanitize, sanitize sorts in place { SearchQuery query; query.target_.field_name_ = "field_name"; query.topk_ = 100; query.target_.set_sparse_vector(pack_idx({42u, 7u, 128u, 3u, 99u}), pack_val({0.1f, 0.2f, 0.3f, 0.4f, 0.5f})); - auto s = query.validate_and_sanitize(&schema); + bool need_sanitize = false; + auto s = query.validate(&schema, &need_sanitize); EXPECT_TRUE(s.ok()) << s.message(); - EXPECT_EQ( - decode_idx( - std::get(query.target_.clause_).sparse_indices_), - (std::vector{3u, 7u, 42u, 99u, 128u})); - EXPECT_EQ( - decode_val( - std::get(query.target_.clause_).sparse_values_), - (std::vector{0.4f, 0.2f, 0.1f, 0.5f, 0.3f})); + EXPECT_TRUE(need_sanitize); + + VectorClause vc = *query.target_.get_vector_clause(); + s = sanitize_sparse_vector(vc, &schema); + EXPECT_TRUE(s.ok()) << s.message(); + EXPECT_EQ(decode_idx(vc.sparse_indices_), + (std::vector{3u, 7u, 42u, 99u, 128u})); + EXPECT_EQ(decode_val(vc.sparse_values_), + (std::vector{0.4f, 0.2f, 0.1f, 0.5f, 0.3f})); } - // duplicates are rejected + // duplicates (sorted): validate detects as unsorted (equal == not strictly + // less), sanitize sorts and reports duplicates { SearchQuery query; query.target_.field_name_ = "field_name"; query.topk_ = 100; query.target_.set_sparse_vector(pack_idx({3u, 7u, 42u, 42u, 99u}), pack_val({0.1f, 0.2f, 0.3f, 0.4f, 0.5f})); - auto s = query.validate_and_sanitize(&schema); + bool need_sanitize = false; + auto s = query.validate(&schema, &need_sanitize); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(need_sanitize); + + VectorClause vc = *query.target_.get_vector_clause(); + s = sanitize_sparse_vector(vc, &schema); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); + } + // duplicates (unsorted): sanitize sorts then reports duplicates + { + SearchQuery query; + query.target_.field_name_ = "field_name"; + query.topk_ = 100; query.target_.set_sparse_vector(pack_idx({42u, 3u, 7u, 42u, 99u}), pack_val({0.1f, 0.2f, 0.3f, 0.4f, 0.5f})); - s = query.validate_and_sanitize(&schema); + bool need_sanitize = false; + auto s = query.validate(&schema, &need_sanitize); + EXPECT_TRUE(s.ok()); + EXPECT_TRUE(need_sanitize); + + VectorClause vc = *query.target_.get_vector_clause(); + s = sanitize_sparse_vector(vc, &schema); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } - // mismatched counts are rejected + // sorted without duplicates: need_sanitize is false { SearchQuery query; query.target_.field_name_ = "field_name"; query.topk_ = 100; - const auto idx_before = pack_idx({3u, 2u, 1u, 4u}); - const auto val_before = pack_val({0.1f, 0.2f, 0.3f, 0.4f}); - query.target_.set_sparse_vector(idx_before, val_before); - auto s = query.validate_and_sanitize(&schema); + query.target_.set_sparse_vector(pack_idx({1u, 2u, 3u, 4u}), + pack_val({0.1f, 0.2f, 0.3f, 0.4f})); + bool need_sanitize = false; + auto s = query.validate(&schema, &need_sanitize); EXPECT_TRUE(s.ok()) << s.message(); - EXPECT_EQ(std::get(query.target_.clause_).sparse_indices_, - pack_idx({1u, 2u, 3u, 4u})); - EXPECT_EQ(std::get(query.target_.clause_).sparse_values_, - pack_val({0.3f, 0.2f, 0.1f, 0.4f})); - - std::get(query.target_.clause_).sparse_values_ = - pack_val({0.1f, 0.2f, 0.3f}); - s = query.validate_and_sanitize(&schema); + EXPECT_FALSE(need_sanitize); + } + + // mismatched counts are rejected by validate + { + SearchQuery query; + query.target_.field_name_ = "field_name"; + query.topk_ = 100; + query.target_.set_sparse_vector(pack_idx({1u, 2u, 3u, 4u}), + pack_val({0.1f, 0.2f, 0.3f})); + auto s = query.validate(&schema, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } @@ -1401,21 +1425,20 @@ TEST(SearchQuery, ValidateAndSanitize) { std::make_shared(MetricType::L2)); query.target_.query_params_ = std::make_shared(150); - auto s = query.validate_and_sanitize(&schema); + auto s = query.validate(&schema, nullptr); EXPECT_TRUE(s.ok()); query.target_.query_params_ = std::make_shared(50); - s = query.validate_and_sanitize(&schema); + s = query.validate(&schema, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); query.target_.query_params_ = nullptr; - s = query.validate_and_sanitize(&schema); + s = query.validate(&schema, nullptr); EXPECT_TRUE(s.ok()); } - // A vector clause and an FTS clause are mutually exclusive by construction: - // target_.clause_ is a variant that holds exactly one of them. + // FTS clause validation { auto fts_params = std::make_shared(); FieldSchema fts_schema("content", DataType::STRING, false, fts_params); @@ -1427,21 +1450,77 @@ TEST(SearchQuery, ValidateAndSanitize) { FtsClause fts_test; fts_test.query_string_ = "test"; fts_only.target_.clause_ = fts_test; - auto s = fts_only.validate_and_sanitize(&fts_schema); + auto s = fts_only.validate(&fts_schema, nullptr); EXPECT_TRUE(s.ok()); // FTS query with nullptr schema -> fail (field not found) - s = fts_only.validate_and_sanitize(nullptr); + s = fts_only.validate(nullptr, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); // FTS query with vector field schema -> fail (type mismatch) FieldSchema vec_schema("embedding", DataType::VECTOR_FP32, 128, false, std::make_shared(MetricType::L2)); - s = fts_only.validate_and_sanitize(&vec_schema); + s = fts_only.validate(&vec_schema, nullptr); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); } + + // VectorViewClause: validate handles VectorViewClause the same as + // VectorClause + { + FieldSchema schema = + FieldSchema("field_name", DataType::VECTOR_FP32, 4, true); + std::vector query_vector = {1.0f, 2.0f, 3.0f, 4.0f}; + std::string vec_data(reinterpret_cast(query_vector.data()), + query_vector.size() * sizeof(float)); + + // Dense VectorViewClause: valid dimension + { + SearchQuery query; + query.target_.field_name_ = "field_name"; + query.topk_ = 10; + query.target_.clause_ = + VectorViewClause{vec_data, std::string_view{}, std::string_view{}}; + auto s = query.validate(&schema, nullptr); + EXPECT_TRUE(s.ok()) << s.message(); + } + + // Dense VectorViewClause: wrong dimension + { + SearchQuery query; + query.target_.field_name_ = "field_name"; + query.topk_ = 10; + std::string short_vec = vec_data.substr(0, sizeof(float) * 2); + query.target_.clause_ = + VectorViewClause{short_vec, std::string_view{}, std::string_view{}}; + auto s = query.validate(&schema, nullptr); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); + } + + // Sparse VectorViewClause: unsorted triggers need_sanitize + { + FieldSchema sparse_schema( + "field_name", DataType::SPARSE_VECTOR_FP32, false, + std::make_shared(MetricType::IP)); + std::vector idx_vec = {3u, 1u, 2u}; + std::vector val_vec = {0.3f, 0.1f, 0.2f}; + std::string idx_data(reinterpret_cast(idx_vec.data()), + idx_vec.size() * sizeof(uint32_t)); + std::string val_data(reinterpret_cast(val_vec.data()), + val_vec.size() * sizeof(float)); + SearchQuery query; + query.target_.field_name_ = "field_name"; + query.topk_ = 10; + query.target_.clause_ = + VectorViewClause{std::string_view{}, idx_data, val_data}; + bool need_sanitize = false; + auto s = query.validate(&sparse_schema, &need_sanitize); + EXPECT_TRUE(s.ok()) << s.message(); + EXPECT_TRUE(need_sanitize); + } + } } // Test null value diff --git a/tests/db/sqlengine/query_info_test.cc b/tests/db/sqlengine/query_info_test.cc index b32fcd52b..8645f0b18 100644 --- a/tests/db/sqlengine/query_info_test.cc +++ b/tests/db/sqlengine/query_info_test.cc @@ -95,6 +95,7 @@ TEST_F(QueryInfoTest, BasicQueryRequest) { query.target_.query_params_ = std::make_shared(IndexType::FLAT); query.target_.query_params_->set_radius(0.8F); + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()) << ret.error().c_str(); @@ -116,11 +117,11 @@ TEST_F(QueryInfoTest, BasicQueryRequest) { auto vector_cond = new_query_info->vector_cond_info(); EXPECT_EQ(1, vector_cond->batch()); EXPECT_EQ("face_feature", vector_cond->vector_field_name()); - EXPECT_EQ(std::get(query.target_.clause_).query_vector_, + EXPECT_EQ(query.target_.get_vector_clause()->query_vector_, vector_cond->vector_term()); - EXPECT_EQ(std::get(query.target_.clause_).sparse_indices_, + EXPECT_EQ(query.target_.get_vector_clause()->sparse_indices_, vector_cond->vector_sparse_indices()); - EXPECT_EQ(std::get(query.target_.clause_).sparse_values_, + EXPECT_EQ(query.target_.get_vector_clause()->sparse_values_, vector_cond->vector_sparse_values()); EXPECT_EQ(query.target_.query_params_, vector_cond->query_params()); } @@ -136,6 +137,7 @@ TEST_F(QueryInfoTest, QueryRequestWithFilter) { query.target_.query_params_->set_radius(0.8F); query.filter_ = "name<3 or name=4 or 1-dash_score_field='test'"; + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); @@ -157,11 +159,11 @@ TEST_F(QueryInfoTest, QueryRequestWithFilter) { auto vector_cond = new_query_info->vector_cond_info(); EXPECT_EQ(1, vector_cond->batch()); EXPECT_EQ("face_feature", vector_cond->vector_field_name()); - EXPECT_EQ(std::get(query.target_.clause_).query_vector_, + EXPECT_EQ(query.target_.get_vector_clause()->query_vector_, vector_cond->vector_term()); - EXPECT_EQ(std::get(query.target_.clause_).sparse_indices_, + EXPECT_EQ(query.target_.get_vector_clause()->sparse_indices_, vector_cond->vector_sparse_indices()); - EXPECT_EQ(std::get(query.target_.clause_).sparse_values_, + EXPECT_EQ(query.target_.get_vector_clause()->sparse_values_, vector_cond->vector_sparse_values()); EXPECT_EQ(query.target_.query_params_, vector_cond->query_params()); @@ -220,6 +222,7 @@ TEST_F(QueryInfoTest, QueryRequestWithIncludeVector) { query.target_.query_params_->set_radius(0.8F); query.include_vector_ = true; + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); @@ -242,11 +245,11 @@ TEST_F(QueryInfoTest, QueryRequestWithIncludeVector) { auto vector_cond = new_query_info->vector_cond_info(); EXPECT_EQ(1, vector_cond->batch()); EXPECT_EQ("face_feature", vector_cond->vector_field_name()); - EXPECT_EQ(std::get(query.target_.clause_).query_vector_, + EXPECT_EQ(query.target_.get_vector_clause()->query_vector_, vector_cond->vector_term()); - EXPECT_EQ(std::get(query.target_.clause_).sparse_indices_, + EXPECT_EQ(query.target_.get_vector_clause()->sparse_indices_, vector_cond->vector_sparse_indices()); - EXPECT_EQ(std::get(query.target_.clause_).sparse_values_, + EXPECT_EQ(query.target_.get_vector_clause()->sparse_values_, vector_cond->vector_sparse_values()); EXPECT_EQ(query.target_.query_params_, vector_cond->query_params()); } @@ -262,6 +265,7 @@ TEST_F(QueryInfoTest, OR_ANCESTOR) { query.target_.query_params_->set_radius(0.8F); query.filter_ = "name=1 and (name=2 or name=3)"; + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); @@ -280,6 +284,7 @@ TEST_F(QueryInfoTest, QueryRequestWithInFilter) { query.filter_ = "name=3 or name in (1, 2, 3) or category not in (\"a\", \"b\", \"c\")"; + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); @@ -303,7 +308,7 @@ TEST_F(QueryInfoTest, QueryRequestWithInFilter) { EXPECT_EQ(1, vector_cond->batch()); EXPECT_EQ("face_feature", vector_cond->vector_field_name()); std::vector data{1.1, 2.2, 3.3, 4.4}; - EXPECT_EQ(std::get(query.target_.clause_).query_vector_, + EXPECT_EQ(query.target_.get_vector_clause()->query_vector_, vector_cond->vector_term()); EXPECT_TRUE(new_query_info->filter_cond()); @@ -369,6 +374,7 @@ TEST_F(QueryInfoTest, QueryRequestWithInFilterWrong) { query.target_.query_params_ = std::make_shared(IndexType::FLAT); query.target_.query_params_->set_radius(0.8F); + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); @@ -409,6 +415,7 @@ TEST_F(QueryInfoTest, QueryRequestWithInFilterNum1024) { } query.filter_ = filter_str; + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); @@ -450,6 +457,7 @@ TEST_F(QueryInfoTest, QueryRequestWithFilter_contain) { R"( or category_array not contain_any ("c", "d", "e") )"; + auto engine = std::make_shared(std::make_shared()); auto ret = engine->build_query_info(schema, query, nullptr); ASSERT_TRUE(ret.has_value()); diff --git a/tests/db/sqlengine/sqlengine_test.cc b/tests/db/sqlengine/sqlengine_test.cc index f03a27e13..9f4894be1 100644 --- a/tests/db/sqlengine/sqlengine_test.cc +++ b/tests/db/sqlengine/sqlengine_test.cc @@ -90,6 +90,7 @@ TEST_F(SqlEngineTest, Vector) { query.target_.query_params_ = std::make_shared(IndexType::FLAT); query.target_.query_params_->set_radius(0.8F); + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(schema_, query, segments); if (!ret) { @@ -131,6 +132,7 @@ TEST_F(SqlEngineTest, MultiSegments) { query.filter_ = env_var; } + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(schema_, query, segments); if (!ret) { @@ -157,6 +159,7 @@ TEST_F(SqlEngineTest, GroupBy) { query.target_.query_params_ = std::make_shared(IndexType::FLAT); query.target_.query_params_->set_radius(0.8F); + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute_group_by(schema_, query, segments); if (!ret) { diff --git a/tests/db/sqlengine/vector_recall_test.cc b/tests/db/sqlengine/vector_recall_test.cc index f034c7159..fdc49a50a 100644 --- a/tests/db/sqlengine/vector_recall_test.cc +++ b/tests/db/sqlengine/vector_recall_test.cc @@ -31,6 +31,7 @@ TEST_F(VectorRecallTest, Basic) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -61,6 +62,7 @@ TEST_F(VectorRecallTest, HybridInvertFilter) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -92,6 +94,7 @@ TEST_F(VectorRecallTest, HybridInvertFilterBfByKeys) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -122,6 +125,7 @@ TEST_F(VectorRecallTest, HybridForwardFilter) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -153,6 +157,7 @@ TEST_F(VectorRecallTest, HybridInvertForwardFilter) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -187,6 +192,7 @@ TEST_F(VectorRecallTest, Sparse) { feature.size() * sizeof(float))); query.target_.field_name_ = "sparse"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -227,6 +233,7 @@ TEST_F(VectorRecallTest, DeleteFilter) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { @@ -259,6 +266,7 @@ TEST_F(VectorRecallTest, HybridInvertForwardDeleteFilter) { feature.size() * sizeof(float))); query.target_.field_name_ = "dense"; + auto engine = SQLEngine::create(std::make_shared()); auto ret = engine->execute(collection_schema_, query, segments_); if (!ret) { From 0c7394a2155338c5350c8f93566378a0b8f0d7a4 Mon Sep 17 00:00:00 2001 From: "jiliang.ljl" Date: Tue, 16 Jun 2026 22:27:52 +0800 Subject: [PATCH 2/3] address comments --- src/db/collection.cc | 5 ++++ src/db/index/common/doc.cc | 34 +++++++++++++++++++------ src/db/index/common/query.cc | 40 +++++++++++++++++++----------- src/db/index/common/type_helper.cc | 19 ++++++++++++++ src/db/index/common/type_helper.h | 10 ++++++++ src/include/zvec/db/query.h | 4 +++ 6 files changed, 89 insertions(+), 23 deletions(-) diff --git a/src/db/collection.cc b/src/db/collection.cc index 004852305..72a181a08 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -1704,6 +1704,11 @@ Result CollectionImpl::Query(const MultiQuery &query) const { query.queries.size())); } + if (auto s = validate_topk_and_output_fields(query.topk, query.output_fields); + !s.ok()) { + return tl::make_unexpected(s); + } + auto segments = get_all_segments(); if (segments.empty()) { return DocPtrList(); diff --git a/src/db/index/common/doc.cc b/src/db/index/common/doc.cc index b16b4038b..f76ad1d86 100644 --- a/src/db/index/common/doc.cc +++ b/src/db/index/common/doc.cc @@ -907,14 +907,23 @@ Status Doc::validate_and_sanitize(const CollectionSchema::Ptr &schema, "] exceeds the maximum number of sparse indices (", kSparseMaxDimSize, ")"); } - if (sort_and_find_duplicates( - sparse_indices.data(), - reinterpret_cast(sparse_values.data()), - sparse_indices.size(), sizeof(float16_t))) { + auto status = need_sanitize_sparse(sparse_indices.data(), + sparse_indices.size()); + if (status == SparseIndicesStatus::kHasDuplicate) { return Status::InvalidArgument( "Invalid doc[", pk_, "]: sparse vector field[", field_name, "] contains duplicate indices"); } + if (status == SparseIndicesStatus::kNeedSort) { + if (sort_and_find_duplicates( + sparse_indices.data(), + reinterpret_cast(sparse_values.data()), + sparse_indices.size(), sizeof(float16_t))) { + return Status::InvalidArgument( + "Invalid doc[", pk_, "]: sparse vector field[", field_name, + "] contains duplicate indices"); + } + } } break; } @@ -936,14 +945,23 @@ Status Doc::validate_and_sanitize(const CollectionSchema::Ptr &schema, "] exceeds the maximum number of sparse indices (", kSparseMaxDimSize, ")"); } - if (sort_and_find_duplicates( - sparse_indices.data(), - reinterpret_cast(sparse_values.data()), - sparse_indices.size(), sizeof(float))) { + auto status = need_sanitize_sparse(sparse_indices.data(), + sparse_indices.size()); + if (status == SparseIndicesStatus::kHasDuplicate) { return Status::InvalidArgument( "Invalid doc[", pk_, "]: sparse vector field[", field_name, "] contains duplicate indices"); } + if (status == SparseIndicesStatus::kNeedSort) { + if (sort_and_find_duplicates( + sparse_indices.data(), + reinterpret_cast(sparse_values.data()), + sparse_indices.size(), sizeof(float))) { + return Status::InvalidArgument( + "Invalid doc[", pk_, "]: sparse vector field[", field_name, + "] contains duplicate indices"); + } + } } break; } diff --git a/src/db/index/common/query.cc b/src/db/index/common/query.cc index 494d06cf9..ca94a2db9 100644 --- a/src/db/index/common/query.cc +++ b/src/db/index/common/query.cc @@ -12,9 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include -#include #include #include #include "db/common/constants.h" @@ -157,10 +155,13 @@ Status QueryTarget::validate(const FieldSchema *schema, if (n_indices > 1 && need_sanitize) { const auto *idx = reinterpret_cast(query_sparse_indices.data()); - // Detect any non-strictly-increasing pair (unsorted or duplicate). - if (std::adjacent_find(idx, idx + n_indices, - std::greater_equal()) != - idx + n_indices) { + auto status = need_sanitize_sparse(idx, n_indices); + if (status == SparseIndicesStatus::kHasDuplicate) { + return Status::InvalidArgument( + "Invalid query: sparse vector query for field[", field_name, + "] contains duplicate indices"); + } + if (status == SparseIndicesStatus::kNeedSort) { *need_sanitize = true; } } @@ -180,22 +181,31 @@ Status QueryTarget::validate(const FieldSchema *schema, return Status::OK(); } -Status SearchQuery::validate(const FieldSchema *schema, - bool *need_sanitize) const { - if (need_sanitize) { - *need_sanitize = false; - } - if ((uint32_t)topk_ > kMaxQueryTopk) { - return Status::InvalidArgument("Invalid query: topk[", topk_, +Status validate_topk_and_output_fields( + int topk, const std::optional> &output_fields) { + if ((uint32_t)topk > kMaxQueryTopk) { + return Status::InvalidArgument("Invalid query: topk[", topk, "] exceeds the maximum allowed value of ", kMaxQueryTopk); } - if (output_fields_.has_value() && - output_fields_->size() > kMaxOutputFieldSize) { + if (output_fields.has_value() && + output_fields->size() > kMaxOutputFieldSize) { return Status::InvalidArgument( "Invalid query: too many output fields, the maximum allowed is ", kMaxOutputFieldSize); } + return Status::OK(); +} + +Status SearchQuery::validate(const FieldSchema *schema, + bool *need_sanitize) const { + if (need_sanitize) { + *need_sanitize = false; + } + auto s = validate_topk_and_output_fields(topk_, output_fields_); + if (!s.ok()) { + return s; + } return target_.validate(schema, need_sanitize); } diff --git a/src/db/index/common/type_helper.cc b/src/db/index/common/type_helper.cc index 3360c89ce..8d26a065d 100644 --- a/src/db/index/common/type_helper.cc +++ b/src/db/index/common/type_helper.cc @@ -15,12 +15,31 @@ #include "type_helper.h" #include #include +#include #include #include #include namespace zvec { +SparseIndicesStatus need_sanitize_sparse(const uint32_t *indices, size_t n) { + if (n <= 1) { + return SparseIndicesStatus::kOk; + } + auto it = + std::adjacent_find(indices, indices + n, std::greater_equal()); + if (it == indices + n) { + return SparseIndicesStatus::kOk; + } + if (*it == *(it + 1)) { + return SparseIndicesStatus::kHasDuplicate; + } + // First non-strictly-increasing pair is a > b (not equal), so unsorted. + // But there may still be duplicates elsewhere — we only know sorting is + // needed; duplicates will be detected after sort_and_find_duplicates. + return SparseIndicesStatus::kNeedSort; +} + bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n, size_t value_byte_size) { if (n <= 1) { diff --git a/src/db/index/common/type_helper.h b/src/db/index/common/type_helper.h index f5630f64c..d24f1ee52 100644 --- a/src/db/index/common/type_helper.h +++ b/src/db/index/common/type_helper.h @@ -22,6 +22,16 @@ namespace zvec { +enum class SparseIndicesStatus { + kOk, + kNeedSort, + kHasDuplicate, +}; + +//! Read-only check of sparse indices: sorted-unique, unsorted, or has +//! duplicates. Uses adjacent_find so it is O(n) and non-mutating. +SparseIndicesStatus need_sanitize_sparse(const uint32_t *indices, size_t n); + //! Sort sparse (indices, values) pairs in place by index ascending and report //! whether any duplicate index exists. value_byte_size is the per-value stride. bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n, diff --git a/src/include/zvec/db/query.h b/src/include/zvec/db/query.h index a3499513d..6de901d93 100644 --- a/src/include/zvec/db/query.h +++ b/src/include/zvec/db/query.h @@ -148,6 +148,10 @@ struct SearchQuery { Status validate(const FieldSchema *schema, bool *need_sanitize) const; }; +// Validate topk and output_fields bounds. +Status validate_topk_and_output_fields( + int topk, const std::optional> &output_fields); + // Sort sparse indices in-place and check for duplicates. // Returns error if duplicates are found after sorting. Status sanitize_sparse_vector(VectorClause &vc, const FieldSchema *schema); From a057644f5b63918a1e3b22204f1e26a58209d7f5 Mon Sep 17 00:00:00 2001 From: "jiliang.ljl" Date: Wed, 17 Jun 2026 19:51:13 +0800 Subject: [PATCH 3/3] fix ut --- tests/db/index/common/doc_test.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/db/index/common/doc_test.cc b/tests/db/index/common/doc_test.cc index 63c703d96..b1f7a5b7f 100644 --- a/tests/db/index/common/doc_test.cc +++ b/tests/db/index/common/doc_test.cc @@ -1348,8 +1348,7 @@ TEST(SearchQuery, ValidateAndSanitize) { (std::vector{0.4f, 0.2f, 0.1f, 0.5f, 0.3f})); } - // duplicates (sorted): validate detects as unsorted (equal == not strictly - // less), sanitize sorts and reports duplicates + // duplicates (sorted): validate detects duplicates directly { SearchQuery query; query.target_.field_name_ = "field_name"; @@ -1358,13 +1357,9 @@ TEST(SearchQuery, ValidateAndSanitize) { pack_val({0.1f, 0.2f, 0.3f, 0.4f, 0.5f})); bool need_sanitize = false; auto s = query.validate(&schema, &need_sanitize); - EXPECT_TRUE(s.ok()); - EXPECT_TRUE(need_sanitize); - - VectorClause vc = *query.target_.get_vector_clause(); - s = sanitize_sparse_vector(vc, &schema); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.code(), StatusCode::INVALID_ARGUMENT); + EXPECT_FALSE(need_sanitize); } // duplicates (unsorted): sanitize sorts then reports duplicates