Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ func (c *Circuit) openCircuit(ctx context.Context, now time.Time) {
// Don't bother opening a circuit that is already open
return
}
if !c.isOpen.CompareAndSwap(false, true) {
// Another goroutine already opened it; don't double-emit Opened()
return
}
c.CircuitMetricsCollector.Opened(ctx, now)
c.isOpen.Set(true)
}

// Go executes `Execute`, but uses spawned goroutines to end early if the context is canceled. Use this if you don't trust
Expand Down Expand Up @@ -403,7 +406,8 @@ func (c *Circuit) fallback(ctx context.Context, err error, fallbackFunc func(con
// Throttle concurrent fallback calls
currentFallbackCount := c.concurrentFallbacks.Add(1)
defer c.concurrentFallbacks.Add(-1)
if c.threadSafeConfig.Fallback.MaxConcurrentRequests.Get() >= 0 && currentFallbackCount > c.threadSafeConfig.Fallback.MaxConcurrentRequests.Get() {
maxFallback := c.threadSafeConfig.Fallback.MaxConcurrentRequests.Get()
if maxFallback >= 0 && currentFallbackCount > maxFallback {
c.FallbackMetricCollector.ErrConcurrencyLimitReject(ctx, c.now())
return &circuitError{concurrencyLimitReached: true, msg: "throttling concurrency to fallbacks"}
}
Expand Down Expand Up @@ -441,8 +445,11 @@ func (c *Circuit) close(ctx context.Context, now time.Time, forceClosed bool) {
return
}
if forceClosed || c.OpenToClose.ShouldClose(ctx, now) {
if !c.isOpen.CompareAndSwap(true, false) {
// Another goroutine already closed it; don't double-emit Closed()
return
}
c.CircuitMetricsCollector.Closed(ctx, now)
c.isOpen.Set(false)
}
}

Expand Down
120 changes: 106 additions & 14 deletions circuit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -197,27 +196,44 @@ func TestThrottled(t *testing.T) {
MaxConcurrentRequests: 2,
},
})
// Barrier pattern: wait until ALL goroutines have either entered runFunc or been
// rejected BEFORE releasing the ones inside. This guarantees the 3rd attempts
// while 2 are still occupying slots, regardless of CI scheduling delays.
const numGoroutines = 3
const limit = 2
attempted := make(chan struct{}, numGoroutines)
release := make(chan struct{})
bc := testhelp.BehaviorCheck{
RunFunc: testhelp.SleepsForX(time.Millisecond),
RunFunc: func(ctx context.Context) error {
attempted <- struct{}{} // entered runFunc
<-release
return nil
},
}
wg := sync.WaitGroup{}
errCount := 0
for i := 0; i < 3; i++ {
var errCount faststats.AtomicInt64
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
err := c.Execute(context.Background(), bc.Run, nil)
if err != nil {
errCount++
errCount.Add(1)
attempted <- struct{}{} // rejected goroutine also signals
}
}()
}
// Wait until all have attempted (2 inside runFunc, 1 rejected) before releasing.
for i := 0; i < numGoroutines; i++ {
<-attempted
}
close(release)
wg.Wait()
if bc.MostConcurrent != 2 {
if bc.MostConcurrent != limit {
t.Errorf("Concurrent count not correct: %d", bc.MostConcurrent)
}
if errCount != 1 {
t.Errorf("did not see error return count: %d", errCount)
if errCount.Get() != 1 {
t.Errorf("did not see error return count: %d", errCount.Get())
}
}

Expand Down Expand Up @@ -429,30 +445,46 @@ func TestFallbackCircuitConcurrency(t *testing.T) {
MaxConcurrentRequests: 2,
},
})
// Barrier pattern: wait until ALL goroutines have either entered the fallback or
// been rejected BEFORE releasing the ones inside. This guarantees the 3rd attempts
// while 2 are still occupying fallback slots, regardless of CI scheduling delays.
const numGoroutines = 3
const limit = 2
attempted := make(chan struct{}, numGoroutines)
release := make(chan struct{})
wg := sync.WaitGroup{}
workingCircuitCount := int64(0)
var workingCircuitCount faststats.AtomicInt64
var fallbackExecuted faststats.AtomicInt64
var totalExecuted faststats.AtomicInt64
for i := 0; i < 3; i++ {
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
totalExecuted.Add(1)
defer wg.Done()
err := c.Execute(context.Background(), testhelp.AlwaysFails, func(ctx context.Context, err error) error {
fallbackExecuted.Add(1)
return testhelp.SleepsForX(time.Millisecond * 500)(ctx)
attempted <- struct{}{} // entered fallback
<-release
return nil
})
if err == nil {
atomic.AddInt64(&workingCircuitCount, 1)
workingCircuitCount.Add(1)
} else {
attempted <- struct{}{} // rejected goroutine also signals
}
}()
}
// Wait until all have attempted (2 inside fallback, 1 rejected) before releasing.
for i := 0; i < numGoroutines; i++ {
<-attempted
}
close(release)
wg.Wait()
if totalExecuted.Get() == fallbackExecuted.Get() {
t.Error("At least one fallback call should never happen due to concurrency")
}
if workingCircuitCount != 2 {
t.Error("Should see 2 working examples")
if workingCircuitCount.Get() != limit {
t.Errorf("Should see %d working examples, got %d", limit, workingCircuitCount.Get())
}
}

