Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions autowire.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ func InvokeStructCtx[T any](ctx context.Context, c *Container) (T, error) {
return structVal.Interface().(T), nil
}

func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) error {
func buildFuncProvider[T any](c *Container, constructor any) (Provider[T], []ProviderOption, error) {
params, returnType, err := reflect.FuncParams(constructor)
if err != nil {
return err
return nil, nil, err
}

if returnType == nil {
return fmt.Errorf("constructor must return at least one value")
return nil, nil, fmt.Errorf("constructor must return at least one value")
}

expectedType := reflectPkg.TypeOf((*T)(nil)).Elem()
if !returnType.AssignableTo(expectedType) {
return fmt.Errorf("constructor returns %s, expected %s", returnType, expectedType)
return nil, nil, fmt.Errorf("constructor returns %s, expected %s", returnType, expectedType)
}

fnVal := reflectPkg.ValueOf(constructor)
Expand Down Expand Up @@ -128,17 +128,10 @@ func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) e
return results[0].Interface().(T), nil
}

opts = append([]ProviderOption{WithDependencies(deps...)}, opts...)
return Provide(c, provider, opts...)
}

func MustProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) {
if err := ProvideFunc[T](c, constructor, opts...); err != nil {
panic(err)
}
return provider, []ProviderOption{WithDependencies(deps...)}, nil
}

func ProvideStruct[T any](c *Container, opts ...ProviderOption) error {
func buildStructProvider[T any](c *Container) (Provider[T], []ProviderOption) {
provider := func(ctx context.Context, r Resolver) (T, error) {
return InvokeStructCtx[T](ctx, c)
}
Expand All @@ -155,7 +148,28 @@ func ProvideStruct[T any](c *Container, opts ...ProviderOption) error {
}
}

opts = append([]ProviderOption{WithDependencies(deps...)}, opts...)
return provider, []ProviderOption{WithDependencies(deps...)}
}

func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) error {
provider, depOpts, err := buildFuncProvider[T](c, constructor)
if err != nil {
return err
}

opts = append(depOpts, opts...)
return Provide(c, provider, opts...)
}

func MustProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) {
if err := ProvideFunc[T](c, constructor, opts...); err != nil {
panic(err)
}
}

func ProvideStruct[T any](c *Container, opts ...ProviderOption) error {
provider, depOpts := buildStructProvider[T](c)
opts = append(depOpts, opts...)
return Provide(c, provider, opts...)
}

Expand Down
12 changes: 4 additions & 8 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,16 @@ func newContainer(opts ...Option) *Container {
}

for _, h := range cfg.onResolve {
hook := h
internalCfg.OnResolve = append(internalCfg.OnResolve, container.ResolveHook(hook))
internalCfg.OnResolve = append(internalCfg.OnResolve, container.ResolveHook(h))
}
for _, h := range cfg.onProvide {
hook := h
internalCfg.OnProvide = append(internalCfg.OnProvide, container.ProvideHook(hook))
internalCfg.OnProvide = append(internalCfg.OnProvide, container.ProvideHook(h))
}
for _, h := range cfg.onStart {
hook := h
internalCfg.OnStart = append(internalCfg.OnStart, container.StartHook(hook))
internalCfg.OnStart = append(internalCfg.OnStart, container.StartHook(h))
}
for _, h := range cfg.onStop {
hook := h
internalCfg.OnStop = append(internalCfg.OnStop, container.StopHook(hook))
internalCfg.OnStop = append(internalCfg.OnStop, container.StopHook(h))
}

