Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion pkg/tools/builtin/sql_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ type CallSQLAgentTool struct {
allowSelectStar bool
execSQLRequiresConfirmation bool
providerHint string
cellRedactor CellRedactor
}

// NewCallSQLAgentTool initializes a tool capable of querying databases. The
Expand Down Expand Up @@ -171,6 +172,30 @@ func (t *CallSQLAgentTool) WithLLMPreviewRows(n int) *CallSQLAgentTool {
return t
}

// CellRedactor transforms a single result-cell value before it is serialized
// into the text the sub-agent LLM reads — a privacy guard for sending SQL
// results to a third-party provider (e.g. mask an email to "j***@***.com"). It
// receives the column name and the raw cell value and returns the value to
// expose to the model. It runs ONLY on the model-facing preview: the
// host-facing rows (OnSQL / SQLQueryEvent and tools.Result.Structured) and the
// local grid keep full-fidelity values. Return the value unchanged to leave a
// cell as-is. Returning a mutable value (slice/map) aliases it into the
// model-facing copy — return a fresh value if the host mutates it later.
type CellRedactor func(column string, value any) any

// WithCellRedactor installs a per-cell transform applied to result values
// before they are serialized into the model's context — a privacy guard for
// sending SQL results to a third-party LLM. The redactor runs on a deep copy of
// the (already row-capped) preview only: the OnSQL / SQLQueryEvent hook, the
// tools.Result.Structured payload, and the host's grid all keep the real
// values, so the user still sees unmasked data locally while the model sees
// masked values. Pairs naturally with WithLLMPreviewRows, which bounds how many
// rows are copied and masked. nil (default) disables redaction with no copying.
func (t *CallSQLAgentTool) WithCellRedactor(fn CellRedactor) *CallSQLAgentTool {
t.cellRedactor = fn
return t
}

// WithQueryTimeout caps the wall-clock time of each underlying QueryContext
// call. d <= 0 disables the timeout (default). Separate from the agent's
// overall request context, which may be much longer.
Expand Down Expand Up @@ -413,6 +438,7 @@ func (t *CallSQLAgentTool) runOnce(ctx context.Context, query string, idx int) s
sessionKey: subSessionKey,
maxRows: t.maxRows,
llmPreviewRows: t.llmPreviewRows,
cellRedactor: t.cellRedactor,
queryTimeout: t.queryTimeout,
allowMutations: t.allowMutations,
allowDDL: t.allowDDL,
Expand Down Expand Up @@ -632,6 +658,7 @@ type executeSQLTool struct {
onSQL func(context.Context, SQLQueryEvent)
maxRows int
llmPreviewRows int
cellRedactor CellRedactor
queryTimeout time.Duration
allowMutations bool
allowDDL bool
Expand Down Expand Up @@ -735,7 +762,10 @@ func (t *executeSQLTool) makeEmitFunc(ctx context.Context) func(SQLResult) (tool
Truncated: res.Truncated,
})
}
b, err := json.Marshal(previewForLLM(res, t.llmPreviewRows))
// Mask the model-facing copy only; res (OnSQL, Structured, the host
// grid) keeps full-fidelity values. redactRows deep-copies, so the
// shared row maps under res are never mutated.
b, err := json.Marshal(redactRows(previewForLLM(res, t.llmPreviewRows), t.cellRedactor))
if err != nil {
return tools.Result{}, fmt.Errorf("tools: marshal result: %w", err)
}
Expand All @@ -758,6 +788,30 @@ func previewForLLM(res SQLResult, n int) SQLResult {
return preview
}

// redactRows returns a copy of res whose every cell value is passed through fn,
// for masking sensitive values before they reach the LLM. It deep-copies the
// row maps, so res — shared with the OnSQL hook, tools.Result.Structured, and
// the host grid — keeps its full-fidelity values (the maps in res.Rows are
// never mutated). fn == nil or no rows is a no-op returning res unchanged, so
// the default path allocates nothing. Only the rows already selected for the
// preview are copied, so the cost scales with the LLM preview size.
func redactRows(res SQLResult, fn CellRedactor) SQLResult {
if fn == nil || len(res.Rows) == 0 {
return res
}
rows := make([]map[string]any, len(res.Rows))
for i, row := range res.Rows {
masked := make(map[string]any, len(row))
for col, val := range row {
masked[col] = fn(col, val)
}
rows[i] = masked
}
out := res
out.Rows = rows
return out
}

