diff --git a/sdk/go/client_test.go b/sdk/go/client_test.go index 0f82cb1794..412450e36a 100644 --- a/sdk/go/client_test.go +++ b/sdk/go/client_test.go @@ -227,6 +227,98 @@ func TestSearchOmitsSearchFiltersWhenUnset(t *testing.T) { } } +func TestFindSendsIncludeProvenanceAndParsesProvenance(t *testing.T) { + client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/search/find" { + t.Fatalf("path = %s", r.URL.Path) + } + body := readJSONBody(t, r) + if got := body["include_provenance"]; got != true { + t.Fatalf("include_provenance = %#v", got) + } + writeOK(t, w, map[string]any{ + "resources": []any{}, + "provenance": []map[string]any{ + {"query": "auth", "score_threshold": 0.5}, + }, + }) + })) + defer closeServer() + + res, err := client.Find(context.Background(), "auth", &FindOptions{IncludeProvenance: true}) + if err != nil { + t.Fatal(err) + } + if len(res.QueryResults) != 1 { + t.Fatalf("QueryResults = %#v", res.QueryResults) + } + if got := res.QueryResults[0]["query"]; got != "auth" { + t.Fatalf("QueryResults[0][query] = %#v", got) + } +} + +func TestFindOmitsIncludeProvenanceWhenUnset(t *testing.T) { + client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/search/find" { + t.Fatalf("path = %s", r.URL.Path) + } + body := readJSONBody(t, r) + requireBodyKeysAbsent(t, body, "include_provenance") + writeOK(t, w, map[string]any{"resources": []any{}}) + })) + defer closeServer() + + if _, err := client.Find(context.Background(), "auth", &FindOptions{}); err != nil { + t.Fatal(err) + } +} + +func TestSearchSendsIncludeProvenanceAndParsesProvenance(t *testing.T) { + client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/search/search" { + t.Fatalf("path = %s", r.URL.Path) + } + body := readJSONBody(t, r) + if got := body["include_provenance"]; got != true { + t.Fatalf("include_provenance = %#v", got) + } + writeOK(t, w, map[string]any{ + "resources": []any{}, + "provenance": []map[string]any{ + {"query": "auth", "score_threshold": 0.5}, + }, + }) + })) + defer closeServer() + + res, err := client.Search(context.Background(), "auth", &SearchOptions{IncludeProvenance: true}) + if err != nil { + t.Fatal(err) + } + if len(res.QueryResults) != 1 { + t.Fatalf("QueryResults = %#v", res.QueryResults) + } + if got := res.QueryResults[0]["query"]; got != "auth" { + t.Fatalf("QueryResults[0][query] = %#v", got) + } +} + +func TestSearchOmitsIncludeProvenanceWhenUnset(t *testing.T) { + client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/search/search" { + t.Fatalf("path = %s", r.URL.Path) + } + body := readJSONBody(t, r) + requireBodyKeysAbsent(t, body, "include_provenance") + writeOK(t, w, map[string]any{"resources": []any{}}) + })) + defer closeServer() + + if _, err := client.Search(context.Background(), "auth", &SearchOptions{}); err != nil { + t.Fatal(err) + } +} + func TestErrorEnvelopePreservesCodeDetailsAndStatus(t *testing.T) { client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeAPIError(t, w, http.StatusNotFound, "NOT_FOUND", map[string]any{"resource": "viking://resources/missing"}) diff --git a/sdk/go/retrieval.go b/sdk/go/retrieval.go index 72d32e3a02..9e3ac3c362 100644 --- a/sdk/go/retrieval.go +++ b/sdk/go/retrieval.go @@ -32,6 +32,9 @@ func (c *Client) Find(ctx context.Context, queryText string, opts *FindOptions) if len(opts.Level) > 0 { payload["level"] = opts.Level } + if opts.IncludeProvenance { + payload["include_provenance"] = true + } setAny(payload, "telemetry", opts.Telemetry) var result FindResult err := c.doJSON(ctx, http.MethodPost, "/api/v1/search/find", nil, payload, &result) @@ -66,6 +69,9 @@ func (c *Client) Search(ctx context.Context, queryText string, opts *SearchOptio if len(opts.Level) > 0 { payload["level"] = opts.Level } + if opts.IncludeProvenance { + payload["include_provenance"] = true + } setAny(payload, "telemetry", opts.Telemetry) var result FindResult err := c.doJSON(ctx, http.MethodPost, "/api/v1/search/search", nil, payload, &result) diff --git a/sdk/go/types.go b/sdk/go/types.go index 7e18886bcd..64cc3b35f2 100644 --- a/sdk/go/types.go +++ b/sdk/go/types.go @@ -161,6 +161,9 @@ type FindOptions struct { Until string TimeField string Level []int + // IncludeProvenance asks the server to return per-query retrieval + // provenance (see FindResult.QueryResults). Defaults to false. + IncludeProvenance bool } // SearchOptions controls Search. @@ -177,6 +180,9 @@ type SearchOptions struct { Until string TimeField string Level []int + // IncludeProvenance asks the server to return per-query retrieval + // provenance (see FindResult.QueryResults). Defaults to false. + IncludeProvenance bool } // GrepOptions controls Grep. @@ -257,7 +263,10 @@ type FindResult struct { Resources []MatchedContext `json:"resources,omitempty"` Skills []MatchedContext `json:"skills,omitempty"` QueryPlan *QueryPlan `json:"query_plan,omitempty"` - QueryResults []map[string]any `json:"query_results,omitempty"` + // QueryResults holds the per-query retrieval provenance returned when + // the request sets IncludeProvenance. The server emits this list under + // the "provenance" key. + QueryResults []map[string]any `json:"provenance,omitempty"` Total int `json:"total,omitempty"` }