c := &Container{
Expand Down
7 changes: 4 additions & 3 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ func (e *Error) Unwrap() error {
}

func (e *Error) Is(target error) bool {
if t, ok := errors.AsType[*Error](target); ok {
return e.Code == t.Code
t, ok := target.(*Error)
if !ok {
return false
}
return false
return e.Code == t.Code
}

func (e *Error) WithService(service string) *Error {
Expand Down
53 changes: 53 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package needle

import (
"errors"
"fmt"
"testing"
)

func TestError_Is_SameCode(t *testing.T) {
t.Parallel()

err1 := newError(ErrCodeServiceNotFound, "service A not found", nil)
err2 := newError(ErrCodeServiceNotFound, "service B not found", nil)

if !errors.Is(err1, err2) {
t.Error("errors with same code should match via Is")
}
}

func TestError_Is_DifferentCode(t *testing.T) {
t.Parallel()

err1 := newError(ErrCodeServiceNotFound, "not found", nil)
err2 := newError(ErrCodeCircularDependency, "cycle", nil)

if errors.Is(err1, err2) {
t.Error("errors with different codes should not match via Is")
}
}

func TestError_Is_DoesNotTraverseTargetChain(t *testing.T) {
t.Parallel()

inner := newError(ErrCodeServiceNotFound, "inner", nil)
wrapper := fmt.Errorf("wrapped: %w", inner)
check := newError(ErrCodeServiceNotFound, "check", nil)

if errors.Is(check, wrapper) {
t.Error("Is should not traverse target's chain, only direct type assertion on target")
}
}

func TestError_Is_WrappedSource(t *testing.T) {
t.Parallel()

inner := newError(ErrCodeServiceNotFound, "inner", nil)
wrapper := fmt.Errorf("wrapped: %w", inner)
target := newError(ErrCodeServiceNotFound, "target", nil)

if !errors.Is(wrapper, target) {
t.Error("errors.Is should find inner *Error via Unwrap chain of source")
}
}
6 changes: 1 addition & 5 deletions internal/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ type Container struct {
logger *slog.Logger
state State

resolving map[string]bool
resolvingMu sync.Mutex

decorators map[string][]DecoratorFunc
decoratorsMu sync.RWMutex

Expand Down Expand Up @@ -68,7 +65,6 @@ func New(cfg *Config) *Container {
registry: NewRegistry(),
graph: graph.New(),
logger: logger,
resolving: make(map[string]bool),
decorators: make(map[string][]DecoratorFunc),
onResolve: cfg.OnResolve,
onProvide: cfg.OnProvide,
Expand All @@ -89,7 +85,7 @@ func (c *Container) Register(key string, provider ProviderFunc, dependencies []s
c.registry.RegisterUnsafe(key, provider, dependencies)
c.graph.AddNodeUnsafe(key, dependencies)

if len(dependencies) > 0 && c.graph.HasCycle() {
if len(dependencies) > 0 && c.graph.HasCycleUnsafe() {
c.registry.RemoveUnsafe(key)
c.graph.RemoveNodeUnsafe(key)
c.mu.Unlock()
Expand Down
62 changes: 62 additions & 0 deletions internal/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package container
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
)

Expand Down Expand Up @@ -297,6 +299,66 @@ func TestContainer_ContextCancellation(t *testing.T) {
}
}

func TestContainer_ConcurrentResolve_NoFalseCycle(t *testing.T) {
t.Parallel()

c := New(&Config{})

_ = c.RegisterValue("dep", "dependency")
_ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) {
_, _ = r.Resolve(ctx, "dep")
return "service", nil
}, []string{"dep"})

var wg sync.WaitGroup
errs := make(chan error, 50)

for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
_, err := c.Resolve(context.Background(), "svc")
if err != nil {
errs <- err
}
}()
}

wg.Wait()
close(errs)

for err := range errs {
t.Errorf("unexpected error during concurrent resolve: %v", err)
}
}

func TestContainer_SingletonCalledOnce(t *testing.T) {
t.Parallel()

c := New(&Config{})

var callCount atomic.Int64
_ = c.Register("singleton", func(ctx context.Context, r Resolver) (any, error) {
callCount.Add(1)
return "instance", nil
}, nil)

var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = c.Resolve(context.Background(), "singleton")
}()
}

wg.Wait()

if count := callCount.Load(); count != 1 {
t.Errorf("singleton provider called %d times, expected 1", count)
}
}

func BenchmarkContainer_Resolve(b *testing.B) {
c := New(&Config{})

Expand Down
19 changes: 9 additions & 10 deletions internal/container/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package container

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -110,13 +111,9 @@ func (c *Container) startService(ctx context.Context, key string) error {
return fmt.Errorf("failed to resolve %s during startup: %w", key, err)
}

entry, exists := c.registry.GetEntry(key)
if !exists {
return nil
}

var startErr error
for _, hook := range entry.OnStart {
hooks := c.registry.GetOnStartHooks(key)
for _, hook := range hooks {
c.logger.Debug("running OnStart hook", "service", key)
if err := hook(ctx); err != nil {
startErr = fmt.Errorf("OnStart hook failed for %s: %w", key, err)
Expand Down Expand Up @@ -240,15 +237,17 @@ func (c *Container) stopService(ctx context.Context, key string) error {
}

start := time.Now()
var stopErr error
var errs []error

for i := len(entry.OnStop) - 1; i >= 0; i-- {
hooks := c.registry.GetOnStopHooks(key)
for i := len(hooks) - 1; i >= 0; i-- {
c.logger.Debug("running OnStop hook", "service", key)
if err := entry.OnStop[i](ctx); err != nil {
stopErr = fmt.Errorf("OnStop hook failed for %s: %w", key, err)
if err := hooks[i](ctx); err != nil {
errs = append(errs, fmt.Errorf("OnStop hook failed for %s: %w", key, err))
}
}

stopErr := errors.Join(errs...)
c.callStopHooks(key, time.Since(start), stopErr)
return stopErr
}
Expand Down
Loading
Loading