Expand Down Expand Up @@ -694,3 +726,63 @@ func (tc *timeoutChecker) ErrBadRequest(_ context.Context, _ time.Time, _ time.D
func (tc *timeoutChecker) ErrInterrupt(_ context.Context, _ time.Time, _ time.Duration) {}
func (tc *timeoutChecker) ErrConcurrencyLimitReject(_ context.Context, _ time.Time) {}
func (tc *timeoutChecker) ErrShortCircuit(_ context.Context, _ time.Time) {}

// transitionCounter counts Opened()/Closed() calls to verify exactly-once semantics.
type transitionCounter struct {
opened faststats.AtomicInt64
closed faststats.AtomicInt64
}

func (tc *transitionCounter) Opened(_ context.Context, _ time.Time) { tc.opened.Add(1) }
func (tc *transitionCounter) Closed(_ context.Context, _ time.Time) { tc.closed.Add(1) }

// TestOpenCircuit_EmitsOpenedExactlyOnce ensures concurrent openCircuit calls
// emit the Opened() metric exactly once per actual state transition.
func TestOpenCircuit_EmitsOpenedExactlyOnce(t *testing.T) {
tc := &transitionCounter{}
c := NewCircuitFromConfig("TestOpenOnce", Config{
Metrics: MetricsCollectors{
Circuit: []Metrics{tc},
},
})
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
c.OpenCircuit(context.Background())
}()
}
wg.Wait()
if got := tc.opened.Get(); got != 1 {
t.Errorf("Opened() called %d times, want exactly 1", got)
}
}

// TestCloseCircuit_EmitsClosedExactlyOnce ensures concurrent CloseCircuit calls
// emit the Closed() metric exactly once per actual state transition.
func TestCloseCircuit_EmitsClosedExactlyOnce(t *testing.T) {
tc := &transitionCounter{}
c := NewCircuitFromConfig("TestCloseOnce", Config{
Metrics: MetricsCollectors{
Circuit: []Metrics{tc},
},
})
c.OpenCircuit(context.Background())
// Reset the opened counter; we only care about closed here
tc.opened.Set(0)
tc.closed.Set(0)

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
c.CloseCircuit(context.Background())
}()
}
wg.Wait()
if got := tc.closed.Get(); got != 1 {
t.Errorf("Closed() called %d times, want exactly 1", got)
}
}
22 changes: 20 additions & 2 deletions closers/hystrix/circuit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/cep21/circuit/v4"
"github.com/cep21/circuit/v4/internal/clock"
"github.com/cep21/circuit/v4/internal/testhelp"
)

Expand Down Expand Up @@ -113,13 +114,26 @@ func TestCircuitAttemptsToReopen(t *testing.T) {
}

