diff --git a/RELEASE.rst b/RELEASE.rst index 81a57c60a8..61ee3173e1 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -1,6 +1,11 @@ Release Notes ============= +Version 0.63.3 +-------------- + +- Facet counts and aggregations for Vector search (#3188) + Version 0.63.2 (Released April 14, 2026) -------------- diff --git a/frontends/api/src/generated/v0/api.ts b/frontends/api/src/generated/v0/api.ts index 368c911b52..8e0eb0472a 100644 --- a/frontends/api/src/generated/v0/api.ts +++ b/frontends/api/src/generated/v0/api.ts @@ -11532,6 +11532,7 @@ export const VectorContentFilesSearchApiAxiosParamCreator = function ( /** * Vector Search for content * @summary Content File Vector Search + * @param {Array} [aggregations] aggregations for facet counts * `key` - Key * `course_number` - Course Number * `platform` - Platform * `offered_by` - Offered By * `file_extension` - File Extension * `content_feature_type` - Content Feature Type * `run_readable_id` - Run Readable Id * `resource_readable_id` - Resource Readable Id * `run_title` - Run Title * `edx_module_id` - Edx Module Id * `content_type` - Content Type * `description` - Description * `title` - Title * `url` - Url * `file_type` - File Type * `summary` - Summary * `flashcards` - Flashcards * `checksum` - Checksum * @param {string} [collection_name] Manually specify the name of the Qdrant collection to query * @param {Array} [file_extension] The extension of the content file. * @param {string} [group_by] The attribute to group results by @@ -11552,6 +11553,7 @@ export const VectorContentFilesSearchApiAxiosParamCreator = function ( * @throws {RequiredError} */ vectorContentFilesSearchRetrieve: async ( + aggregations?: Array, collection_name?: string, file_extension?: Array, group_by?: string, @@ -11586,6 +11588,10 @@ export const VectorContentFilesSearchApiAxiosParamCreator = function ( const localVarHeaderParameter = {} as any const localVarQueryParameter = {} as any + if (aggregations) { + localVarQueryParameter["aggregations"] = aggregations + } + if (collection_name !== undefined) { localVarQueryParameter["collection_name"] = collection_name } @@ -11680,6 +11686,7 @@ export const VectorContentFilesSearchApiFp = function ( /** * Vector Search for content * @summary Content File Vector Search + * @param {Array} [aggregations] aggregations for facet counts * `key` - Key * `course_number` - Course Number * `platform` - Platform * `offered_by` - Offered By * `file_extension` - File Extension * `content_feature_type` - Content Feature Type * `run_readable_id` - Run Readable Id * `resource_readable_id` - Resource Readable Id * `run_title` - Run Title * `edx_module_id` - Edx Module Id * `content_type` - Content Type * `description` - Description * `title` - Title * `url` - Url * `file_type` - File Type * `summary` - Summary * `flashcards` - Flashcards * `checksum` - Checksum * @param {string} [collection_name] Manually specify the name of the Qdrant collection to query * @param {Array} [file_extension] The extension of the content file. * @param {string} [group_by] The attribute to group results by @@ -11700,6 +11707,7 @@ export const VectorContentFilesSearchApiFp = function ( * @throws {RequiredError} */ async vectorContentFilesSearchRetrieve( + aggregations?: Array, collection_name?: string, file_extension?: Array, group_by?: string, @@ -11725,6 +11733,7 @@ export const VectorContentFilesSearchApiFp = function ( > { const localVarAxiosArgs = await localVarAxiosParamCreator.vectorContentFilesSearchRetrieve( + aggregations, collection_name, file_extension, group_by, @@ -11783,6 +11792,7 @@ export const VectorContentFilesSearchApiFactory = function ( ): AxiosPromise { return localVarFp .vectorContentFilesSearchRetrieve( + requestParameters.aggregations, requestParameters.collection_name, requestParameters.file_extension, requestParameters.group_by, @@ -11812,6 +11822,13 @@ export const VectorContentFilesSearchApiFactory = function ( * @interface VectorContentFilesSearchApiVectorContentFilesSearchRetrieveRequest */ export interface VectorContentFilesSearchApiVectorContentFilesSearchRetrieveRequest { + /** + * aggregations for facet counts * `key` - Key * `course_number` - Course Number * `platform` - Platform * `offered_by` - Offered By * `file_extension` - File Extension * `content_feature_type` - Content Feature Type * `run_readable_id` - Run Readable Id * `resource_readable_id` - Resource Readable Id * `run_title` - Run Title * `edx_module_id` - Edx Module Id * `content_type` - Content Type * `description` - Description * `title` - Title * `url` - Url * `file_type` - File Type * `summary` - Summary * `flashcards` - Flashcards * `checksum` - Checksum + * @type {Array<'key' | 'course_number' | 'platform' | 'offered_by' | 'file_extension' | 'content_feature_type' | 'run_readable_id' | 'resource_readable_id' | 'run_title' | 'edx_module_id' | 'content_type' | 'description' | 'title' | 'url' | 'file_type' | 'summary' | 'flashcards' | 'checksum'>} + * @memberof VectorContentFilesSearchApiVectorContentFilesSearchRetrieve + */ + readonly aggregations?: Array + /** * Manually specify the name of the Qdrant collection to query * @type {string} @@ -11946,6 +11963,7 @@ export class VectorContentFilesSearchApi extends BaseAPI { ) { return VectorContentFilesSearchApiFp(this.configuration) .vectorContentFilesSearchRetrieve( + requestParameters.aggregations, requestParameters.collection_name, requestParameters.file_extension, requestParameters.group_by, @@ -11968,6 +11986,31 @@ export class VectorContentFilesSearchApi extends BaseAPI { } } +/** + * @export + */ +export const VectorContentFilesSearchRetrieveAggregationsEnum = { + Key: "key", + CourseNumber: "course_number", + Platform: "platform", + OfferedBy: "offered_by", + FileExtension: "file_extension", + ContentFeatureType: "content_feature_type", + RunReadableId: "run_readable_id", + ResourceReadableId: "resource_readable_id", + RunTitle: "run_title", + EdxModuleId: "edx_module_id", + ContentType: "content_type", + Description: "description", + Title: "title", + Url: "url", + FileType: "file_type", + Summary: "summary", + Flashcards: "flashcards", + Checksum: "checksum", +} as const +export type VectorContentFilesSearchRetrieveAggregationsEnum = + (typeof VectorContentFilesSearchRetrieveAggregationsEnum)[keyof typeof VectorContentFilesSearchRetrieveAggregationsEnum] /** * @export */ @@ -11991,6 +12034,7 @@ export const VectorLearningResourcesSearchApiAxiosParamCreator = function ( /** * Vector Search for learning resources * @summary Vector Search + * @param {Array} [aggregations] aggregations for facet counts * `readable_id` - Readable Id * `resource_type` - Resource Type * `certification` - Certification * `certification_type` - Certification Type * `professional` - Professional * `free` - Free * `course_feature` - Course Feature * `topic` - Topic * `ocw_topic` - Ocw Topic * `level` - Level * `department` - Department * `platform` - Platform * `offered_by` - Offered By * `delivery` - Delivery * `title` - Title * `url` - Url * `resource_type_group` - Resource Type Group * `resource_category` - Resource Category * `published` - Published * @param {boolean | null} [certification] True if the learning resource offers a certificate * @param {Array} [certification_type] The type of certificate * `micromasters` - MicroMasters Credential * `professional` - Professional Certificate * `completion` - Certificate of Completion * `none` - No Certificate * @param {Array} [course_feature] The course feature. Possible options are at api/v1/course_features/ @@ -12005,6 +12049,7 @@ export const VectorLearningResourcesSearchApiAxiosParamCreator = function ( * @param {number} [offset] The initial index from which to return the results * @param {Array} [platform] The platform on which the learning resource is offered * `edx` - edX * `ocw` - MIT OpenCourseWare * `oll` - Open Learning Library * `mitxonline` - MITx Online * `bootcamps` - Bootcamps * `xpro` - MIT xPRO * `csail` - CSAIL * `mitpe` - MIT Professional Education * `see` - MIT Sloan Executive Education * `scc` - Schwarzman College of Computing * `ctl` - Center for Transportation & Logistics * `whu` - WHU * `susskind` - Susskind * `globalalumni` - Global Alumni * `simplilearn` - Simplilearn * `emeritus` - Emeritus * `podcast` - Podcast * `youtube` - YouTube * `canvas` - Canvas * `climate` - MIT Climate * `ovs` - ODL Video Service * @param {boolean | null} [professional] + * @param {boolean} [published] If the resource is published. We default to True unless passed in * @param {string} [q] The search text * @param {string} [readable_id] The readable id of the resource * @param {Array} [resource_type] The type of learning resource * `course` - course * `program` - program * `learning_path` - learning path * `podcast` - podcast * `podcast_episode` - podcast episode * `video` - video * `video_playlist` - video playlist * `document` - document @@ -12016,6 +12061,7 @@ export const VectorLearningResourcesSearchApiAxiosParamCreator = function ( * @throws {RequiredError} */ vectorLearningResourcesSearchRetrieve: async ( + aggregations?: Array, certification?: boolean | null, certification_type?: Array, course_feature?: Array, @@ -12030,6 +12076,7 @@ export const VectorLearningResourcesSearchApiAxiosParamCreator = function ( offset?: number, platform?: Array, professional?: boolean | null, + published?: boolean, q?: string, readable_id?: string, resource_type?: Array, @@ -12055,6 +12102,10 @@ export const VectorLearningResourcesSearchApiAxiosParamCreator = function ( const localVarHeaderParameter = {} as any const localVarQueryParameter = {} as any + if (aggregations) { + localVarQueryParameter["aggregations"] = aggregations + } + if (certification !== undefined) { localVarQueryParameter["certification"] = certification } @@ -12111,6 +12162,10 @@ export const VectorLearningResourcesSearchApiAxiosParamCreator = function ( localVarQueryParameter["professional"] = professional } + if (published !== undefined) { + localVarQueryParameter["published"] = published + } + if (q !== undefined) { localVarQueryParameter["q"] = q } @@ -12169,6 +12224,7 @@ export const VectorLearningResourcesSearchApiFp = function ( /** * Vector Search for learning resources * @summary Vector Search + * @param {Array} [aggregations] aggregations for facet counts * `readable_id` - Readable Id * `resource_type` - Resource Type * `certification` - Certification * `certification_type` - Certification Type * `professional` - Professional * `free` - Free * `course_feature` - Course Feature * `topic` - Topic * `ocw_topic` - Ocw Topic * `level` - Level * `department` - Department * `platform` - Platform * `offered_by` - Offered By * `delivery` - Delivery * `title` - Title * `url` - Url * `resource_type_group` - Resource Type Group * `resource_category` - Resource Category * `published` - Published * @param {boolean | null} [certification] True if the learning resource offers a certificate * @param {Array} [certification_type] The type of certificate * `micromasters` - MicroMasters Credential * `professional` - Professional Certificate * `completion` - Certificate of Completion * `none` - No Certificate * @param {Array} [course_feature] The course feature. Possible options are at api/v1/course_features/ @@ -12183,6 +12239,7 @@ export const VectorLearningResourcesSearchApiFp = function ( * @param {number} [offset] The initial index from which to return the results * @param {Array} [platform] The platform on which the learning resource is offered * `edx` - edX * `ocw` - MIT OpenCourseWare * `oll` - Open Learning Library * `mitxonline` - MITx Online * `bootcamps` - Bootcamps * `xpro` - MIT xPRO * `csail` - CSAIL * `mitpe` - MIT Professional Education * `see` - MIT Sloan Executive Education * `scc` - Schwarzman College of Computing * `ctl` - Center for Transportation & Logistics * `whu` - WHU * `susskind` - Susskind * `globalalumni` - Global Alumni * `simplilearn` - Simplilearn * `emeritus` - Emeritus * `podcast` - Podcast * `youtube` - YouTube * `canvas` - Canvas * `climate` - MIT Climate * `ovs` - ODL Video Service * @param {boolean | null} [professional] + * @param {boolean} [published] If the resource is published. We default to True unless passed in * @param {string} [q] The search text * @param {string} [readable_id] The readable id of the resource * @param {Array} [resource_type] The type of learning resource * `course` - course * `program` - program * `learning_path` - learning path * `podcast` - podcast * `podcast_episode` - podcast episode * `video` - video * `video_playlist` - video playlist * `document` - document @@ -12194,6 +12251,7 @@ export const VectorLearningResourcesSearchApiFp = function ( * @throws {RequiredError} */ async vectorLearningResourcesSearchRetrieve( + aggregations?: Array, certification?: boolean | null, certification_type?: Array, course_feature?: Array, @@ -12208,6 +12266,7 @@ export const VectorLearningResourcesSearchApiFp = function ( offset?: number, platform?: Array, professional?: boolean | null, + published?: boolean, q?: string, readable_id?: string, resource_type?: Array, @@ -12224,6 +12283,7 @@ export const VectorLearningResourcesSearchApiFp = function ( > { const localVarAxiosArgs = await localVarAxiosParamCreator.vectorLearningResourcesSearchRetrieve( + aggregations, certification, certification_type, course_feature, @@ -12238,6 +12298,7 @@ export const VectorLearningResourcesSearchApiFp = function ( offset, platform, professional, + published, q, readable_id, resource_type, @@ -12287,6 +12348,7 @@ export const VectorLearningResourcesSearchApiFactory = function ( ): AxiosPromise { return localVarFp .vectorLearningResourcesSearchRetrieve( + requestParameters.aggregations, requestParameters.certification, requestParameters.certification_type, requestParameters.course_feature, @@ -12301,6 +12363,7 @@ export const VectorLearningResourcesSearchApiFactory = function ( requestParameters.offset, requestParameters.platform, requestParameters.professional, + requestParameters.published, requestParameters.q, requestParameters.readable_id, requestParameters.resource_type, @@ -12321,6 +12384,13 @@ export const VectorLearningResourcesSearchApiFactory = function ( * @interface VectorLearningResourcesSearchApiVectorLearningResourcesSearchRetrieveRequest */ export interface VectorLearningResourcesSearchApiVectorLearningResourcesSearchRetrieveRequest { + /** + * aggregations for facet counts * `readable_id` - Readable Id * `resource_type` - Resource Type * `certification` - Certification * `certification_type` - Certification Type * `professional` - Professional * `free` - Free * `course_feature` - Course Feature * `topic` - Topic * `ocw_topic` - Ocw Topic * `level` - Level * `department` - Department * `platform` - Platform * `offered_by` - Offered By * `delivery` - Delivery * `title` - Title * `url` - Url * `resource_type_group` - Resource Type Group * `resource_category` - Resource Category * `published` - Published + * @type {Array<'readable_id' | 'resource_type' | 'certification' | 'certification_type' | 'professional' | 'free' | 'course_feature' | 'topic' | 'ocw_topic' | 'level' | 'department' | 'platform' | 'offered_by' | 'delivery' | 'title' | 'url' | 'resource_type_group' | 'resource_category' | 'published'>} + * @memberof VectorLearningResourcesSearchApiVectorLearningResourcesSearchRetrieve + */ + readonly aggregations?: Array + /** * True if the learning resource offers a certificate * @type {boolean} @@ -12419,6 +12489,13 @@ export interface VectorLearningResourcesSearchApiVectorLearningResourcesSearchRe */ readonly professional?: boolean | null + /** + * If the resource is published. We default to True unless passed in + * @type {boolean} + * @memberof VectorLearningResourcesSearchApiVectorLearningResourcesSearchRetrieve + */ + readonly published?: boolean + /** * The search text * @type {string} @@ -12490,6 +12567,7 @@ export class VectorLearningResourcesSearchApi extends BaseAPI { ) { return VectorLearningResourcesSearchApiFp(this.configuration) .vectorLearningResourcesSearchRetrieve( + requestParameters.aggregations, requestParameters.certification, requestParameters.certification_type, requestParameters.course_feature, @@ -12504,6 +12582,7 @@ export class VectorLearningResourcesSearchApi extends BaseAPI { requestParameters.offset, requestParameters.platform, requestParameters.professional, + requestParameters.published, requestParameters.q, requestParameters.readable_id, requestParameters.resource_type, @@ -12517,6 +12596,32 @@ export class VectorLearningResourcesSearchApi extends BaseAPI { } } +/** + * @export + */ +export const VectorLearningResourcesSearchRetrieveAggregationsEnum = { + ReadableId: "readable_id", + ResourceType: "resource_type", + Certification: "certification", + CertificationType: "certification_type", + Professional: "professional", + Free: "free", + CourseFeature: "course_feature", + Topic: "topic", + OcwTopic: "ocw_topic", + Level: "level", + Department: "department", + Platform: "platform", + OfferedBy: "offered_by", + Delivery: "delivery", + Title: "title", + Url: "url", + ResourceTypeGroup: "resource_type_group", + ResourceCategory: "resource_category", + Published: "published", +} as const +export type VectorLearningResourcesSearchRetrieveAggregationsEnum = + (typeof VectorLearningResourcesSearchRetrieveAggregationsEnum)[keyof typeof VectorLearningResourcesSearchRetrieveAggregationsEnum] /** * @export */ diff --git a/frontends/main/src/app-pages/SearchPage/SearchPage.test.tsx b/frontends/main/src/app-pages/SearchPage/SearchPage.test.tsx index fe14c1c100..2d9cb44807 100644 --- a/frontends/main/src/app-pages/SearchPage/SearchPage.test.tsx +++ b/frontends/main/src/app-pages/SearchPage/SearchPage.test.tsx @@ -152,50 +152,6 @@ describe("SearchPage", () => { }, ) - test("Vector Hybrid Search passes correct params and hides count", async () => { - setMockApiResponses({ - search: { - count: 700, - metadata: { - aggregations: { - resource_type_group: [{ key: "course", doc_count: 100 }], - }, - suggestions: [], - }, - results: factories.learningResources.resources({ count: 5 }).results, - }, - }) - - // Authenticate as path editor (admin) - setMockResponse.get(urls.userMe.get(), { - is_learning_path_editor: true, - is_authenticated: true, - }) - - renderWithProviders(, { url: "?vector_search=true&q=test" }) - - await waitFor(() => { - const call = makeRequest.mock.calls.find(([_method, url]) => { - return url.includes(urls.search.vectorResources()) - }) - expect(call).toBeDefined() - }) - - const call = makeRequest.mock.calls.find(([_method, url]) => - url.includes(urls.search.vectorResources()), - ) - invariant(call) - const fullUrl = new URL(call[1], "http://mit.edu") - const apiSearchParams = fullUrl.searchParams - - expect(apiSearchParams.get("hybrid_search")).toBe("true") - expect(apiSearchParams.get("q")).toBe("test") - - // Ensure count is hidden - const hideCountText = screen.queryByText("700 results") - expect(hideCountText).toBeNull() - }) - test("Toggling facets", async () => { setMockApiResponses({ search: { diff --git a/frontends/main/src/page-components/SearchDisplay/SearchDisplay.tsx b/frontends/main/src/page-components/SearchDisplay/SearchDisplay.tsx index 2cf28c753e..e0b53a8ee8 100644 --- a/frontends/main/src/page-components/SearchDisplay/SearchDisplay.tsx +++ b/frontends/main/src/page-components/SearchDisplay/SearchDisplay.tsx @@ -516,8 +516,8 @@ const searchModeDropdownOptions = Object.entries( /** * Extracts only the fields supported by the vector search API from a broader - * search params object, dropping admin-only params (e.g., aggregations, - * content_file_score_weight) that the vector endpoint does not accept. + * search params object, dropping admin-only params (e.g., content_file_score_weight) + * that the vector endpoint does not accept. * * The `as` casts for enum arrays are safe because the v0 and v1 generated * clients define separate (but structurally identical) enum types for the same @@ -526,6 +526,7 @@ const searchModeDropdownOptions = Object.entries( const toVectorSearchParams = ( params: ReturnType, ): VectorSearchRequest => ({ + aggregations: params.aggregations as VectorSearchRequest["aggregations"], certification: params.certification, certification_type: params.certification_type as VectorSearchRequest["certification_type"], @@ -625,10 +626,13 @@ const SearchDisplay: React.FC = ({ const wantsVectorSearch = searchParams.get("vector_search") === "true" const isVectorSearch = wantsVectorSearch && user?.is_learning_path_editor + const queryOptions = isVectorSearch + ? learningResourceQueries.vectorSearch(toVectorSearchParams(allParams)) + : learningResourceQueries.search(allParams as LRSearchRequest) + + // @ts-expect-error Typescript has trouble unifying the different query key types const { data, isLoading, isFetching } = useQuery({ - ...(isVectorSearch - ? learningResourceQueries.vectorSearch(toVectorSearchParams(allParams)) - : learningResourceQueries.search(allParams as LRSearchRequest)), + ...queryOptions, enabled: !wantsVectorSearch || !isUserLoading, placeholderData: keepPreviousData, select: (timedData: { @@ -985,9 +989,7 @@ const SearchDisplay: React.FC = ({ * the count when data is loaded even if count is same as previous * count. */} - {isFetching || isLoading || isVectorSearch - ? "" - : `${data?.count} results`} + {isFetching || isLoading ? "" : `${data?.count} results`} diff --git a/main/settings.py b/main/settings.py index 2f148300d7..f658f33bac 100644 --- a/main/settings.py +++ b/main/settings.py @@ -34,7 +34,7 @@ from main.settings_pluggy import * # noqa: F403 from openapi.settings_spectacular import open_spectacular_settings -VERSION = "0.63.2" +VERSION = "0.63.3" log = logging.getLogger() @@ -822,10 +822,10 @@ def get_all_config_keys(): QDRANT_CLIENT_TIMEOUT = get_int(name="QDRANT_CLIENT_TIMEOUT", default=10) VECTOR_HYBRID_SEARCH_PREFETCH_MULTIPLIER = get_int( - name="VECTOR_HYBRID_SEARCH_PREFETCH_MULTIPLIER", default=20 + name="VECTOR_HYBRID_SEARCH_PREFETCH_MULTIPLIER", default=5 ) VECTOR_HYBRID_SEARCH_PREFETCH_MAX_LIMIT = get_int( - name="VECTOR_HYBRID_SEARCH_PREFETCH_MAX_LIMIT", default=10000 + name="VECTOR_HYBRID_SEARCH_PREFETCH_MAX_LIMIT", default=500 ) # toggle to use requests (default for local) or webdriver which renders js elements EMBEDDINGS_EXTERNAL_FETCH_USE_WEBDRIVER = get_bool( diff --git a/openapi/specs/v0.yaml b/openapi/specs/v0.yaml index 4bb97c19c2..85bfabe127 100644 --- a/openapi/specs/v0.yaml +++ b/openapi/specs/v0.yaml @@ -827,6 +827,58 @@ paths: description: Vector Search for content summary: Content File Vector Search parameters: + - in: query + name: aggregations + schema: + type: array + items: + enum: + - key + - course_number + - platform + - offered_by + - file_extension + - content_feature_type + - run_readable_id + - resource_readable_id + - run_title + - edx_module_id + - content_type + - description + - title + - url + - file_type + - summary + - flashcards + - checksum + type: string + description: |- + * `key` - Key + * `course_number` - Course Number + * `platform` - Platform + * `offered_by` - Offered By + * `file_extension` - File Extension + * `content_feature_type` - Content Feature Type + * `run_readable_id` - Run Readable Id + * `resource_readable_id` - Resource Readable Id + * `run_title` - Run Title + * `edx_module_id` - Edx Module Id + * `content_type` - Content Type + * `description` - Description + * `title` - Title + * `url` - Url + * `file_type` - File Type + * `summary` - Summary + * `flashcards` - Flashcards + * `checksum` - Checksum + description: "aggregations for facet counts \n\n* `key` - Key\n\ + * `course_number` - Course Number\n* `platform` - Platform\n* `offered_by`\ + \ - Offered By\n* `file_extension` - File Extension\n* `content_feature_type`\ + \ - Content Feature Type\n* `run_readable_id` - Run Readable Id\n* `resource_readable_id`\ + \ - Resource Readable Id\n* `run_title` - Run Title\n* `edx_module_id` -\ + \ Edx Module Id\n* `content_type` - Content Type\n* `description` - Description\n\ + * `title` - Title\n* `url` - Url\n* `file_type` - File Type\n* `summary`\ + \ - Summary\n* `flashcards` - Flashcards\n* `checksum` - Checksum" - in: query name: collection_name schema: @@ -961,6 +1013,61 @@ paths: description: Vector Search for learning resources summary: Vector Search parameters: + - in: query + name: aggregations + schema: + type: array + items: + enum: + - readable_id + - resource_type + - certification + - certification_type + - professional + - free + - course_feature + - topic + - ocw_topic + - level + - department + - platform + - offered_by + - delivery + - title + - url + - resource_type_group + - resource_category + - published + type: string + description: |- + * `readable_id` - Readable Id + * `resource_type` - Resource Type + * `certification` - Certification + * `certification_type` - Certification Type + * `professional` - Professional + * `free` - Free + * `course_feature` - Course Feature + * `topic` - Topic + * `ocw_topic` - Ocw Topic + * `level` - Level + * `department` - Department + * `platform` - Platform + * `offered_by` - Offered By + * `delivery` - Delivery + * `title` - Title + * `url` - Url + * `resource_type_group` - Resource Type Group + * `resource_category` - Resource Category + * `published` - Published + description: "aggregations for facet counts \n\n* `readable_id`\ + \ - Readable Id\n* `resource_type` - Resource Type\n* `certification` -\ + \ Certification\n* `certification_type` - Certification Type\n* `professional`\ + \ - Professional\n* `free` - Free\n* `course_feature` - Course Feature\n\ + * `topic` - Topic\n* `ocw_topic` - Ocw Topic\n* `level` - Level\n* `department`\ + \ - Department\n* `platform` - Platform\n* `offered_by` - Offered By\n*\ + \ `delivery` - Delivery\n* `title` - Title\n* `url` - Url\n* `resource_type_group`\ + \ - Resource Type Group\n* `resource_category` - Resource Category\n* `published`\ + \ - Published" - in: query name: certification schema: @@ -1255,6 +1362,13 @@ paths: schema: type: boolean nullable: true + - in: query + name: published + schema: + type: boolean + default: true + description: If the resource is published. We default to True unless passed + in - in: query name: q schema: diff --git a/vector_search/constants.py b/vector_search/constants.py index 0adc4f31a4..cadd81622a 100644 --- a/vector_search/constants.py +++ b/vector_search/constants.py @@ -45,6 +45,8 @@ "title": "title", "url": "url", "resource_type_group": "resource_type_group", + "resource_category": "resource_category", + "published": "published", } @@ -71,6 +73,7 @@ "url": models.PayloadSchemaType.KEYWORD, "title": models.PayloadSchemaType.KEYWORD, "resource_type_group": models.PayloadSchemaType.KEYWORD, + "resource_category": models.PayloadSchemaType.KEYWORD, } """ @@ -92,3 +95,14 @@ QDRANT_TOPIC_INDEXES = { "name": models.PayloadSchemaType.KEYWORD, } + + +CONTENT_FILES_RETRIEVE_PAYLOAD = ["key", "run_readable_id"] +RESOURCES_RETRIEVE_PAYLOAD = ["readable_id"] + + +COLLECTION_PARAM_MAP = { + RESOURCES_COLLECTION_NAME: QDRANT_RESOURCE_PARAM_MAP, + TOPICS_COLLECTION_NAME: QDRANT_TOPICS_PARAM_MAP, + CONTENT_FILES_COLLECTION_NAME: QDRANT_CONTENT_FILE_PARAM_MAP, +} diff --git a/vector_search/serializers.py b/vector_search/serializers.py index c4c2e7a56a..0ef7dfcc42 100644 --- a/vector_search/serializers.py +++ b/vector_search/serializers.py @@ -20,6 +20,10 @@ SearchResponseMetadata, SearchResponseSerializer, ) +from vector_search.constants import ( + QDRANT_CONTENT_FILE_PARAM_MAP, + QDRANT_RESOURCE_PARAM_MAP, +) class LearningResourcesVectorSearchRequestSerializer(serializers.Serializer): @@ -35,6 +39,22 @@ class LearningResourcesVectorSearchRequestSerializer(serializers.Serializer): limit = serializers.IntegerField( required=False, help_text="Number of results to return per page" ) + aggregation_choices = [ + (key, key.replace("_", " ").title()) for key in QDRANT_RESOURCE_PARAM_MAP + ] + aggregations = serializers.ListField( + required=False, + child=serializers.ChoiceField(choices=aggregation_choices), + help_text=( + f"aggregations for facet counts \ + \n\n{build_choice_description_list(aggregation_choices)}" + ), + ) + published = serializers.BooleanField( + required=False, + default=True, + help_text="If the resource is published. We default to True unless passed in", + ) readable_id = serializers.CharField( required=False, help_text="The readable id of the resource" ) @@ -177,11 +197,11 @@ def get_results(self, instance): return instance.get("hits", {}) def get_count(self, instance) -> int: - return instance.get("total", {}).get("value") + return instance.get("total", {}).get("value", 0) - def get_metadata(self, _) -> SearchResponseMetadata: + def get_metadata(self, instance) -> SearchResponseMetadata: return { - "aggregations": [], + "aggregations": instance.get("aggregations", {}), "suggest": [], } @@ -198,6 +218,17 @@ class ContentFileVectorSearchRequestSerializer(serializers.Serializer): limit = serializers.IntegerField( required=False, help_text="Number of results to return per page" ) + aggregation_choices = [ + (key, key.replace("_", " ").title()) for key in QDRANT_CONTENT_FILE_PARAM_MAP + ] + aggregations = serializers.ListField( + required=False, + child=serializers.ChoiceField(choices=aggregation_choices), + help_text=( + f"aggregations for facet counts \ + \n\n{build_choice_description_list(aggregation_choices)}" + ), + ) sortby = serializers.ChoiceField( required=False, choices=CONTENT_FILE_SORTBY_OPTIONS, @@ -275,14 +306,14 @@ class ContentFileVectorSearchResponseSerializer(SearchResponseSerializer): """ def get_count(self, instance) -> int: - return instance["total"]["value"] + return instance.get("total", {}).get("value", 0) @extend_schema_field(ContentFileSerializer(many=True)) def get_results(self, instance): - return instance["hits"] + return instance.get("hits", {}) - def get_metadata(self, *_) -> SearchResponseMetadata: + def get_metadata(self, instance) -> SearchResponseMetadata: return { - "aggregations": [], + "aggregations": instance.get("aggregations", {}), "suggest": [], } diff --git a/vector_search/utils.py b/vector_search/utils.py index d5a50e05f0..0027445da4 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -1,3 +1,4 @@ +import asyncio import gc import logging import uuid @@ -32,13 +33,13 @@ ) from main.utils import checksum_for_content from vector_search.constants import ( + COLLECTION_PARAM_MAP, CONTENT_FILES_COLLECTION_NAME, QDRANT_CONTENT_FILE_INDEXES, QDRANT_CONTENT_FILE_PARAM_MAP, QDRANT_LEARNING_RESOURCE_INDEXES, QDRANT_RESOURCE_PARAM_MAP, QDRANT_TOPIC_INDEXES, - QDRANT_TOPICS_PARAM_MAP, RESOURCES_COLLECTION_NAME, TOPICS_COLLECTION_NAME, ) @@ -871,7 +872,7 @@ def process_batch(docs_batch, summaries_list): def _resource_vector_hits(search_result): - hits = [hit.payload["readable_id"] for hit in search_result] + hits = [hit.payload.get("readable_id") for hit in search_result] """ Always lookup learning resources by readable_id for portability in case we load points from external systems @@ -981,17 +982,74 @@ def document_exists(document, collection_name=RESOURCES_COLLECTION_NAME): return count_result.count > 0 +async def async_qdrant_aggregations( + aggregation_keys: list, + params: dict, + collection_name: str = RESOURCES_COLLECTION_NAME, +) -> dict: + """ + Compute facet aggregations from Qdrant for each requested field. + Issues one concurrent facet query per aggregation key and returns results + in the same shape used by the OpenSearch aggregation API: + ``{"delivery": [{"key": "online", "doc_count": 24}, ...], ...}`` + Args: + aggregation_keys: list of aggregation parameter names. + Must be valid keys in the collection's param map + (e.g. ``QDRANT_RESOURCE_PARAM_MAP``). + params: dict of all search parameters, which are used to construct + a Qdrant ``models.Filter`` for each facet query. + collection_name: name of the Qdrant collection to query. + Returns: + dict mapping each requested aggregation name to a list of + ``{"key": str, "doc_count": int}`` dicts sorted by + ``doc_count`` descending. + """ + if not aggregation_keys: + return {} + + param_map = COLLECTION_PARAM_MAP.get(collection_name, QDRANT_RESOURCE_PARAM_MAP) + client = async_qdrant_client() + + async def _get_facet(agg_key: str): + qdrant_field = param_map.get(agg_key) + if not qdrant_field: + return agg_key, [] + + filtered_params = { + k: v for k, v in params.items() if k.partition("__")[0] != agg_key + } + facet_filter = qdrant_query_conditions( + filtered_params, collection_name=collection_name + ) + + result = await client.facet( + collection_name=collection_name, + key=qdrant_field, + facet_filter=facet_filter, + limit=100, + ) + hits = [ + { + "key": str(hit.value).lower() + if isinstance(hit.value, bool) + else str(hit.value), + "doc_count": hit.count, + } + for hit in result.hits + ] + hits.sort(key=lambda x: x["doc_count"], reverse=True) + return agg_key, hits + + results = await asyncio.gather(*[_get_facet(key) for key in aggregation_keys]) + return dict(results) + + def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): """ Return a list of Qdrant FieldCondition objects based on params """ - collection_param_map = { - RESOURCES_COLLECTION_NAME: QDRANT_RESOURCE_PARAM_MAP, - TOPICS_COLLECTION_NAME: QDRANT_TOPICS_PARAM_MAP, - CONTENT_FILES_COLLECTION_NAME: QDRANT_CONTENT_FILE_PARAM_MAP, - } - qdrant_param_map = collection_param_map.get(collection_name) + qdrant_param_map = COLLECTION_PARAM_MAP.get(collection_name) if not params or not qdrant_param_map: return None must = [] diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index 8917834c5b..c9d2458a44 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -1,3 +1,4 @@ +import asyncio import random from decimal import Decimal from unittest.mock import MagicMock @@ -44,6 +45,7 @@ _get_text_splitter, _is_markdown_content, _resource_vector_hits, + async_qdrant_aggregations, create_qdrant_collections, embed_learning_resources, embed_topics, @@ -1519,3 +1521,247 @@ def test_resource_vector_hits_preserves_qdrant_score_order(): expected_readable_ids = [r.readable_id for r in shuffled] actual_readable_ids = [r["readable_id"] for r in result] assert actual_readable_ids == expected_readable_ids + + +def _make_facet_hit(count=0, value="test"): + """Build a minimal mock that looks like a Qdrant FacetHit.""" + hit = MagicMock() + hit.value = value + hit.count = count + return hit + + +def _make_facet_response(hits): + """Build a minimal mock that looks like a Qdrant FacetResponse.""" + resp = MagicMock() + resp.hits = hits + return resp + + +def test_async_qdrant_aggregations_empty_keys(mocker): + """Should return {} immediately without calling Qdrant when aggregation_keys is empty.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + result = asyncio.run(async_qdrant_aggregations([], {})) + assert result == {} + mock_client.facet.assert_not_called() + + +def test_async_qdrant_aggregations_unknown_key(mocker): + """An aggregation key not present in the param map should return an empty list.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + result = asyncio.run( + async_qdrant_aggregations( + ["nonexistent_field"], + {}, + collection_name=RESOURCES_COLLECTION_NAME, + ) + ) + assert result == {"nonexistent_field": []} + mock_client.facet.assert_not_called() + + +def test_async_qdrant_aggregations_single_key(mocker): + """A valid single aggregation key should query Qdrant and return correctly shaped data.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + + mock_client.facet.return_value = _make_facet_response( + [ + _make_facet_hit(42, value="course"), + _make_facet_hit(7, value="podcast"), + ] + ) + + result = asyncio.run( + async_qdrant_aggregations( + ["resource_type"], + {}, + collection_name=RESOURCES_COLLECTION_NAME, + ) + ) + + assert "resource_type" in result + hits = result["resource_type"] + # Should be sorted descending by doc_count + assert hits[0] == {"key": "course", "doc_count": 42} + assert hits[1] == {"key": "podcast", "doc_count": 7} + + mock_client.facet.assert_awaited_once() + call_kwargs = mock_client.facet.call_args.kwargs + assert call_kwargs["collection_name"] == RESOURCES_COLLECTION_NAME + assert call_kwargs["key"] == QDRANT_RESOURCE_PARAM_MAP["resource_type"] + assert call_kwargs["limit"] == 100 + + +def test_async_qdrant_aggregations_multiple_keys(mocker): + """Multiple valid keys should each issue a concurrent Qdrant facet call.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + + # Return different data depending on the 'key' kwarg + def _facet_side_effect(**kwargs): + if kwargs["key"] == QDRANT_RESOURCE_PARAM_MAP["resource_type"]: + return _make_facet_response([_make_facet_hit(10, value="course")]) + if kwargs["key"] == QDRANT_RESOURCE_PARAM_MAP["platform"]: + return _make_facet_response( + [_make_facet_hit(30, value="ocw"), _make_facet_hit(20, value="edx")] + ) + return _make_facet_response([]) + + mock_client.facet.side_effect = _facet_side_effect + + result = asyncio.run( + async_qdrant_aggregations( + ["resource_type", "platform"], + {}, + collection_name=RESOURCES_COLLECTION_NAME, + ) + ) + + assert set(result.keys()) == {"resource_type", "platform"} + assert result["resource_type"] == [{"key": "course", "doc_count": 10}] + # Descending sort + assert result["platform"][0] == {"key": "ocw", "doc_count": 30} + assert result["platform"][1] == {"key": "edx", "doc_count": 20} + assert mock_client.facet.await_count == 2 + + +def test_async_qdrant_aggregations_excludes_own_param_from_filter(mocker): + """ + When building the per-facet filter, the aggregation key's own param + must be excluded so that all values for that facet are counted. + """ + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + mock_client.facet.return_value = _make_facet_response([]) + + params = { + "resource_type": ["course"], + "platform": ["ocw"], + } + + asyncio.run( + async_qdrant_aggregations( + ["resource_type"], + params, + collection_name=RESOURCES_COLLECTION_NAME, + ) + ) + + mock_client.facet.assert_awaited_once() + call_kwargs = mock_client.facet.call_args.kwargs + + # The facet_filter should NOT contain a condition for resource_type + # (it was stripped out so we get all resource_type facet values), + # but it SHOULD still filter by platform. + facet_filter = call_kwargs.get("facet_filter") + # facet_filter is a qdrant models.Filter with must conditions + assert facet_filter is not None + condition_keys = [c.key for c in facet_filter.must if hasattr(c, "key")] + assert QDRANT_RESOURCE_PARAM_MAP["platform"] in condition_keys + assert QDRANT_RESOURCE_PARAM_MAP["resource_type"] not in condition_keys + + +def test_async_qdrant_aggregations_bool_values_lowercased(mocker): + """Boolean hit values must be returned as lowercase strings ('true'/'false').""" + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + + mock_client.facet.return_value = _make_facet_response( + [ + _make_facet_hit(5, value=True), + _make_facet_hit(3, value=False), + ] + ) + + result = asyncio.run( + async_qdrant_aggregations( + ["free"], + {}, + collection_name=RESOURCES_COLLECTION_NAME, + ) + ) + + keys = {hit["key"] for hit in result["free"]} + assert "true" in keys + assert "false" in keys + # Verify no raw booleans slipped through + assert True not in keys + assert False not in keys + + +def test_async_qdrant_aggregations_sorted_by_doc_count_desc(mocker): + """Results must be sorted by doc_count in descending order.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + + mock_client.facet.return_value = _make_facet_response( + [ + _make_facet_hit(5, value="edx"), + _make_facet_hit(100, value="ocw"), + _make_facet_hit(20, value="xpro"), + ] + ) + + result = asyncio.run( + async_qdrant_aggregations( + ["platform"], + {}, + collection_name=RESOURCES_COLLECTION_NAME, + ) + ) + + counts = [hit["doc_count"] for hit in result["platform"]] + assert counts == sorted(counts, reverse=True) + + +def test_async_qdrant_aggregations_uses_content_file_param_map(mocker): + """ + When collection_name is CONTENT_FILES_COLLECTION_NAME the function must + use QDRANT_CONTENT_FILE_PARAM_MAP to resolve the Qdrant field name. + """ + mock_client = mocker.AsyncMock() + mocker.patch( + "vector_search.utils.async_qdrant_client", + return_value=mock_client, + ) + mock_client.facet.return_value = _make_facet_response( + [_make_facet_hit(8, value=".pdf")] + ) + + result = asyncio.run( + async_qdrant_aggregations( + ["file_extension"], + {}, + collection_name=CONTENT_FILES_COLLECTION_NAME, + ) + ) + + assert "file_extension" in result + call_kwargs = mock_client.facet.call_args.kwargs + assert call_kwargs["collection_name"] == CONTENT_FILES_COLLECTION_NAME + # The Qdrant field for 'file_extension' should come from the content-file map + assert call_kwargs["key"] == QDRANT_CONTENT_FILE_PARAM_MAP["file_extension"] diff --git a/vector_search/views.py b/vector_search/views.py index 4466edd129..5335351c46 100644 --- a/vector_search/views.py +++ b/vector_search/views.py @@ -17,6 +17,12 @@ from authentication.decorators import blocked_ip_exempt from learning_resources.constants import GROUP_CONTENT_FILE_CONTENT_VIEWERS from main.utils import cache_page_for_anonymous_users +from vector_search.constants import ( + CONTENT_FILES_COLLECTION_NAME, + CONTENT_FILES_RETRIEVE_PAYLOAD, + RESOURCES_COLLECTION_NAME, + RESOURCES_RETRIEVE_PAYLOAD, +) from vector_search.serializers import ( ContentFileVectorSearchRequestSerializer, ContentFileVectorSearchResponseSerializer, @@ -24,11 +30,10 @@ LearningResourcesVectorSearchResponseSerializer, ) from vector_search.utils import ( - CONTENT_FILES_COLLECTION_NAME, - RESOURCES_COLLECTION_NAME, _content_file_vector_hits, _merge_dicts, _resource_vector_hits, + async_qdrant_aggregations, async_qdrant_client, dense_encoder, qdrant_query_conditions, @@ -82,12 +87,12 @@ async def dispatch(self, request, *args, **kwargs): self.response = self.finalize_response(request, response, *args, **kwargs) return self.response - async def async_vector_search( # noqa: PLR0913 + async def async_vector_search( # noqa: PLR0913, PLR0915 self, query_string: str, params: dict, limit: int = 10, - offset: int = 10, + offset: int = 0, search_collection=RESOURCES_COLLECTION_NAME, *, hybrid_search: bool = False, @@ -113,8 +118,19 @@ async def async_vector_search( # noqa: PLR0913 "collection_name": search_collection, "query_filter": search_filter, "with_vectors": False, - "with_payload": True, - "search_params": models.SearchParams(indexed_only=True, exact=False), + "with_payload": RESOURCES_RETRIEVE_PAYLOAD + if search_collection == RESOURCES_COLLECTION_NAME + else CONTENT_FILES_RETRIEVE_PAYLOAD, + "search_params": models.SearchParams( + quantization=models.QuantizationSearchParams( + ignore=False, + rescore=True, + oversampling=1, + ), + hnsw_ef=64, + indexed_only=True, + exact=False, + ), "limit": limit, } @@ -151,6 +167,7 @@ async def async_vector_search( # noqa: PLR0913 search_params.pop("search_params", None) search_params["group_by"] = params.get("group_by") search_params["group_size"] = params.get("group_size", 1) + search_params["with_payload"] = True group_result = await client.query_points_groups(**search_params) search_result = [] for group in group_result.groups: @@ -171,29 +188,56 @@ async def async_vector_search( # noqa: PLR0913 result_obj = await client.query_points(**search_params) search_result = result_obj.points else: - scroll_res = await client.scroll( - collection_name=search_collection, - scroll_filter=search_filter, - limit=limit, - offset=offset, - with_vectors=False, - ) - search_result = scroll_res[0] - - if search_collection == RESOURCES_COLLECTION_NAME: - hits = await sync_to_async(_resource_vector_hits)(search_result) - else: - hits = await sync_to_async(_content_file_vector_hits)(search_result) + # Qdrant's scroll API uses a point-ID cursor for `offset`, not a + # numeric skip count. We implement integer offset by consuming + # scroll pages until the desired number of records are skipped. + remaining_to_skip = offset + next_page_offset = None + search_result = [] + while True: + fetch_size = min(max(remaining_to_skip, limit), 1000) + scroll_res = await client.scroll( + collection_name=search_collection, + scroll_filter=search_filter, + limit=fetch_size, + offset=next_page_offset, + with_vectors=False, + ) + page_points, next_page_offset = scroll_res + if remaining_to_skip > 0: + skipped = min(remaining_to_skip, len(page_points)) + page_points = page_points[skipped:] + remaining_to_skip -= skipped + search_result.extend(page_points) + if len(search_result) >= limit or not next_page_offset: + break + search_result = search_result[:limit] + + hits_coroutine = ( + sync_to_async(_resource_vector_hits)(search_result) + if search_collection == RESOURCES_COLLECTION_NAME + else sync_to_async(_content_file_vector_hits)(search_result) + ) - count_result = await client.count( - collection_name=search_collection, - count_filter=search_filter, - exact=False, + aggregation_keys = params.get("aggregations") or [] + hits, count_result, aggregations = await asyncio.gather( + hits_coroutine, + client.count( + collection_name=search_collection, + count_filter=search_filter, + exact=False, + ), + async_qdrant_aggregations( + aggregation_keys, + params, + collection_name=search_collection, + ), ) return { "hits": hits, "total": {"value": count_result.count}, + "aggregations": aggregations or {}, } def handle_exception(self, exc): diff --git a/vector_search/views_test.py b/vector_search/views_test.py index 774ca25436..981fc4371b 100644 --- a/vector_search/views_test.py +++ b/vector_search/views_test.py @@ -14,7 +14,7 @@ def test_vector_search_filters(mocker, client): mock_qdrant = mocker.patch( "qdrant_client.AsyncQdrantClient", return_value=mocker.AsyncMock() )() - mock_qdrant.scroll = mocker.AsyncMock(return_value=[[]]) + mock_qdrant.scroll = mocker.AsyncMock(return_value=([], None)) mock_qdrant.query_points = mocker.AsyncMock() mock_qdrant.query_points_groups = mocker.AsyncMock() mocker.patch( @@ -63,7 +63,7 @@ def test_vector_search_filters_empty_query(mocker, client): mock_qdrant = mocker.patch( "qdrant_client.AsyncQdrantClient", return_value=mocker.AsyncMock() )() - mock_qdrant.scroll = mocker.AsyncMock(return_value=[[]]) + mock_qdrant.scroll = mocker.AsyncMock(return_value=([], None)) mock_qdrant.query_points = mocker.AsyncMock() mock_qdrant.query_points_groups = mocker.AsyncMock() mock_qdrant.count = mocker.AsyncMock(return_value=CountResult(count=10)) @@ -124,7 +124,7 @@ def test_content_file_vector_search_filters( mock_qdrant = mocker.patch( "qdrant_client.AsyncQdrantClient", return_value=mocker.AsyncMock() )() - mock_qdrant.scroll = mocker.AsyncMock(return_value=[[]]) + mock_qdrant.scroll = mocker.AsyncMock(return_value=([], None)) mock_qdrant.query_points = mocker.AsyncMock() mock_qdrant.query_points_groups = mocker.AsyncMock() mocker.patch( @@ -201,7 +201,7 @@ def test_content_file_vector_search_filters_empty_query( mock_qdrant = mocker.patch( "qdrant_client.AsyncQdrantClient", return_value=mocker.AsyncMock() )() - mock_qdrant.scroll = mocker.AsyncMock(return_value=[[]]) + mock_qdrant.scroll = mocker.AsyncMock(return_value=([], None)) mock_qdrant.query_points = mocker.AsyncMock() mock_qdrant.query_points_groups = mocker.AsyncMock() mocker.patch( @@ -255,7 +255,7 @@ def test_content_file_vector_search_filters_custom_collection( "qdrant_client.AsyncQdrantClient", return_value=mocker.AsyncMock() )() custom_collection_name = "foo_bar_collection" - mock_qdrant.scroll = mocker.AsyncMock(return_value=[[]]) + mock_qdrant.scroll = mocker.AsyncMock(return_value=([], None)) mock_qdrant.query_points = mocker.AsyncMock() mock_qdrant.query_points_groups = mocker.AsyncMock() mocker.patch( @@ -300,7 +300,7 @@ def test_content_file_vector_search_group_parameters(mocker, client, django_user )() custom_collection_name = "foo_bar_collection" - mock_qdrant.scroll = mocker.AsyncMock(return_value=[[]]) + mock_qdrant.scroll = mocker.AsyncMock(return_value=([], None)) mock_qdrant.query_points = mocker.AsyncMock() mock_qdrant.query_points_groups = mocker.AsyncMock() mocker.patch(