diff --git a/.github/workflows/ci-go-unit-tests.yaml b/.github/workflows/ci-go-unit-tests.yaml index 31878ab..3623538 100644 --- a/.github/workflows/ci-go-unit-tests.yaml +++ b/.github/workflows/ci-go-unit-tests.yaml @@ -39,5 +39,22 @@ jobs: go-version: 1.25.0 cache-dependency-path: go.sum + - name: Set up Dgraph + if: matrix.os == 'linux' + run: | + docker run -d --name dgraph-standalone -p 9080:9080 -p 8080:8080 dgraph/standalone:latest + echo "Waiting for Dgraph to be ready..." + for i in {1..30}; do + if curl -s http://localhost:8080/health > /dev/null; then + echo "Dgraph is ready!" + break + fi + echo "Attempt $i: Dgraph not ready, waiting..." + sleep 2 + done + sleep 5 + - name: Run Unit Tests + env: + MODUSGRAPH_TEST_ADDR: ${{ matrix.os == 'linux' && 'localhost:9080' || '' }} run: go test -short -race -v . diff --git a/client.go b/client.go index ec6f6f0..e37f84b 100644 --- a/client.go +++ b/client.go @@ -113,14 +113,16 @@ type StructValidator interface { // namespace: the namespace for the client. // logger: the logger for the client. // validator: the validator instance for struct validation. +// embeddingProvider: optional provider for automatic SimString vector embeddings. type clientOptions struct { - autoSchema bool - poolSize int - maxEdgeTraversal int - cacheSizeMB int - namespace string - logger logr.Logger - validator StructValidator + autoSchema bool + poolSize int + maxEdgeTraversal int + cacheSizeMB int + namespace string + logger logr.Logger + validator StructValidator + embeddingProvider EmbeddingProvider } // ClientOpt is a function that configures a client @@ -182,6 +184,17 @@ func WithValidator(v StructValidator) ClientOpt { } } +// WithEmbeddingProvider sets the EmbeddingProvider used to automatically generate +// and maintain shadow float32vector predicates for SimString fields tagged with +// `dgraph:"embedding"`. When set, Insert, Upsert, and Update operations will +// call the provider to embed any SimString values and persist the resulting +// vectors alongside the primary string predicates. +func WithEmbeddingProvider(p EmbeddingProvider) ClientOpt { + return func(o *clientOptions) { + o.embeddingProvider = p + } +} + // NewValidator creates a new validator instance with default settings. // This is a convenience function for creating a validator to use with WithValidator. // It returns a *validator.Validate from github.com/go-playground/validator/v10. @@ -308,8 +321,18 @@ func (c client) key() string { if c.options.validator != nil { validatorKey = fmt.Sprintf("%p", c.options.validator) } - return fmt.Sprintf("%s:%t:%d:%d:%d:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, - c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.namespace, validatorKey) + embeddingKey := "nil" + if c.options.embeddingProvider != nil { + embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider) + } + return fmt.Sprintf("%s:%t:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, + c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.namespace, validatorKey, embeddingKey) +} + +// embeddingProvider implements the embeddingClient interface, exposing the +// configured EmbeddingProvider to package-level helpers like SimilarToText. +func (c client) embeddingProvider() EmbeddingProvider { + return c.options.embeddingProvider } func checkPointer(obj any) error { @@ -458,16 +481,33 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { // UpdateSchema implements updating the Dgraph schema. Pass one or more // objects that will be used to generate the schema. +// If any object contains SimString fields tagged `dgraph:"embedding"`, the +// corresponding shadow float32vector predicates (__vec) are also registered. func (c client) UpdateSchema(ctx context.Context, obj ...any) error { - client, err := c.pool.get() + dgClient, err := c.pool.get() if err != nil { c.logger.Error(err, "Failed to get client from pool") return err } - defer c.pool.put(client) + defer c.pool.put(dgClient) + + if _, err = dg.CreateSchema(dgClient, obj...); err != nil { + return err + } + + // Collect shadow vector schema lines for SimString fields across all objects. + var vecSchema strings.Builder + for _, o := range obj { + for _, info := range collectSimFields(o) { + vecSchema.WriteString(buildVecSchemaStatement(info)) + vecSchema.WriteString("\n") + } + } + if vecSchema.Len() == 0 { + return nil + } - _, err = dg.CreateSchema(client, obj...) - return err + return dgClient.Alter(ctx, &api.Operation{Schema: vecSchema.String()}) } // GetSchema implements retrieving the Dgraph schema. diff --git a/embedding.go b/embedding.go new file mode 100644 index 0000000..e579de8 --- /dev/null +++ b/embedding.go @@ -0,0 +1,494 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + "strings" + + "github.com/dgraph-io/dgo/v250/protos/api" + dg "github.com/dolan-in/dgman/v2" +) + +// SimString is a string type that participates in automatic vector similarity search. +// When a struct field of this type is tagged with `dgraph:"embedding"`, modusGraph +// will automatically generate and maintain a shadow float32vector predicate +// (__vec) backed by the configured EmbeddingProvider. +// +// Example: +// +// type Product struct { +// Description SimString `json:"description,omitempty" dgraph:"embedding,index=term"` +// UID string `json:"uid,omitempty"` +// DType []string `json:"dgraph.type,omitempty"` +// } +type SimString string + +// MarshalJSON implements json.Marshaler. SimString serializes as a plain JSON string. +func (s SimString) MarshalJSON() ([]byte, error) { + return json.Marshal(string(s)) +} + +// UnmarshalJSON implements json.Unmarshaler. SimString deserializes from a plain JSON string. +func (s *SimString) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + *s = SimString(str) + return nil +} + +// SchemaType implements the dgman SchemaType interface so that dgman emits +// "string" as the Dgraph predicate type for SimString fields. +func (s SimString) SchemaType() string { + return "string" +} + +// EmbeddingProvider is the interface for generating float32 vector embeddings from text. +// Implement this interface to integrate any embedding service (OpenAI, Ollama, local models, etc.). +type EmbeddingProvider interface { + // Embed returns a float32 embedding vector for the given text. + Embed(ctx context.Context, text string) ([]float32, error) + + // Dims returns the fixed number of dimensions produced by this provider. + Dims() int +} + +// OpenAICompatibleConfig configures an OpenAI-compatible embedding provider. +// This works with both OpenAI and Ollama (which exposes /v1/embeddings since v0.13.3). +type OpenAICompatibleConfig struct { + // BaseURL is the base URL of the embedding service. + // For OpenAI: "https://api.openai.com" + // For Ollama: "http://localhost:11434" + BaseURL string + + // Model is the embedding model to use. + // For OpenAI: e.g. "text-embedding-3-small" + // For Ollama: e.g. "nomic-embed-text" + Model string + + // APIKey is the API key for authentication. Leave empty for Ollama. + APIKey string + + // Dims is the expected number of dimensions for embeddings from this model. + // This must match the model's actual output dimension. + Dims int +} + +// OpenAICompatibleProvider is an EmbeddingProvider that calls any OpenAI-compatible +// /v1/embeddings endpoint. Works with OpenAI, Ollama, and other compatible services. +type OpenAICompatibleProvider struct { + config OpenAICompatibleConfig + httpClient *http.Client +} + +// NewOpenAICompatibleProvider creates a new OpenAICompatibleProvider with the given config. +func NewOpenAICompatibleProvider(config OpenAICompatibleConfig) *OpenAICompatibleProvider { + return &OpenAICompatibleProvider{ + config: config, + httpClient: &http.Client{}, + } +} + +type embeddingRequest struct { + Input string `json:"input"` + Model string `json:"model"` +} + +type embeddingResponse struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` +} + +// Embed implements EmbeddingProvider. +func (p *OpenAICompatibleProvider) Embed(ctx context.Context, text string) ([]float32, error) { + reqBody := embeddingRequest{ + Input: text, + Model: p.config.Model, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("embedding: marshal request: %w", err) + } + + url := strings.TrimRight(p.config.BaseURL, "/") + "/v1/embeddings" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("embedding: create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if p.config.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+p.config.APIKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("embedding: http request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("embedding: server returned %d: %s", resp.StatusCode, string(body)) + } + + var embResp embeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&embResp); err != nil { + return nil, fmt.Errorf("embedding: decode response: %w", err) + } + if len(embResp.Data) == 0 { + return nil, fmt.Errorf("embedding: empty response data") + } + return embResp.Data[0].Embedding, nil +} + +// Dims implements EmbeddingProvider. +func (p *OpenAICompatibleProvider) Dims() int { + return p.config.Dims +} + +// simFieldInfo holds metadata about a detected SimString field. +type simFieldInfo struct { + jsonPredicate string // e.g. "description" + vecPredicate string // e.g. "description__vec" + metric string // default "cosine" + exponent string // default "4" + threshold int // min rune count to embed; 0 = always embed +} + +// vecShadowPredicate returns the shadow vector predicate name for a given field predicate. +func vecShadowPredicate(predicate string) string { + return predicate + "__vec" +} + +// parseEmbeddingTag parses embedding-specific options from a dgraph struct tag. +// It extracts metric, exponent, and threshold values, returning defaults if not set. +func parseEmbeddingTag(tag string) (metric, exponent string, threshold int) { + metric = "cosine" + exponent = "4" + for _, part := range strings.Split(tag, ",") { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "metric=") { + metric = strings.TrimPrefix(part, "metric=") + } else if strings.HasPrefix(part, "exponent=") { + exponent = strings.TrimPrefix(part, "exponent=") + } else if strings.HasPrefix(part, "threshold=") { + if n, err := strconv.Atoi(strings.TrimPrefix(part, "threshold=")); err == nil && n >= 0 { + threshold = n + } + } + } + return metric, exponent, threshold +} + +// hasEmbeddingTag reports whether a dgraph struct tag contains the "embedding" directive. +func hasEmbeddingTag(tag string) bool { + for _, part := range strings.Split(tag, ",") { + if strings.TrimSpace(part) == "embedding" { + return true + } + } + return false +} + +// hasSimStringFields reports whether obj contains any SimString field tagged dgraph:"embedding". +// Used as a fast check before allocating a two-phase transaction. +func hasSimStringFields(obj any) bool { + val := reflect.ValueOf(obj) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + simStringType := reflect.TypeOf(SimString("")) + + checkStruct := func(sv reflect.Value) bool { + if sv.Kind() != reflect.Struct { + return false + } + st := sv.Type() + for i := 0; i < st.NumField(); i++ { + field := st.Field(i) + if field.Type == simStringType && hasEmbeddingTag(field.Tag.Get("dgraph")) { + return true + } + } + return false + } + + switch val.Kind() { + case reflect.Slice: + for i := 0; i < val.Len(); i++ { + elem := val.Index(i) + if elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + if checkStruct(elem) { + return true + } + } + case reflect.Struct: + return checkStruct(val) + } + return false +} + +// collectSimFieldInfoFromType inspects a struct type (not values) and returns +// metadata about all SimString fields tagged dgraph:"embedding". +// Used by UpdateSchema to emit shadow vector predicates without needing actual data. +func collectSimFieldInfoFromType(t reflect.Type) []simFieldInfo { + for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil + } + + simStringType := reflect.TypeOf(SimString("")) + var results []simFieldInfo + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type != simStringType { + continue + } + dgraphTag := field.Tag.Get("dgraph") + if !hasEmbeddingTag(dgraphTag) { + continue + } + jsonTag := field.Tag.Get("json") + predicate := strings.Split(jsonTag, ",")[0] + if predicate == "" { + predicate = field.Name + } + metric, exponent, threshold := parseEmbeddingTag(dgraphTag) + results = append(results, simFieldInfo{ + jsonPredicate: predicate, + vecPredicate: vecShadowPredicate(predicate), + metric: metric, + exponent: exponent, + threshold: threshold, + }) + } + return results +} + +// collectSimFields inspects obj (pointer to struct, or slice of pointer to struct) +// and returns metadata about all SimString fields tagged dgraph:"embedding", +// including the current text value of each field. +func collectSimFields(obj any) []simFieldInfo { + return collectSimFieldInfoFromType(reflect.TypeOf(obj)) +} + +// buildVecSchemaStatement produces a Dgraph schema line for a shadow vector predicate. +func buildVecSchemaStatement(info simFieldInfo) string { + return fmt.Sprintf(`%s: float32vector @index(hnsw(exponent: "%s", metric: "%s")) .`, + info.vecPredicate, info.exponent, info.metric) +} + +// vectorToQueryString converts a []float32 to the string format used in Dgraph +// similar_to query variables: "[v1, v2, ...]" +func vectorToQueryString(vec []float32) string { + parts := make([]string, len(vec)) + for i, v := range vec { + parts[i] = fmt.Sprintf("%v", v) + } + return "[" + strings.Join(parts, ", ") + "]" +} + +// vectorToBytes converts a []float32 to little-endian binary bytes suitable +// for the api.Value_Vfloat32Val field in Dgraph NQuad mutations. +func vectorToBytes(vec []float32) []byte { + buf := new(bytes.Buffer) + for _, v := range vec { + // binary.Write with float32 is not directly supported; use uint32 bit representation + _ = binary.Write(buf, binary.LittleEndian, v) + } + return buf.Bytes() +} + +// injectShadowVectors calls the embedding provider for any SimString fields in obj, +// then writes the resulting vectors as NQuad mutations against the already-assigned UIDs. +// It uses the raw dgo.Txn to issue a Set mutation without CommitNow so the +// caller can commit the whole transaction atomically. +func injectShadowVectors(ctx context.Context, + provider EmbeddingProvider, + tx *dg.TxnContext, + obj any, + uids []string) error { + val := reflect.ValueOf(obj) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + var structs []reflect.Value + switch val.Kind() { + case reflect.Slice: + for i := 0; i < val.Len(); i++ { + elem := val.Index(i) + if elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + if elem.Kind() == reflect.Struct { + structs = append(structs, elem) + } + } + case reflect.Struct: + structs = append(structs, val) + } + + if len(structs) == 0 { + return nil + } + + simStringType := reflect.TypeOf(SimString("")) + var setNquads []*api.NQuad + var delNquads []*api.NQuad + + for _, sv := range structs { + // Always read UID from the struct field — dgman's setUIDHook writes it back + // after mutation, and the returned uids slice has non-deterministic order + // (it comes from a map[string]string iteration). + uidField := sv.FieldByName("UID") + uid := "" + if uidField.IsValid() && uidField.Kind() == reflect.String { + uid = uidField.String() + } + if uid == "" { + continue + } + + st := sv.Type() + for i := 0; i < st.NumField(); i++ { + field := st.Field(i) + if field.Type != simStringType { + continue + } + dgraphTag := field.Tag.Get("dgraph") + if !hasEmbeddingTag(dgraphTag) { + continue + } + + jsonTag := field.Tag.Get("json") + predicate := strings.Split(jsonTag, ",")[0] + if predicate == "" { + predicate = field.Name + } + vecPred := vecShadowPredicate(predicate) + + textVal := string(sv.Field(i).Interface().(SimString)) + _, _, threshold := parseEmbeddingTag(dgraphTag) + + // Below threshold (or empty): delete the shadow vector to avoid stale data. + if textVal == "" || (threshold > 0 && len([]rune(textVal)) < threshold) { + delNquads = append(delNquads, &api.NQuad{ + Subject: uid, + Predicate: vecPred, + ObjectValue: &api.Value{Val: &api.Value_DefaultVal{DefaultVal: "_STAR_ALL"}}, + }) + continue + } + + vec, err := provider.Embed(ctx, textVal) + if err != nil { + return fmt.Errorf("embedding field %q: %w", predicate, err) + } + + setNquads = append(setNquads, &api.NQuad{ + Subject: uid, + Predicate: vecPred, + ObjectValue: &api.Value{ + Val: &api.Value_Vfloat32Val{ + Vfloat32Val: vectorToBytes(vec), + }, + }, + }) + } + } + + if len(setNquads) == 0 && len(delNquads) == 0 { + return nil + } + + _, err := tx.Txn().Mutate(ctx, &api.Mutation{ + Set: setNquads, + Del: delNquads, + CommitNow: false, + }) + return err +} + +// SimilarTo returns a QueryBlock ready to Scan() for the k nearest neighbours +// of the given pre-computed vector. It uses $vec as a query variable so Dgraph +// can parse the query with the standard variable substitution path. +// +// The QueryBlock already has the vector variable bound; call Scan() directly: +// +// vec := []float32{0.1, 0.2, 0.3} +// dgoClient, cleanup, _ := client.DgraphClient(); defer cleanup() +// tx := dg.NewReadOnlyTxn(dgoClient) +// err := SimilarTo(tx, &result, "description", vec, 5).Scan() +func SimilarTo(tx *dg.TxnContext, model any, field string, vec []float32, k int) *dg.QueryBlock { + vecStr := vectorToQueryString(vec) + rootFunc := fmt.Sprintf("similar_to(%s, %d, $vec)", vecShadowPredicate(field), k) + q := dg.NewQuery().Model(model).RootFunc(rootFunc) + return tx.Query(q).Vars("similar_to($vec: string)", map[string]string{"$vec": vecStr}) +} + +// SimilarToText embeds text on-the-fly using the client's configured EmbeddingProvider, +// executes a similar_to(__vec, k, $vec) query, and scans the nearest-neighbour +// results into model. The connection lifecycle is managed internally. +// +// Returns an error if no EmbeddingProvider is configured on the client, embedding fails, +// or the query fails. +// +// Example: +// +// err := SimilarToText(client, ctx, &result, "description", "fast red sports car", 5) +// if err != nil { ... } +func SimilarToText(c Client, ctx context.Context, model any, field string, text string, k int) error { + ec, ok := c.(embeddingClient) + if !ok { + return fmt.Errorf("client does not expose embeddingProvider; ensure it is a modusgraph client") + } + provider := ec.embeddingProvider() + if provider == nil { + return fmt.Errorf("no EmbeddingProvider configured on client; use WithEmbeddingProvider") + } + + vec, err := provider.Embed(ctx, text) + if err != nil { + return fmt.Errorf("SimilarToText: embed text: %w", err) + } + + vecStr := vectorToQueryString(vec) + rootFunc := fmt.Sprintf("similar_to(%s, %d, $vec)", vecShadowPredicate(field), k) + + dgoClient, cleanup, err := c.DgraphClient() + if err != nil { + return err + } + defer cleanup() + + tx := dg.NewReadOnlyTxn(dgoClient) + q := dg.NewQuery().Model(model).RootFunc(rootFunc) + return tx.Query(q).Vars("similar_to($vec: string)", map[string]string{"$vec": vecStr}).Scan() +} + +// embeddingClient is an internal interface implemented by client to expose +// the embedding provider to top-level helper functions. +type embeddingClient interface { + embeddingProvider() EmbeddingProvider +} diff --git a/embedding_test.go b/embedding_test.go new file mode 100644 index 0000000..15779a9 --- /dev/null +++ b/embedding_test.go @@ -0,0 +1,596 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + "net/http" + "strings" + "testing" + "time" + + dg "github.com/dolan-in/dgman/v2" + mg "github.com/matthewmcneely/modusgraph" + "github.com/stretchr/testify/require" +) + +// mockEmbeddingProvider is a deterministic EmbeddingProvider for testing. +// Each unique text gets a distinct unit-vector embedding; identical texts get +// identical embeddings, enabling correct nearest-neighbour assertions. +type mockEmbeddingProvider struct { + dims int + callLog []string // tracks texts that were embedded + vectors map[string][]float32 +} + +func newMockProvider(dims int) *mockEmbeddingProvider { + return &mockEmbeddingProvider{ + dims: dims, + vectors: make(map[string][]float32), + } +} + +// register pre-registers a specific vector for a text so tests can control +// exactly what vector will be stored. +func (m *mockEmbeddingProvider) register(text string, vec []float32) { + m.vectors[text] = vec +} + +func (m *mockEmbeddingProvider) Embed(_ context.Context, text string) ([]float32, error) { + m.callLog = append(m.callLog, text) + if v, ok := m.vectors[text]; ok { + return v, nil + } + // generate a deterministic unit-ish vector based on string hash + vec := make([]float32, m.dims) + for i := range vec { + vec[i] = float32(len(text)+i) * 0.01 + } + m.vectors[text] = vec + return vec, nil +} + +func (m *mockEmbeddingProvider) Dims() int { return m.dims } + +// embeddableProduct is the test struct using SimString. +type embeddableProduct struct { + Name string `json:"name,omitempty" dgraph:"index=term"` + Description mg.SimString `json:"description,omitempty" dgraph:"embedding,index=term"` + + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` +} + +// embeddableCustomMetric tests overriding metric and exponent. +type embeddableCustomMetric struct { + Name string `json:"name,omitempty" dgraph:"index=term"` + Description mg.SimString `json:"description,omitempty" dgraph:"embedding,metric=euclidean,exponent=5"` + + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` +} + +// embeddableWithThreshold tests the threshold=N tag option. +// Descriptions shorter than 20 runes should not be embedded. +type embeddableWithThreshold struct { + Name string `json:"name,omitempty" dgraph:"index=term"` + Description mg.SimString `json:"description,omitempty" dgraph:"embedding,threshold=20"` + + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` +} + +// createEmbeddingClient creates a test client with the given mock embedding provider. +func createEmbeddingClient(t *testing.T, provider mg.EmbeddingProvider) (mg.Client, func()) { + t.Helper() + uri := "file://" + GetTempDir(t) + client, err := mg.NewClient(uri, + mg.WithAutoSchema(true), + mg.WithEmbeddingProvider(provider), + ) + require.NoError(t, err) + + cleanup := func() { + _ = client.DropAll(context.Background()) + client.Close() + mg.Shutdown() + } + return client, cleanup +} + +// --- Unit tests --- + +func TestSimStringMarshal(t *testing.T) { + s := mg.SimString("hello world") + b, err := s.MarshalJSON() + require.NoError(t, err) + require.Equal(t, `"hello world"`, string(b)) +} + +func TestSimStringUnmarshal(t *testing.T) { + var s mg.SimString + require.NoError(t, s.UnmarshalJSON([]byte(`"hello world"`))) + require.Equal(t, mg.SimString("hello world"), s) +} + +func TestHasEmbeddingTagDetection(t *testing.T) { + provider := newMockProvider(4) + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + product := &embeddableProduct{ + Name: "Widget", + Description: "A small gadget", + } + err := client.Insert(ctx, product) + require.NoError(t, err, "Insert with SimString field should succeed") + require.NotEmpty(t, product.UID, "UID should be populated after insert") + + // Provider should have been called exactly once for the description field + require.Len(t, provider.callLog, 1) + require.Equal(t, "A small gadget", provider.callLog[0]) +} + +func TestInsertWithEmbedding(t *testing.T) { + const dims = 5 + provider := newMockProvider(dims) + + // Register controlled vectors so similarity search is deterministic + provider.register("apple fruit sweet", []float32{0.9, 0.1, 0.1, 0.1, 0.1}) + provider.register("banana yellow tropical", []float32{0.1, 0.9, 0.1, 0.1, 0.1}) + provider.register("carrot orange vegetable", []float32{0.1, 0.1, 0.9, 0.1, 0.1}) + + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + products := []*embeddableProduct{ + {Name: "Apple", Description: "apple fruit sweet"}, + {Name: "Banana", Description: "banana yellow tropical"}, + {Name: "Carrot", Description: "carrot orange vegetable"}, + } + + err := client.Insert(ctx, products) + require.NoError(t, err, "Batch insert with SimString should succeed") + + for _, p := range products { + require.NotEmpty(t, p.UID, "Each product should have a UID after insert") + } + + // Provider called once per product + require.Len(t, provider.callLog, 3) + + // Query back the apple and verify text is intact + var fetched embeddableProduct + err = client.Get(ctx, &fetched, products[0].UID) + require.NoError(t, err) + require.Equal(t, "apple fruit sweet", string(fetched.Description)) + require.Equal(t, "Apple", fetched.Name) +} + +func TestUpdateWithEmbedding(t *testing.T) { + const dims = 5 + provider := newMockProvider(dims) + provider.register("original text", []float32{0.5, 0.5, 0.0, 0.0, 0.0}) + provider.register("updated text", []float32{0.0, 0.0, 0.5, 0.5, 0.0}) + + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + product := &embeddableProduct{ + Name: "Thing", + Description: "original text", + } + require.NoError(t, client.Insert(ctx, product)) + require.Len(t, provider.callLog, 1) + + // Update description + product.Description = "updated text" + require.NoError(t, client.Update(ctx, product)) + + // Provider should have been called again for the updated text + require.Len(t, provider.callLog, 2) + require.Equal(t, "updated text", provider.callLog[1]) +} + +func TestSimilarToQuery(t *testing.T) { + const dims = 5 + provider := newMockProvider(dims) + + // Use well-separated, non-zero vectors to ensure stable cosine similarity results. + // Group 1 (high first component): items 1-4 + // Group 2 (high last component): items 5-8 + // We query with a Group 1 vector and assert we don't get a Group 2 item back as top-1. + group1 := [][]float32{ + {0.95, 0.20, 0.10, 0.10, 0.05}, + {0.90, 0.25, 0.12, 0.08, 0.06}, + {0.92, 0.22, 0.11, 0.09, 0.07}, + {0.88, 0.28, 0.13, 0.07, 0.08}, + } + group2 := [][]float32{ + {0.05, 0.10, 0.10, 0.20, 0.95}, + {0.06, 0.08, 0.12, 0.25, 0.90}, + {0.07, 0.09, 0.11, 0.22, 0.92}, + {0.08, 0.07, 0.13, 0.28, 0.88}, + } + + /* trunk-ignore(golangci-lint/prealloc) */ + var products []*embeddableProduct + for i, v := range group1 { + name := fmt.Sprintf("Group1-%d", i+1) + provider.register(name, v) + products = append(products, &embeddableProduct{Name: name, Description: mg.SimString(name)}) + } + for i, v := range group2 { + name := fmt.Sprintf("Group2-%d", i+1) + provider.register(name, v) + products = append(products, &embeddableProduct{Name: name, Description: mg.SimString(name)}) + } + + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + require.NoError(t, client.Insert(ctx, products)) + + // Query vector is clearly in Group 1 (high first component) + queryVec := []float32{0.93, 0.21, 0.11, 0.09, 0.06} + + dgoClient, cleanupDgo, err := client.DgraphClient() + require.NoError(t, err) + defer cleanupDgo() + + var result embeddableProduct + tx := dg.NewReadOnlyTxn(dgoClient) + err = mg.SimilarTo(tx, &result, "description", queryVec, 1).Scan() + require.NoError(t, err) + + require.NotEmpty(t, result.Name, "Should find a matching product") + require.True(t, strings.HasPrefix(result.Name, "Group1-"), + "Expected a Group1 result but got: %s", result.Name) +} + +func TestSimilarToTextQuery(t *testing.T) { + const dims = 5 + provider := newMockProvider(dims) + + vecApple := []float32{1.0, 0.0, 0.0, 0.0, 0.0} + vecBanana := []float32{0.0, 1.0, 0.0, 0.0, 0.0} + vecQueryFruit := []float32{0.99, 0.01, 0.0, 0.0, 0.0} // clearly close to apple + + provider.register("apple fruit sweet", vecApple) + provider.register("banana yellow tropical", vecBanana) + provider.register("fruit like apple", vecQueryFruit) // the query text + + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + products := []*embeddableProduct{ + {Name: "Apple Product", Description: "apple fruit sweet"}, + {Name: "Banana Product", Description: "banana yellow tropical"}, + } + require.NoError(t, client.Insert(ctx, products)) + + // SimilarToText should embed "fruit like apple" → vecQueryFruit → nearest is Apple Product + var result embeddableProduct + err := mg.SimilarToText(client, ctx, &result, "description", "fruit like apple", 1) + require.NoError(t, err, "SimilarToText should not error") + require.Equal(t, "Apple Product", result.Name) +} + +func TestUpdateSchemaRegistersVecPredicate(t *testing.T) { + provider := newMockProvider(4) + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + // Trigger explicit schema update + err := client.UpdateSchema(ctx, &embeddableProduct{}) + require.NoError(t, err) + + // QueryRaw against schema introspection to verify the vector predicate was registered + raw, err := client.QueryRaw(ctx, `schema(pred: [description__vec]) { type }`, nil) + require.NoError(t, err) + rawStr := string(raw) + require.Contains(t, rawStr, "description__vec", "Schema should contain the shadow vector predicate") + require.Contains(t, rawStr, "float32vector", "Shadow predicate should be of type float32vector") +} + +func TestCustomMetricEmbedding(t *testing.T) { + provider := newMockProvider(4) + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + err := client.UpdateSchema(ctx, &embeddableCustomMetric{}) + require.NoError(t, err) + + // QueryRaw to verify the vector predicate schema + raw, err := client.QueryRaw(ctx, `schema(pred: [description__vec]) { type }`, nil) + require.NoError(t, err) + require.Contains(t, string(raw), "description__vec", "Shadow predicate should exist") + // Euclidean metric is embedded in the index definition; verify the predicate type at minimum + require.Contains(t, string(raw), "float32vector", "Shadow predicate should be float32vector type") +} + +func TestNoProviderNoEmbedding(t *testing.T) { + // Client without embedding provider: Insert should still work normally for SimString fields + uri := "file://" + GetTempDir(t) + client, err := mg.NewClient(uri, mg.WithAutoSchema(true)) + require.NoError(t, err) + defer func() { + _ = client.DropAll(context.Background()) + client.Close() + mg.Shutdown() + }() + + ctx := context.Background() + + product := &embeddableProduct{ + Name: "NoVec", + Description: "plain text no embedding", + } + err = client.Insert(ctx, product) + require.NoError(t, err, "Insert should succeed even without an EmbeddingProvider") + require.NotEmpty(t, product.UID) +} + +func TestSimStringTermSearch(t *testing.T) { + provider := newMockProvider(4) + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + products := []*embeddableProduct{ + {Name: "Kettle", Description: "stainless steel electric kettle for boiling water"}, + {Name: "Toaster", Description: "four slice toaster with browning control"}, + {Name: "Blender", Description: "high speed blender for smoothies and soups"}, + } + require.NoError(t, client.Insert(ctx, products)) + + // Term search on the description predicate of a SimString field. + // allofterms matches nodes where the predicate contains all listed terms. + var result embeddableProduct + q := client.Query(ctx, &result).Filter("allofterms(description, \"electric kettle\")") + err := q.Node() + require.NoError(t, err) + require.Equal(t, "Kettle", result.Name, + "Term search on SimString description should return the matching product") + + // anyofterms: should match both Kettle and Blender + var results []embeddableProduct + q2 := client.Query(ctx, &results).Filter("anyofterms(description, \"kettle blender\")") + err = q2.Nodes() + require.NoError(t, err) + require.Len(t, results, 2, "anyofterms should match two products") +} + +func TestThresholdEmbedding(t *testing.T) { + const dims = 4 + provider := newMockProvider(dims) + provider.register("long enough text to embed", []float32{1.0, 0.0, 0.0, 0.0}) + + client, cleanup := createEmbeddingClient(t, provider) + defer cleanup() + + ctx := context.Background() + + // Insert with a description below the 20-rune threshold — provider should NOT be called. + short := &embeddableWithThreshold{Name: "Short", Description: "too short"} + require.NoError(t, client.Insert(ctx, short)) + require.Empty(t, provider.callLog, "Provider should not be called for below-threshold text") + + // Insert with a description above the threshold — provider SHOULD be called. + long := &embeddableWithThreshold{ + Name: "Long", + Description: "long enough text to embed", + } + require.NoError(t, client.Insert(ctx, long)) + require.Len(t, provider.callLog, 1, "Provider should be called for above-threshold text") + + // Update the long item to a short description — shadow vector should be cleared. + // After clearing, a similarity query for the original text should not return it. + long.Description = "short" + require.NoError(t, client.Update(ctx, long)) + // Provider call count should not increase (below threshold on update too). + require.Len(t, provider.callLog, 1, "Provider should not be called when updated text is below threshold") + + // The shadow vec for the long item should now be absent — verify via raw schema query + // that the predicate exists but the node won't appear in similar_to results. + dgoClient, cleanupDgo, err := client.DgraphClient() + require.NoError(t, err) + defer cleanupDgo() + + queryVec := []float32{1.0, 0.0, 0.0, 0.0} + tx := dg.NewReadOnlyTxn(dgoClient) + var result embeddableWithThreshold + err = mg.SimilarTo(tx, &result, "description", queryVec, 1).Scan() + // Either no results (empty UID) or the short item (which was never embedded) — + // the long item's cleared vector should not be the top match. + require.NoError(t, err) + require.NotEqual(t, long.UID, result.UID, + "Cleared shadow vector should not appear in similarity results") +} + +// ── Ollama live integration ──────────────────────────────────────────── + +const ( + ollamaBaseURL = "http://localhost:11434" + ollamaModel = "bge-m3:latest" + ollamaDims = 1024 +) + +// ollamaRunning probes Ollama's /api/tags endpoint with a short timeout. +// Returns true only when Ollama is reachable and responds 200. +func ollamaRunning() bool { + c := &http.Client{Timeout: 2 * time.Second} + resp, err := c.Get(ollamaBaseURL + "/api/tags") + if err != nil { + return false + } + resp.Body.Close() + return resp.StatusCode == http.StatusOK +} + +// skipUnlessOllama calls t.Skip if Ollama is not reachable. +func skipUnlessOllama(t *testing.T) { + t.Helper() + if !ollamaRunning() { + t.Skipf("Ollama not reachable at %s — skipping live integration test", ollamaBaseURL) + } +} + +// sportingGoodsProduct is the test struct used in the live embedding integration tests. +type sportingGoodsProduct struct { + Name string `json:"name,omitempty" dgraph:"index=term"` + Category string `json:"category,omitempty" dgraph:"index=term"` + Description mg.SimString `json:"description,omitempty" dgraph:"embedding,index=term"` + + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` +} + +func newOllamaProvider() *mg.OpenAICompatibleProvider { + return mg.NewOpenAICompatibleProvider(mg.OpenAICompatibleConfig{ + BaseURL: ollamaBaseURL, + Model: ollamaModel, + Dims: ollamaDims, + }) +} + +// TestOllamaIntegration exercises insert, upsert, update, and SimilarToText +// against a real Ollama instance running bge-m3:latest. +func TestOllamaIntegration(t *testing.T) { + skipUnlessOllama(t) + + provider := newOllamaProvider() + uri := "file://" + GetTempDir(t) + client, err := mg.NewClient(uri, + mg.WithAutoSchema(true), + mg.WithEmbeddingProvider(provider), + ) + require.NoError(t, err) + defer func() { + _ = client.DropAll(context.Background()) + client.Close() + mg.Shutdown() + }() + + ctx := context.Background() + + // 1. Insert a corpus of semantically varied products + products := []*sportingGoodsProduct{ + { + Name: "Trail Runner X", + Category: "footwear", + Description: "Lightweight trail running shoe with aggressive grip for mountain terrain", + }, + { + Name: "Road Racer Pro", + Category: "footwear", + Description: "Carbon-plated road running shoe for fast road races and marathons", + }, + { + Name: "Summit Hardshell", + Category: "outerwear", + Description: "Waterproof hardshell jacket for alpine climbing and severe weather", + }, + { + Name: "Base Layer Merino", + Category: "clothing", + Description: "Soft merino wool thermal base layer for cold weather activities", + }, + { + Name: "Carbon Fibre Kayak", + Category: "watersports", + Description: "Ultra-light carbon fibre sea kayak for ocean touring and expedition paddling", + }, + { + Name: "Rock Climbing Harness", + Category: "climbing", + Description: "Comfortable sit harness for sport climbing and indoor bouldering gym use", + }, + { + Name: "Trail Mix Nutrition Bar", + Category: "nutrition", + Description: "High-energy snack bar for long hikes, runs, and endurance activities", + }, + { + Name: "GPS Watch Ultra", + Category: "electronics", + Description: "Multi-sport GPS watch with heart rate monitoring and route navigation", + }, + } + + err = client.Insert(ctx, products) + require.NoError(t, err, "Insert corpus should succeed") + for _, p := range products { + require.NotEmpty(t, p.UID, "Each product should have a UID: %s", p.Name) + } + t.Logf("Inserted %d products", len(products)) + + // 2. Verify the shadow schema was registered + raw, err := client.QueryRaw(ctx, `schema(pred: [description__vec]) { type }`, nil) + require.NoError(t, err) + require.Contains(t, string(raw), "float32vector", "Shadow predicate should be registered") + + // 3. SimilarToText: query for running shoes + var shoeResult sportingGoodsProduct + err = mg.SimilarToText(client, ctx, &shoeResult, "description", "running shoes for trails", 1) + require.NoError(t, err) + t.Logf("Running shoe query → %q (%s)", shoeResult.Name, shoeResult.Description) + require.NotEmpty(t, shoeResult.Name, "Should find a product") + require.Contains(t, strings.ToLower(shoeResult.Category), "footwear", + "Top result for 'running shoes for trails' should be footwear, got %q", shoeResult.Name) + + // 4. SimilarToText: query for waterproof outerwear + var jacketResult sportingGoodsProduct + err = mg.SimilarToText(client, ctx, &jacketResult, "description", "waterproof jacket for bad weather", 1) + require.NoError(t, err) + t.Logf("Jacket query → %q (%s)", jacketResult.Name, jacketResult.Description) + require.NotEmpty(t, jacketResult.Name) + require.Equal(t, "Summit Hardshell", jacketResult.Name, + "Top result for waterproof jacket should be Summit Hardshell") + + // ── 5. Update: change Trail Runner X description and re-query ───────────── + trailRunner := products[0] + trailRunner.Description = "Rugged trail running shoe with rock plate and waterproof membrane for muddy conditions" + err = client.Update(ctx, trailRunner) + require.NoError(t, err, "Update should succeed") + + // Re-query with the updated semantics — still expects a trail running shoe + var updatedResult sportingGoodsProduct + err = mg.SimilarToText(client, ctx, &updatedResult, "description", "waterproof trail shoe for mud", 1) + require.NoError(t, err) + t.Logf("After update query → %q", updatedResult.Name) + require.NotEmpty(t, updatedResult.Name) + + // 6. Upsert: update Road Racer Pro by predicate + roadRacer := products[1] + roadRacer.Description = "Featherlight carbon road shoe for sub-3-hour marathon performance" + err = client.Upsert(ctx, roadRacer, "name") + require.NoError(t, err, "Upsert should succeed") + + // ── 7. SimilarToText: confirm marathon query still maps to road shoe ─────── + var marathonResult sportingGoodsProduct + err = mg.SimilarToText(client, ctx, &marathonResult, "description", "shoe for running a marathon", 1) + require.NoError(t, err) + t.Logf("Marathon query → %q", marathonResult.Name) + require.NotEmpty(t, marathonResult.Name) + require.Equal(t, "footwear", marathonResult.Category, + "Marathon shoe query should return a footwear product") +} diff --git a/go.mod b/go.mod index 5ec767e..c665bf7 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matthewmcneely/modusgraph -go 1.25.6 +go 1.25.7 require ( github.com/cavaliergopher/grab/v3 v3.0.1 diff --git a/mutate.go b/mutate.go index e10aa8b..ead6dd9 100644 --- a/mutate.go +++ b/mutate.go @@ -97,7 +97,20 @@ func (c client) process(ctx context.Context, } defer c.pool.put(client) - tx := dg.NewTxnContext(ctx, client).SetCommitNow() + provider := c.options.embeddingProvider + hasEmbedding := provider != nil && hasSimStringFields(obj) + + var tx *dg.TxnContext + if hasEmbedding { + // Do not use SetCommitNow: we need to inject shadow vectors before committing. + tx = dg.NewTxnContext(ctx, client) + // Discard is a no-op after a successful Commit but ensures resources are + // cleaned up on all paths (error returns, panics, etc.). + defer func() { _ = tx.Txn().Discard(ctx) }() + } else { + tx = dg.NewTxnContext(ctx, client).SetCommitNow() + } + uids, err := txFunc(tx, obj) if err != nil { // Check if this is a unique constraint violation error from Dgraph @@ -106,6 +119,16 @@ func (c client) process(ctx context.Context, } return err } + + if hasEmbedding { + if err := injectShadowVectors(ctx, provider, tx, obj, uids); err != nil { + return fmt.Errorf("injecting shadow vectors: %w", err) + } + if err := tx.Txn().Commit(ctx); err != nil { + return fmt.Errorf("committing transaction with shadow vectors: %w", err) + } + } + c.logger.V(2).Info(operation+" successful", "uidCount", len(uids)) return nil }