Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/ci-go-unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,22 @@ jobs:
go-version: 1.25.0
cache-dependency-path: go.sum

- name: Set up Dgraph
if: matrix.os == 'linux'
run: |
docker run -d --name dgraph-standalone -p 9080:9080 -p 8080:8080 dgraph/standalone:latest
echo "Waiting for Dgraph to be ready..."
for i in {1..30}; do
if curl -s http://localhost:8080/health > /dev/null; then
echo "Dgraph is ready!"
break
fi
echo "Attempt $i: Dgraph not ready, waiting..."
sleep 2
done
sleep 5

- name: Run Unit Tests
env:
MODUSGRAPH_TEST_ADDR: ${{ matrix.os == 'linux' && 'localhost:9080' || '' }}
run: go test -short -race -v .
66 changes: 53 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,16 @@ type StructValidator interface {
// namespace: the namespace for the client.
// logger: the logger for the client.
// validator: the validator instance for struct validation.
// embeddingProvider: optional provider for automatic SimString vector embeddings.
type clientOptions struct {
autoSchema bool
poolSize int
maxEdgeTraversal int
cacheSizeMB int
namespace string
logger logr.Logger
validator StructValidator
autoSchema bool
poolSize int
maxEdgeTraversal int
cacheSizeMB int
namespace string
logger logr.Logger
validator StructValidator
embeddingProvider EmbeddingProvider
}

// ClientOpt is a function that configures a client
Expand Down Expand Up @@ -182,6 +184,17 @@ func WithValidator(v StructValidator) ClientOpt {
}
}

// WithEmbeddingProvider sets the EmbeddingProvider used to automatically generate
// and maintain shadow float32vector predicates for SimString fields tagged with
// `dgraph:"embedding"`. When set, Insert, Upsert, and Update operations will
// call the provider to embed any SimString values and persist the resulting
// vectors alongside the primary string predicates.
func WithEmbeddingProvider(p EmbeddingProvider) ClientOpt {
return func(o *clientOptions) {
o.embeddingProvider = p
}
}

// NewValidator creates a new validator instance with default settings.
// This is a convenience function for creating a validator to use with WithValidator.
// It returns a *validator.Validate from github.com/go-playground/validator/v10.
Expand Down Expand Up @@ -308,8 +321,18 @@ func (c client) key() string {
if c.options.validator != nil {
validatorKey = fmt.Sprintf("%p", c.options.validator)
}
return fmt.Sprintf("%s:%t:%d:%d:%d:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize,
c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.namespace, validatorKey)
embeddingKey := "nil"
if c.options.embeddingProvider != nil {
embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider)
}
return fmt.Sprintf("%s:%t:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize,
c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.namespace, validatorKey, embeddingKey)
}

// embeddingProvider implements the embeddingClient interface, exposing the
// configured EmbeddingProvider to package-level helpers like SimilarToText.
func (c client) embeddingProvider() EmbeddingProvider {
return c.options.embeddingProvider
}

func checkPointer(obj any) error {
Expand Down Expand Up @@ -458,16 +481,33 @@ func (c client) Query(ctx context.Context, model any) *dg.Query {

// UpdateSchema implements updating the Dgraph schema. Pass one or more
// objects that will be used to generate the schema.
// If any object contains SimString fields tagged `dgraph:"embedding"`, the
// corresponding shadow float32vector predicates (<field>__vec) are also registered.
func (c client) UpdateSchema(ctx context.Context, obj ...any) error {
client, err := c.pool.get()
dgClient, err := c.pool.get()
if err != nil {
c.logger.Error(err, "Failed to get client from pool")
return err
}
defer c.pool.put(client)
defer c.pool.put(dgClient)

if _, err = dg.CreateSchema(dgClient, obj...); err != nil {
return err
}

// Collect shadow vector schema lines for SimString fields across all objects.
var vecSchema strings.Builder
for _, o := range obj {
for _, info := range collectSimFields(o) {
vecSchema.WriteString(buildVecSchemaStatement(info))
vecSchema.WriteString("\n")
}
}
if vecSchema.Len() == 0 {
return nil
}

_, err = dg.CreateSchema(client, obj...)
return err
return dgClient.Alter(ctx, &api.Operation{Schema: vecSchema.String()})
}

// GetSchema implements retrieving the Dgraph schema.
Expand Down
Loading
Loading