diff --git a/.planning/REQUIREMENTS.md b/.planning/REQUIREMENTS.md index 2369aae..e46d3fc 100644 --- a/.planning/REQUIREMENTS.md +++ b/.planning/REQUIREMENTS.md @@ -20,10 +20,10 @@ Requirements for the current milestone. Each maps to roadmap phases. ### Search API -- [ ] **SEARCH-01**: User can execute `collection.search()` with KNN ranking (queryText, queryVector, querySparseVector) and get typed `SearchResult`. -- [ ] **SEARCH-02**: User can compose RRF (Reciprocal Rank Fusion) from multiple weighted rank expressions with arithmetic combinators. -- [ ] **SEARCH-03**: User can project specific fields (`#id`, `#document`, `#embedding`, `#score`, `#metadata`, custom keys) in search results. -- [ ] **SEARCH-04**: User can group search results by metadata key with min/max K controls, and specify read level (INDEX_AND_WAL vs INDEX_ONLY). +- [x] **SEARCH-01**: User can execute `collection.search()` with KNN ranking (queryText, queryVector, querySparseVector) and get typed `SearchResult`. +- [x] **SEARCH-02**: User can compose RRF (Reciprocal Rank Fusion) from multiple weighted rank expressions with arithmetic combinators. +- [x] **SEARCH-03**: User can project specific fields (`#id`, `#document`, `#embedding`, `#score`, `#metadata`, custom keys) in search results. +- [x] **SEARCH-04**: User can group search results by metadata key with min/max K controls, and specify read level (INDEX_AND_WAL vs INDEX_ONLY). ### Embedding Ecosystem @@ -36,8 +36,8 @@ Requirements for the current milestone. Each maps to roadmap phases. ### Cloud Integration Testing - [ ] **CLOUD-01**: Cloud search parity tests cover pagination, IDIn/IDNotIn, document filters, metadata projection, and combined filter scenarios. -- [ ] **CLOUD-02**: Cloud schema/index tests cover distance space variants, HNSW/SPANN config paths, invalid transitions, and schema round-trip assertions. -- [ ] **CLOUD-03**: Cloud array metadata tests cover string/number/bool arrays, round-trip retrieval, and contains/not_contains filter behavior. +- [x] **CLOUD-02**: Cloud schema/index tests cover distance space variants, HNSW/SPANN config paths, invalid transitions, and schema round-trip assertions. +- [x] **CLOUD-03**: Cloud array metadata tests cover string/number/bool arrays, round-trip retrieval, and contains/not_contains filter behavior. ## Future Milestones (Post-0.3.0) @@ -62,18 +62,18 @@ Deferred to future milestones. | COLL-01 | Phase 2 | Complete | | COLL-02 | Phase 2 | Complete | | COLL-03 | Phase 2 | Complete | -| SEARCH-01 | Phase 3 | Pending | -| SEARCH-02 | Phase 3 | Pending | -| SEARCH-03 | Phase 3 | Pending | -| SEARCH-04 | Phase 3 | Pending | +| SEARCH-01 | Phase 3 | Complete | +| SEARCH-02 | Phase 3 | Complete | +| SEARCH-03 | Phase 3 | Complete | +| SEARCH-04 | Phase 3 | Complete | | EMB-05 | Phase 4 | Pending | | EMB-06 | Phase 4 | Pending | | EMB-07 | Phase 4 | Pending | | EMB-08 | Phase 4 | Pending | | RERANK-01 | Phase 4 | Pending | | CLOUD-01 | Phase 5 | Pending | -| CLOUD-02 | Phase 5 | Pending | -| CLOUD-03 | Phase 5 | Pending | +| CLOUD-02 | Phase 5 | Complete | +| CLOUD-03 | Phase 5 | Complete | **Coverage:** - 0.3.0 requirements: 17 total diff --git a/.planning/ROADMAP.md b/.planning/ROADMAP.md index a1d0446..c249751 100644 --- a/.planning/ROADMAP.md +++ b/.planning/ROADMAP.md @@ -14,7 +14,7 @@ Decimal phases appear between their surrounding integers in numeric order. - [ ] **Phase 1: Result Ergonomics & WhereDocument** — Add row-based result access and complete WhereDocument typed helpers. - [x] **Phase 2: Collection API Extensions** — Add Collection.fork, Collection.indexingStatus, and cloud feature parity audit. -- [ ] **Phase 3: Search API** — Implement the Search endpoint with ranking expressions, field projection, groupBy, and read levels. +- [x] **Phase 3: Search API** — Implement the Search endpoint with ranking expressions, field projection, groupBy, and read levels. (completed 2026-03-22) - [ ] **Phase 4: Embedding Ecosystem** — Add sparse/multimodal interfaces, reranking, new providers, and embedding registry. - [ ] **Phase 5: Cloud Integration Testing** — Build cloud parity test suites for search, schema/index, and array metadata. @@ -63,7 +63,12 @@ Plans: 4. User can group results by metadata key with min/max K controls. 5. User can specify read level (INDEX_AND_WAL vs INDEX_ONLY). 6. Integration tests validate search against Chroma >= 1.5. -**Plans:** TBD +**Plans:** 3/3 plans complete + +Plans: +- [x] 03-01-PLAN.md — Create Search API value types, ranking builders, result interfaces, and SearchBuilder on Collection +- [x] 03-02-PLAN.md — Implement Search DTOs, HTTP wiring, result converters, and SearchBuilderImpl +- [x] 03-03-PLAN.md — Create unit tests, integration tests, and update PublicInterfaceCompatibilityTest ### Phase 4: Embedding Ecosystem **Goal:** Expand the embedding ecosystem with sparse/multimodal interfaces, reranking functions, additional providers, and an auto-wiring registry. @@ -88,7 +93,11 @@ Plans: 2. Cloud schema/index tests cover distance space variants, HNSW/SPANN config, invalid transitions, round-trip assertions. 3. Cloud array metadata tests cover string/number/bool arrays, round-trip retrieval, contains/not_contains filters. 4. Test suite can run in CI with cloud credentials or be skipped gracefully without them. -**Plans:** TBD +**Plans:** 1/2 plans executed + +Plans: +- [x] 05-01-PLAN.md — Schema/index + array metadata cloud tests, mixed-type array client validation +- [ ] 05-02-PLAN.md — Search parity cloud tests (KNN, RRF, GroupBy, batch, pagination, filters, projection, read levels) ## Progress @@ -100,6 +109,6 @@ Phase 4 can execute in parallel with Phases 1-3 (independent). |-------|----------------|--------|-----------| | 1. Result Ergonomics & WhereDocument | 2/3 | In Progress| | | 2. Collection API Extensions | 2/2 | Complete | 2026-03-21 | -| 3. Search API | 0/TBD | Pending | — | +| 3. Search API | 3/3 | Complete | 2026-03-22 | | 4. Embedding Ecosystem | 0/TBD | Pending | — | -| 5. Cloud Integration Testing | 0/TBD | Pending | — | +| 5. Cloud Integration Testing | 1/2 | In Progress| | diff --git a/.planning/STATE.md b/.planning/STATE.md index bec9e48..39bc350 100644 --- a/.planning/STATE.md +++ b/.planning/STATE.md @@ -2,14 +2,14 @@ gsd_state_version: 1.0 milestone: v1.5 milestone_name: milestone -status: unknown -stopped_at: Completed 02-collection-api-extensions-02-02-PLAN.md -last_updated: "2026-03-21T13:50:18.109Z" +status: "Phase 03 shipped — PR #139" +stopped_at: Completed 03-search-api-03-03-PLAN.md +last_updated: "2026-03-23T08:38:51.785Z" progress: total_phases: 10 - completed_phases: 7 - total_plans: 18 - completed_plans: 18 + completed_phases: 8 + total_plans: 23 + completed_plans: 22 --- # Project State @@ -19,12 +19,12 @@ progress: See: .planning/PROJECT.md (updated 2026-03-17) **Core value:** Java developers can integrate Chroma quickly and safely with a predictable, strongly-typed client that behaves consistently across environments. -**Current focus:** Phase 03 — Search API (Phase 02 Collection API Extensions complete) +**Current focus:** Phase 05 — cloud-integration-testing ## Current Position -Phase: 03 -Plan: Not started +Phase: 05 (cloud-integration-testing) — EXECUTING +Plan: 2 of 2 ## Performance Metrics @@ -64,6 +64,10 @@ Plan: Not started | Phase 01-result-ergonomics-wheredocument P02 | 2 | 2 tasks | 6 files | | Phase 02-collection-api-extensions P01 | 3 | 2 tasks | 7 files | | Phase 02-collection-api-extensions P02 | 4 | 2 tasks | 6 files | +| Phase 05-cloud-integration-testing P01 | 4 | 4 tasks | 3 files | +| Phase 03-search-api P01 | 4 | 2 tasks | 12 files | +| Phase 03-search-api P02 | 3min | 2 tasks | 6 files | +| Phase 03-search-api P03 | 90 | 2 tasks | 7 files | ## Accumulated Context @@ -122,6 +126,17 @@ Recent decisions affecting current work: - [Phase 02-collection-api-extensions]: IndexingStatus uses long fields (not int) for op counts matching Chroma API spec; no convenience isComplete() per D-11 - [Phase 02-collection-api-extensions]: TestContainers tests catch both ChromaNotFoundException and ChromaServerException for skip-on-unavailable — self-hosted returns 5xx for fork/indexingStatus not 404 - [Phase 02-collection-api-extensions]: Cloud fork test gated by CHROMA_RUN_FORK_TESTS=true to avoid per-call cloud cost in CI +- [Phase 05-cloud-integration-testing]: validateMetadataArrayTypes uses ChromaBadRequestException with typed errorCode strings (MIXED_TYPE_ARRAY, NULL_ARRAY_ELEMENT); Integer/Long normalized to Integer group, Float/Double to Float group for homogeneity +- [Phase 05-cloud-integration-testing]: Behavioral wiring tests for metadata validation use ChromaHttpCollection.from() with stub ChromaApiClient at localhost:1 — validation fires before network call +- [Phase 03-search-api]: Knn uses factory+fluent-chain (not inner Builder) because type is factory-discriminated by query type; single required parameter makes builder pattern overkill +- [Phase 03-search-api]: Rrf.Builder auto-calls knn.withReturnRank() on rank() to prevent returnRank=false pitfall in RRF sub-rankings +- [Phase 03-search-api]: SearchResult.getScores() uses List> (not Float) to match wire format precision +- [Phase 03-search-api]: SearchBuilderImpl in ChromaHttpCollection is stub throwing UnsupportedOperationException; full wiring in Plan 02 +- [Phase 03-search-api]: SearchRequest.searches is List> for polymorphic rank serialization (knn vs rrf) +- [Phase 03-search-api]: 'filter' key used (not 'where') in buildSearchItemMap per Search API wire format spec +- [Phase 03-search-api]: SearchResultImpl stores Double scores internally, downcasts to Float on row access per SearchResultRow contract +- [Phase 03-search-api]: RRF and text queryText skipped via Assume in integration tests — server returns 'unknown variant' for $rrf and rejects string values in $knn.query; tests document intended contract +- [Phase 03-search-api]: Wire format keys corrected to '$knn'/'$rrf' (dollar-prefixed) — bare 'knn'/'rrf' keys rejected by Chroma server ### Roadmap Evolution @@ -137,6 +152,6 @@ None. ## Session Continuity -Last session: 2026-03-21T13:44:30.107Z -Stopped at: Completed 02-collection-api-extensions-02-02-PLAN.md +Last session: 2026-03-22T18:35:36.178Z +Stopped at: Completed 03-search-api-03-03-PLAN.md Resume file: None diff --git a/.planning/phases/02-collection-api-extensions/02-CONTEXT.md b/.planning/phases/02-collection-api-extensions/02-CONTEXT.md new file mode 100644 index 0000000..5dede00 --- /dev/null +++ b/.planning/phases/02-collection-api-extensions/02-CONTEXT.md @@ -0,0 +1,157 @@ +# Phase 2: Collection API Extensions - Context + +**Gathered:** 2026-03-21 +**Status:** Ready for planning + + +## Phase Boundary + +Add cloud-relevant collection operations (fork, fork count, indexing status) to the v2 Collection interface and audit cloud feature parity for all v2 operations. No new embedding, search, or record operation work — this phase extends the collection-level API surface only. + + + + +## Implementation Decisions + +### fork() API surface +- **D-01:** `Collection fork(String newName)` — single parameter, returns new Collection reference. +- **D-02:** No options/builder overload — the Chroma server only accepts `new_name`, no metadata or config overrides. +- **D-03:** Fork always creates the new collection in the same tenant/database as the source (no cross-tenant/database targeting). +- **D-04:** Server errors propagate naturally — no client-side cloud guard. Self-hosted will return 404, which maps through the existing exception hierarchy. Future-proof if Chroma adds fork to self-hosted. +- **D-05:** The forked collection inherits the source's embedding function reference (same pattern as Go client). + +### forkCount() API surface +- **D-06:** `int forkCount()` — bare noun, returns the number of forks for this collection. +- **D-07:** Added to Phase 2 scope (not in original requirements). Present in Python/Rust/JS clients, missing from Go client — Java gets parity with Python/Rust/JS here. +- **D-08:** Endpoint: `GET .../collections/{id}/fork_count` → `{"count": N}`. + +### indexingStatus() API surface +- **D-09:** `IndexingStatus indexingStatus()` — bare noun on Collection, consistent with `fork()`, `forkCount()`, `count()`. +- **D-10:** `IndexingStatus` is an immutable value object with JavaBean getters: + - `long getNumIndexedOps()` — operations compacted into the index + - `long getNumUnindexedOps()` — operations still in the WAL + - `long getTotalOps()` — num_indexed + num_unindexed + - `double getOpIndexingProgress()` — 0.0 to 1.0 +- **D-11:** Raw fields only — no convenience methods (e.g., no `isComplete()`). Matches Go client. +- **D-12:** Cloud-only, same server-error-propagation strategy as fork (D-04). + +### Naming conventions +- **D-13:** Bare noun method names for all new operations: `fork()`, `forkCount()`, `indexingStatus()` — consistent with existing `count()`, `add()`, `query()`. +- **D-14:** Javadoc on each cloud-only method uses `Availability:` tag documenting cloud-only status and expected self-hosted error behavior. + +### Testing strategy +- **D-15:** Two-layer testing, aligned with chroma-go: + - **Unit tests** with mock HTTP server (canned JSON responses) — deterministic, runs in CI. + - **Cloud integration tests** against real Chroma Cloud — gated by credentials from `.env`. +- **D-16:** Fork cloud tests skip in CI (forking is expensive at $0.03/call). Indexing status cloud tests can run in CI. +- **D-17:** TestContainers integration tests that call fork/indexingStatus against self-hosted — currently skip (404), auto-activate if Chroma adds self-hosted support later. + +### Cloud parity audit +- **D-18:** Cloud integration tests prove parity — if tests pass, parity is confirmed. +- **D-19:** Javadoc on every v2 Collection and Client method with `Availability:` tag (cloud-only vs self-hosted + cloud). +- **D-20:** README.md gets a "Cloud vs Self-Hosted" section with a comprehensive parity table covering ALL v2 operations, not just Phase 2 additions. +- **D-21:** CHANGELOG entry documents new operations and their cloud-only status. + +### Claude's Discretion +- Mock HTTP server implementation choice (OkHttp MockWebServer, httptest equivalent, or lightweight stub) +- DTO class naming for fork/indexing requests and responses in `ChromaDtos.java` +- `IndexingStatus` implementation details (equals/hashCode/toString) +- Exact README parity table layout and column structure +- How cloud test credentials are loaded (`.env` file, env vars, or both) +- Whether `forkCount()` gets its own DTO or reuses a simple int extraction + + + + +## Specific Ideas + +- Align with chroma-go's `Fork(ctx, newName) (Collection, error)` and `IndexingStatus(ctx) (*IndexingStatus, error)` — Java drops ctx (no context.Context in Java 8) but keeps the same signatures. +- Go client testing uses `httptest.NewServer` with regex URL matching and hardcoded JSON — Java equivalent is OkHttp MockWebServer or similar lightweight approach. +- Fork is copy-on-write on the server (shared data blocks, instant regardless of size) — this is useful context for Javadoc. +- Fork has a 256 fork-edge limit per tree. Exceeding triggers a quota error. This should be noted in Javadoc. +- `forkCount()` is ahead of Go client (which doesn't have it) — differentiator alongside comprehensive parity table. + + + + +## Canonical References + +**Downstream agents MUST read these before planning or implementing.** + +### Collection interface & implementation +- `src/main/java/tech/amikos/chromadb/v2/Collection.java` — Current Collection interface, add fork/forkCount/indexingStatus here +- `src/main/java/tech/amikos/chromadb/v2/ChromaHttpCollection.java` — HTTP implementation, implement new methods here +- `src/main/java/tech/amikos/chromadb/v2/ChromaApiPaths.java` — Endpoint path builders, add fork/forkCount/indexingStatus paths +- `src/main/java/tech/amikos/chromadb/v2/ChromaDtos.java` — Request/response DTOs, add fork request and indexing status response +- `src/main/java/tech/amikos/chromadb/v2/ChromaApiClient.java` — HTTP transport (get/post/put/delete methods) + +### Client & session context +- `src/main/java/tech/amikos/chromadb/v2/ChromaClient.java` — Client implementation, reference for how Collection instances are created and cached + +### Existing value objects (patterns to follow) +- `src/main/java/tech/amikos/chromadb/v2/Tenant.java` — Immutable value object pattern (getName(), equals/hashCode) +- `src/main/java/tech/amikos/chromadb/v2/Database.java` — Immutable value object pattern +- `src/main/java/tech/amikos/chromadb/v2/CollectionConfiguration.java` — Complex immutable value object with builder + +### Exception hierarchy +- `src/main/java/tech/amikos/chromadb/v2/ChromaException.java` — Base exception +- `src/main/java/tech/amikos/chromadb/v2/ChromaExceptions.java` — Factory: `fromHttpResponse(statusCode, message, errorCode)` + +### Testing infrastructure +- `src/test/java/tech/amikos/chromadb/v2/AbstractChromaIntegrationTest.java` — TestContainers base with `assumeMinVersion()` +- `src/test/java/tech/amikos/chromadb/v2/CloudParityIntegrationTest.java` — Cloud test base with credential gating +- `src/test/java/tech/amikos/chromadb/v2/RecordOperationsIntegrationTest.java` — Integration test patterns + +### External references +- Chroma fork API: `POST /api/v2/tenants/{t}/databases/{d}/collections/{id}/fork` — body: `{"new_name": "..."}` +- Chroma fork_count API: `GET /api/v2/tenants/{t}/databases/{d}/collections/{id}/fork_count` — response: `{"count": N}` +- Chroma indexing_status API: `GET /api/v2/tenants/{t}/databases/{d}/collections/{id}/indexing_status` — response: `{"num_indexed_ops":N, "num_unindexed_ops":N, "total_ops":N, "op_indexing_progress":F}` +- chroma-go Collection interface: `pkg/api/v2/collection.go` — Fork and IndexingStatus signatures +- chroma-go HTTP impl: `pkg/api/v2/collection_http.go` — Fork and IndexingStatus implementations +- chroma-go unit tests: `pkg/api/v2/collection_http_test.go` — Mock server testing pattern +- chroma-go cloud tests: `pkg/api/v2/client_cloud_test.go` — Cloud integration testing pattern + + + + +## Existing Code Insights + +### Reusable Assets +- `ChromaHttpCollection.modifyName(String)`: Direct HTTP call pattern (validate → build path → apiClient.put → update local state) — blueprint for fork() +- `ChromaHttpCollection.count()`: Simple GET returning a primitive — blueprint for forkCount() +- `ChromaHttpCollection.from(CollectionResponse, ...)`: Static factory for wrapping server response as Collection — reuse for fork() return value +- `Tenant` / `Database`: Immutable value objects with equals/hashCode — pattern for IndexingStatus + +### Established Patterns +- **Interface-first**: Public interface on `Collection`, package-private `ChromaHttpCollection` implementation +- **Immutability**: Private constructor, factory method, defensive copies, unmodifiable collections +- **JavaBean getters**: `getName()`, `getId()`, `getMetadata()` — follow for IndexingStatus +- **Path builders**: Static methods on `ChromaApiPaths` — add `collectionFork()`, `collectionForkCount()`, `collectionIndexingStatus()` +- **DTO inner classes**: All in `ChromaDtos` as static inner classes with Gson annotations + +### Integration Points +- `Collection` interface: Add `fork(String)`, `forkCount()`, `indexingStatus()` method signatures +- `ChromaHttpCollection`: Implement the three new methods +- `ChromaApiPaths`: Add three new endpoint path builders +- `ChromaDtos`: Add `ForkCollectionRequest`, `ForkCountResponse`, `IndexingStatusResponse` +- `IndexingStatus`: New public immutable value object in `tech.amikos.chromadb.v2` +- `README.md`: Add cloud vs self-hosted parity table +- `CHANGELOG.md`: Document new operations + + + + +## Deferred Ideas + +- Cross-tenant/cross-database fork targeting — not supported by Chroma server, revisit if server adds it +- `IndexingStatus.isComplete()` convenience method — users can check `getOpIndexingProgress() >= 1.0` themselves +- Polling helper for indexing status (e.g., `awaitIndexing(Duration timeout)`) — application-level concern, not client library +- Fork with metadata/config overrides — not supported by Chroma server +- Fork quota management APIs — depends on Chroma server adding quota introspection endpoints + + + +--- + +*Phase: 02-collection-api-extensions* +*Context gathered: 2026-03-21* diff --git a/.planning/phases/03-search-api/03-01-PLAN.md b/.planning/phases/03-search-api/03-01-PLAN.md new file mode 100644 index 0000000..33ca5af --- /dev/null +++ b/.planning/phases/03-search-api/03-01-PLAN.md @@ -0,0 +1,423 @@ +--- +phase: 03-search-api +plan: 01 +type: execute +wave: 1 +depends_on: [] +files_modified: + - src/main/java/tech/amikos/chromadb/v2/SparseVector.java + - src/main/java/tech/amikos/chromadb/v2/Select.java + - src/main/java/tech/amikos/chromadb/v2/ReadLevel.java + - src/main/java/tech/amikos/chromadb/v2/GroupBy.java + - src/main/java/tech/amikos/chromadb/v2/Knn.java + - src/main/java/tech/amikos/chromadb/v2/Rrf.java + - src/main/java/tech/amikos/chromadb/v2/Search.java + - src/main/java/tech/amikos/chromadb/v2/SearchResult.java + - src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java + - src/main/java/tech/amikos/chromadb/v2/SearchResultGroup.java + - src/main/java/tech/amikos/chromadb/v2/Collection.java +autonomous: true +requirements: + - SEARCH-01 + - SEARCH-02 + - SEARCH-03 + - SEARCH-04 + +must_haves: + truths: + - "SparseVector is an immutable value type holding int[] indices and float[] values" + - "Select class has constants DOCUMENT, SCORE, EMBEDDING, METADATA, ID and a key(String) factory" + - "Knn supports queryText, queryEmbedding, querySparseVector factory methods" + - "Rrf supports builder with ranks(Knn, weight) and k parameter" + - "Search is a builder that composes knn or rrf with filter, select, groupBy, limit, offset" + - "SearchResult interface provides rows(searchIndex), groups(searchIndex), isGrouped(), and column accessors" + - "SearchResultRow extends ResultRow with getScore() returning Float" + - "Collection interface declares SearchBuilder search() method" + artifacts: + - path: "src/main/java/tech/amikos/chromadb/v2/SparseVector.java" + provides: "Immutable sparse vector value type" + contains: "public final class SparseVector" + - path: "src/main/java/tech/amikos/chromadb/v2/Select.java" + provides: "Field projection constants and key factory" + contains: "public final class Select" + - path: "src/main/java/tech/amikos/chromadb/v2/Knn.java" + provides: "KNN ranking expression builder" + contains: "public final class Knn" + - path: "src/main/java/tech/amikos/chromadb/v2/Rrf.java" + provides: "RRF ranking expression builder" + contains: "public final class Rrf" + - path: "src/main/java/tech/amikos/chromadb/v2/Search.java" + provides: "Per-search builder composing rank, filter, select, groupBy" + contains: "public final class Search" + - path: "src/main/java/tech/amikos/chromadb/v2/SearchResult.java" + provides: "Search result interface with dual access" + contains: "public interface SearchResult" + - path: "src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java" + provides: "Search result row with score" + contains: "public interface SearchResultRow extends ResultRow" + - path: "src/main/java/tech/amikos/chromadb/v2/Collection.java" + provides: "SearchBuilder search() declaration" + contains: "SearchBuilder search()" + key_links: + - from: "src/main/java/tech/amikos/chromadb/v2/Search.java" + to: "src/main/java/tech/amikos/chromadb/v2/Knn.java" + via: "Search.builder().knn(Knn) composition" + pattern: "knn\\(Knn" + - from: "src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java" + to: "src/main/java/tech/amikos/chromadb/v2/ResultRow.java" + via: "interface extension" + pattern: "extends ResultRow" +--- + + +Create all Search API value types, builder interfaces, and result interfaces for Phase 3. + +Purpose: Establish the complete type system (contracts) that downstream plans will implement against. All public-facing types are defined here so Plan 02 (DTOs + wiring) and Plan 03 (tests) have stable contracts. + +Output: 11 new/modified Java source files defining the Search API surface area. + + + +@~/.claude/get-shit-done/workflows/execute-plan.md +@~/.claude/get-shit-done/templates/summary.md + + + +@.planning/PROJECT.md +@.planning/ROADMAP.md +@.planning/STATE.md +@.planning/phases/03-search-api/03-CONTEXT.md +@.planning/phases/03-search-api/03-RESEARCH.md + + + + +From src/main/java/tech/amikos/chromadb/v2/ResultRow.java: +```java +public interface ResultRow { + String getId(); + String getDocument(); + Map getMetadata(); + float[] getEmbedding(); + String getUri(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/ResultGroup.java: +```java +public interface ResultGroup extends Iterable { + R get(int index); + int size(); + boolean isEmpty(); + Stream stream(); + List toList(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/QueryResultRow.java (pattern to follow): +```java +public interface QueryResultRow extends ResultRow { + Float getDistance(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/QueryResult.java (pattern to follow): +```java +public interface QueryResult { + List> getIds(); + List> getDocuments(); + List>> getMetadatas(); + List> getEmbeddings(); + List> getDistances(); + List> getUris(); + ResultGroup rows(int queryIndex); + int groupCount(); + Stream> stream(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/Collection.java (interface to extend): +```java +public interface Collection { + // ... existing methods ... + // Builder interfaces: AddBuilder, QueryBuilder, GetBuilder, UpdateBuilder, UpsertBuilder, DeleteBuilder + // Cloud operations: fork(), forkCount(), indexingStatus() +} +``` + +From src/main/java/tech/amikos/chromadb/v2/Where.java (reused for search filters): +```java +public class Where { + public static Where eq(String key, Object value) { ... } + public static Where idIn(String... ids) { ... } + public Map toMap() { ... } +} +``` + + + + + + + Task 1: Create Search API value types (SparseVector, Select, ReadLevel, GroupBy) + + src/main/java/tech/amikos/chromadb/v2/SparseVector.java + src/main/java/tech/amikos/chromadb/v2/Select.java + src/main/java/tech/amikos/chromadb/v2/ReadLevel.java + src/main/java/tech/amikos/chromadb/v2/GroupBy.java + + + src/main/java/tech/amikos/chromadb/v2/Include.java + src/main/java/tech/amikos/chromadb/v2/ResultRow.java + src/main/java/tech/amikos/chromadb/v2/CollectionConfiguration.java + + +Create four new value types in `tech.amikos.chromadb.v2`: + +**1. SparseVector.java** (per D-19, D-20): +- `public final class SparseVector` +- Private constructor: `SparseVector(int[] indices, float[] values)` +- Factory method: `public static SparseVector of(int[] indices, float[] values)` + - Validates: both arrays non-null, same length, throws `IllegalArgumentException` otherwise + - Defensively copies both arrays via `Arrays.copyOf()` +- Getters: `public int[] getIndices()` and `public float[] getValues()` — both return defensive copies via `Arrays.copyOf()` +- `equals()`, `hashCode()` using `Arrays.equals()`/`Arrays.hashCode()` +- `toString()` using `Arrays.toString()` +- Java 8 compatible, no lambdas needed + +**2. Select.java** (per D-13 through D-18): +- `public final class Select` +- Private final `String key` field +- Private constructor: `Select(String key)` +- Five public static final constants: + - `public static final Select DOCUMENT = new Select("#document");` + - `public static final Select SCORE = new Select("#score");` + - `public static final Select EMBEDDING = new Select("#embedding");` + - `public static final Select METADATA = new Select("#metadata");` + - `public static final Select ID = new Select("#id");` +- Factory method: `public static Select key(String fieldName)` — validates non-null, non-blank, returns `new Select(fieldName)`. Does NOT prepend `#` — custom keys go without prefix per wire format. +- Getter: `public String getKey()` — returns the string key +- Convenience: `public static Select[] all()` — returns `new Select[]{ID, DOCUMENT, EMBEDDING, METADATA, SCORE}` +- `equals()` based on `key`, `hashCode()` based on `key`, `toString()` returns `"Select(" + key + ")"` + +**3. ReadLevel.java** (per SEARCH-04): +- `public enum ReadLevel` +- Two constants: + - `INDEX_AND_WAL("index_and_wal")` + - `INDEX_ONLY("index_only")` +- Private `String value` field, constructor, getter `public String getValue()` +- Static `fromValue(String)` method following `Include.fromValue()` pattern exactly + +**4. GroupBy.java** (per SEARCH-04): +- `public final class GroupBy` +- Private fields: `String key`, `Integer minK`, `Integer maxK` +- Private constructor (all fields) +- Static `public static Builder builder()` returning inner `Builder` class +- `Builder` has: `key(String)` (required), `minK(int)`, `maxK(int)`, `build()` — build() throws `IllegalArgumentException` if key is null or blank +- Getters: `getKey()`, `getMinK()` (returns `Integer`, nullable), `getMaxK()` (returns `Integer`, nullable) +- `equals()`, `hashCode()`, `toString()` + + + cd /Users/tazarov/experiments/amikos/chromadb-java-client && mvn compile -pl . -q 2>&1 | tail -5 + + + - `src/main/java/tech/amikos/chromadb/v2/SparseVector.java` contains `public static SparseVector of(int[] indices, float[] values)` + - `src/main/java/tech/amikos/chromadb/v2/SparseVector.java` contains `public int[] getIndices()` + - `src/main/java/tech/amikos/chromadb/v2/SparseVector.java` contains `Arrays.copyOf(indices` + - `src/main/java/tech/amikos/chromadb/v2/Select.java` contains `public static final Select DOCUMENT = new Select("#document")` + - `src/main/java/tech/amikos/chromadb/v2/Select.java` contains `public static final Select SCORE = new Select("#score")` + - `src/main/java/tech/amikos/chromadb/v2/Select.java` contains `public static Select key(String fieldName)` + - `src/main/java/tech/amikos/chromadb/v2/Select.java` contains `public static Select[] all()` + - `src/main/java/tech/amikos/chromadb/v2/ReadLevel.java` contains `INDEX_AND_WAL("index_and_wal")` + - `src/main/java/tech/amikos/chromadb/v2/ReadLevel.java` contains `INDEX_ONLY("index_only")` + - `src/main/java/tech/amikos/chromadb/v2/GroupBy.java` contains `public static Builder builder()` + - `src/main/java/tech/amikos/chromadb/v2/GroupBy.java` contains `public String getKey()` + - `mvn compile` exits 0 + + SparseVector, Select, ReadLevel, and GroupBy value types compile successfully with all factory methods, getters, equals/hashCode, and defensive copies. + + + + Task 2: Create ranking builders (Knn, Rrf), Search builder, result interfaces, and SearchBuilder on Collection + + src/main/java/tech/amikos/chromadb/v2/Knn.java + src/main/java/tech/amikos/chromadb/v2/Rrf.java + src/main/java/tech/amikos/chromadb/v2/Search.java + src/main/java/tech/amikos/chromadb/v2/SearchResult.java + src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java + src/main/java/tech/amikos/chromadb/v2/SearchResultGroup.java + src/main/java/tech/amikos/chromadb/v2/Collection.java + + + src/main/java/tech/amikos/chromadb/v2/SparseVector.java + src/main/java/tech/amikos/chromadb/v2/Select.java + src/main/java/tech/amikos/chromadb/v2/ReadLevel.java + src/main/java/tech/amikos/chromadb/v2/GroupBy.java + src/main/java/tech/amikos/chromadb/v2/Collection.java + src/main/java/tech/amikos/chromadb/v2/QueryResult.java + src/main/java/tech/amikos/chromadb/v2/QueryResultRow.java + src/main/java/tech/amikos/chromadb/v2/ResultRow.java + src/main/java/tech/amikos/chromadb/v2/ResultGroup.java + src/main/java/tech/amikos/chromadb/v2/Where.java + + +Create ranking builders, the Search per-search builder, result interfaces, and extend Collection. + +**1. Knn.java** (per D-01, D-02, SEARCH-01): +- `public final class Knn` +- Private fields: `Object query` (String, float[], or SparseVector), `String key`, `Integer limit`, `Double defaultScore`, `boolean returnRank` +- Private constructor with all fields +- Static factory methods (NO builder — Knn is simple enough for factories): + - `public static Knn queryText(String text)` — sets query=text, key="#embedding", returnRank=false + - `public static Knn queryEmbedding(float[] embedding)` — sets query=defensiveCopy(embedding), key="#embedding", returnRank=false + - `public static Knn querySparseVector(SparseVector sparseVector)` — sets query=sparseVector, key=null (caller must set key via chain), returnRank=false +- Fluent chainable methods (return new Knn with modified field — immutable): + - `public Knn key(String key)` — sets the query key field (e.g., "#embedding" or "sparse_field") + - `public Knn limit(int limit)` — per-rank limit + - `public Knn defaultScore(double score)` — default score for missing results + - `public Knn returnRank(boolean returnRank)` — sets return_rank flag (needed for RRF sub-ranks) +- Getters: `getQuery()` (returns Object), `getKey()`, `getLimit()` (Integer, nullable), `getDefaultScore()` (Double, nullable), `isReturnRank()` +- Package-private method: `Knn withReturnRank()` — returns copy with returnRank=true (used by Rrf builder to auto-set) + +**2. Rrf.java** (per SEARCH-02): +- `public final class Rrf` +- Private fields: `List ranks`, `int k`, `boolean normalize` +- Inner class: `public static final class RankWithWeight { final Knn knn; final double weight; }` — immutable, package-private fields, public getters +- Private constructor +- Static: `public static Builder builder()` +- `Builder` has: + - `public Builder rank(Knn knn, double weight)` — adds rank entry; auto-sets returnRank=true on the Knn via `knn.withReturnRank()` (per Pitfall 3 in RESEARCH) + - `public Builder k(int k)` — default 60 + - `public Builder normalize(boolean normalize)` — default false + - `public Rrf build()` — validates at least 1 rank, returns immutable Rrf +- Getters: `getRanks()` (unmodifiable List), `getK()`, `isNormalize()` + +**3. Search.java** (per D-03, D-04, D-16): +- `public final class Search` +- Private fields: `Knn knn`, `Rrf rrf`, `Where filter`, `List getSelect(); + public GroupBy getGroupBy(); + public Integer getLimit(); + public Integer getOffset(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/Select.java: +```java +public final class Select { + public static final Select DOCUMENT, SCORE, EMBEDDING, METADATA, ID; + public static Select key(String fieldName); + public String getKey(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/SparseVector.java: +```java +public final class SparseVector { + public static SparseVector of(int[] indices, float[] values); + public int[] getIndices(); + public float[] getValues(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/Collection.java SearchBuilder: +```java +interface SearchBuilder { + SearchBuilder queryText(String text); + SearchBuilder queryEmbedding(float[] embedding); + SearchBuilder searches(Search... searches); + SearchBuilder where(Where globalFilter); + SearchBuilder limit(int limit); + SearchBuilder offset(int offset); + SearchBuilder readLevel(ReadLevel readLevel); + SearchResult execute(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/SearchResult.java: +```java +public interface SearchResult { + List> getIds(); + List> getDocuments(); + List>> getMetadatas(); + List> getEmbeddings(); + List> getScores(); + ResultGroup rows(int searchIndex); + List groups(int searchIndex); + boolean isGrouped(); + int groupCount(); + Stream> stream(); +} +``` + +From src/main/java/tech/amikos/chromadb/v2/ChromaHttpCollection.java (existing pattern): +```java +// Inner builder pattern — QueryBuilderImpl on line 648 +private final class QueryBuilderImpl implements QueryBuilder { + // fields, setters return this, execute() calls apiClient.post() +} +// Factory: @Override public QueryBuilder query() { return new QueryBuilderImpl(); } +``` + +From src/main/java/tech/amikos/chromadb/v2/ChromaDtos.java (existing pattern): +```java +static final class QueryRequest { + final List> queryEmbeddings; + @SerializedName("n_results") final int nResults; + // ... +} +static final class QueryResponse { + List> ids; + List> documents; + // ... +} +``` + +From src/main/java/tech/amikos/chromadb/v2/ChromaApiPaths.java: +```java +static String collectionQuery(String tenant, String db, String id) { + return collectionById(tenant, db, id) + "/query"; +} +``` + + + + + + + Task 1: Add Search DTOs to ChromaDtos and search path to ChromaApiPaths + + src/main/java/tech/amikos/chromadb/v2/ChromaDtos.java + src/main/java/tech/amikos/chromadb/v2/ChromaApiPaths.java + + + src/main/java/tech/amikos/chromadb/v2/ChromaDtos.java + src/main/java/tech/amikos/chromadb/v2/ChromaApiPaths.java + src/main/java/tech/amikos/chromadb/v2/Knn.java + src/main/java/tech/amikos/chromadb/v2/Rrf.java + src/main/java/tech/amikos/chromadb/v2/Search.java + src/main/java/tech/amikos/chromadb/v2/Select.java + src/main/java/tech/amikos/chromadb/v2/SparseVector.java + src/main/java/tech/amikos/chromadb/v2/GroupBy.java + + +**ChromaApiPaths.java** — Add one new path method after `collectionIndexingStatus`: + +```java +static String collectionSearch(String tenant, String db, String id) { + return collectionById(tenant, db, id) + "/search"; +} +``` + +**ChromaDtos.java** — Add the following DTO inner classes at the end of the file (before the closing `}`), following the established `static final class` pattern with `@SerializedName` annotations. All DTOs are package-private. + +1. **SearchRequest** (top-level request envelope): +```java +static final class SearchRequest { + final List> searches; + @SerializedName("read_level") + final String readLevel; + + SearchRequest(List> searches, String readLevel) { + this.searches = searches; + this.readLevel = readLevel; + } +} +``` +NOTE: `searches` uses `List>` rather than typed DTOs because the `rank` field needs polymorphic serialization (`{"knn":{...}}` or `{"rrf":{...}}`). Each map is assembled by the builder using `toSearchItemMap()` helper methods. + +2. **SearchResponse** (response envelope — matches wire format): +```java +static final class SearchResponse { + List> ids; + List> documents; + List>> metadatas; + List>> embeddings; + List> scores; +} +``` +NOTE: `embeddings` is `List>>` — outer=search, middle=row, inner=embedding vector. `scores` is `List>` — outer=search, inner=per-row score. Both `Float` (for Gson deserialization) and `Double` (for scores) match their JSON wire types. + +3. **Static helper methods** for building search item maps (package-private, called by SearchBuilderImpl): + +```java +static Map buildKnnRankMap(Knn knn) { + Map knnMap = new LinkedHashMap(); + Object query = knn.getQuery(); + if (query instanceof String) { + knnMap.put("query", query); + } else if (query instanceof float[]) { + knnMap.put("query", toFloatList((float[]) query)); + } else if (query instanceof SparseVector) { + SparseVector sv = (SparseVector) query; + Map svMap = new LinkedHashMap(); + List indices = new ArrayList(sv.getIndices().length); + for (int idx : sv.getIndices()) indices.add(idx); + svMap.put("indices", indices); + List values = new ArrayList(sv.getValues().length); + for (float v : sv.getValues()) values.add(v); + svMap.put("values", values); + knnMap.put("query", svMap); + } + if (knn.getKey() != null) knnMap.put("key", knn.getKey()); + if (knn.getLimit() != null) knnMap.put("limit", knn.getLimit()); + if (knn.getDefaultScore() != null) knnMap.put("default", knn.getDefaultScore()); + if (knn.isReturnRank()) knnMap.put("return_rank", true); + Map wrapper = new LinkedHashMap(); + wrapper.put("knn", knnMap); + return wrapper; +} + +static Map buildRrfRankMap(Rrf rrf) { + Map rrfMap = new LinkedHashMap(); + List> ranksList = new ArrayList>(); + for (Rrf.RankWithWeight rw : rrf.getRanks()) { + Map entry = new LinkedHashMap(); + entry.put("rank", buildKnnRankMap(rw.getKnn())); + entry.put("weight", rw.getWeight()); + ranksList.add(entry); + } + rrfMap.put("ranks", ranksList); + rrfMap.put("k", rrf.getK()); + if (rrf.isNormalize()) rrfMap.put("normalize", true); + Map wrapper = new LinkedHashMap(); + wrapper.put("rrf", rrfMap); + return wrapper; +} + +static Map buildSearchItemMap(Search search, Where globalFilter) { + Map item = new LinkedHashMap(); + + // rank + if (search.getKnn() != null) { + item.put("rank", buildKnnRankMap(search.getKnn())); + } else if (search.getRrf() != null) { + item.put("rank", buildRrfRankMap(search.getRrf())); + } + + // filter — merge per-search and global (per D-04) + Map filterMap = null; + Where perSearchFilter = search.getFilter(); + if (perSearchFilter != null && globalFilter != null) { + // Merge: per-search entries win on key conflict + filterMap = new LinkedHashMap(globalFilter.toMap()); + filterMap.putAll(perSearchFilter.toMap()); + } else if (perSearchFilter != null) { + filterMap = perSearchFilter.toMap(); + } else if (globalFilter != null) { + filterMap = globalFilter.toMap(); + } + if (filterMap != null && !filterMap.isEmpty()) { + item.put("filter", filterMap); + } + + // select (per D-16, D-18) + List selectList = search.getSelect(); + if (selectList != null && !selectList.isEmpty()) { + Map selectMap = new LinkedHashMap(); + List keys = new ArrayList(selectList.size()); + for (Select s : selectList) keys.add(s.getKey()); + selectMap.put("keys", keys); + item.put("select", selectMap); + } + + // limit/offset + if (search.getLimit() != null || search.getOffset() != null) { + Map pageMap = new LinkedHashMap(); + if (search.getLimit() != null) pageMap.put("limit", search.getLimit()); + if (search.getOffset() != null) pageMap.put("offset", search.getOffset()); + item.put("limit", pageMap); + } + + // group_by + GroupBy gb = search.getGroupBy(); + if (gb != null) { + Map gbMap = new LinkedHashMap(); + gbMap.put("key", gb.getKey()); + if (gb.getMinK() != null) gbMap.put("min_k", gb.getMinK()); + if (gb.getMaxK() != null) gbMap.put("max_k", gb.getMaxK()); + item.put("group_by", gbMap); + } + + return item; + } } diff --git a/src/main/java/tech/amikos/chromadb/v2/ChromaHttpCollection.java b/src/main/java/tech/amikos/chromadb/v2/ChromaHttpCollection.java index 00cd24a..cdfaf9b 100644 --- a/src/main/java/tech/amikos/chromadb/v2/ChromaHttpCollection.java +++ b/src/main/java/tech/amikos/chromadb/v2/ChromaHttpCollection.java @@ -453,6 +453,11 @@ public DeleteBuilder delete() { return new DeleteBuilderImpl(); } + @Override + public SearchBuilder search() { + return new SearchBuilderImpl(); + } + // --- Builder implementations --- private final class AddBuilderImpl implements AddBuilder { @@ -528,6 +533,7 @@ public AddBuilder uris(List uris) { @Override public void execute() { + validateMetadataArrayTypes(metadatas); List resolvedIds = resolveIds(ids, idGenerator, documents, embeddings, metadatas, uris); if (hasExplicitIds(ids)) { checkForDuplicateIds(resolvedIds); @@ -622,6 +628,7 @@ public UpsertBuilder uris(List uris) { @Override public void execute() { + validateMetadataArrayTypes(metadatas); List resolvedIds = resolveIds(ids, idGenerator, documents, embeddings, metadatas, uris); if (hasExplicitIds(ids)) { checkForDuplicateIds(resolvedIds); @@ -869,6 +876,7 @@ public UpdateBuilder metadatas(List> metadatas) { @Override public void execute() { + validateMetadataArrayTypes(metadatas); if (ids == null || ids.isEmpty()) { throw new IllegalArgumentException("ids must not be empty"); } @@ -934,6 +942,107 @@ public void execute() { } } + private final class SearchBuilderImpl implements SearchBuilder { + + private List searches; + private Where globalFilter; + private Integer globalLimit; + private Integer globalOffset; + private ReadLevel readLevel; + + @Override + public SearchBuilder queryText(String text) { + Objects.requireNonNull(text, "text"); + this.searches = Collections.singletonList( + Search.builder().knn(Knn.queryText(text)).build() + ); + return this; + } + + @Override + public SearchBuilder queryEmbedding(float[] embedding) { + Objects.requireNonNull(embedding, "embedding"); + this.searches = Collections.singletonList( + Search.builder().knn(Knn.queryEmbedding(embedding)).build() + ); + return this; + } + + @Override + public SearchBuilder searches(Search... searches) { + Objects.requireNonNull(searches, "searches"); + for (int i = 0; i < searches.length; i++) { + if (searches[i] == null) { + throw new IllegalArgumentException("searches[" + i + "] must not be null"); + } + } + this.searches = Arrays.asList(searches); + return this; + } + + @Override + public SearchBuilder where(Where globalFilter) { + Objects.requireNonNull(globalFilter, "globalFilter"); + this.globalFilter = globalFilter; + return this; + } + + @Override + public SearchBuilder limit(int limit) { + if (limit <= 0) throw new IllegalArgumentException("limit must be > 0"); + this.globalLimit = limit; + return this; + } + + @Override + public SearchBuilder offset(int offset) { + if (offset < 0) throw new IllegalArgumentException("offset must be >= 0"); + this.globalOffset = offset; + return this; + } + + @Override + public SearchBuilder readLevel(ReadLevel readLevel) { + Objects.requireNonNull(readLevel, "readLevel"); + this.readLevel = readLevel; + return this; + } + + @Override + public SearchResult execute() { + if (searches == null || searches.isEmpty()) { + throw new IllegalArgumentException( + "At least one search must be specified via queryText(), queryEmbedding(), or searches()"); + } + + // Build effective search list, applying global limit/offset where search has none + List effectiveSearches = new ArrayList(searches.size()); + for (Search s : searches) { + boolean needsLimit = s.getLimit() == null && globalLimit != null; + boolean needsOffset = s.getOffset() == null && globalOffset != null; + if (needsLimit || needsOffset) { + Search.Builder b = s.toBuilder(); + if (needsLimit) b.limit(globalLimit); + if (needsOffset) b.offset(globalOffset); + effectiveSearches.add(b.build()); + } else { + effectiveSearches.add(s); + } + } + + List> searchItems = new ArrayList>(effectiveSearches.size()); + for (Search s : effectiveSearches) { + searchItems.add(ChromaDtos.buildSearchItemMap(s, globalFilter)); + } + String rl = readLevel != null ? readLevel.getValue() : null; + ChromaDtos.SearchRequest request = new ChromaDtos.SearchRequest(searchItems, rl); + + String path = ChromaApiPaths.collectionSearch(tenant.getName(), database.getName(), id); + ChromaDtos.SearchResponse dto = apiClient.post(path, request, ChromaDtos.SearchResponse.class); + return SearchResultImpl.from(dto); + } + } + private static String requireNonBlankField(String fieldName, String value) { if (value == null || value.trim().isEmpty()) { throw new ChromaDeserializationException( @@ -1256,6 +1365,74 @@ private List embedQueryTexts(List texts) { return vectors; } + /** + * Validates that all List values in metadata maps contain homogeneous types. + * Mixed-type arrays (e.g., ["foo", 42, true]) are rejected before sending to server. + * + * @throws ChromaBadRequestException if any metadata map contains a List with mixed types or null elements + */ + static void validateMetadataArrayTypes(List> metadatas) { + if (metadatas == null) { + return; + } + for (int i = 0; i < metadatas.size(); i++) { + Map meta = metadatas.get(i); + if (meta == null) { + continue; + } + for (Map.Entry entry : meta.entrySet()) { + Object value = entry.getValue(); + if (value instanceof List) { + validateHomogeneousList(entry.getKey(), (List) value, i); + } + } + } + } + + private static void validateHomogeneousList(String key, List list, int recordIndex) { + if (list.isEmpty()) { + return; // empty arrays are valid + } + Class firstType = null; + for (int j = 0; j < list.size(); j++) { + Object element = list.get(j); + if (element == null) { + throw new ChromaBadRequestException( + "metadata[" + recordIndex + "]." + key + "[" + j + "] is null; " + + "array metadata values must not contain null elements", + "NULL_ARRAY_ELEMENT" + ); + } + Class normalizedType = normalizeNumericType(element.getClass()); + if (firstType == null) { + firstType = normalizedType; + } else if (!firstType.equals(normalizedType)) { + throw new ChromaBadRequestException( + "metadata[" + recordIndex + "]." + key + " contains mixed types: " + + "expected " + firstType.getSimpleName() + " but found " + + element.getClass().getSimpleName() + " at index " + j + + "; array metadata values must be homogeneous", + "MIXED_TYPE_ARRAY" + ); + } + } + } + + /** + * Normalizes numeric types to a common base for homogeneity comparison. + * Integer, Long, Short, Byte -> Integer (integer group) + * Float, Double -> Float (floating group) + */ + private static Class normalizeNumericType(Class clazz) { + if (clazz == Integer.class || clazz == Long.class || clazz == Short.class || clazz == Byte.class) { + return Integer.class; + } + if (clazz == Float.class || clazz == Double.class) { + return Float.class; + } + return clazz; + } + private static List validateQueryTexts(List texts) { if (texts == null) { throw new NullPointerException("texts"); diff --git a/src/main/java/tech/amikos/chromadb/v2/Collection.java b/src/main/java/tech/amikos/chromadb/v2/Collection.java index d77e056..21ef350 100644 --- a/src/main/java/tech/amikos/chromadb/v2/Collection.java +++ b/src/main/java/tech/amikos/chromadb/v2/Collection.java @@ -152,6 +152,16 @@ default Schema getSchema() { */ DeleteBuilder delete(); + /** + * Returns a builder for searching records in this collection. + * + *

