diff --git a/pkg/tools/builtin/sql_agent.go b/pkg/tools/builtin/sql_agent.go index f9c6821..a35af6d 100644 --- a/pkg/tools/builtin/sql_agent.go +++ b/pkg/tools/builtin/sql_agent.go @@ -89,6 +89,7 @@ type CallSQLAgentTool struct { allowSelectStar bool execSQLRequiresConfirmation bool providerHint string + cellRedactor CellRedactor } // NewCallSQLAgentTool initializes a tool capable of querying databases. The @@ -171,6 +172,30 @@ func (t *CallSQLAgentTool) WithLLMPreviewRows(n int) *CallSQLAgentTool { return t } +// CellRedactor transforms a single result-cell value before it is serialized +// into the text the sub-agent LLM reads — a privacy guard for sending SQL +// results to a third-party provider (e.g. mask an email to "j***@***.com"). It +// receives the column name and the raw cell value and returns the value to +// expose to the model. It runs ONLY on the model-facing preview: the +// host-facing rows (OnSQL / SQLQueryEvent and tools.Result.Structured) and the +// local grid keep full-fidelity values. Return the value unchanged to leave a +// cell as-is. Returning a mutable value (slice/map) aliases it into the +// model-facing copy — return a fresh value if the host mutates it later. +type CellRedactor func(column string, value any) any + +// WithCellRedactor installs a per-cell transform applied to result values +// before they are serialized into the model's context — a privacy guard for +// sending SQL results to a third-party LLM. The redactor runs on a deep copy of +// the (already row-capped) preview only: the OnSQL / SQLQueryEvent hook, the +// tools.Result.Structured payload, and the host's grid all keep the real +// values, so the user still sees unmasked data locally while the model sees +// masked values. Pairs naturally with WithLLMPreviewRows, which bounds how many +// rows are copied and masked. nil (default) disables redaction with no copying. +func (t *CallSQLAgentTool) WithCellRedactor(fn CellRedactor) *CallSQLAgentTool { + t.cellRedactor = fn + return t +} + // WithQueryTimeout caps the wall-clock time of each underlying QueryContext // call. d <= 0 disables the timeout (default). Separate from the agent's // overall request context, which may be much longer. @@ -413,6 +438,7 @@ func (t *CallSQLAgentTool) runOnce(ctx context.Context, query string, idx int) s sessionKey: subSessionKey, maxRows: t.maxRows, llmPreviewRows: t.llmPreviewRows, + cellRedactor: t.cellRedactor, queryTimeout: t.queryTimeout, allowMutations: t.allowMutations, allowDDL: t.allowDDL, @@ -632,6 +658,7 @@ type executeSQLTool struct { onSQL func(context.Context, SQLQueryEvent) maxRows int llmPreviewRows int + cellRedactor CellRedactor queryTimeout time.Duration allowMutations bool allowDDL bool @@ -735,7 +762,10 @@ func (t *executeSQLTool) makeEmitFunc(ctx context.Context) func(SQLResult) (tool Truncated: res.Truncated, }) } - b, err := json.Marshal(previewForLLM(res, t.llmPreviewRows)) + // Mask the model-facing copy only; res (OnSQL, Structured, the host + // grid) keeps full-fidelity values. redactRows deep-copies, so the + // shared row maps under res are never mutated. + b, err := json.Marshal(redactRows(previewForLLM(res, t.llmPreviewRows), t.cellRedactor)) if err != nil { return tools.Result{}, fmt.Errorf("tools: marshal result: %w", err) } @@ -758,6 +788,30 @@ func previewForLLM(res SQLResult, n int) SQLResult { return preview } +// redactRows returns a copy of res whose every cell value is passed through fn, +// for masking sensitive values before they reach the LLM. It deep-copies the +// row maps, so res — shared with the OnSQL hook, tools.Result.Structured, and +// the host grid — keeps its full-fidelity values (the maps in res.Rows are +// never mutated). fn == nil or no rows is a no-op returning res unchanged, so +// the default path allocates nothing. Only the rows already selected for the +// preview are copied, so the cost scales with the LLM preview size. +func redactRows(res SQLResult, fn CellRedactor) SQLResult { + if fn == nil || len(res.Rows) == 0 { + return res + } + rows := make([]map[string]any, len(res.Rows)) + for i, row := range res.Rows { + masked := make(map[string]any, len(row)) + for col, val := range row { + masked[col] = fn(col, val) + } + rows[i] = masked + } + out := res + out.Rows = rows + return out +} + // executeRead runs the read-only path: optional LIMIT injection, then // QueryContext + row iteration into a structured SQLResult. func (t *executeSQLTool) executeRead(ctx context.Context, sqlStr string, emit func(SQLResult) (tools.Result, error)) (tools.Result, error) { diff --git a/pkg/tools/builtin/sql_agent_redact_test.go b/pkg/tools/builtin/sql_agent_redact_test.go new file mode 100644 index 0000000..7d3a429 --- /dev/null +++ b/pkg/tools/builtin/sql_agent_redact_test.go @@ -0,0 +1,124 @@ +package builtin + +import ( + "context" + "encoding/json" + "testing" +) + +func maskEmail(col string, val any) any { + if col == "email" { + return "***" + } + return val +} + +// TestEmit_CellRedactor_MasksTextNotHookOrStructured asserts WithCellRedactor +// masks ONLY the LLM-visible result text; the OnSQL hook (host grid), the +// tools.Result.Structured payload, and the caller's source row maps all keep +// full-fidelity values — i.e. the redaction deep-copies and never mutates the +// shared row maps. +func TestEmit_CellRedactor_MasksTextNotHookOrStructured(t *testing.T) { + var hookEv SQLQueryEvent + exec := &executeSQLTool{ + sessionKey: "s", + onSQL: func(_ context.Context, ev SQLQueryEvent) { hookEv = ev }, + cellRedactor: maskEmail, + } + // "phone": nil exercises fn(col, nil) and the marshal of a redacted nil. + row := map[string]any{"id": 1, "email": "alice@example.com", "phone": nil} + res := SQLResult{SQL: "SELECT id, email, phone FROM users", Columns: []string{"id", "email", "phone"}, Rows: []map[string]any{row}, RowCount: 1} + + out, err := exec.makeEmitFunc(context.Background())(res) + if err != nil { + t.Fatalf("emit: %v", err) + } + + // Model-facing Text is masked. + var seen SQLResult + if err := json.Unmarshal([]byte(out.Text), &seen); err != nil { + t.Fatalf("unmarshal text: %v", err) + } + if got := seen.Rows[0]["email"]; got != "***" { + t.Fatalf("model-facing email should be masked, got %v", got) + } + if _, ok := seen.Rows[0]["id"]; !ok { + t.Fatalf("non-sensitive column dropped from preview: %v", seen.Rows[0]) + } + // A nil cell flows through the redactor and marshals without panicking. + if got, ok := seen.Rows[0]["phone"]; !ok || got != nil { + t.Fatalf("nil cell should pass through redactor as null, got %v (present=%v)", got, ok) + } + + // Host hook keeps the real value (aliasing guard). + if got := hookEv.Rows[0]["email"]; got != "alice@example.com" { + t.Fatalf("OnSQL hook email must stay unmasked, got %v", got) + } + // Structured (host via OnToolResult) keeps the real value. + if sr, ok := out.Structured.(SQLResult); !ok || sr.Rows[0]["email"] != "alice@example.com" { + t.Fatalf("Structured email must stay unmasked, got %#v", out.Structured) + } + // The caller's source row map must not be mutated in place. + if got := row["email"]; got != "alice@example.com" { + t.Fatalf("source row map mutated by redactor: %v", got) + } +} + +// TestEmit_CellRedactor_ComposesWithPreview asserts redaction runs on the +// row-capped preview and still leaves the shared source maps untouched even on +// the truncating branch (where the preview shares the underlying maps). +func TestEmit_CellRedactor_ComposesWithPreview(t *testing.T) { + r0 := map[string]any{"id": 0, "email": "a@x.com"} + r1 := map[string]any{"id": 1, "email": "b@x.com"} + res := SQLResult{Columns: []string{"id", "email"}, Rows: []map[string]any{r0, r1}, RowCount: 2} + exec := &executeSQLTool{sessionKey: "s", llmPreviewRows: 1, cellRedactor: maskEmail} + + out, err := exec.makeEmitFunc(context.Background())(res) + if err != nil { + t.Fatalf("emit: %v", err) + } + var seen SQLResult + if err := json.Unmarshal([]byte(out.Text), &seen); err != nil { + t.Fatalf("unmarshal text: %v", err) + } + if len(seen.Rows) != 1 { + t.Fatalf("preview should cap to 1 row, got %d", len(seen.Rows)) + } + if seen.Rows[0]["email"] != "***" { + t.Fatalf("preview email not masked: %v", seen.Rows[0]["email"]) + } + if !seen.Truncated { + t.Fatal("preview should be flagged Truncated") + } + // Source maps (shared with the host grid / Structured) untouched. + if r0["email"] != "a@x.com" || r1["email"] != "b@x.com" { + t.Fatalf("source rows mutated: r0=%v r1=%v", r0["email"], r1["email"]) + } +} + +// TestEmit_CellRedactor_NilIsNoOp asserts the default (no redactor) passes +// values through unchanged, byte-identical to pre-feature behavior. +func TestEmit_CellRedactor_NilIsNoOp(t *testing.T) { + exec := &executeSQLTool{sessionKey: "s"} // cellRedactor nil + res := SQLResult{Columns: []string{"email"}, Rows: []map[string]any{{"email": "real@x.com"}}, RowCount: 1} + + out, err := exec.makeEmitFunc(context.Background())(res) + if err != nil { + t.Fatalf("emit: %v", err) + } + var seen SQLResult + if err := json.Unmarshal([]byte(out.Text), &seen); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if seen.Rows[0]["email"] != "real@x.com" { + t.Fatalf("nil redactor must be a no-op, got %v", seen.Rows[0]["email"]) + } +} + +// TestWithCellRedactor_SetsField asserts the builder wires the redactor through. +func TestWithCellRedactor_SetsField(t *testing.T) { + tool := NewCallSQLAgentTool(nil, "", nil, nil).WithCellRedactor(maskEmail) + if tool.cellRedactor == nil { + t.Fatal("WithCellRedactor did not set the redactor") + } +}