Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/db/collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1667,22 +1667,30 @@ Result<DocPtrList> CollectionImpl::Query(const SearchQuery &query) const {

CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false);

SearchQuery sanitized = query;
// When field_name_ is set, use get_field to retrieve the schema uniformly.
// validate_and_sanitize checks that the field type matches the query type
// validate checks that the field type matches the query type
// (FTS query requires an FTS field, vector query requires a vector field).
const auto &field_name = sanitized.target_.field_name_;
const auto &field_name = query.target_.field_name_;
const FieldSchema *field_schema =
field_name.empty() ? nullptr : schema_->get_field(field_name);
auto s = sanitized.validate_and_sanitize(field_schema);
bool need_sanitize = false;
auto s = query.validate(field_schema, &need_sanitize);
CHECK_RETURN_STATUS_EXPECTED(s);

auto segments = get_all_segments();
if (segments.empty()) {
return DocPtrList();
}

return sql_engine_->execute(schema_, std::move(sanitized), segments);
if (!need_sanitize) {
return sql_engine_->execute(schema_, query, segments);
}

// Sparse needs sanitization: make a mutable copy and sort indices in place.
SearchQuery sanitized_query = query;
auto ss = sanitize_sparse_vector(sanitized_query.target_, field_schema);
CHECK_RETURN_STATUS_EXPECTED(ss);
return sql_engine_->execute(schema_, std::move(sanitized_query), segments);
}

Result<DocPtrList> CollectionImpl::Query(const MultiQuery &query) const {
Expand All @@ -1696,6 +1704,11 @@ Result<DocPtrList> CollectionImpl::Query(const MultiQuery &query) const {
query.queries.size()));
}

if (auto s = validate_topk_and_output_fields(query.topk, query.output_fields);
!s.ok()) {
return tl::make_unexpected(s);
}

auto segments = get_all_segments();
if (segments.empty()) {
return DocPtrList();
Expand All @@ -1716,6 +1729,10 @@ Result<DocPtrList> CollectionImpl::Query(const MultiQuery &query) const {
}
auto *field_schema = field_ptr.get();

bool need_sanitize = false;
auto s = target.validate(field_schema, &need_sanitize);
CHECK_RETURN_STATUS_EXPECTED(s);

SearchQuery sq;
sq.target_ = target;
sq.topk_ = sub.num_candidates_;
Expand All @@ -1724,8 +1741,10 @@ Result<DocPtrList> CollectionImpl::Query(const MultiQuery &query) const {
sq.include_doc_id_ = query.include_doc_id_;
sq.output_fields_ = query.output_fields;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是漏了对SearchQuery的validate?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


auto s = sq.validate_and_sanitize(field_schema);
CHECK_RETURN_STATUS_EXPECTED(s);
if (need_sanitize) {
auto ss = sanitize_sparse_vector(sq.target_, field_schema);
CHECK_RETURN_STATUS_EXPECTED(ss);
}
pending_queries.push_back(std::move(sq));
field_schemas.push_back(std::move(field_ptr));
}
Expand Down Expand Up @@ -1777,7 +1796,22 @@ Result<GroupResults> CollectionImpl::GroupByQuery(
return GroupResults();
}

return sql_engine_->execute_group_by(schema_, query, segments);
// Determine vector data source (zero-copy for dense, copy+sort for sparse)
const FieldSchema *field_schema =
schema_->get_field(query.target_.field_name_);
bool need_sanitize = false;
auto s = query.target_.validate(field_schema, &need_sanitize);
CHECK_RETURN_STATUS_EXPECTED(s);

if (!need_sanitize) {
return sql_engine_->execute_group_by(schema_, query, segments);
}

// Sparse needs sanitization: make a mutable copy and sort indices in place.
GroupByVectorQuery sanitized_query = query;
auto ss = sanitize_sparse_vector(sanitized_query.target_, field_schema);
CHECK_RETURN_STATUS_EXPECTED(ss);
return sql_engine_->execute_group_by(schema_, sanitized_query, segments);
}