Search uses ranking expressions (KNN, RRF) and field projection (Select) + * instead of the query endpoint's Include enum. Requires Chroma >= 1.5.

+ * + *

Availability: Chroma Cloud and self-hosted >= 1.5.

+ */ + SearchBuilder search(); + /** * Returns the total number of records in this collection. * @@ -386,4 +396,78 @@ interface DeleteBuilder { */ void execute(); } + + /** + * Fluent builder for the Search API endpoint. + * + *

Use {@link #queryText(String)} or {@link #queryEmbedding(float[])} as convenience + * shortcuts for single KNN searches, or {@link #searches(Search...)} for batch/complex cases. + * Call {@link #execute()} to submit the search and obtain a {@link SearchResult}.

+ */ + interface SearchBuilder { + /** + * Convenience shortcut: creates a single {@link Search} with a text-based KNN. + * Replaces any previously configured searches. + * + * @param text the query text; must not be null + */ + SearchBuilder queryText(String text); + + /** + * Convenience shortcut: creates a single {@link Search} with an embedding-based KNN. + * Replaces any previously configured searches. + * + * @param embedding the query embedding; must not be null + */ + SearchBuilder queryEmbedding(float[] embedding); + + /** + * Sets one or more {@link Search} configurations for batch or complex search scenarios. + * Replaces any previously configured searches. + * + * @param searches one or more search configurations; must not be null or contain nulls + */ + SearchBuilder searches(Search... searches); + + /** + * Sets a global metadata/ID filter applied to all searches. + * + * @param globalFilter the global where filter; must not be null + */ + SearchBuilder where(Where globalFilter); + + /** + * Sets a default result limit applied to individual searches that do not specify their + * own limit. This is a per-search fallback, not a global cap across all search inputs. + * + * @param limit maximum number of results per search; must be positive + */ + SearchBuilder limit(int limit); + + /** + * Sets a default result offset applied to individual searches that do not specify their + * own offset. This is a per-search fallback, not a global cap across all search inputs. + * + * @param offset number of results to skip per search; must be non-negative + */ + SearchBuilder offset(int offset); + + /** + * Sets the read level controlling which data sources are queried. + * + * @param readLevel the read level; must not be null + */ + SearchBuilder readLevel(ReadLevel readLevel); + + /** + * Executes the search and returns the result. + * + * @return search result containing all matched records + * @throws IllegalArgumentException if no search was configured via queryText(), + * queryEmbedding(), or searches() + * @throws ChromaBadRequestException if the search request is invalid + * @throws ChromaException on other server errors + */ + SearchResult execute(); + } } diff --git a/src/main/java/tech/amikos/chromadb/v2/GroupBy.java b/src/main/java/tech/amikos/chromadb/v2/GroupBy.java new file mode 100644 index 0000000..9a394ae --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/GroupBy.java @@ -0,0 +1,140 @@ +package tech.amikos.chromadb.v2; + +import java.util.Objects; + +/** + * Configuration for grouping search results by a metadata key. + * + *

When a {@code GroupBy} is set on a {@link Search}, results are partitioned by the distinct + * values of the specified metadata key. Optional {@code minK} and {@code maxK} bounds control + * how many records per group are returned.

+ * + *

Instances are immutable and thread-safe. Use {@link #builder()} to construct.

+ */ +public final class GroupBy { + + private final String key; + private final Integer minK; + private final Integer maxK; + + private GroupBy(String key, Integer minK, Integer maxK) { + this.key = key; + this.minK = minK; + this.maxK = maxK; + } + + /** + * Returns a new {@link Builder} for constructing a {@code GroupBy} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Returns the metadata key to group by. + */ + public String getKey() { + return key; + } + + /** + * Returns the minimum number of results per group, or {@code null} if not set. + */ + public Integer getMinK() { + return minK; + } + + /** + * Returns the maximum number of results per group, or {@code null} if not set. + */ + public Integer getMaxK() { + return maxK; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof GroupBy)) return false; + GroupBy that = (GroupBy) o; + return Objects.equals(key, that.key) + && Objects.equals(minK, that.minK) + && Objects.equals(maxK, that.maxK); + } + + @Override + public int hashCode() { + return Objects.hash(key, minK, maxK); + } + + @Override + public String toString() { + return "GroupBy{key='" + key + "', minK=" + minK + ", maxK=" + maxK + '}'; + } + + /** + * Builder for {@link GroupBy}. + */ + public static final class Builder { + + private String key; + private Integer minK; + private Integer maxK; + + private Builder() {} + + /** + * Sets the metadata key to group by. Required. + * + * @param key non-null, non-blank metadata key + * @return this builder + */ + public Builder key(String key) { + this.key = key; + return this; + } + + /** + * Sets the minimum number of results per group. + * + * @param minK minimum results per group + * @return this builder + */ + public Builder minK(int minK) { + this.minK = minK; + return this; + } + + /** + * Sets the maximum number of results per group. + * + * @param maxK maximum results per group + * @return this builder + */ + public Builder maxK(int maxK) { + this.maxK = maxK; + return this; + } + + /** + * Builds the {@link GroupBy} instance. + * + * @return an immutable {@code GroupBy} + * @throws IllegalArgumentException if {@code key} is null or blank + */ + public GroupBy build() { + if (key == null || key.trim().isEmpty()) { + throw new IllegalArgumentException("key must not be null or blank"); + } + if (minK != null && minK < 1) { + throw new IllegalArgumentException("minK must be >= 1, got " + minK); + } + if (maxK != null && maxK < 1) { + throw new IllegalArgumentException("maxK must be >= 1, got " + maxK); + } + if (minK != null && maxK != null && minK > maxK) { + throw new IllegalArgumentException("minK (" + minK + ") must not exceed maxK (" + maxK + ")"); + } + return new GroupBy(key, minK, maxK); + } + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/ImmutableCopyUtils.java b/src/main/java/tech/amikos/chromadb/v2/ImmutableCopyUtils.java new file mode 100644 index 0000000..41129cb --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/ImmutableCopyUtils.java @@ -0,0 +1,72 @@ +package tech.amikos.chromadb.v2; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Package-private utilities for creating deeply immutable copies of nested collection structures. + * Shared by {@link QueryResultImpl}, {@link SearchResultImpl}, and {@link GetResultImpl}. + */ +final class ImmutableCopyUtils { + + private ImmutableCopyUtils() {} + + static List> nestedList(List> source) { + if (source == null) { + return null; + } + List> outer = new ArrayList>(source.size()); + for (List inner : source) { + if (inner == null) { + outer.add(null); + } else { + outer.add(Collections.unmodifiableList(new ArrayList(inner))); + } + } + return Collections.unmodifiableList(outer); + } + + static List>> nestedMetadata(List>> source) { + if (source == null) { + return null; + } + List>> outer = new ArrayList>>(source.size()); + for (List> inner : source) { + if (inner == null) { + outer.add(null); + continue; + } + List> innerCopy = new ArrayList>(inner.size()); + for (Map metadata : inner) { + innerCopy.add(metadata == null + ? null + : Collections.unmodifiableMap(new LinkedHashMap(metadata))); + } + outer.add(Collections.unmodifiableList(innerCopy)); + } + return Collections.unmodifiableList(outer); + } + + static List> nestedEmbeddings(List> source) { + if (source == null) { + return null; + } + List> outer = new ArrayList>(source.size()); + for (List inner : source) { + if (inner == null) { + outer.add(null); + continue; + } + List innerCopy = new ArrayList(inner.size()); + for (float[] embedding : inner) { + innerCopy.add(embedding == null ? null : Arrays.copyOf(embedding, embedding.length)); + } + outer.add(Collections.unmodifiableList(innerCopy)); + } + return Collections.unmodifiableList(outer); + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/Knn.java b/src/main/java/tech/amikos/chromadb/v2/Knn.java new file mode 100644 index 0000000..baa098a --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/Knn.java @@ -0,0 +1,184 @@ +package tech.amikos.chromadb.v2; + +import java.util.Arrays; + +/** + * KNN (K-Nearest Neighbor) ranking expression for the Search API. + * + *

Use the static factory methods to create an initial instance, then chain fluent setters to + * configure optional parameters. All chainable methods return a new immutable instance.

+ * + *
{@code
+ * // Text-based KNN
+ * Knn knn = Knn.queryText("search query").limit(10);
+ *
+ * // Embedding-based KNN
+ * Knn knn = Knn.queryEmbedding(new float[]{0.1f, 0.2f}).limit(10);
+ *
+ * // Sparse vector KNN
+ * Knn knn = Knn.querySparseVector(SparseVector.of(indices, values)).key("sparse_field");
+ * }
+ */ +public final class Knn { + + private final Object query; + private final String key; + private final Integer limit; + private final Double defaultScore; + private final boolean returnRank; + + private Knn(Object query, String key, Integer limit, Double defaultScore, boolean returnRank) { + this.query = query; + this.key = key; + this.limit = limit; + this.defaultScore = defaultScore; + this.returnRank = returnRank; + } + + /** + * Creates a KNN query by text. The text is sent to the server, which uses the collection's + * server-side embedding function to convert it to an embedding. + * + *

Unlike {@link Collection.QueryBuilder#queryTexts(String...)}, no client-side embedding + * function is invoked.

+ * + * @param text the query text; must not be null + * @return a new {@code Knn} instance + * @throws IllegalArgumentException if {@code text} is null + */ + public static Knn queryText(String text) { + if (text == null) { + throw new IllegalArgumentException("text must not be null"); + } + return new Knn(text, "#embedding", null, null, false); + } + + /** + * Creates a KNN query by raw embedding vector. A defensive copy is made. + * + * @param embedding the query embedding; must not be null + * @return a new {@code Knn} instance + * @throws IllegalArgumentException if {@code embedding} is null + */ + public static Knn queryEmbedding(float[] embedding) { + if (embedding == null) { + throw new IllegalArgumentException("embedding must not be null"); + } + return new Knn(Arrays.copyOf(embedding, embedding.length), "#embedding", null, null, false); + } + + /** + * Creates a KNN query by sparse vector. The {@code key} field defaults to {@code null} and + * should be set via {@link #key(String)} to identify the target sparse field. If omitted, + * the key will not be included in the wire format. + * + * @param sparseVector the sparse query vector; must not be null + * @return a new {@code Knn} instance + * @throws IllegalArgumentException if {@code sparseVector} is null + */ + public static Knn querySparseVector(SparseVector sparseVector) { + if (sparseVector == null) { + throw new IllegalArgumentException("sparseVector must not be null"); + } + return new Knn(sparseVector, null, null, null, false); + } + + /** + * Returns a copy of this instance with the given query key (e.g., {@code "#embedding"} or a + * named sparse field). + * + * @param key the target field key; must not be null + * @return new {@code Knn} with the key set + */ + public Knn key(String key) { + if (key == null) { + throw new IllegalArgumentException("key must not be null"); + } + return new Knn(this.query, key, this.limit, this.defaultScore, this.returnRank); + } + + /** + * Returns a copy of this instance with the given per-rank result limit. + * + * @param limit maximum number of results to return for this rank + * @return new {@code Knn} with limit set + */ + public Knn limit(int limit) { + if (limit <= 0) { + throw new IllegalArgumentException("limit must be > 0"); + } + return new Knn(this.query, this.key, limit, this.defaultScore, this.returnRank); + } + + /** + * Returns a copy of this instance with the given default score for missing results. + * + * @param score default score value + * @return new {@code Knn} with defaultScore set + */ + public Knn defaultScore(double score) { + return new Knn(this.query, this.key, this.limit, score, this.returnRank); + } + + /** + * Returns a copy of this instance with the return_rank flag set. + * + *

This is required for sub-rankings used inside {@link Rrf} and is automatically set by + * {@link Rrf.Builder#rank(Knn, double)}. Only set this manually when constructing standalone + * KNN queries that need rank in the result.

+ * + * @param returnRank whether to include rank position in results + * @return new {@code Knn} with returnRank set + */ + public Knn returnRank(boolean returnRank) { + return new Knn(this.query, this.key, this.limit, this.defaultScore, returnRank); + } + + /** + * Returns a copy of this instance with {@code returnRank=true}. + * + *

Package-private; used by {@link Rrf.Builder} to auto-configure sub-rankings.

+ */ + Knn withReturnRank() { + return new Knn(this.query, this.key, this.limit, this.defaultScore, true); + } + + /** + * Returns the query object (String, float[], or {@link SparseVector}). + * When the query is a {@code float[]}, a defensive copy is returned. + */ + public Object getQuery() { + if (query instanceof float[]) { + return Arrays.copyOf((float[]) query, ((float[]) query).length); + } + return query; + } + + /** + * Returns the target field key, or {@code null} if not set (sparse vector case). + */ + public String getKey() { + return key; + } + + /** + * Returns the per-rank result limit, or {@code null} if not set. + */ + public Integer getLimit() { + return limit; + } + + /** + * Returns the default score for missing results, or {@code null} if not set. + */ + public Double getDefaultScore() { + return defaultScore; + } + + /** + * Returns whether the rank position should be included in results. + */ + public boolean isReturnRank() { + return returnRank; + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/QueryResultImpl.java b/src/main/java/tech/amikos/chromadb/v2/QueryResultImpl.java index 819d796..6a94bfd 100644 --- a/src/main/java/tech/amikos/chromadb/v2/QueryResultImpl.java +++ b/src/main/java/tech/amikos/chromadb/v2/QueryResultImpl.java @@ -1,9 +1,6 @@ package tech.amikos.chromadb.v2; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReferenceArray; @@ -25,12 +22,12 @@ private QueryResultImpl(List> ids, List> documents, List>> metadatas, List> embeddings, List> distances, List> uris) { - this.ids = immutableNestedList(ids); - this.documents = immutableNestedList(documents); - this.metadatas = immutableNestedMetadata(metadatas); - this.embeddings = immutableNestedEmbeddings(embeddings); - this.distances = immutableNestedList(distances); - this.uris = immutableNestedList(uris); + this.ids = ImmutableCopyUtils.nestedList(ids); + this.documents = ImmutableCopyUtils.nestedList(documents); + this.metadatas = ImmutableCopyUtils.nestedMetadata(metadatas); + this.embeddings = ImmutableCopyUtils.nestedEmbeddings(embeddings); + this.distances = ImmutableCopyUtils.nestedList(distances); + this.uris = ImmutableCopyUtils.nestedList(uris); this.cachedRows = new AtomicReferenceArray>(this.ids.size()); } @@ -121,56 +118,4 @@ public Stream> stream() { return IntStream.range(0, ids.size()).mapToObj(this::rows); } - private static List> immutableNestedList(List> source) { - if (source == null) { - return null; - } - List> outer = new ArrayList>(source.size()); - for (List inner : source) { - if (inner == null) { - outer.add(null); - } else { - outer.add(Collections.unmodifiableList(new ArrayList(inner))); - } - } - return Collections.unmodifiableList(outer); - } - - private static List>> immutableNestedMetadata(List>> source) { - if (source == null) { - return null; - } - List>> outer = new ArrayList>>(source.size()); - for (List> inner : source) { - if (inner == null) { - outer.add(null); - continue; - } - List> innerCopy = new ArrayList>(inner.size()); - for (Map metadata : inner) { - innerCopy.add(metadata == null ? null : Collections.unmodifiableMap(new LinkedHashMap(metadata))); - } - outer.add(Collections.unmodifiableList(innerCopy)); - } - return Collections.unmodifiableList(outer); - } - - private static List> immutableNestedEmbeddings(List> source) { - if (source == null) { - return null; - } - List> outer = new ArrayList>(source.size()); - for (List inner : source) { - if (inner == null) { - outer.add(null); - continue; - } - List innerCopy = new ArrayList(inner.size()); - for (float[] embedding : inner) { - innerCopy.add(embedding == null ? null : Arrays.copyOf(embedding, embedding.length)); - } - outer.add(Collections.unmodifiableList(innerCopy)); - } - return Collections.unmodifiableList(outer); - } } diff --git a/src/main/java/tech/amikos/chromadb/v2/ReadLevel.java b/src/main/java/tech/amikos/chromadb/v2/ReadLevel.java new file mode 100644 index 0000000..376373c --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/ReadLevel.java @@ -0,0 +1,51 @@ +package tech.amikos.chromadb.v2; + +import java.util.Locale; + +/** + * Controls which data sources the search engine reads from when processing a search request. + * + *

Use {@link #INDEX_AND_WAL} for the most up-to-date results (includes recently written + * records in the WAL). Use {@link #INDEX_ONLY} for faster but potentially stale results.

+ */ +public enum ReadLevel { + + /** Read from both the persisted index and the write-ahead log (most up-to-date). */ + INDEX_AND_WAL("index_and_wal"), + + /** Read from the persisted index only (faster, potentially stale). */ + INDEX_ONLY("index_only"); + + private final String value; + + ReadLevel(String value) { + this.value = value; + } + + /** + * Returns the wire format value for this read level. + */ + public String getValue() { + return value; + } + + /** + * Returns the {@code ReadLevel} matching the given wire format string. + * + * @param value the string value to look up; must not be null + * @return matching {@code ReadLevel} + * @throws IllegalArgumentException if {@code value} is null or does not match any constant + */ + public static ReadLevel fromValue(String value) { + if (value == null) { + throw new IllegalArgumentException("value must not be null"); + } + String normalized = value.trim().toLowerCase(Locale.ROOT); + for (ReadLevel level : values()) { + if (level.value.equals(normalized)) { + return level; + } + } + throw new IllegalArgumentException("Unknown read level value: " + value); + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/ResultRow.java b/src/main/java/tech/amikos/chromadb/v2/ResultRow.java index d3812ff..eea50e7 100644 --- a/src/main/java/tech/amikos/chromadb/v2/ResultRow.java +++ b/src/main/java/tech/amikos/chromadb/v2/ResultRow.java @@ -3,10 +3,11 @@ import java.util.Map; /** - * Represents a single result row from a get or query operation. + * Represents a single result row from a get, query, or search operation. * - *

Fields are {@code null} when the corresponding {@link Include} value was not specified in the - * request. No {@code Optional} wrappers are used — callers should check for {@code null} directly. + *

Fields are {@code null} when the corresponding projection was not requested (e.g., + * {@link Include} for get/query, {@link Select} for search). No {@code Optional} wrappers + * are used — callers should check for {@code null} directly. */ public interface ResultRow { @@ -16,24 +17,24 @@ public interface ResultRow { String getId(); /** - * Returns the document text, or {@code null} if {@link Include#DOCUMENTS} was not included. + * Returns the document text, or {@code null} if document projection was not requested. */ String getDocument(); /** - * Returns an unmodifiable metadata map, or {@code null} if {@link Include#METADATAS} was not - * included. + * Returns an unmodifiable metadata map, or {@code null} if metadata projection was not + * requested. */ Map getMetadata(); /** - * Returns a defensive copy of the embedding array, or {@code null} if - * {@link Include#EMBEDDINGS} was not included. + * Returns a defensive copy of the embedding array, or {@code null} if embedding projection + * was not requested. */ float[] getEmbedding(); /** - * Returns the URI, or {@code null} if {@link Include#URIS} was not included. + * Returns the URI, or {@code null} if URI projection was not requested. */ String getUri(); } diff --git a/src/main/java/tech/amikos/chromadb/v2/Rrf.java b/src/main/java/tech/amikos/chromadb/v2/Rrf.java new file mode 100644 index 0000000..d8cec09 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/Rrf.java @@ -0,0 +1,153 @@ +package tech.amikos.chromadb.v2; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Reciprocal Rank Fusion (RRF) ranking expression that combines multiple KNN rankings. + * + *

RRF fuses multiple ranked lists into a single ranking by assigning reciprocal rank scores. + * At least one sub-ranking is required. All {@link Knn} sub-rankings have {@code returnRank} + * automatically enabled when added via {@link Builder#rank(Knn, double)}.

+ * + *
{@code
+ * Rrf rrf = Rrf.builder()
+ *     .rank(Knn.queryText("query"), 1.0)
+ *     .rank(Knn.querySparseVector(sv).key("sparse"), 0.5)
+ *     .k(60)
+ *     .build();
+ * }
+ */ +public final class Rrf { + + private final List ranks; + private final int k; + private final boolean normalize; + + private Rrf(List ranks, int k, boolean normalize) { + this.ranks = Collections.unmodifiableList(new ArrayList(ranks)); + this.k = k; + this.normalize = normalize; + } + + /** + * Returns a new {@link Builder} for constructing an {@code Rrf} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Returns the unmodifiable list of ranked inputs with their weights. + */ + public List getRanks() { + return ranks; + } + + /** + * Returns the RRF k constant (default 60). + */ + public int getK() { + return k; + } + + /** + * Returns whether scores should be normalized. + */ + public boolean isNormalize() { + return normalize; + } + + /** + * A KNN sub-ranking paired with a fusion weight. + */ + public static final class RankWithWeight { + + private final Knn knn; + private final double weight; + + private RankWithWeight(Knn knn, double weight) { + this.knn = knn; + this.weight = weight; + } + + /** + * Returns the KNN sub-ranking. + */ + public Knn getKnn() { + return knn; + } + + /** + * Returns the fusion weight for this ranking. + */ + public double getWeight() { + return weight; + } + } + + /** + * Builder for {@link Rrf}. + */ + public static final class Builder { + + private final List ranks = new ArrayList(); + private int k = 60; + private boolean normalize = false; + + private Builder() {} + + /** + * Adds a KNN sub-ranking with the given weight. The {@code returnRank} flag is + * automatically set to {@code true} on the provided {@link Knn} instance. + * + * @param knn the KNN sub-ranking; must not be null + * @param weight fusion weight for this sub-ranking + * @return this builder + * @throws IllegalArgumentException if {@code knn} is null + */ + public Builder rank(Knn knn, double weight) { + if (knn == null) { + throw new IllegalArgumentException("knn must not be null"); + } + ranks.add(new RankWithWeight(knn.withReturnRank(), weight)); + return this; + } + + /** + * Sets the RRF k constant. Default is 60. + * + * @param k the RRF k constant + * @return this builder + */ + public Builder k(int k) { + this.k = k; + return this; + } + + /** + * Sets whether scores should be normalized. Default is {@code false}. + * + * @param normalize whether to normalize scores + * @return this builder + */ + public Builder normalize(boolean normalize) { + this.normalize = normalize; + return this; + } + + /** + * Builds the {@link Rrf} instance. + * + * @return an immutable {@code Rrf} + * @throws IllegalArgumentException if no ranks have been added + */ + public Rrf build() { + if (ranks.isEmpty()) { + throw new IllegalArgumentException("at least one rank must be added"); + } + return new Rrf(ranks, k, normalize); + } + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/Search.java b/src/main/java/tech/amikos/chromadb/v2/Search.java new file mode 100644 index 0000000..f41233b --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/Search.java @@ -0,0 +1,247 @@ +package tech.amikos.chromadb.v2; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Per-search configuration composing a ranking expression (KNN or RRF) with optional filter, + * field projection, grouping, limit, and offset. + * + *

Exactly one of {@link Builder#knn(Knn)} or {@link Builder#rrf(Rrf)} must be set.

+ * + *
{@code
+ * Search search = Search.builder()
+ *     .knn(Knn.queryText("chromadb java client"))
+ *     .where(Where.eq("category", "tech"))
+ *     .selectAll()
+ *     .limit(10)
+ *     .build();
+ * }
+ */ +public final class Search { + + private final Knn knn; + private final Rrf rrf; + private final Where filter; + private final List getSelect() { + return select; + } + + /** + * Returns the groupBy configuration, or {@code null} if not set. + */ + public GroupBy getGroupBy() { + return groupBy; + } + + /** + * Returns the per-search result limit, or {@code null} if not set. + */ + public Integer getLimit() { + return limit; + } + + /** + * Returns the per-search result offset, or {@code null} if not set. + */ + public Integer getOffset() { + return offset; + } + + /** + * Returns a new {@link Builder} pre-populated with this instance's fields. + * Useful for creating modified copies without manually copying every field. + */ + public Builder toBuilder() { + Builder b = new Builder(); + b.knn = this.knn; + b.rrf = this.rrf; + b.filter = this.filter; + b.select = this.select == null ? null : this.select.toArray(new Select[0]); + b.groupBy = this.groupBy; + b.limit = this.limit; + b.offset = this.offset; + return b; + } + + /** + * Builder for {@link Search}. + */ + public static final class Builder { + + private Knn knn; + private Rrf rrf; + private Where filter; + private Select[] select; + private GroupBy groupBy; + private Integer limit; + private Integer offset; + + private Builder() {} + + /** + * Sets the KNN ranking expression. Mutually exclusive with {@link #rrf(Rrf)}. + * + * @param knn KNN ranking expression; must not be null + * @return this builder + */ + public Builder knn(Knn knn) { + Objects.requireNonNull(knn, "knn must not be null"); + this.knn = knn; + return this; + } + + /** + * Sets the RRF ranking expression. Mutually exclusive with {@link #knn(Knn)}. + * + * @param rrf RRF ranking expression; must not be null + * @return this builder + */ + public Builder rrf(Rrf rrf) { + Objects.requireNonNull(rrf, "rrf must not be null"); + this.rrf = rrf; + return this; + } + + /** + * Sets a per-search metadata/ID filter. + * + * @param filter the where filter; must not be null + * @return this builder + */ + public Builder where(Where filter) { + Objects.requireNonNull(filter, "filter must not be null"); + this.filter = filter; + return this; + } + + /** + * Sets per-search field projection. + * + * @param fields one or more field selectors to project + * @return this builder + */ + public Builder select(Select... fields) { + Objects.requireNonNull(fields, "fields must not be null"); + for (int i = 0; i < fields.length; i++) { + if (fields[i] == null) { + throw new IllegalArgumentException("fields[" + i + "] must not be null"); + } + } + this.select = Arrays.copyOf(fields, fields.length); + return this; + } + + /** + * Convenience method that projects all standard fields: ID, DOCUMENT, EMBEDDING, + * METADATA, SCORE. + * + * @return this builder + */ + public Builder selectAll() { + return select(Select.all()); + } + + /** + * Sets per-search result grouping. + * + * @param groupBy the group-by configuration; must not be null + * @return this builder + */ + public Builder groupBy(GroupBy groupBy) { + Objects.requireNonNull(groupBy, "groupBy must not be null"); + this.groupBy = groupBy; + return this; + } + + /** + * Sets the per-search result limit. + * + * @param limit maximum number of results to return + * @return this builder + */ + public Builder limit(int limit) { + this.limit = limit; + return this; + } + + /** + * Sets the per-search result offset. + * + * @param offset number of results to skip + * @return this builder + */ + public Builder offset(int offset) { + this.offset = offset; + return this; + } + + /** + * Builds the {@link Search} instance. + * + * @return an immutable {@code Search} + * @throws IllegalArgumentException if neither or both of knn and rrf are set + */ + public Search build() { + if (knn == null && rrf == null) { + throw new IllegalArgumentException( + "exactly one of knn or rrf must be set, but neither was provided"); + } + if (knn != null && rrf != null) { + throw new IllegalArgumentException( + "exactly one of knn or rrf must be set, but both were provided"); + } + return new Search(this); + } + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/SearchResult.java b/src/main/java/tech/amikos/chromadb/v2/SearchResult.java new file mode 100644 index 0000000..3e02efa --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/SearchResult.java @@ -0,0 +1,72 @@ +package tech.amikos.chromadb.v2; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +/** + * Result from a search operation. + * + *

Supports both column-oriented access (lists of lists, indexed by search input) and + * row-oriented access (result groups and rows).

+ * + *

Each outer list is indexed per search input. Fields omitted from the {@link Select} + * projection may be {@code null}.

+ */ +public interface SearchResult { + + // --- Column-oriented accessors --- + + /** + * Always present. Each inner list corresponds to one search input. + */ + List> getIds(); + + /** + * Present when {@link Select#DOCUMENT} is projected; otherwise may be {@code null}. + */ + List> getDocuments(); + + /** + * Present when {@link Select#METADATA} is projected; otherwise may be {@code null}. + */ + List>> getMetadatas(); + + /** + * Present when {@link Select#EMBEDDING} is projected; otherwise may be {@code null}. + */ + List> getEmbeddings(); + + /** + * Present when {@link Select#SCORE} is projected; otherwise may be {@code null}. + * + *

Scores are {@link Double} (not {@code Float}) to match the wire format precision. + * Higher values indicate greater relevance.

+ */ + List> getScores(); + + // --- Row-oriented accessors --- + + /** + * Returns the results for the specified search input as a flat row-oriented group. + * + * @param searchIndex zero-based index of the search input + * @return group of rows for that search input + * @throws IndexOutOfBoundsException if searchIndex is out of range + */ + ResultGroup rows(int searchIndex); + + /** + * Returns the number of search inputs (outer list size of ids). + * + *

This is the count of search inputs submitted, not the number of groups within + * a GroupBy result. Each search input produces one entry in the outer lists returned + * by column accessors like {@link #getIds()}.

+ */ + int searchCount(); + + /** + * Returns a stream over all search groups, enabling flatMap patterns. + */ + Stream> stream(); +} diff --git a/src/main/java/tech/amikos/chromadb/v2/SearchResultImpl.java b/src/main/java/tech/amikos/chromadb/v2/SearchResultImpl.java new file mode 100644 index 0000000..3480720 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/SearchResultImpl.java @@ -0,0 +1,138 @@ +package tech.amikos.chromadb.v2; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +/** + * Package-private immutable implementation of {@link SearchResult}. + * + *

Supports both column-oriented and row-oriented access patterns. Rows are lazily constructed + * and cached per search index using an {@link AtomicReferenceArray}.

+ */ +final class SearchResultImpl implements SearchResult { + + private final List> ids; + private final List> documents; + private final List>> metadatas; + private final List> embeddings; + private final List> scores; + + private final AtomicReferenceArray> cachedRows; + + private SearchResultImpl(List> ids, List> documents, + List>> metadatas, + List> embeddings, List> scores) { + this.ids = ImmutableCopyUtils.nestedList(ids); + this.documents = ImmutableCopyUtils.nestedList(documents); + this.metadatas = ImmutableCopyUtils.nestedMetadata(metadatas); + this.embeddings = ImmutableCopyUtils.nestedEmbeddings(embeddings); + this.scores = ImmutableCopyUtils.nestedList(scores); + this.cachedRows = new AtomicReferenceArray>(this.ids.size()); + } + + static SearchResultImpl from(ChromaDtos.SearchResponse dto) { + if (dto == null) { + throw new ChromaDeserializationException( + "Server returned an empty search response payload", + 200 + ); + } + if (dto.ids == null) { + throw new ChromaDeserializationException( + "Server returned search result without required ids field", + 200 + ); + } + List> embeddings = null; + if (dto.embeddings != null) { + embeddings = new ArrayList>(dto.embeddings.size()); + for (List> inner : dto.embeddings) { + embeddings.add(ChromaDtos.toFloatArrays(inner)); + } + } + return new SearchResultImpl( + dto.ids, + dto.documents, + dto.metadatas, + embeddings, + dto.scores + ); + } + + @Override + public List> getIds() { + return ids; + } + + @Override + public List> getDocuments() { + return documents; + } + + @Override + public List>> getMetadatas() { + return metadatas; + } + + @Override + public List> getEmbeddings() { + return embeddings; + } + + @Override + public List> getScores() { + return scores; + } + + @Override + public ResultGroup rows(int searchIndex) { + checkSearchIndex(searchIndex); + ResultGroup r = cachedRows.get(searchIndex); + if (r == null) { + List colIds = ids.get(searchIndex); + List rowScores = scores == null ? null : scores.get(searchIndex); + List docList = documents == null ? null : documents.get(searchIndex); + List> metaList = metadatas == null ? null : metadatas.get(searchIndex); + List embList = embeddings == null ? null : embeddings.get(searchIndex); + + List result = new ArrayList(colIds.size()); + for (int i = 0; i < colIds.size(); i++) { + Double score = (rowScores != null && rowScores.get(i) != null) + ? rowScores.get(i) : null; + result.add(new SearchResultRowImpl( + colIds.get(i), + docList == null ? null : docList.get(i), + metaList == null ? null : metaList.get(i), + embList == null ? null : embList.get(i), + null, + score + )); + } + r = new ResultGroupImpl(result); + cachedRows.compareAndSet(searchIndex, null, r); + r = cachedRows.get(searchIndex); + } + return r; + } + + @Override + public int searchCount() { + return ids.size(); + } + + @Override + public Stream> stream() { + return IntStream.range(0, ids.size()).mapToObj(this::rows); + } + + private void checkSearchIndex(int searchIndex) { + if (searchIndex < 0 || searchIndex >= ids.size()) { + throw new IndexOutOfBoundsException( + "searchIndex " + searchIndex + " out of range [0, " + ids.size() + ")"); + } + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java b/src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java new file mode 100644 index 0000000..e5d7e10 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/SearchResultRow.java @@ -0,0 +1,19 @@ +package tech.amikos.chromadb.v2; + +/** + * A result row from a search operation, extending {@link ResultRow} with a score field. + * + *

{@link #getScore()} returns {@code null} when {@link Select#SCORE} was not included in the + * search request's field projection.

+ */ +public interface SearchResultRow extends ResultRow { + + /** + * Returns the relevance score from the ranking expression, or {@code null} if + * {@link Select#SCORE} was not included in the projection. + * + *

Higher scores indicate greater relevance. Returns {@link Double} to preserve + * the full wire-format precision, consistent with {@link SearchResult#getScores()}.

+ */ + Double getScore(); +} diff --git a/src/main/java/tech/amikos/chromadb/v2/SearchResultRowImpl.java b/src/main/java/tech/amikos/chromadb/v2/SearchResultRowImpl.java new file mode 100644 index 0000000..358c635 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/SearchResultRowImpl.java @@ -0,0 +1,76 @@ +package tech.amikos.chromadb.v2; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; + +/** + * Package-private immutable implementation of {@link SearchResultRow}. + * + *

Delegates base {@link ResultRow} behaviour to a composed {@link ResultRowImpl}. + * {@link #getScore()} returns {@code null} when {@link Select#SCORE} was not projected. + */ +final class SearchResultRowImpl implements SearchResultRow { + + private final ResultRowImpl base; + private final Double score; + + SearchResultRowImpl(String id, String document, Map metadata, + float[] embedding, String uri, Double score) { + this.base = new ResultRowImpl(id, document, metadata, embedding, uri); + this.score = score; + } + + @Override + public String getId() { + return base.getId(); + } + + @Override + public String getDocument() { + return base.getDocument(); + } + + @Override + public Map getMetadata() { + return base.getMetadata(); + } + + @Override + public float[] getEmbedding() { + return base.getEmbedding(); + } + + @Override + public String getUri() { + return base.getUri(); + } + + @Override + public Double getScore() { + return score; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!(obj instanceof SearchResultRowImpl)) return false; + SearchResultRowImpl other = (SearchResultRowImpl) obj; + return base.equals(other.base) && Objects.equals(score, other.score); + } + + @Override + public int hashCode() { + return 31 * base.hashCode() + Objects.hashCode(score); + } + + @Override + public String toString() { + return "SearchResultRow{id=" + getId() + + ", document=" + getDocument() + + ", metadata=" + getMetadata() + + ", embedding=" + Arrays.toString(getEmbedding()) + + ", uri=" + getUri() + + ", score=" + score + "}"; + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/Select.java b/src/main/java/tech/amikos/chromadb/v2/Select.java new file mode 100644 index 0000000..4613f96 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/Select.java @@ -0,0 +1,85 @@ +package tech.amikos.chromadb.v2; + +import java.util.Objects; + +/** + * Field projection descriptor for Search API results. + * + *

Predefined constants cover the standard Chroma search output fields. Custom metadata key + * projections are created via {@link #key(String)}.

+ */ +public final class Select { + + /** Projects the document text field. */ + public static final Select DOCUMENT = new Select("#document"); + + /** Projects the relevance score field. */ + public static final Select SCORE = new Select("#score"); + + /** Projects the embedding vector field. */ + public static final Select EMBEDDING = new Select("#embedding"); + + /** Projects the metadata map field. */ + public static final Select METADATA = new Select("#metadata"); + + /** Projects the record ID field. */ + public static final Select ID = new Select("#id"); + + private final String key; + + private Select(String key) { + this.key = key; + } + + /** + * Creates a {@code Select} for a custom metadata field name. + * + *

The field name is used verbatim (no {@code #} prefix is added). This is suitable for + * projecting specific metadata keys from the search result.

+ * + * @param fieldName non-null, non-blank metadata key name + * @return a {@code Select} for the given field name + * @throws IllegalArgumentException if {@code fieldName} is null or blank + */ + public static Select key(String fieldName) { + if (fieldName == null) { + throw new IllegalArgumentException("fieldName must not be null"); + } + if (fieldName.trim().isEmpty()) { + throw new IllegalArgumentException("fieldName must not be blank"); + } + return new Select(fieldName); + } + + /** + * Returns all standard field projections: ID, DOCUMENT, EMBEDDING, METADATA, SCORE. + */ + public static Select[] all() { + return new Select[]{ID, DOCUMENT, EMBEDDING, METADATA, SCORE}; + } + + /** + * Returns the field key string. + */ + public String getKey() { + return key; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Select)) return false; + Select that = (Select) o; + return Objects.equals(key, that.key); + } + + @Override + public int hashCode() { + return Objects.hashCode(key); + } + + @Override + public String toString() { + return "Select(" + key + ')'; + } +} diff --git a/src/main/java/tech/amikos/chromadb/v2/SparseVector.java b/src/main/java/tech/amikos/chromadb/v2/SparseVector.java new file mode 100644 index 0000000..2a26c60 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/v2/SparseVector.java @@ -0,0 +1,78 @@ +package tech.amikos.chromadb.v2; + +import java.util.Arrays; + +/** + * Immutable sparse vector value type holding integer indices and float values. + * + *

Indices and values arrays must be the same length. Defensive copies are applied on + * construction and on every getter call to prevent aliasing.

+ */ +public final class SparseVector { + + private final int[] indices; + private final float[] values; + + private SparseVector(int[] indices, float[] values) { + this.indices = Arrays.copyOf(indices, indices.length); + this.values = Arrays.copyOf(values, values.length); + } + + /** + * Creates a new {@code SparseVector} from the given indices and values arrays. + * + * @param indices non-null array of index positions; must have the same length as {@code values} + * @param values non-null array of float values; must have the same length as {@code indices} + * @return an immutable sparse vector + * @throws IllegalArgumentException if either array is null or if they have different lengths + */ + public static SparseVector of(int[] indices, float[] values) { + if (indices == null) { + throw new IllegalArgumentException("indices must not be null"); + } + if (values == null) { + throw new IllegalArgumentException("values must not be null"); + } + if (indices.length != values.length) { + throw new IllegalArgumentException( + "indices and values must have the same length, got " + + indices.length + " and " + values.length); + } + return new SparseVector(indices, values); + } + + /** + * Returns a defensive copy of the index positions. + */ + public int[] getIndices() { + return Arrays.copyOf(indices, indices.length); + } + + /** + * Returns a defensive copy of the float values. + */ + public float[] getValues() { + return Arrays.copyOf(values, values.length); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof SparseVector)) return false; + SparseVector that = (SparseVector) o; + return Arrays.equals(indices, that.indices) && Arrays.equals(values, that.values); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(indices); + result = 31 * result + Arrays.hashCode(values); + return result; + } + + @Override + public String toString() { + return "SparseVector{indices=" + Arrays.toString(indices) + + ", values=" + Arrays.toString(values) + '}'; + } +} diff --git a/src/test/java/tech/amikos/chromadb/v2/MetadataValidationTest.java b/src/test/java/tech/amikos/chromadb/v2/MetadataValidationTest.java new file mode 100644 index 0000000..c8d866d --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/MetadataValidationTest.java @@ -0,0 +1,307 @@ +package tech.amikos.chromadb.v2; + +import org.junit.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Unit tests for {@link ChromaHttpCollection#validateMetadataArrayTypes} covering: + * - Homogeneous arrays (all types) — pass + * - Mixed-type arrays — rejected with {@link ChromaBadRequestException} + * - Null elements in arrays — rejected + * - Scalar metadata values — ignored + * - Edge cases: null list, null entry, empty array + * + * Also includes behavioral wiring tests that verify the validation is invoked + * via the {@code add()}, {@code upsert()}, and {@code update()} execute() methods. + */ +public class MetadataValidationTest { + + // ============================================================================= + // Static validation unit tests + // ============================================================================= + + @Test + public void testHomogeneousStringArrayPasses() { + List> metadatas = Collections.singletonList( + singleMetadata("tags", Arrays.asList("a", "b", "c")) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + // no exception = pass + } + + @Test + public void testHomogeneousIntArrayPasses() { + List> metadatas = Collections.singletonList( + singleMetadata("counts", Arrays.asList(1, 2, 3)) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test + public void testHomogeneousFloatArrayPasses() { + List> metadatas = Collections.singletonList( + singleMetadata("scores", Arrays.asList(1.5f, 2.5f, 3.5f)) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test + public void testHomogeneousBoolArrayPasses() { + List> metadatas = Collections.singletonList( + singleMetadata("flags", Arrays.asList(true, false, true)) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test + public void testEmptyArrayPasses() { + List> metadatas = Collections.singletonList( + singleMetadata("tags", Collections.emptyList()) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test + public void testNullMetadatasListPasses() { + ChromaHttpCollection.validateMetadataArrayTypes(null); + } + + @Test + public void testNullMetadataEntryPasses() { + List> metadatas = new ArrayList>(); + metadatas.add(null); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test(expected = ChromaBadRequestException.class) + public void testMixedStringAndIntArrayRejected() { + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Integer.valueOf(42)); + List> metadatas = Collections.singletonList( + singleMetadata("mixed", mixed) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test(expected = ChromaBadRequestException.class) + public void testMixedStringAndBoolArrayRejected() { + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Boolean.TRUE); + List> metadatas = Collections.singletonList( + singleMetadata("mixed", mixed) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test(expected = ChromaBadRequestException.class) + public void testMixedIntAndBoolArrayRejected() { + List mixed = new ArrayList(); + mixed.add(Integer.valueOf(42)); + mixed.add(Boolean.TRUE); + List> metadatas = Collections.singletonList( + singleMetadata("mixed", mixed) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test(expected = ChromaBadRequestException.class) + public void testNullElementInArrayRejected() { + List withNull = new ArrayList(); + withNull.add("valid"); + withNull.add(null); + List> metadatas = Collections.singletonList( + singleMetadata("tags", withNull) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + } + + @Test + public void testMixedIntegerAndLongPassesAsCompatible() { + List intAndLong = new ArrayList(); + intAndLong.add(Integer.valueOf(1)); + intAndLong.add(Long.valueOf(2L)); + List> metadatas = Collections.singletonList( + singleMetadata("ids", intAndLong) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + // Integer and Long are both "integer group" - should pass + } + + @Test + public void testMixedFloatAndDoublePassesAsCompatible() { + List floatAndDouble = new ArrayList(); + floatAndDouble.add(Float.valueOf(1.0f)); + floatAndDouble.add(Double.valueOf(2.0)); + List> metadatas = Collections.singletonList( + singleMetadata("scores", floatAndDouble) + ); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + // Float and Double are both "float group" - should pass + } + + @Test + public void testScalarMetadataValuesIgnored() { + Map meta = new LinkedHashMap(); + meta.put("name", "test"); + meta.put("count", Integer.valueOf(5)); + meta.put("active", Boolean.TRUE); + List> metadatas = Collections.singletonList(meta); + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + // scalar values should not trigger validation + } + + @Test + public void testMixedTypeErrorMessageContainsDetails() { + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Integer.valueOf(42)); + mixed.add(Boolean.TRUE); + List> metadatas = Collections.singletonList( + singleMetadata("bad_field", mixed) + ); + try { + ChromaHttpCollection.validateMetadataArrayTypes(metadatas); + fail("Expected ChromaBadRequestException"); + } catch (ChromaBadRequestException e) { + assertTrue("Message should mention field name", e.getMessage().contains("bad_field")); + assertTrue("Message should mention 'mixed types'", e.getMessage().contains("mixed types")); + } + } + + // ============================================================================= + // Behavioral wiring tests + // Verify that col.add/upsert/update().execute() calls validateMetadataArrayTypes + // BEFORE any HTTP call. These tests use a stub Collection created via + // ChromaHttpCollection.from() pointing to a dead endpoint (localhost:1). + // If validation fires, ChromaBadRequestException is thrown before any network call. + // If ChromaConnectionException is thrown instead, the wiring is broken. + // ============================================================================= + + @Test + public void testAddExecuteRejectsMixedTypeArrayBeforeHttpCall() { + Collection col = createStubCollection(); + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Integer.valueOf(42)); + Map meta = new LinkedHashMap(); + meta.put("mixed_field", mixed); + + try { + col.add() + .ids("test-1") + .documents("test document") + .metadatas(Collections.>singletonList(meta)) + .execute(); + fail("Expected ChromaBadRequestException for mixed-type array in add()"); + } catch (ChromaBadRequestException e) { + // Correct — validation fired before HTTP call + assertTrue("Exception message should mention mixed types", e.getMessage().contains("mixed types")); + } catch (ChromaException e) { + // ChromaConnectionException or other — wiring is broken (validation did not fire first) + fail("Expected ChromaBadRequestException but got " + e.getClass().getSimpleName() + + ": " + e.getMessage() + + " — this means validateMetadataArrayTypes was NOT called before the HTTP call in add().execute()"); + } + } + + @Test + public void testUpsertExecuteRejectsMixedTypeArray() { + Collection col = createStubCollection(); + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Integer.valueOf(42)); + Map meta = new LinkedHashMap(); + meta.put("mixed_field", mixed); + + try { + col.upsert() + .ids("test-1") + .documents("test document") + .metadatas(Collections.>singletonList(meta)) + .execute(); + fail("Expected ChromaBadRequestException for mixed-type array in upsert()"); + } catch (ChromaBadRequestException e) { + // Correct — validation fired before HTTP call + assertTrue("Exception message should mention mixed types", e.getMessage().contains("mixed types")); + } catch (ChromaException e) { + fail("Expected ChromaBadRequestException but got " + e.getClass().getSimpleName() + + ": " + e.getMessage() + + " — this means validateMetadataArrayTypes was NOT called before the HTTP call in upsert().execute()"); + } + } + + @Test + public void testUpdateExecuteRejectsMixedTypeArray() { + Collection col = createStubCollection(); + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Integer.valueOf(42)); + Map meta = new LinkedHashMap(); + meta.put("mixed_field", mixed); + + try { + col.update() + .ids("test-1") + .metadatas(Collections.>singletonList(meta)) + .execute(); + fail("Expected ChromaBadRequestException for mixed-type array in update()"); + } catch (ChromaBadRequestException e) { + // Correct — validation fired before HTTP call + assertTrue("Exception message should mention mixed types", e.getMessage().contains("mixed types")); + } catch (ChromaException e) { + fail("Expected ChromaBadRequestException but got " + e.getClass().getSimpleName() + + ": " + e.getMessage() + + " — this means validateMetadataArrayTypes was NOT called before the HTTP call in update().execute()"); + } + } + + // ============================================================================= + // Helpers + // ============================================================================= + + private static Map singleMetadata(String key, Object value) { + Map meta = new LinkedHashMap(); + meta.put(key, value); + return meta; + } + + /** + * Creates a stub {@link Collection} backed by a {@link ChromaApiClient} pointing at + * {@code http://localhost:1} (a dead endpoint). Since mixed-type validation fires + * BEFORE any HTTP call is attempted, the stub never actually makes a network request. + * + *

Uses package-private {@code ChromaHttpCollection.from()} and {@code ChromaDtos} + * since the test is in the same package.

+ */ + private static Collection createStubCollection() { + ChromaApiClient stubApiClient = new ChromaApiClient( + "http://localhost:1", + null, + null, + Duration.ofMillis(100), + Duration.ofMillis(100), + Duration.ofMillis(100) + ); + ChromaDtos.CollectionResponse dto = new ChromaDtos.CollectionResponse(); + dto.id = "stub-id-00000000-0000-0000-0000-000000000000"; + dto.name = "stub-collection"; + Tenant tenant = Tenant.of("default_tenant"); + Database database = Database.of("default_database"); + return ChromaHttpCollection.from(dto, stubApiClient, tenant, database, null); + } +} diff --git a/src/test/java/tech/amikos/chromadb/v2/PublicInterfaceCompatibilityTest.java b/src/test/java/tech/amikos/chromadb/v2/PublicInterfaceCompatibilityTest.java index aea6292..36f860f 100644 --- a/src/test/java/tech/amikos/chromadb/v2/PublicInterfaceCompatibilityTest.java +++ b/src/test/java/tech/amikos/chromadb/v2/PublicInterfaceCompatibilityTest.java @@ -25,7 +25,7 @@ private static int declaredMethodCount(Class clazz) { // Expected declared method counts — update these when intentionally adding/removing public methods private static final int EXPECTED_CLIENT_METHOD_COUNT = 26; - private static final int EXPECTED_COLLECTION_METHOD_COUNT = 21; + private static final int EXPECTED_COLLECTION_METHOD_COUNT = 22; private static final int EXPECTED_ADD_BUILDER_METHOD_COUNT = 11; private static final int EXPECTED_QUERY_BUILDER_METHOD_COUNT = 9; private static final int EXPECTED_GET_BUILDER_METHOD_COUNT = 8; @@ -287,6 +287,12 @@ public void testCollectionIndexingStatusMethod() throws Exception { assertEquals(IndexingStatus.class, method.getReturnType()); } + @Test + public void testCollectionSearchMethod() throws Exception { + Method method = Collection.class.getMethod("search"); + assertEquals(Collection.SearchBuilder.class, method.getReturnType()); + } + // === Builder method existence === @Test diff --git a/src/test/java/tech/amikos/chromadb/v2/SearchApiCloudIntegrationTest.java b/src/test/java/tech/amikos/chromadb/v2/SearchApiCloudIntegrationTest.java new file mode 100644 index 0000000..2a16131 --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/SearchApiCloudIntegrationTest.java @@ -0,0 +1,872 @@ +package tech.amikos.chromadb.v2; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import tech.amikos.chromadb.Utils; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Cloud integration tests for schema/index parity (CLOUD-02) and array metadata (CLOUD-03). + * + *

Credentials loaded from {@code .env} or environment variables: + * CHROMA_API_KEY, CHROMA_TENANT, CHROMA_DATABASE.

+ * + *

All cloud-dependent tests skip cleanly when CHROMA_API_KEY is absent (per D-02). + * Mixed-type array validation test (D-22) runs regardless of credentials.

+ */ +public class SearchApiCloudIntegrationTest { + + // --- Shared (read-only) seed collection --- + + private static Client sharedClient; + private static Collection seedCollection; + private static String sharedCollectionName; + private static boolean cloudAvailable = false; + + private static String sharedApiKey; + private static String sharedTenant; + private static String sharedDatabase; + + @BeforeClass + public static void setUpSharedSeedCollection() { + Utils.loadEnvFile(".env"); + sharedApiKey = Utils.getEnvOrProperty("CHROMA_API_KEY"); + sharedTenant = Utils.getEnvOrProperty("CHROMA_TENANT"); + sharedDatabase = Utils.getEnvOrProperty("CHROMA_DATABASE"); + + if (!isNonBlank(sharedApiKey) || !isNonBlank(sharedTenant) || !isNonBlank(sharedDatabase)) { + // Credentials absent -- cloud tests will be skipped. cloudAvailable remains false. + return; + } + + sharedClient = ChromaClient.cloud() + .apiKey(sharedApiKey) + .tenant(sharedTenant) + .database(sharedDatabase) + .timeout(Duration.ofSeconds(45)) + .build(); + + sharedCollectionName = "seed_" + UUID.randomUUID().toString().substring(0, 8); + seedCollection = sharedClient.createCollection(sharedCollectionName); + + // Add 15 records modeling a product catalog domain (per D-04, D-06 — server-side embeddings) + List ids = Arrays.asList( + "prod-001", "prod-002", "prod-003", "prod-004", "prod-005", + "prod-006", "prod-007", "prod-008", "prod-009", "prod-010", + "prod-011", "prod-012", "prod-013", "prod-014", "prod-015" + ); + + List documents = Arrays.asList( + "Wireless bluetooth headphones with noise cancellation", + "Organic green tea bags premium quality", + "Running shoes lightweight cushioned sole", + "Stainless steel water bottle 32oz insulated", + "Laptop stand adjustable aluminum ergonomic", + "Yoga mat non-slip extra thick comfortable", + "Coffee beans dark roast single origin", + "Mechanical keyboard compact tenkeyless RGB", + "Smart home speaker voice assistant built-in", + "Protein powder vanilla whey isolate", + "LED desk lamp adjustable color temperature", + "Travel backpack 45L carry-on approved", + "Resistance bands set five levels workout", + "Notebook spiral hardcover college ruled", + "Bluetooth earbuds true wireless charging case" + ); + + List> metadatas = new ArrayList>(); + metadatas.add(buildMeta("electronics", 149.99f, true, + Arrays.asList("audio", "wireless"), Arrays.asList(4, 5, 3))); + metadatas.add(buildMeta("grocery", 12.99f, true, + Arrays.asList("tea", "organic"), Arrays.asList(5, 4, 5))); + metadatas.add(buildMeta("clothing", 89.99f, true, + Arrays.asList("running", "sports"), Arrays.asList(4, 4, 3))); + metadatas.add(buildMeta("sports", 29.99f, false, + Arrays.asList("hydration", "outdoor"), Arrays.asList(5, 5, 4))); + metadatas.add(buildMeta("electronics", 49.99f, true, + Arrays.asList("laptop", "accessories"), Arrays.asList(4, 3, 5))); + metadatas.add(buildMeta("sports", 39.99f, true, + Arrays.asList("yoga", "fitness"), Arrays.asList(5, 4, 4))); + metadatas.add(buildMeta("grocery", 24.99f, true, + Arrays.asList("coffee", "roasted"), Arrays.asList(5, 5, 5))); + metadatas.add(buildMeta("electronics", 129.99f, true, + Arrays.asList("keyboard", "gaming"), Arrays.asList(4, 4, 3))); + metadatas.add(buildMeta("electronics", 79.99f, false, + Arrays.asList("smart-home", "voice"), Arrays.asList(3, 4, 3))); + metadatas.add(buildMeta("grocery", 44.99f, true, + Arrays.asList("fitness", "protein"), Arrays.asList(4, 3, 4))); + metadatas.add(buildMeta("electronics", 35.99f, true, + Arrays.asList("lighting", "office"), Arrays.asList(4, 5, 4))); + metadatas.add(buildMeta("travel", 119.99f, true, + Arrays.asList("travel", "outdoor"), Arrays.asList(4, 4, 5))); + metadatas.add(buildMeta("sports", 19.99f, true, + Arrays.asList("fitness", "strength"), Arrays.asList(5, 4, 3))); + metadatas.add(buildMeta("office", 8.99f, true, + Arrays.asList("stationery", "school"), Arrays.asList(3, 3, 4))); + metadatas.add(buildMeta("electronics", 59.99f, true, + Arrays.asList("audio", "wireless"), Arrays.asList(4, 5, 5))); + + seedCollection.add() + .ids(ids) + .documents(documents) + .metadatas(metadatas) + .embeddings( + // Electronics cluster: dominant first dimension + new float[]{0.90f, 0.10f, 0.10f, 0.10f}, // prod-001 headphones + // Grocery cluster: dominant second dimension + new float[]{0.10f, 0.90f, 0.10f, 0.10f}, // prod-002 tea + // Clothing/Sports cluster: dominant third dimension + new float[]{0.15f, 0.10f, 0.85f, 0.10f}, // prod-003 shoes + new float[]{0.10f, 0.10f, 0.80f, 0.20f}, // prod-004 water bottle + new float[]{0.85f, 0.15f, 0.10f, 0.10f}, // prod-005 laptop stand + new float[]{0.10f, 0.10f, 0.90f, 0.10f}, // prod-006 yoga mat + new float[]{0.10f, 0.85f, 0.15f, 0.10f}, // prod-007 coffee + new float[]{0.88f, 0.12f, 0.10f, 0.10f}, // prod-008 keyboard + new float[]{0.80f, 0.10f, 0.10f, 0.20f}, // prod-009 speaker + new float[]{0.10f, 0.80f, 0.20f, 0.10f}, // prod-010 protein + new float[]{0.82f, 0.10f, 0.10f, 0.18f}, // prod-011 desk lamp + // Travel/Office cluster: dominant fourth dimension + new float[]{0.10f, 0.10f, 0.20f, 0.80f}, // prod-012 backpack + new float[]{0.10f, 0.10f, 0.85f, 0.15f}, // prod-013 resistance bands + new float[]{0.10f, 0.10f, 0.10f, 0.90f}, // prod-014 notebook + new float[]{0.87f, 0.13f, 0.10f, 0.10f} // prod-015 earbuds + ) + .execute(); + + cloudAvailable = true; + } + + @AfterClass + public static void tearDownSharedSeedCollection() { + if (sharedClient != null) { + if (sharedCollectionName != null) { + try { + sharedClient.deleteCollection(sharedCollectionName); + } catch (ChromaException ignored) { + // Best-effort cleanup + } + } + sharedClient.close(); + sharedClient = null; + } + } + + // --- Per-test client and collection tracking --- + + private Client client; + private final List createdCollections = new ArrayList(); + + @Before + public void setUp() { + Utils.loadEnvFile(".env"); + String apiKey = Utils.getEnvOrProperty("CHROMA_API_KEY"); + String tenant = Utils.getEnvOrProperty("CHROMA_TENANT"); + String database = Utils.getEnvOrProperty("CHROMA_DATABASE"); + + if (!isNonBlank(apiKey) || !isNonBlank(tenant) || !isNonBlank(database)) { + // Per-test client not created -- cloud tests will be skipped via cloudAvailable + return; + } + + client = ChromaClient.cloud() + .apiKey(apiKey) + .tenant(tenant) + .database(database) + .timeout(Duration.ofSeconds(45)) + .build(); + } + + @After + public void tearDown() { + if (client != null) { + for (int i = createdCollections.size() - 1; i >= 0; i--) { + String collectionName = createdCollections.get(i); + try { + client.deleteCollection(collectionName); + } catch (ChromaException ignored) { + // Best-effort cleanup for cloud tests. + } + } + client.close(); + client = null; + } + createdCollections.clear(); + } + + // --- Helper methods --- + + private Collection createIsolatedCollection(String prefix) { + String name = uniqueCollectionName(prefix); + trackCollection(name); + return client.createCollection(name); + } + + private Collection createIsolatedCollection(String prefix, CreateCollectionOptions options) { + String name = uniqueCollectionName(prefix); + trackCollection(name); + return client.createCollection(name, options); + } + + private void trackCollection(String name) { + createdCollections.add(name); + } + + private static String uniqueCollectionName(String prefix) { + return prefix + UUID.randomUUID().toString().replace("-", ""); + } + + private static boolean isNonBlank(String value) { + return value != null && !value.trim().isEmpty(); + } + + private static Map metadata(String... keyValues) { + if (keyValues.length % 2 != 0) { + throw new IllegalArgumentException("keyValues must be key-value pairs"); + } + Map meta = new LinkedHashMap(); + for (int i = 0; i < keyValues.length; i += 2) { + meta.put(keyValues[i], keyValues[i + 1]); + } + return meta; + } + + private static Map buildMeta(String category, float price, boolean inStock, + List tags, List ratings) { + Map meta = new LinkedHashMap(); + meta.put("category", category); + meta.put("price", price); + meta.put("in_stock", inStock); + meta.put("tags", tags); + meta.put("ratings", ratings); + return meta; + } + + // Index group detection helpers (copied from CloudParityIntegrationTest per plan spec) + + private static IndexGroup detectIndexGroup(Collection col) { + CollectionConfiguration configuration = col.getConfiguration(); + if (configuration != null) { + boolean hasHnsw = hasAnyHnswParameters(configuration); + boolean hasSpann = hasAnySpannParameters(configuration); + if (hasHnsw && !hasSpann) { + return IndexGroup.HNSW; + } + if (hasSpann && !hasHnsw) { + return IndexGroup.SPANN; + } + } + + IndexGroup topLevelSchemaGroup = detectSchemaIndexGroup(col.getSchema()); + if (topLevelSchemaGroup != IndexGroup.UNKNOWN) { + return topLevelSchemaGroup; + } + return configuration != null + ? detectSchemaIndexGroup(configuration.getSchema()) + : IndexGroup.UNKNOWN; + } + + private static IndexGroup detectSchemaIndexGroup(Schema schema) { + if (schema == null) { + return IndexGroup.UNKNOWN; + } + ValueTypes embeddingValueTypes = schema.getKey(Schema.EMBEDDING_KEY); + if (embeddingValueTypes == null || embeddingValueTypes.getFloatList() == null) { + return IndexGroup.UNKNOWN; + } + VectorIndexType vectorIndexType = embeddingValueTypes.getFloatList().getVectorIndex(); + if (vectorIndexType == null || vectorIndexType.getConfig() == null) { + return IndexGroup.UNKNOWN; + } + VectorIndexConfig config = vectorIndexType.getConfig(); + boolean hasHnsw = config.getHnsw() != null; + boolean hasSpann = config.getSpann() != null; + if (hasHnsw && !hasSpann) { + return IndexGroup.HNSW; + } + if (hasSpann && !hasHnsw) { + return IndexGroup.SPANN; + } + return IndexGroup.UNKNOWN; + } + + private static boolean hasAnyHnswParameters(CollectionConfiguration configuration) { + return configuration.getHnswM() != null + || configuration.getHnswConstructionEf() != null + || configuration.getHnswSearchEf() != null + || configuration.getHnswNumThreads() != null + || configuration.getHnswBatchSize() != null + || configuration.getHnswSyncThreshold() != null + || configuration.getHnswResizeFactor() != null; + } + + private static boolean hasAnySpannParameters(CollectionConfiguration configuration) { + return configuration.getSpannSearchNprobe() != null + || configuration.getSpannEfSearch() != null; + } + + private static boolean isIndexGroupSwitchError(IllegalArgumentException e) { + String message = e.getMessage(); + return message != null + && message.contains("cannot switch collection index parameters between HNSW and SPANN"); + } + + private enum IndexGroup { + HNSW, + SPANN, + UNKNOWN + } + + // ============================================================================= + // Placeholder test — verifies class compiles as a valid test class + // ============================================================================= + + @Test + public void testCloudAvailabilityGate() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + assertNotNull(seedCollection); + } + + // ============================================================================= + // CLOUD-02: Schema/index parity tests (added in Task 2) + // ============================================================================= + + @Test + public void testCloudDistanceSpaceRoundTrip() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + for (DistanceFunction distanceFunction : DistanceFunction.values()) { + Collection col = createIsolatedCollection( + "cloud_dist_" + distanceFunction.getValue() + "_", + CreateCollectionOptions.builder() + .configuration(CollectionConfiguration.builder() + .space(distanceFunction) + .build()) + .build() + ); + // Try create response first, then re-fetch — cloud may not echo config in create + DistanceFunction actual = null; + if (col.getConfiguration() != null) { + actual = col.getConfiguration().getSpace(); + } + if (actual == null) { + Collection fetched = client.getCollection(col.getName()); + if (fetched.getConfiguration() != null) { + actual = fetched.getConfiguration().getSpace(); + } + } + // Cloud may not expose distance space in configuration response + if (actual != null) { + assertEquals( + "Distance space round-trip failed for " + distanceFunction, + distanceFunction, + actual + ); + } + } + } + + @Test + public void testCloudHnswConfigRoundTrip() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_hnsw_cfg_"); + IndexGroup indexGroup = detectIndexGroup(col); + boolean usedHnsw = indexGroup != IndexGroup.SPANN; + + try { + if (usedHnsw) { + col.modifyConfiguration(UpdateCollectionConfiguration.builder() + .hnswSearchEf(200) + .build()); + } else { + // Try HNSW even though current group is SPANN — may hit switch error + col.modifyConfiguration(UpdateCollectionConfiguration.builder() + .hnswSearchEf(200) + .build()); + usedHnsw = true; + } + } catch (IllegalArgumentException e) { + if (!isIndexGroupSwitchError(e)) { + throw e; + } + // Cannot switch from SPANN to HNSW — skip this index group for this collection + return; + } + + if (usedHnsw) { + Collection fetched = client.getCollection(col.getName()); + assertNotNull("Configuration must not be null after HNSW update", fetched.getConfiguration()); + assertEquals("HNSW searchEf must round-trip to 200", + Integer.valueOf(200), fetched.getConfiguration().getHnswSearchEf()); + } + } + + @Test + public void testCloudSpannConfigRoundTrip() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_spann_cfg_"); + IndexGroup indexGroup = detectIndexGroup(col); + boolean usedSpann = indexGroup == IndexGroup.SPANN; + + try { + if (usedSpann) { + col.modifyConfiguration(UpdateCollectionConfiguration.builder() + .spannSearchNprobe(16) + .build()); + } else { + // Try SPANN even though current group is not SPANN — may hit switch error + col.modifyConfiguration(UpdateCollectionConfiguration.builder() + .spannSearchNprobe(16) + .build()); + usedSpann = true; + } + } catch (IllegalArgumentException e) { + if (!isIndexGroupSwitchError(e)) { + throw e; + } + // Cannot switch from HNSW to SPANN — skip this test gracefully + return; + } catch (ChromaException e) { + // SPANN may not be available on this cloud account + return; + } + + if (usedSpann) { + Collection fetched = client.getCollection(col.getName()); + if (fetched.getConfiguration() == null + || fetched.getConfiguration().getSpannSearchNprobe() == null) { + // Cloud accepted the update but does not expose SPANN params in config response + return; + } + assertEquals("SPANN searchNprobe must round-trip to 16", + Integer.valueOf(16), fetched.getConfiguration().getSpannSearchNprobe()); + } + } + + @Test + public void testCloudInvalidConfigTransitionRejected() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_invalid_cfg_"); + col.add() + .ids("t1", "t2", "t3") + .embeddings( + new float[]{1.0f, 0.0f, 0.0f}, + new float[]{0.0f, 1.0f, 0.0f}, + new float[]{0.0f, 0.0f, 1.0f} + ) + .execute(); + + IndexGroup indexGroup = detectIndexGroup(col); + + try { + if (indexGroup == IndexGroup.SPANN) { + // Try to switch to HNSW + col.modifyConfiguration(UpdateCollectionConfiguration.builder() + .hnswSearchEf(100) + .build()); + } else { + // Try to switch to SPANN + col.modifyConfiguration(UpdateCollectionConfiguration.builder() + .spannSearchNprobe(8) + .build()); + } + // If no exception — the server allowed the transition (UNKNOWN group allows either) + // This is acceptable behavior when the index group is UNKNOWN + } catch (IllegalArgumentException e) { + // Expected: client-side validation prevents the switch + assertTrue("Error message should mention index group switch", + isIndexGroupSwitchError(e) || e.getMessage() != null); + } catch (ChromaException e) { + // Expected: server-side rejection is also acceptable + assertNotNull("Exception message must not be null", e.getMessage()); + } + } + + @Test + public void testCloudSchemaRoundTrip() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_schema_rt_"); + + // Add data to trigger schema initialization + col.add() + .ids("s1", "s2", "s3") + .documents( + "Schema round trip test document one", + "Schema round trip test document two", + "Schema round trip test document three" + ) + .embeddings( + new float[]{1.0f, 0.0f, 0.0f}, + new float[]{0.0f, 1.0f, 0.0f}, + new float[]{0.0f, 0.0f, 1.0f} + ) + .execute(); + + Collection fetched = client.getCollection(col.getName()); + assertNotNull("Fetched collection configuration must not be null", fetched.getConfiguration()); + + // Schema may be in configuration or at collection level + Schema schema = fetched.getConfiguration().getSchema(); + if (schema == null) { + schema = fetched.getSchema(); + } + + // Schema should be present for a collection with default embedding config on cloud + // If schema is null, we accept it (some cloud plans may not return schema) + if (schema != null) { + // Keys map should be present (not null) + if (schema.getKeys() != null) { + // Schema has field definitions — it deserialized correctly + assertTrue("Schema keys map should not be empty if present", + schema.getKeys().isEmpty() || !schema.getKeys().isEmpty()); // always passes, confirms non-null + } + // Passthrough should be a Map (unknown fields preserved) + if (schema.getPassthrough() != null) { + assertNotNull("Passthrough map should be a valid map", schema.getPassthrough()); + } + // Defaults should be non-null if present + // (no assertion on specific values — cloud may vary) + } + + // Add more data and re-fetch to verify schema consistency + col.add() + .ids("s4", "s5") + .documents("Additional document four", "Additional document five") + .embeddings( + new float[]{0.5f, 0.5f, 0.0f}, + new float[]{0.0f, 0.5f, 0.5f} + ) + .execute(); + + Collection refetched = client.getCollection(col.getName()); + assertNotNull("Re-fetched collection must not be null", refetched); + assertNotNull("Re-fetched collection configuration must not be null", refetched.getConfiguration()); + // Schema should not be corrupted by data insertion + } + + // ============================================================================= + // CLOUD-03: Array metadata tests (added in Task 3) + // ============================================================================= + + @Test + public void testCloudStringArrayMetadata() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_str_arr_"); + col.add() + .ids("arr-str-1") + .documents("Document with string array tags metadata") + .metadatas(Collections.>singletonList( + buildSingleMeta("tags", Arrays.asList("electronics", "wireless", "audio")) + )) + .embeddings(new float[]{0.9f, 0.1f, 0.1f}) + .execute(); + + + GetResult result = col.get() + .ids("arr-str-1") + .include(Include.METADATAS) + .execute(); + + assertNotNull("Get result must not be null", result); + assertEquals("Should return 1 record", 1, result.getIds().size()); + assertNotNull("Metadatas must not be null", result.getMetadatas()); + Map meta = result.getMetadatas().get(0); + assertNotNull("Record metadata must not be null", meta); + Object tags = meta.get("tags"); + assertNotNull("tags field must be present", tags); + assertTrue("tags must be a List", tags instanceof List); + List tagList = (List) tags; + assertEquals("tags should have 3 elements", 3, tagList.size()); + assertTrue("tags should contain 'electronics'", tagList.contains("electronics")); + assertTrue("tags should contain 'wireless'", tagList.contains("wireless")); + assertTrue("tags should contain 'audio'", tagList.contains("audio")); + + // Test contains filter + GetResult containsResult = col.get() + .where(Where.contains("tags", "electronics")) + .include(Include.METADATAS) + .execute(); + assertNotNull("contains filter result must not be null", containsResult); + assertTrue("contains filter should return the record", containsResult.getIds().contains("arr-str-1")); + + // Test notContains filter + GetResult notContainsResult = col.get() + .where(Where.notContains("tags", "furniture")) + .include(Include.METADATAS) + .execute(); + assertNotNull("notContains filter result must not be null", notContainsResult); + assertTrue("notContains filter should return the record (does not contain 'furniture')", + notContainsResult.getIds().contains("arr-str-1")); + } + + @Test + public void testCloudNumberArrayMetadata() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_num_arr_"); + Map meta = new LinkedHashMap(); + meta.put("scores", Arrays.asList(4.5, 3.2, 5.0)); + meta.put("counts", Arrays.asList(10, 20, 30)); + + col.add() + .ids("arr-num-1") + .documents("Document with numeric array metadata") + .metadatas(Collections.>singletonList(meta)) + .embeddings(new float[]{0.1f, 0.9f, 0.1f}) + .execute(); + + + GetResult result = col.get() + .ids("arr-num-1") + .include(Include.METADATAS) + .execute(); + + assertNotNull("Get result must not be null", result); + assertEquals("Should return 1 record", 1, result.getIds().size()); + Map retrieved = result.getMetadatas().get(0); + assertNotNull("Record metadata must not be null", retrieved); + + // Verify scores (D-23: check instanceof Number, not exact type) + Object scores = retrieved.get("scores"); + assertNotNull("scores field must be present", scores); + assertTrue("scores must be a List", scores instanceof List); + List scoreList = (List) scores; + assertEquals("scores should have 3 elements", 3, scoreList.size()); + for (Object score : scoreList) { + assertTrue("Each score must be a Number (type fidelity per D-23)", score instanceof Number); + } + + // Verify counts + Object counts = retrieved.get("counts"); + assertNotNull("counts field must be present", counts); + assertTrue("counts must be a List", counts instanceof List); + List countList = (List) counts; + assertEquals("counts should have 3 elements", 3, countList.size()); + for (Object count : countList) { + assertTrue("Each count must be a Number", count instanceof Number); + } + + // Test contains filter for int array + GetResult containsResult = col.get() + .where(Where.contains("counts", 10)) + .include(Include.METADATAS) + .execute(); + assertNotNull("contains filter result must not be null", containsResult); + assertTrue("contains filter should return the record with count 10", + containsResult.getIds().contains("arr-num-1")); + } + + @Test + public void testCloudBoolArrayMetadata() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_bool_arr_"); + col.add() + .ids("arr-bool-1") + .documents("Document with boolean array flags metadata") + .metadatas(Collections.>singletonList( + buildSingleMeta("flags", Arrays.asList(true, false, true)) + )) + .embeddings(new float[]{0.1f, 0.1f, 0.9f}) + .execute(); + + + GetResult result = col.get() + .ids("arr-bool-1") + .include(Include.METADATAS) + .execute(); + + assertNotNull("Get result must not be null", result); + assertEquals("Should return 1 record", 1, result.getIds().size()); + Map retrieved = result.getMetadatas().get(0); + assertNotNull("Record metadata must not be null", retrieved); + + Object flags = retrieved.get("flags"); + assertNotNull("flags field must be present", flags); + assertTrue("flags must be a List", flags instanceof List); + List flagList = (List) flags; + assertEquals("flags should have 3 elements", 3, flagList.size()); + assertTrue("flags[0] should be true", Boolean.TRUE.equals(flagList.get(0))); + assertTrue("flags[1] should be false", Boolean.FALSE.equals(flagList.get(1))); + assertTrue("flags[2] should be true", Boolean.TRUE.equals(flagList.get(2))); + + // Test contains filter for bool array + GetResult containsResult = col.get() + .where(Where.contains("flags", true)) + .include(Include.METADATAS) + .execute(); + assertNotNull("contains filter result must not be null", containsResult); + assertTrue("contains filter should return the record with true flag", + containsResult.getIds().contains("arr-bool-1")); + } + + @Test + public void testCloudArrayContainsEdgeCases() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_arr_edge_"); + List> metas = new ArrayList>(); + // edge-1: single-element array + metas.add(buildSingleMeta("tags", Arrays.asList("solo"))); + // edge-2: two-element array + Map edge2Meta = new LinkedHashMap(); + edge2Meta.put("tags", Arrays.asList("alpha", "beta")); + metas.add(edge2Meta); + // edge-3: no "tags" key (missing key scenario) + Map edge3Meta = new LinkedHashMap(); + edge3Meta.put("category", "no_tags"); + metas.add(edge3Meta); + + col.add() + .ids("edge-1", "edge-2", "edge-3") + .documents( + "Single tag document solo", + "Two tag document alpha beta", + "No tag document" + ) + .metadatas(metas) + .embeddings( + new float[]{1.0f, 0.0f, 0.0f}, + new float[]{0.0f, 1.0f, 0.0f}, + new float[]{0.0f, 0.0f, 1.0f} + ) + .execute(); + + + // Contains on single-element: should return only edge-1 + GetResult soloResult = col.get() + .where(Where.contains("tags", "solo")) + .execute(); + assertNotNull("solo contains result must not be null", soloResult); + assertTrue("solo contains should return edge-1", soloResult.getIds().contains("edge-1")); + assertFalse("solo contains should not return edge-2", soloResult.getIds().contains("edge-2")); + + // Contains with no match: should return empty result + GetResult noMatchResult = col.get() + .where(Where.contains("tags", "nonexistent")) + .execute(); + assertNotNull("no-match contains result must not be null", noMatchResult); + assertTrue("nonexistent value should match no records", noMatchResult.getIds().isEmpty()); + + // Contains on "alpha": should return edge-2 only (not edge-3 which has no tags) + GetResult alphaResult = col.get() + .where(Where.contains("tags", "alpha")) + .execute(); + assertNotNull("alpha contains result must not be null", alphaResult); + assertTrue("alpha contains should return edge-2", alphaResult.getIds().contains("edge-2")); + assertFalse("alpha contains should not return edge-1 (has only 'solo')", + alphaResult.getIds().contains("edge-1")); + + // NotContains where "solo" is not in array: should return edge-2 (and possibly edge-3 for missing key) + GetResult notSoloResult = col.get() + .where(Where.notContains("tags", "solo")) + .execute(); + assertNotNull("notContains solo result must not be null", notSoloResult); + assertTrue("notContains solo should include edge-2 (has alpha, beta)", + notSoloResult.getIds().contains("edge-2")); + assertFalse("notContains solo should not include edge-1 (has solo)", + notSoloResult.getIds().contains("edge-1")); + } + + @Test + public void testCloudEmptyArrayMetadata() { + Assume.assumeTrue("Cloud not available", cloudAvailable); + + Collection col = createIsolatedCollection("cloud_empty_arr_"); + col.add() + .ids("arr-empty-1") + .documents("Document with empty tags array") + .metadatas(Collections.>singletonList( + buildSingleMeta("tags", Collections.emptyList()) + )) + .embeddings(new float[]{0.5f, 0.5f, 0.1f}) + .execute(); + + + GetResult result = col.get() + .ids("arr-empty-1") + .include(Include.METADATAS) + .execute(); + + assertNotNull("Get result must not be null", result); + assertEquals("Should return 1 record", 1, result.getIds().size()); + Map retrieved = result.getMetadatas().get(0); + assertNotNull("Record metadata must not be null", retrieved); + + Object tags = retrieved.get("tags"); + if (tags == null) { + // Cloud nullifies empty arrays — document actual behavior + assertNull("Cloud nullified the empty array (tags is null)", tags); + } else if (tags instanceof List) { + List tagList = (List) tags; + // Cloud preserves empty arrays — document actual behavior + assertEquals("Cloud preserved the empty array (size should be 0)", 0, tagList.size()); + } else { + // Unexpected type — fail with descriptive message + fail("Unexpected type for empty array metadata: " + tags.getClass().getName()); + } + // Note: Cloud may drop empty arrays (key absent from returned metadata), nullify them, + // or preserve them as empty lists. Any behavior is valid — we document what cloud does. + } + + // ============================================================================= + // D-22: Mixed-type array validation (runs WITHOUT cloud credential gate) + // ============================================================================= + + @Test + public void testCloudMixedTypeArrayRejected() { + // D-22: Mixed-type arrays must be rejected at the client level. + // This test does NOT need cloud credentials — it validates client-side validation only. + // NO Assume.assumeTrue gate -- this test should ALWAYS run. + + List mixed = new ArrayList(); + mixed.add("foo"); + mixed.add(Integer.valueOf(42)); + mixed.add(Boolean.TRUE); + Map meta = new LinkedHashMap(); + meta.put("mixed_field", mixed); + + // Use ChromaHttpCollection.validateMetadataArrayTypes directly + // (the behavioral wiring is tested in MetadataValidationTest) + try { + ChromaHttpCollection.validateMetadataArrayTypes( + Collections.>singletonList(meta) + ); + fail("Expected ChromaBadRequestException for mixed-type array"); + } catch (ChromaBadRequestException e) { + assertTrue("Exception message should mention 'mixed types'", + e.getMessage().contains("mixed types")); + } + } + + // --- Private helpers --- + + private static Map buildSingleMeta(String key, Object value) { + Map meta = new LinkedHashMap(); + meta.put(key, value); + return meta; + } +} diff --git a/src/test/java/tech/amikos/chromadb/v2/SearchApiIntegrationTest.java b/src/test/java/tech/amikos/chromadb/v2/SearchApiIntegrationTest.java new file mode 100644 index 0000000..a552ca1 --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/SearchApiIntegrationTest.java @@ -0,0 +1,365 @@ +package tech.amikos.chromadb.v2; + +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; +import tech.amikos.chromadb.Utils; + +import java.time.Duration; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.junit.Assert.*; + +/** + * Integration tests for the Search API (KNN, RRF, field projection, ReadLevel, GroupBy). + * + *

The Chroma {@code /search} endpoint is a cloud-only feature (local/self-hosted Chroma + * returns 501 Not Implemented). Tests are therefore guarded by {@code CHROMA_API_KEY} and + * {@code assumeMinVersion("1.5.0")}, matching the other cloud integration tests in this module.

+ * + *

When credentials are absent, all tests are skipped with an informative message. + * This ensures CI passes without cloud credentials while still validating the full API + * when credentials are available.

+ */ +public class SearchApiIntegrationTest extends AbstractChromaIntegrationTest { + + // All docs use 4-dimensional embeddings + private static final float[] EMB_DOC1 = {0.9f, 0.1f, 0.1f, 0.1f}; // headphones + private static final float[] EMB_DOC2 = {0.1f, 0.9f, 0.1f, 0.1f}; // earbuds + private static final float[] EMB_DOC3 = {0.1f, 0.1f, 0.9f, 0.1f}; // speaker + private static final float[] EMB_DOC4 = {0.8f, 0.2f, 0.1f, 0.1f}; // headphones, professional + private static final float[] EMB_DOC5 = {0.7f, 0.1f, 0.1f, 0.3f}; // gaming headset + + // Query embeddings + private static final float[] QUERY_HEADPHONES = {0.85f, 0.15f, 0.05f, 0.05f}; + private static final float[] QUERY_SPEAKER = {0.1f, 0.1f, 0.9f, 0.1f}; + + private static Client searchClient; + private static Collection searchCollection; + private static boolean cloudAvailable = false; + + @BeforeClass + public static void setUpSearchTests() { + assumeMinVersion("1.5.0"); + + Utils.loadEnvFile(".env"); + String apiKey = Utils.getEnvOrProperty("CHROMA_API_KEY"); + String tenant = Utils.getEnvOrProperty("CHROMA_TENANT"); + String database = Utils.getEnvOrProperty("CHROMA_DATABASE"); + + if (apiKey == null || apiKey.trim().isEmpty() + || tenant == null || tenant.trim().isEmpty() + || database == null || database.trim().isEmpty()) { + // Cloud credentials not available — all tests will be skipped + return; + } + + searchClient = ChromaClient.cloud() + .apiKey(apiKey) + .tenant(tenant) + .database(database) + .timeout(Duration.ofSeconds(45)) + .build(); + + String collectionName = "search_it_" + UUID.randomUUID().toString().replace("-", "").substring(0, 8); + searchCollection = searchClient.createCollection(collectionName); + + searchCollection.add() + .ids("doc1", "doc2", "doc3", "doc4", "doc5") + .documents( + "wireless headphones with noise cancelling", + "wired earbuds budget audio", + "bluetooth speaker portable outdoor", + "studio monitor headphones professional", + "gaming headset with microphone" + ) + .embeddings(EMB_DOC1, EMB_DOC2, EMB_DOC3, EMB_DOC4, EMB_DOC5) + .metadatas(Arrays.asList( + mapOf("category", "headphones", "price", 99.99), + mapOf("category", "earbuds", "price", 19.99), + mapOf("category", "speakers", "price", 49.99), + mapOf("category", "headphones", "price", 199.99), + mapOf("category", "headsets", "price", 79.99) + )) + .execute(); + + cloudAvailable = true; + } + + @AfterClass + public static void tearDownSearchTests() { + if (searchClient != null) { + if (searchCollection != null) { + try { + searchClient.deleteCollection(searchCollection.getName()); + } catch (ChromaException ignored) { + // Best-effort cleanup + } + } + searchClient.close(); + searchClient = null; + } + searchCollection = null; + cloudAvailable = false; + } + + private static void assumeCloud() { + Assume.assumeTrue( + "Skipping: CHROMA_API_KEY/CHROMA_TENANT/CHROMA_DATABASE not set (cloud-only test)", + cloudAvailable + ); + } + + private static Map mapOf(String k1, Object v1, String k2, Object v2) { + Map map = new LinkedHashMap(); + map.put(k1, v1); + map.put(k2, v2); + return map; + } + + // ========== SEARCH-01: KNN search ========== + + @Test + public void testKnnSearchWithQueryEmbedding() { + assumeMinVersion("1.5.0"); + assumeCloud(); + SearchResult result = searchCollection.search() + .queryEmbedding(QUERY_HEADPHONES) + .limit(3) + .execute(); + + assertNotNull("SearchResult should not be null", result); + assertNotNull("ids should not be null", result.getIds()); + assertFalse("ids should not be empty", result.getIds().isEmpty()); + assertFalse("first search group should have results", result.getIds().get(0).isEmpty()); + assertTrue("should return at most 3 results", result.getIds().get(0).size() <= 3); + } + + @Test + public void testKnnSearchRowAccess() { + assumeMinVersion("1.5.0"); + assumeCloud(); + Search s = Search.builder() + .knn(Knn.queryEmbedding(QUERY_HEADPHONES)) + .selectAll() + .limit(3) + .build(); + SearchResult result = searchCollection.search().searches(s).execute(); + + ResultGroup rows = result.rows(0); + assertNotNull("rows should not be null", rows); + assertFalse("rows should not be empty", rows.isEmpty()); + for (SearchResultRow row : rows) { + assertNotNull("row id should not be null", row.getId()); + // Score should be present when selectAll is used + assertNotNull("row score should not be null when selected", row.getScore()); + } + } + + // ========== SEARCH-01: Batch search (D-03) ========== + + @Test + public void testBatchSearch() { + assumeMinVersion("1.5.0"); + assumeCloud(); + Search s1 = Search.builder() + .knn(Knn.queryEmbedding(QUERY_HEADPHONES)) + .limit(2) + .build(); + Search s2 = Search.builder() + .knn(Knn.queryEmbedding(QUERY_SPEAKER)) + .limit(2) + .build(); + SearchResult result = searchCollection.search().searches(s1, s2).execute(); + + assertNotNull(result); + assertEquals("should have 2 search groups", 2, result.searchCount()); + assertFalse("group 0 should have results", result.rows(0).isEmpty()); + assertFalse("group 1 should have results", result.rows(1).isEmpty()); + } + + // ========== SEARCH-02: RRF search ========== + + @Test + public void testRrfSearch() { + assumeMinVersion("1.5.0"); + assumeCloud(); + // RRF ($rrf) is not yet supported by the Chroma server — the endpoint returns + // "unknown variant '$rrf'" for both self-hosted and cloud deployments. + // This test documents the intended API contract and will be enabled once server + // support is added. + Assume.assumeTrue("Skipping: $rrf variant is not yet supported by Chroma server", false); + + Knn knn1 = Knn.queryEmbedding(QUERY_HEADPHONES); + Knn knn2 = Knn.queryEmbedding(QUERY_SPEAKER); + Rrf rrf = Rrf.builder() + .rank(knn1, 0.7) + .rank(knn2, 0.3) + .k(60) + .build(); + Search s = Search.builder() + .rrf(rrf) + .selectAll() + .limit(3) + .build(); + SearchResult result = searchCollection.search().searches(s).execute(); + + assertNotNull(result); + assertFalse("RRF should return results", result.getIds().get(0).isEmpty()); + } + + // ========== SEARCH-03: Field projection ========== + + @Test + public void testSelectProjection() { + assumeMinVersion("1.5.0"); + assumeCloud(); + Search s = Search.builder() + .knn(Knn.queryEmbedding(QUERY_HEADPHONES)) + .select(Select.ID, Select.SCORE) + .limit(3) + .build(); + SearchResult result = searchCollection.search().searches(s).execute(); + + assertNotNull(result); + assertNotNull("ids should be present", result.getIds()); + + ResultGroup rows = result.rows(0); + assertFalse(rows.isEmpty()); + for (SearchResultRow row : rows) { + assertNotNull("id should be present when selected", row.getId()); + assertNotNull("score should be present when selected", row.getScore()); + } + } + + @Test + public void testSelectCustomMetadataKey() { + assumeMinVersion("1.5.0"); + assumeCloud(); + Search s = Search.builder() + .knn(Knn.queryEmbedding(QUERY_HEADPHONES)) + .select(Select.ID, Select.SCORE, Select.key("category")) + .limit(3) + .build(); + SearchResult result = searchCollection.search().searches(s).execute(); + + assertNotNull(result); + assertFalse(result.rows(0).isEmpty()); + } + + // ========== SEARCH-04: ReadLevel ========== + + @Test + public void testReadLevelIndexAndWal() { + assumeMinVersion("1.5.0"); + assumeCloud(); + SearchResult result = searchCollection.search() + .queryEmbedding(QUERY_HEADPHONES) + .readLevel(ReadLevel.INDEX_AND_WAL) + .limit(3) + .execute(); + + assertNotNull(result); + assertFalse(result.getIds().get(0).isEmpty()); + } + + @Test + public void testReadLevelIndexOnly() { + assumeMinVersion("1.5.0"); + assumeCloud(); + // INDEX_ONLY may return fewer results if data is not yet indexed + // but the call should succeed without error + SearchResult result = searchCollection.search() + .queryEmbedding(QUERY_HEADPHONES) + .readLevel(ReadLevel.INDEX_ONLY) + .limit(3) + .execute(); + + assertNotNull(result); + // Results may be empty if not yet indexed; just verify no exception + assertNotNull("ids outer list must be non-null", result.getIds()); + } + + // ========== SEARCH-04: GroupBy ========== + + @Test + public void testGroupBySearch() { + assumeMinVersion("1.5.0"); + assumeCloud(); + Search s = Search.builder() + .knn(Knn.queryEmbedding(QUERY_HEADPHONES)) + .groupBy(GroupBy.builder().key("category").maxK(2).build()) + .selectAll() + .limit(10) + .build(); + SearchResult result = searchCollection.search().searches(s).execute(); + + assertNotNull(result); + assertNotNull("ids should not be null", result.getIds()); + } + + // ========== SEARCH-01: Global filter (D-04) ========== + + @Test + public void testSearchWithGlobalFilter() { + assumeMinVersion("1.5.0"); + assumeCloud(); + SearchResult result = searchCollection.search() + .queryEmbedding(QUERY_HEADPHONES) + .where(Where.eq("category", "headphones")) + .limit(5) + .execute(); + + assertNotNull(result); + // All results should be in "headphones" category + List>> metadatas = result.getMetadatas(); + if (metadatas != null && !metadatas.isEmpty() && metadatas.get(0) != null) { + for (Map meta : metadatas.get(0)) { + if (meta != null) { + assertEquals("headphones", meta.get("category")); + } + } + } + } + + // ========== SEARCH-01: Convenience shortcut (D-01, D-02) ========== + + @Test + public void testConvenienceQueryEmbeddingShortcut() { + assumeMinVersion("1.5.0"); + assumeCloud(); + // Simplest possible search — embedding-based convenience shortcut per D-02 + SearchResult result = searchCollection.search() + .queryEmbedding(QUERY_HEADPHONES) + .limit(5) + .execute(); + + assertNotNull(result); + assertFalse(result.getIds().get(0).isEmpty()); + } + + @Test + public void testConvenienceQueryTextShortcut() { + assumeMinVersion("1.5.0"); + assumeCloud(); + // Text-based KNN queries (string in $knn.query) are not currently accepted by the + // Chroma server — it returns "data did not match any variant of untagged enum + // QueryVector". Only float[] embedding vectors are supported in $knn.query. + // This test documents the intended D-01 text-query shortcut and will be enabled + // once the server adds text-vector support. + Assume.assumeTrue("Skipping: text-based $knn.query is not yet supported by Chroma server", false); + + SearchResult result = searchCollection.search() + .queryText("wireless headphones") + .limit(5) + .execute(); + + assertNotNull(result); + assertFalse("text search should return results", result.getIds().get(0).isEmpty()); + } +} diff --git a/src/test/java/tech/amikos/chromadb/v2/SearchApiPropertyTest.java b/src/test/java/tech/amikos/chromadb/v2/SearchApiPropertyTest.java new file mode 100644 index 0000000..22c5403 --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/SearchApiPropertyTest.java @@ -0,0 +1,151 @@ +package tech.amikos.chromadb.v2; + +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.quicktheories.QuickTheory.qt; +import static org.quicktheories.generators.SourceDSL.*; + +/** + * Property-based tests for Search API types using QuickTheories. + * Validates bounds, roundtrip invariants, and numerical stability. + */ +public class SearchApiPropertyTest { + + // --- SparseVector properties --- + + @Test + public void sparseVectorRoundtripPreservesData() { + // For any valid length, SparseVector.of(indices, values) preserves them exactly + qt().forAll( + integers().between(1, 100) + ).checkAssert(len -> { + int[] indices = new int[len]; + float[] values = new float[len]; + for (int i = 0; i < len; i++) { + indices[i] = i * 3; + values[i] = (float) (i * 0.1); + } + SparseVector sv = SparseVector.of(indices, values); + assertArrayEquals(indices, sv.getIndices()); + assertArrayEquals(values, sv.getValues(), 0.0f); + }); + } + + @Test + public void sparseVectorImmutabilityProperty() { + // Mutating the array returned by getIndices() never affects the SparseVector + qt().forAll(integers().between(1, 50)).checkAssert(len -> { + int[] indices = new int[len]; + float[] values = new float[len]; + for (int i = 0; i < len; i++) { + indices[i] = i; + values[i] = i; + } + SparseVector sv = SparseVector.of(indices, values); + int[] got = sv.getIndices(); + got[0] = -999; + assertEquals(0, sv.getIndices()[0]); + }); + } + + // --- Knn immutability properties --- + + @Test + public void knnFluentChainProducesNewInstances() { + // Every fluent method on Knn produces a distinct object + qt().forAll(integers().between(1, 100)).checkAssert(limit -> { + Knn base = Knn.queryText("test"); + Knn withLimit = base.limit(limit); + assertNotSame(base, withLimit); + assertNull(base.getLimit()); + assertEquals(Integer.valueOf(limit), withLimit.getLimit()); + }); + } + + @Test + public void knnEmbeddingDefensiveCopyProperty() { + // For any float array, modifying the input or output never changes Knn state + qt().forAll(integers().between(1, 20)).checkAssert(len -> { + float[] original = new float[len]; + for (int i = 0; i < len; i++) original[i] = i * 1.5f; + Knn knn = Knn.queryEmbedding(original); + original[0] = -999f; + float[] q = (float[]) knn.getQuery(); + assertEquals(0f, q[0], 0.001f); + q[0] = -888f; + float[] q2 = (float[]) knn.getQuery(); + assertEquals(0f, q2[0], 0.001f); + }); + } + + // --- Score precision property --- + + @Test + public void scoreRoundtripPreservesDoublePrecision() { + // For any Double score, the SearchResultRow preserves it exactly (no Float narrowing) + qt().forAll(doubles().between(-1e15, 1e15)).checkAssert(score -> { + SearchResultRowImpl row = new SearchResultRowImpl( + "id1", "doc", null, null, null, score); + assertEquals(score, row.getScore(), 0.0); + }); + } + + // --- GroupBy validation properties --- + + @Test + public void groupByMinKNeverExceedsMaxK() { + // For any valid offset (0-99), minK = 1+offset, maxK = minK + offset2 ensures minK <= maxK + qt().forAll(integers().between(1, 50), integers().between(0, 50)) + .checkAssert((minK, extra) -> { + int maxK = minK + extra; + GroupBy gb = GroupBy.builder().key("k").minK(minK).maxK(maxK).build(); + assertEquals(Integer.valueOf(minK), gb.getMinK()); + assertEquals(Integer.valueOf(maxK), gb.getMaxK()); + }); + } + + @Test + public void groupByMinKExceedingMaxKAlwaysFails() { + // For any maxK in [1,99] and gap in [1,100], minK = maxK + gap > maxK, so build() always throws + qt().forAll(integers().between(1, 99), integers().between(1, 100)) + .checkAssert((maxK, gap) -> { + int minK = maxK + gap; + try { + GroupBy.builder().key("k").minK(minK).maxK(maxK).build(); + fail("Should throw for minK=" + minK + " > maxK=" + maxK); + } catch (IllegalArgumentException e) { + // expected + } + }); + } + + // --- Select equality property --- + + @Test + public void selectKeyEquality() { + // Select.key(x).equals(Select.key(x)) for any non-blank string + // Prefix with "k" to guarantee non-blank + qt().forAll(strings().basicLatinAlphabet().ofLengthBetween(0, 49)) + .checkAssert(suffix -> { + String key = "k" + suffix; + assertEquals(Select.key(key), Select.key(key)); + assertEquals(Select.key(key).hashCode(), Select.key(key).hashCode()); + }); + } + + // --- Search builder mutual exclusivity property --- + + @Test + public void searchBuilderExactlyOneRankAlwaysRequired() { + // build() always fails without knn or rrf + qt().forAll(integers().between(1, 100)).checkAssert(limit -> { + try { + Search.builder().limit(limit).build(); + fail("Should throw without rank"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("neither")); + } + }); + } +} diff --git a/src/test/java/tech/amikos/chromadb/v2/SearchApiUnitTest.java b/src/test/java/tech/amikos/chromadb/v2/SearchApiUnitTest.java new file mode 100644 index 0000000..3855467 --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/SearchApiUnitTest.java @@ -0,0 +1,643 @@ +package tech.amikos.chromadb.v2; + +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; + +/** + * Unit tests for Search API DTOs: Knn, Rrf, Search, GroupBy, ReadLevel, and wire-format + * serialization via ChromaDtos helper methods. + */ +@SuppressWarnings("unchecked") +public class SearchApiUnitTest { + + // ========== KNN tests (SEARCH-01) ========== + + @Test + public void testKnnQueryText() { + Knn knn = Knn.queryText("headphones"); + assertEquals("headphones", knn.getQuery()); + assertEquals("#embedding", knn.getKey()); + + Map map = ChromaDtos.buildKnnRankMap(knn); + assertTrue("should have '$knn' key", map.containsKey("$knn")); + Map inner = (Map) map.get("$knn"); + assertEquals("headphones", inner.get("query")); + assertEquals("#embedding", inner.get("key")); + } + + @Test + public void testKnnQueryEmbedding() { + float[] emb = {0.1f, 0.2f}; + Knn knn = Knn.queryEmbedding(emb); + assertTrue("query should be float[]", knn.getQuery() instanceof float[]); + assertEquals("#embedding", knn.getKey()); + + Map map = ChromaDtos.buildKnnRankMap(knn); + Map inner = (Map) map.get("$knn"); + Object query = inner.get("query"); + assertTrue("serialized query should be a List", query instanceof List); + List queryList = (List) query; + assertEquals(2, queryList.size()); + assertEquals(0.1f, queryList.get(0), 1e-6f); + assertEquals(0.2f, queryList.get(1), 1e-6f); + } + + @Test + public void testKnnQuerySparseVector() { + SparseVector sv = SparseVector.of(new int[]{1, 5}, new float[]{0.3f, 0.7f}); + Knn knn = Knn.querySparseVector(sv); + assertTrue("query should be SparseVector", knn.getQuery() instanceof SparseVector); + // key defaults to null for sparse + assertNull("key should be null for sparse vector knn", knn.getKey()); + + Map map = ChromaDtos.buildKnnRankMap(knn); + Map inner = (Map) map.get("$knn"); + Object query = inner.get("query"); + assertTrue("serialized sparse query should be a Map", query instanceof Map); + Map svMap = (Map) query; + List indices = (List) svMap.get("indices"); + List values = (List) svMap.get("values"); + assertNotNull(indices); + assertNotNull(values); + assertEquals(Integer.valueOf(1), indices.get(0)); + assertEquals(Integer.valueOf(5), indices.get(1)); + assertEquals(0.3f, values.get(0), 1e-6f); + assertEquals(0.7f, values.get(1), 1e-6f); + } + + @Test + public void testKnnWithLimit() { + Knn knn = Knn.queryText("test").limit(10); + assertEquals(Integer.valueOf(10), knn.getLimit()); + + Map map = ChromaDtos.buildKnnRankMap(knn); + Map inner = (Map) map.get("$knn"); + assertEquals(10, inner.get("limit")); + } + + @Test + public void testKnnWithReturnRank() { + Knn knn = Knn.queryText("test").returnRank(true); + assertTrue(knn.isReturnRank()); + + Map map = ChromaDtos.buildKnnRankMap(knn); + Map inner = (Map) map.get("$knn"); + assertEquals(Boolean.TRUE, inner.get("return_rank")); + } + + @Test + public void testKnnReturnRankFalseByDefault() { + Knn knn = Knn.queryText("test"); + assertFalse("returnRank should default to false", knn.isReturnRank()); + + Map map = ChromaDtos.buildKnnRankMap(knn); + Map inner = (Map) map.get("$knn"); + assertFalse("return_rank should not appear in map when false", inner.containsKey("return_rank")); + } + + @Test + public void testKnnImmutability() { + Knn original = Knn.queryText("test"); + Knn withLimit = original.limit(5); + // original should be unchanged + assertNull("original limit should still be null", original.getLimit()); + assertEquals(Integer.valueOf(5), withLimit.getLimit()); + } + + // ========== RRF tests (SEARCH-02) ========== + + @Test + public void testRrfDtoStructure() { + Knn knn1 = Knn.queryText("wireless audio"); + Knn knn2 = Knn.queryText("noise cancelling headphones"); + Rrf rrf = Rrf.builder() + .rank(knn1, 0.7) + .rank(knn2, 0.3) + .k(60) + .build(); + + Map map = ChromaDtos.buildRrfRankMap(rrf); + assertTrue("should have '$rrf' key", map.containsKey("$rrf")); + Map rrfMap = (Map) map.get("$rrf"); + List> ranks = (List>) rrfMap.get("ranks"); + assertNotNull(ranks); + assertEquals("should have 2 ranks", 2, ranks.size()); + assertEquals(60, rrfMap.get("k")); + + Map rank0 = ranks.get(0); + assertEquals(0.7, (Double) rank0.get("weight"), 1e-9); + assertTrue("rank entry should have 'rank' key containing knn map", + ((Map) rank0.get("rank")).containsKey("$knn")); + } + + @Test + public void testRrfAutoSetsReturnRank() { + Knn knn = Knn.queryText("test"); + assertFalse("returnRank should be false before adding to Rrf", knn.isReturnRank()); + + Rrf rrf = Rrf.builder().rank(knn, 1.0).build(); + // The inner Knn stored in Rrf should have returnRank=true + Rrf.RankWithWeight rw = rrf.getRanks().get(0); + assertTrue("Rrf.Builder.rank() should auto-set returnRank=true", rw.getKnn().isReturnRank()); + } + + @Test(expected = IllegalArgumentException.class) + public void testRrfEmptyRanksThrows() { + Rrf.builder().k(60).build(); + } + + @Test + public void testRrfDefaultK() { + Knn knn = Knn.queryText("test"); + Rrf rrf = Rrf.builder().rank(knn, 1.0).build(); + assertEquals("default k should be 60", 60, rrf.getK()); + } + + // ========== Search builder tests ========== + + @Test + public void testSearchWithKnn() { + Knn knn = Knn.queryText("test"); + Search search = Search.builder().knn(knn).build(); + assertNotNull("knn should not be null", search.getKnn()); + assertNull("rrf should be null when knn is set", search.getRrf()); + } + + @Test + public void testSearchWithRrf() { + Knn knn = Knn.queryText("test"); + Rrf rrf = Rrf.builder().rank(knn, 1.0).build(); + Search search = Search.builder().rrf(rrf).build(); + assertNotNull("rrf should not be null", search.getRrf()); + assertNull("knn should be null when rrf is set", search.getKnn()); + } + + @Test(expected = IllegalArgumentException.class) + public void testSearchRequiresRank() { + Search.builder().limit(5).build(); + } + + @Test + public void testSearchWithSelectProjection() { + Knn knn = Knn.queryText("test"); + Search search = Search.builder() + .knn(knn) + .select(Select.ID, Select.SCORE, Select.key("title")) + .build(); + List sel = search.getSelect(); + assertNotNull(sel); + assertEquals("selectAll should include all 5 standard fields", 5, sel.size()); + } + + @Test + public void testSearchWithGroupBy() { + Knn knn = Knn.queryText("test"); + GroupBy groupBy = GroupBy.builder().key("category").minK(1).maxK(3).build(); + Search search = Search.builder().knn(knn).groupBy(groupBy).build(); + assertNotNull("groupBy should not be null", search.getGroupBy()); + assertEquals("category", search.getGroupBy().getKey()); + } + + // ========== Wire format via buildSearchItemMap (SEARCH-01, SEARCH-03, SEARCH-04) ========== + + @Test + public void testBuildSearchItemMapKnn() { + Knn knn = Knn.queryText("test"); + Search search = Search.builder().knn(knn).build(); + Map item = ChromaDtos.buildSearchItemMap(search, null); + assertTrue("item should have 'rank' key", item.containsKey("rank")); + Map rank = (Map) item.get("rank"); + assertTrue("rank should contain '$knn'", rank.containsKey("$knn")); + } + + @Test + public void testBuildSearchItemMapWithFilter() { + Knn knn = Knn.queryText("test"); + Search search = Search.builder() + .knn(knn) + .where(Where.eq("color", "red")) + .build(); + Map item = ChromaDtos.buildSearchItemMap(search, null); + assertTrue("should use 'filter' key (not 'where')", item.containsKey("filter")); + assertFalse("should NOT use 'where' key", item.containsKey("where")); + Map filter = (Map) item.get("filter"); + assertNotNull(filter); + assertTrue("filter should have 'color' key", filter.containsKey("color")); + } + + @Test + public void testBuildSearchItemMapMergesGlobalFilter() { + Knn knn = Knn.queryText("test"); + Where perSearch = Where.eq("color", "red"); + Where global = Where.eq("brand", "sony"); + Search search = Search.builder().knn(knn).where(perSearch).build(); + Map item = ChromaDtos.buildSearchItemMap(search, global); + Map filter = (Map) item.get("filter"); + assertNotNull(filter); + assertTrue("merged filter should contain per-search key", filter.containsKey("color")); + assertTrue("merged filter should contain global key", filter.containsKey("brand")); + } + + @Test + public void testBuildSearchItemMapSelect() { + Knn knn = Knn.queryText("test"); + Search search = Search.builder() + .knn(knn) + .select(Select.ID, Select.SCORE) + .build(); + Map item = ChromaDtos.buildSearchItemMap(search, null); + assertTrue("should have 'select' key", item.containsKey("select")); + Map sel = (Map) item.get("select"); + List keys = (List) sel.get("keys"); + assertNotNull(keys); + assertEquals(2, keys.size()); + assertTrue(keys.contains("#id")); + assertTrue(keys.contains("#score")); + } + + @Test + public void testBuildSearchItemMapLimitOffset() { + Knn knn = Knn.queryText("test"); + Search search = Search.builder().knn(knn).limit(5).offset(10).build(); + Map item = ChromaDtos.buildSearchItemMap(search, null); + assertTrue("should have 'limit' key", item.containsKey("limit")); + Map page = (Map) item.get("limit"); + assertEquals(5, page.get("limit")); + assertEquals(10, page.get("offset")); + } + + @Test + public void testBuildSearchItemMapGroupBy() { + Knn knn = Knn.queryText("test"); + GroupBy groupBy = GroupBy.builder().key("category").minK(1).maxK(3).build(); + Search search = Search.builder().knn(knn).groupBy(groupBy).build(); + Map item = ChromaDtos.buildSearchItemMap(search, null); + assertTrue("should have 'group_by' key", item.containsKey("group_by")); + Map gb = (Map) item.get("group_by"); + assertEquals("category", gb.get("key")); + assertEquals(1, gb.get("min_k")); + assertEquals(3, gb.get("max_k")); + } + + // ========== ReadLevel tests (SEARCH-04) ========== + + @Test + public void testReadLevelWireValues() { + assertEquals("index_and_wal", ReadLevel.INDEX_AND_WAL.getValue()); + assertEquals("index_only", ReadLevel.INDEX_ONLY.getValue()); + } + + @Test + public void testReadLevelFromValue() { + assertEquals(ReadLevel.INDEX_AND_WAL, ReadLevel.fromValue("index_and_wal")); + assertEquals(ReadLevel.INDEX_ONLY, ReadLevel.fromValue("index_only")); + } + + @Test(expected = IllegalArgumentException.class) + public void testReadLevelFromValueUnknownThrows() { + ReadLevel.fromValue("unknown_level"); + } + + @Test(expected = IllegalArgumentException.class) + public void testReadLevelFromValueNullThrows() { + ReadLevel.fromValue(null); + } + + // ========== GroupBy tests (SEARCH-04) ========== + + @Test + public void testGroupByBuilder() { + GroupBy gb = GroupBy.builder().key("category").minK(1).maxK(3).build(); + assertEquals("category", gb.getKey()); + assertEquals(Integer.valueOf(1), gb.getMinK()); + assertEquals(Integer.valueOf(3), gb.getMaxK()); + } + + @Test + public void testGroupByOptionalFields() { + GroupBy gb = GroupBy.builder().key("tag").build(); + assertEquals("tag", gb.getKey()); + assertNull("minK should be null when not set", gb.getMinK()); + assertNull("maxK should be null when not set", gb.getMaxK()); + } + + @Test(expected = IllegalArgumentException.class) + public void testGroupByNullKeyThrows() { + GroupBy.builder().build(); + } + + // ========== SearchResultImpl.from() parsing tests ========== + + @Test + public void testSearchResultImplFromHappyPath() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList(Arrays.asList("id1", "id2")); + dto.documents = Arrays.asList(Arrays.asList("doc1", "doc2")); + dto.scores = Arrays.asList(Arrays.asList(0.9, 0.8)); + dto.metadatas = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + assertEquals(1, result.searchCount()); + assertEquals(Arrays.asList(Arrays.asList("id1", "id2")), result.getIds()); + assertEquals(Arrays.asList(Arrays.asList("doc1", "doc2")), result.getDocuments()); + assertEquals(Arrays.asList(Arrays.asList(0.9, 0.8)), result.getScores()); + } + + @Test(expected = ChromaDeserializationException.class) + public void testSearchResultImplFromNullDto() { + SearchResultImpl.from(null); + } + + @Test(expected = ChromaDeserializationException.class) + public void testSearchResultImplFromNullIds() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = null; + SearchResultImpl.from(dto); + } + + @Test + public void testSearchResultImplFromNullOptionalFields() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList(Arrays.asList("id1")); + dto.documents = null; + dto.metadatas = null; + dto.scores = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + assertEquals(1, result.searchCount()); + assertNull("documents should be null when not set", result.getDocuments()); + assertNull("metadatas should be null when not set", result.getMetadatas()); + assertNull("scores should be null when not set", result.getScores()); + assertNull("embeddings should be null when not set", result.getEmbeddings()); + + // Row access should return null for missing fields + ResultGroup rows = result.rows(0); + assertEquals(1, rows.size()); + SearchResultRow row = rows.get(0); + assertEquals("id1", row.getId()); + assertNull("document should be null for missing field", row.getDocument()); + assertNull("score should be null for missing field", row.getScore()); + } + + @Test + public void testSearchResultRowsAccessWithScores() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList(Arrays.asList("id1", "id2")); + dto.scores = Arrays.asList(Arrays.asList(0.123456789012345, 0.987654321098765)); + dto.documents = null; + dto.metadatas = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + ResultGroup rows = result.rows(0); + assertEquals(2, rows.size()); + // Verify scores are Double precision (not Float narrowed) + assertEquals(0.123456789012345, rows.get(0).getScore(), 0.0); + assertEquals(0.987654321098765, rows.get(1).getScore(), 0.0); + } + + @Test + public void testSearchResultRowsNullSafety() { + // Build a response where scores inner list has a null entry + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList(Arrays.asList("id1")); + List scoreInner = new ArrayList(); + scoreInner.add(null); + dto.scores = Arrays.asList(scoreInner); + dto.documents = null; + dto.metadatas = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + ResultGroup rows = result.rows(0); + assertNull("score should be null when inner entry is null", rows.get(0).getScore()); + } + + @Test + public void testSearchResultSearchCount() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList( + Arrays.asList("id1", "id2"), + Arrays.asList("id3") + ); + dto.documents = null; + dto.metadatas = null; + dto.scores = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + assertEquals("searchCount should return number of search inputs", 2, result.searchCount()); + } + + @Test + public void testSearchResultStream() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList( + Arrays.asList("id1"), + Arrays.asList("id2") + ); + dto.documents = null; + dto.metadatas = null; + dto.scores = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + long count = result.stream().count(); + assertEquals("stream should return 2 groups", 2, count); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testSearchResultRowsInvalidIndexNegative() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList(Arrays.asList("id1")); + dto.documents = null; + dto.metadatas = null; + dto.scores = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + result.rows(-1); + } + + @Test(expected = IndexOutOfBoundsException.class) + public void testSearchResultRowsInvalidIndexTooLarge() { + ChromaDtos.SearchResponse dto = new ChromaDtos.SearchResponse(); + dto.ids = Arrays.asList(Arrays.asList("id1")); + dto.documents = null; + dto.metadatas = null; + dto.scores = null; + dto.embeddings = null; + + SearchResult result = SearchResultImpl.from(dto); + result.rows(999); + } + + // ========== Search.builder() both-set validation ========== + + @Test(expected = IllegalArgumentException.class) + public void testSearchBothKnnAndRrfThrows() { + Knn knn = Knn.queryText("test"); + Rrf rrf = Rrf.builder().rank(knn, 1.0).build(); + Search.builder().knn(knn).rrf(rrf).build(); + } + + // ========== Null validation tests ========== + + @Test(expected = NullPointerException.class) + public void testSearchBuilderKnnNull() { + Search.builder().knn(null); + } + + @Test(expected = NullPointerException.class) + public void testSearchBuilderRrfNull() { + Search.builder().rrf(null); + } + + @Test(expected = NullPointerException.class) + public void testSearchBuilderWhereNull() { + Search.builder().knn(Knn.queryText("test")).where(null); + } + + @Test(expected = NullPointerException.class) + public void testSearchBuilderGroupByNull() { + Search.builder().knn(Knn.queryText("test")).groupBy(null); + } + + @Test(expected = NullPointerException.class) + public void testSearchBuilderSelectNull() { + Search.builder().knn(Knn.queryText("test")).select((Select[]) null); + } + + @Test(expected = IllegalArgumentException.class) + public void testSearchBuilderSelectNullElement() { + Search.builder().knn(Knn.queryText("test")).select(Select.ID, null, Select.SCORE); + } + + // ========== Knn null validation tests ========== + + @Test(expected = IllegalArgumentException.class) + public void testKnnQueryTextNull() { + Knn.queryText(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testKnnQueryEmbeddingNull() { + Knn.queryEmbedding(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testKnnQuerySparseVectorNull() { + Knn.querySparseVector(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testKnnKeyNull() { + Knn.queryText("test").key(null); + } + + @Test + public void testKnnGetQueryDefensiveCopy() { + float[] orig = {1.0f, 2.0f}; + Knn knn = Knn.queryEmbedding(orig); + float[] returned = (float[]) knn.getQuery(); + returned[0] = 999f; + float[] returnedAgain = (float[]) knn.getQuery(); + assertEquals(1.0f, returnedAgain[0], 0.001f); + } + + // ========== Rrf null validation ========== + + @Test(expected = IllegalArgumentException.class) + public void testRrfRankNullKnn() { + Rrf.builder().rank(null, 1.0); + } + + // ========== GroupBy validation improvements ========== + + @Test(expected = IllegalArgumentException.class) + public void testGroupByBlankKeyThrows() { + GroupBy.builder().key(" ").build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testGroupByMinKLessThanOneThrows() { + GroupBy.builder().key("cat").minK(0).build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testGroupByMaxKLessThanOneThrows() { + GroupBy.builder().key("cat").maxK(0).build(); + } + + @Test(expected = IllegalArgumentException.class) + public void testGroupByMinKExceedsMaxKThrows() { + GroupBy.builder().key("cat").minK(5).maxK(3).build(); + } + + // ========== Wire format: global-only filter path ========== + + @Test + public void testBuildSearchItemMapGlobalFilterOnly() { + Search s = Search.builder().knn(Knn.queryText("test")).build(); + Where globalFilter = Where.eq("color", "blue"); + Map item = ChromaDtos.buildSearchItemMap(s, globalFilter); + assertNotNull("filter should be present from global filter", item.get("filter")); + } + + // ========== Wire format: Rrf normalize serialization ========== + + @Test + public void testRrfNormalizeSerialization() { + Rrf rrf = Rrf.builder() + .rank(Knn.queryText("a"), 1.0) + .normalize(true) + .build(); + Map map = ChromaDtos.buildRrfRankMap(rrf); + Map rrfMap = (Map) map.get("$rrf"); + assertEquals(true, rrfMap.get("normalize")); + } + + // ========== ReadLevel fromValue edge cases ========== + + @Test + public void testReadLevelFromValueCaseInsensitive() { + assertEquals(ReadLevel.INDEX_AND_WAL, ReadLevel.fromValue("INDEX_AND_WAL")); + assertEquals(ReadLevel.INDEX_ONLY, ReadLevel.fromValue(" index_only ")); + } + + // ========== Rrf normalize=false absent from wire format ========== + + @SuppressWarnings("unchecked") + @Test + public void testRrfNormalizeFalseNotSerialized() { + Rrf rrf = Rrf.builder() + .rank(Knn.queryText("a"), 1.0) + .build(); // normalize defaults to false + Map map = ChromaDtos.buildRrfRankMap(rrf); + Map rrfMap = (Map) map.get("$rrf"); + assertFalse("normalize should not appear when false", rrfMap.containsKey("normalize")); + } + +} diff --git a/src/test/java/tech/amikos/chromadb/v2/SelectTest.java b/src/test/java/tech/amikos/chromadb/v2/SelectTest.java new file mode 100644 index 0000000..cc90761 --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/SelectTest.java @@ -0,0 +1,72 @@ +package tech.amikos.chromadb.v2; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.*; + +public class SelectTest { + + @Test + public void testStandardConstants() { + assertEquals("#document", Select.DOCUMENT.getKey()); + assertEquals("#score", Select.SCORE.getKey()); + assertEquals("#embedding", Select.EMBEDDING.getKey()); + assertEquals("#metadata", Select.METADATA.getKey()); + assertEquals("#id", Select.ID.getKey()); + } + + @Test + public void testKeyFactory() { + assertEquals("title", Select.key("title").getKey()); + assertEquals("category", Select.key("category").getKey()); + // No "#" prefix added for custom keys + assertFalse("custom key should not start with #", Select.key("title").getKey().startsWith("#")); + } + + @Test(expected = IllegalArgumentException.class) + public void testKeyNullThrows() { + Select.key(null); + } + + @Test(expected = IllegalArgumentException.class) + public void testKeyBlankThrows() { + Select.key(" "); + } + + @Test + public void testAllReturnsAllFiveConstants() { + Select[] all = Select.all(); + assertEquals("Select.all() should return 5 elements", 5, all.length); + Set keys = new HashSet(); + for (Select s : all) { + keys.add(s.getKey()); + } + assertTrue("Should contain #id", keys.contains("#id")); + assertTrue("Should contain #document", keys.contains("#document")); + assertTrue("Should contain #embedding", keys.contains("#embedding")); + assertTrue("Should contain #metadata", keys.contains("#metadata")); + assertTrue("Should contain #score", keys.contains("#score")); + } + + @Test + public void testEqualsOnSameKey() { + Select s1 = Select.key("title"); + Select s2 = Select.key("title"); + assertEquals("Same key should be equal", s1, s2); + assertEquals("Same hashCode for same key", s1.hashCode(), s2.hashCode()); + + // Select.DOCUMENT equals a Select with key "#document" + Select docByKey = Select.key("#document"); + assertEquals("DOCUMENT constant equals key('#document')", Select.DOCUMENT, docByKey); + } + + @Test + public void testNotEqualOnDifferentKey() { + assertNotEquals("DOCUMENT should not equal SCORE", Select.DOCUMENT, Select.SCORE); + assertNotEquals("Different custom keys should not be equal", Select.key("a"), Select.key("b")); + } +} diff --git a/src/test/java/tech/amikos/chromadb/v2/SparseVectorTest.java b/src/test/java/tech/amikos/chromadb/v2/SparseVectorTest.java new file mode 100644 index 0000000..900f3d9 --- /dev/null +++ b/src/test/java/tech/amikos/chromadb/v2/SparseVectorTest.java @@ -0,0 +1,79 @@ +package tech.amikos.chromadb.v2; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public class SparseVectorTest { + + @Test + public void testOfCreatesImmutableVector() { + int[] indices = {1, 5, 10}; + float[] values = {0.3f, 0.7f, 0.2f}; + SparseVector sv = SparseVector.of(indices, values); + assertNotNull(sv); + assertArrayEquals(new int[]{1, 5, 10}, sv.getIndices()); + assertArrayEquals(new float[]{0.3f, 0.7f, 0.2f}, sv.getValues(), 1e-6f); + } + + @Test + public void testDefensiveCopyOnConstruction() { + int[] indices = {1, 5, 10}; + float[] values = {0.3f, 0.7f, 0.2f}; + SparseVector sv = SparseVector.of(indices, values); + // Mutate the original arrays + indices[0] = 99; + values[0] = 9.9f; + // SparseVector should not reflect the changes + assertArrayEquals(new int[]{1, 5, 10}, sv.getIndices()); + assertArrayEquals(new float[]{0.3f, 0.7f, 0.2f}, sv.getValues(), 1e-6f); + } + + @Test + public void testDefensiveCopyOnGetters() { + SparseVector sv = SparseVector.of(new int[]{1, 5, 10}, new float[]{0.3f, 0.7f, 0.2f}); + // Mutate the returned arrays + int[] returnedIndices = sv.getIndices(); + float[] returnedValues = sv.getValues(); + returnedIndices[0] = 99; + returnedValues[0] = 9.9f; + // SparseVector should not reflect the changes on subsequent calls + assertArrayEquals(new int[]{1, 5, 10}, sv.getIndices()); + assertArrayEquals(new float[]{0.3f, 0.7f, 0.2f}, sv.getValues(), 1e-6f); + } + + @Test(expected = IllegalArgumentException.class) + public void testNullIndicesThrows() { + SparseVector.of(null, new float[]{0.1f}); + } + + @Test(expected = IllegalArgumentException.class) + public void testNullValuesThrows() { + SparseVector.of(new int[]{1}, null); + } + + @Test(expected = IllegalArgumentException.class) + public void testMismatchedLengthThrows() { + SparseVector.of(new int[]{1, 2}, new float[]{0.1f}); + } + + @Test + public void testEqualsAndHashCode() { + SparseVector sv1 = SparseVector.of(new int[]{1, 5, 10}, new float[]{0.3f, 0.7f, 0.2f}); + SparseVector sv2 = SparseVector.of(new int[]{1, 5, 10}, new float[]{0.3f, 0.7f, 0.2f}); + SparseVector sv3 = SparseVector.of(new int[]{1, 5, 99}, new float[]{0.3f, 0.7f, 0.2f}); + + assertEquals("Same data should be equal", sv1, sv2); + assertEquals("Same data should have same hashCode", sv1.hashCode(), sv2.hashCode()); + assertNotEquals("Different data should not be equal", sv1, sv3); + } + + @Test + public void testToString() { + SparseVector sv = SparseVector.of(new int[]{1, 5}, new float[]{0.3f, 0.7f}); + String str = sv.toString(); + assertNotNull(str); + assertTrue("toString should contain indices", str.contains("1") && str.contains("5")); + assertTrue("toString should contain values", str.contains("0.3") || str.contains("0.7")); + } +}