Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions sdk/go/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,98 @@ func TestSearchOmitsSearchFiltersWhenUnset(t *testing.T) {
}
}

func TestFindSendsIncludeProvenanceAndParsesProvenance(t *testing.T) {
client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/search/find" {
t.Fatalf("path = %s", r.URL.Path)
}
body := readJSONBody(t, r)
if got := body["include_provenance"]; got != true {
t.Fatalf("include_provenance = %#v", got)
}
writeOK(t, w, map[string]any{
"resources": []any{},
"provenance": []map[string]any{
{"query": "auth", "score_threshold": 0.5},
},
})
}))
defer closeServer()

res, err := client.Find(context.Background(), "auth", &FindOptions{IncludeProvenance: true})
if err != nil {
t.Fatal(err)
}
if len(res.QueryResults) != 1 {
t.Fatalf("QueryResults = %#v", res.QueryResults)
}
if got := res.QueryResults[0]["query"]; got != "auth" {
t.Fatalf("QueryResults[0][query] = %#v", got)
}
}

func TestFindOmitsIncludeProvenanceWhenUnset(t *testing.T) {
client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/search/find" {
t.Fatalf("path = %s", r.URL.Path)
}
body := readJSONBody(t, r)
requireBodyKeysAbsent(t, body, "include_provenance")
writeOK(t, w, map[string]any{"resources": []any{}})
}))
defer closeServer()

if _, err := client.Find(context.Background(), "auth", &FindOptions{}); err != nil {
t.Fatal(err)
}
}

func TestSearchSendsIncludeProvenanceAndParsesProvenance(t *testing.T) {
client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/search/search" {
t.Fatalf("path = %s", r.URL.Path)
}
body := readJSONBody(t, r)
if got := body["include_provenance"]; got != true {
t.Fatalf("include_provenance = %#v", got)
}
writeOK(t, w, map[string]any{
"resources": []any{},
"provenance": []map[string]any{
{"query": "auth", "score_threshold": 0.5},
},
})
}))
defer closeServer()

res, err := client.Search(context.Background(), "auth", &SearchOptions{IncludeProvenance: true})
if err != nil {
t.Fatal(err)
}
if len(res.QueryResults) != 1 {
t.Fatalf("QueryResults = %#v", res.QueryResults)
}
if got := res.QueryResults[0]["query"]; got != "auth" {
t.Fatalf("QueryResults[0][query] = %#v", got)
}
}

func TestSearchOmitsIncludeProvenanceWhenUnset(t *testing.T) {
client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/search/search" {
t.Fatalf("path = %s", r.URL.Path)
}
body := readJSONBody(t, r)
requireBodyKeysAbsent(t, body, "include_provenance")
writeOK(t, w, map[string]any{"resources": []any{}})
}))
defer closeServer()

if _, err := client.Search(context.Background(), "auth", &SearchOptions{}); err != nil {
t.Fatal(err)
}
}

func TestErrorEnvelopePreservesCodeDetailsAndStatus(t *testing.T) {
client, closeServer := testClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeAPIError(t, w, http.StatusNotFound, "NOT_FOUND", map[string]any{"resource": "viking://resources/missing"})
Expand Down
6 changes: 6 additions & 0 deletions sdk/go/retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ func (c *Client) Find(ctx context.Context, queryText string, opts *FindOptions)
if len(opts.Level) > 0 {
payload["level"] = opts.Level
}
if opts.IncludeProvenance {
payload["include_provenance"] = true
}
setAny(payload, "telemetry", opts.Telemetry)
var result FindResult
err := c.doJSON(ctx, http.MethodPost, "/api/v1/search/find", nil, payload, &result)
Expand Down Expand Up @@ -66,6 +69,9 @@ func (c *Client) Search(ctx context.Context, queryText string, opts *SearchOptio
if len(opts.Level) > 0 {
payload["level"] = opts.Level
}
if opts.IncludeProvenance {
payload["include_provenance"] = true
}
setAny(payload, "telemetry", opts.Telemetry)
var result FindResult
err := c.doJSON(ctx, http.MethodPost, "/api/v1/search/search", nil, payload, &result)
Expand Down
11 changes: 10 additions & 1 deletion sdk/go/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ type FindOptions struct {
Until string
TimeField string
Level []int
// IncludeProvenance asks the server to return per-query retrieval
// provenance (see FindResult.QueryResults). Defaults to false.
IncludeProvenance bool
}

// SearchOptions controls Search.
Expand All @@ -177,6 +180,9 @@ type SearchOptions struct {
Until string
TimeField string
Level []int
// IncludeProvenance asks the server to return per-query retrieval
// provenance (see FindResult.QueryResults). Defaults to false.
IncludeProvenance bool
}

// GrepOptions controls Grep.
Expand Down Expand Up @@ -257,7 +263,10 @@ type FindResult struct {
Resources []MatchedContext `json:"resources,omitempty"`
Skills []MatchedContext `json:"skills,omitempty"`
QueryPlan *QueryPlan `json:"query_plan,omitempty"`
QueryResults []map[string]any `json:"query_results,omitempty"`
// QueryResults holds the per-query retrieval provenance returned when
// the request sets IncludeProvenance. The server emits this list under
// the "provenance" key.
QueryResults []map[string]any `json:"provenance,omitempty"`
Total int `json:"total,omitempty"`
}

Expand Down
Loading