Result<DocPtrMap> CollectionImpl::Fetch(
Expand Down
34 changes: 26 additions & 8 deletions src/db/index/common/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -907,14 +907,23 @@ Status Doc::validate_and_sanitize(const CollectionSchema::Ptr &schema,
"] exceeds the maximum number of sparse indices (",
kSparseMaxDimSize, ")");
}
if (sort_and_find_duplicates(
sparse_indices.data(),
reinterpret_cast<char *>(sparse_values.data()),
sparse_indices.size(), sizeof(float16_t))) {
auto status = need_sanitize_sparse(sparse_indices.data(),
sparse_indices.size());
if (status == SparseIndicesStatus::kHasDuplicate) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
if (status == SparseIndicesStatus::kNeedSort) {
if (sort_and_find_duplicates(
sparse_indices.data(),
reinterpret_cast<char *>(sparse_values.data()),
sparse_indices.size(), sizeof(float16_t))) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
}
}
break;
}
Expand All @@ -936,14 +945,23 @@ Status Doc::validate_and_sanitize(const CollectionSchema::Ptr &schema,
"] exceeds the maximum number of sparse indices (",
kSparseMaxDimSize, ")");
}
if (sort_and_find_duplicates(
sparse_indices.data(),
reinterpret_cast<char *>(sparse_values.data()),
sparse_indices.size(), sizeof(float))) {
auto status = need_sanitize_sparse(sparse_indices.data(),
sparse_indices.size());
if (status == SparseIndicesStatus::kHasDuplicate) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
if (status == SparseIndicesStatus::kNeedSort) {
if (sort_and_find_duplicates(
sparse_indices.data(),
reinterpret_cast<char *>(sparse_values.data()),
sparse_indices.size(), sizeof(float))) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
}
}
break;
}
Expand Down
117 changes: 89 additions & 28 deletions src/db/index/common/query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,16 @@

namespace zvec {

Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) {
if ((uint32_t)topk_ > kMaxQueryTopk) {
return Status::InvalidArgument("Invalid query: topk[", topk_,
"] exceeds the maximum allowed value of ",
kMaxQueryTopk);
}
if (output_fields_.has_value() &&
output_fields_->size() > kMaxOutputFieldSize) {
return Status::InvalidArgument(
"Invalid query: too many output fields, the maximum allowed is ",
kMaxOutputFieldSize);
}

auto *vc = target_.get_vector_clause();
auto *fc = target_.get_fts_clause();
auto &field_name = target_.field_name_;
auto &query_params = target_.query_params_;
Status QueryTarget::validate(const FieldSchema *schema,
bool *need_sanitize) const {
auto opt_view = get_vector_view();
auto *fc = get_fts_clause();
auto &field_name = field_name_;
auto &query_params = query_params_;
// A "scalar-only filter" query has no vector payload — either the clause
// is not a VectorClause (e.g., FtsClause) or its fields are all empty.
bool no_vector_payload = (vc == nullptr) || (vc->query_vector_.empty() &&
vc->sparse_indices_.empty());
bool no_vector_payload = !opt_view || (opt_view->query_vector_.empty() &&
opt_view->sparse_indices_.empty());

if (schema == nullptr) {
if (fc != nullptr) {
Expand Down Expand Up @@ -87,9 +76,9 @@ Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) {
"Invalid query: missing query clause for field[", field_name, "]");
}

auto &query_vector = vc->query_vector_;
auto &query_sparse_indices = vc->sparse_indices_;
auto &query_sparse_values = vc->sparse_values_;
auto &query_vector = opt_view->query_vector_;
auto &query_sparse_indices = opt_view->sparse_indices_;
auto &query_sparse_values = opt_view->sparse_values_;

// Vector query
if (schema->is_dense_vector()) {
Expand Down Expand Up @@ -163,12 +152,18 @@ Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) {
"Invalid query: too many sparse indices, the maximum allowed is ",
kSparseMaxDimSize);
}
if (sort_and_find_duplicates(
reinterpret_cast<uint32_t *>(query_sparse_indices.data()),
query_sparse_values.data(), n_indices, value_byte_size)) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name,
"] contains duplicate indices");
if (n_indices > 1 && need_sanitize) {
const auto *idx =
reinterpret_cast<const uint32_t *>(query_sparse_indices.data());
auto status = need_sanitize_sparse(idx, n_indices);
if (status == SparseIndicesStatus::kHasDuplicate) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name,
"] contains duplicate indices");
}
if (status == SparseIndicesStatus::kNeedSort) {
*need_sanitize = true;
}
}
} else {
return Status::InvalidArgument("Invalid query: field[", field_name,
Expand All @@ -186,4 +181,70 @@ Status SearchQuery::validate_and_sanitize(const FieldSchema *schema) {
return Status::OK();
}

Status validate_topk_and_output_fields(
int topk, const std::optional<std::vector<std::string>> &output_fields) {
if ((uint32_t)topk > kMaxQueryTopk) {
return Status::InvalidArgument("Invalid query: topk[", topk,
"] exceeds the maximum allowed value of ",
kMaxQueryTopk);
}
if (output_fields.has_value() &&
output_fields->size() > kMaxOutputFieldSize) {
return Status::InvalidArgument(
"Invalid query: too many output fields, the maximum allowed is ",
kMaxOutputFieldSize);
}
return Status::OK();
}

Status SearchQuery::validate(const FieldSchema *schema,
bool *need_sanitize) const {
if (need_sanitize) {
*need_sanitize = false;
}
auto s = validate_topk_and_output_fields(topk_, output_fields_);
if (!s.ok()) {
return s;
}
return target_.validate(schema, need_sanitize);
}

Status sanitize_sparse_vector(VectorClause &vc, const FieldSchema *schema) {
if (!schema || !schema->is_sparse_vector()) {
return Status::OK();
}
size_t value_byte_size = 0;
switch (schema->data_type()) {
case DataType::SPARSE_VECTOR_FP32:
value_byte_size = sizeof(float);
break;
case DataType::SPARSE_VECTOR_FP16:
value_byte_size = sizeof(float16_t);
break;
default:
return Status::OK();
}
size_t n_indices = vc.sparse_indices_.size() / sizeof(uint32_t);
if (n_indices <= 1) {
return Status::OK();
}
if (sort_and_find_duplicates(
reinterpret_cast<uint32_t *>(vc.sparse_indices_.data()),
vc.sparse_values_.data(), n_indices, value_byte_size)) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", schema->name(),
"] contains duplicate indices");
}
return Status::OK();
}

