diff --git a/.changeset/yellow-crabs-attend.md b/.changeset/yellow-crabs-attend.md new file mode 100644 index 000000000..c36539820 --- /dev/null +++ b/.changeset/yellow-crabs-attend.md @@ -0,0 +1,5 @@ +--- +"braintrust": minor +--- + +(feat) Add experiment dataset filters to experiment metadata diff --git a/js/src/logger.test.ts b/js/src/logger.test.ts index c4536bb78..21231d8e0 100644 --- a/js/src/logger.test.ts +++ b/js/src/logger.test.ts @@ -484,6 +484,156 @@ function mockInitGitMetadata() { ).mockResolvedValue([]); } +test("init forwards dataset _internal_btql to experiment register", async () => { + const state = await _exportsForTestingOnly.simulateLoginForTests(); + + try { + vi.spyOn(state, "login").mockResolvedValue(state); + mockInitGitMetadata(); + + const datasetFilter = { + filter: [ + { + op: "eq", + left: { op: "ident", name: ["metadata", "model"] }, + right: { op: "literal", value: "gpt-5-mini" }, + }, + { + op: "isnotnull", + expr: { op: "ident", name: ["expected"] }, + }, + ], + }; + + let experimentRegisterBody: unknown; + vi.spyOn(state.appConn(), "post_json") + .mockResolvedValueOnce({ + project: { + id: "00000000-0000-0000-0000-000000000001", + name: "test-project", + }, + dataset: { + id: "00000000-0000-0000-0000-000000000002", + name: "test-dataset", + }, + }) + .mockImplementationOnce(async (_path, body) => { + experimentRegisterBody = body; + return { + project: { + id: "00000000-0000-0000-0000-000000000001", + name: "test-project", + }, + experiment: { + id: "00000000-0000-0000-0000-000000000003", + project_id: "00000000-0000-0000-0000-000000000001", + name: "test-experiment", + public: false, + }, + }; + }); + + const dataset = initDataset({ + project: "test-project", + dataset: "test-dataset", + version: "123", + _internal_btql: datasetFilter, + state, + }); + + const experiment = init({ + project: "test-project", + experiment: "test-experiment", + dataset, + setCurrent: false, + state, + }); + + await experiment.id; + + expect(experimentRegisterBody).toEqual( + expect.objectContaining({ + internal_metadata: { + dataset_filter: datasetFilter, + }, + }), + ); + } finally { + _exportsForTestingOnly.simulateLogoutForTests(); + vi.restoreAllMocks(); + } +}); + +test("dataset fetch forwards _internal_btql filter arrays to btql", async () => { + const state = await _exportsForTestingOnly.simulateLoginForTests(); + + try { + vi.spyOn(state, "login").mockResolvedValue(state); + + const datasetFilter = { + filter: [ + { + op: "eq", + left: { op: "ident", name: ["metadata", "model"] }, + right: { op: "literal", value: "gpt-5-mini" }, + }, + { + op: "isnotnull", + expr: { op: "ident", name: ["expected"] }, + }, + ], + limit: 5, + }; + + vi.spyOn(state.appConn(), "post_json").mockResolvedValue({ + project: { + id: "00000000-0000-0000-0000-000000000001", + name: "test-project", + }, + dataset: { + id: "00000000-0000-0000-0000-000000000002", + name: "test-dataset", + }, + }); + + let btqlBody: unknown; + vi.spyOn(state.apiConn(), "post").mockImplementation( + async (_path, body) => { + btqlBody = body; + return new Response(JSON.stringify({ data: [] }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + }, + ); + + const dataset = initDataset({ + project: "test-project", + dataset: "test-dataset", + _internal_btql: datasetFilter, + state, + }); + + const rows: unknown[] = []; + for await (const row of dataset) { + rows.push(row); + } + + expect(rows).toEqual([]); + expect(btqlBody).toEqual( + expect.objectContaining({ + query: expect.objectContaining({ + filter: datasetFilter.filter, + limit: 5, + }), + }), + ); + } finally { + _exportsForTestingOnly.simulateLogoutForTests(); + vi.restoreAllMocks(); + } +}); + test("initDataset prefers version over environment in eval data", async () => { const state = await _exportsForTestingOnly.simulateLoginForTests(); vi.spyOn(state, "login").mockResolvedValue(state); @@ -948,7 +1098,6 @@ test("init keeps plain dataset refs attached to the experiment", async () => { }); await experiment.id; - expect(experiment.dataset).toMatchObject({ id: "00000000-0000-0000-0000-000000000002", }); diff --git a/js/src/logger.ts b/js/src/logger.ts index ef8ed3a65..81da807fe 100644 --- a/js/src/logger.ts +++ b/js/src/logger.ts @@ -3472,6 +3472,7 @@ export type InitOptions = FullLoginOptions & { experiment?: string; description?: string; dataset?: AnyDataset | DatasetRef; + _internal_btql?: Record; parameters?: ParametersRef | RemoteEvalParameters; update?: boolean; baseExperiment?: string; @@ -3490,6 +3491,32 @@ export type FullInitOptions = { project?: string; } & InitOptions; +function getExperimentDatasetFilter({ + dataset, + _internal_btql, +}: { + dataset?: AnyDataset | DatasetRef; + _internal_btql?: Record; +}): Record | undefined { + if (_internal_btql !== undefined) { + return _internal_btql; + } + + if (!(dataset instanceof Dataset)) { + return undefined; + } + + const datasetFilter = Reflect.get(dataset, "_internal_btql"); + return isObject(datasetFilter) ? datasetFilter : undefined; +} + +function getInternalBtqlLimit( + internalBtql?: Record, +): number | undefined { + const limit = internalBtql?.["limit"]; + return typeof limit === "number" ? limit : undefined; +} + type InitializedExperiment = IsOpen extends true ? ReadonlyExperiment : Experiment; @@ -3556,6 +3583,7 @@ export function init( experiment, description, dataset, + _internal_btql, parameters, baseExperiment, isPublic, @@ -3697,6 +3725,16 @@ export function init( } } + const datasetFilter = getExperimentDatasetFilter({ + dataset, + _internal_btql, + }); + if (datasetFilter !== undefined) { + args["internal_metadata"] = { + dataset_filter: datasetFilter, + }; + } + if (parameters !== undefined) { if (RemoteEvalParameters.isParameters(parameters)) { args["parameters_id"] = parameters.id; @@ -6046,9 +6084,7 @@ export class ObjectFetcher implements AsyncIterable< const state = await this.getState(); const objectId = await this.id; const batchLimit = batchSize ?? DEFAULT_FETCH_BATCH_SIZE; - const internalLimit = ( - this._internal_btql as { limit?: number } | undefined - )?.limit; + const internalLimit = getInternalBtqlLimit(this._internal_btql); const limit = batchSize !== undefined ? batchSize : (internalLimit ?? batchLimit); const internalBtqlWithoutReservedQueryKeys = Object.fromEntries(