// executeRead runs the read-only path: optional LIMIT injection, then
// QueryContext + row iteration into a structured SQLResult.
func (t *executeSQLTool) executeRead(ctx context.Context, sqlStr string, emit func(SQLResult) (tools.Result, error)) (tools.Result, error) {
Expand Down
124 changes: 124 additions & 0 deletions pkg/tools/builtin/sql_agent_redact_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package builtin

import (
"context"
"encoding/json"
"testing"
)

func maskEmail(col string, val any) any {
if col == "email" {
return "***"
}
return val
}

// TestEmit_CellRedactor_MasksTextNotHookOrStructured asserts WithCellRedactor
// masks ONLY the LLM-visible result text; the OnSQL hook (host grid), the
// tools.Result.Structured payload, and the caller's source row maps all keep
// full-fidelity values — i.e. the redaction deep-copies and never mutates the
// shared row maps.
func TestEmit_CellRedactor_MasksTextNotHookOrStructured(t *testing.T) {
var hookEv SQLQueryEvent
exec := &executeSQLTool{
sessionKey: "s",
onSQL: func(_ context.Context, ev SQLQueryEvent) { hookEv = ev },
cellRedactor: maskEmail,
}
// "phone": nil exercises fn(col, nil) and the marshal of a redacted nil.
row := map[string]any{"id": 1, "email": "alice@example.com", "phone": nil}
res := SQLResult{SQL: "SELECT id, email, phone FROM users", Columns: []string{"id", "email", "phone"}, Rows: []map[string]any{row}, RowCount: 1}

out, err := exec.makeEmitFunc(context.Background())(res)
if err != nil {
t.Fatalf("emit: %v", err)
}

// Model-facing Text is masked.
var seen SQLResult
if err := json.Unmarshal([]byte(out.Text), &seen); err != nil {
t.Fatalf("unmarshal text: %v", err)
}
if got := seen.Rows[0]["email"]; got != "***" {
t.Fatalf("model-facing email should be masked, got %v", got)
}
if _, ok := seen.Rows[0]["id"]; !ok {
t.Fatalf("non-sensitive column dropped from preview: %v", seen.Rows[0])
}
// A nil cell flows through the redactor and marshals without panicking.
if got, ok := seen.Rows[0]["phone"]; !ok || got != nil {
t.Fatalf("nil cell should pass through redactor as null, got %v (present=%v)", got, ok)
}

// Host hook keeps the real value (aliasing guard).
if got := hookEv.Rows[0]["email"]; got != "alice@example.com" {
t.Fatalf("OnSQL hook email must stay unmasked, got %v", got)
}
// Structured (host via OnToolResult) keeps the real value.
if sr, ok := out.Structured.(SQLResult); !ok || sr.Rows[0]["email"] != "alice@example.com" {
t.Fatalf("Structured email must stay unmasked, got %#v", out.Structured)
}
// The caller's source row map must not be mutated in place.
if got := row["email"]; got != "alice@example.com" {
t.Fatalf("source row map mutated by redactor: %v", got)
}
}

// TestEmit_CellRedactor_ComposesWithPreview asserts redaction runs on the
// row-capped preview and still leaves the shared source maps untouched even on
// the truncating branch (where the preview shares the underlying maps).
func TestEmit_CellRedactor_ComposesWithPreview(t *testing.T) {
r0 := map[string]any{"id": 0, "email": "a@x.com"}
r1 := map[string]any{"id": 1, "email": "b@x.com"}
res := SQLResult{Columns: []string{"id", "email"}, Rows: []map[string]any{r0, r1}, RowCount: 2}
exec := &executeSQLTool{sessionKey: "s", llmPreviewRows: 1, cellRedactor: maskEmail}

out, err := exec.makeEmitFunc(context.Background())(res)
if err != nil {
t.Fatalf("emit: %v", err)
}
var seen SQLResult
if err := json.Unmarshal([]byte(out.Text), &seen); err != nil {
t.Fatalf("unmarshal text: %v", err)
}
if len(seen.Rows) != 1 {
t.Fatalf("preview should cap to 1 row, got %d", len(seen.Rows))
}
if seen.Rows[0]["email"] != "***" {
t.Fatalf("preview email not masked: %v", seen.Rows[0]["email"])
}
if !seen.Truncated {
t.Fatal("preview should be flagged Truncated")
}
// Source maps (shared with the host grid / Structured) untouched.
if r0["email"] != "a@x.com" || r1["email"] != "b@x.com" {
t.Fatalf("source rows mutated: r0=%v r1=%v", r0["email"], r1["email"])
}
}

// TestEmit_CellRedactor_NilIsNoOp asserts the default (no redactor) passes
// values through unchanged, byte-identical to pre-feature behavior.
func TestEmit_CellRedactor_NilIsNoOp(t *testing.T) {
exec := &executeSQLTool{sessionKey: "s"} // cellRedactor nil
res := SQLResult{Columns: []string{"email"}, Rows: []map[string]any{{"email": "real@x.com"}}, RowCount: 1}

out, err := exec.makeEmitFunc(context.Background())(res)
if err != nil {
t.Fatalf("emit: %v", err)
}
var seen SQLResult
if err := json.Unmarshal([]byte(out.Text), &seen); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if seen.Rows[0]["email"] != "real@x.com" {
t.Fatalf("nil redactor must be a no-op, got %v", seen.Rows[0]["email"])
}
}

// TestWithCellRedactor_SetsField asserts the builder wires the redactor through.
func TestWithCellRedactor_SetsField(t *testing.T) {
tool := NewCallSQLAgentTool(nil, "", nil, nil).WithCellRedactor(maskEmail)
if tool.cellRedactor == nil {
t.Fatal("WithCellRedactor did not set the redactor")
}
}
Loading