diff --git a/api/src/api/request/pipeline.js b/api/src/api/request/pipeline.js index 290eb3b2..66fa8a0e 100644 --- a/api/src/api/request/pipeline.js +++ b/api/src/api/request/pipeline.js @@ -49,7 +49,36 @@ module.exports = class RequestPipeline { return this; } + addNeuralModelId() { + const neuralModelId = process.env.OPENSEARCH_MODEL_ID; + if (!neuralModelId) return this; + + const recursivelyAddNeuralModelId = (query) => { + if (Array.isArray(query)) { + for (const subQuery of query) { + recursivelyAddNeuralModelId(subQuery); + } + } + + if (typeof query !== "object" || query === null) return this; + + for (const key in query) { + if (key === "neural") { + const [field] = Object.keys(query.neural); + query.neural[field].model_id ||= neuralModelId; + } else { + recursivelyAddNeuralModelId(query[key]); + } + } + }; + + recursivelyAddNeuralModelId(this.searchContext.query); + + return this; + } + toJson() { + this.addNeuralModelId(); return JSON.stringify(sortJson(this.searchContext)); } }; diff --git a/api/test/unit/api/request/pipeline.test.js b/api/test/unit/api/request/pipeline.test.js index 470678d1..bd15c917 100644 --- a/api/test/unit/api/request/pipeline.test.js +++ b/api/test/unit/api/request/pipeline.test.js @@ -151,4 +151,93 @@ describe("RequestPipeline", () => { } }); }); + + describe("addNeuralModelId", () => { + let oldModelId; + beforeEach(() => { + oldModelId = process.env.OPENSEARCH_MODEL_ID; + process.env.OPENSEARCH_MODEL_ID = "MODEL_ID"; + requestBody.query = { + neural: { + embedding: { + query_text: + "Do you have any materials related to testing the request pipeline?", + k: 5, + }, + }, + }; + pipeline = new RequestPipeline(requestBody); + }); + + afterEach(() => { + if (oldModelId) { + process.env.OPENSEARCH_MODEL_ID = oldModelId; + } else { + delete process.env.OPENSEARCH_MODEL_ID; + } + oldModelId = null; + }); + + it("does not modify the query if OPENSEARCH_MODEL_ID is not set", () => { + delete process.env.OPENSEARCH_MODEL_ID; + pipeline.addNeuralModelId(); + expect(pipeline.searchContext.query).to.deep.equal(requestBody.query); + }); + + it("does not modify the query if there are no neural queries", () => { + requestBody.query = { + term: { + all_titles: "request pipeline testing", + }, + }; + pipeline = new RequestPipeline(requestBody); + pipeline.addNeuralModelId(); + expect(pipeline.searchContext.query).to.deep.equal(requestBody.query); + }); + + it("does not modify the query if there is already a model_id", () => { + requestBody.query.neural.embedding.model_id = "EXISTING_MODEL_ID"; + pipeline = new RequestPipeline(requestBody); + pipeline.addNeuralModelId(); + expect(pipeline.searchContext.query.neural.embedding.model_id).to.eq( + "EXISTING_MODEL_ID" + ); + }); + + it("automatically adds the model_id to a neural query", () => { + pipeline.addNeuralModelId(); + expect(pipeline.searchContext.query.neural.embedding.model_id).to.eq( + "MODEL_ID" + ); + }); + + it("recursively adds the model_id to all neural queries in a hybrid query", () => { + event.userToken = new ApiToken(); + requestBody.query = { + hybrid: { + queries: [ + { + neural: { + embedding: { + query_text: + "Do you have any materials related to testing the request pipeline?", + k: 5, + }, + }, + }, + { + term: { + all_titles: "request pipeline testing", + }, + }, + ], + }, + }; + pipeline = new RequestPipeline(requestBody); + pipeline.addNeuralModelId(); + expect( + pipeline.searchContext.query.hybrid.queries[0].neural.embedding.model_id + ).to.eq("MODEL_ID"); + }); + }); });