From 1ecc608f334a18b90955e3fab70d75662cfb3274 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Thu, 30 Apr 2026 16:58:47 -0400 Subject: [PATCH 1/7] clean up code --- README.md | 23 + scripts/dataset-pipeline-runner.ts | 483 +++++++++++++++ src/auth.rs | 12 + src/datasets/mod.rs | 24 + src/datasets/pipeline.rs | 904 +++++++++++++++++++++++++++++ src/eval.rs | 166 +----- src/js_runner.rs | 39 +- src/main.rs | 7 + src/sync.rs | 164 +++++- src/sync/discovery.rs | 223 +++++++ 10 files changed, 1875 insertions(+), 170 deletions(-) create mode 100644 scripts/dataset-pipeline-runner.ts create mode 100644 src/datasets/mod.rs create mode 100644 src/datasets/pipeline.rs create mode 100644 src/sync/discovery.rs diff --git a/README.md b/README.md index 78221b0d..9ed65295 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ Remove-Item -Recurse -Force (Join-Path $env:APPDATA "bt") -ErrorAction SilentlyC | `bt auth` | Authenticate with Braintrust | | `bt switch` | Switch org and project context | | `bt status` | Show current org and project context | +| `bt datasets` | Manage datasets and dataset pipelines | | `bt eval` | Run eval files (Unix only) | | `bt sql` | Run SQL queries against Braintrust | | `bt view` | View logs, traces, and spans | @@ -157,6 +158,28 @@ bt eval foo.eval.ts -- --description "Prod" --shard=1/4 - `bt eval --sample 20 --sample-seed 7 qa.eval.ts` — run a deterministic random sample and clearly label the summary as a non-final smoke run. - If you do not pass a sampling flag, `bt eval` runs the full dataset and marks the summary as final. +## `bt datasets pipeline` + +Run TypeScript dataset pipelines declared with `DatasetPipeline(...)` from the `braintrust` SDK. + +```bash +# One-shot execution: discover refs, transform, and insert up to 100 new rows. +bt datasets pipeline run ./pipeline.ts --target 100 + +# Staged execution for human or agent review. +bt datasets pipeline fetch ./pipeline.ts --target 500 --out refs.jsonl +bt datasets pipeline transform ./pipeline.ts --in refs.jsonl --out proposed.jsonl +bt datasets pipeline review ./pipeline.ts --in proposed.jsonl --out approved.jsonl +bt datasets pipeline commit ./pipeline.ts --in approved.jsonl +``` + +Useful flags: + +- `--root-span-id ` restricts fetching to one or more specific root spans. +- `--extra-where-sql ` appends a source SQL predicate. +- `--max-concurrency ` controls transform concurrency. +- `--name ` selects a pipeline when the file defines more than one. + ## `bt sql` - Runs interactively on TTY by default. diff --git a/scripts/dataset-pipeline-runner.ts b/scripts/dataset-pipeline-runner.ts new file mode 100644 index 00000000..8ff7c801 --- /dev/null +++ b/scripts/dataset-pipeline-runner.ts @@ -0,0 +1,483 @@ +import { createRequire } from "node:module"; +import fs from "node:fs"; +import { pathToFileURL } from "node:url"; +import path from "node:path"; + +type PipelineSource = { + projectName?: string; + projectId?: string; + orgName?: string; + filter?: string; + scope?: "span" | "trace"; + limit?: number; +}; + +type PipelineTarget = { + projectName?: string; + projectId?: string; + orgName?: string; + datasetName?: string; + description?: string; + metadata?: Record; +}; + +type DatasetPipelineDefinition = { + name?: string; + source?: PipelineSource; + target?: PipelineTarget; + transform?: ( + candidate: HydratedCandidate, + context: { pipeline: DatasetPipelineDefinition }, + ) => unknown | Promise; +}; + +type BraintrustModule = { + DatasetPipeline?: ( + definition: DatasetPipelineDefinition, + ) => DatasetPipelineDefinition; + getRegisteredDatasetPipelines?: () => DatasetPipelineDefinition[]; + isDatasetPipelineDefinition?: ( + value: unknown, + ) => value is DatasetPipelineDefinition; + LocalTrace?: new (options: { + objectType: "project_logs"; + objectId: string; + rootSpanId: string; + state: unknown; + }) => unknown; + _internalGetGlobalState?: () => BraintrustState; + loginToState?: (options: { orgName: string }) => Promise; + default?: BraintrustModule; +}; + +type BraintrustState = { + loggedIn?: boolean; + orgName?: string; + login: (options: Record) => Promise; +}; + +type DiscoveryRef = { + root_span_id?: unknown; + id?: unknown; +}; + +type HydratedCandidate = { + trace: unknown; + id?: string; + origin?: { + object_type: "project_logs"; + object_id: string; + id: string; + }; +}; + +type Stage = "inspect" | "transform"; + +function isObject(value: unknown): value is Record { + return typeof value === "object" && value !== null; +} + +function normalizeBraintrustModule(value: unknown): BraintrustModule { + if (isObject(value) && "default" in value && isObject(value.default)) { + return value.default as BraintrustModule; + } + if (isObject(value)) { + return value as BraintrustModule; + } + throw new Error("Unable to load braintrust module."); +} + +function resolveBraintrustPath(pipelineFile: string): string { + const file = path.resolve(process.cwd(), pipelineFile); + try { + const require = createRequire(pathToFileURL(file).href); + return require.resolve("braintrust"); + } catch {} + + try { + const require = createRequire(process.cwd() + "/"); + return require.resolve("braintrust"); + } catch { + throw new Error( + "Unable to resolve the `braintrust` package. Please install it in your project.", + ); + } +} + +async function loadBraintrust(pipelineFile: string): Promise { + const cjsPath = resolveBraintrustPath(pipelineFile); + const cjsUrl = pathToFileURL(cjsPath).href; + + try { + return normalizeBraintrustModule(await import(cjsUrl)); + } catch {} + + const esmPath = cjsPath.replace(/\.js$/, ".mjs"); + if (esmPath !== cjsPath && fs.existsSync(esmPath)) { + try { + return normalizeBraintrustModule( + await import(pathToFileURL(esmPath).href), + ); + } catch {} + } + + const require = createRequire(cjsUrl); + return normalizeBraintrustModule(require(cjsPath)); +} + +async function loadPipelineFile(file: string): Promise { + const absolute = path.resolve(process.cwd(), file); + const fileUrl = pathToFileURL(absolute).href; + try { + return await import(fileUrl); + } catch (importErr) { + try { + const require = createRequire(fileUrl); + return require(absolute); + } catch (requireErr) { + throw new Error( + `Failed to load ${file}: import failed with ${formatError(importErr)}; require failed with ${formatError(requireErr)}`, + ); + } + } +} + +function formatError(err: unknown): string { + return err instanceof Error ? err.message : String(err); +} + +function collectPipelines( + braintrust: BraintrustModule, + loadedModule: unknown, +): DatasetPipelineDefinition[] { + const pipelines = new Set(); + const isPipeline = (value: unknown): value is DatasetPipelineDefinition => + (braintrust.isDatasetPipelineDefinition?.(value) ?? false) || + (isObject(value) && + isObject(value.source) && + isObject(value.target) && + typeof value.transform === "function"); + + for (const pipeline of braintrust.getRegisteredDatasetPipelines?.() ?? []) { + pipelines.add(pipeline); + } + + if (isObject(loadedModule)) { + for (const value of Object.values(loadedModule)) { + if (isPipeline(value)) { + pipelines.add(value); + } + } + } + + if (isPipeline(loadedModule)) { + pipelines.add(loadedModule); + } + + return [...pipelines]; +} + +function selectPipeline( + pipelines: DatasetPipelineDefinition[], + name: string | undefined, +): DatasetPipelineDefinition { + if (name) { + const matches = pipelines.filter((pipeline) => pipeline.name === name); + if (matches.length === 0) { + throw new Error( + `No dataset pipeline named ${JSON.stringify(name)} found.`, + ); + } + if (matches.length > 1) { + throw new Error( + `Multiple dataset pipelines named ${JSON.stringify(name)} found.`, + ); + } + return matches[0]; + } + + if (pipelines.length === 0) { + throw new Error( + "No dataset pipelines found. Did you call DatasetPipeline()?", + ); + } + if (pipelines.length > 1) { + const names = pipelines + .map((pipeline) => pipeline.name ?? "") + .join(", "); + throw new Error( + `Multiple dataset pipelines found (${names}). Pass --name.`, + ); + } + return pipelines[0]; +} + +function parseStage(): Stage { + const value = process.env.BT_DATASET_PIPELINE_STAGE; + if (value === "inspect" || value === "transform") { + return value; + } + throw new Error("BT_DATASET_PIPELINE_STAGE must be inspect or transform."); +} + +async function readRequest(): Promise { + const chunks: Buffer[] = []; + for await (const chunk of process.stdin) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(String(chunk))); + } + const text = Buffer.concat(chunks).toString("utf8").trim(); + return text.length > 0 ? JSON.parse(text) : {}; +} + +function writeResponse(value: unknown): void { + process.stdout.write(`${JSON.stringify(value)}\n`); +} + +function requireArrayField(request: unknown, field: string): unknown[] { + if (!isObject(request) || !Array.isArray(request[field])) { + throw new Error(`Request field ${field} must be an array.`); + } + return request[field] as unknown[]; +} + +function requireStringField(request: unknown, field: string): string { + if (!isObject(request) || typeof request[field] !== "string") { + throw new Error(`Request field ${field} must be a string.`); + } + return request[field] as string; +} + +function optionalPositiveIntegerField( + request: unknown, + field: string, +): number | undefined { + if (!isObject(request) || request[field] === undefined) { + return undefined; + } + const value = request[field]; + if (!Number.isInteger(value) || (value as number) <= 0) { + throw new Error(`Request field ${field} must be a positive integer.`); + } + return value as number; +} + +function requirePipelineSource( + pipeline: DatasetPipelineDefinition, +): PipelineSource { + if (!isObject(pipeline.source)) { + throw new Error("Dataset pipeline source is required."); + } + return pipeline.source; +} + +function requireBraintrustRuntime(braintrust: BraintrustModule) { + if ( + !braintrust.LocalTrace || + !braintrust._internalGetGlobalState || + !braintrust.loginToState + ) { + throw new Error( + "The installed braintrust package does not include dataset pipeline runtime support.", + ); + } +} + +async function stateForOrg( + braintrust: BraintrustModule, + orgName: string | undefined, +): Promise { + if (!braintrust._internalGetGlobalState || !braintrust.loginToState) { + throw new Error("The installed braintrust package cannot authenticate."); + } + const state = braintrust._internalGetGlobalState(); + if (!orgName) { + await state.login({}); + return state; + } + if (!state.loggedIn) { + await state.login({ orgName }); + return state; + } + if (state.orgName === orgName) { + return state; + } + return braintrust.loginToState({ orgName }); +} + +function refRootSpanId(ref: unknown): string { + if (!isObject(ref) || typeof ref.root_span_id !== "string") { + throw new Error("Discovery ref is missing root_span_id."); + } + return ref.root_span_id; +} + +function refSpanRowId(ref: DiscoveryRef): string | undefined { + return typeof ref.id === "string" ? ref.id : undefined; +} + +async function hydrateDiscoveryRefs( + braintrust: BraintrustModule, + pipeline: DatasetPipelineDefinition, + sourceProjectId: string, + refs: unknown[], +): Promise { + requireBraintrustRuntime(braintrust); + const source = requirePipelineSource(pipeline); + const state = await stateForOrg(braintrust, source.orgName); + return refs.map((ref) => { + const rootSpanId = refRootSpanId(ref); + const id = refSpanRowId(ref as DiscoveryRef); + return { + trace: new braintrust.LocalTrace!({ + objectType: "project_logs", + objectId: sourceProjectId, + rootSpanId, + state, + }), + ...(id ? { id } : {}), + ...(id + ? { + origin: { + object_type: "project_logs" as const, + object_id: sourceProjectId, + id, + }, + } + : {}), + }; + }); +} + +function normalizeTransformResult(result: unknown): unknown[] { + if (result == null) { + return []; + } + return Array.isArray(result) ? result : [result]; +} + +function candidateFallbackId(candidate: HydratedCandidate): string | undefined { + if (candidate.id) { + return candidate.id; + } + const trace = candidate.trace; + if ( + isObject(trace) && + typeof trace.getConfiguration === "function" && + isObject(trace.getConfiguration()) + ) { + const config = trace.getConfiguration() as Record; + return typeof config.root_span_id === "string" + ? config.root_span_id + : undefined; + } + return undefined; +} + +function withPipelineDefaults( + row: unknown, + candidate: HydratedCandidate, + rowIndex: number | undefined, +): unknown { + if (!isObject(row)) { + throw new Error("Dataset pipeline transform must return an object row."); + } + const fallbackId = candidateFallbackId(candidate); + return { + ...row, + ...(row.id === undefined && fallbackId + ? { + id: rowIndex === undefined ? fallbackId : `${fallbackId}:${rowIndex}`, + } + : {}), + ...(row.origin === undefined && candidate.origin + ? { origin: candidate.origin } + : {}), + }; +} + +async function transformRefs( + braintrust: BraintrustModule, + pipeline: DatasetPipelineDefinition, + sourceProjectId: string, + refs: unknown[], + maxConcurrency = 16, +): Promise { + if (!Number.isInteger(maxConcurrency) || maxConcurrency <= 0) { + throw new Error("maxConcurrency must be a positive integer."); + } + if (typeof pipeline.transform !== "function") { + throw new Error("Dataset pipeline transform must be a function."); + } + const candidates = await hydrateDiscoveryRefs( + braintrust, + pipeline, + sourceProjectId, + refs, + ); + const transformedRows: unknown[][] = new Array(candidates.length); + let nextIndex = 0; + + async function worker() { + while (nextIndex < candidates.length) { + const index = nextIndex++; + const candidate = candidates[index]; + const result = await pipeline.transform!(candidate, { pipeline }); + const rows = normalizeTransformResult(result); + transformedRows[index] = rows.map((row, rowIndex) => + withPipelineDefaults( + row, + candidate, + rows.length > 1 ? rowIndex : undefined, + ), + ); + } + } + + const workerCount = Math.min(maxConcurrency, Math.max(candidates.length, 1)); + await Promise.all(Array.from({ length: workerCount }, () => worker())); + return transformedRows.flat(); +} + +async function main() { + const pipelineFile = process.argv[2]; + if (!pipelineFile) { + throw new Error("Pipeline file is required."); + } + + const [braintrust, loadedModule] = await Promise.all([ + loadBraintrust(pipelineFile), + loadPipelineFile(pipelineFile), + ]); + const pipeline = selectPipeline( + collectPipelines(braintrust, loadedModule), + process.env.BT_DATASET_PIPELINE_NAME || undefined, + ); + const stage = parseStage(); + + if (stage === "inspect") { + writeResponse({ + name: pipeline.name, + source: pipeline.source, + target: pipeline.target, + }); + } else if (stage === "transform") { + const request = await readRequest(); + const refs = requireArrayField(request, "refs"); + const sourceProjectId = requireStringField(request, "sourceProjectId"); + const rows = await transformRefs( + braintrust, + pipeline, + sourceProjectId, + refs, + optionalPositiveIntegerField(request, "maxConcurrency"), + ); + writeResponse({ candidates: refs.length, rowCount: rows.length, rows }); + } else { + throw new Error(`Unsupported dataset pipeline stage: ${stage}`); + } +} + +main().catch((err) => { + console.error(err instanceof Error ? err.stack || err.message : String(err)); + process.exit(1); +}); diff --git a/src/auth.rs b/src/auth.rs index ec13df11..303f4f22 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -688,6 +688,18 @@ pub async fn resolved_auth_env(base: &BaseArgs) -> Result> Ok(envs) } +pub async fn resolved_runner_env(base: &BaseArgs) -> Result> { + let mut envs = resolved_auth_env(base).await?; + let project = base + .project + .clone() + .or_else(|| crate::config::load().ok().and_then(|c| c.project)); + if let Some(project) = project { + envs.push(("BRAINTRUST_DEFAULT_PROJECT".to_string(), project)); + } + Ok(envs) +} + fn resolve_profile_for_org<'a>(org: &str, store: &'a AuthStore) -> Option<&'a str> { if store.profiles.contains_key(org) { return Some( diff --git a/src/datasets/mod.rs b/src/datasets/mod.rs new file mode 100644 index 00000000..76ce2e97 --- /dev/null +++ b/src/datasets/mod.rs @@ -0,0 +1,24 @@ +use anyhow::Result; +use clap::{Args, Subcommand}; + +use crate::args::BaseArgs; + +mod pipeline; + +#[derive(Debug, Clone, Args)] +pub struct DatasetsArgs { + #[command(subcommand)] + command: DatasetsCommands, +} + +#[derive(Debug, Clone, Subcommand)] +enum DatasetsCommands { + /// Run dataset pipeline workflows + Pipeline(pipeline::PipelineArgs), +} + +pub async fn run(base: BaseArgs, args: DatasetsArgs) -> Result<()> { + match args.command { + DatasetsCommands::Pipeline(args) => pipeline::run(base, args).await, + } +} diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs new file mode 100644 index 00000000..6c678106 --- /dev/null +++ b/src/datasets/pipeline.rs @@ -0,0 +1,904 @@ +use std::fs::File; +use std::io::{self, BufWriter, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; + +use anyhow::{bail, Context, Result}; +use braintrust_sdk_rust::Logs3BatchUploader; +use clap::{Args, Subcommand}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; +use urlencoding::encode; + +use crate::args::BaseArgs; +use crate::auth::{login, resolved_runner_env, LoginContext}; +use crate::http::ApiClient; +use crate::js_runner::{build_js_runner_command, materialize_runner_script_in_cwd}; +use crate::projects::api::{create_project, get_project_by_name, Project}; +use crate::sync::discovery::{discover_project_log_refs, ProjectLogRefScope}; +use crate::sync::{read_jsonl_values, write_jsonl_value, write_jsonl_values}; + +const RUNNER_FILE: &str = "dataset-pipeline-runner.ts"; +const RUNNER_SOURCE: &str = include_str!("../../scripts/dataset-pipeline-runner.ts"); + +#[derive(Debug, Clone, Args)] +pub struct PipelineArgs { + #[command(subcommand)] + command: PipelineCommands, +} + +#[derive(Debug, Clone, Subcommand)] +enum PipelineCommands { + /// Fetch, transform, and insert dataset rows + Run(PipelineRunArgs), + /// Discover source trace/span refs to JSONL + Fetch(PipelineFetchArgs), + /// Transform candidate JSONL into proposed dataset row JSONL + Transform(PipelineTransformArgs), + /// Copy proposed row JSONL for human or agent review + Review(PipelineReviewArgs), + /// Insert approved row JSONL into the target dataset + Commit(PipelineCommitArgs), +} + +#[derive(Debug, Clone, Args)] +struct PipelineRunnerArgs { + /// Dataset pipeline file to execute + #[arg(value_name = "PIPELINE")] + pipeline: PathBuf, + + /// Pipeline name, required when the file defines multiple pipelines + #[arg(long)] + name: Option, + + /// JavaScript/TypeScript runner binary (e.g. tsx, vite-node, ts-node) + #[arg( + long, + short = 'r', + env = "BT_DATASET_PIPELINE_RUNNER", + value_name = "RUNNER" + )] + runner: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineFetchOptions { + /// Maximum number of source refs to discover + #[arg(long, default_value_t = 100, value_parser = parse_positive_usize)] + target: usize, + + /// Restrict the source query to one or more root span ids + #[arg(long = "root-span-id")] + root_span_ids: Vec, + + /// Additional SQL predicate appended to the source WHERE clause + #[arg(long)] + extra_where_sql: Option, + + /// Page size for discovery BTQL pagination + #[arg(long, default_value_t = 1000, value_parser = parse_positive_usize)] + page_size: usize, +} + +#[derive(Debug, Clone, Args)] +struct PipelineTransformOptions { + /// Maximum concurrent transform calls + #[arg(long, default_value_t = 16, value_parser = parse_positive_usize)] + max_concurrency: usize, +} + +#[derive(Debug, Clone, Args)] +struct PipelineRunArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + fetch: PipelineFetchOptions, + + #[command(flatten)] + transform: PipelineTransformOptions, +} + +#[derive(Debug, Clone, Args)] +struct PipelineFetchArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + fetch: PipelineFetchOptions, + + /// Output JSONL file. Defaults to stdout. + #[arg(long)] + out: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineTransformArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + #[command(flatten)] + transform: PipelineTransformOptions, + + /// Input candidate JSONL file + #[arg(long = "in")] + input: PathBuf, + + /// Output proposed dataset row JSONL file. Defaults to stdout. + #[arg(long)] + out: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineReviewArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + /// Input proposed dataset row JSONL file + #[arg(long = "in")] + input: PathBuf, + + /// Output approved dataset row JSONL file. Defaults to stdout. + #[arg(long)] + out: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineCommitArgs { + #[command(flatten)] + runner: PipelineRunnerArgs, + + /// Input approved dataset row JSONL file + #[arg(long = "in")] + input: PathBuf, +} + +pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { + match args.command { + PipelineCommands::Run(args) => { + let inspect = inspect_pipeline(&base, &args.runner).await?; + let tempdir = + tempfile::tempdir().context("failed to create dataset pipeline temp dir")?; + let refs_path = tempdir.path().join("discovered.jsonl"); + discover_refs(&base, &inspect, &args.fetch, Some(&refs_path), false).await?; + + let refs = read_jsonl_values(&refs_path)?; + let source_project = resolve_pipeline_source_project(&base, &inspect.source).await?; + let transform_response: PipelineTransformResponse = run_runner_json( + &base, + "transform", + &args.runner, + &json!({ + "sourceProjectId": source_project.id, + "refs": refs, + "maxConcurrency": args.transform.max_concurrency, + }), + ) + .await?; + validate_transform_response(&transform_response)?; + let row_count = transform_response.rows.len(); + let inserted = + upload_dataset_rows(&base, &inspect.target, transform_response.rows).await?; + print_summary( + &base, + json!({ + "refs": transform_response.candidates, + "rows": row_count, + "inserted": inserted, + }), + false, + ) + } + PipelineCommands::Fetch(args) => { + let inspect = inspect_pipeline(&base, &args.runner).await?; + discover_refs(&base, &inspect, &args.fetch, args.out.as_deref(), true).await + } + PipelineCommands::Transform(args) => transform_refs(&base, args).await, + PipelineCommands::Review(args) => review_rows(&base, args), + PipelineCommands::Commit(args) => commit_rows(&base, args).await, + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineInspect { + source: PipelineSourceInspect, + target: PipelineTargetInspect, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineSourceInspect { + project_id: Option, + project_name: Option, + org_name: Option, + filter: Option, + scope: Option, + limit: Option, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineTargetInspect { + project_id: Option, + project_name: Option, + org_name: Option, + dataset_name: String, + description: Option, + metadata: Option, +} + +#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +enum PipelineScope { + Span, + Trace, +} + +impl PipelineScope { + fn from_source(source: &PipelineSourceInspect) -> Self { + source.scope.unwrap_or(PipelineScope::Span) + } +} + +#[derive(Debug, Clone, Deserialize)] +struct NamedObject { + id: String, + name: String, +} + +#[derive(Debug, Clone, Deserialize)] +struct CreatedDataset { + id: String, +} + +#[derive(Debug, Deserialize)] +struct NamedObjectListResponse { + objects: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct DatasetPipelineRow { + id: Option, + input: Option, + expected: Option, + output: Option, + tags: Option>, + metadata: Option>, + origin: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +struct DatasetPipelineObjectReference { + object_type: String, + object_id: String, + id: String, + #[serde(rename = "_xact_id", skip_serializing_if = "Option::is_none")] + xact_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + created: Option, +} + +#[derive(Debug, Serialize)] +struct DatasetPipelineUploadRow { + project_id: String, + dataset_id: String, + id: String, + span_id: String, + root_span_id: String, + span_parents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + expected: Option, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tags: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + origin: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineTransformResponse { + candidates: usize, + row_count: usize, + rows: Vec, +} + +async fn inspect_pipeline(base: &BaseArgs, runner: &PipelineRunnerArgs) -> Result { + let output = build_runner_command(base, "inspect", runner, |_, _| Ok(())) + .await? + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .output() + .context("failed to start dataset pipeline inspect runner")?; + if !output.status.success() { + bail!( + "dataset pipeline inspect runner failed with status {}", + output.status + ); + } + serde_json::from_slice(&output.stdout) + .context("failed to parse dataset pipeline inspect output") +} + +async fn build_runner_command( + base: &BaseArgs, + stage: &'static str, + runner: &PipelineRunnerArgs, + configure: F, +) -> Result +where + F: FnOnce(&mut Command, &'static str) -> Result<()>, +{ + let runner_script = + materialize_runner_script_in_cwd("dataset-pipeline-runners", RUNNER_FILE, RUNNER_SOURCE)?; + let pipeline_file = runner.pipeline.clone(); + let files = vec![pipeline_file.clone()]; + let mut command = build_js_runner_command(runner.runner.as_deref(), &runner_script, &files); + + command.envs(resolved_runner_env(base).await?); + command.env("BT_DATASET_PIPELINE_STAGE", stage); + if let Some(name) = runner.name.as_deref() { + command.env("BT_DATASET_PIPELINE_NAME", name); + } + configure(&mut command, stage)?; + Ok(command) +} + +async fn run_runner_json( + base: &BaseArgs, + stage: &'static str, + runner: &PipelineRunnerArgs, + request: &Value, +) -> Result +where + T: DeserializeOwned, +{ + let mut command = build_runner_command(base, stage, runner, |_, _| Ok(())).await?; + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + command.stderr(Stdio::inherit()); + + let mut child = command + .spawn() + .context("failed to start dataset pipeline runner")?; + { + let mut stdin = child + .stdin + .take() + .context("dataset pipeline runner stdin was not available")?; + serde_json::to_writer(&mut stdin, request) + .context("failed to write dataset pipeline runner request")?; + stdin + .write_all(b"\n") + .context("failed to finish dataset pipeline runner request")?; + } + + let output = child + .wait_with_output() + .context("failed to wait for dataset pipeline runner")?; + if !output.status.success() { + bail!( + "dataset pipeline runner failed with status {}", + output.status + ); + } + serde_json::from_slice(&output.stdout) + .context("failed to parse dataset pipeline runner response") +} + +async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result<()> { + let inspect = inspect_pipeline(base, &args.runner).await?; + let source_project = resolve_pipeline_source_project(base, &inspect.source).await?; + let refs = read_jsonl_values(&args.input)?; + let response: PipelineTransformResponse = run_runner_json( + base, + "transform", + &args.runner, + &json!({ + "sourceProjectId": source_project.id, + "refs": refs, + "maxConcurrency": args.transform.max_concurrency, + }), + ) + .await?; + validate_transform_response(&response)?; + let row_count = response.rows.len(); + write_jsonl_values(args.out.as_deref(), &response.rows)?; + print_summary( + base, + json!({ + "candidates": response.candidates, + "rows": row_count, + "out": args + .out + .as_ref() + .map(|path| path.display().to_string()) + .unwrap_or_else(|| "stdout".to_string()), + }), + args.out.is_none(), + ) +} + +fn validate_transform_response(response: &PipelineTransformResponse) -> Result<()> { + if response.row_count != response.rows.len() { + bail!( + "dataset pipeline runner response rowCount {} did not match rows length {}", + response.row_count, + response.rows.len() + ); + } + Ok(()) +} + +fn review_rows(base: &BaseArgs, args: PipelineReviewArgs) -> Result<()> { + let rows = read_jsonl_values(&args.input)?; + write_jsonl_values(args.out.as_deref(), &rows)?; + print_summary( + base, + json!({ + "rows": rows.len(), + "out": args + .out + .as_ref() + .map(|path| path.display().to_string()) + .unwrap_or_else(|| "stdout".to_string()), + }), + args.out.is_none(), + ) +} + +async fn commit_rows(base: &BaseArgs, args: PipelineCommitArgs) -> Result<()> { + let inspect = inspect_pipeline(base, &args.runner).await?; + let rows = read_jsonl_values(&args.input)?; + let row_count = rows.len(); + let inserted = upload_dataset_rows(base, &inspect.target, rows).await?; + print_summary( + base, + json!({ + "rows": row_count, + "inserted": inserted, + }), + false, + ) +} + +async fn upload_dataset_rows( + base: &BaseArgs, + target: &PipelineTargetInspect, + rows: Vec, +) -> Result { + let mut target_base = base.clone(); + if let Some(org_name) = target.org_name.as_deref() { + target_base.org_name = Some(org_name.to_string()); + } + let ctx = login(&target_base).await?; + let client = ApiClient::new(&ctx)?; + let project = resolve_target_project(&client, target).await?; + let dataset = resolve_target_dataset(&client, target, &project).await?; + let upload_run_id = chrono::Utc::now().timestamp_millis().to_string(); + + let mut prepared_rows = Vec::with_capacity(rows.len()); + for (index, row) in rows.into_iter().enumerate() { + let row: DatasetPipelineRow = serde_json::from_value(row).with_context(|| { + format!("dataset pipeline row {index} does not match the expected dataset row schema") + })?; + let row = + prepare_dataset_row_for_upload(row, &project.id, &dataset.id, &upload_run_id, index); + prepared_rows.push(upload_row_to_map(row)?); + } + + let mut uploader = Logs3BatchUploader::new( + ctx.api_url.clone(), + ctx.login + .api_key() + .context("login state missing API key for dataset pipeline upload")?, + ctx.login + .org_name() + .filter(|org_name| !org_name.trim().is_empty()), + ) + .context("failed to initialize dataset pipeline uploader")?; + for chunk in prepared_rows.chunks(1000) { + uploader + .upload_rows(chunk, 1000) + .await + .map_err(|err| anyhow::anyhow!("dataset pipeline upload failed: {err}"))?; + } + Ok(prepared_rows.len()) +} + +fn prepare_dataset_row_for_upload( + row: DatasetPipelineRow, + project_id: &str, + dataset_id: &str, + upload_run_id: &str, + row_index: usize, +) -> DatasetPipelineUploadRow { + let id = row + .id + .clone() + .unwrap_or_else(|| format!("dataset-pipeline-{upload_run_id}-{row_index}")); + + DatasetPipelineUploadRow { + project_id: project_id.to_string(), + dataset_id: dataset_id.to_string(), + span_id: id.clone(), + root_span_id: id.clone(), + id, + span_parents: Vec::new(), + input: row.input, + expected: row.expected, + output: row.output, + tags: row.tags, + metadata: row.metadata, + origin: row.origin, + } +} + +fn upload_row_to_map(row: DatasetPipelineUploadRow) -> Result> { + match serde_json::to_value(row).context("failed to serialize dataset pipeline upload row")? { + Value::Object(row) => Ok(row), + _ => bail!("serialized dataset pipeline upload row was not an object"), + } +} + +async fn resolve_target_project( + client: &ApiClient, + target: &PipelineTargetInspect, +) -> Result { + if let Some(project_id) = target.project_id.as_deref() { + return Ok(Project { + id: project_id.to_string(), + name: target + .project_name + .clone() + .unwrap_or_else(|| project_id.to_string()), + org_id: String::new(), + description: None, + }); + } + let project_name = target + .project_name + .as_deref() + .context("dataset pipeline target requires projectName or projectId")?; + if let Some(project) = get_project_by_name(client, project_name).await? { + Ok(project) + } else { + create_project(client, project_name) + .await + .with_context(|| format!("project '{project_name}' not found, and creating it failed")) + } +} + +async fn resolve_target_dataset( + client: &ApiClient, + target: &PipelineTargetInspect, + project: &Project, +) -> Result { + let dataset_name = target.dataset_name.trim(); + if dataset_name.is_empty() { + bail!("dataset pipeline target.datasetName cannot be empty"); + } + + let objects = list_project_datasets(client, &project.id).await?; + if let Some(dataset) = objects + .iter() + .find(|object| object.id == dataset_name || object.name == dataset_name) + { + return Ok(CreatedDataset { + id: dataset.id.clone(), + }); + } + + if is_uuid_like(dataset_name) { + bail!( + "dataset id '{}' not found in project '{}'", + dataset_name, + project.name + ); + } + + create_dataset(client, &project.id, target) + .await + .with_context(|| format!("dataset '{dataset_name}' not found, and creating it failed")) +} + +async fn list_project_datasets(client: &ApiClient, project_id: &str) -> Result> { + let path = format!( + "/v1/dataset?org_name={}&project_id={}", + encode(client.org_name()), + encode(project_id) + ); + let response: NamedObjectListResponse = client.get(&path).await?; + Ok(response.objects) +} + +async fn create_dataset( + client: &ApiClient, + project_id: &str, + target: &PipelineTargetInspect, +) -> Result { + let mut body = json!({ + "name": target.dataset_name.clone(), + "project_id": project_id, + "org_name": client.org_name(), + }); + if let (Value::Object(body), Some(description)) = (&mut body, target.description.as_deref()) { + body.insert( + "description".to_string(), + Value::String(description.to_string()), + ); + } + if let (Value::Object(body), Some(metadata)) = (&mut body, target.metadata.as_ref()) { + body.insert("metadata".to_string(), metadata.clone()); + } + client.post("/v1/dataset", &body).await +} + +async fn discover_refs( + base: &BaseArgs, + inspect: &PipelineInspect, + options: &PipelineFetchOptions, + out: Option<&Path>, + emit_summary: bool, +) -> Result<()> { + let (ctx, client, project) = resolve_pipeline_source_context(base, &inspect.source).await?; + let scope = PipelineScope::from_source(&inspect.source); + let target = inspect.source.limit.unwrap_or(options.target); + let filter = discovery_filter(&inspect.source, options); + + let mut writer: Box = if let Some(path) = out { + if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) { + std::fs::create_dir_all(parent) + .with_context(|| format!("failed to create {}", parent.display()))?; + } + Box::new(BufWriter::new(File::create(path).with_context(|| { + format!("failed to create {}", path.display()) + })?)) + } else { + Box::new(BufWriter::new(io::stdout())) + }; + + let result = discover_project_log_refs( + &client, + &ctx, + &project.id, + filter.as_ref(), + project_log_ref_scope(scope), + target, + options.page_size, + |reference| write_jsonl_value(writer.as_mut(), &reference.to_value()).map(|_| ()), + ) + .await?; + writer.flush().context("failed to flush discovery output")?; + + let out_label = out + .map(|path| path.display().to_string()) + .unwrap_or_else(|| "stdout".to_string()); + if emit_summary { + print_summary( + base, + json!({ + "refs": result.refs, + "pages": result.pages, + "scope": match scope { PipelineScope::Trace => "trace", PipelineScope::Span => "span" }, + "out": out_label, + }), + out.is_none(), + )?; + } + Ok(()) +} + +fn project_log_ref_scope(scope: PipelineScope) -> ProjectLogRefScope { + match scope { + PipelineScope::Trace => ProjectLogRefScope::Trace, + PipelineScope::Span => ProjectLogRefScope::Span, + } +} + +async fn resolve_pipeline_source_project( + base: &BaseArgs, + source: &PipelineSourceInspect, +) -> Result { + let (_, _, project) = resolve_pipeline_source_context(base, source).await?; + Ok(project) +} + +async fn resolve_pipeline_source_context( + base: &BaseArgs, + source: &PipelineSourceInspect, +) -> Result<(LoginContext, ApiClient, Project)> { + let mut source_base = base.clone(); + if let Some(org_name) = source.org_name.as_deref() { + source_base.org_name = Some(org_name.to_string()); + } + let ctx = login(&source_base).await?; + let client = ApiClient::new(&ctx)?; + let project = resolve_source_project(&client, source).await?; + Ok((ctx, client, project)) +} + +fn discovery_filter( + source: &PipelineSourceInspect, + options: &PipelineFetchOptions, +) -> Option { + let mut filters = Vec::new(); + if let Some(filter) = source + .filter + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + { + filters.push(json!({ "btql": filter })); + } + if !options.root_span_ids.is_empty() { + filters.push(root_span_id_filter(&options.root_span_ids)); + } + if let Some(filter) = options + .extra_where_sql + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) + { + filters.push(json!({ "btql": filter })); + } + match filters.len() { + 0 => None, + 1 => filters.into_iter().next(), + _ => Some(json!({ "op": "and", "children": filters })), + } +} + +fn root_span_id_filter(root_span_ids: &[String]) -> Value { + json!({ + "op": "in", + "left": { "op": "ident", "name": ["root_span_id"] }, + "right": { "op": "literal", "value": root_span_ids } + }) +} + +async fn resolve_source_project( + client: &ApiClient, + source: &PipelineSourceInspect, +) -> Result { + if let Some(project_id) = source.project_id.as_deref() { + return Ok(Project { + id: project_id.to_string(), + name: source + .project_name + .clone() + .unwrap_or_else(|| project_id.to_string()), + org_id: String::new(), + description: None, + }); + } + let project_name = source + .project_name + .as_deref() + .context("dataset pipeline source requires projectName or projectId")?; + get_project_by_name(client, project_name) + .await? + .with_context(|| format!("project '{project_name}' not found")) +} + +fn print_summary(base: &BaseArgs, summary: Value, force_stderr: bool) -> Result<()> { + let object = summary + .as_object() + .context("dataset pipeline summary must be an object")?; + if base.json && !force_stderr { + println!("{}", serde_json::to_string(&summary)?); + return Ok(()); + } + let parts = object + .iter() + .map(|(key, value)| format!("{key}: {}", summary_value(value))) + .collect::>(); + eprintln!("{}", parts.join(", ")); + Ok(()) +} + +fn summary_value(value: &Value) -> String { + match value { + Value::String(value) => value.clone(), + Value::Number(value) => value.to_string(), + Value::Bool(value) => value.to_string(), + Value::Null => "null".to_string(), + Value::Array(_) | Value::Object(_) => value.to_string(), + } +} + +fn is_uuid_like(value: &str) -> bool { + let bytes = value.as_bytes(); + if bytes.len() != 36 { + return false; + } + for (index, byte) in bytes.iter().enumerate() { + match index { + 8 | 13 | 18 | 23 => { + if *byte != b'-' { + return false; + } + } + _ if !byte.is_ascii_hexdigit() => return false, + _ => {} + } + } + true +} + +fn parse_positive_usize(value: &str) -> std::result::Result { + let parsed = value + .parse::() + .map_err(|_| format!("invalid positive integer '{value}'"))?; + if parsed == 0 { + return Err("value must be greater than 0".to_string()); + } + Ok(parsed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dataset_pipeline_row_rejects_unknown_fields() { + let err = serde_json::from_value::(json!({ + "input": "hello", + "span_attributes": { "type": "llm" }, + })) + .expect_err("unexpected dataset row fields should be rejected"); + + assert!(err.to_string().contains("unknown field")); + } + + #[test] + fn prepare_dataset_row_for_upload_uses_typed_schema() { + let row = serde_json::from_value::(json!({ + "id": "row-1", + "input": { "question": "hello" }, + "expected": "world", + "tags": ["smoke"], + "metadata": { "source": "test" }, + "origin": { + "object_type": "project_logs", + "object_id": "source-project", + "id": "source-span" + } + })) + .expect("valid dataset pipeline row should deserialize"); + + let upload = + prepare_dataset_row_for_upload(row, "target-project", "target-dataset", "run", 0); + let upload = upload_row_to_map(upload).expect("upload row should serialize"); + + assert_eq!(upload.get("id"), Some(&json!("row-1"))); + assert_eq!(upload.get("span_id"), Some(&json!("row-1"))); + assert_eq!(upload.get("root_span_id"), Some(&json!("row-1"))); + assert_eq!(upload.get("project_id"), Some(&json!("target-project"))); + assert_eq!(upload.get("dataset_id"), Some(&json!("target-dataset"))); + assert!(!upload.contains_key("log_id")); + assert!(!upload.contains_key("experiment_id")); + } + + #[test] + fn transform_response_validation_rejects_row_count_mismatch() { + let response = PipelineTransformResponse { + candidates: 1, + row_count: 2, + rows: vec![json!({ "input": "one" })], + }; + + let err = + validate_transform_response(&response).expect_err("rowCount should match rows length"); + assert!(err.to_string().contains("rowCount 2")); + } +} diff --git a/src/eval.rs b/src/eval.rs index c01262ae..411dd6ed 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -41,7 +41,8 @@ use ratatui::widgets::{Cell, Row, Table}; use ratatui::Terminal; use crate::args::BaseArgs; -use crate::auth::resolved_auth_env; +use crate::auth::resolved_runner_env; +use crate::js_runner; use crate::python_runner; use crate::ui::{animations_enabled, is_quiet}; @@ -795,7 +796,7 @@ async fn spawn_eval_runner( set_node_heap_size_env(&mut cmd); } - cmd.envs(build_env(base).await?); + cmd.envs(resolved_runner_env(base).await?); for (key, value) in extra_env { cmd.env(key, value); } @@ -2049,18 +2050,6 @@ fn format_watch_paths(paths: &[PathBuf]) -> String { } } -async fn build_env(base: &BaseArgs) -> Result> { - let mut envs = resolved_auth_env(base).await?; - let project = base - .project - .clone() - .or_else(|| crate::config::load().ok().and_then(|c| c.project)); - if let Some(project) = &project { - envs.push(("BRAINTRUST_DEFAULT_PROJECT".to_string(), project.clone())); - } - Ok(envs) -} - fn detect_eval_language( files: &[String], language_override: Option, @@ -2261,7 +2250,7 @@ fn build_js_plan( ) -> Result { if let Some(explicit) = runner_override { let resolved_runner = resolve_js_runner_command(explicit, files); - if is_deno_runner(explicit) || is_deno_runner_path(resolved_runner.as_ref()) { + if is_deno_runner(explicit) || js_runner::is_deno_runner_path(resolved_runner.as_ref()) { let runner_script = prepare_js_runner_in_cwd()?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(resolved_runner.as_os_str(), &runner_script, files), @@ -2276,7 +2265,7 @@ fn build_js_plan( } if let Some(auto_runner) = find_js_runner_binary(files) { - if is_deno_runner_path(&auto_runner) { + if js_runner::is_deno_runner_path(&auto_runner) { let runner_script = prepare_js_runner_in_cwd()?; return Ok(JsRunnerPlan { cmd: build_deno_js_command(auto_runner.as_os_str(), &runner_script, files), @@ -2300,7 +2289,7 @@ fn build_js_plan( fn build_vite_node_fallback_command(runner: &Path, files: &[String]) -> Result { if let Some(path) = find_node_module_bin_for_files("vite-node", files) - .or_else(|| find_binary_in_path(&["vite-node"])) + .or_else(|| js_runner::find_binary_in_path(&["vite-node"])) { let mut command = Command::new(path); command.arg(runner).args(files); @@ -2327,15 +2316,7 @@ fn build_deno_js_command( } fn deno_js_command_args(runner: &Path, files: &[String]) -> Vec { - let mut args = vec![ - OsString::from("run"), - OsString::from("-A"), - OsString::from("--node-modules-dir=auto"), - OsString::from("--unstable-detect-cjs"), - runner.as_os_str().to_os_string(), - ]; - args.extend(files.iter().map(OsString::from)); - args + js_runner::deno_runner_args(runner, &files_as_paths(files)) } fn build_python_command( @@ -2379,104 +2360,38 @@ fn python_runner_search_roots(files: &[String]) -> Vec { } fn find_js_runner_binary(files: &[String]) -> Option { - // Prefer local project bins first, then PATH. `tsx` remains the preferred - // default, with other common TS runners as fallback. - const RUNNER_CANDIDATES: &[&str] = &["tsx", "vite-node", "ts-node", "ts-node-esm", "deno"]; - - for candidate in RUNNER_CANDIDATES { - if let Some(path) = find_node_module_bin_for_files(candidate, files) { - return Some(path); - } - } - - find_binary_in_path(RUNNER_CANDIDATES) + js_runner::find_js_runner_binary(&files_as_paths(files)) } fn resolve_js_runner_command(runner: &str, files: &[String]) -> PathBuf { - if is_path_like_runner(runner) { - return PathBuf::from(runner); - } - - find_node_module_bin_for_files(runner, files) - .or_else(|| find_binary_in_path(&[runner])) - .unwrap_or_else(|| PathBuf::from(runner)) -} - -fn is_path_like_runner(runner: &str) -> bool { - let path = Path::new(runner); - path.is_absolute() || runner.contains('/') || runner.contains('\\') || runner.starts_with('.') + js_runner::resolve_js_runner_command(runner, &files_as_paths(files)) } fn find_node_module_bin_for_files(binary: &str, files: &[String]) -> Option { - let search_roots = js_runner_search_roots(files); - for root in &search_roots { - if let Some(path) = find_node_module_bin(binary, root) { - return Some(path); - } - } - None + js_runner::find_node_module_bin_for_files(binary, &files_as_paths(files)) } -fn js_runner_search_roots(files: &[String]) -> Vec { - let mut search_roots = Vec::new(); - if let Ok(cwd) = std::env::current_dir() { - search_roots.push(cwd.clone()); - for file in files { - let path = PathBuf::from(file); - let absolute = if path.is_absolute() { - path - } else { - cwd.join(path) - }; - if let Some(parent) = absolute.parent() { - search_roots.push(parent.to_path_buf()); - } - } - } - search_roots +fn files_as_paths(files: &[String]) -> Vec { + files.iter().map(PathBuf::from).collect() } fn is_deno_runner(runner: &str) -> bool { - let file_name = Path::new(runner) - .file_name() - .and_then(|value| value.to_str()) - .unwrap_or(runner); - file_name.eq_ignore_ascii_case("deno") || file_name.eq_ignore_ascii_case("deno.exe") -} - -fn is_deno_runner_path(runner: &Path) -> bool { - runner - .file_name() - .and_then(|value| value.to_str()) - .map(|name| name.eq_ignore_ascii_case("deno") || name.eq_ignore_ascii_case("deno.exe")) - .unwrap_or(false) + js_runner::is_deno_runner_path(Path::new(runner)) } fn select_js_runner_entrypoint(default_runner: &Path, runner_command: &Path) -> Result { - if is_ts_node_runner(runner_command) { + if js_runner::is_ts_node_runner_path(runner_command) { return prepare_js_runner_in_cwd(); } Ok(default_runner.to_path_buf()) } fn prepare_js_runner_in_cwd() -> Result { - let cwd = std::env::current_dir().context("failed to resolve current working directory")?; - let cache_dir = cwd - .join(".bt") - .join("eval-runners") - .join(env!("CARGO_PKG_VERSION")); - std::fs::create_dir_all(&cache_dir).with_context(|| { - format!( - "failed to create eval runner cache dir {}", - cache_dir.display() - ) - })?; - materialize_runner_script(&cache_dir, JS_RUNNER_FILE, JS_RUNNER_SOURCE) + js_runner::materialize_runner_script_in_cwd("eval-runners", JS_RUNNER_FILE, JS_RUNNER_SOURCE) } fn runner_bin_name(runner_command: &Path) -> Option { - let name = runner_command.file_name()?.to_str()?.to_ascii_lowercase(); - Some(name.strip_suffix(".cmd").unwrap_or(&name).to_string()) + js_runner::runner_bin_name(runner_command) } fn runner_kind_for_bin(runner_command: &Path) -> RunnerKind { @@ -2508,47 +2423,6 @@ fn set_node_heap_size_env(command: &mut Command) { command.env("NODE_OPTIONS", merged); } -fn is_ts_node_runner(runner_command: &Path) -> bool { - runner_bin_name(runner_command).is_some_and(|n| n == "ts-node" || n == "ts-node-esm") -} - -fn find_node_module_bin(binary: &str, start: &Path) -> Option { - let mut current = Some(start); - while let Some(dir) = current { - let base = dir.join("node_modules").join(".bin").join(binary); - if base.is_file() { - return Some(base); - } - if cfg!(windows) { - let cmd = base.with_extension("cmd"); - if cmd.is_file() { - return Some(cmd); - } - } - current = dir.parent(); - } - None -} - -fn find_binary_in_path(candidates: &[&str]) -> Option { - let paths = std::env::var_os("PATH")?; - for dir in std::env::split_paths(&paths) { - for candidate in candidates { - let path = dir.join(candidate); - if path.is_file() { - return Some(path); - } - if cfg!(windows) { - let cmd = path.with_extension("cmd"); - if cmd.is_file() { - return Some(cmd); - } - } - } - } - None -} - fn build_sse_socket_path() -> Result { let pid = std::process::id(); let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); @@ -2621,13 +2495,7 @@ fn prepare_eval_runners_in_dir(cache_dir: &Path) -> Result<(PathBuf, PathBuf)> { } fn materialize_runner_script(cache_dir: &Path, file_name: &str, source: &str) -> Result { - let path = cache_dir.join(file_name); - let current = std::fs::read_to_string(&path).ok(); - if current.as_deref() != Some(source) { - std::fs::write(&path, source) - .with_context(|| format!("failed to write eval runner script {}", path.display()))?; - } - Ok(path) + js_runner::materialize_runner_script(cache_dir, file_name, source) } #[derive(Debug)] diff --git a/src/js_runner.rs b/src/js_runner.rs index cd4c94dc..31bba482 100644 --- a/src/js_runner.rs +++ b/src/js_runner.rs @@ -1,4 +1,4 @@ -use std::ffi::OsStr; +use std::ffi::{OsStr, OsString}; use std::path::{Path, PathBuf}; use std::process::Command; @@ -105,16 +105,20 @@ pub fn resolve_js_runner_command(runner: &str, files: &[PathBuf]) -> PathBuf { fn build_deno_command(deno_runner: &OsStr, runner_script: &Path, files: &[PathBuf]) -> Command { let mut command = Command::new(deno_runner); + command.args(deno_runner_args(runner_script, files)); command - .arg("run") - .arg("-A") - .arg("--node-modules-dir=auto") - .arg("--unstable-detect-cjs") - .arg(runner_script); - for file in files { - command.arg(file); - } - command +} + +pub fn deno_runner_args(runner_script: &Path, files: &[PathBuf]) -> Vec { + let mut args = vec![ + OsString::from("run"), + OsString::from("-A"), + OsString::from("--node-modules-dir=auto"), + OsString::from("--unstable-detect-cjs"), + runner_script.as_os_str().to_os_string(), + ]; + args.extend(files.iter().map(|file| file.as_os_str().to_os_string())); + args } fn is_path_like_runner(runner: &str) -> bool { @@ -122,7 +126,7 @@ fn is_path_like_runner(runner: &str) -> bool { path.is_absolute() || runner.contains('/') || runner.contains('\\') || runner.starts_with('.') } -fn is_deno_runner_path(runner: &Path) -> bool { +pub fn is_deno_runner_path(runner: &Path) -> bool { runner .file_name() .and_then(|value| value.to_str()) @@ -130,7 +134,7 @@ fn is_deno_runner_path(runner: &Path) -> bool { .unwrap_or(false) } -fn find_node_module_bin_for_files(binary: &str, files: &[PathBuf]) -> Option { +pub fn find_node_module_bin_for_files(binary: &str, files: &[PathBuf]) -> Option { for root in js_runner_search_roots(files) { if let Some(path) = find_node_module_bin(binary, &root) { return Some(path); @@ -176,7 +180,7 @@ fn find_node_module_bin(binary: &str, start: &Path) -> Option { None } -fn find_binary_in_path(candidates: &[&str]) -> Option { +pub fn find_binary_in_path(candidates: &[&str]) -> Option { let paths = std::env::var_os("PATH")?; for dir in std::env::split_paths(&paths) { for candidate in candidates { @@ -196,6 +200,15 @@ fn find_binary_in_path(candidates: &[&str]) -> Option { None } +pub fn runner_bin_name(runner_command: &Path) -> Option { + let name = runner_command.file_name()?.to_str()?.to_ascii_lowercase(); + Some(name.strip_suffix(".cmd").unwrap_or(&name).to_string()) +} + +pub fn is_ts_node_runner_path(runner_command: &Path) -> bool { + runner_bin_name(runner_command).is_some_and(|name| name == "ts-node" || name == "ts-node-esm") +} + #[cfg(windows)] fn with_windows_extensions(path: &Path) -> [PathBuf; 2] { [path.with_extension("exe"), path.with_extension("cmd")] diff --git a/src/main.rs b/src/main.rs index 1943d69f..2ef531b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod args; mod auth; #[allow(dead_code)] mod config; +mod datasets; mod env; #[cfg(unix)] mod eval; @@ -68,6 +69,7 @@ Projects & resources experiments Manage experiments Data & evaluation + datasets Manage datasets eval Run eval files sql Run SQL queries against Braintrust sync Synchronize project logs between Braintrust and local NDJSON files @@ -148,6 +150,8 @@ enum Commands { Functions(CLIArgs), /// Manage experiments Experiments(CLIArgs), + /// Manage datasets + Datasets(CLIArgs), /// Synchronize project logs between Braintrust and local NDJSON files Sync(CLIArgs), /// Local utility commands @@ -179,6 +183,7 @@ impl Commands { Commands::Scorers(cmd) => &cmd.base, Commands::Functions(cmd) => &cmd.base, Commands::Experiments(cmd) => &cmd.base, + Commands::Datasets(cmd) => &cmd.base, Commands::Sync(cmd) => &cmd.base, Commands::Util(cmd) => &cmd.base, Commands::Switch(cmd) => &cmd.base, @@ -204,6 +209,7 @@ impl Commands { Commands::Scorers(cmd) => &mut cmd.base, Commands::Functions(cmd) => &mut cmd.base, Commands::Experiments(cmd) => &mut cmd.base, + Commands::Datasets(cmd) => &mut cmd.base, Commands::Sync(cmd) => &mut cmd.base, Commands::Util(cmd) => &mut cmd.base, Commands::Switch(cmd) => &mut cmd.base, @@ -281,6 +287,7 @@ fn try_main() -> Result<()> { Commands::Scorers(cmd) => scorers::run(cmd.base, cmd.args).await?, Commands::Functions(cmd) => functions::run(cmd.base, cmd.args).await?, Commands::Experiments(cmd) => experiments::run(cmd.base, cmd.args).await?, + Commands::Datasets(cmd) => datasets::run(cmd.base, cmd.args).await?, Commands::Sync(cmd) => sync::run(cmd.base, cmd.args).await?, Commands::Util(cmd) => util_cmd::run(cmd.base, cmd.args).await?, Commands::SelfCommand(cmd) => self_update::run(cmd.base, cmd.args).await?, diff --git a/src/sync.rs b/src/sync.rs index 72711854..43945a00 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -27,6 +27,8 @@ use crate::http::ApiClient; use crate::projects::api::{create_project, list_projects, Project}; use crate::ui::{animations_enabled, fuzzy_select, is_quiet}; +pub(crate) mod discovery; + const STATE_SCHEMA_VERSION: u32 = 1; const DEFAULT_PULL_LIMIT: usize = 100; const DEFAULT_PAGE_SIZE: usize = 1000; @@ -1476,8 +1478,7 @@ async fn process_trace_chunk( let serialized = rows .iter() .map(|row| { - let line = - serde_json::to_string(row).context("failed to serialize trace row")?; + let line = serialize_jsonl_value(row)?; bytes_written += (line.len() + 1) as u64; Result::::Ok(line) }) @@ -2057,16 +2058,21 @@ fn run_status(json_output: bool, args: StatusArgs) -> Result<()> { Ok(()) } -async fn execute_btql_query( +async fn execute_btql_request( client: &ApiClient, ctx: &LoginContext, - query: &str, + query: &Q, + query_source: &str, btql_retry_tracker: Option>, -) -> Result { +) -> Result +where + Q: Serialize + ?Sized, + T: DeserializeOwned, +{ let body = json!({ "query": query, "fmt": "json", - "query_source": "bt_sync_9f4b1e6d7c2a4a7b8d4f9a6c2b1e7f3d", + "query_source": query_source, }); let org_name = ctx.login.org_name().unwrap_or_default(); let client = client.clone(); @@ -2101,7 +2107,7 @@ async fn execute_btql_query( Ok(response) => { let status = response.status(); if status.is_success() { - return response.json::().await.map_err(|err| { + return response.json::().await.map_err(|err| { BackoffError::permanent(anyhow!("failed to parse BTQL response: {err}")) }); } @@ -2148,6 +2154,31 @@ async fn execute_btql_query( }) } +async fn execute_btql_query( + client: &ApiClient, + ctx: &LoginContext, + query: &str, + btql_retry_tracker: Option>, +) -> Result { + execute_btql_request( + client, + ctx, + query, + "bt_sync_9f4b1e6d7c2a4a7b8d4f9a6c2b1e7f3d", + btql_retry_tracker, + ) + .await +} + +async fn execute_btql_json_query( + client: &ApiClient, + ctx: &LoginContext, + query: &Value, + query_source: &str, +) -> Result { + execute_btql_request(client, ctx, query, query_source, None).await +} + async fn execute_btql_query_timed( client: &ApiClient, ctx: &LoginContext, @@ -3728,8 +3759,68 @@ fn open_jsonl_part_writer(base_dir: &Path, append: bool) -> Result Result> { + let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?; + let reader = BufReader::new(file); + let mut values = Vec::new(); + for (index, line) in reader.lines().enumerate() { + let line = line.with_context(|| { + format!("failed to read line {} from {}", index + 1, path.display()) + })?; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + values.push(serde_json::from_str(trimmed).with_context(|| { + format!( + "failed to parse JSON on line {} from {}", + index + 1, + path.display() + ) + })?); + } + Ok(values) +} + +fn serialize_jsonl_value(value: &T) -> Result { + serde_json::to_string(value).context("failed to serialize row to JSONL") +} + +pub(crate) fn write_jsonl_value( + writer: &mut dyn Write, + value: &T, +) -> Result { + let encoded = serialize_jsonl_value(value)?; + writer + .write_all(encoded.as_bytes()) + .context("failed to write JSONL row")?; + writer + .write_all(b"\n") + .context("failed to write JSONL newline")?; + Ok(encoded.len() + 1) +} + +pub(crate) fn write_jsonl_values(out: Option<&Path>, values: &[T]) -> Result<()> { + let mut writer: Box = if let Some(path) = out { + if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) { + fs::create_dir_all(parent) + .with_context(|| format!("failed to create {}", parent.display()))?; + } + Box::new(BufWriter::new(File::create(path).with_context(|| { + format!("failed to create {}", path.display()) + })?)) + } else { + Box::new(BufWriter::new(std::io::stdout())) + }; + + for value in values { + write_jsonl_value(writer.as_mut(), value)?; + } + writer.flush().context("failed to flush JSONL output") +} + fn write_jsonl_row(writer: &mut JsonlPartWriter, row: &Map) -> Result { - let encoded = serde_json::to_string(row).context("failed to serialize row to JSONL")?; + let encoded = serialize_jsonl_value(row)?; writer .write_line(&encoded) .context("failed to write JSONL row") @@ -4288,6 +4379,63 @@ fn spinner_bar(message: &str) -> ProgressBar { mod tests { use super::*; + #[test] + fn write_jsonl_value_serializes_one_line_and_reports_bytes() -> Result<()> { + let mut output = Vec::new(); + + let bytes = write_jsonl_value(&mut output, &json!({ "id": "row-1" }))?; + + assert_eq!(bytes, output.len()); + assert_eq!(String::from_utf8(output)?, "{\"id\":\"row-1\"}\n"); + Ok(()) + } + + #[test] + fn read_jsonl_values_skips_blank_lines() -> Result<()> { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or_default(); + let path = std::env::temp_dir().join(format!( + "bt-sync-read-jsonl-values-{}-{}.jsonl", + std::process::id(), + unique + )); + + fs::write(&path, "{\"id\":\"row-1\"}\n\n{\"id\":\"row-2\"}\n")?; + let values = read_jsonl_values(&path)?; + + assert_eq!( + values, + vec![json!({ "id": "row-1" }), json!({ "id": "row-2" })] + ); + let _ = fs::remove_file(&path); + Ok(()) + } + + #[test] + fn write_jsonl_values_writes_file() -> Result<()> { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or_default(); + let path = std::env::temp_dir().join(format!( + "bt-sync-write-jsonl-values-{}-{}.jsonl", + std::process::id(), + unique + )); + + write_jsonl_values( + Some(&path), + &[json!({ "id": "row-1" }), json!({ "id": "row-2" })], + )?; + + let content = fs::read_to_string(&path)?; + assert_eq!(content, "{\"id\":\"row-1\"}\n{\"id\":\"row-2\"}\n"); + let _ = fs::remove_file(&path); + Ok(()) + } + #[test] fn push_checkpoint_line_offset_advances_only_after_commit() { let mut state = diff --git a/src/sync/discovery.rs b/src/sync/discovery.rs new file mode 100644 index 00000000..8be0b378 --- /dev/null +++ b/src/sync/discovery.rs @@ -0,0 +1,223 @@ +use std::collections::HashSet; + +use anyhow::Result; +use serde::Deserialize; +use serde_json::{json, Map, Value}; + +use crate::auth::LoginContext; +use crate::http::ApiClient; +use crate::sync::execute_btql_json_query; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ProjectLogRefScope { + Trace, + Span, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct ProjectLogRef { + pub(crate) root_span_id: String, + pub(crate) id: Option, +} + +impl ProjectLogRef { + pub(crate) fn to_value(&self) -> Value { + match self.id.as_deref() { + Some(id) => json!({ "root_span_id": self.root_span_id, "id": id }), + None => json!({ "root_span_id": self.root_span_id }), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct ProjectLogRefDiscoveryResult { + pub(crate) refs: usize, + pub(crate) pages: usize, +} + +#[derive(Debug, Deserialize)] +struct DiscoveryBtqlResponse { + data: Vec>, + #[serde(default)] + cursor: Option, +} + +pub(crate) async fn discover_project_log_refs( + client: &ApiClient, + ctx: &LoginContext, + project_id: &str, + filter: Option<&Value>, + scope: ProjectLogRefScope, + target: usize, + page_size: usize, + mut on_ref: F, +) -> Result +where + F: FnMut(ProjectLogRef) -> Result<()>, +{ + let mut seen = HashSet::new(); + let mut cursor: Option = None; + let mut pages = 0usize; + while seen.len() < target { + let limit = discovery_page_limit(scope, target - seen.len(), page_size); + let query = + build_project_log_ref_query(project_id, filter, limit, cursor.as_deref(), scope); + let response = execute_discovery_btql(client, ctx, &query).await?; + let row_count = response.data.len(); + + for row in response.data { + if seen.len() >= target { + break; + } + let Some(reference) = project_log_ref_from_row(&row, scope) else { + continue; + }; + if seen.insert(reference.clone()) { + on_ref(reference)?; + } + } + + pages += 1; + cursor = response.cursor.filter(|c| !c.is_empty()); + if row_count == 0 || cursor.is_none() { + break; + } + } + Ok(ProjectLogRefDiscoveryResult { + refs: seen.len(), + pages, + }) +} + +fn discovery_page_limit(scope: ProjectLogRefScope, remaining: usize, page_size: usize) -> usize { + match scope { + ProjectLogRefScope::Trace => page_size.min(1000), + ProjectLogRefScope::Span => remaining.min(page_size).min(1000), + } +} + +async fn execute_discovery_btql( + client: &ApiClient, + ctx: &LoginContext, + query: &Value, +) -> Result { + execute_btql_json_query(client, ctx, query, "bt_sync_discovery").await +} + +fn build_project_log_ref_query( + project_id: &str, + filter: Option<&Value>, + page_size: usize, + cursor: Option<&str>, + scope: ProjectLogRefScope, +) -> Value { + let select = match scope { + ProjectLogRefScope::Trace => vec![btql_select_field("root_span_id")], + ProjectLogRefScope::Span => { + vec![btql_select_field("root_span_id"), btql_select_field("id")] + } + }; + + let mut query = json!({ + "select": select, + "from": { + "op": "function", + "name": { "op": "ident", "name": ["project_logs"] }, + "args": [{ "op": "literal", "value": project_id }], + "shape": "spans" + }, + "limit": page_size, + "sort": [{ + "expr": { "op": "ident", "name": ["_pagination_key"] }, + "dir": "desc" + }] + }); + + if let Some(filter_expr) = filter { + query["filter"] = filter_expr.clone(); + } + if let Some(c) = cursor { + query["cursor"] = Value::String(c.to_string()); + } + query +} + +fn project_log_ref_from_row( + row: &Map, + scope: ProjectLogRefScope, +) -> Option { + let root_span_id = row_string(row, "root_span_id")?; + match scope { + ProjectLogRefScope::Trace => Some(ProjectLogRef { + root_span_id, + id: None, + }), + ProjectLogRefScope::Span => Some(ProjectLogRef { + root_span_id, + id: Some(row_string(row, "id")?), + }), + } +} + +fn btql_select_field(field: &str) -> Value { + json!({ + "alias": field, + "expr": { "op": "ident", "name": [field] } + }) +} + +fn row_string(row: &Map, key: &str) -> Option { + row.get(key) + .and_then(Value::as_str) + .map(ToString::to_string) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn project_log_ref_from_row_uses_trace_scope() { + let row = Map::from_iter([ + ("root_span_id".to_string(), json!("root-1")), + ("id".to_string(), json!("span-1")), + ]); + + assert_eq!( + project_log_ref_from_row(&row, ProjectLogRefScope::Trace), + Some(ProjectLogRef { + root_span_id: "root-1".to_string(), + id: None, + }) + ); + } + + #[test] + fn project_log_ref_from_row_uses_span_scope() { + let row = Map::from_iter([ + ("root_span_id".to_string(), json!("root-1")), + ("id".to_string(), json!("span-1")), + ]); + + assert_eq!( + project_log_ref_from_row(&row, ProjectLogRefScope::Span), + Some(ProjectLogRef { + root_span_id: "root-1".to_string(), + id: Some("span-1".to_string()), + }) + ); + } + + #[test] + fn span_scope_page_limit_uses_remaining_target() { + assert_eq!(discovery_page_limit(ProjectLogRefScope::Span, 3, 1000), 3); + } + + #[test] + fn trace_scope_page_limit_keeps_full_page_for_dedupe() { + assert_eq!( + discovery_page_limit(ProjectLogRefScope::Trace, 3, 1000), + 1000 + ); + } +} From 52ebd4b31372f8e437988d97c4e479a3610e9558 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Fri, 1 May 2026 12:03:20 -0400 Subject: [PATCH 2/7] consolidate --- src/datasets/api.rs | 13 ++ src/datasets/pipeline.rs | 263 +++++++++++---------------------------- src/datasets/records.rs | 6 + 3 files changed, 93 insertions(+), 189 deletions(-) diff --git a/src/datasets/api.rs b/src/datasets/api.rs index fbb4d6e1..dd4e5125 100644 --- a/src/datasets/api.rs +++ b/src/datasets/api.rs @@ -186,6 +186,16 @@ pub async fn create_dataset( project_id: &str, name: &str, description: Option<&str>, +) -> Result { + create_dataset_with_metadata(client, project_id, name, description, None).await +} + +pub async fn create_dataset_with_metadata( + client: &ApiClient, + project_id: &str, + name: &str, + description: Option<&str>, + metadata: Option<&Value>, ) -> Result { let mut body = serde_json::json!({ "name": name, @@ -195,6 +205,9 @@ pub async fn create_dataset( if let Some(description) = description.filter(|description| !description.is_empty()) { body["description"] = serde_json::Value::String(description.to_string()); } + if let Some(metadata) = metadata { + body["metadata"] = metadata.clone(); + } client.post("/v1/dataset", &body).await } diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs index 6c678106..5bfb369b 100644 --- a/src/datasets/pipeline.rs +++ b/src/datasets/pipeline.rs @@ -4,12 +4,10 @@ use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use anyhow::{bail, Context, Result}; -use braintrust_sdk_rust::Logs3BatchUploader; use clap::{Args, Subcommand}; use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Map, Value}; -use urlencoding::encode; +use serde::Deserialize; +use serde_json::{json, Value}; use crate::args::BaseArgs; use crate::auth::{login, resolved_runner_env, LoginContext}; @@ -19,6 +17,8 @@ use crate::projects::api::{create_project, get_project_by_name, Project}; use crate::sync::discovery::{discover_project_log_refs, ProjectLogRefScope}; use crate::sync::{read_jsonl_values, write_jsonl_value, write_jsonl_values}; +use super::{api as datasets_api, records, utils, ResolvedContext}; + const RUNNER_FILE: &str = "dataset-pipeline-runner.ts"; const RUNNER_SOURCE: &str = include_str!("../../scripts/dataset-pipeline-runner.ts"); @@ -242,68 +242,6 @@ impl PipelineScope { } } -#[derive(Debug, Clone, Deserialize)] -struct NamedObject { - id: String, - name: String, -} - -#[derive(Debug, Clone, Deserialize)] -struct CreatedDataset { - id: String, -} - -#[derive(Debug, Deserialize)] -struct NamedObjectListResponse { - objects: Vec, -} - -#[derive(Debug, Deserialize)] -#[serde(deny_unknown_fields)] -struct DatasetPipelineRow { - id: Option, - input: Option, - expected: Option, - output: Option, - tags: Option>, - metadata: Option>, - origin: Option, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -struct DatasetPipelineObjectReference { - object_type: String, - object_id: String, - id: String, - #[serde(rename = "_xact_id", skip_serializing_if = "Option::is_none")] - xact_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - created: Option, -} - -#[derive(Debug, Serialize)] -struct DatasetPipelineUploadRow { - project_id: String, - dataset_id: String, - id: String, - span_id: String, - root_span_id: String, - span_parents: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - input: Option, - #[serde(skip_serializing_if = "Option::is_none")] - expected: Option, - #[serde(skip_serializing_if = "Option::is_none")] - output: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tags: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - metadata: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - origin: Option, -} - #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct PipelineTransformResponse { @@ -476,78 +414,53 @@ async fn upload_dataset_rows( target: &PipelineTargetInspect, rows: Vec, ) -> Result { - let mut target_base = base.clone(); - if let Some(org_name) = target.org_name.as_deref() { - target_base.org_name = Some(org_name.to_string()); - } - let ctx = login(&target_base).await?; - let client = ApiClient::new(&ctx)?; - let project = resolve_target_project(&client, target).await?; - let dataset = resolve_target_dataset(&client, target, &project).await?; - let upload_run_id = chrono::Utc::now().timestamp_millis().to_string(); + let target_ctx = resolve_target_context(base, target).await?; + let dataset = resolve_target_dataset(&target_ctx.client, target, &target_ctx.project).await?; + let records = prepare_pipeline_records(rows)?; + let inserted = records.len(); + + utils::submit_prepared_records( + &target_ctx, + &dataset.id, + &records, + false, + "Uploading dataset rows...", + "dataset pipeline upload failed", + ) + .await?; - let mut prepared_rows = Vec::with_capacity(rows.len()); + Ok(inserted) +} + +fn prepare_pipeline_records(rows: Vec) -> Result> { + let mut objects = Vec::with_capacity(rows.len()); for (index, row) in rows.into_iter().enumerate() { - let row: DatasetPipelineRow = serde_json::from_value(row).with_context(|| { - format!("dataset pipeline row {index} does not match the expected dataset row schema") - })?; - let row = - prepare_dataset_row_for_upload(row, &project.id, &dataset.id, &upload_run_id, index); - prepared_rows.push(upload_row_to_map(row)?); + match row { + Value::Object(row) => objects.push(row), + _ => bail!("dataset pipeline row {} must be a JSON object", index + 1), + } } - let mut uploader = Logs3BatchUploader::new( - ctx.api_url.clone(), - ctx.login - .api_key() - .context("login state missing API key for dataset pipeline upload")?, - ctx.login - .org_name() - .filter(|org_name| !org_name.trim().is_empty()), - ) - .context("failed to initialize dataset pipeline uploader")?; - for chunk in prepared_rows.chunks(1000) { - uploader - .upload_rows(chunk, 1000) - .await - .map_err(|err| anyhow::anyhow!("dataset pipeline upload failed: {err}"))?; - } - Ok(prepared_rows.len()) -} - -fn prepare_dataset_row_for_upload( - row: DatasetPipelineRow, - project_id: &str, - dataset_id: &str, - upload_run_id: &str, - row_index: usize, -) -> DatasetPipelineUploadRow { - let id = row - .id - .clone() - .unwrap_or_else(|| format!("dataset-pipeline-{upload_run_id}-{row_index}")); - - DatasetPipelineUploadRow { - project_id: project_id.to_string(), - dataset_id: dataset_id.to_string(), - span_id: id.clone(), - root_span_id: id.clone(), - id, - span_parents: Vec::new(), - input: row.input, - expected: row.expected, - output: row.output, - tags: row.tags, - metadata: row.metadata, - origin: row.origin, - } + records::prepare_upload_records(objects) + .context("dataset pipeline transform produced invalid dataset rows") } -fn upload_row_to_map(row: DatasetPipelineUploadRow) -> Result> { - match serde_json::to_value(row).context("failed to serialize dataset pipeline upload row")? { - Value::Object(row) => Ok(row), - _ => bail!("serialized dataset pipeline upload row was not an object"), +async fn resolve_target_context( + base: &BaseArgs, + target: &PipelineTargetInspect, +) -> Result { + let mut target_base = base.clone(); + if let Some(org_name) = target.org_name.as_deref() { + target_base.org_name = Some(org_name.to_string()); } + let ctx = login(&target_base).await?; + let client = ApiClient::new(&ctx)?; + let project = resolve_target_project(&client, target).await?; + Ok(ResolvedContext { + client, + app_url: ctx.app_url, + project, + }) } async fn resolve_target_project( @@ -582,20 +495,18 @@ async fn resolve_target_dataset( client: &ApiClient, target: &PipelineTargetInspect, project: &Project, -) -> Result { +) -> Result { let dataset_name = target.dataset_name.trim(); if dataset_name.is_empty() { bail!("dataset pipeline target.datasetName cannot be empty"); } - let objects = list_project_datasets(client, &project.id).await?; - if let Some(dataset) = objects + let datasets = datasets_api::list_datasets(client, &project.id).await?; + if let Some(dataset) = datasets .iter() - .find(|object| object.id == dataset_name || object.name == dataset_name) + .find(|dataset| dataset.id == dataset_name || dataset.name == dataset_name) { - return Ok(CreatedDataset { - id: dataset.id.clone(), - }); + return Ok(dataset.clone()); } if is_uuid_like(dataset_name) { @@ -606,41 +517,15 @@ async fn resolve_target_dataset( ); } - create_dataset(client, &project.id, target) - .await - .with_context(|| format!("dataset '{dataset_name}' not found, and creating it failed")) -} - -async fn list_project_datasets(client: &ApiClient, project_id: &str) -> Result> { - let path = format!( - "/v1/dataset?org_name={}&project_id={}", - encode(client.org_name()), - encode(project_id) - ); - let response: NamedObjectListResponse = client.get(&path).await?; - Ok(response.objects) -} - -async fn create_dataset( - client: &ApiClient, - project_id: &str, - target: &PipelineTargetInspect, -) -> Result { - let mut body = json!({ - "name": target.dataset_name.clone(), - "project_id": project_id, - "org_name": client.org_name(), - }); - if let (Value::Object(body), Some(description)) = (&mut body, target.description.as_deref()) { - body.insert( - "description".to_string(), - Value::String(description.to_string()), - ); - } - if let (Value::Object(body), Some(metadata)) = (&mut body, target.metadata.as_ref()) { - body.insert("metadata".to_string(), metadata.clone()); - } - client.post("/v1/dataset", &body).await + datasets_api::create_dataset_with_metadata( + client, + &project.id, + dataset_name, + target.description.as_deref(), + target.metadata.as_ref(), + ) + .await + .with_context(|| format!("dataset '{dataset_name}' not found, and creating it failed")) } async fn discover_refs( @@ -850,19 +735,21 @@ mod tests { use super::*; #[test] - fn dataset_pipeline_row_rejects_unknown_fields() { - let err = serde_json::from_value::(json!({ + fn prepare_pipeline_records_reuses_dataset_record_validation() { + let err = prepare_pipeline_records(vec![json!({ "input": "hello", "span_attributes": { "type": "llm" }, - })) + })]) .expect_err("unexpected dataset row fields should be rejected"); - assert!(err.to_string().contains("unknown field")); + assert!(err + .to_string() + .contains("dataset pipeline transform produced invalid dataset rows")); } #[test] - fn prepare_dataset_row_for_upload_uses_typed_schema() { - let row = serde_json::from_value::(json!({ + fn prepare_pipeline_records_uses_dataset_record_schema() { + let records = prepare_pipeline_records(vec![json!({ "id": "row-1", "input": { "question": "hello" }, "expected": "world", @@ -873,20 +760,18 @@ mod tests { "object_id": "source-project", "id": "source-span" } - })) + })]) .expect("valid dataset pipeline row should deserialize"); - let upload = - prepare_dataset_row_for_upload(row, "target-project", "target-dataset", "run", 0); - let upload = upload_row_to_map(upload).expect("upload row should serialize"); - + assert_eq!(records.len(), 1); + assert_eq!(records[0].id, "row-1"); + let upload = records[0].to_upload_row("target-dataset", false); assert_eq!(upload.get("id"), Some(&json!("row-1"))); - assert_eq!(upload.get("span_id"), Some(&json!("row-1"))); - assert_eq!(upload.get("root_span_id"), Some(&json!("row-1"))); - assert_eq!(upload.get("project_id"), Some(&json!("target-project"))); assert_eq!(upload.get("dataset_id"), Some(&json!("target-dataset"))); - assert!(!upload.contains_key("log_id")); - assert!(!upload.contains_key("experiment_id")); + assert_eq!(upload.get("expected"), Some(&json!("world"))); + assert!(!upload.contains_key("span_id")); + assert!(!upload.contains_key("root_span_id")); + assert!(!upload.contains_key("project_id")); } #[test] diff --git a/src/datasets/records.rs b/src/datasets/records.rs index de84999a..3d9daa09 100644 --- a/src/datasets/records.rs +++ b/src/datasets/records.rs @@ -95,6 +95,12 @@ pub(crate) fn load_refresh_records( prepare_records(raw, id_field, true) } +pub(crate) fn prepare_upload_records( + raw_records: Vec>, +) -> Result> { + prepare_records(raw_records, "id", false) +} + fn load_required_record_objects( input_path: Option<&Path>, inline_rows: Option<&str>, From 5c39933b5b12b4d5f7f3fdc0a49bd0bfdd1fa202 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Fri, 1 May 2026 14:02:23 -0400 Subject: [PATCH 3/7] rm --- src/datasets/pipeline.rs | 2 +- src/datasets/records.rs | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs index 5bfb369b..73529b49 100644 --- a/src/datasets/pipeline.rs +++ b/src/datasets/pipeline.rs @@ -441,7 +441,7 @@ fn prepare_pipeline_records(rows: Vec) -> Result>, -) -> Result> { - prepare_records(raw_records, "id", false) -} - fn load_required_record_objects( input_path: Option<&Path>, inline_rows: Option<&str>, @@ -212,7 +206,7 @@ fn expect_record_object(value: Value, record_number: Option) -> Result>, id_field: &str, require_ids: bool, From 3cbce325a8551fb941df445ac8d8236acd6e781d Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Sun, 3 May 2026 11:43:08 -0400 Subject: [PATCH 4/7] more progress --- README.md | 26 +- scripts/dataset-pipeline-runner.py | 451 +++++++++ scripts/dataset-pipeline-runner.ts | 169 +++- src/datasets/mod.rs | 2 +- src/datasets/pipeline.rs | 1423 ++++++++++++++++++++++++---- src/eval.rs | 244 ++--- src/main.rs | 1 + src/runner_sse.rs | 173 ++++ src/sync.rs | 90 +- 9 files changed, 2197 insertions(+), 382 deletions(-) create mode 100644 scripts/dataset-pipeline-runner.py create mode 100644 src/runner_sse.rs diff --git a/README.md b/README.md index 832a1152..341a2ad5 100644 --- a/README.md +++ b/README.md @@ -180,23 +180,33 @@ bt eval foo.eval.ts -- --description "Prod" --shard=1/4 ### `bt datasets pipeline` -Run TypeScript dataset pipelines declared with `DatasetPipeline(...)` from the `braintrust` SDK. +Run full dataset pipelines declared with `DatasetPipeline(...)`, or stage fetch/transform/push. ```bash # One-shot execution: discover refs, transform, and insert up to 100 new rows. -bt datasets pipeline run ./pipeline.ts --target 100 +bt datasets pipeline run ./pipeline.ts --limit 100 -# Staged execution for human or agent review. -bt datasets pipeline fetch ./pipeline.ts --target 500 --out refs.jsonl -bt datasets pipeline transform ./pipeline.ts --in refs.jsonl --out proposed.jsonl -bt datasets pipeline review ./pipeline.ts --in proposed.jsonl --out approved.jsonl -bt datasets pipeline commit ./pipeline.ts --in approved.jsonl +# Staged execution for inspection or agent editing. +bt datasets pipeline fetch ./pipeline.ts --limit 500 +bt datasets pipeline transform ./pipeline.ts +# Inspect or edit the transformed JSONL, then push to the pipeline target. +bt datasets pipeline push ./pipeline.ts + +# Python pipelines are supported too. +bt datasets pipeline run ./pipeline.py --project "" --limit 100 ``` Useful flags: +- `--limit ` controls how many source refs to discover. - `--root-span-id ` restricts fetching to one or more specific root spans. -- `--extra-where-sql ` appends a source SQL predicate. +- `--root ` controls where staged artifacts are written; it defaults to `bt-sync`. A staged run writes `fetched.jsonl` and `transformed.jsonl` in the same managed directory. +- `--out` can override the managed output path for `fetch` and `transform`. +- `--in` can override the latest fetch artifact for `transform`, or the latest transform artifact for `push`. +- `push` reads the target from the pipeline and delegates to `bt sync push`; pass `--fresh` to restart an already completed push spec. +- `--project ` supplies the active source project when the pipeline source omits a project. +- `--source-project`, `--source-project-id`, `--source-org`, and `--source-filter` explicitly override source fields on `fetch`, `transform`, and `run`. +- `--target-project`, `--target-project-id`, `--target-org`, and `--target-dataset` override target fields on `run` and `push`. - `--max-concurrency ` controls transform concurrency. - `--name ` selects a pipeline when the file defines more than one. diff --git a/scripts/dataset-pipeline-runner.py b/scripts/dataset-pipeline-runner.py new file mode 100644 index 00000000..bfefd218 --- /dev/null +++ b/scripts/dataset-pipeline-runner.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import asyncio +import importlib.util +import json +import os +import socket +import sys +import traceback +from typing import Any + +try: + import braintrust + from braintrust.framework import call_user_fn + from braintrust.logger import _internal_get_global_state, login_to_state + from braintrust.trace import LocalTrace +except Exception as exc: # pragma: no cover - runtime guard + print( + "Unable to import the braintrust package. Please install it in your Python environment.", + file=sys.stderr, + ) + print(str(exc), file=sys.stderr) + sys.exit(1) + + +SOURCE_KEY_MAP = { + "project_id": "projectId", + "project_name": "projectName", + "org_name": "orgName", +} +TARGET_KEY_MAP = { + "project_id": "projectId", + "project_name": "projectName", + "org_name": "orgName", + "dataset_name": "datasetName", +} + + +class SseWriter: + def __init__(self, sock_path: str): + self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._socket.connect(sock_path) + + def send(self, event: str, payload: Any) -> None: + data = payload if isinstance(payload, str) else json.dumps(payload, separators=(",", ":")) + frame = f"event: {event}\ndata: {data}\n\n".encode("utf-8") + self._socket.sendall(frame) + + def close(self) -> None: + self._socket.close() + + +def create_sse_writer() -> SseWriter | None: + sock_path = os.getenv("BT_DATASET_PIPELINE_SSE_SOCK") + if not sock_path: + return None + try: + return SseWriter(sock_path) + except Exception as exc: + print(f"Failed to connect to dataset pipeline socket: {exc}", file=sys.stderr) + return None + + +def camelize_mapping(value: Any, key_map: dict[str, str]) -> Any: + if not isinstance(value, dict): + return value + return { + key_map.get(key, key): camelize_mapping(item, key_map) + for key, item in value.items() + } + + +def object_get(value: Any, name: str) -> Any: + if isinstance(value, dict): + return value.get(name) + return getattr(value, name, None) + + +def pipeline_source(pipeline: Any) -> dict[str, Any]: + source = object_get(pipeline, "source") + if not isinstance(source, dict): + raise RuntimeError("Dataset pipeline source is required.") + return source + + +def pipeline_transform(pipeline: Any) -> Any: + transform = object_get(pipeline, "transform") + if not callable(transform): + raise RuntimeError("Dataset pipeline transform must be callable.") + return transform + + +def load_pipeline_file(file: str) -> Any: + absolute = os.path.abspath(file) + cwd = os.getcwd() + file_dir = os.path.dirname(absolute) + for path in (file_dir, cwd): + if path and path not in sys.path: + sys.path.insert(0, path) + + module_name = f"_bt_dataset_pipeline_{abs(hash(absolute))}" + spec = importlib.util.spec_from_file_location(module_name, absolute) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load {file}.") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def is_pipeline(value: Any) -> bool: + checker = getattr(braintrust, "is_dataset_pipeline_definition", None) + if callable(checker) and checker(value): + return True + return ( + object_get(value, "source") is not None + and object_get(value, "target") is not None + and callable(object_get(value, "transform")) + ) + + +def collect_pipelines(module: Any) -> list[Any]: + pipelines: list[Any] = [] + seen: set[int] = set() + + registered = getattr(braintrust, "get_registered_dataset_pipelines", None) + if callable(registered): + for pipeline in registered(): + if id(pipeline) not in seen: + seen.add(id(pipeline)) + pipelines.append(pipeline) + + for value in vars(module).values(): + if is_pipeline(value) and id(value) not in seen: + seen.add(id(value)) + pipelines.append(value) + + if is_pipeline(module) and id(module) not in seen: + pipelines.append(module) + + return pipelines + + +def select_pipeline(pipelines: list[Any], name: str | None) -> Any: + if name: + matches = [ + pipeline + for pipeline in pipelines + if object_get(pipeline, "name") == name + ] + if not matches: + raise RuntimeError(f"No dataset pipeline named {json.dumps(name)} found.") + if len(matches) > 1: + raise RuntimeError( + f"Multiple dataset pipelines named {json.dumps(name)} found." + ) + return matches[0] + + if not pipelines: + raise RuntimeError("No dataset pipelines found. Did you call DatasetPipeline()?") + if len(pipelines) > 1: + names = ", ".join( + object_get(pipeline, "name") or "" for pipeline in pipelines + ) + raise RuntimeError(f"Multiple dataset pipelines found ({names}). Pass --name.") + return pipelines[0] + + +def parse_stage() -> str: + stage = os.getenv("BT_DATASET_PIPELINE_STAGE") + if stage in {"inspect", "transform"}: + return stage + raise RuntimeError("BT_DATASET_PIPELINE_STAGE must be inspect or transform.") + + +def read_request() -> dict[str, Any]: + text = sys.stdin.read().strip() + if not text: + return {} + value = json.loads(text) + if not isinstance(value, dict): + raise RuntimeError("Dataset pipeline runner request must be an object.") + return value + + +def write_response(value: Any, sse: SseWriter | None) -> None: + if sse is not None: + sse.send("response", value) + sse.close() + else: + print(json.dumps(value, separators=(",", ":"))) + + +def write_progress(sse: SseWriter | None, rows: int) -> None: + if sse is None: + return + sse.send( + "progress", + { + "type": "dataset_pipeline_progress", + "kind": "candidate", + "rows": rows, + }, + ) + + +def require_array_field(request: dict[str, Any], field: str) -> list[Any]: + value = request.get(field) + if not isinstance(value, list): + raise RuntimeError(f"Request field {field} must be an array.") + return value + + +def require_string_field(request: dict[str, Any], field: str) -> str: + value = request.get(field) + if not isinstance(value, str): + raise RuntimeError(f"Request field {field} must be a string.") + return value + + +def optional_positive_integer_field(request: dict[str, Any], field: str) -> int | None: + value = request.get(field) + if value is None: + return None + if not isinstance(value, int) or value <= 0: + raise RuntimeError(f"Request field {field} must be a positive integer.") + return value + + +def merged_source(pipeline: Any, source_override: Any) -> dict[str, Any]: + source = camelize_mapping(pipeline_source(pipeline), SOURCE_KEY_MAP) + if isinstance(source_override, dict): + return {**source, **source_override} + return source + + +def state_for_org(org_name: str | None) -> Any: + state = _internal_get_global_state() + if not org_name: + state.login() + return state + if not getattr(state, "logged_in", False): + state.login(org_name=org_name) + return state + if getattr(state, "org_name", None) == org_name: + return state + return login_to_state(org_name=org_name) + + +def ref_root_span_id(ref: Any) -> str: + if not isinstance(ref, dict) or not isinstance(ref.get("root_span_id"), str): + raise RuntimeError("Discovery ref is missing root_span_id.") + return ref["root_span_id"] + + +def ref_span_row_id(ref: Any) -> str | None: + if isinstance(ref, dict) and isinstance(ref.get("id"), str): + return ref["id"] + return None + + +def hydrate_discovery_refs( + pipeline: Any, + source_override: Any, + source_project_id: str, + refs: list[Any], +) -> list[dict[str, Any]]: + source = merged_source(pipeline, source_override) + state = state_for_org(source.get("orgName")) + candidates: list[dict[str, Any]] = [] + traces_by_root_span_id: dict[str, LocalTrace] = {} + for ref in refs: + root_span_id = ref_root_span_id(ref) + row_id = ref_span_row_id(ref) + trace = traces_by_root_span_id.get(root_span_id) + if trace is None: + trace = LocalTrace( + object_type="project_logs", + object_id=source_project_id, + root_span_id=root_span_id, + ensure_spans_flushed=None, + state=state, + ) + traces_by_root_span_id[root_span_id] = trace + candidate: dict[str, Any] = { + "trace": trace, + } + if row_id: + candidate["id"] = row_id + candidate["origin"] = { + "object_type": "project_logs", + "object_id": source_project_id, + "id": row_id, + } + candidates.append(candidate) + return candidates + + +def span_attr(span: Any, name: str) -> Any: + if isinstance(span, dict): + return span.get(name) + return getattr(span, name, None) + + +async def source_row_for_candidate(candidate: dict[str, Any]) -> Any | None: + row_id = candidate.get("id") + if not isinstance(row_id, str): + return None + + trace = candidate["trace"] + spans = await trace.get_spans(include_scorers=True) + for span in spans: + if row_id in {span_attr(span, "id"), span_attr(span, "span_id")}: + return span + raise RuntimeError(f"Source span row {row_id!r} was not found in hydrated trace.") + + +async def transform_args_for_candidate(candidate: dict[str, Any]) -> dict[str, Any]: + row = await source_row_for_candidate(candidate) + return { + "input": span_attr(row, "input"), + "output": span_attr(row, "output"), + "expected": span_attr(row, "expected"), + "metadata": span_attr(row, "metadata"), + "trace": candidate["trace"], + } + + +def normalize_transform_result(result: Any) -> list[Any]: + if result is None: + return [] + if isinstance(result, list): + return result + return [result] + + +def candidate_fallback_id(candidate: dict[str, Any]) -> str | None: + row_id = candidate.get("id") + if isinstance(row_id, str): + return row_id + trace = candidate.get("trace") + config = trace.get_configuration() if hasattr(trace, "get_configuration") else None + if isinstance(config, dict) and isinstance(config.get("root_span_id"), str): + return config["root_span_id"] + return None + + +def with_pipeline_defaults( + row: Any, + candidate: dict[str, Any], + row_index: int | None, +) -> dict[str, Any]: + if not isinstance(row, dict): + raise RuntimeError("Dataset pipeline transform must return an object row.") + output = dict(row) + fallback_id = candidate_fallback_id(candidate) + if "id" not in output and fallback_id: + output["id"] = fallback_id if row_index is None else f"{fallback_id}:{row_index}" + if "origin" not in output and "origin" in candidate: + output["origin"] = candidate["origin"] + return output + + +async def transform_refs( + pipeline: Any, + source_override: Any, + source_project_id: str, + refs: list[Any], + max_concurrency: int = 16, + sse: SseWriter | None = None, +) -> list[dict[str, Any]]: + if max_concurrency <= 0: + raise RuntimeError("maxConcurrency must be a positive integer.") + transform = pipeline_transform(pipeline) + candidates = hydrate_discovery_refs(pipeline, source_override, source_project_id, refs) + transformed_rows: list[list[dict[str, Any]]] = [[] for _ in candidates] + semaphore = asyncio.Semaphore(max_concurrency) + + async def run_one(index: int, candidate: dict[str, Any]) -> None: + async with semaphore: + transform_args = await transform_args_for_candidate(candidate) + result = await call_user_fn( + asyncio.get_running_loop(), + transform, + **transform_args, + ) + rows = normalize_transform_result(result) + transformed_rows[index] = [ + with_pipeline_defaults( + row, + candidate, + row_index if len(rows) > 1 else None, + ) + for row_index, row in enumerate(rows) + ] + write_progress(sse, len(transformed_rows[index])) + + await asyncio.gather( + *(run_one(index, candidate) for index, candidate in enumerate(candidates)) + ) + return [row for rows in transformed_rows for row in rows] + + +async def main() -> None: + if len(sys.argv) < 2: + raise RuntimeError("Pipeline file is required.") + + module = load_pipeline_file(sys.argv[1]) + pipeline = select_pipeline( + collect_pipelines(module), + os.getenv("BT_DATASET_PIPELINE_NAME") or None, + ) + stage = parse_stage() + sse = create_sse_writer() + + if stage == "inspect": + write_response( + { + "name": object_get(pipeline, "name"), + "source": camelize_mapping(object_get(pipeline, "source"), SOURCE_KEY_MAP), + "target": camelize_mapping(object_get(pipeline, "target"), TARGET_KEY_MAP), + }, + sse, + ) + elif stage == "transform": + request = read_request() + refs = require_array_field(request, "refs") + source_project_id = require_string_field(request, "sourceProjectId") + source_override = ( + request.get("source") if isinstance(request.get("source"), dict) else None + ) + rows = await transform_refs( + pipeline, + source_override, + source_project_id, + refs, + optional_positive_integer_field(request, "maxConcurrency") or 16, + sse, + ) + write_response({"candidates": len(refs), "rowCount": len(rows), "rows": rows}, sse) + else: + raise RuntimeError(f"Unsupported dataset pipeline stage: {stage}") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception: + traceback.print_exc(file=sys.stderr) + sys.exit(1) diff --git a/scripts/dataset-pipeline-runner.ts b/scripts/dataset-pipeline-runner.ts index 8ff7c801..4892938d 100644 --- a/scripts/dataset-pipeline-runner.ts +++ b/scripts/dataset-pipeline-runner.ts @@ -1,5 +1,6 @@ import { createRequire } from "node:module"; import fs from "node:fs"; +import net from "node:net"; import { pathToFileURL } from "node:url"; import path from "node:path"; @@ -9,7 +10,6 @@ type PipelineSource = { orgName?: string; filter?: string; scope?: "span" | "trace"; - limit?: number; }; type PipelineTarget = { @@ -26,11 +26,18 @@ type DatasetPipelineDefinition = { source?: PipelineSource; target?: PipelineTarget; transform?: ( - candidate: HydratedCandidate, - context: { pipeline: DatasetPipelineDefinition }, + args: DatasetPipelineTransformArgs, ) => unknown | Promise; }; +type DatasetPipelineTransformArgs = { + input?: unknown; + output?: unknown; + expected?: unknown; + metadata?: unknown; + trace: unknown; +}; + type BraintrustModule = { DatasetPipeline?: ( definition: DatasetPipelineDefinition, @@ -73,6 +80,11 @@ type HydratedCandidate = { type Stage = "inspect" | "transform"; +type SseWriter = { + send: (event: string, payload: unknown) => void; + close: () => void; +}; + function isObject(value: unknown): value is Record { return typeof value === "object" && value !== null; } @@ -229,8 +241,61 @@ async function readRequest(): Promise { return text.length > 0 ? JSON.parse(text) : {}; } -function writeResponse(value: unknown): void { - process.stdout.write(`${JSON.stringify(value)}\n`); +function writeResponse(value: unknown, sse: SseWriter | null): void { + if (sse) { + sse.send("response", value); + sse.close(); + } else { + process.stdout.write(`${JSON.stringify(value)}\n`); + } +} + +function serializeSseEvent(event: { event?: string; data: string }): string { + return ( + Object.entries(event) + .filter(([_key, value]) => value !== undefined) + .map(([key, value]) => `${key}: ${value}`) + .join("\n") + "\n\n" + ); +} + +function createSseWriter(): SseWriter | null { + const sock = process.env.BT_DATASET_PIPELINE_SSE_SOCK; + if (!sock) { + return null; + } + const socket = net.createConnection({ path: sock }); + socket.on("error", (err) => { + console.error( + `Failed to connect to dataset pipeline socket: ${ + err instanceof Error ? err.message : String(err) + }`, + ); + }); + return { + send: (event: string, payload: unknown) => { + if (!socket.writable) { + return; + } + const data = + typeof payload === "string" ? payload : JSON.stringify(payload); + socket.write(serializeSseEvent({ event, data })); + }, + close: () => { + socket.end(); + }, + }; +} + +function writeProgress(sse: SseWriter | null, rows: number): void { + if (!sse) { + return; + } + sse.send("progress", { + type: "dataset_pipeline_progress", + kind: "candidate", + rows, + }); } function requireArrayField(request: unknown, field: string): unknown[] { @@ -263,11 +328,12 @@ function optionalPositiveIntegerField( function requirePipelineSource( pipeline: DatasetPipelineDefinition, + sourceOverride?: PipelineSource, ): PipelineSource { if (!isObject(pipeline.source)) { throw new Error("Dataset pipeline source is required."); } - return pipeline.source; + return { ...pipeline.source, ...(sourceOverride ?? {}) }; } function requireBraintrustRuntime(braintrust: BraintrustModule) { @@ -318,22 +384,29 @@ function refSpanRowId(ref: DiscoveryRef): string | undefined { async function hydrateDiscoveryRefs( braintrust: BraintrustModule, pipeline: DatasetPipelineDefinition, + sourceOverride: PipelineSource | undefined, sourceProjectId: string, refs: unknown[], ): Promise { requireBraintrustRuntime(braintrust); - const source = requirePipelineSource(pipeline); + const source = requirePipelineSource(pipeline, sourceOverride); const state = await stateForOrg(braintrust, source.orgName); + const tracesByRootSpanId = new Map(); return refs.map((ref) => { const rootSpanId = refRootSpanId(ref); const id = refSpanRowId(ref as DiscoveryRef); - return { - trace: new braintrust.LocalTrace!({ + let trace = tracesByRootSpanId.get(rootSpanId); + if (!trace) { + trace = new braintrust.LocalTrace!({ objectType: "project_logs", objectId: sourceProjectId, rootSpanId, state, - }), + }); + tracesByRootSpanId.set(rootSpanId, trace); + } + return { + trace, ...(id ? { id } : {}), ...(id ? { @@ -348,6 +421,50 @@ async function hydrateDiscoveryRefs( }); } +function spanAttr(row: unknown, name: string): unknown { + return isObject(row) ? row[name] : undefined; +} + +async function sourceRowForCandidate( + candidate: HydratedCandidate, +): Promise { + if (!candidate.id) { + return undefined; + } + const trace = candidate.trace; + if (!isObject(trace) || typeof trace.getSpans !== "function") { + throw new Error("Hydrated trace does not support getSpans()."); + } + const spans = await trace.getSpans({ includeScorers: true }); + if (!Array.isArray(spans)) { + throw new Error("Hydrated trace getSpans() did not return an array."); + } + const row = spans.find( + (span) => + spanAttr(span, "id") === candidate.id || + spanAttr(span, "span_id") === candidate.id, + ); + if (!row) { + throw new Error( + `Source span row ${JSON.stringify(candidate.id)} was not found in hydrated trace.`, + ); + } + return row; +} + +async function transformArgsForCandidate( + candidate: HydratedCandidate, +): Promise { + const row = await sourceRowForCandidate(candidate); + return { + input: spanAttr(row, "input"), + output: spanAttr(row, "output"), + expected: spanAttr(row, "expected"), + metadata: spanAttr(row, "metadata"), + trace: candidate.trace, + }; +} + function normalizeTransformResult(result: unknown): unknown[] { if (result == null) { return []; @@ -398,9 +515,11 @@ function withPipelineDefaults( async function transformRefs( braintrust: BraintrustModule, pipeline: DatasetPipelineDefinition, + sourceOverride: PipelineSource | undefined, sourceProjectId: string, refs: unknown[], maxConcurrency = 16, + sse: SseWriter | null = null, ): Promise { if (!Number.isInteger(maxConcurrency) || maxConcurrency <= 0) { throw new Error("maxConcurrency must be a positive integer."); @@ -411,6 +530,7 @@ async function transformRefs( const candidates = await hydrateDiscoveryRefs( braintrust, pipeline, + sourceOverride, sourceProjectId, refs, ); @@ -421,7 +541,8 @@ async function transformRefs( while (nextIndex < candidates.length) { const index = nextIndex++; const candidate = candidates[index]; - const result = await pipeline.transform!(candidate, { pipeline }); + const args = await transformArgsForCandidate(candidate); + const result = await pipeline.transform!(args); const rows = normalizeTransformResult(result); transformedRows[index] = rows.map((row, rowIndex) => withPipelineDefaults( @@ -430,6 +551,7 @@ async function transformRefs( rows.length > 1 ? rowIndex : undefined, ), ); + writeProgress(sse, transformedRows[index].length); } } @@ -453,25 +575,38 @@ async function main() { process.env.BT_DATASET_PIPELINE_NAME || undefined, ); const stage = parseStage(); + const sse = createSseWriter(); if (stage === "inspect") { - writeResponse({ - name: pipeline.name, - source: pipeline.source, - target: pipeline.target, - }); + writeResponse( + { + name: pipeline.name, + source: pipeline.source, + target: pipeline.target, + }, + sse, + ); } else if (stage === "transform") { const request = await readRequest(); const refs = requireArrayField(request, "refs"); const sourceProjectId = requireStringField(request, "sourceProjectId"); + const sourceOverride = + isObject(request) && isObject(request.source) + ? (request.source as PipelineSource) + : undefined; const rows = await transformRefs( braintrust, pipeline, + sourceOverride, sourceProjectId, refs, optionalPositiveIntegerField(request, "maxConcurrency"), + sse, + ); + writeResponse( + { candidates: refs.length, rowCount: rows.length, rows }, + sse, ); - writeResponse({ candidates: refs.length, rowCount: rows.length, rows }); } else { throw new Error(`Unsupported dataset pipeline stage: ${stage}`); } diff --git a/src/datasets/mod.rs b/src/datasets/mod.rs index 0ae86fd1..e866353e 100644 --- a/src/datasets/mod.rs +++ b/src/datasets/mod.rs @@ -108,7 +108,7 @@ enum DatasetsCommands { View(ViewArgs), /// Delete a dataset Delete(DeleteArgs), - /// Run dataset pipeline workflows + /// Run full dataset pipelines, or stage fetch/transform/push Pipeline(pipeline::PipelineArgs), } diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs index 73529b49..4f98e61e 100644 --- a/src/datasets/pipeline.rs +++ b/src/datasets/pipeline.rs @@ -1,28 +1,53 @@ -use std::fs::File; -use std::io::{self, BufWriter, Write}; +use std::fs; +use std::io::{self, BufRead, BufReader, IsTerminal, Read, Write}; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; use anyhow::{bail, Context, Result}; use clap::{Args, Subcommand}; +use indicatif::{ProgressBar, ProgressStyle}; use serde::de::DeserializeOwned; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use crate::args::BaseArgs; use crate::auth::{login, resolved_runner_env, LoginContext}; use crate::http::ApiClient; -use crate::js_runner::{build_js_runner_command, materialize_runner_script_in_cwd}; +use crate::js_runner::{build_js_runner_command, materialize_runner_script}; use crate::projects::api::{create_project, get_project_by_name, Project}; -use crate::sync::discovery::{discover_project_log_refs, ProjectLogRefScope}; -use crate::sync::{read_jsonl_values, write_jsonl_value, write_jsonl_values}; +use crate::python_runner; +use crate::runner_sse; +use crate::source_language::{classify_runtime_extension, SourceLanguage}; +use crate::sync::discovery::{ + discover_project_log_refs, ProjectLogRefDiscoveryResult, ProjectLogRefScope, +}; +use crate::sync::{ + artifact_base_dir, artifact_spec_dir, create_jsonl_file_writer, epoch_seconds, read_json_file, + read_jsonl_values, stable_spec_hash, write_json_atomic, write_jsonl_value, SyncPushFileArgs, +}; +use tokio::sync::mpsc; use super::{api as datasets_api, records, utils, ResolvedContext}; const RUNNER_FILE: &str = "dataset-pipeline-runner.ts"; const RUNNER_SOURCE: &str = include_str!("../../scripts/dataset-pipeline-runner.ts"); +const PY_RUNNER_FILE: &str = "dataset-pipeline-runner.py"; +const PY_RUNNER_SOURCE: &str = include_str!("../../scripts/dataset-pipeline-runner.py"); +const PIPELINE_ARTIFACT_OBJECT_TYPE: &str = "dataset_pipeline"; +const PIPELINE_ARTIFACT_SCHEMA_VERSION: u32 = 1; #[derive(Debug, Clone, Args)] +#[command(after_help = "\ +Use `run` to run the whole pipeline. + +For staged workflows, run `fetch`, then `transform`, inspect or edit the transformed JSONL, then upload it with: + bt datasets pipeline push ./pipeline.ts + +`push` reads the pipeline target and delegates to `bt sync push`. +")] pub struct PipelineArgs { #[command(subcommand)] command: PipelineCommands, @@ -36,10 +61,8 @@ enum PipelineCommands { Fetch(PipelineFetchArgs), /// Transform candidate JSONL into proposed dataset row JSONL Transform(PipelineTransformArgs), - /// Copy proposed row JSONL for human or agent review - Review(PipelineReviewArgs), - /// Insert approved row JSONL into the target dataset - Commit(PipelineCommitArgs), + /// Push transformed dataset rows to the pipeline target + Push(PipelinePushArgs), } #[derive(Debug, Clone, Args)] @@ -52,7 +75,7 @@ struct PipelineRunnerArgs { #[arg(long)] name: Option, - /// JavaScript/TypeScript runner binary (e.g. tsx, vite-node, ts-node) + /// Runner binary (e.g. tsx, vite-node, ts-node, python) #[arg( long, short = 'r', @@ -62,20 +85,59 @@ struct PipelineRunnerArgs { runner: Option, } +#[derive(Debug, Clone, Args)] +struct PipelineSourceArgs { + /// Override the source project name from the pipeline file + #[arg(long = "source-project")] + source_project: Option, + + /// Override the source project id from the pipeline file + #[arg(long = "source-project-id")] + source_project_id: Option, + + /// Override the source org name from the pipeline file + #[arg(long = "source-org")] + source_org: Option, + + /// Override the source filter from the pipeline file + #[arg(long = "source-filter")] + source_filter: Option, +} + +#[derive(Debug, Clone, Args)] +struct PipelineTargetArgs { + /// Override the target project name from the pipeline file + #[arg(long = "target-project")] + target_project: Option, + + /// Override the target project id from the pipeline file + #[arg(long = "target-project-id")] + target_project_id: Option, + + /// Override the target org name from the pipeline file + #[arg(long = "target-org")] + target_org: Option, + + /// Override the target dataset name from the pipeline file + #[arg(long = "target-dataset")] + target_dataset: Option, +} + #[derive(Debug, Clone, Args)] struct PipelineFetchOptions { /// Maximum number of source refs to discover - #[arg(long, default_value_t = 100, value_parser = parse_positive_usize)] - target: usize, + #[arg( + long, + alias = "target", + default_value_t = 100, + value_parser = parse_positive_usize + )] + limit: usize, /// Restrict the source query to one or more root span ids #[arg(long = "root-span-id")] root_span_ids: Vec, - /// Additional SQL predicate appended to the source WHERE clause - #[arg(long)] - extra_where_sql: Option, - /// Page size for discovery BTQL pagination #[arg(long, default_value_t = 1000, value_parser = parse_positive_usize)] page_size: usize, @@ -88,11 +150,24 @@ struct PipelineTransformOptions { max_concurrency: usize, } +#[derive(Debug, Clone, Args)] +struct PipelineArtifactArgs { + /// Root directory for pipeline artifacts. + #[arg(long, default_value = "bt-sync")] + root: PathBuf, +} + #[derive(Debug, Clone, Args)] struct PipelineRunArgs { #[command(flatten)] runner: PipelineRunnerArgs, + #[command(flatten)] + source: PipelineSourceArgs, + + #[command(flatten)] + target: PipelineTargetArgs, + #[command(flatten)] fetch: PipelineFetchOptions, @@ -105,10 +180,16 @@ struct PipelineFetchArgs { #[command(flatten)] runner: PipelineRunnerArgs, + #[command(flatten)] + artifacts: PipelineArtifactArgs, + + #[command(flatten)] + source: PipelineSourceArgs, + #[command(flatten)] fetch: PipelineFetchOptions, - /// Output JSONL file. Defaults to stdout. + /// Output JSONL file. Defaults to a managed path under --root. #[arg(long)] out: Option, } @@ -118,65 +199,69 @@ struct PipelineTransformArgs { #[command(flatten)] runner: PipelineRunnerArgs, + #[command(flatten)] + artifacts: PipelineArtifactArgs, + + #[command(flatten)] + source: PipelineSourceArgs, + #[command(flatten)] transform: PipelineTransformOptions, - /// Input candidate JSONL file + /// Input candidate JSONL file. Defaults to the latest fetch output under --root. #[arg(long = "in")] - input: PathBuf, + input: Option, - /// Output proposed dataset row JSONL file. Defaults to stdout. + /// Output proposed dataset row JSONL file. Defaults to a managed path under --root. #[arg(long)] out: Option, } #[derive(Debug, Clone, Args)] -struct PipelineReviewArgs { +struct PipelinePushArgs { #[command(flatten)] runner: PipelineRunnerArgs, - /// Input proposed dataset row JSONL file - #[arg(long = "in")] - input: PathBuf, - - /// Output approved dataset row JSONL file. Defaults to stdout. - #[arg(long)] - out: Option, -} + #[command(flatten)] + artifacts: PipelineArtifactArgs, -#[derive(Debug, Clone, Args)] -struct PipelineCommitArgs { #[command(flatten)] - runner: PipelineRunnerArgs, + target: PipelineTargetArgs, - /// Input approved dataset row JSONL file + /// Input transformed dataset row JSONL file. Defaults to the latest transform output under --root. #[arg(long = "in")] - input: PathBuf, + input: Option, + + /// Ignore previous sync push state and upload from the beginning. + #[arg(long)] + fresh: bool, } pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { match args.command { PipelineCommands::Run(args) => { - let inspect = inspect_pipeline(&base, &args.runner).await?; + let inspect = inspect_with_overrides( + inspect_pipeline(&base, &args.runner).await?, + Some(&args.source), + Some(&args.target), + ); let tempdir = tempfile::tempdir().context("failed to create dataset pipeline temp dir")?; let refs_path = tempdir.path().join("discovered.jsonl"); - discover_refs(&base, &inspect, &args.fetch, Some(&refs_path), false).await?; + discover_refs(&base, &inspect, &args.fetch, &refs_path).await?; let refs = read_jsonl_values(&refs_path)?; let source_project = resolve_pipeline_source_project(&base, &inspect.source).await?; - let transform_response: PipelineTransformResponse = run_runner_json( + let transform_response = transform_source_refs( &base, - "transform", &args.runner, - &json!({ - "sourceProjectId": source_project.id, - "refs": refs, - "maxConcurrency": args.transform.max_concurrency, - }), + &source_project.id, + &inspect.source, + refs, + args.transform.max_concurrency, + None, ) .await?; - validate_transform_response(&transform_response)?; let row_count = transform_response.rows.len(); let inserted = upload_dataset_rows(&base, &inspect.target, transform_response.rows).await?; @@ -191,12 +276,15 @@ pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { ) } PipelineCommands::Fetch(args) => { - let inspect = inspect_pipeline(&base, &args.runner).await?; - discover_refs(&base, &inspect, &args.fetch, args.out.as_deref(), true).await + let inspect = inspect_with_overrides( + inspect_pipeline(&base, &args.runner).await?, + Some(&args.source), + None, + ); + fetch_refs(&base, args, inspect).await } PipelineCommands::Transform(args) => transform_refs(&base, args).await, - PipelineCommands::Review(args) => review_rows(&base, args), - PipelineCommands::Commit(args) => commit_rows(&base, args).await, + PipelineCommands::Push(args) => push_rows(&base, args).await, } } @@ -207,18 +295,22 @@ struct PipelineInspect { target: PipelineTargetInspect, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] struct PipelineSourceInspect { + #[serde(skip_serializing_if = "Option::is_none")] project_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] project_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] org_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] filter: Option, + #[serde(skip_serializing_if = "Option::is_none")] scope: Option, - limit: Option, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] struct PipelineTargetInspect { project_id: Option, @@ -229,7 +321,7 @@ struct PipelineTargetInspect { metadata: Option, } -#[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] enum PipelineScope { Span, @@ -250,21 +342,225 @@ struct PipelineTransformResponse { rows: Vec, } +#[derive(Debug)] +enum PipelineRunnerEvent { + Response(Value), + Progress(PipelineProgressEvent), + Error { + message: String, + stack: Option, + status: Option, + }, + Console { + stream: String, + message: String, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineProgressEvent { + #[serde(rename = "type")] + kind_type: String, + kind: String, + #[serde(default)] + rows: Option, +} + +#[derive(Debug, Deserialize)] +struct PipelineRunnerErrorPayload { + message: String, + #[serde(default)] + stack: Option, + #[serde(default)] + status: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum PipelineArtifactStage { + Fetch, + Transform, +} + +impl PipelineArtifactStage { + fn command(self) -> &'static str { + match self { + PipelineArtifactStage::Fetch => "fetch", + PipelineArtifactStage::Transform => "transform", + } + } + + fn output_file(self) -> &'static str { + match self { + PipelineArtifactStage::Fetch => "fetched.jsonl", + PipelineArtifactStage::Transform => "transformed.jsonl", + } + } + + fn spec_file(self) -> &'static str { + match self { + PipelineArtifactStage::Fetch => "fetch.spec.json", + PipelineArtifactStage::Transform => "transform.spec.json", + } + } + + fn manifest_file(self) -> &'static str { + match self { + PipelineArtifactStage::Fetch => "fetch.manifest.json", + PipelineArtifactStage::Transform => "transform.manifest.json", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineFetchArtifactOptions { + limit: usize, + root_span_ids: Vec, + page_size: usize, +} + +impl From<&PipelineFetchOptions> for PipelineFetchArtifactOptions { + fn from(options: &PipelineFetchOptions) -> Self { + Self { + limit: options.limit, + root_span_ids: options.root_span_ids.clone(), + page_size: options.page_size, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineTransformArtifactOptions { + max_concurrency: usize, +} + +impl From<&PipelineTransformOptions> for PipelineTransformArtifactOptions { + fn from(options: &PipelineTransformOptions) -> Self { + Self { + max_concurrency: options.max_concurrency, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineArtifactSpec { + schema_version: u32, + kind: String, + pipeline: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + cli_project: Option, + #[serde(skip_serializing_if = "Option::is_none")] + cli_org: Option, + stage: PipelineArtifactStage, + #[serde(skip_serializing_if = "Option::is_none")] + source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + target: Option, + #[serde(skip_serializing_if = "Option::is_none")] + fetch: Option, + #[serde(skip_serializing_if = "Option::is_none")] + transform: Option, + #[serde(skip_serializing_if = "Option::is_none")] + input_path: Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum PipelineArtifactStatus { + Completed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PipelineArtifactManifest { + schema_version: u32, + spec_hash: String, + spec: PipelineArtifactSpec, + status: PipelineArtifactStatus, + stage: PipelineArtifactStage, + #[serde(skip_serializing_if = "Option::is_none")] + input_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + output_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + refs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + candidates: Option, + #[serde(skip_serializing_if = "Option::is_none")] + rows: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pages: Option, + started_at: u64, + updated_at: u64, + completed_at: Option, +} + +#[derive(Debug, Clone)] +struct PipelineOutputArtifact { + spec_hash: String, + spec: PipelineArtifactSpec, + stage: PipelineArtifactStage, + spec_dir: PathBuf, + output_path: PathBuf, +} + async fn inspect_pipeline(base: &BaseArgs, runner: &PipelineRunnerArgs) -> Result { - let output = build_runner_command(base, "inspect", runner, |_, _| Ok(())) - .await? - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) - .output() - .context("failed to start dataset pipeline inspect runner")?; - if !output.status.success() { - bail!( - "dataset pipeline inspect runner failed with status {}", - output.status - ); + run_runner_json(base, "inspect", runner, None, |event| { + handle_pipeline_runner_event(None, event); + }) + .await +} + +fn inspect_with_overrides( + mut inspect: PipelineInspect, + source: Option<&PipelineSourceArgs>, + target: Option<&PipelineTargetArgs>, +) -> PipelineInspect { + if let Some(source) = source { + apply_source_overrides(&mut inspect.source, source); + } + if let Some(target) = target { + apply_target_overrides(&mut inspect.target, target); + } + inspect +} + +fn apply_source_overrides(source: &mut PipelineSourceInspect, args: &PipelineSourceArgs) { + if let Some(project_name) = args.source_project.as_deref() { + source.project_name = Some(project_name.to_string()); + source.project_id = None; + } + if let Some(project_id) = args.source_project_id.as_deref() { + source.project_id = Some(project_id.to_string()); + } + if let Some(org_name) = args.source_org.as_deref() { + source.org_name = Some(org_name.to_string()); + } + if let Some(filter) = args.source_filter.as_deref() { + source.filter = Some(filter.to_string()); + } +} + +fn apply_target_overrides(target: &mut PipelineTargetInspect, args: &PipelineTargetArgs) { + if let Some(project_name) = args.target_project.as_deref() { + target.project_name = Some(project_name.to_string()); + target.project_id = None; + } + if let Some(project_id) = args.target_project_id.as_deref() { + target.project_id = Some(project_id.to_string()); + } + if let Some(org_name) = args.target_org.as_deref() { + target.org_name = Some(org_name.to_string()); + } + if let Some(dataset_name) = args.target_dataset.as_deref() { + target.dataset_name = dataset_name.to_string(); } - serde_json::from_slice(&output.stdout) - .context("failed to parse dataset pipeline inspect output") } async fn build_runner_command( @@ -276,11 +572,9 @@ async fn build_runner_command( where F: FnOnce(&mut Command, &'static str) -> Result<()>, { - let runner_script = - materialize_runner_script_in_cwd("dataset-pipeline-runners", RUNNER_FILE, RUNNER_SOURCE)?; let pipeline_file = runner.pipeline.clone(); let files = vec![pipeline_file.clone()]; - let mut command = build_js_runner_command(runner.runner.as_deref(), &runner_script, &files); + let mut command = build_pipeline_runner_command(runner, &pipeline_file, &files)?; command.envs(resolved_runner_env(base).await?); command.env("BT_DATASET_PIPELINE_STAGE", stage); @@ -291,24 +585,365 @@ where Ok(command) } -async fn run_runner_json( +fn build_pipeline_runner_command( + runner: &PipelineRunnerArgs, + pipeline_file: &Path, + files: &[PathBuf], +) -> Result { + match pipeline_language(pipeline_file)? { + SourceLanguage::JsLike => { + let runner_script = materialize_dataset_pipeline_runner(RUNNER_FILE, RUNNER_SOURCE)?; + Ok(build_js_runner_command( + runner.runner.as_deref(), + &runner_script, + files, + )) + } + SourceLanguage::Python => { + let runner_script = + materialize_dataset_pipeline_runner(PY_RUNNER_FILE, PY_RUNNER_SOURCE)?; + let python = python_runner::resolve_python_interpreter_for_roots( + runner.runner.as_deref(), + &["BT_DATASET_PIPELINE_PYTHON"], + files, + ) + .context("No Python interpreter found. Install python, create a virtualenv, or pass --runner.")?; + let mut command = Command::new(python); + command.arg(runner_script).arg(pipeline_file); + Ok(command) + } + } +} + +fn materialize_dataset_pipeline_runner(file_name: &str, source: &str) -> Result { + materialize_runner_script(&dataset_pipeline_runner_cache_dir(), file_name, source) +} + +fn dataset_pipeline_runner_cache_dir() -> PathBuf { + let root = std::env::var_os("XDG_CACHE_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".cache"))) + .unwrap_or_else(std::env::temp_dir); + + root.join("bt") + .join("dataset-pipeline-runners") + .join(env!("CARGO_PKG_VERSION")) +} + +fn pipeline_language(pipeline_file: &Path) -> Result { + let extension = pipeline_file + .extension() + .and_then(|extension| extension.to_str()) + .with_context(|| { + format!( + "dataset pipeline file '{}' has no extension", + pipeline_file.display() + ) + })?; + classify_runtime_extension(extension).with_context(|| { + format!( + "unsupported dataset pipeline file extension '.{extension}'; expected .ts, .tsx, .js, .jsx, or .py" + ) + }) +} + +async fn fetch_refs( + base: &BaseArgs, + args: PipelineFetchArgs, + inspect: PipelineInspect, +) -> Result<()> { + let spec = pipeline_fetch_artifact_spec(base, &args.runner, &inspect.source, &args.fetch); + let artifact = resolve_pipeline_output_artifact( + &args.artifacts.root, + &args.runner, + spec, + args.out.as_deref(), + None, + )?; + artifact.write_spec()?; + let started_at = epoch_seconds(); + let result = discover_refs(base, &inspect, &args.fetch, &artifact.output_path).await?; + artifact.write_manifest(PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: artifact.spec_hash.clone(), + spec: artifact.spec.clone(), + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Fetch, + input_path: None, + output_path: Some(artifact.output_path.display().to_string()), + refs: Some(result.refs), + candidates: None, + rows: None, + pages: Some(result.pages), + started_at, + updated_at: epoch_seconds(), + completed_at: Some(epoch_seconds()), + })?; + print_summary( + base, + json!({ + "refs": result.refs, + "pages": result.pages, + "scope": match PipelineScope::from_source(&inspect.source) { PipelineScope::Trace => "trace", PipelineScope::Span => "span" }, + "out": artifact.output_path.display().to_string(), + }), + false, + ) +} + +fn pipeline_fetch_artifact_spec( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + source: &PipelineSourceInspect, + options: &PipelineFetchOptions, +) -> PipelineArtifactSpec { + base_pipeline_artifact_spec(base, runner, PipelineArtifactStage::Fetch) + .with_source(source.clone()) + .with_fetch(options.into()) +} + +fn pipeline_transform_artifact_spec( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + source: &PipelineSourceInspect, + options: &PipelineTransformOptions, + input_path: &Path, +) -> PipelineArtifactSpec { + base_pipeline_artifact_spec(base, runner, PipelineArtifactStage::Transform) + .with_source(source.clone()) + .with_transform(options.into()) + .with_input_path(input_path) +} + +fn base_pipeline_artifact_spec( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + stage: PipelineArtifactStage, +) -> PipelineArtifactSpec { + PipelineArtifactSpec { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + kind: PIPELINE_ARTIFACT_OBJECT_TYPE.to_string(), + pipeline: runner.pipeline.display().to_string(), + name: runner.name.clone(), + cli_project: base.project.clone(), + cli_org: base.org_name.clone(), + stage, + source: None, + target: None, + fetch: None, + transform: None, + input_path: None, + } +} + +impl PipelineArtifactSpec { + fn with_source(mut self, source: PipelineSourceInspect) -> Self { + self.source = Some(source); + self + } + + fn with_fetch(mut self, fetch: PipelineFetchArtifactOptions) -> Self { + self.fetch = Some(fetch); + self + } + + fn with_transform(mut self, transform: PipelineTransformArtifactOptions) -> Self { + self.transform = Some(transform); + self + } + + fn with_input_path(mut self, input_path: &Path) -> Self { + self.input_path = Some(input_path.display().to_string()); + self + } +} + +fn resolve_pipeline_output_artifact( + root: &Path, + runner: &PipelineRunnerArgs, + spec: PipelineArtifactSpec, + explicit_out: Option<&Path>, + input_path: Option<&Path>, +) -> Result { + let spec_hash = stable_spec_hash(&spec)?; + let stage = spec.stage; + let hashed_spec_dir = artifact_spec_dir( + root, + PIPELINE_ARTIFACT_OBJECT_TYPE, + &pipeline_artifact_name(runner), + &spec_hash, + ); + let spec_dir = if matches!(stage, PipelineArtifactStage::Fetch) { + hashed_spec_dir + } else { + input_path + .and_then(Path::parent) + .filter(|parent| !parent.as_os_str().is_empty()) + .map(Path::to_path_buf) + .unwrap_or(hashed_spec_dir) + }; + let output_path = explicit_out + .map(Path::to_path_buf) + .unwrap_or_else(|| spec_dir.join(stage.output_file())); + Ok(PipelineOutputArtifact { + spec_hash, + spec, + stage, + spec_dir, + output_path, + }) +} + +fn resolve_pipeline_input_path( + explicit_input: &Option, + root: &Path, + runner: &PipelineRunnerArgs, + stage: PipelineArtifactStage, +) -> Result { + if let Some(input) = explicit_input { + Ok(input.clone()) + } else { + resolve_latest_pipeline_stage_output(root, runner, stage) + } +} + +fn resolve_latest_pipeline_stage_output( + root: &Path, + runner: &PipelineRunnerArgs, + stage: PipelineArtifactStage, +) -> Result { + let base = artifact_base_dir( + root, + PIPELINE_ARTIFACT_OBJECT_TYPE, + &pipeline_artifact_name(runner), + ); + let mut best: Option<(u64, PathBuf)> = None; + if base.is_dir() { + for entry in + fs::read_dir(&base).with_context(|| format!("failed to read {}", base.display()))? + { + let entry = entry?; + if !entry.file_type()?.is_dir() { + continue; + } + let manifest_path = entry.path().join(stage.manifest_file()); + if !manifest_path.exists() { + continue; + } + let manifest = read_json_file::(&manifest_path)?; + if manifest.stage != stage || manifest.status != PipelineArtifactStatus::Completed { + continue; + } + let Some(output_path) = manifest + .output_path + .as_ref() + .map(PathBuf::from) + .filter(|path| path.exists()) + else { + continue; + }; + if best + .as_ref() + .map(|(best_time, _)| manifest.updated_at > *best_time) + .unwrap_or(true) + { + best = Some((manifest.updated_at, output_path)); + } + } + } + + best.map(|(_, path)| path).ok_or_else(|| { + anyhow::anyhow!( + "no completed dataset pipeline {} output found for '{}'. run `bt datasets pipeline {} {}` first or pass --in", + stage.command(), + pipeline_artifact_name(runner), + stage.command(), + runner.pipeline.display() + ) + }) +} + +fn pipeline_artifact_name(runner: &PipelineRunnerArgs) -> String { + runner + .name + .clone() + .or_else(|| { + runner + .pipeline + .file_stem() + .and_then(|stem| stem.to_str()) + .map(ToString::to_string) + }) + .unwrap_or_else(|| "pipeline".to_string()) +} + +impl PipelineOutputArtifact { + fn write_spec(&self) -> Result<()> { + write_json_atomic(&self.spec_dir.join(self.stage.spec_file()), &self.spec) + } + + fn write_manifest(&self, manifest: PipelineArtifactManifest) -> Result<()> { + write_json_atomic(&self.spec_dir.join(self.stage.manifest_file()), &manifest) + } +} + +async fn run_runner_json( base: &BaseArgs, stage: &'static str, runner: &PipelineRunnerArgs, - request: &Value, + request: Option<&Value>, + mut on_event: F, ) -> Result where T: DeserializeOwned, + F: FnMut(PipelineRunnerEvent), { let mut command = build_runner_command(base, stage, runner, |_, _| Ok(())).await?; + let (listener, socket_path, socket_cleanup_guard) = + runner_sse::bind_sse_listener("bt-dataset-pipeline")?; + let (tx, rx) = mpsc::unbounded_channel::(); + let sse_connected = Arc::new(AtomicBool::new(false)); + + let tx_sse = tx.clone(); + let sse_connected_for_task = Arc::clone(&sse_connected); + let mut sse_task = tokio::spawn(async move { + match listener.accept().await { + Ok((stream, _)) => { + sse_connected_for_task.store(true, Ordering::Relaxed); + if let Err(err) = runner_sse::read_sse_stream(stream, |event, data| { + handle_pipeline_sse_event(event, data, &tx_sse); + }) + .await + { + let _ = tx_sse.send(PipelineRunnerEvent::Error { + message: format!("SSE stream error: {err}"), + stack: None, + status: None, + }); + } + } + Err(err) => { + let _ = tx_sse.send(PipelineRunnerEvent::Error { + message: format!("Failed to accept SSE connection: {err}"), + stack: None, + status: None, + }); + } + } + }); + + command.env( + "BT_DATASET_PIPELINE_SSE_SOCK", + socket_path.to_string_lossy().to_string(), + ); command.stdin(Stdio::piped()); command.stdout(Stdio::piped()); - command.stderr(Stdio::inherit()); + command.stderr(Stdio::piped()); let mut child = command .spawn() .context("failed to start dataset pipeline runner")?; - { + if let Some(request) = request { let mut stdin = child .stdin .take() @@ -320,37 +955,197 @@ where .context("failed to finish dataset pipeline runner request")?; } - let output = child - .wait_with_output() - .context("failed to wait for dataset pipeline runner")?; - if !output.status.success() { + if let Some(stdout) = child.stdout.take() { + forward_blocking_stream(stdout, "stdout", tx.clone()); + } + if let Some(stderr) = child.stderr.take() { + forward_blocking_stream(stderr, "stderr", tx.clone()); + } + drop(tx); + + let wait_task = tokio::task::spawn_blocking(move || child.wait()); + let mut response: Option = None; + let mut errors = Vec::::new(); + let wait = Box::pin(async move { + wait_task + .await + .context("dataset pipeline runner wait task failed")? + .context("dataset pipeline runner process failed") + }); + let status = runner_sse::drive_runner_events( + rx, + wait, + &mut sse_task, + &sse_connected, + "dataset pipeline runner exited without a status", + |event| match event { + PipelineRunnerEvent::Response(value) => { + response = Some(value); + } + PipelineRunnerEvent::Error { + message, + stack, + status: _, + } => { + errors.push(message.clone()); + if let Some(stack) = stack { + errors.push(stack); + } + on_event(PipelineRunnerEvent::Error { + message, + stack: None, + status: None, + }); + } + event => on_event(event), + }, + ) + .await?; + + let _socket_cleanup_guard = socket_cleanup_guard; + if !status.success() { + let detail = if errors.is_empty() { + String::new() + } else { + format!(": {}", errors.join("\n")) + }; bail!( - "dataset pipeline runner failed with status {}", - output.status + "dataset pipeline runner failed with status {}{}", + status, + detail ); } - serde_json::from_slice(&output.stdout) - .context("failed to parse dataset pipeline runner response") + + let response = response.context("dataset pipeline runner did not send a response")?; + serde_json::from_value(response).context("failed to parse dataset pipeline runner response") +} + +fn handle_pipeline_sse_event( + event: Option, + data: String, + tx: &mpsc::UnboundedSender, +) { + match event.unwrap_or_default().as_str() { + "response" => { + if let Ok(value) = serde_json::from_str::(&data) { + let _ = tx.send(PipelineRunnerEvent::Response(value)); + } + } + "progress" => { + if let Ok(progress) = serde_json::from_str::(&data) { + if progress.kind_type == "dataset_pipeline_progress" { + let _ = tx.send(PipelineRunnerEvent::Progress(progress)); + } + } + } + "error" => { + if let Ok(payload) = serde_json::from_str::(&data) { + let _ = tx.send(PipelineRunnerEvent::Error { + message: payload.message, + stack: payload.stack, + status: payload.status, + }); + } else { + let _ = tx.send(PipelineRunnerEvent::Error { + message: data, + stack: None, + status: None, + }); + } + } + _ => {} + } +} + +fn forward_blocking_stream( + stream: T, + name: &'static str, + tx: mpsc::UnboundedSender, +) where + T: Read + Send + 'static, +{ + std::thread::spawn(move || { + let lines = BufReader::new(stream).lines(); + for line in lines { + match line { + Ok(message) => { + let _ = tx.send(PipelineRunnerEvent::Console { + stream: name.to_string(), + message, + }); + } + Err(err) => { + let _ = tx.send(PipelineRunnerEvent::Error { + message: format!("failed to read dataset pipeline runner {name}: {err}"), + stack: None, + status: None, + }); + break; + } + } + } + }); } async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result<()> { - let inspect = inspect_pipeline(base, &args.runner).await?; + let inspect = inspect_with_overrides( + inspect_pipeline(base, &args.runner).await?, + Some(&args.source), + None, + ); let source_project = resolve_pipeline_source_project(base, &inspect.source).await?; - let refs = read_jsonl_values(&args.input)?; - let response: PipelineTransformResponse = run_runner_json( + let input_path = resolve_pipeline_input_path( + &args.input, + &args.artifacts.root, + &args.runner, + PipelineArtifactStage::Fetch, + )?; + let refs = read_jsonl_values(&input_path)?; + let spec = pipeline_transform_artifact_spec( base, - "transform", &args.runner, - &json!({ - "sourceProjectId": source_project.id, - "refs": refs, - "maxConcurrency": args.transform.max_concurrency, - }), + &inspect.source, + &args.transform, + &input_path, + ); + let artifact = resolve_pipeline_output_artifact( + &args.artifacts.root, + &args.runner, + spec, + args.out.as_deref(), + Some(&input_path), + )?; + artifact.write_spec()?; + let started_at = epoch_seconds(); + let mut writer = create_jsonl_file_writer(&artifact.output_path)?; + let response = transform_source_refs( + base, + &args.runner, + &source_project.id, + &inspect.source, + refs, + args.transform.max_concurrency, + Some(&mut writer as &mut dyn Write), ) .await?; - validate_transform_response(&response)?; - let row_count = response.rows.len(); - write_jsonl_values(args.out.as_deref(), &response.rows)?; + writer.flush().context("failed to flush transform output")?; + artifact.write_manifest(PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: artifact.spec_hash.clone(), + spec: artifact.spec.clone(), + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Transform, + input_path: Some(input_path.display().to_string()), + output_path: Some(artifact.output_path.display().to_string()), + refs: None, + candidates: Some(response.candidates), + rows: Some(response.row_count), + pages: None, + started_at, + updated_at: epoch_seconds(), + completed_at: Some(epoch_seconds()), + })?; + let row_count = response.row_count; print_summary( base, json!({ @@ -358,14 +1153,143 @@ async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result< "rows": row_count, "out": args .out - .as_ref() - .map(|path| path.display().to_string()) - .unwrap_or_else(|| "stdout".to_string()), + .as_deref() + .unwrap_or(&artifact.output_path) + .display() + .to_string(), }), - args.out.is_none(), + false, ) } +async fn transform_source_refs( + base: &BaseArgs, + runner: &PipelineRunnerArgs, + source_project_id: &str, + source: &PipelineSourceInspect, + refs: Vec, + max_concurrency: usize, + mut row_writer: Option<&mut dyn Write>, +) -> Result { + let progress = pipeline_progress_bar(base, refs.len() as u64, "Transforming candidates"); + progress.set_message("output rows: 0"); + let mut combined = PipelineTransformResponse { + candidates: 0, + row_count: 0, + rows: Vec::new(), + }; + let batch_size = max_concurrency.max(1); + let mut completed_candidates = 0usize; + for batch in refs.chunks(batch_size) { + let request = json!({ + "sourceProjectId": source_project_id, + "source": source, + "refs": batch, + "maxConcurrency": max_concurrency, + }); + let mut completed_in_batch = 0usize; + let mut batch_rows = 0usize; + let base_row_count = combined.row_count; + let response: PipelineTransformResponse = run_runner_json( + base, + "transform", + runner, + Some(&request), + |event| match event { + PipelineRunnerEvent::Progress(progress_event) + if progress_event.kind == "candidate" => + { + if completed_in_batch < batch.len() { + completed_in_batch += 1; + completed_candidates += 1; + batch_rows += progress_event.rows.unwrap_or(0); + progress.set_position(completed_candidates.min(refs.len()) as u64); + } + progress.set_message(format!("output rows: {}", base_row_count + batch_rows)); + } + event => handle_pipeline_runner_event(Some(&progress), event), + }, + ) + .await?; + validate_transform_response(&response)?; + if completed_in_batch < batch.len() { + completed_candidates += batch.len() - completed_in_batch; + progress.set_position(completed_candidates.min(refs.len()) as u64); + } + combined.candidates += response.candidates; + combined.row_count += response.row_count; + if let Some(writer) = row_writer.as_deref_mut() { + for row in response.rows { + write_jsonl_value(writer, &row).context("failed to write transform output row")?; + } + writer.flush().context("failed to flush transform output")?; + } else { + combined.rows.extend(response.rows); + } + progress.set_message(format!("output rows: {}", combined.row_count)); + } + progress.finish_and_clear(); + Ok(combined) +} + +fn pipeline_progress_bar(base: &BaseArgs, total: u64, label: &str) -> ProgressBar { + if base.json || base.quiet || !io::stderr().is_terminal() { + return ProgressBar::hidden(); + } + let pb = ProgressBar::new(total); + pb.set_style( + ProgressStyle::with_template( + "{spinner:.cyan} {prefix} [{bar:40.cyan/blue}] {pos}/{len} candidates ({percent:>3}%) | {msg}", + ) + .unwrap(), + ); + pb.set_prefix(label.to_string()); + pb.enable_steady_tick(Duration::from_millis(80)); + pb +} + +fn handle_pipeline_runner_event(progress: Option<&ProgressBar>, event: PipelineRunnerEvent) { + match event { + PipelineRunnerEvent::Console { stream, message } => { + let line = if stream == "stdout" { + format!("[pipeline stdout] {message}") + } else { + message + }; + if let Some(progress) = progress { + progress.suspend(|| eprintln!("{line}")); + } else { + eprintln!("{line}"); + } + } + PipelineRunnerEvent::Error { + message, + stack, + status, + } => { + let line = if let Some(status) = status { + format!("dataset pipeline runner error ({status}): {message}") + } else { + format!("dataset pipeline runner error: {message}") + }; + if let Some(progress) = progress { + progress.suspend(|| { + eprintln!("{line}"); + if let Some(stack) = stack { + eprintln!("{stack}"); + } + }); + } else { + eprintln!("{line}"); + if let Some(stack) = stack { + eprintln!("{stack}"); + } + } + } + PipelineRunnerEvent::Progress(_) | PipelineRunnerEvent::Response(_) => {} + } +} + fn validate_transform_response(response: &PipelineTransformResponse) -> Result<()> { if response.row_count != response.rows.len() { bail!( @@ -377,36 +1301,51 @@ fn validate_transform_response(response: &PipelineTransformResponse) -> Result<( Ok(()) } -fn review_rows(base: &BaseArgs, args: PipelineReviewArgs) -> Result<()> { - let rows = read_jsonl_values(&args.input)?; - write_jsonl_values(args.out.as_deref(), &rows)?; - print_summary( - base, - json!({ - "rows": rows.len(), - "out": args - .out - .as_ref() - .map(|path| path.display().to_string()) - .unwrap_or_else(|| "stdout".to_string()), - }), - args.out.is_none(), +async fn push_rows(base: &BaseArgs, args: PipelinePushArgs) -> Result<()> { + let inspect = inspect_with_overrides( + inspect_pipeline(base, &args.runner).await?, + None, + Some(&args.target), + ); + let input_path = resolve_pipeline_input_path( + &args.input, + &args.artifacts.root, + &args.runner, + PipelineArtifactStage::Transform, + )?; + let target_base = base_with_pipeline_target(base, &inspect.target); + + crate::sync::push_jsonl_file( + target_base, + SyncPushFileArgs { + object_ref: pipeline_target_dataset_ref(&inspect.target)?, + input: input_path, + root: args.artifacts.root, + fresh: args.fresh, + }, ) + .await } -async fn commit_rows(base: &BaseArgs, args: PipelineCommitArgs) -> Result<()> { - let inspect = inspect_pipeline(base, &args.runner).await?; - let rows = read_jsonl_values(&args.input)?; - let row_count = rows.len(); - let inserted = upload_dataset_rows(base, &inspect.target, rows).await?; - print_summary( - base, - json!({ - "rows": row_count, - "inserted": inserted, - }), - false, - ) +fn base_with_pipeline_target(base: &BaseArgs, target: &PipelineTargetInspect) -> BaseArgs { + let mut target_base = base.clone(); + if let Some(org_name) = target.org_name.as_deref() { + target_base.org_name = Some(org_name.to_string()); + } + if let Some(project_id) = target.project_id.as_deref() { + target_base.project = Some(project_id.to_string()); + } else if let Some(project_name) = target.project_name.as_deref() { + target_base.project = Some(project_name.to_string()); + } + target_base +} + +fn pipeline_target_dataset_ref(target: &PipelineTargetInspect) -> Result { + let dataset_name = target.dataset_name.trim(); + if dataset_name.is_empty() { + bail!("dataset pipeline target.datasetName cannot be empty"); + } + Ok(format!("dataset:{dataset_name}")) } async fn upload_dataset_rows( @@ -532,25 +1471,14 @@ async fn discover_refs( base: &BaseArgs, inspect: &PipelineInspect, options: &PipelineFetchOptions, - out: Option<&Path>, - emit_summary: bool, -) -> Result<()> { + out: &Path, +) -> Result { let (ctx, client, project) = resolve_pipeline_source_context(base, &inspect.source).await?; let scope = PipelineScope::from_source(&inspect.source); - let target = inspect.source.limit.unwrap_or(options.target); + let limit = options.limit; let filter = discovery_filter(&inspect.source, options); - let mut writer: Box = if let Some(path) = out { - if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) { - std::fs::create_dir_all(parent) - .with_context(|| format!("failed to create {}", parent.display()))?; - } - Box::new(BufWriter::new(File::create(path).with_context(|| { - format!("failed to create {}", path.display()) - })?)) - } else { - Box::new(BufWriter::new(io::stdout())) - }; + let mut writer = create_jsonl_file_writer(out)?; let result = discover_project_log_refs( &client, @@ -558,29 +1486,18 @@ async fn discover_refs( &project.id, filter.as_ref(), project_log_ref_scope(scope), - target, + limit, options.page_size, - |reference| write_jsonl_value(writer.as_mut(), &reference.to_value()).map(|_| ()), + |reference| { + write_jsonl_value(&mut writer, &reference.to_value())?; + writer.flush().context("failed to flush discovery output")?; + Ok(()) + }, ) .await?; writer.flush().context("failed to flush discovery output")?; - let out_label = out - .map(|path| path.display().to_string()) - .unwrap_or_else(|| "stdout".to_string()); - if emit_summary { - print_summary( - base, - json!({ - "refs": result.refs, - "pages": result.pages, - "scope": match scope { PipelineScope::Trace => "trace", PipelineScope::Span => "span" }, - "out": out_label, - }), - out.is_none(), - )?; - } - Ok(()) + Ok(result) } fn project_log_ref_scope(scope: PipelineScope) -> ProjectLogRefScope { @@ -608,7 +1525,7 @@ async fn resolve_pipeline_source_context( } let ctx = login(&source_base).await?; let client = ApiClient::new(&ctx)?; - let project = resolve_source_project(&client, source).await?; + let project = resolve_source_project(base, &client, source).await?; Ok((ctx, client, project)) } @@ -628,14 +1545,6 @@ fn discovery_filter( if !options.root_span_ids.is_empty() { filters.push(root_span_id_filter(&options.root_span_ids)); } - if let Some(filter) = options - .extra_where_sql - .as_deref() - .map(str::trim) - .filter(|s| !s.is_empty()) - { - filters.push(json!({ "btql": filter })); - } match filters.len() { 0 => None, 1 => filters.into_iter().next(), @@ -652,6 +1561,7 @@ fn root_span_id_filter(root_span_ids: &[String]) -> Value { } async fn resolve_source_project( + base: &BaseArgs, client: &ApiClient, source: &PipelineSourceInspect, ) -> Result { @@ -666,10 +1576,16 @@ async fn resolve_source_project( description: None, }); } + let configured_project = + crate::config::configured_project_for_context(base, Some(client.org_name())); let project_name = source .project_name .as_deref() - .context("dataset pipeline source requires projectName or projectId")?; + .or(base.project.as_deref()) + .or(configured_project.as_deref()) + .context( + "dataset pipeline source requires projectName or projectId; pass --source-project or set an active project", + )?; get_project_by_name(client, project_name) .await? .with_context(|| format!("project '{project_name}' not found")) @@ -734,6 +1650,28 @@ fn parse_positive_usize(value: &str) -> std::result::Result { mod tests { use super::*; + fn test_base_args() -> BaseArgs { + BaseArgs { + json: false, + verbose: false, + quiet: false, + quiet_source: None, + no_color: false, + no_input: false, + profile: None, + profile_explicit: false, + org_name: None, + project: None, + api_key: None, + api_key_source: None, + prefer_profile: false, + api_url: None, + app_url: None, + ca_cert: None, + env_file: None, + } + } + #[test] fn prepare_pipeline_records_reuses_dataset_record_validation() { let err = prepare_pipeline_records(vec![json!({ @@ -786,4 +1724,155 @@ mod tests { validate_transform_response(&response).expect_err("rowCount should match rows length"); assert!(err.to_string().contains("rowCount 2")); } + + #[test] + fn pipeline_target_dataset_ref_validates_dataset_name() { + let target = PipelineTargetInspect { + project_id: None, + project_name: Some("Target Project".to_string()), + org_name: None, + dataset_name: " Ground Truth ".to_string(), + description: None, + metadata: None, + }; + assert_eq!( + pipeline_target_dataset_ref(&target).expect("dataset ref"), + "dataset:Ground Truth" + ); + + let err = pipeline_target_dataset_ref(&PipelineTargetInspect { + dataset_name: " ".to_string(), + ..target + }) + .expect_err("empty dataset names should fail"); + assert!(err.to_string().contains("target.datasetName")); + } + + #[test] + fn pipeline_push_base_uses_target_org_and_project() { + let base = test_base_args(); + let target = PipelineTargetInspect { + project_id: Some("project-id".to_string()), + project_name: Some("Project Name".to_string()), + org_name: Some("target-org".to_string()), + dataset_name: "Dataset".to_string(), + description: None, + metadata: None, + }; + + let target_base = base_with_pipeline_target(&base, &target); + assert_eq!(target_base.org_name.as_deref(), Some("target-org")); + assert_eq!(target_base.project.as_deref(), Some("project-id")); + } + + #[test] + fn pipeline_artifacts_default_to_sync_root_shape() { + let root = tempfile::tempdir().expect("tempdir"); + let runner = PipelineRunnerArgs { + pipeline: PathBuf::from("facet_pipeline.py"), + name: None, + runner: None, + }; + let spec = + base_pipeline_artifact_spec(&test_base_args(), &runner, PipelineArtifactStage::Fetch); + + let artifact = resolve_pipeline_output_artifact(root.path(), &runner, spec, None, None) + .expect("artifact path"); + + assert!(artifact + .output_path + .starts_with(root.path().join("dataset_pipeline_facet_pipeline"))); + assert_eq!(artifact.output_path.file_name().unwrap(), "fetched.jsonl"); + } + + #[test] + fn pipeline_input_defaults_to_latest_completed_stage_output() { + let root = tempfile::tempdir().expect("tempdir"); + let runner = PipelineRunnerArgs { + pipeline: PathBuf::from("facet_pipeline.py"), + name: None, + runner: None, + }; + let spec = + base_pipeline_artifact_spec(&test_base_args(), &runner, PipelineArtifactStage::Fetch); + let artifact = resolve_pipeline_output_artifact(root.path(), &runner, spec, None, None) + .expect("artifact path"); + + artifact.write_spec().expect("write spec"); + crate::sync::write_jsonl_values( + Some(&artifact.output_path), + &[json!({ "root_span_id": "root-1" })], + ) + .expect("write output"); + artifact + .write_manifest(PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: artifact.spec_hash.clone(), + spec: artifact.spec.clone(), + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Fetch, + input_path: None, + output_path: Some(artifact.output_path.display().to_string()), + refs: Some(1), + candidates: None, + rows: None, + pages: Some(1), + started_at: 1, + updated_at: 2, + completed_at: Some(2), + }) + .expect("write manifest"); + + let resolved = + resolve_pipeline_input_path(&None, root.path(), &runner, PipelineArtifactStage::Fetch) + .expect("default input"); + + assert_eq!(resolved, artifact.output_path); + } + + #[test] + fn pipeline_transform_output_defaults_to_fetch_artifact_dir() { + let root = tempfile::tempdir().expect("tempdir"); + let runner = PipelineRunnerArgs { + pipeline: PathBuf::from("facet_pipeline.py"), + name: None, + runner: None, + }; + let source = PipelineSourceInspect { + project_id: None, + project_name: Some("Loop".to_string()), + org_name: None, + filter: None, + scope: Some(PipelineScope::Span), + }; + let fetch_spec = + base_pipeline_artifact_spec(&test_base_args(), &runner, PipelineArtifactStage::Fetch); + let fetch_artifact = + resolve_pipeline_output_artifact(root.path(), &runner, fetch_spec, None, None) + .expect("fetch artifact"); + let transform_spec = pipeline_transform_artifact_spec( + &test_base_args(), + &runner, + &source, + &PipelineTransformOptions { + max_concurrency: 16, + }, + &fetch_artifact.output_path, + ); + + let transform_artifact = resolve_pipeline_output_artifact( + root.path(), + &runner, + transform_spec, + None, + Some(&fetch_artifact.output_path), + ) + .expect("transform artifact"); + + assert_eq!(transform_artifact.spec_dir, fetch_artifact.spec_dir); + assert_eq!( + transform_artifact.output_path, + fetch_artifact.spec_dir.join("transformed.jsonl") + ); + } } diff --git a/src/eval.rs b/src/eval.rs index 87c5cd09..475fdeee 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -3,9 +3,9 @@ use std::ffi::{OsStr, OsString}; use std::io::IsTerminal; use std::path::{Path, PathBuf}; use std::process::{ExitStatus, Stdio}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime}; use actix_web::dev::Service; use actix_web::http::header::{ @@ -27,8 +27,6 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use strip_ansi_escapes::strip; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::net::UnixListener; use tokio::process::Command; use tokio::sync::mpsc; use unicode_width::UnicodeWidthStr; @@ -44,6 +42,7 @@ use crate::args::BaseArgs; use crate::auth::resolved_runner_env; use crate::js_runner; use crate::python_runner; +use crate::runner_sse; use crate::ui::{animations_enabled, is_quiet}; const MAX_NAME_LENGTH: usize = 40; @@ -58,7 +57,6 @@ const HEADER_BT_AUTH_TOKEN: &str = "x-bt-auth-token"; const HEADER_BT_ORG_NAME: &str = "x-bt-org-name"; const HEADER_CORS_REQ_PRIVATE_NETWORK: &str = "access-control-request-private-network"; const HEADER_CORS_ALLOW_PRIVATE_NETWORK: &str = "access-control-allow-private-network"; -const SSE_SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; const EVAL_NODE_MAX_OLD_SPACE_SIZE_MB: usize = 8192; const MAX_DEFERRED_EVAL_ERRORS: usize = 8; const DEFAULT_EVAL_SAMPLE_SEED: u64 = 0; @@ -83,8 +81,6 @@ fn parse_positive_usize(value: &str) -> std::result::Result { } Ok(parsed) } -static SSE_SOCKET_COUNTER: AtomicU64 = AtomicU64::new(0); - struct EvalRunOutput { status: ExitStatus, dependencies: Vec, @@ -95,7 +91,7 @@ struct EvalRunnerProcess { rx: mpsc::UnboundedReceiver, sse_task: tokio::task::JoinHandle<()>, sse_connected: Arc, - _socket_cleanup_guard: SocketCleanupGuard, + _socket_cleanup_guard: runner_sse::SocketCleanupGuard, } struct EvalProcessOutput { @@ -220,22 +216,6 @@ const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); const PY_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.py"); const PYTHON_INTERPRETER_ENV_OVERRIDES: &[&str] = &["BT_EVAL_PYTHON_RUNNER", "BT_EVAL_PYTHON"]; -struct SocketCleanupGuard { - path: PathBuf, -} - -impl SocketCleanupGuard { - fn new(path: PathBuf) -> Self { - Self { path } - } -} - -impl Drop for SocketCleanupGuard { - fn drop(&mut self) { - let _ = std::fs::remove_file(&self.path); - } -} - #[derive(Debug, Copy, Clone, Eq, PartialEq, ValueEnum)] pub enum EvalLanguage { #[value(alias = "js")] @@ -772,7 +752,7 @@ async fn spawn_eval_runner( let (js_runner, py_runner) = prepare_eval_runners()?; let force_esm = matches!(js_mode, JsMode::ForceEsm); - let (listener, socket_path, socket_cleanup_guard) = bind_sse_listener()?; + let (listener, socket_path, socket_cleanup_guard) = runner_sse::bind_sse_listener("bt-eval")?; let (tx, rx) = mpsc::unbounded_channel(); let sse_connected = Arc::new(AtomicBool::new(false)); @@ -782,7 +762,11 @@ async fn spawn_eval_runner( match listener.accept().await { Ok((stream, _)) => { sse_connected_for_task.store(true, Ordering::Relaxed); - if let Err(err) = read_sse_stream(stream, tx_sse.clone()).await { + if let Err(err) = runner_sse::read_sse_stream(stream, |event, data| { + handle_sse_event(event, data, &tx_sse); + }) + .await + { let _ = tx_sse.send(EvalEvent::Error { message: format!("SSE stream error: {err}"), stack: None, @@ -911,7 +895,14 @@ async fn spawn_eval_runner( if let Some(stdout) = stdout { let tx_stdout = tx.clone(); tokio::spawn(async move { - if let Err(err) = forward_stream(stdout, "stdout", tx_stdout).await { + if let Err(err) = runner_sse::forward_stream(stdout, "stdout", |stream, message| { + let _ = tx_stdout.send(EvalEvent::Console { + stream: stream.to_string(), + message, + }); + }) + .await + { eprintln!("Failed to read eval stdout: {err}"); } }); @@ -920,7 +911,14 @@ async fn spawn_eval_runner( if let Some(stderr) = stderr { let tx_stderr = tx.clone(); tokio::spawn(async move { - if let Err(err) = forward_stream(stderr, "stderr", tx_stderr).await { + if let Err(err) = runner_sse::forward_stream(stderr, "stderr", |stream, message| { + let _ = tx_stderr.send(EvalEvent::Console { + stream: stream.to_string(), + message, + }); + }) + .await + { eprintln!("Failed to read eval stderr: {err}"); } }); @@ -941,70 +939,64 @@ async fn spawn_eval_runner( } async fn drive_eval_runner( - mut process: EvalRunnerProcess, + process: EvalRunnerProcess, console_policy: ConsolePolicy, mut on_event: F, ) -> Result where F: FnMut(EvalEvent), { - let mut status = None; + let EvalRunnerProcess { + mut child, + rx, + mut sse_task, + sse_connected, + _socket_cleanup_guard, + } = process; let mut dependency_files: Vec = Vec::new(); let mut error_messages: Vec = Vec::new(); let mut stderr_lines: Vec = Vec::new(); - - loop { - tokio::select! { - event = process.rx.recv() => { - match event { - Some(EvalEvent::Dependencies { files }) => { - dependency_files.extend(files.clone()); - on_event(EvalEvent::Dependencies { files }); - } - Some(EvalEvent::Error { message, stack, status }) => { - error_messages.push(message.clone()); - if let Some(stack) = stack.as_ref() { - error_messages.push(stack.clone()); - } - on_event(EvalEvent::Error { message, stack, status }); - } - Some(EvalEvent::Console { stream, message }) => { - if stream == "stderr" && matches!(console_policy, ConsolePolicy::BufferStderr) - { - stderr_lines.push(message); - } else { - on_event(EvalEvent::Console { stream, message }); - } - } - Some(event) => on_event(event), - None => { - if status.is_none() { - status = Some(process.child.wait().await.context("eval runner process failed")?); - if !process.sse_connected.load(Ordering::Relaxed) { - process.sse_task.abort(); - } - } - break; - } + let wait = Box::pin(async { child.wait().await.context("eval runner process failed") }); + let status = runner_sse::drive_runner_events( + rx, + wait, + &mut sse_task, + &sse_connected, + "eval runner process exited without a status", + |event| match event { + EvalEvent::Dependencies { files } => { + dependency_files.extend(files.clone()); + on_event(EvalEvent::Dependencies { files }); + } + EvalEvent::Error { + message, + stack, + status, + } => { + error_messages.push(message.clone()); + if let Some(stack) = stack.as_ref() { + error_messages.push(stack.clone()); } + on_event(EvalEvent::Error { + message, + stack, + status, + }); } - exit_status = process.child.wait(), if status.is_none() => { - status = Some(exit_status.context("eval runner process failed")?); - if !process.sse_connected.load(Ordering::Relaxed) { - process.sse_task.abort(); + EvalEvent::Console { stream, message } => { + if stream == "stderr" && matches!(console_policy, ConsolePolicy::BufferStderr) { + stderr_lines.push(message); + } else { + on_event(EvalEvent::Console { stream, message }); } } - } - - if status.is_some() && process.rx.is_closed() { - break; - } - } - - let _ = process.sse_task.await; + event => on_event(event), + }, + ) + .await?; Ok(EvalProcessOutput { - status: status.context("eval runner process exited without a status")?, + status, dependency_files, error_messages, stderr_lines, @@ -2507,49 +2499,6 @@ fn set_node_heap_size_env(command: &mut Command) { command.env("NODE_OPTIONS", merged); } -fn build_sse_socket_path() -> Result { - let pid = std::process::id(); - let serial = SSE_SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .context("failed to read system time")? - .as_nanos(); - Ok(std::env::temp_dir().join(format!("bt-eval-{pid}-{now}-{serial}.sock"))) -} - -fn bind_sse_listener() -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { - let mut last_bind_err: Option = None; - for _ in 0..SSE_SOCKET_BIND_MAX_ATTEMPTS { - let socket_path = build_sse_socket_path()?; - let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); - let _ = std::fs::remove_file(&socket_path); - match UnixListener::bind(&socket_path) { - Ok(listener) => return Ok((listener, socket_path, socket_cleanup_guard)), - Err(err) - if matches!( - err.kind(), - std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::AddrInUse - ) => - { - last_bind_err = Some(err); - continue; - } - Err(err) => { - return Err(err).context("failed to bind SSE unix socket"); - } - } - } - let err = last_bind_err.unwrap_or_else(|| { - std::io::Error::new( - std::io::ErrorKind::AddrInUse, - "failed to allocate a unique SSE socket path", - ) - }); - Err(err).context(format!( - "failed to bind SSE unix socket after {SSE_SOCKET_BIND_MAX_ATTEMPTS} attempts" - )) -} - fn eval_runner_cache_dir() -> PathBuf { let root = std::env::var_os("XDG_CACHE_HOME") .map(PathBuf::from) @@ -2714,57 +2663,6 @@ struct SseDependenciesEventData { files: Vec, } -async fn forward_stream( - stream: T, - name: &'static str, - tx: mpsc::UnboundedSender, -) -> Result<()> -where - T: tokio::io::AsyncRead + Unpin, -{ - let mut lines = BufReader::new(stream).lines(); - while let Some(line) = lines.next_line().await? { - let _ = tx.send(EvalEvent::Console { - stream: name.to_string(), - message: line, - }); - } - Ok(()) -} - -async fn read_sse_stream(stream: T, tx: mpsc::UnboundedSender) -> Result<()> -where - T: tokio::io::AsyncRead + Unpin, -{ - let mut lines = BufReader::new(stream).lines(); - let mut event: Option = None; - let mut data_lines: Vec = Vec::new(); - - while let Some(line) = lines.next_line().await? { - if line.is_empty() { - if event.is_some() || !data_lines.is_empty() { - let data = data_lines.join("\n"); - handle_sse_event(event.take(), data, &tx); - data_lines.clear(); - } - continue; - } - - if let Some(value) = line.strip_prefix("event:") { - event = Some(value.trim().to_string()); - } else if let Some(value) = line.strip_prefix("data:") { - data_lines.push(value.trim_start().to_string()); - } - } - - if event.is_some() || !data_lines.is_empty() { - let data = data_lines.join("\n"); - handle_sse_event(event.take(), data, &tx); - } - - Ok(()) -} - fn handle_sse_event(event: Option, data: String, tx: &mpsc::UnboundedSender) { let event_name = event.unwrap_or_default(); match event_name.as_str() { @@ -4313,8 +4211,8 @@ mod tests { #[test] fn build_sse_socket_path_is_unique_for_consecutive_calls() { - let first = build_sse_socket_path().expect("first socket path"); - let second = build_sse_socket_path().expect("second socket path"); + let first = runner_sse::build_sse_socket_path("bt-eval").expect("first socket path"); + let second = runner_sse::build_sse_socket_path("bt-eval").expect("second socket path"); assert_ne!(first, second); } diff --git a/src/main.rs b/src/main.rs index e59f0d26..23d12afb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,6 +19,7 @@ mod project_context; mod projects; mod prompts; mod python_runner; +mod runner_sse; mod scorers; mod self_update; mod setup; diff --git a/src/runner_sse.rs b/src/runner_sse.rs new file mode 100644 index 00000000..abfbbad6 --- /dev/null +++ b/src/runner_sse.rs @@ -0,0 +1,173 @@ +use std::path::PathBuf; +use std::pin::Pin; +use std::process::ExitStatus; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use anyhow::{Context, Result}; +use std::future::Future; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::net::UnixListener; +use tokio::sync::mpsc; + +const SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; +static SOCKET_COUNTER: AtomicU64 = AtomicU64::new(0); + +pub(crate) struct SocketCleanupGuard { + path: PathBuf, +} + +impl SocketCleanupGuard { + fn new(path: PathBuf) -> Self { + Self { path } + } +} + +impl Drop for SocketCleanupGuard { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} + +pub(crate) fn bind_sse_listener( + prefix: &str, +) -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { + let mut last_bind_err: Option = None; + for _ in 0..SOCKET_BIND_MAX_ATTEMPTS { + let socket_path = build_sse_socket_path(prefix)?; + let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); + let _ = std::fs::remove_file(&socket_path); + match UnixListener::bind(&socket_path) { + Ok(listener) => return Ok((listener, socket_path, socket_cleanup_guard)), + Err(err) + if matches!( + err.kind(), + std::io::ErrorKind::AlreadyExists | std::io::ErrorKind::AddrInUse + ) => + { + last_bind_err = Some(err); + continue; + } + Err(err) => { + return Err(err).context("failed to bind SSE unix socket"); + } + } + } + let err = last_bind_err.unwrap_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::AddrInUse, + "failed to allocate a unique SSE socket path", + ) + }); + Err(err).context(format!( + "failed to bind SSE unix socket after {SOCKET_BIND_MAX_ATTEMPTS} attempts" + )) +} + +pub(crate) fn build_sse_socket_path(prefix: &str) -> Result { + let pid = std::process::id(); + let serial = SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .context("failed to read system time")? + .as_nanos(); + Ok(std::env::temp_dir().join(format!("{prefix}-{pid}-{now}-{serial}.sock"))) +} + +pub(crate) async fn forward_stream( + stream: T, + name: &'static str, + mut on_line: F, +) -> Result<()> +where + T: tokio::io::AsyncRead + Unpin, + F: FnMut(&'static str, String), +{ + let mut lines = BufReader::new(stream).lines(); + while let Some(line) = lines.next_line().await? { + on_line(name, line); + } + Ok(()) +} + +pub(crate) async fn read_sse_stream(stream: T, mut on_event: F) -> Result<()> +where + T: tokio::io::AsyncRead + Unpin, + F: FnMut(Option, String), +{ + let mut lines = BufReader::new(stream).lines(); + let mut event: Option = None; + let mut data_lines: Vec = Vec::new(); + + while let Some(line) = lines.next_line().await? { + if line.is_empty() { + if event.is_some() || !data_lines.is_empty() { + let data = data_lines.join("\n"); + on_event(event.take(), data); + data_lines.clear(); + } + continue; + } + + if let Some(value) = line.strip_prefix("event:") { + event = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.trim_start().to_string()); + } + } + + if event.is_some() || !data_lines.is_empty() { + let data = data_lines.join("\n"); + on_event(event.take(), data); + } + + Ok(()) +} + +pub(crate) async fn drive_runner_events( + mut rx: mpsc::UnboundedReceiver, + mut wait: Pin> + Send + '_>>, + sse_task: &mut tokio::task::JoinHandle<()>, + sse_connected: &AtomicBool, + missing_status_message: &'static str, + mut on_event: F, +) -> Result +where + F: FnMut(E), +{ + let mut status: Option = None; + + loop { + tokio::select! { + event = rx.recv() => { + match event { + Some(event) => on_event(event), + None => { + if status.is_none() { + status = Some(wait.as_mut().await?); + abort_unconnected_sse(sse_task, sse_connected); + } + break; + } + } + } + wait_result = wait.as_mut(), if status.is_none() => { + status = Some(wait_result?); + abort_unconnected_sse(sse_task, sse_connected); + } + } + + if status.is_some() && rx.is_closed() { + break; + } + } + + let _ = sse_task.await; + status.context(missing_status_message) +} + +fn abort_unconnected_sse(sse_task: &mut tokio::task::JoinHandle<()>, sse_connected: &AtomicBool) { + if !sse_connected.load(Ordering::Relaxed) { + sse_task.abort(); + } +} diff --git a/src/sync.rs b/src/sync.rs index a01f7a5d..6d1b9e90 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -47,6 +47,14 @@ pub(crate) fn default_workers() -> usize { .unwrap_or(DEFAULT_WORKERS_FALLBACK) } +#[derive(Debug, Clone)] +pub(crate) struct SyncPushFileArgs { + pub object_ref: String, + pub input: PathBuf, + pub root: PathBuf, + pub fresh: bool, +} + #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: @@ -564,6 +572,34 @@ pub async fn run(base: BaseArgs, args: SyncArgs) -> Result<()> { } } +pub(crate) async fn push_jsonl_file(base: BaseArgs, args: SyncPushFileArgs) -> Result<()> { + let json_output = base.json; + let ctx = login(&base).await?; + let client = ApiClient::new(&ctx)?; + let project = base.project.clone().or_else(|| { + crate::config::configured_project_for_context(&base, ctx.login.org_name().as_deref()) + }); + + run_push( + json_output, + &ctx, + &client, + project.as_deref(), + PushArgs { + object_ref: args.object_ref, + input: Some(args.input), + filter: None, + traces: None, + spans: None, + page_size: DEFAULT_PAGE_SIZE, + fresh: args.fresh, + root: args.root, + workers: default_workers(), + }, + ) + .await +} + async fn run_pull( json_output: bool, verbose: bool, @@ -3449,13 +3485,26 @@ fn sanitize_segment(value: &str) -> String { } } -fn spec_dir(root: &Path, object: &ObjectRef, hash: &str) -> PathBuf { +pub(crate) fn artifact_base_dir(root: &Path, object_type: &str, object_name: &str) -> PathBuf { let object_key = format!( "{}_{}", - sanitize_segment(object.object_type.as_str()), - sanitize_segment(&object.object_name) + sanitize_segment(object_type), + sanitize_segment(object_name) ); - root.join(object_key).join(&hash[..12]) + root.join(object_key) +} + +pub(crate) fn artifact_spec_dir( + root: &Path, + object_type: &str, + object_name: &str, + hash: &str, +) -> PathBuf { + artifact_base_dir(root, object_type, object_name).join(&hash[..12]) +} + +fn spec_dir(root: &Path, object: &ObjectRef, hash: &str) -> PathBuf { + artifact_spec_dir(root, object.object_type.as_str(), &object.object_name, hash) } fn legacy_spec_dir( @@ -3504,8 +3553,8 @@ fn resolve_spec_dir( } } -fn spec_hash(spec: &SyncSpec) -> Result { - let canonical = serde_json::to_vec(spec).context("failed to serialize sync spec")?; +pub(crate) fn stable_spec_hash(spec: &T) -> Result { + let canonical = serde_json::to_vec(spec).context("failed to serialize spec")?; let mut hasher = Sha256::new(); hasher.update(&canonical); let digest = hasher.finalize(); @@ -3516,7 +3565,11 @@ fn spec_hash(spec: &SyncSpec) -> Result { Ok(out) } -fn epoch_seconds() -> u64 { +fn spec_hash(spec: &SyncSpec) -> Result { + stable_spec_hash(spec) +} + +pub(crate) fn epoch_seconds() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .map(|d| d.as_secs()) @@ -3803,15 +3856,20 @@ pub(crate) fn write_jsonl_value( Ok(encoded.len() + 1) } +pub(crate) fn create_jsonl_file_writer(path: &Path) -> Result> { + if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) { + fs::create_dir_all(parent) + .with_context(|| format!("failed to create {}", parent.display()))?; + } + Ok(BufWriter::new(File::create(path).with_context(|| { + format!("failed to create {}", path.display()) + })?)) +} + +#[cfg(test)] pub(crate) fn write_jsonl_values(out: Option<&Path>, values: &[T]) -> Result<()> { let mut writer: Box = if let Some(path) = out { - if let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) { - fs::create_dir_all(parent) - .with_context(|| format!("failed to create {}", parent.display()))?; - } - Box::new(BufWriter::new(File::create(path).with_context(|| { - format!("failed to create {}", path.display()) - })?)) + Box::new(create_jsonl_file_writer(path)?) } else { Box::new(BufWriter::new(std::io::stdout())) }; @@ -4238,7 +4296,7 @@ fn value_as_string(value: Option<&Value>) -> Option { } } -fn write_json_atomic(path: &Path, value: &T) -> Result<()> { +pub(crate) fn write_json_atomic(path: &Path, value: &T) -> Result<()> { let parent = path .parent() .ok_or_else(|| anyhow!("path has no parent: {}", path.display()))?; @@ -4257,7 +4315,7 @@ fn write_json_atomic(path: &Path, value: &T) -> Result<()> { Ok(()) } -fn read_json_file(path: &Path) -> Result { +pub(crate) fn read_json_file(path: &Path) -> Result { let bytes = fs::read(path).with_context(|| format!("failed to read {}", path.display()))?; serde_json::from_slice(&bytes).with_context(|| format!("failed to parse {}", path.display())) } From bbb36898119b8b912a781557f4d7af244f4ad388 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Sun, 3 May 2026 12:28:35 -0400 Subject: [PATCH 5/7] print destination URL --- src/sync.rs | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/src/sync.rs b/src/sync.rs index 6d1b9e90..e831c8d7 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -410,6 +410,8 @@ struct ResolvedPushDestination { object: ObjectRef, object_ref: String, project_id: String, + project_name: String, + object_name: String, run_id: Option, } @@ -418,13 +420,17 @@ struct ResolvedDestination { object: ObjectRef, object_ref: String, project_id: String, + project_name: String, + object_name: String, run_id: Option, } #[derive(Debug, Clone)] struct ResolvedNamedObjectTarget { object_id: String, + object_name: String, project_id: String, + project_name: String, } #[derive(Debug, Clone, Copy)] @@ -1622,6 +1628,7 @@ async fn run_push( args: PushArgs, ) -> Result<()> { let destination = resolve_push_destination(client, &args.object_ref, project_selector).await?; + let object_url = push_destination_url(&ctx.app_url, client.org_name(), &destination); let object = destination.object.clone(); let (scope, limit) = resolve_push_scope_and_limit(args.traces, args.spans)?; @@ -1696,6 +1703,7 @@ async fn run_push( "status": "completed", "message": "already completed for this spec", "source_path": state.source_path, + "object_url": object_url, "items_done": state.items_done, "pages_done": state.pages_done }))? @@ -1705,6 +1713,7 @@ async fn run_push( "Sync already completed for this spec. input={} items={} pages={}", state.source_path, state.items_done, state.pages_done ); + println!(" URL: {object_url}"); } return Ok(()); } @@ -1949,6 +1958,7 @@ async fn run_push( "status": "interrupted", "spec_dir": spec_dir, "input_path": input_path, + "object_url": object_url, "rows_uploaded": state.items_done, "pages_done": state.pages_done, "bytes_sent": state.bytes_sent, @@ -1964,6 +1974,7 @@ async fn run_push( format_u64_commas(state.bytes_sent) ); println!(" Resume: rerun the same command (use --fresh to restart)"); + println!(" URL: {object_url}"); } return Ok(()); } @@ -1988,6 +1999,7 @@ async fn run_push( "status": "completed", "spec_dir": spec_dir, "input_path": input_path, + "object_url": object_url, "rows_uploaded": state.items_done, "pages_done": state.pages_done, "bytes_sent": state.bytes_sent @@ -2001,6 +2013,7 @@ async fn run_push( let spans_per_sec = spans_done as f64 / elapsed_secs as f64; let bytes_per_sec = state.bytes_sent as f64 / elapsed_secs as f64; println!("Push complete"); + println!(" URL: {object_url}"); println!(" Input: {}", input_path.display()); println!(" Time: {}", format_duration(elapsed_secs)); println!(" Traces: {}", format_usize_commas(traces_done)); @@ -2021,6 +2034,35 @@ async fn run_push( Ok(()) } +fn push_destination_url( + app_url: &str, + org_name: &str, + destination: &ResolvedPushDestination, +) -> String { + let project_name = if destination.project_name.trim().is_empty() { + destination.project_id.as_str() + } else { + destination.project_name.as_str() + }; + let object_name = if destination.object_name.trim().is_empty() { + destination.object.object_name.as_str() + } else { + destination.object_name.as_str() + }; + let path = match destination.object.object_type { + ObjectType::ProjectLogs => "logs".to_string(), + ObjectType::Experiment => format!("experiments/{}", encode(object_name)), + ObjectType::Dataset => format!("datasets/{}", encode(object_name)), + }; + format!( + "{}/app/{}/p/{}/{}", + app_url.trim_end_matches('/'), + encode(org_name), + encode(project_name), + path + ) +} + fn run_status(json_output: bool, args: StatusArgs) -> Result<()> { let object = parse_object_ref(&args.object_ref)?; let (scope, limit) = resolve_status_scope_and_limit(args.traces, args.spans)?; @@ -3003,7 +3045,9 @@ async fn resolve_destination( Ok(ResolvedDestination { object_ref: format!("project_logs:{}", project.id), object, - project_id: project.id, + project_id: project.id.clone(), + project_name: project.name.clone(), + object_name: project.name, run_id: None, }) } @@ -3037,6 +3081,8 @@ async fn resolve_destination( object_ref: format!("experiment:{}", resolved.object_id), object, project_id: resolved.project_id, + project_name: resolved.project_name, + object_name: resolved.object_name, run_id, }) } @@ -3061,6 +3107,8 @@ async fn resolve_destination( object_ref: format!("dataset:{}", resolved.object_id), object, project_id: resolved.project_id, + project_name: resolved.project_name, + object_name: resolved.object_name, run_id: None, }) } @@ -3117,7 +3165,9 @@ async fn resolve_named_object_target( )?; return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), + object_name: object.name.clone(), project_id: project.id.clone(), + project_name: project.name.clone(), }); } @@ -3133,7 +3183,9 @@ async fn resolve_named_object_target( if let Some(object) = objects.iter().find(|value| value.id == object_selector) { return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), + object_name: object.name.clone(), project_id: project.id.clone(), + project_name: project.name.clone(), }); } } @@ -3183,7 +3235,9 @@ async fn resolve_push_experiment_target( ) { return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), - project_id: project.id, + object_name: object.name.clone(), + project_id: project.id.clone(), + project_name: project.name.clone(), }); } @@ -3210,7 +3264,9 @@ async fn resolve_push_experiment_target( Ok(ResolvedNamedObjectTarget { object_id: created.id, + object_name: created.name, project_id: project.id, + project_name: project.name, }) } @@ -3246,7 +3302,9 @@ async fn resolve_push_dataset_target( ) { return Ok(ResolvedNamedObjectTarget { object_id: object.id.clone(), - project_id: project.id, + object_name: object.name.clone(), + project_id: project.id.clone(), + project_name: project.name.clone(), }); } @@ -3273,7 +3331,9 @@ async fn resolve_push_dataset_target( Ok(ResolvedNamedObjectTarget { object_id: created.id, + object_name: created.name, project_id: project.id, + project_name: project.name, }) } @@ -3321,6 +3381,8 @@ async fn resolve_push_destination( object: resolved.object, object_ref: resolved.object_ref, project_id: resolved.project_id, + project_name: resolved.project_name, + object_name: resolved.object_name, run_id: resolved.run_id, }) } @@ -4497,6 +4559,30 @@ mod tests { Ok(()) } + #[test] + fn push_destination_url_links_to_dataset_object() { + let destination = ResolvedPushDestination { + object: ObjectRef { + object_type: ObjectType::Dataset, + object_name: "dataset-id".to_string(), + }, + object_ref: "dataset:dataset-id".to_string(), + project_id: "project-id".to_string(), + project_name: "Facet Optimizer".to_string(), + object_name: "Loop Facet Ground Truth".to_string(), + run_id: None, + }; + + assert_eq!( + push_destination_url( + "https://www.braintrust.dev/", + "braintrustdata.com", + &destination + ), + "https://www.braintrust.dev/app/braintrustdata.com/p/Facet%20Optimizer/datasets/Loop%20Facet%20Ground%20Truth" + ); + } + #[test] fn push_checkpoint_line_offset_advances_only_after_commit() { let mut state = From ff6d0aa8ea99954159da2725dad0697e874ccdac Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Sun, 3 May 2026 15:30:24 -0400 Subject: [PATCH 6/7] a few more fixes --- src/datasets/pipeline.rs | 138 +++++++++++++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 11 deletions(-) diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs index 4f98e61e..cdb47299 100644 --- a/src/datasets/pipeline.rs +++ b/src/datasets/pipeline.rs @@ -547,6 +547,16 @@ fn apply_source_overrides(source: &mut PipelineSourceInspect, args: &PipelineSou } } +fn source_with_resolved_project( + source: &PipelineSourceInspect, + project: &Project, +) -> PipelineSourceInspect { + let mut source = source.clone(); + source.project_id = Some(project.id.clone()); + source.project_name = Some(project.name.clone()); + source +} + fn apply_target_overrides(target: &mut PipelineTargetInspect, args: &PipelineTargetArgs) { if let Some(project_name) = args.target_project.as_deref() { target.project_name = Some(project_name.to_string()); @@ -650,8 +660,10 @@ fn pipeline_language(pipeline_file: &Path) -> Result { async fn fetch_refs( base: &BaseArgs, args: PipelineFetchArgs, - inspect: PipelineInspect, + mut inspect: PipelineInspect, ) -> Result<()> { + let source_project = resolve_pipeline_source_project(base, &inspect.source).await?; + inspect.source = source_with_resolved_project(&inspect.source, &source_project); let spec = pipeline_fetch_artifact_spec(base, &args.runner, &inspect.source, &args.fetch); let artifact = resolve_pipeline_output_artifact( &args.artifacts.root, @@ -685,6 +697,8 @@ async fn fetch_refs( "refs": result.refs, "pages": result.pages, "scope": match PipelineScope::from_source(&inspect.source) { PipelineScope::Trace => "trace", PipelineScope::Span => "span" }, + "source_project": source_project.name, + "source_project_id": source_project.id, "out": artifact.output_path.display().to_string(), }), false, @@ -807,6 +821,41 @@ fn resolve_pipeline_input_path( } } +fn read_pipeline_stage_manifest_for_output( + output_path: &Path, + stage: PipelineArtifactStage, +) -> Result> { + let Some(parent) = output_path.parent() else { + return Ok(None); + }; + let manifest_path = parent.join(stage.manifest_file()); + if !manifest_path.exists() { + return Ok(None); + } + let manifest = read_json_file::(&manifest_path) + .with_context(|| format!("failed to read {}", manifest_path.display()))?; + if manifest.stage != stage || manifest.status != PipelineArtifactStatus::Completed { + return Ok(None); + } + Ok(Some(manifest)) +} + +fn base_with_pipeline_artifact_context( + base: &BaseArgs, + manifest: Option<&PipelineArtifactManifest>, +) -> BaseArgs { + let mut base = base.clone(); + if let Some(spec) = manifest.map(|manifest| &manifest.spec) { + if base.project.is_none() { + base.project = spec.cli_project.clone(); + } + if base.org_name.is_none() { + base.org_name = spec.cli_org.clone(); + } + } + base +} + fn resolve_latest_pipeline_stage_output( root: &Path, runner: &PipelineRunnerArgs, @@ -1088,23 +1137,27 @@ fn forward_blocking_stream( } async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result<()> { - let inspect = inspect_with_overrides( - inspect_pipeline(base, &args.runner).await?, - Some(&args.source), - None, - ); - let source_project = resolve_pipeline_source_project(base, &inspect.source).await?; let input_path = resolve_pipeline_input_path( &args.input, &args.artifacts.root, &args.runner, PipelineArtifactStage::Fetch, )?; + let fetch_manifest = + read_pipeline_stage_manifest_for_output(&input_path, PipelineArtifactStage::Fetch)?; + let inspect = inspect_pipeline(base, &args.runner).await?; + let mut source = fetch_manifest + .as_ref() + .and_then(|manifest| manifest.spec.source.clone()) + .unwrap_or(inspect.source); + apply_source_overrides(&mut source, &args.source); + let source_base = base_with_pipeline_artifact_context(base, fetch_manifest.as_ref()); + let source_project = resolve_pipeline_source_project(&source_base, &source).await?; let refs = read_jsonl_values(&input_path)?; let spec = pipeline_transform_artifact_spec( - base, + &source_base, &args.runner, - &inspect.source, + &source, &args.transform, &input_path, ); @@ -1119,10 +1172,10 @@ async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result< let started_at = epoch_seconds(); let mut writer = create_jsonl_file_writer(&artifact.output_path)?; let response = transform_source_refs( - base, + &source_base, &args.runner, &source_project.id, - &inspect.source, + &source, refs, args.transform.max_concurrency, Some(&mut writer as &mut dyn Write), @@ -1765,6 +1818,69 @@ mod tests { assert_eq!(target_base.project.as_deref(), Some("project-id")); } + #[test] + fn pipeline_source_artifact_records_resolved_project() { + let source = PipelineSourceInspect { + project_id: None, + project_name: None, + org_name: None, + filter: Some("span_attributes.type = 'llm'".to_string()), + scope: Some(PipelineScope::Span), + }; + let project = Project { + id: "project-id".to_string(), + name: "Loop".to_string(), + org_id: "org-id".to_string(), + description: None, + }; + + let resolved = source_with_resolved_project(&source, &project); + + assert_eq!(resolved.project_id.as_deref(), Some("project-id")); + assert_eq!(resolved.project_name.as_deref(), Some("Loop")); + assert_eq!(resolved.filter, source.filter); + assert_eq!(resolved.scope, source.scope); + } + + #[test] + fn pipeline_transform_base_inherits_fetch_artifact_context() { + let base = test_base_args(); + let manifest = PipelineArtifactManifest { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + spec_hash: "hash".to_string(), + spec: PipelineArtifactSpec { + schema_version: PIPELINE_ARTIFACT_SCHEMA_VERSION, + kind: PIPELINE_ARTIFACT_OBJECT_TYPE.to_string(), + pipeline: "facet_pipeline.py".to_string(), + name: None, + cli_project: Some("Loop".to_string()), + cli_org: Some("braintrustdata.com".to_string()), + stage: PipelineArtifactStage::Fetch, + source: None, + target: None, + fetch: None, + transform: None, + input_path: None, + }, + status: PipelineArtifactStatus::Completed, + stage: PipelineArtifactStage::Fetch, + input_path: None, + output_path: None, + refs: Some(1), + candidates: None, + rows: None, + pages: Some(1), + started_at: 1, + updated_at: 2, + completed_at: Some(2), + }; + + let inherited = base_with_pipeline_artifact_context(&base, Some(&manifest)); + + assert_eq!(inherited.project.as_deref(), Some("Loop")); + assert_eq!(inherited.org_name.as_deref(), Some("braintrustdata.com")); + } + #[test] fn pipeline_artifacts_default_to_sync_root_shape() { let root = tempfile::tempdir().expect("tempdir"); From 15db7da50e9207b67fb1a6186f0879f7cf2dd675 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Sun, 3 May 2026 16:48:39 -0400 Subject: [PATCH 7/7] fix non unix --- scripts/dataset-pipeline-runner.py | 18 ++-- scripts/dataset-pipeline-runner.ts | 20 ++++- src/datasets/mod.rs | 2 +- src/datasets/pipeline.rs | 103 ++++++++++++--------- src/eval.rs | 53 +++++------ src/runner_sse.rs | 138 +++++++++++++++++++++++++++-- 6 files changed, 245 insertions(+), 89 deletions(-) diff --git a/scripts/dataset-pipeline-runner.py b/scripts/dataset-pipeline-runner.py index bfefd218..a15d453c 100644 --- a/scripts/dataset-pipeline-runner.py +++ b/scripts/dataset-pipeline-runner.py @@ -38,9 +38,8 @@ class SseWriter: - def __init__(self, sock_path: str): - self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self._socket.connect(sock_path) + def __init__(self, sock: socket.socket): + self._socket = sock def send(self, event: str, payload: Any) -> None: data = payload if isinstance(payload, str) else json.dumps(payload, separators=(",", ":")) @@ -54,9 +53,18 @@ def close(self) -> None: def create_sse_writer() -> SseWriter | None: sock_path = os.getenv("BT_DATASET_PIPELINE_SSE_SOCK") if not sock_path: - return None + addr = os.getenv("BT_DATASET_PIPELINE_SSE_ADDR") + if not addr: + return None + if ":" not in addr: + raise ValueError("BT_DATASET_PIPELINE_SSE_ADDR must be in host:port format") + host, port_str = addr.rsplit(":", 1) + sock = socket.create_connection((host, int(port_str))) + return SseWriter(sock) try: - return SseWriter(sock_path) + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(sock_path) + return SseWriter(sock) except Exception as exc: print(f"Failed to connect to dataset pipeline socket: {exc}", file=sys.stderr) return None diff --git a/scripts/dataset-pipeline-runner.ts b/scripts/dataset-pipeline-runner.ts index 4892938d..d5f8b4db 100644 --- a/scripts/dataset-pipeline-runner.ts +++ b/scripts/dataset-pipeline-runner.ts @@ -261,13 +261,27 @@ function serializeSseEvent(event: { event?: string; data: string }): string { function createSseWriter(): SseWriter | null { const sock = process.env.BT_DATASET_PIPELINE_SSE_SOCK; - if (!sock) { + const addr = process.env.BT_DATASET_PIPELINE_SSE_ADDR; + if (!sock && !addr) { + return null; + } + let socket: net.Socket; + if (sock) { + socket = net.createConnection({ path: sock }); + } else if (addr) { + const [host, portStr] = addr.split(":"); + const port = Number(portStr); + if (!host || !Number.isFinite(port)) { + throw new Error(`Invalid BT_DATASET_PIPELINE_SSE_ADDR: ${addr}`); + } + socket = net.createConnection({ host, port }); + socket.setNoDelay(true); + } else { return null; } - const socket = net.createConnection({ path: sock }); socket.on("error", (err) => { console.error( - `Failed to connect to dataset pipeline socket: ${ + `Failed to connect to dataset pipeline SSE endpoint: ${ err instanceof Error ? err.message : String(err) }`, ); diff --git a/src/datasets/mod.rs b/src/datasets/mod.rs index e866353e..39f6075e 100644 --- a/src/datasets/mod.rs +++ b/src/datasets/mod.rs @@ -108,7 +108,7 @@ enum DatasetsCommands { View(ViewArgs), /// Delete a dataset Delete(DeleteArgs), - /// Run full dataset pipelines, or stage fetch/transform/push + /// Run full dataset pipelines, or stage pull/transform/push Pipeline(pipeline::PipelineArgs), } diff --git a/src/datasets/pipeline.rs b/src/datasets/pipeline.rs index cdb47299..957532b1 100644 --- a/src/datasets/pipeline.rs +++ b/src/datasets/pipeline.rs @@ -43,7 +43,7 @@ const PIPELINE_ARTIFACT_SCHEMA_VERSION: u32 = 1; #[command(after_help = "\ Use `run` to run the whole pipeline. -For staged workflows, run `fetch`, then `transform`, inspect or edit the transformed JSONL, then upload it with: +For staged workflows, run `pull`, then `transform`, inspect or edit the transformed JSONL, then upload it with: bt datasets pipeline push ./pipeline.ts `push` reads the pipeline target and delegates to `bt sync push`. @@ -55,10 +55,10 @@ pub struct PipelineArgs { #[derive(Debug, Clone, Subcommand)] enum PipelineCommands { - /// Fetch, transform, and insert dataset rows + /// Pull, transform, and insert dataset rows Run(PipelineRunArgs), - /// Discover source trace/span refs to JSONL - Fetch(PipelineFetchArgs), + /// Pull source trace/span refs to JSONL + Pull(PipelineFetchArgs), /// Transform candidate JSONL into proposed dataset row JSONL Transform(PipelineTransformArgs), /// Push transformed dataset rows to the pipeline target @@ -145,9 +145,16 @@ struct PipelineFetchOptions { #[derive(Debug, Clone, Args)] struct PipelineTransformOptions { - /// Maximum concurrent transform calls - #[arg(long, default_value_t = 16, value_parser = parse_positive_usize)] - max_concurrency: usize, + /// Maximum concurrent transform calls. Defaults to the logical CPU count. + #[arg(long, value_parser = parse_positive_usize)] + max_concurrency: Option, +} + +impl PipelineTransformOptions { + fn max_concurrency(&self) -> usize { + self.max_concurrency + .unwrap_or_else(default_transform_concurrency) + } } #[derive(Debug, Clone, Args)] @@ -208,7 +215,7 @@ struct PipelineTransformArgs { #[command(flatten)] transform: PipelineTransformOptions, - /// Input candidate JSONL file. Defaults to the latest fetch output under --root. + /// Input candidate JSONL file. Defaults to the latest pull output under --root. #[arg(long = "in")] input: Option, @@ -248,7 +255,15 @@ pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { let tempdir = tempfile::tempdir().context("failed to create dataset pipeline temp dir")?; let refs_path = tempdir.path().join("discovered.jsonl"); - discover_refs(&base, &inspect, &args.fetch, &refs_path).await?; + print_pipeline_status(&base, "Fetching source refs..."); + let fetch_result = discover_refs(&base, &inspect, &args.fetch, &refs_path).await?; + print_pipeline_status( + &base, + format!( + "Fetched {} source ref(s) across {} page(s).", + fetch_result.refs, fetch_result.pages + ), + ); let refs = read_jsonl_values(&refs_path)?; let source_project = resolve_pipeline_source_project(&base, &inspect.source).await?; @@ -258,7 +273,7 @@ pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { &source_project.id, &inspect.source, refs, - args.transform.max_concurrency, + args.transform.max_concurrency(), None, ) .await?; @@ -275,7 +290,7 @@ pub async fn run(base: BaseArgs, args: PipelineArgs) -> Result<()> { false, ) } - PipelineCommands::Fetch(args) => { + PipelineCommands::Pull(args) => { let inspect = inspect_with_overrides( inspect_pipeline(&base, &args.runner).await?, Some(&args.source), @@ -386,7 +401,7 @@ enum PipelineArtifactStage { impl PipelineArtifactStage { fn command(self) -> &'static str { match self { - PipelineArtifactStage::Fetch => "fetch", + PipelineArtifactStage::Fetch => "pull", PipelineArtifactStage::Transform => "transform", } } @@ -440,7 +455,7 @@ struct PipelineTransformArtifactOptions { impl From<&PipelineTransformOptions> for PipelineTransformArtifactOptions { fn from(options: &PipelineTransformOptions) -> Self { Self { - max_concurrency: options.max_concurrency, + max_concurrency: options.max_concurrency(), } } } @@ -948,43 +963,37 @@ where F: FnMut(PipelineRunnerEvent), { let mut command = build_runner_command(base, stage, runner, |_, _| Ok(())).await?; - let (listener, socket_path, socket_cleanup_guard) = - runner_sse::bind_sse_listener("bt-dataset-pipeline")?; + let (listener, sse_guard) = runner_sse::bind_sse_listener("bt-dataset-pipeline")?; let (tx, rx) = mpsc::unbounded_channel::(); let sse_connected = Arc::new(AtomicBool::new(false)); let tx_sse = tx.clone(); let sse_connected_for_task = Arc::clone(&sse_connected); let mut sse_task = tokio::spawn(async move { - match listener.accept().await { - Ok((stream, _)) => { + if let Err(err) = runner_sse::accept_and_read_sse_stream( + listener, + || { sse_connected_for_task.store(true, Ordering::Relaxed); - if let Err(err) = runner_sse::read_sse_stream(stream, |event, data| { - handle_pipeline_sse_event(event, data, &tx_sse); - }) - .await - { - let _ = tx_sse.send(PipelineRunnerEvent::Error { - message: format!("SSE stream error: {err}"), - stack: None, - status: None, - }); - } - } - Err(err) => { - let _ = tx_sse.send(PipelineRunnerEvent::Error { - message: format!("Failed to accept SSE connection: {err}"), - stack: None, - status: None, - }); - } + }, + |event, data| { + handle_pipeline_sse_event(event, data, &tx_sse); + }, + ) + .await + { + let _ = tx_sse.send(PipelineRunnerEvent::Error { + message: format!("SSE stream error: {err}"), + stack: None, + status: None, + }); } }); - command.env( + let (sse_env_name, sse_env_value) = sse_guard.env( "BT_DATASET_PIPELINE_SSE_SOCK", - socket_path.to_string_lossy().to_string(), + "BT_DATASET_PIPELINE_SSE_ADDR", ); + command.env(sse_env_name, sse_env_value); command.stdin(Stdio::piped()); command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); @@ -1051,7 +1060,7 @@ where ) .await?; - let _socket_cleanup_guard = socket_cleanup_guard; + let _sse_guard = sse_guard; if !status.success() { let detail = if errors.is_empty() { String::new() @@ -1177,7 +1186,7 @@ async fn transform_refs(base: &BaseArgs, args: PipelineTransformArgs) -> Result< &source_project.id, &source, refs, - args.transform.max_concurrency, + args.transform.max_concurrency(), Some(&mut writer as &mut dyn Write), ) .await?; @@ -1613,6 +1622,12 @@ fn root_span_id_filter(root_span_ids: &[String]) -> Value { }) } +fn default_transform_concurrency() -> usize { + std::thread::available_parallelism() + .map(|parallelism| parallelism.get()) + .unwrap_or(16) +} + async fn resolve_source_project( base: &BaseArgs, client: &ApiClient, @@ -1660,6 +1675,12 @@ fn print_summary(base: &BaseArgs, summary: Value, force_stderr: bool) -> Result< Ok(()) } +fn print_pipeline_status(base: &BaseArgs, message: impl AsRef) { + if !base.json && !base.quiet { + eprintln!("{}", message.as_ref()); + } +} + fn summary_value(value: &Value) -> String { match value { Value::String(value) => value.clone(), @@ -1971,7 +1992,7 @@ mod tests { &runner, &source, &PipelineTransformOptions { - max_concurrency: 16, + max_concurrency: Some(16), }, &fetch_artifact.output_path, ); diff --git a/src/eval.rs b/src/eval.rs index 475fdeee..6402e57f 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -91,7 +91,7 @@ struct EvalRunnerProcess { rx: mpsc::UnboundedReceiver, sse_task: tokio::task::JoinHandle<()>, sse_connected: Arc, - _socket_cleanup_guard: runner_sse::SocketCleanupGuard, + _sse_guard: runner_sse::SseListenerGuard, } struct EvalProcessOutput { @@ -752,36 +752,30 @@ async fn spawn_eval_runner( let (js_runner, py_runner) = prepare_eval_runners()?; let force_esm = matches!(js_mode, JsMode::ForceEsm); - let (listener, socket_path, socket_cleanup_guard) = runner_sse::bind_sse_listener("bt-eval")?; + let (listener, sse_guard) = runner_sse::bind_sse_listener("bt-eval")?; let (tx, rx) = mpsc::unbounded_channel(); let sse_connected = Arc::new(AtomicBool::new(false)); let tx_sse = tx.clone(); let sse_connected_for_task = Arc::clone(&sse_connected); let sse_task = tokio::spawn(async move { - match listener.accept().await { - Ok((stream, _)) => { + if let Err(err) = runner_sse::accept_and_read_sse_stream( + listener, + || { sse_connected_for_task.store(true, Ordering::Relaxed); - if let Err(err) = runner_sse::read_sse_stream(stream, |event, data| { - handle_sse_event(event, data, &tx_sse); - }) - .await - { - let _ = tx_sse.send(EvalEvent::Error { - message: format!("SSE stream error: {err}"), - stack: None, - status: None, - }); - } - } - Err(err) => { - let _ = tx_sse.send(EvalEvent::Error { - message: format!("Failed to accept SSE connection: {err}"), - stack: None, - status: None, - }); - } - }; + }, + |event, data| { + handle_sse_event(event, data, &tx_sse); + }, + ) + .await + { + let _ = tx_sse.send(EvalEvent::Error { + message: format!("SSE stream error: {err}"), + stack: None, + status: None, + }); + } }); let (mut cmd, runner_kind) = match language { @@ -880,10 +874,8 @@ async fn spawn_eval_runner( serde_json::to_string(&payload).context("failed to serialize matrix axes")?; cmd.env("BT_EVAL_MATRIX_JSON", serialized); } - cmd.env( - "BT_EVAL_SSE_SOCK", - socket_path.to_string_lossy().to_string(), - ); + let (sse_env_name, sse_env_value) = sse_guard.env("BT_EVAL_SSE_SOCK", "BT_EVAL_SSE_ADDR"); + cmd.env(sse_env_name, sse_env_value); cmd.stdout(Stdio::piped()); cmd.stderr(Stdio::piped()); @@ -932,7 +924,7 @@ async fn spawn_eval_runner( rx, sse_task, sse_connected, - _socket_cleanup_guard: socket_cleanup_guard, + _sse_guard: sse_guard, }, runner_kind, }) @@ -951,7 +943,7 @@ where rx, mut sse_task, sse_connected, - _socket_cleanup_guard, + _sse_guard, } = process; let mut dependency_files: Vec = Vec::new(); let mut error_messages: Vec = Vec::new(); @@ -4209,6 +4201,7 @@ mod tests { assert!(message.contains("pnpm add -D vite-node")); } + #[cfg(unix)] #[test] fn build_sse_socket_path_is_unique_for_consecutive_calls() { let first = runner_sse::build_sse_socket_path("bt-eval").expect("first socket path"); diff --git a/src/runner_sse.rs b/src/runner_sse.rs index abfbbad6..b0b2b593 100644 --- a/src/runner_sse.rs +++ b/src/runner_sse.rs @@ -1,44 +1,112 @@ -use std::path::PathBuf; use std::pin::Pin; use std::process::ExitStatus; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::sync::atomic::{AtomicBool, Ordering}; use anyhow::{Context, Result}; use std::future::Future; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::net::UnixListener; use tokio::sync::mpsc; +#[cfg(unix)] +use std::path::PathBuf; +#[cfg(unix)] +use std::sync::atomic::AtomicU64; +#[cfg(unix)] +use std::time::{SystemTime, UNIX_EPOCH}; +#[cfg(not(unix))] +use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; + +#[cfg(unix)] const SOCKET_BIND_MAX_ATTEMPTS: u8 = 16; +#[cfg(unix)] static SOCKET_COUNTER: AtomicU64 = AtomicU64::new(0); -pub(crate) struct SocketCleanupGuard { +pub(crate) enum SseListener { + #[cfg(unix)] + Unix(UnixListener), + #[cfg(not(unix))] + Tcp(TcpListener), +} + +pub(crate) struct SseListenerGuard { + endpoint: SseEndpoint, + #[cfg(unix)] + _socket_cleanup_guard: SocketCleanupGuard, +} + +enum SseEndpoint { + #[cfg(unix)] + Unix(PathBuf), + #[cfg(not(unix))] + Tcp(std::net::SocketAddr), +} + +#[cfg(unix)] +struct SocketCleanupGuard { path: PathBuf, } +#[cfg(unix)] impl SocketCleanupGuard { fn new(path: PathBuf) -> Self { Self { path } } } +#[cfg(unix)] impl Drop for SocketCleanupGuard { fn drop(&mut self) { let _ = std::fs::remove_file(&self.path); } } -pub(crate) fn bind_sse_listener( - prefix: &str, -) -> Result<(UnixListener, PathBuf, SocketCleanupGuard)> { +impl SseListenerGuard { + pub(crate) fn env<'a>(&self, socket_var: &'a str, addr_var: &'a str) -> (&'a str, String) { + #[cfg(unix)] + let _ = addr_var; + #[cfg(not(unix))] + let _ = socket_var; + match &self.endpoint { + #[cfg(unix)] + SseEndpoint::Unix(path) => (socket_var, path.to_string_lossy().to_string()), + #[cfg(not(unix))] + SseEndpoint::Tcp(addr) => (addr_var, addr.to_string()), + } + } +} + +pub(crate) fn bind_sse_listener(prefix: &str) -> Result<(SseListener, SseListenerGuard)> { + #[cfg(unix)] + { + bind_unix_sse_listener(prefix) + } + + #[cfg(not(unix))] + { + let _ = prefix; + bind_tcp_sse_listener() + } +} + +#[cfg(unix)] +fn bind_unix_sse_listener(prefix: &str) -> Result<(SseListener, SseListenerGuard)> { let mut last_bind_err: Option = None; for _ in 0..SOCKET_BIND_MAX_ATTEMPTS { let socket_path = build_sse_socket_path(prefix)?; let socket_cleanup_guard = SocketCleanupGuard::new(socket_path.clone()); let _ = std::fs::remove_file(&socket_path); match UnixListener::bind(&socket_path) { - Ok(listener) => return Ok((listener, socket_path, socket_cleanup_guard)), + Ok(listener) => { + return Ok(( + SseListener::Unix(listener), + SseListenerGuard { + endpoint: SseEndpoint::Unix(socket_path), + _socket_cleanup_guard: socket_cleanup_guard, + }, + )) + } Err(err) if matches!( err.kind(), @@ -64,6 +132,27 @@ pub(crate) fn bind_sse_listener( )) } +#[cfg(not(unix))] +fn bind_tcp_sse_listener() -> Result<(SseListener, SseListenerGuard)> { + let std_listener = + std::net::TcpListener::bind(("127.0.0.1", 0)).context("failed to bind SSE TCP listener")?; + std_listener + .set_nonblocking(true) + .context("failed to configure SSE TCP listener")?; + let addr = std_listener + .local_addr() + .context("failed to read SSE TCP listener address")?; + let listener = + TcpListener::from_std(std_listener).context("failed to create SSE TCP listener")?; + Ok(( + SseListener::Tcp(listener), + SseListenerGuard { + endpoint: SseEndpoint::Tcp(addr), + }, + )) +} + +#[cfg(unix)] pub(crate) fn build_sse_socket_path(prefix: &str) -> Result { let pid = std::process::id(); let serial = SOCKET_COUNTER.fetch_add(1, Ordering::Relaxed); @@ -74,6 +163,37 @@ pub(crate) fn build_sse_socket_path(prefix: &str) -> Result { Ok(std::env::temp_dir().join(format!("{prefix}-{pid}-{now}-{serial}.sock"))) } +pub(crate) async fn accept_and_read_sse_stream( + listener: SseListener, + on_connected: C, + on_event: F, +) -> Result<()> +where + C: FnOnce(), + F: FnMut(Option, String), +{ + match listener { + #[cfg(unix)] + SseListener::Unix(listener) => { + let (stream, _) = listener + .accept() + .await + .context("failed to accept SSE unix socket connection")?; + on_connected(); + read_sse_stream(stream, on_event).await + } + #[cfg(not(unix))] + SseListener::Tcp(listener) => { + let (stream, _) = listener + .accept() + .await + .context("failed to accept SSE TCP connection")?; + on_connected(); + read_sse_stream(stream, on_event).await + } + } +} + pub(crate) async fn forward_stream( stream: T, name: &'static str,