From c5f65b47a5394cb13d8e6c3aee2fafe245da34dd Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Mon, 2 Feb 2026 18:16:41 +0100 Subject: [PATCH 01/19] fix: page respects paginator max size --- page.go | 73 +++++++++++++++++++++++++++++----------------------- page_test.go | 10 +++---- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/page.go b/page.go index 86f9efc..da1aaae 100644 --- a/page.go +++ b/page.go @@ -123,7 +123,7 @@ func (p *Page) GetOrder(defaultSort ...string) []Sort { return sort } -func (p *Page) Offset() uint64 { +func (p *Page) Offset(o *PaginatorOptions) uint64 { n := uint64(1) if p != nil && p.Page != 0 { n = uint64(p.Page) @@ -131,11 +131,14 @@ func (p *Page) Offset() uint64 { if n < 1 { n = 1 } - return (n - 1) * p.Limit() + return (n - 1) * p.Limit(o) } -func (p *Page) Limit() uint64 { - var n = uint64(DefaultPageSize) +func (p *Page) Limit(o *PaginatorOptions) uint64 { + n := uint64(DefaultPageSize) + if o != nil && o.DefaultSize != 0 { + n = uint64(o.DefaultSize) + } if p != nil && p.Size != 0 { n = uint64(p.Size) } @@ -146,55 +149,61 @@ func (p *Page) Limit() uint64 { } // PaginatorOption is a function that sets an option on a paginator. -type PaginatorOption[T any] func(*Paginator[T]) +type PaginatorOption func(*PaginatorOptions) // WithDefaultSize sets the default page size. -func WithDefaultSize[T any](size uint32) PaginatorOption[T] { - return func(p *Paginator[T]) { p.defaultSize = size } +func WithDefaultSize(size uint32) PaginatorOption { + return func(p *PaginatorOptions) { p.DefaultSize = size } } // WithMaxSize sets the maximum page size. -func WithMaxSize[T any](size uint32) PaginatorOption[T] { - return func(p *Paginator[T]) { p.maxSize = size } +func WithMaxSize(size uint32) PaginatorOption { + return func(p *PaginatorOptions) { p.MaxSize = size } } // WithSort sets the default sort order. -func WithSort[T any](sort ...string) PaginatorOption[T] { - return func(p *Paginator[T]) { p.defaultSort = sort } +func WithSort(sort ...string) PaginatorOption { + return func(p *PaginatorOptions) { p.Sort = sort } } // WithColumnFunc sets a function to transform column names. -func WithColumnFunc[T any](f func(string) string) PaginatorOption[T] { - return func(p *Paginator[T]) { p.columnFunc = f } +func WithColumnFunc(f func(string) string) PaginatorOption { + return func(p *PaginatorOptions) { p.ColumnFunc = f } } // NewPaginator creates a new paginator with the given options. // Default page size is 10 and max size is 50. -func NewPaginator[T any](options ...PaginatorOption[T]) Paginator[T] { +func NewPaginator[T any](options ...PaginatorOption) Paginator[T] { p := Paginator[T]{ - defaultSize: DefaultPageSize, - maxSize: MaxPageSize, + PaginatorOptions: PaginatorOptions{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + }, } for _, opt := range options { - opt(&p) + opt(&p.PaginatorOptions) } return p } +type PaginatorOptions struct { + DefaultSize uint32 + MaxSize uint32 + Sort []string + ColumnFunc func(string) string +} + // Paginator is a helper to paginate results. type Paginator[T any] struct { - defaultSize uint32 - maxSize uint32 - defaultSort []string - columnFunc func(string) string + PaginatorOptions } func (p Paginator[T]) getOrder(page *Page) []string { - sort := page.GetOrder(p.defaultSort...) + sort := page.GetOrder(p.Sort...) list := make([]string, len(sort)) for i, s := range sort { - if p.columnFunc != nil { - s.Column = p.columnFunc(s.Column) + if p.ColumnFunc != nil { + s.Column = p.ColumnFunc(s.Column) } list[i] = s.String() } @@ -205,19 +214,19 @@ func (p Paginator[T]) getOrder(page *Page) []string { func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder) { if page != nil { if page.Size == 0 { - page.Size = p.defaultSize + page.Size = p.DefaultSize } - if page.Size > p.maxSize { - page.Size = p.maxSize + if page.Size > p.MaxSize { + page.Size = p.MaxSize } } - limit := page.Limit() - q = q.Limit(page.Limit() + 1).Offset(page.Offset()).OrderBy(p.getOrder(page)...) + limit := page.Limit(&p.PaginatorOptions) + q = q.Limit(page.Limit(&p.PaginatorOptions) + 1).Offset(page.Offset(&p.PaginatorOptions)).OrderBy(p.getOrder(page)...) return make([]T, 0, limit+1), q } func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) { - limit, offset := page.Limit(), page.Offset() + limit, offset := page.Limit(&p.PaginatorOptions), page.Offset(&p.PaginatorOptions) q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") q = q + " LIMIT @limit OFFSET @offset" @@ -240,13 +249,13 @@ func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, // - it removes the last element, returning n elements // - it sets more to true in the page object func (p Paginator[T]) PrepareResult(result []T, page *Page) []T { - limit := int(page.Limit()) + limit := int(page.Limit(&p.PaginatorOptions)) page.More = len(result) > limit if page.More { result = result[:limit] } page.Size = uint32(limit) - page.Page = 1 + uint32(page.Offset())/uint32(limit) + page.Page = 1 + uint32(page.Offset(&p.PaginatorOptions))/uint32(limit) return result } diff --git a/page_test.go b/page_test.go index bdad83f..8df13d6 100644 --- a/page_test.go +++ b/page_test.go @@ -17,11 +17,11 @@ func TestPagination(t *testing.T) { MaxSize = 5 Sort = "ID" ) - paginator := pgkit.NewPaginator( - pgkit.WithColumnFunc[T](strings.ToLower), - pgkit.WithDefaultSize[T](DefaultSize), - pgkit.WithMaxSize[T](MaxSize), - pgkit.WithSort[T](Sort), + paginator := pgkit.NewPaginator[T]( + pgkit.WithColumnFunc(strings.ToLower), + pgkit.WithDefaultSize(DefaultSize), + pgkit.WithMaxSize(MaxSize), + pgkit.WithSort(Sort), ) page := pgkit.NewPage(0, 0) result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) From 3eb43d54fd5506a64ef0a121d5cf41f776d133a3 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Tue, 3 Feb 2026 09:28:02 +0100 Subject: [PATCH 02/19] refactor: simplify paginator options handling and improve pagination tests --- page.go | 61 ++++++++++++++++++++-------------------------------- page_test.go | 42 ++++++++++++++++-------------------- 2 files changed, 42 insertions(+), 61 deletions(-) diff --git a/page.go b/page.go index da1aaae..f052d54 100644 --- a/page.go +++ b/page.go @@ -135,54 +135,24 @@ func (p *Page) Offset(o *PaginatorOptions) uint64 { } func (p *Page) Limit(o *PaginatorOptions) uint64 { - n := uint64(DefaultPageSize) - if o != nil && o.DefaultSize != 0 { - n = uint64(o.DefaultSize) - } + n, maxSize := o.getDefaults() if p != nil && p.Size != 0 { n = uint64(p.Size) } - if n > MaxPageSize { - n = MaxPageSize + if n > uint64(maxSize) { + n = maxSize } return n } -// PaginatorOption is a function that sets an option on a paginator. -type PaginatorOption func(*PaginatorOptions) - -// WithDefaultSize sets the default page size. -func WithDefaultSize(size uint32) PaginatorOption { - return func(p *PaginatorOptions) { p.DefaultSize = size } -} - -// WithMaxSize sets the maximum page size. -func WithMaxSize(size uint32) PaginatorOption { - return func(p *PaginatorOptions) { p.MaxSize = size } -} - -// WithSort sets the default sort order. -func WithSort(sort ...string) PaginatorOption { - return func(p *PaginatorOptions) { p.Sort = sort } -} - -// WithColumnFunc sets a function to transform column names. -func WithColumnFunc(f func(string) string) PaginatorOption { - return func(p *PaginatorOptions) { p.ColumnFunc = f } -} - // NewPaginator creates a new paginator with the given options. // Default page size is 10 and max size is 50. -func NewPaginator[T any](options ...PaginatorOption) Paginator[T] { - p := Paginator[T]{ - PaginatorOptions: PaginatorOptions{ - DefaultSize: DefaultPageSize, - MaxSize: MaxPageSize, - }, - } - for _, opt := range options { - opt(&p.PaginatorOptions) +func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { + p := Paginator[T]{} + if options == nil { + return p } + p.PaginatorOptions = *options return p } @@ -193,6 +163,21 @@ type PaginatorOptions struct { ColumnFunc func(string) string } +func (o *PaginatorOptions) getDefaults() (defaultSize, maxSize uint64) { + defaultSize = DefaultPageSize + maxSize = MaxPageSize + if o == nil { + return + } + if o.DefaultSize != 0 { + defaultSize = uint64(o.DefaultSize) + } + if o.MaxSize != 0 { + maxSize = uint64(o.MaxSize) + } + return +} + // Paginator is a helper to paginate results. type Paginator[T any] struct { PaginatorOptions diff --git a/page_test.go b/page_test.go index 8df13d6..077270e 100644 --- a/page_test.go +++ b/page_test.go @@ -12,21 +12,17 @@ import ( type T struct{} func TestPagination(t *testing.T) { - const ( - DefaultSize = 2 - MaxSize = 5 - Sort = "ID" - ) - paginator := pgkit.NewPaginator[T]( - pgkit.WithColumnFunc(strings.ToLower), - pgkit.WithDefaultSize(DefaultSize), - pgkit.WithMaxSize(MaxSize), - pgkit.WithSort(Sort), - ) + o := &pgkit.PaginatorOptions{ + ColumnFunc: strings.ToLower, + DefaultSize: 2, + MaxSize: 5, + Sort: []string{"ID"}, + } + paginator := pgkit.NewPaginator[T](o) page := pgkit.NewPage(0, 0) result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize}, page) sql, args, err := query.ToSql() require.NoError(t, err) @@ -35,19 +31,19 @@ func TestPagination(t *testing.T) { result = paginator.PrepareResult(make([]T, 0), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize}, page) - result = paginator.PrepareResult(make([]T, MaxSize), page) - require.Len(t, result, MaxSize) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) + result = paginator.PrepareResult(make([]T, o.MaxSize), page) + require.Len(t, result, int(o.MaxSize)) + require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize}, page) - result = paginator.PrepareResult(make([]T, MaxSize+2), page) - require.Len(t, result, MaxSize) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page) + result = paginator.PrepareResult(make([]T, o.MaxSize+2), page) + require.Len(t, result, int(o.MaxSize)) + require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize, More: true}, page) } func TestInvalidSort(t *testing.T) { - paginator := pgkit.NewPaginator[T]() + paginator := pgkit.NewPaginator[T](nil) page := pgkit.NewPage(0, 0) page.Sort = []pgkit.Sort{ {Column: "ID; DROP TABLE users;", Order: pgkit.Asc}, @@ -63,7 +59,7 @@ func TestInvalidSort(t *testing.T) { } func TestPageColumnInjection(t *testing.T) { - paginator := pgkit.NewPaginator[T]() + paginator := pgkit.NewPaginator[T](nil) page := pgkit.NewPage(0, 0) page.Column = "id; DROP TABLE users;--" @@ -76,7 +72,7 @@ func TestPageColumnInjection(t *testing.T) { } func TestPageColumnSpaces(t *testing.T) { - paginator := pgkit.NewPaginator[T]() + paginator := pgkit.NewPaginator[T](nil) page := pgkit.NewPage(0, 0) page.Column = "id, name" @@ -89,7 +85,7 @@ func TestPageColumnSpaces(t *testing.T) { } func TestSortOrderInjection(t *testing.T) { - paginator := pgkit.NewPaginator[T]() + paginator := pgkit.NewPaginator[T](nil) page := pgkit.NewPage(0, 0) page.Sort = []pgkit.Sort{ {Column: "id", Order: pgkit.Order("DESC; DROP TABLE users;--")}, From d74d8726082879076d93e0be47e67b0defd9c76a Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 14:24:26 +0100 Subject: [PATCH 03/19] fix: ensure default size does not exceed max size in paginator options --- page.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/page.go b/page.go index f052d54..a4d4d19 100644 --- a/page.go +++ b/page.go @@ -175,7 +175,7 @@ func (o *PaginatorOptions) getDefaults() (defaultSize, maxSize uint64) { if o.MaxSize != 0 { maxSize = uint64(o.MaxSize) } - return + return min(defaultSize, maxSize), maxSize } // Paginator is a helper to paginate results. From e62f09b900d0468319d8a421207f08f9a53aeb82 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 15:33:09 +0100 Subject: [PATCH 04/19] fix: handle default and max sizes in paginator options and improve pagination edge case tests --- page.go | 27 +++++++++++++++++++-------- page_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/page.go b/page.go index a4d4d19..144b3f9 100644 --- a/page.go +++ b/page.go @@ -150,7 +150,13 @@ func (p *Page) Limit(o *PaginatorOptions) uint64 { func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { p := Paginator[T]{} if options == nil { - return p + options = &PaginatorOptions{} + } + if options.DefaultSize == 0 { + options.DefaultSize = DefaultPageSize + } + if options.MaxSize == 0 { + options.MaxSize = MaxPageSize } p.PaginatorOptions = *options return p @@ -197,14 +203,19 @@ func (p Paginator[T]) getOrder(page *Page) []string { // PrepareQuery adds pagination to the query. It sets the number of max rows to limit+1. func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder) { - if page != nil { - if page.Size == 0 { - page.Size = p.DefaultSize - } - if page.Size > p.MaxSize { - page.Size = p.MaxSize - } + if page == nil { + page = &Page{} } + if page.Page == 0 { + page.Page = 1 + } + if page.Size == 0 { + page.Size = p.DefaultSize + } + if p.MaxSize != 0 && page.Size > p.MaxSize { + page.Size = p.MaxSize + } + limit := page.Limit(&p.PaginatorOptions) q = q.Limit(page.Limit(&p.PaginatorOptions) + 1).Offset(page.Offset(&p.PaginatorOptions)).OrderBy(p.getOrder(page)...) return make([]T, 0, limit+1), q diff --git a/page_test.go b/page_test.go index 077270e..6762a3d 100644 --- a/page_test.go +++ b/page_test.go @@ -100,3 +100,54 @@ func TestSortOrderInjection(t *testing.T) { require.Equal(t, "SELECT * FROM t ORDER BY \"id\" ASC, \"name\" DESC, \"created_at\" ASC LIMIT 11 OFFSET 0", sql) require.Empty(t, args) } + +func TestPaginationEdgeCases(t *testing.T) { + // Test case 1: nil options, NewPage with zeros + paginator1 := pgkit.NewPaginator[T](nil) + page1 := pgkit.NewPage(0, 0) + result1, query1 := paginator1.PrepareQuery(sq.Select("*").From("t"), page1) + require.Len(t, result1, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 10}, page1) + + sql1, _, err1 := query1.ToSql() + require.NoError(t, err1) + require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql1) + + // Test case 2: nil options, empty struct assignment + paginator2 := pgkit.NewPaginator[T](nil) + page2 := &pgkit.Page{} + result2, query2 := paginator2.PrepareQuery(sq.Select("*").From("t"), page2) + require.Len(t, result2, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 10}, page2) + + sql2, _, err2 := query2.ToSql() + require.NoError(t, err2) + require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql2) + + // Test case 3: empty options, NewPage + paginator3 := pgkit.NewPaginator[T](&pgkit.PaginatorOptions{}) + page3 := pgkit.NewPage(0, 0) + result3, query3 := paginator3.PrepareQuery(sq.Select("*").From("t"), page3) + require.Len(t, result3, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 10}, page3) + + sql3, _, err3 := query3.ToSql() + require.NoError(t, err3) + require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql3) + + // Test case 4: options with defaults, struct assignment + paginator4 := pgkit.Paginator[T]{ + pgkit.PaginatorOptions{ + DefaultSize: 5, + MaxSize: 20, + }, + } + page4 := &pgkit.Page{Page: 0, Size: 0} + result4, query4 := paginator4.PrepareQuery(sq.Select("*").From("t"), page4) + require.Len(t, result4, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 5}, page4) + + sql4, _, err4 := query4.ToSql() + require.NoError(t, err4) + require.Equal(t, "SELECT * FROM t LIMIT 6 OFFSET 0", sql4) +} From 5db2fa1ad0ac1d83eaf51240407583b5a630e2f0 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 16:03:31 +0100 Subject: [PATCH 05/19] fix: streamline paginator initialization in pagination edge cases test --- page_test.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/page_test.go b/page_test.go index 6762a3d..26d3766 100644 --- a/page_test.go +++ b/page_test.go @@ -136,13 +136,8 @@ func TestPaginationEdgeCases(t *testing.T) { require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql3) // Test case 4: options with defaults, struct assignment - paginator4 := pgkit.Paginator[T]{ - pgkit.PaginatorOptions{ - DefaultSize: 5, - MaxSize: 20, - }, - } - page4 := &pgkit.Page{Page: 0, Size: 0} + paginator4 := pgkit.Paginator[T]{pgkit.PaginatorOptions{DefaultSize: 5, MaxSize: 20}} + page4 := &pgkit.Page{} result4, query4 := paginator4.PrepareQuery(sq.Select("*").From("t"), page4) require.Len(t, result4, 0) require.Equal(t, &pgkit.Page{Page: 1, Size: 5}, page4) From abe7931d45b3febd30346464d68e081f55d4d690 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 16:11:38 +0100 Subject: [PATCH 06/19] fix: ensure max size is not less than default size in paginator options --- page.go | 4 ++++ page_test.go | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/page.go b/page.go index 144b3f9..06b9442 100644 --- a/page.go +++ b/page.go @@ -147,6 +147,7 @@ func (p *Page) Limit(o *PaginatorOptions) uint64 { // NewPaginator creates a new paginator with the given options. // Default page size is 10 and max size is 50. +// If MaxSize is less than DefaultSize, MaxSize is set to DefaultSize. func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { p := Paginator[T]{} if options == nil { @@ -158,6 +159,9 @@ func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { if options.MaxSize == 0 { options.MaxSize = MaxPageSize } + if options.MaxSize < options.DefaultSize { + options.MaxSize = options.DefaultSize + } p.PaginatorOptions = *options return p } diff --git a/page_test.go b/page_test.go index 26d3766..b59845a 100644 --- a/page_test.go +++ b/page_test.go @@ -145,4 +145,15 @@ func TestPaginationEdgeCases(t *testing.T) { sql4, _, err4 := query4.ToSql() require.NoError(t, err4) require.Equal(t, "SELECT * FROM t LIMIT 6 OFFSET 0", sql4) + + // Test case 5: max size lower than default size + paginator5 := pgkit.NewPaginator[T](&pgkit.PaginatorOptions{DefaultSize: 20, MaxSize: 5}) + page5 := &pgkit.Page{} + result5, query5 := paginator5.PrepareQuery(sq.Select("*").From("t"), page5) + require.Len(t, result5, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 20}, page5) + + sql5, _, err5 := query5.ToSql() + require.NoError(t, err5) + require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql5) } From 6b9c2bde48f654ef97a451c5810b90c6898e021c Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 16:23:02 +0100 Subject: [PATCH 07/19] linter pass --- tests/helpers_test.go | 26 -------------------------- tests/pgkit_test.go | 37 +++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 42 deletions(-) diff --git a/tests/helpers_test.go b/tests/helpers_test.go index dbddd16..f25ced6 100644 --- a/tests/helpers_test.go +++ b/tests/helpers_test.go @@ -4,37 +4,11 @@ import ( "context" "fmt" "testing" - "time" "github.com/stretchr/testify/assert" ) -func truncateAllTables(t *testing.T) { - truncateTable(t, "accounts") - truncateTable(t, "reviews") - truncateTable(t, "logs") - truncateTable(t, "stats") - truncateTable(t, "articles") -} - func truncateTable(t *testing.T, tableName string) { _, err := DB.Conn.Exec(context.Background(), fmt.Sprintf(`TRUNCATE TABLE %q CASCADE`, tableName)) assert.NoError(t, err) } - -func measureCall(fn func() error) (time.Duration, error) { - t0 := time.Now() - err := fn() - return time.Since(t0), err -} - -func measureCalls(n int, fn func() error) (time.Duration, error) { - t0 := time.Now() - for i := 0; i < n; i++ { - err := fn() - if err != nil { - return time.Since(t0), err - } - } - return time.Since(t0), nil -} diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index df65d4c..26ae5c0 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -198,8 +197,8 @@ func TestInsertAndSelectRecords(t *testing.T) { // Build select query rows, err := DB.Conn.Query(context.Background(), selectq, args...) - defer rows.Close() assert.NoError(t, err) + defer rows.Close() // Scan result into *Account object a := &Account{} @@ -245,8 +244,8 @@ func TestQueryWithNoResults(t *testing.T) { // or, with more verbose method: { rows, err := DB.Conn.Query(context.Background(), selectq, args...) - defer rows.Close() assert.NoError(t, err) + defer rows.Close() err = DB.Query.Scan.ScanAll(&accounts, rows) @@ -321,7 +320,7 @@ func TestRecordsWithJSONStruct(t *testing.T) { // Assert record mapping for nested jsonb struct cols, _, err := pgkit.Map(article) assert.NoError(t, err) - sort.Sort(sort.StringSlice(cols)) + sort.Strings(cols) assert.Equal(t, []string{"alias", "author", "content"}, cols) // Insert record @@ -478,7 +477,7 @@ func TestSugarUpdateRecord(t *testing.T) { assert.Equal(t, "JUL14", accountResp.Name) assert.True(t, accountResp2.ID != 0) assert.True(t, accountResp2.ID == accountResp.ID) - assert.True(t, accountResp2.CreatedAt == accountResp.CreatedAt) + assert.True(t, accountResp2.CreatedAt.Equal(accountResp.CreatedAt)) } func TestSugarUpdateRecordColumns(t *testing.T) { @@ -517,7 +516,7 @@ func TestTransactionBasics(t *testing.T) { truncateTable(t, "accounts") // Insert some rows + commit - pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { + err := pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { // Insert 1 insertq, args, err := DB.SQL.Insert("accounts").Columns("name", "disabled").Values("peter", false).ToSql() require.NoError(t, err) @@ -534,6 +533,7 @@ func TestTransactionBasics(t *testing.T) { return nil }) + require.NoError(t, err) // Assert above records have been made { @@ -547,7 +547,7 @@ func TestTransactionBasics(t *testing.T) { } // Insert some rows -- but rollback - pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { // Insert 1 insertq, args, err := DB.SQL.Insert("accounts").Columns("name", "disabled").Values("zelda", false).ToSql() require.NoError(t, err) @@ -564,6 +564,7 @@ func TestTransactionBasics(t *testing.T) { return fmt.Errorf("something bad happend") }) + require.Error(t, err) // Assert above records were rolled back { @@ -578,7 +579,7 @@ func TestTransactionBasics(t *testing.T) { func TestSugarTransaction(t *testing.T) { truncateTable(t, "accounts") - pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { + err := pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { rec1 := &Account{ Name: "peter", Disabled: false, @@ -599,6 +600,7 @@ func TestSugarTransaction(t *testing.T) { return nil }) + require.NoError(t, err) // Assert above records have been made { @@ -703,7 +705,8 @@ func TestBatchQuery(t *testing.T) { require.NoError(t, err) accounts = append(accounts, &account) } - br.Close() + err := br.Close() + require.NoError(t, err) require.Len(t, accounts, len(names)) for i := 0; i < len(names); i++ { @@ -727,7 +730,10 @@ func TestSugarBatchQuery(t *testing.T) { batchResults, batchLen, err := DB.Query.BatchQuery(ctx, queries) require.NoError(t, err) - defer batchResults.Close() + defer func() { + err := batchResults.Close() + require.NoError(t, err) + }() var accounts []*Account @@ -826,6 +832,7 @@ func TestSlogQueryTracerWithValuesReplaced(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -924,6 +931,7 @@ func TestSlogQueryTracerUsingContextToInit(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -981,6 +989,7 @@ func TestSlogQueryTracerWithErr(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -1029,6 +1038,7 @@ func TestSlogSlowQuery(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -1065,6 +1075,7 @@ func TestSlogTracerBatchQuery(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -1148,9 +1159,3 @@ func connectToDb(conf pgkit.Config) (*pgkit.DB, error) { } return dbClient, err } - -func hexEncode(b []byte) string { - enc := make([]byte, len(b)*2) - hex.Encode(enc[0:], b) - return string(enc) -} From 81357e9318af80b15d67220ad4bd1549fa377c7e Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 16:26:30 +0100 Subject: [PATCH 08/19] fix: replace assert.True with assert.Equal and assert.NotZero for clarity in tests --- dbtype/bigint_test.go | 4 ++-- tests/pgkit_test.go | 50 +++++++++++++++++++++---------------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/dbtype/bigint_test.go b/dbtype/bigint_test.go index 198d129..e64e871 100644 --- a/dbtype/bigint_test.go +++ b/dbtype/bigint_test.go @@ -57,8 +57,8 @@ func TestBigIntScan(t *testing.T) { assert.Error(t, b.Scan("1.")) assert.NoError(t, b.Scan("100")) - assert.True(t, b.Uint64() == 100) + assert.Equal(t, uint64(100), b.Uint64()) assert.NoError(t, b.Scan("2e0")) - assert.True(t, b.Uint64() == 2) + assert.Equal(t, uint64(2), b.Uint64()) } diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index 26ae5c0..281f387 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "log" @@ -67,7 +66,7 @@ func TestSugarInsertAndSelectRows(t *testing.T) { require.NoError(t, err) require.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "peter", accounts[0].Name) } @@ -90,7 +89,7 @@ func TestInsertAndSelectRows(t *testing.T) { require.NoError(t, err) require.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "peter", accounts[0].Name) } @@ -113,7 +112,7 @@ func TestSugarInsertAndSelectRecords(t *testing.T) { err = DB.Query.GetAll(context.Background(), q2, &accounts) assert.NoError(t, err) assert.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "joe", accounts[0].Name) // Select one -- into object @@ -121,7 +120,7 @@ func TestSugarInsertAndSelectRecords(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, account) assert.NoError(t, err) assert.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "joe", accounts[0].Name) // Select one -- into struct value @@ -129,7 +128,7 @@ func TestSugarInsertAndSelectRecords(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &accountv) assert.NoError(t, err) assert.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "joe", accounts[0].Name) } @@ -160,7 +159,7 @@ func TestSugarInsertAndSelectRecordsReturningID(t *testing.T) { err := DB.Query.QueryRow(context.Background(), DB.SQL.InsertRecord(rec).Suffix(`RETURNING "id"`)).Scan(&rec.ID) assert.NoError(t, err) - assert.True(t, rec.ID > 0) + assert.NotZero(t, rec.ID) // Select one -- into object account := &Account{} @@ -205,7 +204,7 @@ func TestInsertAndSelectRecords(t *testing.T) { err = DB.Query.Scan.ScanOne(a, rows) assert.NoError(t, err) - assert.True(t, a.ID != 0) + assert.NotZero(t, a.ID) assert.Equal(t, "joe", a.Name) assert.Equal(t, true, a.Disabled) @@ -224,8 +223,7 @@ func TestSugarQueryWithNoResults(t *testing.T) { var account interface{} err := DB.Query.GetOne(context.Background(), q, &account) - assert.True(t, errors.Is(err, pgkit.ErrNoRows)) - assert.True(t, errors.Is(err, pgx.ErrNoRows)) + assert.ErrorIs(t, err, pgkit.ErrNoRows) } func TestQueryWithNoResults(t *testing.T) { @@ -257,7 +255,7 @@ func TestQueryWithNoResults(t *testing.T) { { var a *Account err = DB.Query.Scan.Get(context.Background(), DB.Conn, a, selectq, args...) - assert.True(t, errors.Is(err, pgx.ErrNoRows)) + assert.ErrorIs(t, err, pgkit.ErrNoRows) } } @@ -354,7 +352,7 @@ func TestRowsWithBigInt(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &sout) assert.NoError(t, err) assert.Equal(t, "count", sout.Key) - assert.True(t, sout.Num.Int64() == 2) + assert.Equal(t, int64(2), sout.Num.Int64()) assert.True(t, sout.Num.IsValid) assert.False(t, sout.Rating.IsValid) @@ -375,7 +373,7 @@ func TestRowsWithBigInt(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &sout) assert.NoError(t, err) assert.Equal(t, "count2", sout.Key) - assert.True(t, sout.Num.String() == "12323942398472837489234") + assert.Equal(t, "12323942398472837489234", sout.Num.String()) assert.True(t, sout.Num.IsValid) assert.False(t, sout.Rating.IsValid) @@ -399,9 +397,9 @@ func TestRowsWithBigInt(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &sout) assert.NoError(t, err) assert.Equal(t, "count3", sout.Key) - assert.True(t, sout.Num.String() == "44") + assert.Equal(t, "44", sout.Num.String()) assert.True(t, sout.Num.IsValid) - assert.True(t, sout.Rating.String() == "5") + assert.Equal(t, "5", sout.Rating.String()) assert.True(t, sout.Rating.IsValid) } @@ -463,7 +461,7 @@ func TestSugarUpdateRecord(t *testing.T) { err = DB.Query.GetOne(context.Background(), DB.SQL.Select("*").From("accounts"), accountResp) assert.NoError(t, err) assert.Equal(t, "julia", accountResp.Name) - assert.True(t, accountResp.ID != 0) + assert.NotZero(t, accountResp.ID) // Update accountResp.Name = "JUL14" @@ -475,9 +473,9 @@ func TestSugarUpdateRecord(t *testing.T) { err = DB.Query.GetOne(context.Background(), DB.SQL.Select("*").From("accounts"), accountResp2) assert.NoError(t, err) assert.Equal(t, "JUL14", accountResp.Name) - assert.True(t, accountResp2.ID != 0) - assert.True(t, accountResp2.ID == accountResp.ID) - assert.True(t, accountResp2.CreatedAt.Equal(accountResp.CreatedAt)) + assert.NotZero(t, accountResp2.ID) + assert.Equal(t, accountResp2.ID, accountResp.ID) + assert.Equal(t, accountResp2.CreatedAt, accountResp.CreatedAt) } func TestSugarUpdateRecordColumns(t *testing.T) { @@ -493,7 +491,7 @@ func TestSugarUpdateRecordColumns(t *testing.T) { err = DB.Query.GetOne(context.Background(), DB.SQL.Select("*").From("accounts"), accountResp) assert.NoError(t, err) assert.Equal(t, "peter", accountResp.Name) - assert.True(t, accountResp.ID != 0) + assert.NotZero(t, accountResp.ID) assert.False(t, accountResp.Disabled) // Update @@ -508,8 +506,8 @@ func TestSugarUpdateRecordColumns(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "peter", accountResp2.Name) // should not have changed, expect as previous was recorded assert.True(t, accountResp2.Disabled) - assert.True(t, accountResp2.ID != 0) - assert.True(t, accountResp2.ID == accountResp.ID) + assert.NotZero(t, accountResp2.ID) + assert.Equal(t, accountResp2.ID, accountResp.ID) } func TestTransactionBasics(t *testing.T) { @@ -542,8 +540,8 @@ func TestTransactionBasics(t *testing.T) { err := DB.Query.GetAll(context.Background(), q, &accounts) require.NoError(t, err) assert.Len(t, accounts, 2) - assert.True(t, accounts[0].Name == "mario") - assert.True(t, accounts[1].Name == "peter") + assert.Equal(t, "mario", accounts[0].Name) + assert.Equal(t, "peter", accounts[1].Name) } // Insert some rows -- but rollback @@ -609,8 +607,8 @@ func TestSugarTransaction(t *testing.T) { err := DB.Query.GetAll(context.Background(), q, &accounts) require.NoError(t, err) assert.Len(t, accounts, 2) - assert.True(t, accounts[0].Name == "mario") - assert.True(t, accounts[1].Name == "peter") + assert.Equal(t, "mario", accounts[0].Name) + assert.Equal(t, "peter", accounts[1].Name) } } From 4273b9b3669da5fb7cf8301eed56d30a81d0e3ae Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 17:44:24 +0100 Subject: [PATCH 09/19] fix: update paginator options to clarify default and max size behavior --- page.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/page.go b/page.go index 06b9442..20ce5e8 100644 --- a/page.go +++ b/page.go @@ -146,7 +146,6 @@ func (p *Page) Limit(o *PaginatorOptions) uint64 { } // NewPaginator creates a new paginator with the given options. -// Default page size is 10 and max size is 50. // If MaxSize is less than DefaultSize, MaxSize is set to DefaultSize. func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { p := Paginator[T]{} @@ -166,11 +165,21 @@ func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { return p } +// PaginatorOptions are the options for the paginator. type PaginatorOptions struct { + // DefaultSize is the default number of rows per page. + // If zero, DefaultPageSize is used. DefaultSize uint32 - MaxSize uint32 - Sort []string - ColumnFunc func(string) string + + // MaxSize is the maximum number of rows per page. + // If zero, MaxPageSize is used. If less than DefaultSize, it is set to DefaultSize. + MaxSize uint32 + + // Sort is the default sort order. + Sort []string + + // ColumnFunc is a transformation applied to column names. + ColumnFunc func(string) string } func (o *PaginatorOptions) getDefaults() (defaultSize, maxSize uint64) { @@ -185,7 +194,7 @@ func (o *PaginatorOptions) getDefaults() (defaultSize, maxSize uint64) { if o.MaxSize != 0 { maxSize = uint64(o.MaxSize) } - return min(defaultSize, maxSize), maxSize + return min(defaultSize, maxSize), max(defaultSize, maxSize) } // Paginator is a helper to paginate results. From afdd6ea38b59f189b4e836ac9d85fed08ae8ad51 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 4 Feb 2026 17:49:14 +0100 Subject: [PATCH 10/19] fix: simplify filter checks in UpdateRecordColumns and TraversalsByName methods --- builder.go | 6 ++--- internal/reflectx/reflect.go | 3 +-- internal/reflectx/reflect_test.go | 39 ++++++++++++++++--------------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/builder.go b/builder.go index 6ea3256..8e8bba6 100644 --- a/builder.go +++ b/builder.go @@ -78,9 +78,7 @@ func (s StatementBuilder) UpdateRecordColumns(record interface{}, whereExpr sq.E // when filter is empty or nil, update the entire record var filter []string - if filterCols == nil || len(filterCols) == 0 { - filter = nil - } else { + if len(filterCols) != 0 { filter = filterCols } @@ -126,7 +124,7 @@ func createMap(k []string, v []interface{}, filterK []string) (map[string]interf m := make(map[string]interface{}, len(k)) for i := 0; i < len(k); i++ { - if filterK == nil || len(filterK) == 0 { + if len(filterK) == 0 { m[k[i]] = v[i] continue } diff --git a/internal/reflectx/reflect.go b/internal/reflectx/reflect.go index 0b10994..91b5285 100644 --- a/internal/reflectx/reflect.go +++ b/internal/reflectx/reflect.go @@ -3,7 +3,6 @@ // allows for Go-compatible named attribute access, including accessing embedded // struct attributes and the ability to use functions and struct tags to // customize field names. -// package reflectx import ( @@ -167,7 +166,7 @@ func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { // to a struct. Returns empty int slice for each name not found. func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { r := make([][]int, 0, len(names)) - m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { + _ = m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { if i == nil { r = append(r, []int{}) } else { diff --git a/internal/reflectx/reflect_test.go b/internal/reflectx/reflect_test.go index e73af5b..1e3f7ae 100644 --- a/internal/reflectx/reflect_test.go +++ b/internal/reflectx/reflect_test.go @@ -57,7 +57,7 @@ func TestBasicEmbedded(t *testing.T) { z.A = 1 z.B = 2 z.C = 4 - z.Bar.Foo.A = 3 + z.Foo.A = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) @@ -75,12 +75,12 @@ func TestBasicEmbedded(t *testing.T) { t.Errorf("Expecting %d, got %d", z.A, ival(v)) } v = m.FieldByName(zv, "Bar.B") - if ival(v) != z.Bar.B { - t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) + if ival(v) != z.B { + t.Errorf("Expecting %d, got %d", z.B, ival(v)) } v = m.FieldByName(zv, "Bar.A") - if ival(v) != z.Bar.Foo.A { - t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + if ival(v) != z.Foo.A { + t.Errorf("Expecting %d, got %d", z.Foo.A, ival(v)) } v = m.FieldByName(zv, "Bar.C") if _, ok := v.Interface().(int); ok { @@ -127,7 +127,7 @@ func TestBasicEmbeddedWithTags(t *testing.T) { z := Baz{} z.A = 1 z.B = 2 - z.Bar.Foo.A = 3 + z.Foo.A = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) @@ -245,12 +245,12 @@ func TestNestedStruct(t *testing.T) { t.Errorf("Expecting field to not exist") } v = m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + if v.Interface().(string) != post.Title { + t.Errorf("Expecting %s, got %s", post.Title, v.Interface().(string)) } v = m.FieldByName(pv, "asset.details.active") - if v.Interface().(bool) != post.Asset.Details.Active { - t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) + if v.Interface().(bool) != post.Details.Active { + t.Errorf("Expecting %v, got %v", post.Details.Active, v.Interface().(bool)) } } @@ -317,7 +317,7 @@ func TestFieldsEmbedded(t *testing.T) { pp := PP{} pp.Person.Name = "Peter" pp.Place.Name = "Toronto" - pp.Article.Title = "Best city ever" + pp.Title = "Best city ever" fields := m.TypeMap(reflect.TypeOf(pp)) // for i, f := range fields { @@ -337,8 +337,8 @@ func TestFieldsEmbedded(t *testing.T) { } v = m.FieldByName(ppv, "title") - if v.Interface().(string) != pp.Article.Title { - t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) + if v.Interface().(string) != pp.Title { + t.Errorf("Expecting %s, got %s", pp.Title, v.Interface().(string)) } fi := fields.GetByPath("person") @@ -355,6 +355,7 @@ func TestFieldsEmbedded(t *testing.T) { fi = fields.GetByPath("person.name") if fi == nil { t.Errorf("Expecting person.name to exist") + return } if fi.Path != "person.name" { t.Errorf("Expecting %s, got %s", "person.name", fi.Path) @@ -366,6 +367,7 @@ func TestFieldsEmbedded(t *testing.T) { fi = fields.GetByTraversal([]int{1, 0}) if fi == nil { t.Errorf("Expecting traveral to exist") + return } if fi.Path != "name" { t.Errorf("Expecting %s, got %s", "name", fi.Path) @@ -374,6 +376,7 @@ func TestFieldsEmbedded(t *testing.T) { fi = fields.GetByTraversal([]int{2}) if fi == nil { t.Errorf("Expecting traversal to exist") + return } if _, ok := fi.Options["required"]; !ok { t.Errorf("Expecting required option to be set") @@ -404,8 +407,8 @@ func TestPtrFields(t *testing.T) { } v := m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + if v.Interface().(string) != post.Title { + t.Errorf("Expecting %s, got %s", post.Title, v.Interface().(string)) } v = m.FieldByName(pv, "author") if v.Interface().(string) != post.Author { @@ -642,7 +645,7 @@ func TestMapperMethodsByName(t *testing.T) { A0 *B `db:"A0"` B `db:"A1"` A2 int - a3 int + a3 int //nolint } val := &A{ @@ -846,12 +849,10 @@ func TestMustBe(t *testing.T) { if r := recover(); r != nil { valueErr, ok := r.(*reflect.ValueError) if !ok { - t.Errorf("unexpected Method: %s", valueErr.Method) + t.Errorf("unexpected Type: %T", r) t.Error("expected panic with *reflect.ValueError") return } - if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { - } if valueErr.Kind != reflect.String { t.Errorf("unexpected Kind: %s", valueErr.Kind) } From 520d0a14a12efb94ebb3c6c16588303dd88d768f Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 13:36:49 +0100 Subject: [PATCH 11/19] linter pass --- mapper.go | 2 +- querier.go | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mapper.go b/mapper.go index 29c7745..ca15e08 100644 --- a/mapper.go +++ b/mapper.go @@ -80,7 +80,7 @@ func MapWithOptions(record interface{}, options *MapOptions) ([]string, []interf for _, fi := range fieldMap { // Skip any fields which do not specify the `db:".."` tag - if strings.Index(string(fi.Field.Tag), dbTagPrefix) < 0 { + if !strings.Contains(string(fi.Field.Tag), dbTagPrefix) { continue } diff --git a/querier.go b/querier.go index 294d8d2..548ced7 100644 --- a/querier.go +++ b/querier.go @@ -132,13 +132,15 @@ func (q *Querier) BatchExec(ctx context.Context, queries Queries) ([]pgconn.Comm } // Send batch - var results pgx.BatchResults + type batchSender interface { + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults + } + batcher := batchSender(q.pool) if q.Tx != nil { - results = q.Tx.SendBatch(ctx, batch) - } else { - results = q.pool.SendBatch(ctx, batch) + batcher = batchSender(q.Tx) } - defer results.Close() + results := batcher.SendBatch(ctx, batch) + defer results.Close() //nolint:errcheck // Exec the number of times as we have queries in the batch so we may get the exec // result and potential error response. @@ -291,7 +293,7 @@ func (r RawSQL) ToSql() (string, []interface{}, error) { return r.Query, r.Args, nil } - if r.Args == nil || len(r.Args) == 0 { + if len(r.Args) == 0 { return r.Query, r.Args, nil // assume no params passed } From 1f0621fc33aec1033a7edaf52ec9572dff584884 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 13:38:09 +0100 Subject: [PATCH 12/19] fix: implement default settings for PaginatorOptions and Page --- page.go | 51 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/page.go b/page.go index 20ce5e8..4172cb0 100644 --- a/page.go +++ b/page.go @@ -87,6 +87,24 @@ func NewPage(size, page uint32, sort ...Sort) *Page { } } +func (p *Page) SetDefaults(o *PaginatorOptions) { + if o == nil { + o = &PaginatorOptions{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + } + } + if p.Size == 0 { + p.Size = o.DefaultSize + } + if p.Size > o.MaxSize { + p.Size = o.MaxSize + } + if p.Page == 0 { + p.Page = 1 + } +} + func (p *Page) GetOrder(defaultSort ...string) []Sort { // if page has sort, use it if p != nil && len(p.Sort) != 0 { @@ -152,15 +170,7 @@ func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { if options == nil { options = &PaginatorOptions{} } - if options.DefaultSize == 0 { - options.DefaultSize = DefaultPageSize - } - if options.MaxSize == 0 { - options.MaxSize = MaxPageSize - } - if options.MaxSize < options.DefaultSize { - options.MaxSize = options.DefaultSize - } + options.SetDefaults() p.PaginatorOptions = *options return p } @@ -182,6 +192,18 @@ type PaginatorOptions struct { ColumnFunc func(string) string } +func (p *PaginatorOptions) SetDefaults() { + if p.DefaultSize == 0 { + p.DefaultSize = DefaultPageSize + } + if p.MaxSize == 0 { + p.MaxSize = MaxPageSize + } + if p.MaxSize < p.DefaultSize { + p.MaxSize = p.DefaultSize + } +} + func (o *PaginatorOptions) getDefaults() (defaultSize, maxSize uint64) { defaultSize = DefaultPageSize maxSize = MaxPageSize @@ -216,18 +238,11 @@ func (p Paginator[T]) getOrder(page *Page) []string { // PrepareQuery adds pagination to the query. It sets the number of max rows to limit+1. func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder) { + p.SetDefaults() if page == nil { page = &Page{} } - if page.Page == 0 { - page.Page = 1 - } - if page.Size == 0 { - page.Size = p.DefaultSize - } - if p.MaxSize != 0 && page.Size > p.MaxSize { - page.Size = p.MaxSize - } + page.SetDefaults(&p.PaginatorOptions) limit := page.Limit(&p.PaginatorOptions) q = q.Limit(page.Limit(&p.PaginatorOptions) + 1).Offset(page.Offset(&p.PaginatorOptions)).OrderBy(p.getOrder(page)...) From 8c31ac499344e37d9f1f181a9b574ba8e71a44c4 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 13:43:01 +0100 Subject: [PATCH 13/19] fix: refactor pagination methods to remove PaginatorOptions dependency --- page.go | 42 +++++++++++++----------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/page.go b/page.go index 4172cb0..1cd9968 100644 --- a/page.go +++ b/page.go @@ -141,7 +141,7 @@ func (p *Page) GetOrder(defaultSort ...string) []Sort { return sort } -func (p *Page) Offset(o *PaginatorOptions) uint64 { +func (p *Page) Offset() uint64 { n := uint64(1) if p != nil && p.Page != 0 { n = uint64(p.Page) @@ -149,16 +149,16 @@ func (p *Page) Offset(o *PaginatorOptions) uint64 { if n < 1 { n = 1 } - return (n - 1) * p.Limit(o) + return (n - 1) * p.Limit() } -func (p *Page) Limit(o *PaginatorOptions) uint64 { - n, maxSize := o.getDefaults() +func (p *Page) Limit() uint64 { + n := uint64(DefaultPageSize) if p != nil && p.Size != 0 { n = uint64(p.Size) } - if n > uint64(maxSize) { - n = maxSize + if n > DefaultPageSize { + n = DefaultPageSize } return n } @@ -199,24 +199,8 @@ func (p *PaginatorOptions) SetDefaults() { if p.MaxSize == 0 { p.MaxSize = MaxPageSize } - if p.MaxSize < p.DefaultSize { - p.MaxSize = p.DefaultSize - } -} - -func (o *PaginatorOptions) getDefaults() (defaultSize, maxSize uint64) { - defaultSize = DefaultPageSize - maxSize = MaxPageSize - if o == nil { - return - } - if o.DefaultSize != 0 { - defaultSize = uint64(o.DefaultSize) - } - if o.MaxSize != 0 { - maxSize = uint64(o.MaxSize) - } - return min(defaultSize, maxSize), max(defaultSize, maxSize) + p.MaxSize = max(p.DefaultSize, p.MaxSize) + p.DefaultSize = min(p.DefaultSize, p.MaxSize) } // Paginator is a helper to paginate results. @@ -244,13 +228,13 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele } page.SetDefaults(&p.PaginatorOptions) - limit := page.Limit(&p.PaginatorOptions) - q = q.Limit(page.Limit(&p.PaginatorOptions) + 1).Offset(page.Offset(&p.PaginatorOptions)).OrderBy(p.getOrder(page)...) + limit := page.Limit() + q = q.Limit(limit + 1).Offset(page.Offset()).OrderBy(p.getOrder(page)...) return make([]T, 0, limit+1), q } func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) { - limit, offset := page.Limit(&p.PaginatorOptions), page.Offset(&p.PaginatorOptions) + limit, offset := page.Limit(), page.Offset() q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") q = q + " LIMIT @limit OFFSET @offset" @@ -273,13 +257,13 @@ func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, // - it removes the last element, returning n elements // - it sets more to true in the page object func (p Paginator[T]) PrepareResult(result []T, page *Page) []T { - limit := int(page.Limit(&p.PaginatorOptions)) + limit := int(page.Limit()) page.More = len(result) > limit if page.More { result = result[:limit] } page.Size = uint32(limit) - page.Page = 1 + uint32(page.Offset(&p.PaginatorOptions))/uint32(limit) + page.Page = 1 + uint32(page.Offset())/uint32(limit) return result } From 726813a2475ab8fbb5513ad7bdb92443500e741f Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 13:48:05 +0100 Subject: [PATCH 14/19] fix test --- page.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/page.go b/page.go index 1cd9968..1832f11 100644 --- a/page.go +++ b/page.go @@ -157,9 +157,6 @@ func (p *Page) Limit() uint64 { if p != nil && p.Size != 0 { n = uint64(p.Size) } - if n > DefaultPageSize { - n = DefaultPageSize - } return n } From 3b218c7f1172484dbdc8392313183a9d004d7a6e Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 14:05:42 +0100 Subject: [PATCH 15/19] fix: replace manual assertions with testify assertions in reflect tests --- internal/reflectx/reflect_test.go | 311 ++++++++---------------------- 1 file changed, 85 insertions(+), 226 deletions(-) diff --git a/internal/reflectx/reflect_test.go b/internal/reflectx/reflect_test.go index 1e3f7ae..1f23309 100644 --- a/internal/reflectx/reflect_test.go +++ b/internal/reflectx/reflect_test.go @@ -4,6 +4,9 @@ import ( "reflect" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func ival(v reflect.Value) int { @@ -22,17 +25,11 @@ func TestBasic(t *testing.T) { m := NewMapperFunc("", func(s string) string { return s }) v := m.FieldByName(fv, "A") - if ival(v) != f.A { - t.Errorf("Expecting %d, got %d", ival(v), f.A) - } + assert.Equal(t, f.A, ival(v)) v = m.FieldByName(fv, "B") - if ival(v) != f.B { - t.Errorf("Expecting %d, got %d", f.B, ival(v)) - } + assert.Equal(t, f.B, ival(v)) v = m.FieldByName(fv, "C") - if ival(v) != f.C { - t.Errorf("Expecting %d, got %d", f.C, ival(v)) - } + assert.Equal(t, f.C, ival(v)) } func TestBasicEmbedded(t *testing.T) { @@ -62,35 +59,24 @@ func TestBasicEmbedded(t *testing.T) { zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) - if len(fields.Index) != 5 { - t.Errorf("Expecting 5 fields") - } + assert.Len(t, fields.Index, 5) // for _, fi := range fields.Index { // log.Println(fi) // } v := m.FieldByName(zv, "A") - if ival(v) != z.A { - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } + assert.Equal(t, z.A, ival(v)) v = m.FieldByName(zv, "Bar.B") - if ival(v) != z.B { - t.Errorf("Expecting %d, got %d", z.B, ival(v)) - } + assert.Equal(t, z.B, ival(v)) v = m.FieldByName(zv, "Bar.A") - if ival(v) != z.Foo.A { - t.Errorf("Expecting %d, got %d", z.Foo.A, ival(v)) - } + assert.Equal(t, z.Foo.A, ival(v)) v = m.FieldByName(zv, "Bar.C") - if _, ok := v.Interface().(int); ok { - t.Errorf("Expecting Bar.C to not exist") - } + _, ok := v.Interface().(int) + assert.False(t, ok, "Expecting Bar.C to not exist") fi := fields.GetByPath("Bar.C") - if fi != nil { - t.Errorf("Bar.C should not exist") - } + assert.Nil(t, fi, "Bar.C should not exist") } func TestEmbeddedSimple(t *testing.T) { @@ -132,22 +118,16 @@ func TestBasicEmbeddedWithTags(t *testing.T) { zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) - if len(fields.Index) != 5 { - t.Errorf("Expecting 5 fields") - } + assert.Len(t, fields.Index, 5) // for _, fi := range fields.index { // log.Println(fi) // } v := m.FieldByName(zv, "a") - if ival(v) != z.A { // the dominant field - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } + assert.Equal(t, z.A, ival(v)) // the dominant field v = m.FieldByName(zv, "b") - if ival(v) != z.B { - t.Errorf("Expecting %d, got %d", z.B, ival(v)) - } + assert.Equal(t, z.B, ival(v)) } func TestBasicEmbeddedWithSameName(t *testing.T) { @@ -171,22 +151,14 @@ func TestBasicEmbeddedWithSameName(t *testing.T) { zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) - if len(fields.Index) != 4 { - t.Errorf("Expecting 3 fields, found %d", len(fields.Index)) - } + assert.Len(t, fields.Index, 4) v := m.FieldByName(zv, "a") - if ival(v) != z.A { // the dominant field - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } + assert.Equal(t, z.A, ival(v)) // the dominant field v = m.FieldByName(zv, "b") - if ival(v) != z.B { - t.Errorf("Expecting %d, got %d", z.B, ival(v)) - } + assert.Equal(t, z.B, ival(v)) v = m.FieldByName(zv, "Foo") - if ival(v) != z.Foo.Foo { - t.Errorf("Expecting %d, got %d", z.Foo.Foo, ival(v)) - } + assert.Equal(t, z.Foo.Foo, ival(v)) } func TestFlatTags(t *testing.T) { @@ -205,13 +177,9 @@ func TestFlatTags(t *testing.T) { pv := reflect.ValueOf(post) v := m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) v = m.FieldByName(pv, "title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } + assert.Equal(t, post.Asset.Title, v.Interface().(string)) } func TestNestedStruct(t *testing.T) { @@ -237,21 +205,14 @@ func TestNestedStruct(t *testing.T) { pv := reflect.ValueOf(post) v := m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) v = m.FieldByName(pv, "title") - if _, ok := v.Interface().(string); ok { - t.Errorf("Expecting field to not exist") - } + _, ok := v.Interface().(string) + assert.False(t, ok, "Expecting field to not exist") v = m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Title { - t.Errorf("Expecting %s, got %s", post.Title, v.Interface().(string)) - } + assert.Equal(t, post.Title, v.Interface().(string)) v = m.FieldByName(pv, "asset.details.active") - if v.Interface().(bool) != post.Details.Active { - t.Errorf("Expecting %v, got %v", post.Details.Active, v.Interface().(bool)) - } + assert.Equal(t, post.Details.Active, v.Interface().(bool)) } func TestInlineStruct(t *testing.T) { @@ -272,18 +233,12 @@ func TestInlineStruct(t *testing.T) { ev := reflect.ValueOf(em) fields := m.TypeMap(reflect.TypeOf(em)) - if len(fields.Index) != 6 { - t.Errorf("Expecting 6 fields") - } + assert.Len(t, fields.Index, 6) v := m.FieldByName(ev, "employee.name") - if v.Interface().(string) != em.Employee.Name { - t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) - } + assert.Equal(t, em.Employee.Name, v.Interface().(string)) v = m.FieldByName(ev, "boss.id") - if ival(v) != em.Boss.ID { - t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) - } + assert.Equal(t, em.Boss.ID, ival(v)) } func TestRecursiveStruct(t *testing.T) { @@ -327,65 +282,36 @@ func TestFieldsEmbedded(t *testing.T) { ppv := reflect.ValueOf(pp) v := m.FieldByName(ppv, "person.name") - if v.Interface().(string) != pp.Person.Name { - t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) - } + assert.Equal(t, pp.Person.Name, v.Interface().(string)) v = m.FieldByName(ppv, "name") - if v.Interface().(string) != pp.Place.Name { - t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) - } + assert.Equal(t, pp.Place.Name, v.Interface().(string)) v = m.FieldByName(ppv, "title") - if v.Interface().(string) != pp.Title { - t.Errorf("Expecting %s, got %s", pp.Title, v.Interface().(string)) - } + assert.Equal(t, pp.Title, v.Interface().(string)) fi := fields.GetByPath("person") - if _, ok := fi.Options["required"]; !ok { - t.Errorf("Expecting required option to be set") - } - if !fi.Embedded { - t.Errorf("Expecting field to be embedded") - } - if len(fi.Index) != 1 || fi.Index[0] != 0 { - t.Errorf("Expecting index to be [0]") - } + _, ok := fi.Options["required"] + assert.True(t, ok, "Expecting required option to be set") + assert.True(t, fi.Embedded, "Expecting field to be embedded") + assert.Equal(t, []int{0}, fi.Index) fi = fields.GetByPath("person.name") - if fi == nil { - t.Errorf("Expecting person.name to exist") - return - } - if fi.Path != "person.name" { - t.Errorf("Expecting %s, got %s", "person.name", fi.Path) - } - if fi.Options["size"] != "64" { - t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) - } + require.NotNil(t, fi, "Expecting person.name to exist") + assert.Equal(t, "person.name", fi.Path) + assert.Equal(t, "64", fi.Options["size"]) fi = fields.GetByTraversal([]int{1, 0}) - if fi == nil { - t.Errorf("Expecting traveral to exist") - return - } - if fi.Path != "name" { - t.Errorf("Expecting %s, got %s", "name", fi.Path) - } + require.NotNil(t, fi, "Expecting traversal to exist") + assert.Equal(t, "name", fi.Path) fi = fields.GetByTraversal([]int{2}) - if fi == nil { - t.Errorf("Expecting traversal to exist") - return - } - if _, ok := fi.Options["required"]; !ok { - t.Errorf("Expecting required option to be set") - } + require.NotNil(t, fi, "Expecting traversal to exist") + _, ok = fi.Options["required"] + assert.True(t, ok, "Expecting required option to be set") trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) - if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { - t.Errorf("Expecting traversal: %v", trs) - } + assert.Equal(t, [][]int{{0, 0}, {1, 0}, {2, 0}}, trs) } func TestPtrFields(t *testing.T) { @@ -402,18 +328,12 @@ func TestPtrFields(t *testing.T) { pv := reflect.ValueOf(post) fields := m.TypeMap(reflect.TypeOf(post)) - if len(fields.Index) != 3 { - t.Errorf("Expecting 3 fields") - } + assert.Len(t, fields.Index, 3) v := m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Title { - t.Errorf("Expecting %s, got %s", post.Title, v.Interface().(string)) - } + assert.Equal(t, post.Title, v.Interface().(string)) v = m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) } func TestNamedPtrFields(t *testing.T) { @@ -439,30 +359,18 @@ func TestNamedPtrFields(t *testing.T) { pv := reflect.ValueOf(post) fields := m.TypeMap(reflect.TypeOf(post)) - if len(fields.Index) != 9 { - t.Errorf("Expecting 9 fields") - } + assert.Len(t, fields.Index, 9) v := m.FieldByName(pv, "asset1.title") - if v.Interface().(string) != post.Asset1.Title { - t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) - } + assert.Equal(t, post.Asset1.Title, v.Interface().(string)) v = m.FieldByName(pv, "asset1.owner.name") - if v.Interface().(string) != post.Asset1.Owner.Name { - t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) - } + assert.Equal(t, post.Asset1.Owner.Name, v.Interface().(string)) v = m.FieldByName(pv, "asset2.title") - if v.Interface().(string) != post.Asset2.Title { - t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) - } + assert.Equal(t, post.Asset2.Title, v.Interface().(string)) v = m.FieldByName(pv, "asset2.owner.name") - if v.Interface().(string) != post.Asset2.Owner.Name { - t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) - } + assert.Equal(t, post.Asset2.Owner.Name, v.Interface().(string)) v = m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) } func TestFieldMap(t *testing.T) { @@ -477,18 +385,10 @@ func TestFieldMap(t *testing.T) { fm := m.FieldMap(reflect.ValueOf(f)) - if len(fm) != 3 { - t.Errorf("Expecting %d keys, got %d", 3, len(fm)) - } - if fm["a"].Interface().(int) != 1 { - t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) - } - if fm["b"].Interface().(int) != 2 { - t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) - } - if fm["c"].Interface().(int) != 3 { - t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) - } + assert.Len(t, fm, 3) + assert.Equal(t, 1, fm["a"].Interface().(int)) + assert.Equal(t, 2, fm["b"].Interface().(int)) + assert.Equal(t, 3, fm["c"].Interface().(int)) } func TestTagNameMapping(t *testing.T) { @@ -507,9 +407,8 @@ func TestTagNameMapping(t *testing.T) { mapping := m.TypeMap(reflect.TypeOf(strategy)) for _, key := range []string{"strategy_id", "STRATEGYNAME"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } } @@ -525,9 +424,8 @@ func TestMapping(t *testing.T) { mapping := m.TypeMap(reflect.TypeOf(p)) for _, key := range []string{"id", "name", "wears_glasses"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } type SportsPerson struct { @@ -538,9 +436,8 @@ func TestMapping(t *testing.T) { s := SportsPerson{Weight: 100, Age: 30, Person: p} mapping = m.TypeMap(reflect.TypeOf(s)) for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } type RugbyPlayer struct { @@ -552,14 +449,12 @@ func TestMapping(t *testing.T) { r := RugbyPlayer{12, true, false, s} mapping = m.TypeMap(reflect.TypeOf(r)) for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } - if fi := mapping.GetByPath("isallblack"); fi != nil { - t.Errorf("Expecting to ignore `IsAllBlack` field") - } + fi := mapping.GetByPath("isallblack") + assert.Nil(t, fi, "Expecting to ignore `IsAllBlack` field") } func TestGetByTraversal(t *testing.T) { @@ -613,20 +508,12 @@ func TestGetByTraversal(t *testing.T) { for i, tc := range testCases { fi := tm.GetByTraversal(tc.Index) if tc.ExpectNil { - if fi != nil { - t.Errorf("%d: expected nil, got %v", i, fi) - } + assert.Nil(t, fi, "%d: expected nil", i) continue } - if fi == nil { - t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) - continue - } - - if fi.Name != tc.ExpectedName { - t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) - } + require.NotNil(t, fi, "%d: expected %s, got nil", i, tc.ExpectedName) + assert.Equal(t, tc.ExpectedName, fi.Name, "%d: name mismatch", i) } } @@ -743,37 +630,21 @@ func TestMapperMethodsByName(t *testing.T) { m := NewMapperFunc("db", func(n string) string { return n }) v := reflect.ValueOf(val) values := m.FieldsByName(v, names) - if len(values) != len(testCases) { - t.Errorf("expected %d values, got %d", len(testCases), len(values)) - t.FailNow() - } + require.Len(t, values, len(testCases)) indexes := m.TraversalsByName(v.Type(), names) - if len(indexes) != len(testCases) { - t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) - t.FailNow() - } + require.Len(t, indexes, len(testCases)) for i, val := range values { tc := testCases[i] traversal := indexes[i] - if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { - t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) - t.FailNow() - } + require.Equal(t, tc.ExpectedIndexes, traversal) val = reflect.Indirect(val) if tc.ExpectInvalid { - if val.IsValid() { - t.Errorf("%d: expected zero value, got %v", i, val) - } - continue - } - if !val.IsValid() { - t.Errorf("%d: expected valid value, got %v", i, val) + assert.False(t, val.IsValid(), "%d: expected zero value, got %v", i, val) continue } + require.True(t, val.IsValid(), "%d: expected valid value", i) actualValue := reflect.Indirect(val).Interface() - if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { - t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) - } + assert.Equal(t, tc.ExpectedValue, actualValue, "%d: value mismatch", i) } } @@ -824,13 +695,9 @@ func TestFieldByIndexes(t *testing.T) { for i, tc := range testCases { checkResults := func(v reflect.Value) { if tc.expectedValue == nil { - if !v.IsNil() { - t.Errorf("%d: expected nil, actual %v", i, v.Interface()) - } + assert.True(t, v.IsNil(), "%d: expected nil", i) } else { - if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { - t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) - } + assert.Equal(t, tc.expectedValue, v.Interface(), "%d: value mismatch", i) } } @@ -846,19 +713,11 @@ func TestMustBe(t *testing.T) { mustBe(typ, reflect.Struct) defer func() { - if r := recover(); r != nil { - valueErr, ok := r.(*reflect.ValueError) - if !ok { - t.Errorf("unexpected Type: %T", r) - t.Error("expected panic with *reflect.ValueError") - return - } - if valueErr.Kind != reflect.String { - t.Errorf("unexpected Kind: %s", valueErr.Kind) - } - } else { - t.Error("expected panic") - } + r := recover() + require.NotNil(t, r, "expected panic") + valueErr, ok := r.(*reflect.ValueError) + require.True(t, ok, "expected panic with *reflect.ValueError, got %T", r) + assert.Equal(t, reflect.String, valueErr.Kind) }() typ = reflect.TypeOf("string") From 515735a91d60c025f50cf71cb6b8c8e8b9eec6ba Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 14:57:15 +0100 Subject: [PATCH 16/19] fix: refactor paginator to use PaginatorSettings and update related tests --- page.go | 78 ++++++++++++++++++++++++++++++++-------------------- page_test.go | 50 ++++++++++++++------------------- 2 files changed, 68 insertions(+), 60 deletions(-) diff --git a/page.go b/page.go index 1832f11..7419b23 100644 --- a/page.go +++ b/page.go @@ -87,9 +87,9 @@ func NewPage(size, page uint32, sort ...Sort) *Page { } } -func (p *Page) SetDefaults(o *PaginatorOptions) { +func (p *Page) SetDefaults(o *PaginatorSettings) { if o == nil { - o = &PaginatorOptions{ + o = &PaginatorSettings{ DefaultSize: DefaultPageSize, MaxSize: MaxPageSize, } @@ -128,7 +128,7 @@ func (p *Page) GetOrder(defaultSort ...string) []Sort { } // use column sort := make([]Sort, 0) - for _, part := range strings.Split(p.Column, ",") { + for part := range strings.SplitSeq(p.Column, ",") { part = strings.TrimSpace(part) if part == "" { continue @@ -160,20 +160,8 @@ func (p *Page) Limit() uint64 { return n } -// NewPaginator creates a new paginator with the given options. -// If MaxSize is less than DefaultSize, MaxSize is set to DefaultSize. -func NewPaginator[T any](options *PaginatorOptions) Paginator[T] { - p := Paginator[T]{} - if options == nil { - options = &PaginatorOptions{} - } - options.SetDefaults() - p.PaginatorOptions = *options - return p -} - -// PaginatorOptions are the options for the paginator. -type PaginatorOptions struct { +// PaginatorSettings are the settings for the paginator. +type PaginatorSettings struct { // DefaultSize is the default number of rows per page. // If zero, DefaultPageSize is used. DefaultSize uint32 @@ -189,28 +177,59 @@ type PaginatorOptions struct { ColumnFunc func(string) string } -func (p *PaginatorOptions) SetDefaults() { - if p.DefaultSize == 0 { - p.DefaultSize = DefaultPageSize +type PaginatorOption func(*PaginatorSettings) + +func WithDefaultSize(size uint32) PaginatorOption { + return func(s *PaginatorSettings) { + s.DefaultSize = size + } +} + +func WithMaxSize(size uint32) PaginatorOption { + return func(s *PaginatorSettings) { + s.MaxSize = size + } +} + +func WithDefaultSort(sort ...string) PaginatorOption { + return func(s *PaginatorSettings) { + s.Sort = sort + } +} + +func WithColumnFunc(f func(string) string) PaginatorOption { + return func(s *PaginatorSettings) { + s.ColumnFunc = f + } +} + +// NewPaginator creates a new paginator with the given options. +// If MaxSize is less than DefaultSize, MaxSize is set to DefaultSize. +func NewPaginator[T any](options ...PaginatorOption) Paginator[T] { + settings := &PaginatorSettings{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + } + for _, option := range options { + option(settings) } - if p.MaxSize == 0 { - p.MaxSize = MaxPageSize + if settings.MaxSize < settings.DefaultSize { + settings.MaxSize = settings.DefaultSize } - p.MaxSize = max(p.DefaultSize, p.MaxSize) - p.DefaultSize = min(p.DefaultSize, p.MaxSize) + return Paginator[T]{settings: *settings} } // Paginator is a helper to paginate results. type Paginator[T any] struct { - PaginatorOptions + settings PaginatorSettings } func (p Paginator[T]) getOrder(page *Page) []string { - sort := page.GetOrder(p.Sort...) + sort := page.GetOrder(p.settings.Sort...) list := make([]string, len(sort)) for i, s := range sort { - if p.ColumnFunc != nil { - s.Column = p.ColumnFunc(s.Column) + if p.settings.ColumnFunc != nil { + s.Column = p.settings.ColumnFunc(s.Column) } list[i] = s.String() } @@ -219,11 +238,10 @@ func (p Paginator[T]) getOrder(page *Page) []string { // PrepareQuery adds pagination to the query. It sets the number of max rows to limit+1. func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder) { - p.SetDefaults() if page == nil { page = &Page{} } - page.SetDefaults(&p.PaginatorOptions) + page.SetDefaults(&p.settings) limit := page.Limit() q = q.Limit(limit + 1).Offset(page.Offset()).OrderBy(p.getOrder(page)...) diff --git a/page_test.go b/page_test.go index b59845a..050356d 100644 --- a/page_test.go +++ b/page_test.go @@ -12,17 +12,18 @@ import ( type T struct{} func TestPagination(t *testing.T) { - o := &pgkit.PaginatorOptions{ - ColumnFunc: strings.ToLower, - DefaultSize: 2, - MaxSize: 5, - Sort: []string{"ID"}, + const _MaxSize = 5 + options := []pgkit.PaginatorOption{ + pgkit.WithColumnFunc(strings.ToLower), + pgkit.WithDefaultSize(2), + pgkit.WithMaxSize(_MaxSize), + pgkit.WithDefaultSort("ID"), } - paginator := pgkit.NewPaginator[T](o) + paginator := pgkit.NewPaginator[T](options...) page := pgkit.NewPage(0, 0) result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize}, page) sql, args, err := query.ToSql() require.NoError(t, err) @@ -31,15 +32,15 @@ func TestPagination(t *testing.T) { result = paginator.PrepareResult(make([]T, 0), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize}, page) - result = paginator.PrepareResult(make([]T, o.MaxSize), page) - require.Len(t, result, int(o.MaxSize)) - require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize}, page) + result = paginator.PrepareResult(make([]T, _MaxSize), page) + require.Len(t, result, int(_MaxSize)) + require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize}, page) - result = paginator.PrepareResult(make([]T, o.MaxSize+2), page) - require.Len(t, result, int(o.MaxSize)) - require.Equal(t, &pgkit.Page{Page: 1, Size: o.MaxSize, More: true}, page) + result = paginator.PrepareResult(make([]T, _MaxSize+2), page) + require.Len(t, result, int(_MaxSize)) + require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize, More: true}, page) } func TestInvalidSort(t *testing.T) { @@ -125,7 +126,7 @@ func TestPaginationEdgeCases(t *testing.T) { require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql2) // Test case 3: empty options, NewPage - paginator3 := pgkit.NewPaginator[T](&pgkit.PaginatorOptions{}) + paginator3 := pgkit.NewPaginator[T]() page3 := pgkit.NewPage(0, 0) result3, query3 := paginator3.PrepareQuery(sq.Select("*").From("t"), page3) require.Len(t, result3, 0) @@ -135,25 +136,14 @@ func TestPaginationEdgeCases(t *testing.T) { require.NoError(t, err3) require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql3) - // Test case 4: options with defaults, struct assignment - paginator4 := pgkit.Paginator[T]{pgkit.PaginatorOptions{DefaultSize: 5, MaxSize: 20}} + // Test case 4: max size lower than default size + paginator4 := pgkit.NewPaginator[T](pgkit.WithDefaultSize(20), pgkit.WithMaxSize(5)) page4 := &pgkit.Page{} result4, query4 := paginator4.PrepareQuery(sq.Select("*").From("t"), page4) require.Len(t, result4, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: 5}, page4) + require.Equal(t, &pgkit.Page{Page: 1, Size: 20}, page4) sql4, _, err4 := query4.ToSql() require.NoError(t, err4) - require.Equal(t, "SELECT * FROM t LIMIT 6 OFFSET 0", sql4) - - // Test case 5: max size lower than default size - paginator5 := pgkit.NewPaginator[T](&pgkit.PaginatorOptions{DefaultSize: 20, MaxSize: 5}) - page5 := &pgkit.Page{} - result5, query5 := paginator5.PrepareQuery(sq.Select("*").From("t"), page5) - require.Len(t, result5, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: 20}, page5) - - sql5, _, err5 := query5.ToSql() - require.NoError(t, err5) - require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql5) + require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql4) } From bdb585c027daaeadf310ade572ed107fb79e717e Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 14:59:58 +0100 Subject: [PATCH 17/19] fix: update paginator option functions for clarity and consistency --- page.go | 6 +++++- page_test.go | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/page.go b/page.go index 7419b23..58a6002 100644 --- a/page.go +++ b/page.go @@ -179,24 +179,28 @@ type PaginatorSettings struct { type PaginatorOption func(*PaginatorSettings) +// WithDefaultSize sets the default page size. func WithDefaultSize(size uint32) PaginatorOption { return func(s *PaginatorSettings) { s.DefaultSize = size } } +// WithMaxSize sets the maximum page size. func WithMaxSize(size uint32) PaginatorOption { return func(s *PaginatorSettings) { s.MaxSize = size } } -func WithDefaultSort(sort ...string) PaginatorOption { +// WithSort sets the default sort order. +func WithSort(sort ...string) PaginatorOption { return func(s *PaginatorSettings) { s.Sort = sort } } +// WithColumnFunc sets a function to transform column names. func WithColumnFunc(f func(string) string) PaginatorOption { return func(s *PaginatorSettings) { s.ColumnFunc = f diff --git a/page_test.go b/page_test.go index 050356d..6898240 100644 --- a/page_test.go +++ b/page_test.go @@ -17,7 +17,7 @@ func TestPagination(t *testing.T) { pgkit.WithColumnFunc(strings.ToLower), pgkit.WithDefaultSize(2), pgkit.WithMaxSize(_MaxSize), - pgkit.WithDefaultSort("ID"), + pgkit.WithSort("ID"), } paginator := pgkit.NewPaginator[T](options...) page := pgkit.NewPage(0, 0) From 3587061da46f9f477822f3aa130a42d91fb96212 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 15:04:15 +0100 Subject: [PATCH 18/19] fix: refactor pagination test to use constants for size and sort options --- page_test.go | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/page_test.go b/page_test.go index 6898240..0aae21a 100644 --- a/page_test.go +++ b/page_test.go @@ -12,18 +12,21 @@ import ( type T struct{} func TestPagination(t *testing.T) { - const _MaxSize = 5 - options := []pgkit.PaginatorOption{ + const ( + DefaultSize = 2 + MaxSize = 5 + Sort = "ID" + ) + paginator := pgkit.NewPaginator[T]( pgkit.WithColumnFunc(strings.ToLower), - pgkit.WithDefaultSize(2), - pgkit.WithMaxSize(_MaxSize), - pgkit.WithSort("ID"), - } - paginator := pgkit.NewPaginator[T](options...) + pgkit.WithDefaultSize(DefaultSize), + pgkit.WithMaxSize(MaxSize), + pgkit.WithSort(Sort), + ) page := pgkit.NewPage(0, 0) result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) sql, args, err := query.ToSql() require.NoError(t, err) @@ -32,19 +35,19 @@ func TestPagination(t *testing.T) { result = paginator.PrepareResult(make([]T, 0), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) - result = paginator.PrepareResult(make([]T, _MaxSize), page) - require.Len(t, result, int(_MaxSize)) - require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize}, page) + result = paginator.PrepareResult(make([]T, MaxSize), page) + require.Len(t, result, int(MaxSize)) + require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) - result = paginator.PrepareResult(make([]T, _MaxSize+2), page) - require.Len(t, result, int(_MaxSize)) - require.Equal(t, &pgkit.Page{Page: 1, Size: _MaxSize, More: true}, page) + result = paginator.PrepareResult(make([]T, MaxSize+2), page) + require.Len(t, result, int(MaxSize)) + require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page) } func TestInvalidSort(t *testing.T) { - paginator := pgkit.NewPaginator[T](nil) + paginator := pgkit.NewPaginator[T]() page := pgkit.NewPage(0, 0) page.Sort = []pgkit.Sort{ {Column: "ID; DROP TABLE users;", Order: pgkit.Asc}, @@ -60,7 +63,7 @@ func TestInvalidSort(t *testing.T) { } func TestPageColumnInjection(t *testing.T) { - paginator := pgkit.NewPaginator[T](nil) + paginator := pgkit.NewPaginator[T]() page := pgkit.NewPage(0, 0) page.Column = "id; DROP TABLE users;--" @@ -73,7 +76,7 @@ func TestPageColumnInjection(t *testing.T) { } func TestPageColumnSpaces(t *testing.T) { - paginator := pgkit.NewPaginator[T](nil) + paginator := pgkit.NewPaginator[T]() page := pgkit.NewPage(0, 0) page.Column = "id, name" @@ -86,7 +89,7 @@ func TestPageColumnSpaces(t *testing.T) { } func TestSortOrderInjection(t *testing.T) { - paginator := pgkit.NewPaginator[T](nil) + paginator := pgkit.NewPaginator[T]() page := pgkit.NewPage(0, 0) page.Sort = []pgkit.Sort{ {Column: "id", Order: pgkit.Order("DESC; DROP TABLE users;--")}, @@ -104,7 +107,7 @@ func TestSortOrderInjection(t *testing.T) { func TestPaginationEdgeCases(t *testing.T) { // Test case 1: nil options, NewPage with zeros - paginator1 := pgkit.NewPaginator[T](nil) + paginator1 := pgkit.NewPaginator[T]() page1 := pgkit.NewPage(0, 0) result1, query1 := paginator1.PrepareQuery(sq.Select("*").From("t"), page1) require.Len(t, result1, 0) @@ -115,7 +118,7 @@ func TestPaginationEdgeCases(t *testing.T) { require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql1) // Test case 2: nil options, empty struct assignment - paginator2 := pgkit.NewPaginator[T](nil) + paginator2 := pgkit.NewPaginator[T]() page2 := &pgkit.Page{} result2, query2 := paginator2.PrepareQuery(sq.Select("*").From("t"), page2) require.Len(t, result2, 0) From d50a9b541a3ca30217b233a5403c1116a5f0e38c Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Feb 2026 15:06:16 +0100 Subject: [PATCH 19/19] fix: simplify length assertions in pagination test --- page_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/page_test.go b/page_test.go index 0aae21a..fa2e0ff 100644 --- a/page_test.go +++ b/page_test.go @@ -38,11 +38,11 @@ func TestPagination(t *testing.T) { require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) result = paginator.PrepareResult(make([]T, MaxSize), page) - require.Len(t, result, int(MaxSize)) + require.Len(t, result, MaxSize) require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) result = paginator.PrepareResult(make([]T, MaxSize+2), page) - require.Len(t, result, int(MaxSize)) + require.Len(t, result, MaxSize) require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page) }