Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Added new `HookQueueStateCount` hook which is run by a River leader to generate queue count statistics. [PR #1203](https://github.com/riverqueue/river/pull/1203).
- Middleware that implements `rivertype.Hook` can be looked up as hooks even if passed into `Config.Middleware`. Similarly, hooks that implement `rivertype.Middleware` can be looked up as middleware even if passed into `Config.Hooks`. [PR #1203](https://github.com/riverqueue/river/pull/1203).

## [0.34.0] - 2026-04-08

### Added
Expand Down
55 changes: 49 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"regexp"
"slices"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -657,6 +658,7 @@ type clientTestSignals struct {
periodicJobEnqueuer *maintenance.PeriodicJobEnqueuerTestSignals
queueCleaner *maintenance.QueueCleanerTestSignals
queueMaintainerLeader *maintenance.QueueMaintainerLeaderTestSignals
queueStateCounter *maintenance.QueueStateCounterTestSignals
reindexer *maintenance.ReindexerTestSignals
}

Expand All @@ -679,6 +681,9 @@ func (ts *clientTestSignals) Init(tb testutil.TestingTB) {
if ts.queueMaintainerLeader != nil {
ts.queueMaintainerLeader.Init(tb)
}
if ts.queueStateCounter != nil {
ts.queueStateCounter.Init(tb)
}
if ts.reindexer != nil {
ts.reindexer.Init(tb)
}
Expand Down Expand Up @@ -759,7 +764,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
config: config,
driver: driver,
hookLookupByJob: hooklookup.NewJobHookLookup(),
hookLookupGlobal: hooklookup.NewHookLookup(config.Hooks),
hookLookupGlobal: nil, // initialized below after cross-referencing with middleware
producersByQueueName: make(map[string]*producer),
testSignals: clientTestSignals{},
workCancel: func(cause error) {}, // replaced on start, but here in case StopAndCancel is called before start up
Expand All @@ -780,9 +785,9 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
// the more abstract config.Middleware for middleware are set, but not both,
// so in practice we never append all three of these to each other.
{
middleware := config.Middleware
middlewares := config.Middleware
for _, jobInsertMiddleware := range config.JobInsertMiddleware {
middleware = append(middleware, jobInsertMiddleware)
middlewares = append(middlewares, jobInsertMiddleware)
}
outerLoop:
for _, workerMiddleware := range config.WorkerMiddleware {
Expand All @@ -798,16 +803,44 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
}
}

middleware = append(middleware, workerMiddleware)
middlewares = append(middlewares, workerMiddleware)
}

for _, middleware := range middleware {
for _, middleware := range middlewares {
if withBaseService, ok := middleware.(baseservice.WithBaseService); ok {
baseservice.Init(archetype, withBaseService)
}
}

client.middlewareLookupGlobal = middlewarelookup.NewMiddlewareLookup(middleware)
// Cross-reference hooks and middleware: any middleware that also
// implements Hook is added to hooks, and any hook that also implements
// Middleware is added to middleware. Deduplication prevents double
// registration when the same struct is passed to both Config.Hooks and
// Config.Middleware.
hooks := config.Hooks

for _, middleware := range middlewares {
if hook, ok := middleware.(rivertype.Hook); ok {
// Only add if this middleware isn't already in hooks (it may
// have been passed to both config properties).
alreadyInHooks := slices.Contains(hooks, hook)
if !alreadyInHooks {
hooks = append(hooks, hook)
}
}
}

for _, hook := range config.Hooks {
if middleware, ok := hook.(rivertype.Middleware); ok {
alreadyInMiddleware := slices.Contains(middlewares, middleware)
if !alreadyInMiddleware {
middlewares = append(middlewares, middleware)
}
}
}

client.hookLookupGlobal = hooklookup.NewHookLookup(hooks)
client.middlewareLookupGlobal = middlewarelookup.NewMiddlewareLookup(middlewares)
}

pluginDriver, _ := driver.(driverPlugin[TTx])
Expand Down Expand Up @@ -961,6 +994,16 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
client.testSignals.queueCleaner = &queueCleaner.TestSignals
}

{
queueStateCounter := maintenance.NewQueueStateCounter(archetype, &maintenance.QueueStateCounterConfig{
HookLookupGlobal: client.hookLookupGlobal,
QueueNames: maputil.Keys(config.Queues),
Schema: config.Schema,
}, driver.GetExecutor())
maintenanceServices = append(maintenanceServices, queueStateCounter)
client.testSignals.queueStateCounter = &queueStateCounter.TestSignals
}

{
var scheduleFunc func(time.Time) time.Time
if config.ReindexerSchedule != nil {
Expand Down
69 changes: 69 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/tidwall/sjson"

"github.com/riverqueue/river/internal/dbunique"
"github.com/riverqueue/river/internal/hooklookup"
"github.com/riverqueue/river/internal/jobexecutor"
"github.com/riverqueue/river/internal/maintenance"
"github.com/riverqueue/river/internal/middlewarelookup"
Expand Down Expand Up @@ -7626,6 +7627,46 @@ func Test_NewClient_Validations(t *testing.T) {
require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1)
},
},
{
name: "Middleware implementing Hook is available in hook lookup",
configFunc: func(config *Config) {
config.Middleware = []rivertype.Middleware{&middlewareWithHook{}}
},
validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper
require.Len(t, client.hookLookupGlobal.ByHookKind(hooklookup.HookKindWorkBegin), 1)
},
},
{
name: "Hook implementing Middleware is available in middleware lookup",
configFunc: func(config *Config) {
config.Hooks = []rivertype.Hook{&hookWithMiddleware{}}
},
validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper
require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1)
},
},
{
name: "Hook implementing Middleware in both configs is deduplicated in middleware lookup",
configFunc: func(config *Config) {
hm := &hookWithMiddleware{}
config.Hooks = []rivertype.Hook{hm}
config.Middleware = []rivertype.Middleware{hm}
},
validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper
require.Len(t, client.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindWorker), 1)
},
},
{
name: "Middleware implementing Hook in both configs is deduplicated in hook lookup",
configFunc: func(config *Config) {
mh := &middlewareWithHook{}
config.Hooks = []rivertype.Hook{mh}
config.Middleware = []rivertype.Middleware{mh}
},
validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper
require.Len(t, client.hookLookupGlobal.ByHookKind(hooklookup.HookKindWorkBegin), 1)
},
},
{
name: "Middleware not allowed with JobInsertMiddleware",
configFunc: func(config *Config) {
Expand Down Expand Up @@ -8625,3 +8666,31 @@ func (f JobArgsWithHooksFunc) Hooks() []rivertype.Hook {
func (JobArgsWithHooksFunc) MarshalJSON() ([]byte, error) { return []byte("{}"), nil }

func (JobArgsWithHooksFunc) UnmarshalJSON([]byte) error { return nil }

// middlewareWithHook is a middleware that also implements HookWorkBegin,
// used to test cross-pollination from middleware to hooks.
type middlewareWithHook struct {
MiddlewareDefaults
}

func (m *middlewareWithHook) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
return doInner(ctx)
}

func (m *middlewareWithHook) IsHook() bool { return true }

func (m *middlewareWithHook) WorkBegin(ctx context.Context, job *rivertype.JobRow) error {
return nil
}

// hookWithMiddleware is a hook that also implements WorkerMiddleware,
// used to test cross-pollination from hooks to middleware.
type hookWithMiddleware struct {
HookDefaults
}

func (h *hookWithMiddleware) IsMiddleware() bool { return true }

func (h *hookWithMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
return doInner(ctx)
}
10 changes: 10 additions & 0 deletions hook_defaults_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ func (f HookPeriodicJobsStartFunc) Start(ctx context.Context, params *rivertype.
return f(ctx, params)
}

// HookQueueStateCountFunc is a convenience helper for implementing
// rivertype.HookQueueStateCount using a simple function instead of a struct.
type HookQueueStateCountFunc func(ctx context.Context, params *rivertype.HookQueueStateCountParams)

func (f HookQueueStateCountFunc) IsHook() bool { return true }

func (f HookQueueStateCountFunc) QueueStateCount(ctx context.Context, params *rivertype.HookQueueStateCountParams) {
f(ctx, params)
}

// HookWorkBeginFunc is a convenience helper for implementing
// rivertype.HookWorkBegin using a simple function instead of a struct.
type HookWorkBeginFunc func(ctx context.Context, job *rivertype.JobRow) error
Expand Down
7 changes: 7 additions & 0 deletions internal/hooklookup/hook_lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type HookKind string
const (
HookKindInsertBegin HookKind = "insert_begin"
HookKindPeriodicJobsStart HookKind = "periodic_job_start"
HookKindQueueStateCount HookKind = "queue_state_count"
HookKindWorkBegin HookKind = "work_begin"
HookKindWorkEnd HookKind = "work_end"
)
Expand Down Expand Up @@ -90,6 +91,12 @@ func (c *hookLookup) ByHookKind(kind HookKind) []rivertype.Hook {
c.hooksByKind[kind] = append(c.hooksByKind[kind], typedHook)
}
}
case HookKindQueueStateCount:
for _, hook := range c.hooks {
if typedHook, ok := hook.(rivertype.HookQueueStateCount); ok {
c.hooksByKind[kind] = append(c.hooksByKind[kind], typedHook)
}
}
case HookKindWorkBegin:
for _, hook := range c.hooks {
if typedHook, ok := hook.(rivertype.HookWorkBegin); ok {
Expand Down
Loading
Loading