diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 57f4860..471bd82 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,11 +28,11 @@ jobs: - name: Check out code uses: actions/checkout@v6 - name: Build - run: go build -mod=readonly ./... + run: make build - name: Verify run: go mod verify - name: Test - run: env "GORACE=halt_on_error=1" go test -v -race -count 10 ./... + run: make test-race - name: golangci-lint uses: golangci/golangci-lint-action@v9 - name: Output coverage diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml new file mode 100644 index 0000000..22f7185 --- /dev/null +++ b/.github/workflows/fuzz.yml @@ -0,0 +1,30 @@ +name: Fuzz + +permissions: + contents: read + +on: + schedule: + - cron: '0 6 * * *' # daily, 06:00 UTC + workflow_dispatch: # manual trigger from Actions tab + +jobs: + fuzz: + name: Fuzz + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v6 + - name: Install Go + uses: actions/setup-go@v6 + with: + go-version: 1.25.x + - name: Fuzz + run: make fuzz FUZZTIME=5m + - name: Upload crash inputs + if: failure() + uses: actions/upload-artifact@v4 + with: + name: fuzz-crashers + path: '**/testdata/fuzz/**' + if-no-files-found: ignore diff --git a/.gitignore b/.gitignore index b10d64f..22e4b6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea /coverage.out +/coverage.html diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c99fb3d --- /dev/null +++ b/Makefile @@ -0,0 +1,59 @@ +.PHONY: all ci build test test-race lint fuzz bench fix cover help + +# Default: the fast loop (build + test + lint). Use during active development. +all: build test lint + +# Everything CI runs. Use before submitting a PR. +ci: build test-race lint + +build: + go build -mod=readonly ./... + +# Fast unit tests — includes property tests and fuzz seed corpora +test: + go test ./... + +# What CI runs: race detector, 10 iterations, halt on first race +test-race: + env "GORACE=halt_on_error=1" go test -race -count 10 ./... + +lint: + golangci-lint run + +# Coverage HTML report (opens in browser on most systems; else see coverage.html) +cover: + go test -covermode=count -coverprofile=coverage.out ./... + go tool cover -html=coverage.out -o coverage.html + @echo "wrote coverage.html" + +# Active fuzzing. One target at a time, FUZZTIME per target (default 30s). +# Example: make fuzz FUZZTIME=2m +FUZZTIME ?= 30s +FUZZ_TARGETS := FuzzRollingCounterOps FuzzRollingCounterJSON \ + FuzzSortedDurationsPercentile FuzzRollingBucketAdvance FuzzTimedCheckJSON +fuzz: + @for t in $(FUZZ_TARGETS); do \ + echo "=== fuzz $$t ($(FUZZTIME)) ==="; \ + go test -fuzz="^$$t$$" -fuzztime=$(FUZZTIME) ./faststats/ || exit 1; \ + done + +bench: + go test -benchmem -run=^$$ -bench=. ./... + +# Auto-format: gofmt + goimports on all source +fix: + gofmt -s -w . + @command -v goimports >/dev/null 2>&1 && goimports -w . || echo "goimports not installed, skipped" + +help: + @echo "Targets:" + @echo " make fast dev loop: build + test + lint (~5s)" + @echo " make ci everything CI runs — use before submitting a PR (~1min)" + @echo " make build compile all packages" + @echo " make test unit tests (includes property tests, fuzz seeds)" + @echo " make test-race race detector, -count 10" + @echo " make lint golangci-lint run" + @echo " make fuzz active fuzzing, FUZZTIME per target (default 30s)" + @echo " make bench all benchmarks" + @echo " make cover generate coverage.html" + @echo " make fix gofmt + goimports" diff --git a/README.md b/README.md index e0dec8f..af67575 100644 --- a/README.md +++ b/README.md @@ -451,7 +451,14 @@ BenchmarkCiruits/iand_circuit/Default/passing/75-8 5000000 349 ns # [Development](https://github.com/cep21/circuit/blob/master/Makefile) -Make sure your tests pass with `go test` and your lints pass with `golangci-lint run`. +```bash +make # fast dev loop: build + test + lint (~5s) +make ci # everything CI runs — use before submitting a PR (~1min) +make fuzz # active fuzzing (FUZZTIME per target, default 30s) +make help # full target list +``` + +`make ci` mirrors the GitHub Actions workflow: build, `go test -race -count 10`, and `golangci-lint run`. If it passes locally, CI should pass. # [Example](https://github.com/cep21/circuit/blob/master/example/main.go) diff --git a/circuit_property_test.go b/circuit_property_test.go new file mode 100644 index 0000000..c8fed50 --- /dev/null +++ b/circuit_property_test.go @@ -0,0 +1,697 @@ +package circuit + +import ( + "context" + "errors" + "fmt" + "math/rand" + "sync" + "testing" + "testing/quick" + "time" + + "github.com/cep21/circuit/v4/faststats" + "github.com/cep21/circuit/v4/internal/clock" +) + +// countingRunMetrics records every RunMetrics callback. Used to verify the +// "exactly one callback per Execute" contract stated at metrics.go (RunMetrics +// interface doc). +type countingRunMetrics struct { + success faststats.AtomicInt64 + errFailure faststats.AtomicInt64 + errTimeout faststats.AtomicInt64 + errBadRequest faststats.AtomicInt64 + errInterrupt faststats.AtomicInt64 + errConcurrencyLimitReject faststats.AtomicInt64 + errShortCircuit faststats.AtomicInt64 +} + +var _ RunMetrics = &countingRunMetrics{} + +func (m *countingRunMetrics) Success(_ context.Context, _ time.Time, _ time.Duration) { + m.success.Add(1) +} +func (m *countingRunMetrics) ErrFailure(_ context.Context, _ time.Time, _ time.Duration) { + m.errFailure.Add(1) +} +func (m *countingRunMetrics) ErrTimeout(_ context.Context, _ time.Time, _ time.Duration) { + m.errTimeout.Add(1) +} +func (m *countingRunMetrics) ErrBadRequest(_ context.Context, _ time.Time, _ time.Duration) { + m.errBadRequest.Add(1) +} +func (m *countingRunMetrics) ErrInterrupt(_ context.Context, _ time.Time, _ time.Duration) { + m.errInterrupt.Add(1) +} +func (m *countingRunMetrics) ErrConcurrencyLimitReject(_ context.Context, _ time.Time) { + m.errConcurrencyLimitReject.Add(1) +} +func (m *countingRunMetrics) ErrShortCircuit(_ context.Context, _ time.Time) { + m.errShortCircuit.Add(1) +} + +// total returns the sum of all callback counts. The RunMetrics contract +// guarantees exactly one callback fires per Execute call, so total() should +// equal the number of Execute calls. +func (m *countingRunMetrics) total() int64 { + return m.success.Get() + + m.errFailure.Get() + + m.errTimeout.Get() + + m.errBadRequest.Get() + + m.errInterrupt.Get() + + m.errConcurrencyLimitReject.Get() + + m.errShortCircuit.Get() +} + +// countingFallbackMetrics records every FallbackMetrics callback. +type countingFallbackMetrics struct { + success faststats.AtomicInt64 + errFailure faststats.AtomicInt64 + errConcurrencyLimitReject faststats.AtomicInt64 +} + +var _ FallbackMetrics = &countingFallbackMetrics{} + +func (m *countingFallbackMetrics) Success(_ context.Context, _ time.Time, _ time.Duration) { + m.success.Add(1) +} +func (m *countingFallbackMetrics) ErrFailure(_ context.Context, _ time.Time, _ time.Duration) { + m.errFailure.Add(1) +} +func (m *countingFallbackMetrics) ErrConcurrencyLimitReject(_ context.Context, _ time.Time) { + m.errConcurrencyLimitReject.Add(1) +} + +func (m *countingFallbackMetrics) total() int64 { + return m.success.Get() + m.errFailure.Get() + m.errConcurrencyLimitReject.Get() +} + +// ============================================================================ +// Section 1: Invariant-asserting stress tests +// These run the circuit hot under concurrency and verify global invariants +// that the existing stress tests miss. +// ============================================================================ + +// TestCircuit_RunMetricsExactlyOnce_Concurrent verifies the RunMetrics contract +// ("exactly one callback per Execute") holds under high concurrency with a mix +// of success/failure/bad-request/interrupt outcomes. This is the class of bug +// fixed by the CAS changes in openCircuit/close — protocol violations that +// -race does not catch. +func TestCircuit_RunMetricsExactlyOnce_Concurrent(t *testing.T) { + rm := &countingRunMetrics{} + tc := &transitionCounter{} + c := NewCircuitFromConfig("RunMetricsExactlyOnce", Config{ + Execution: ExecutionConfig{ + MaxConcurrentRequests: -1, + Timeout: time.Hour, + }, + Metrics: MetricsCollectors{ + Run: []RunMetrics{rm}, + Circuit: []Metrics{tc}, + }, + }) + + const goroutines = 50 + const iterations = 2000 + ctx := context.Background() + + errFail := errors.New("fail") + var wg sync.WaitGroup + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + switch (id + i) % 3 { + case 0: + _ = c.Execute(ctx, func(context.Context) error { return nil }, nil) + case 1: + _ = c.Execute(ctx, func(context.Context) error { return errFail }, nil) + case 2: + _ = c.Execute(ctx, func(context.Context) error { + return SimpleBadRequest{Err: errFail} + }, nil) + } + } + }(g) + } + wg.Wait() + + const want = int64(goroutines * iterations) + if got := rm.total(); got != want { + t.Errorf("RunMetrics exactly-once violated: total callbacks = %d, Execute calls = %d\n"+ + " success=%d failure=%d timeout=%d badReq=%d interrupt=%d concReject=%d shortCircuit=%d", + got, want, + rm.success.Get(), rm.errFailure.Get(), rm.errTimeout.Get(), + rm.errBadRequest.Get(), rm.errInterrupt.Get(), + rm.errConcurrencyLimitReject.Get(), rm.errShortCircuit.Get()) + } + + // Default circuit uses neverOpens/neverCloses — no transitions expected. + if tc.opened.Get() != 0 || tc.closed.Get() != 0 { + t.Errorf("unexpected transitions with neverOpens/neverCloses: opened=%d closed=%d", + tc.opened.Get(), tc.closed.Get()) + } + + if cc := c.ConcurrentCommands(); cc != 0 { + t.Errorf("concurrentCommands counter unbalanced after all Execute returned: %d", cc) + } +} + +// TestCircuit_TransitionAlternation_Concurrent hammers OpenCircuit/CloseCircuit +// concurrently with Execute and verifies the Opened/Closed callbacks alternate +// correctly: at any point |opened - closed| ≤ 1, and opened ≥ closed (you can +// only close what was opened). This is the exact invariant bugs #4/#5 broke. +func TestCircuit_TransitionAlternation_Concurrent(t *testing.T) { + rm := &countingRunMetrics{} + tc := &transitionCounter{} + c := NewCircuitFromConfig("TransitionAlternation", Config{ + Execution: ExecutionConfig{ + MaxConcurrentRequests: -1, + Timeout: time.Hour, + }, + Metrics: MetricsCollectors{ + Run: []RunMetrics{rm}, + Circuit: []Metrics{tc}, + }, + }) + + ctx := context.Background() + var running faststats.AtomicBoolean + running.Set(true) + var wg sync.WaitGroup + + // Goroutines that bang on Execute + for g := 0; g < 8; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for running.Get() { + var runFn func(context.Context) error + if id%2 == 0 { + runFn = func(context.Context) error { return nil } + } else { + runFn = func(context.Context) error { return errors.New("x") } + } + _ = c.Execute(ctx, runFn, nil) + } + }(g) + } + + // Goroutines that toggle the circuit. They race with each other and with + // Execute on the CAS in openCircuit/close. + for g := 0; g < 4; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for running.Get() { + c.OpenCircuit(ctx) + c.CloseCircuit(ctx) + } + }() + } + + time.Sleep(50 * time.Millisecond) + running.Set(false) + wg.Wait() + + opened := tc.opened.Get() + closed := tc.closed.Get() + + // Circuit starts closed, so Opened must fire first; and each Closed must + // have a matching Opened before it. Hence: closed ≤ opened ≤ closed+1. + if closed > opened { + t.Errorf("transition ordering violated: closed=%d > opened=%d", closed, opened) + } + if opened > closed+1 { + t.Errorf("Opened() emitted without matching Closed(): opened=%d closed=%d (diff=%d, max diff is 1)", + opened, closed, opened-closed) + } + + if cc := c.ConcurrentCommands(); cc != 0 { + t.Errorf("concurrentCommands unbalanced: %d", cc) + } + + // Sanity: we expect many transitions (not 0) given 50ms of hammering; + // if 0 the test setup is wrong. + if opened == 0 { + t.Logf("warning: no transitions observed in 50ms — test may be ineffective on this machine") + } +} + +// TestCircuit_FallbackMetricsExactlyOnce_Concurrent verifies the FallbackMetrics +// contract under concurrency: exactly one fallback callback per fallback invocation. +func TestCircuit_FallbackMetricsExactlyOnce_Concurrent(t *testing.T) { + rm := &countingRunMetrics{} + fm := &countingFallbackMetrics{} + c := NewCircuitFromConfig("FallbackExactlyOnce", Config{ + Execution: ExecutionConfig{ + MaxConcurrentRequests: -1, + Timeout: time.Hour, + }, + Fallback: FallbackConfig{ + MaxConcurrentRequests: -1, + }, + Metrics: MetricsCollectors{ + Run: []RunMetrics{rm}, + Fallback: []FallbackMetrics{fm}, + }, + }) + + const goroutines = 30 + const iterations = 1000 + ctx := context.Background() + errFail := errors.New("fail") + + var wg sync.WaitGroup + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + // Always fail runFunc to trigger fallback; vary fallback outcome. + fallbackFails := (id+i)%3 == 0 + _ = c.Execute(ctx, + func(context.Context) error { return errFail }, + func(context.Context, error) error { + if fallbackFails { + return errFail + } + return nil + }) + } + }(g) + } + wg.Wait() + + const want = int64(goroutines * iterations) + + if got := rm.total(); got != want { + t.Errorf("RunMetrics total = %d, want %d", got, want) + } + + // Every Execute failed the runFunc (no bad requests, no short circuits + // since neverOpens). So every one should hit the fallback path and emit + // exactly one FallbackMetrics callback. + if got := fm.total(); got != want { + t.Errorf("FallbackMetrics exactly-once violated: total = %d, want %d\n"+ + " success=%d failure=%d concReject=%d", + got, want, fm.success.Get(), fm.errFailure.Get(), fm.errConcurrencyLimitReject.Get()) + } + + if cf := c.ConcurrentFallbacks(); cf != 0 { + t.Errorf("concurrentFallbacks unbalanced: %d", cf) + } +} + +// ============================================================================ +// Section 2: Deterministic state-machine property test +// Uses MockClock and a programmable opener/closer so the circuit is fully +// deterministic. Replays random operation sequences and checks invariants +// after every step. +// ============================================================================ + +// commandableOpener is a test ClosedToOpen whose ShouldOpen result is set +// externally — lets the property test drive state transitions deterministically. +type commandableOpener struct { + neverOpens + wantOpen faststats.AtomicBoolean +} + +func (o *commandableOpener) ShouldOpen(_ context.Context, _ time.Time) bool { + return o.wantOpen.Get() +} + +// commandableCloser mirrors commandableOpener for OpenToClosed. +type commandableCloser struct { + neverCloses + wantClose faststats.AtomicBoolean + allow faststats.AtomicBoolean +} + +func (c *commandableCloser) ShouldClose(_ context.Context, _ time.Time) bool { + return c.wantClose.Get() +} +func (c *commandableCloser) Allow(_ context.Context, _ time.Time) bool { + return c.allow.Get() +} + +// circuitOp is an operation we can perform against the circuit in the property +// test. Using a dense int8 enum plays well with testing/quick's shrinking. +type circuitOp int8 + +const ( + opExecSuccess circuitOp = iota + opExecFailure + opExecBadRequest + opExecFallbackSuccess + opExecFallbackFailure + opOpenCircuit + opCloseCircuit + opSetWantOpenTrue + opSetWantOpenFalse + opSetAllowTrue + opSetAllowFalse + opSetWantCloseTrue + opSetWantCloseFalse + opAdvanceClock + numOps +) + +// circuitModel is a parallel, trivially-correct model of the circuit state. +// The property test keeps the real circuit and this model in lockstep; any +// divergence is a bug. +type circuitModel struct { + isOpen bool + allow bool + wantOpen bool + wantClose bool + executeN int64 + fallbackN int64 + openedN int64 + closedN int64 +} + +func (m *circuitModel) apply(op circuitOp) { + switch op { + case opExecSuccess: + m.executeN++ + if m.isOpen && !m.allow { + return // short-circuited, run not attempted + } + // Success while open can close the circuit if wantClose. + if m.isOpen && m.wantClose { + m.isOpen = false + m.closedN++ + } + case opExecFailure: + m.executeN++ + if m.isOpen && !m.allow { + return + } + // Failure while closed can open the circuit if wantOpen. + if !m.isOpen && m.wantOpen { + m.isOpen = true + m.openedN++ + } + case opExecBadRequest: + m.executeN++ + // Bad requests never change circuit state and never hit fallback. + // They DO still get short-circuited if the circuit is open though. + case opExecFallbackSuccess, opExecFallbackFailure: + m.executeN++ + if m.isOpen && !m.allow { + // Short-circuited: fallback IS called on short-circuit. + m.fallbackN++ + return + } + // runFunc fails → fallback fires. + m.fallbackN++ + // Failure may also open the circuit. + if !m.isOpen && m.wantOpen { + m.isOpen = true + m.openedN++ + } + case opOpenCircuit: + if !m.isOpen { + m.isOpen = true + m.openedN++ + } + case opCloseCircuit: + if m.isOpen { + m.isOpen = false + m.closedN++ + } + case opSetWantOpenTrue: + m.wantOpen = true + case opSetWantOpenFalse: + m.wantOpen = false + case opSetAllowTrue: + m.allow = true + case opSetAllowFalse: + m.allow = false + case opSetWantCloseTrue: + m.wantClose = true + case opSetWantCloseFalse: + m.wantClose = false + case opAdvanceClock: + // no state change in model + } +} + +// checkInvariants returns a non-empty error string if any invariant is +// violated. These are the properties that should hold at every step. +func checkInvariants( + rm *countingRunMetrics, fm *countingFallbackMetrics, tc *transitionCounter, + c *Circuit, m *circuitModel, step int, op circuitOp, +) string { + // 1. Exactly one RunMetrics callback per Execute. + if got, want := rm.total(), m.executeN; got != want { + return reportf(step, op, + "RunMetrics total=%d ≠ executeN=%d (success=%d fail=%d timeout=%d badReq=%d int=%d concRej=%d short=%d)", + got, want, + rm.success.Get(), rm.errFailure.Get(), rm.errTimeout.Get(), + rm.errBadRequest.Get(), rm.errInterrupt.Get(), + rm.errConcurrencyLimitReject.Get(), rm.errShortCircuit.Get()) + } + + // 2. At most one FallbackMetrics callback per fallback invocation. + if got, want := fm.total(), m.fallbackN; got != want { + return reportf(step, op, "FallbackMetrics total=%d ≠ model fallbackN=%d", got, want) + } + + // 3. Opened/Closed counts match the model exactly (deterministic). + if tc.opened.Get() != m.openedN { + return reportf(step, op, "Opened count=%d ≠ model=%d", tc.opened.Get(), m.openedN) + } + if tc.closed.Get() != m.closedN { + return reportf(step, op, "Closed count=%d ≠ model=%d", tc.closed.Get(), m.closedN) + } + + // 4. Transition alternation: never more Closed than Opened, and the gap + // is at most 1. + if m.closedN > m.openedN { + return reportf(step, op, "alternation violated: closed=%d > opened=%d", m.closedN, m.openedN) + } + if m.openedN > m.closedN+1 { + return reportf(step, op, "alternation violated: opened=%d > closed+1=%d", m.openedN, m.closedN+1) + } + + // 5. IsOpen matches model. + if c.IsOpen() != m.isOpen { + return reportf(step, op, "IsOpen()=%v ≠ model.isOpen=%v", c.IsOpen(), m.isOpen) + } + + // 6. Concurrent counters return to 0 after serial execution. + if cc := c.ConcurrentCommands(); cc != 0 { + return reportf(step, op, "ConcurrentCommands=%d after serial Execute", cc) + } + if cf := c.ConcurrentFallbacks(); cf != 0 { + return reportf(step, op, "ConcurrentFallbacks=%d after serial Execute", cf) + } + + return "" +} + +func reportf(step int, op circuitOp, format string, args ...interface{}) string { + all := make([]interface{}, 0, len(args)+2) + all = append(all, step, opName(op)) + all = append(all, args...) + return fmt.Sprintf("[step %d op=%s] "+format, all...) +} + +func opName(op circuitOp) string { + names := [...]string{ + "ExecSuccess", "ExecFailure", "ExecBadRequest", + "ExecFallbackSuccess", "ExecFallbackFailure", + "OpenCircuit", "CloseCircuit", + "WantOpen=true", "WantOpen=false", + "Allow=true", "Allow=false", + "WantClose=true", "WantClose=false", + "AdvanceClock", + } + if int(op) >= 0 && int(op) < len(names) { + return names[op] + } + return "?" +} + +// replayOps builds a fresh circuit + model and replays the op sequence, +// checking invariants after every step. Returns (failure msg, false) on +// violation, ("", true) on success. +func replayOps(ops []circuitOp) (string, bool) { + mc := &clock.MockClock{} + mc.Set(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) + + opener := &commandableOpener{} + closer := &commandableCloser{} + + rm := &countingRunMetrics{} + fm := &countingFallbackMetrics{} + tc := &transitionCounter{} + + c := NewCircuitFromConfig("prop", Config{ + General: GeneralConfig{ + TimeKeeper: TimeKeeper{ + Now: mc.Now, + AfterFunc: mc.AfterFunc, + }, + ClosedToOpenFactory: func() ClosedToOpen { return opener }, + OpenToClosedFactory: func() OpenToClosed { return closer }, + }, + Execution: ExecutionConfig{ + MaxConcurrentRequests: -1, + Timeout: time.Hour, + }, + Fallback: FallbackConfig{ + MaxConcurrentRequests: -1, + }, + Metrics: MetricsCollectors{ + Run: []RunMetrics{rm}, + Fallback: []FallbackMetrics{fm}, + Circuit: []Metrics{tc}, + }, + }) + + model := &circuitModel{} + ctx := context.Background() + errFail := errors.New("fail") + + for i, op := range ops { + switch op { + case opExecSuccess: + _ = c.Execute(ctx, func(context.Context) error { return nil }, nil) + case opExecFailure: + _ = c.Execute(ctx, func(context.Context) error { return errFail }, nil) + case opExecBadRequest: + _ = c.Execute(ctx, func(context.Context) error { + return SimpleBadRequest{Err: errFail} + }, nil) + case opExecFallbackSuccess: + _ = c.Execute(ctx, + func(context.Context) error { return errFail }, + func(context.Context, error) error { return nil }) + case opExecFallbackFailure: + _ = c.Execute(ctx, + func(context.Context) error { return errFail }, + func(context.Context, error) error { return errFail }) + case opOpenCircuit: + c.OpenCircuit(ctx) + case opCloseCircuit: + c.CloseCircuit(ctx) + case opSetWantOpenTrue: + opener.wantOpen.Set(true) + case opSetWantOpenFalse: + opener.wantOpen.Set(false) + case opSetAllowTrue: + closer.allow.Set(true) + case opSetAllowFalse: + closer.allow.Set(false) + case opSetWantCloseTrue: + closer.wantClose.Set(true) + case opSetWantCloseFalse: + closer.wantClose.Set(false) + case opAdvanceClock: + mc.Add(time.Millisecond) + } + + model.apply(op) + + if msg := checkInvariants(rm, fm, tc, c, model, i, op); msg != "" { + return msg, false + } + } + return "", true +} + +// TestCircuit_StateMachineProperty replays random sequences of operations +// against both the real Circuit and a trivial reference model, asserting they +// agree at every step. Fully deterministic (MockClock, single goroutine) — any +// failure is reproducible from the printed op list. +func TestCircuit_StateMachineProperty(t *testing.T) { + prop := func(raw []uint8) bool { + if len(raw) == 0 { + return true + } + ops := make([]circuitOp, len(raw)) + for i, b := range raw { + // b%numOps is 0..13, safely fits int8. + ops[i] = circuitOp(b % uint8(numOps)) //nolint:gosec // bounded modulo result + } + msg, ok := replayOps(ops) + if !ok { + t.Logf("FAIL: %s\nops (len=%d): %v", msg, len(ops), opNames(ops)) + } + return ok + } + + cfg := &quick.Config{ + MaxCount: 500, + Rand: rand.New(rand.NewSource(0x5EED)), + } + if err := quick.Check(prop, cfg); err != nil { + t.Fatal(err) + } +} + +// TestCircuit_StateMachineProperty_Handwritten runs a few curated sequences +// that encode past bug reproductions. Keeps them green without relying solely +// on random search to rediscover them. +func TestCircuit_StateMachineProperty_Handwritten(t *testing.T) { + cases := []struct { + name string + ops []circuitOp + }{ + { + name: "open-then-execute-shortcircuits", + ops: []circuitOp{opOpenCircuit, opExecSuccess, opExecFailure}, + }, + { + name: "wantOpen-failure-opens", + ops: []circuitOp{opSetWantOpenTrue, opExecFailure, opExecSuccess}, + }, + { + name: "half-open-success-closes", + ops: []circuitOp{ + opOpenCircuit, opSetAllowTrue, opSetWantCloseTrue, + opExecSuccess, opExecSuccess, + }, + }, + { + name: "repeated-OpenCircuit-is-idempotent", + ops: []circuitOp{ + opOpenCircuit, opOpenCircuit, opOpenCircuit, + opCloseCircuit, opCloseCircuit, + }, + }, + { + name: "fallback-fires-on-short-circuit", + ops: []circuitOp{ + opOpenCircuit, opExecFallbackSuccess, opExecFallbackFailure, + }, + }, + { + name: "bad-request-never-opens", + ops: []circuitOp{ + opSetWantOpenTrue, + opExecBadRequest, opExecBadRequest, opExecBadRequest, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if msg, ok := replayOps(tc.ops); !ok { + t.Fatalf("%s\nops: %v", msg, opNames(tc.ops)) + } + }) + } +} + +func opNames(ops []circuitOp) []string { + names := make([]string, len(ops)) + for i, op := range ops { + names[i] = opName(op) + } + return names +} diff --git a/circuit_stress_test.go b/circuit_stress_test.go index 78b8a24..3f90d6e 100644 --- a/circuit_stress_test.go +++ b/circuit_stress_test.go @@ -104,10 +104,17 @@ func TestRaceOnConfigChange(t *testing.T) { wg.Wait() } -// TestCircuitStateTransitionRace tests for race conditions during circuit state transitions +// TestCircuitStateTransitionRace tests for race conditions during circuit +// state transitions. Previously this test had NO assertions (just t.Logf). +// Now it verifies the Opened/Closed alternation invariant under contention. func TestCircuitStateTransitionRace(t *testing.T) { + tc := &transitionCounter{} // Create a circuit that will open after 20 consecutive failures - c := NewCircuitFromConfig("state-transition-race", Config{}) + c := NewCircuitFromConfig("state-transition-race", Config{ + Metrics: MetricsCollectors{ + Circuit: []Metrics{tc}, + }, + }) var wg sync.WaitGroup goroutines := 100 @@ -160,6 +167,21 @@ func TestCircuitStateTransitionRace(t *testing.T) { t.Logf("Circuit open observed: %d, Circuit closed observed: %d", circuitOpenObserved, circuitClosedObserved) + + // Transition alternation: starting from closed, each Closed() must be + // preceded by an Opened(). Hence: closed ≤ opened ≤ closed+1. + opened, closed := tc.opened.Get(), tc.closed.Get() + if closed > opened { + t.Errorf("Closed()=%d > Opened()=%d — Closed emitted without matching Opened", closed, opened) + } + if opened > closed+1 { + t.Errorf("Opened()=%d > Closed()+1=%d — duplicate Opened() emission (TOCTOU regression)", + opened, closed+1) + } + + if cc := c.ConcurrentCommands(); cc != 0 { + t.Errorf("concurrentCommands counter unbalanced: %d", cc) + } } // TestContextCancellationStress tests how the circuit handles many context cancellations diff --git a/faststats/fuzz_test.go b/faststats/fuzz_test.go new file mode 100644 index 0000000..022de58 --- /dev/null +++ b/faststats/fuzz_test.go @@ -0,0 +1,429 @@ +package faststats + +import ( + "encoding/binary" + "encoding/json" + "math" + "sort" + "testing" + "time" +) + +// ============================================================================ +// Fuzz targets for faststats. These encode structural invariants as executable +// specifications — the kind of edge-case bugs (zero-values, empty JSON, +// nil-derefs, div-by-zero) that unit tests miss and fuzzing finds fast. +// +// Run with: go test -fuzz=FuzzRollingCounterOps -fuzztime=30s ./faststats +// ============================================================================ + +// naiveRollingCounter is a trivially-correct, lock-free-unfriendly reference +// implementation used as a differential oracle. It records every Inc time and +// answers queries by linear scan — obviously correct, obviously slow. +type naiveRollingCounter struct { + bucketWidth time.Duration + numBuckets int + startTime time.Time + incs []time.Time +} + +func newNaive(bucketWidth time.Duration, numBuckets int, startTime time.Time) *naiveRollingCounter { + return &naiveRollingCounter{ + bucketWidth: bucketWidth, + numBuckets: numBuckets, + startTime: startTime, + } +} + +func (n *naiveRollingCounter) Inc(now time.Time) { + // Mirror real RollingCounter semantics: ignore times before startTime. + if now.Before(n.startTime) { + return + } + n.incs = append(n.incs, now) +} + +// RollingSumAt returns events whose bucket is still within the window at 'now'. +// A past event at time t is in-window iff the current bucket index minus t's +// bucket index < numBuckets. This matches RollingBuckets.Advance semantics. +func (n *naiveRollingCounter) RollingSumAt(now time.Time) int64 { + if n.numBuckets == 0 || n.bucketWidth == 0 { + return 0 + } + if now.Before(n.startTime) { + return 0 + } + nowIdx := now.Sub(n.startTime).Nanoseconds() / n.bucketWidth.Nanoseconds() + var sum int64 + for _, t := range n.incs { + if t.After(now) { + // The real RollingCounter can retain future Incs if the window + // later advances past them. But for a monotone time sequence + // (which we use in the fuzz test) this never happens. + continue + } + tIdx := t.Sub(n.startTime).Nanoseconds() / n.bucketWidth.Nanoseconds() + if nowIdx-tIdx < int64(n.numBuckets) { + sum++ + } + } + return sum +} + +func (n *naiveRollingCounter) TotalSum() int64 { + return int64(len(n.incs)) +} + +// FuzzRollingCounterOps drives RollingCounter with a fuzzed sequence of +// monotone-time Inc/RollingSumAt/GetBuckets ops and checks invariants: +// - TotalSum == number of Inc calls (the obvious conservation law) +// - RollingSumAt ≥ 0 always +// - RollingSumAt ≤ TotalSum always +// - sum(GetBuckets) == RollingSumAt (buckets and rolling sum agree) +// - RollingSumAt matches the naive oracle +// +// Time is monotone here because RollingCounter's behaviour on backward time is +// intentionally lossy (see rolling_bucket.go: events landing in buckets that +// have since been cleared are dropped). Fuzzing backward time would produce +// spurious oracle mismatches. +func FuzzRollingCounterOps(f *testing.F) { + f.Add([]byte{3, 1, 0, 1, 0, 2, 0, 1, 2}) + f.Add([]byte{10, 5, 0, 0, 0, 0, 2}) + f.Add([]byte{1, 100}) + f.Add([]byte{}) + + f.Fuzz(func(t *testing.T, data []byte) { + if len(data) < 2 { + return + } + // Derive config from first two bytes. Keep ranges small so the fuzzer + // explores op sequences rather than giant bucket arrays. + numBuckets := int(data[0])%16 + 1 // 1..16 + bucketWidth := time.Duration(int(data[1])%100+1) * time.Millisecond // 1..100ms + ops := data[2:] + + start := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + rc := NewRollingCounter(bucketWidth, numBuckets, start) + oracle := newNaive(bucketWidth, numBuckets, start) + + now := start + var incCount int64 + + for _, b := range ops { + op := b % 3 + // Always advance time by a small fuzz-derived amount. + // Monotone non-decreasing; may advance by 0. + step := time.Duration(b>>2) * bucketWidth / 8 + now = now.Add(step) + + switch op { + case 0: // Inc + rc.Inc(now) + oracle.Inc(now) + incCount++ + case 1: // RollingSumAt + invariants + got := rc.RollingSumAt(now) + if got < 0 { + t.Fatalf("RollingSumAt < 0: %d (now=%s)", got, now.Sub(start)) + } + if got > rc.TotalSum() { + t.Fatalf("RollingSumAt=%d > TotalSum=%d", got, rc.TotalSum()) + } + want := oracle.RollingSumAt(now) + if got != want { + t.Fatalf("RollingSumAt mismatch: real=%d oracle=%d "+ + "(numBuckets=%d bucketWidth=%s now=+%s incs=%d)", + got, want, numBuckets, bucketWidth, now.Sub(start), incCount) + } + case 2: // GetBuckets + sum-consistency + buckets := rc.GetBuckets(now) + if len(buckets) != numBuckets { + t.Fatalf("GetBuckets len=%d, want %d", len(buckets), numBuckets) + } + var bsum int64 + for _, v := range buckets { + if v < 0 { + t.Fatalf("negative bucket value: %d", v) + } + bsum += v + } + rsum := rc.RollingSumAt(now) + if bsum != rsum { + t.Fatalf("sum(GetBuckets)=%d ≠ RollingSumAt=%d (buckets=%v)", bsum, rsum, buckets) + } + } + } + + // Final conservation check. + if rc.TotalSum() != incCount { + t.Fatalf("TotalSum=%d ≠ Inc count=%d", rc.TotalSum(), incCount) + } + if oracle.TotalSum() != incCount { + t.Fatalf("oracle TotalSum=%d ≠ Inc count=%d", oracle.TotalSum(), incCount) + } + }) +} + +// FuzzRollingCounterJSON verifies Marshal→Unmarshal round-trips and that +// arbitrary bytes fed to Unmarshal never panic (bug #8 in past fixes was a +// nil-deref on `{}`). The guard at rolling_counter.go:65 should make this safe. +func FuzzRollingCounterJSON(f *testing.F) { + f.Add([]byte(`{}`)) + f.Add([]byte(`{"TotalSum":5}`)) + f.Add([]byte(`{"Buckets":[1,2],"RollingSum":3,"TotalSum":3,"RollingBucket":{"NumBuckets":2,"StartTime":"2020-01-01T00:00:00Z","BucketWidth":1000000,"LastAbsIndex":0}}`)) + f.Add([]byte(`null`)) + f.Add([]byte(`[]`)) + f.Add([]byte(`"`)) + + f.Fuzz(func(t *testing.T, data []byte) { + // Part 1: arbitrary bytes must not panic UnmarshalJSON. On error the + // receiver is unmodified; on success the state is validated-consistent. + // Either way, subsequent method calls must be safe. + var sink RollingCounter + _ = sink.UnmarshalJSON(data) // error is fine; panic is not + now := time.Now() + _ = sink.GetBuckets(now) + _ = sink.RollingSumAt(now) + _ = sink.TotalSum() + _ = sink.String() + + // Part 2: seed a counter from the same data bytes, then verify + // Marshal→Unmarshal is a clean round trip. + if len(data) < 4 { + return + } + numBuckets := int(data[0])%8 + 1 + bucketWidth := time.Duration(int(data[1])%50+1) * time.Millisecond + start := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + orig := NewRollingCounter(bucketWidth, numBuckets, start) + for _, b := range data[2:] { + orig.Inc(start.Add(time.Duration(b) * time.Millisecond)) + } + + marshaled, err := json.Marshal(&orig) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var restored RollingCounter + if err := json.Unmarshal(marshaled, &restored); err != nil { + t.Fatalf("Unmarshal of valid Marshal output: %v (json=%s)", err, marshaled) + } + if restored.TotalSum() != orig.TotalSum() { + t.Fatalf("round-trip TotalSum: got=%d want=%d", restored.TotalSum(), orig.TotalSum()) + } + // GetBuckets must not panic on restored value. + _ = restored.GetBuckets(start) + }) +} + +// FuzzSortedDurationsPercentile checks the Percentile function: +// - Percentile(p) ∈ [Min, Max] for any p (including NaN, ±Inf, out-of-range) +// - Monotone non-decreasing in p +// - Never panics on any input +func FuzzSortedDurationsPercentile(f *testing.F) { + // Seeds: edge cases from past bug hunting. + seed := func(durs []uint64, p float64) []byte { + buf := make([]byte, 8+8*len(durs)) + binary.LittleEndian.PutUint64(buf, math.Float64bits(p)) + for i, d := range durs { + binary.LittleEndian.PutUint64(buf[8+8*i:], d) + } + return buf + } + f.Add(seed([]uint64{1, 2, 3}, 50.0)) + f.Add(seed([]uint64{100}, 0.0)) + f.Add(seed([]uint64{}, 99.9)) + f.Add(seed([]uint64{1, 2}, math.Inf(1))) + f.Add(seed([]uint64{1, 2}, math.Inf(-1))) + f.Add(seed([]uint64{5, 5, 5}, -1.0)) + + f.Fuzz(func(t *testing.T, data []byte) { + if len(data) < 8 { + return + } + p := math.Float64frombits(binary.LittleEndian.Uint64(data)) + + // Build a sorted duration list from remaining bytes. Keep values + // non-negative and bounded to avoid int64 overflow in Mean()'s sum. + raw := data[8:] + n := len(raw) / 8 + if n > 1000 { + n = 1000 + } + durs := make([]time.Duration, n) + const maxDur = uint64(time.Hour) + for i := 0; i < n; i++ { + v := binary.LittleEndian.Uint64(raw[8*i:]) % maxDur + // Bound to [0, 1h] so sum of 1000 durations stays well under int64 max. + durs[i] = time.Duration(int64(v)) //nolint:gosec // v < 2^62, fits int64 + } + sort.Slice(durs, func(i, j int) bool { return durs[i] < durs[j] }) + sd := SortedDurations(durs) + + // Any input, including NaN/Inf p, must not panic. + got := sd.Percentile(p) + mn := sd.Min() + mx := sd.Max() + _ = sd.Mean() + _ = sd.String() + + if len(sd) == 0 { + if got != -1 || mn != -1 || mx != -1 { + t.Fatalf("empty list should return -1, got pct=%d min=%d max=%d", got, mn, mx) + } + return + } + + // Bounds: Percentile(p) ∈ [Min, Max] for any non-NaN p; NaN returns -1. + if math.IsNaN(p) { + if got != -1 { + t.Fatalf("Percentile(NaN)=%v, want -1", got) + } + return + } + if got < mn || got > mx { + t.Fatalf("Percentile(%g)=%v out of [Min=%v, Max=%v] (n=%d)", p, got, mn, mx, len(sd)) + } + + // Monotonicity: for any p1 ≤ p2, Percentile(p1) ≤ Percentile(p2). + // Sample a second percentile from the same fuzz input. + if len(data) >= 16 { + p2 := math.Float64frombits(binary.LittleEndian.Uint64(data[len(data)-8:])) + if !math.IsNaN(p2) && !math.IsInf(p2, 0) && !math.IsInf(p, 0) { + lo, hi := p, p2 + if lo > hi { + lo, hi = hi, lo + } + if sd.Percentile(lo) > sd.Percentile(hi) { + t.Fatalf("Percentile not monotone: P(%g)=%v > P(%g)=%v", + lo, sd.Percentile(lo), hi, sd.Percentile(hi)) + } + } + } + }) +} + +// TestRollingCounter_UnmarshalJSON_InconsistentState is a regression test for +// a panic found by FuzzRollingCounterJSON: hostile JSON with NumBuckets > 0 +// but Buckets == nil previously passed validation and caused GetBuckets to +// index out of range. Now rejected at unmarshal time. +func TestRollingCounter_UnmarshalJSON_InconsistentState(t *testing.T) { + var x RollingCounter + // Buckets omitted; NumBuckets=1 via RollingBucket. + hostile := []byte(`{"RollingSum":0,"TotalSum":0,"RollingBucket":{"NumBuckets":1,"StartTime":"2020-01-01T00:00:00Z","BucketWidth":1000000,"LastAbsIndex":0}}`) + if err := x.UnmarshalJSON(hostile); err == nil { + t.Fatal("expected error for inconsistent JSON (NumBuckets=1, Buckets=nil)") + } + // Receiver must be unmodified (zero-value) on error; methods must be safe. + if b := x.GetBuckets(time.Now()); b != nil { + t.Errorf("GetBuckets after rejected unmarshal = %v, want nil", b) + } +} + +// TestSortedDurations_Percentile_NaN is a regression test for a panic found +// by FuzzSortedDurationsPercentile: Percentile(NaN) falls through both the +// p <= 0 and p >= 100 guards (NaN comparisons are always false), then did +// int(math.Floor(NaN)) which is platform-undefined — INT64_MIN on amd64, +// causing index-out-of-range. Now returns -1. +func TestSortedDurations_Percentile_NaN(t *testing.T) { + sd := SortedDurations{time.Millisecond, time.Millisecond * 2} + if got := sd.Percentile(math.NaN()); got != -1 { + t.Errorf("Percentile(NaN) = %v, want -1", got) + } +} + +// FuzzRollingBucketAdvance checks the lock-free Advance loop at +// rolling_bucket.go:28-74 — the hairiest code in the package. Invariants: +// - Returned index is either -1 or in [0, NumBuckets) +// - clearBucket is never called with an out-of-range index +// - LastAbsIndex is monotone non-decreasing across calls +func FuzzRollingBucketAdvance(f *testing.F) { + f.Add([]byte{5, 10, 0, 1, 2, 3, 4, 5, 100, 200}) + f.Add([]byte{1, 1, 0, 0, 0}) + f.Add([]byte{3, 50, 255, 0, 255}) + + f.Fuzz(func(t *testing.T, data []byte) { + if len(data) < 2 { + return + } + numBuckets := int(data[0])%16 + 1 + bucketWidth := time.Duration(int(data[1])%100+1) * time.Millisecond + timeBytes := data[2:] + + start := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + rb := RollingBuckets{ + NumBuckets: numBuckets, + StartTime: start, + BucketWidth: bucketWidth, + } + + clearFn := func(idx int) { + if idx < 0 || idx >= numBuckets { + t.Fatalf("clearBucket called with out-of-range idx=%d (numBuckets=%d)", idx, numBuckets) + } + } + + prevLastAbs := rb.LastAbsIndex.Get() + for _, b := range timeBytes { + // Time can jump forward by up to ~63 bucket widths, or backward + // by up to ~2 bucket widths. Both are realistic workloads. + delta := time.Duration(int(b)-4) * bucketWidth / 4 + now := start.Add(delta) + // Allow start to creep forward too so we explore large absolute + // indices over many ops. + if delta > 0 { + start = start.Add(delta) + } + + idx := rb.Advance(now, clearFn) + + if idx != -1 && (idx < 0 || idx >= numBuckets) { + t.Fatalf("Advance returned out-of-range index: %d (numBuckets=%d)", idx, numBuckets) + } + + lastAbs := rb.LastAbsIndex.Get() + if lastAbs < prevLastAbs { + t.Fatalf("LastAbsIndex went backward: %d -> %d", prevLastAbs, lastAbs) + } + prevLastAbs = lastAbs + } + }) +} + +// FuzzTimedCheckJSON verifies Marshal→Unmarshal round-trip and that arbitrary +// bytes never panic Unmarshal (or subsequent method calls). +func FuzzTimedCheckJSON(f *testing.F) { + f.Add([]byte(`{}`)) + f.Add([]byte(`{"SleepDuration":1000000000,"EventCountToAllow":5}`)) + f.Add([]byte(`null`)) + f.Add([]byte(`{"NextOpenTime":"invalid"}`)) + + f.Fuzz(func(t *testing.T, data []byte) { + // Part 1: arbitrary bytes never panic. + var tc TimedCheck + _ = tc.UnmarshalJSON(data) + now := time.Now() + _ = tc.Check(now) + _ = tc.String() + + // Part 2: round-trip a configured TimedCheck. + if len(data) < 2 { + return + } + var orig TimedCheck + orig.SetSleepDuration(time.Duration(data[0]) * time.Millisecond) + orig.SetEventCountToAllow(int64(data[1])) + + marshaled, err := json.Marshal(&orig) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var restored TimedCheck + if err := json.Unmarshal(marshaled, &restored); err != nil { + t.Fatalf("Unmarshal valid Marshal output: %v (json=%s)", err, marshaled) + } + // Can't directly compare because fields are unexported, but Check must + // not panic. + _ = restored.Check(now) + }) +} diff --git a/faststats/rolling_counter.go b/faststats/rolling_counter.go index ca606d0..4590450 100644 --- a/faststats/rolling_counter.go +++ b/faststats/rolling_counter.go @@ -65,6 +65,10 @@ func (r *RollingCounter) UnmarshalJSON(b []byte) error { if into.RollingSum == nil || into.TotalSum == nil || into.RollingBucket == nil { return fmt.Errorf("RollingCounter.UnmarshalJSON: incomplete JSON (missing required fields)") } + if len(into.Buckets) != into.RollingBucket.NumBuckets { + return fmt.Errorf("RollingCounter.UnmarshalJSON: inconsistent JSON (Buckets len=%d, NumBuckets=%d)", + len(into.Buckets), into.RollingBucket.NumBuckets) + } r.buckets = into.Buckets r.rollingSum.Store(into.RollingSum.Get()) r.totalSum.Store(into.TotalSum.Get()) diff --git a/faststats/rolling_percentile.go b/faststats/rolling_percentile.go index d030d23..eb0efab 100644 --- a/faststats/rolling_percentile.go +++ b/faststats/rolling_percentile.go @@ -78,8 +78,8 @@ func (s SortedDurations) Var() expvar.Var { // Percentile returns a p [0 - 100] percentile of the list func (s SortedDurations) Percentile(p float64) time.Duration { - if len(s) == 0 { - // A meaningless value for a meaningless list + if len(s) == 0 || math.IsNaN(p) { + // A meaningless value for a meaningless list or meaningless percentile return -1 } if len(s) == 1 { diff --git a/faststats/rolling_stress_test.go b/faststats/rolling_stress_test.go index 53fa7ab..7d03833 100644 --- a/faststats/rolling_stress_test.go +++ b/faststats/rolling_stress_test.go @@ -54,11 +54,8 @@ func TestRollingCounterConcurrency(t *testing.T) { t.Logf("Successfully processed %d concurrent increments", totalIncrements) } -// TestRollingBucketConcurrency tests that RollingBuckets are thread-safe -func TestRollingBucketConcurrency(t *testing.T) { - // Skip this test since we don't have direct access to bucket functionality - t.Skip("RollingBucket implementation not directly accessible") -} +// TestRollingBucketConcurrency was permanently skipped; coverage provided by +// TestRollingBuckets_ConcurrentAdvance in rolling_bucket_test.go. // TestRollingPercentileConcurrency tests that RollingPercentile is thread-safe func TestRollingPercentileConcurrency(t *testing.T) { @@ -141,10 +138,60 @@ func TestRollingPercentileConcurrency(t *testing.T) { } } -// TestTimedCheckConcurrency tests that TimedCheck is thread-safe +// TestTimedCheckConcurrency tests that TimedCheck is thread-safe under +// concurrent Check + SleepStart + SetSleepDuration/SetEventCountToAllow. +// Previously skipped on the mistaken belief TimedCheck was not exported. +// This verifies no -race failures and that Check never returns true more +// than eventCountToAllow times per sleep window (the TimedCheck contract). func TestTimedCheckConcurrency(t *testing.T) { - // Skip since we don't have direct access to TimedCheck - t.Skip("TimedCheck not directly accessible for testing") + var x TimedCheck + x.SetSleepDuration(time.Millisecond) + x.SetEventCountToAllow(1) + x.SleepStart(time.Now()) + + var wg sync.WaitGroup + var running atomic.Bool + running.Store(true) + var checkTrueCount atomic.Int64 + + // Checkers hammer Check(). + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for running.Load() { + if x.Check(time.Now()) { + checkTrueCount.Add(1) + } + } + }() + } + + // Writer changes config concurrently. + wg.Add(1) + go func() { + defer wg.Done() + for running.Load() { + x.SetSleepDuration(time.Millisecond) + x.SetEventCountToAllow(1) + } + }() + + time.Sleep(time.Millisecond * 50) + running.Store(false) + wg.Wait() + + // With sleepDuration=1ms, eventCount=1, over 50ms we expect ~50 true + // returns. We don't assert an exact bound (timer jitter, test load) but + // the count should be plausible — not 0, and not tens of thousands + // (which would indicate the allow-once-per-window gate is broken). + ct := checkTrueCount.Load() + if ct == 0 { + t.Logf("warning: Check never returned true — may indicate test ineffective on slow CI") + } + if ct > 500 { + t.Errorf("Check returned true %d times in ~50ms with 1ms sleep window — allow-once gate may be broken", ct) + } } // TestRollingCounterBucketRolloverRace tests for race conditions during bucket rollover @@ -180,7 +227,12 @@ func TestRollingCounterBucketRolloverRace(t *testing.T) { }() } - // Start more threads that read percentiles during rollover + // Start more threads that read during rollover and check invariants. + // NOTE: RollingSumAt can briefly return a negative value during concurrent + // rollover — the lock-free design (see rolling_bucket.go doc) accepts a + // transient window where clearBucket has decremented rollingSum for a + // bucket value that an in-flight Inc hasn't yet added to rollingSum. The + // value is correct once all writers quiesce; we check that post-wg.Wait. readThreads := 10 for g := 0; g < readThreads; g++ { wg.Add(1) @@ -190,14 +242,15 @@ func TestRollingCounterBucketRolloverRace(t *testing.T) { for atomic.LoadInt32(&running) == 1 { sum := counter.TotalSum() if sum < 0 { - t.Errorf("Counter sum went negative: %d", sum) + t.Errorf("TotalSum went negative: %d", sum) } - // Also check rolling sum now := time.Now() rollingSum := counter.RollingSumAt(now) - if rollingSum < 0 { - t.Errorf("Rolling sum went negative: %d", rollingSum) + // Rolling sum can never exceed total sum (both monotone from + // this observation point: rollingSum read, THEN totalSum read). + if ts := counter.TotalSum(); rollingSum > ts { + t.Errorf("RollingSum=%d > TotalSum=%d", rollingSum, ts) } } }() @@ -209,5 +262,15 @@ func TestRollingCounterBucketRolloverRace(t *testing.T) { wg.Wait() + // At quiescence (all writers stopped) the transient-negative window is + // closed and rollingSum must be non-negative and consistent. + finalRoll := counter.RollingSumAt(time.Now()) + if finalRoll < 0 { + t.Errorf("RollingSumAt negative after quiescence: %d", finalRoll) + } + if finalTS := counter.TotalSum(); finalRoll > finalTS { + t.Errorf("RollingSumAt=%d > TotalSum=%d after quiescence", finalRoll, finalTS) + } + t.Logf("Added %d items during high-frequency rollover test", totalAdded) }