Status sanitize_sparse_vector(QueryTarget &target, const FieldSchema *schema) {
if (auto *vvc = target.get_vector_view_clause()) {
target.clause_ = VectorClause{std::string(vvc->query_vector_),
std::string(vvc->sparse_indices_),
std::string(vvc->sparse_values_)};
}
return sanitize_sparse_vector(*target.get_vector_clause(), schema);
}

} // namespace zvec
32 changes: 19 additions & 13 deletions src/db/index/common/type_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,34 @@
#include "type_helper.h"
#include <algorithm>
#include <cstring>
#include <functional>
#include <numeric>
#include <vector>
#include <zvec/core/framework/index_meta.h>

namespace zvec {

bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n,
size_t value_byte_size) {
SparseIndicesStatus need_sanitize_sparse(const uint32_t *indices, size_t n) {
if (n <= 1) {
return false;
return SparseIndicesStatus::kOk;
}
bool already_sorted = true;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写入过程应该还需要这段逻辑

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为先调用need_sanitize_sparse判断是否需要排序

for (size_t i = 1; i < n; ++i) {
if (indices[i] == indices[i - 1]) {
return true;
}
if (indices[i] < indices[i - 1]) {
already_sorted = false;
break;
}
auto it =
std::adjacent_find(indices, indices + n, std::greater_equal<uint32_t>());
if (it == indices + n) {
return SparseIndicesStatus::kOk;
}
if (*it == *(it + 1)) {
return SparseIndicesStatus::kHasDuplicate;
}
if (already_sorted) {
// First non-strictly-increasing pair is a > b (not equal), so unsorted.
// But there may still be duplicates elsewhere — we only know sorting is
// needed; duplicates will be detected after sort_and_find_duplicates.
return SparseIndicesStatus::kNeedSort;
}

bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n,
size_t value_byte_size) {
if (n <= 1) {
return false;
}
std::vector<size_t> perm(n);
Expand Down
10 changes: 10 additions & 0 deletions src/db/index/common/type_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@

namespace zvec {

enum class SparseIndicesStatus {
kOk,
kNeedSort,
kHasDuplicate,
};

//! Read-only check of sparse indices: sorted-unique, unsorted, or has
//! duplicates. Uses adjacent_find so it is O(n) and non-mutating.
SparseIndicesStatus need_sanitize_sparse(const uint32_t *indices, size_t n);

//! Sort sparse (indices, values) pairs in place by index ascending and report
//! whether any duplicate index exists. value_byte_size is the per-value stride.
bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n,
Expand Down
4 changes: 2 additions & 2 deletions src/db/sqlengine/analyzer/query_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,8 @@ Status QueryAnalyzer::check_and_convert_vector(
}

*vector_cond = std::make_shared<QueryInfo::QueryVectorCondInfo>(
vector_meta, vector_data->take_matrix(), core_data_type, dimension,
vector_data->take_sparse_indices(), vector_data->take_sparse_values(),
vector_meta, vector_data->matrix(), core_data_type, dimension,
vector_data->sparse_indices(), vector_data->sparse_values(),
vector_data->take_query_params());
return Status::OK();
} else {
Expand Down
Loading
Loading