func TestCircuitAttemptsToReopenOnlyOnce(t *testing.T) {
// Use a mock clock for deterministic timing — the previous real-clock version
// with SleepWindow: 1ms flaked when CI scheduling delays exceeded 1ms.
mockClock := &clock.MockClock{}
now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
mockClock.Set(now)
sleepWindow := time.Second

c := circuit.NewCircuitFromConfig("TestCircuitAttemptsToReopenOnlyOnce", circuit.Config{
General: circuit.GeneralConfig{
TimeKeeper: circuit.TimeKeeper{
Now: mockClock.Now,
AfterFunc: mockClock.AfterFunc,
},
OpenToClosedFactory: CloserFactory(ConfigureCloser{
SleepWindow: time.Millisecond,
SleepWindow: sleepWindow,
AfterFunc: mockClock.AfterFunc,
}),
ClosedToOpenFactory: OpenerFactory(ConfigureOpener{
RequestVolumeThreshold: 1,
Now: mockClock.Now,
}),
},
})
Expand All @@ -133,16 +147,20 @@ func TestCircuitAttemptsToReopenOnlyOnce(t *testing.T) {
if !c.IsOpen() {
t.Fatal("Circuit should be open after having failed once")
}
// No time has advanced — sleep window is still active, circuit must reject
err = c.Execute(context.Background(), testhelp.AlwaysPasses, nil)
if err == nil {
t.Fatal("Circuit should be open")
}

time.Sleep(time.Millisecond * 3)
// Advance mock time past the sleep window
mockClock.Add(sleepWindow + time.Millisecond)
// Half-open probe: allowed once, fails → circuit stays open, sleep window resets
err = c.Execute(context.Background(), testhelp.AlwaysFails, nil)
if err == nil {
t.Fatal("Circuit should try to reopen, but fail")
}
// Second attempt in the same (new) sleep window must be rejected
err = c.Execute(context.Background(), testhelp.AlwaysPasses, nil)
if err == nil {
t.Fatal("Circuit should only try to reopen once")
Expand Down
2 changes: 1 addition & 1 deletion closers/hystrix/opener.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (e *Opener) SetConfigThreadSafe(props ConfigureOpener) {
// SetConfigNotThreadSafe recreates the buckets. It is not safe to call while the circuit is active.
func (e *Opener) SetConfigNotThreadSafe(props ConfigureOpener) {
e.SetConfigThreadSafe(props)
now := props.Now()
now := props.now()
rollingCounterBucketWidth := time.Duration(props.RollingDuration.Nanoseconds() / int64(props.NumBuckets))
e.errorsCount = faststats.NewRollingCounter(rollingCounterBucketWidth, props.NumBuckets, now)
e.legitimateAttemptsCount = faststats.NewRollingCounter(rollingCounterBucketWidth, props.NumBuckets, now)
Expand Down
16 changes: 16 additions & 0 deletions closers/hystrix/opener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ func TestOpener(t *testing.T) {
}
}

func TestOpener_SetConfigNotThreadSafe_NilNow(t *testing.T) {
// SetConfigNotThreadSafe should not panic when Now is nil; it should fall
// back to time.Now via the nil-safe now() helper.
o := &Opener{}
o.SetConfigNotThreadSafe(ConfigureOpener{
RequestVolumeThreshold: 20,
ErrorThresholdPercentage: 50,
NumBuckets: 10,
RollingDuration: 10 * time.Second,
// Now intentionally left nil
})
if o.Config().RequestVolumeThreshold != 20 {
t.Errorf("config not applied: RequestVolumeThreshold = %d", o.Config().RequestVolumeThreshold)
}
}

func TestOpenerFactory_ConcurrentCreation(t *testing.T) {
factory := OpenerFactory(ConfigureOpener{
RequestVolumeThreshold: 10,
Expand Down
5 changes: 3 additions & 2 deletions closers/simplelogic/closers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ type ConsecutiveErrOpener struct {
func ConsecutiveErrOpenerFactory(config ConfigConsecutiveErrOpener) func() circuit.ClosedToOpen {
return func() circuit.ClosedToOpen {
ret := &ConsecutiveErrOpener{}
config.Merge(defaultConfigConsecutiveErrOpener)
ret.SetConfigThreadSafe(config)
cfg := config
cfg.Merge(defaultConfigConsecutiveErrOpener)
ret.SetConfigThreadSafe(cfg)
return ret
}
}
Expand Down
19 changes: 19 additions & 0 deletions closers/simplelogic/closers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package simplelogic
import (
"context"
"fmt"
"sync"
"testing"
"time"

Expand All @@ -26,6 +27,24 @@ func TestConsecutiveErrOpenerFactory(t *testing.T) {
}
}

func TestConsecutiveErrOpenerFactory_ConcurrentCreation(t *testing.T) {
// Concurrent factory calls must not race on the captured config.
// Run with -race to verify no data race on config.ErrorThreshold.
factory := ConsecutiveErrOpenerFactory(ConfigConsecutiveErrOpener{})
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
o := factory().(*ConsecutiveErrOpener)
if o.closeThreshold.Get() != 10 {
t.Errorf("closeThreshold = %d, want 10 (default)", o.closeThreshold.Get())
}
}()
}
wg.Wait()
}

func TestConsecutiveErrOpener_Merge(t *testing.T) {
c := &ConfigConsecutiveErrOpener{}
c.Merge(ConfigConsecutiveErrOpener{
Expand Down
8 changes: 8 additions & 0 deletions faststats/rolling_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ func (r *RollingCounter) MarshalJSON() ([]byte, error) {
}

// UnmarshalJSON stores the previous JSON encoding. Note, this is *NOT* thread safe.
// Returns an error if the JSON is missing required fields (i.e., was not produced
// by MarshalJSON or was truncated); the receiver is left unmodified in that case.
func (r *RollingCounter) UnmarshalJSON(b []byte) error {
var into jsonCounter
if err := json.Unmarshal(b, &into); err != nil {
return err
}
if into.RollingSum == nil || into.TotalSum == nil || into.RollingBucket == nil {
return fmt.Errorf("RollingCounter.UnmarshalJSON: incomplete JSON (missing required fields)")
}
r.buckets = into.Buckets
r.rollingSum.Store(into.RollingSum.Get())
r.totalSum.Store(into.TotalSum.Get())
Expand Down Expand Up @@ -115,6 +120,9 @@ func (r *RollingCounter) TotalSum() int64 {

// GetBuckets returns a copy of the buckets in order backwards in time
func (r *RollingCounter) GetBuckets(now time.Time) []int64 {
if r.rollingBucket.NumBuckets == 0 {
return nil
}
r.rollingBucket.Advance(now, r.clearBucket)
startIdx := int(r.rollingBucket.LastAbsIndex.Get() % int64(r.rollingBucket.NumBuckets))
ret := make([]int64, r.rollingBucket.NumBuckets)
Expand Down
Loading
Loading