diff --git a/builder.go b/builder.go index 6ea3256..8e8bba6 100644 --- a/builder.go +++ b/builder.go @@ -78,9 +78,7 @@ func (s StatementBuilder) UpdateRecordColumns(record interface{}, whereExpr sq.E // when filter is empty or nil, update the entire record var filter []string - if filterCols == nil || len(filterCols) == 0 { - filter = nil - } else { + if len(filterCols) != 0 { filter = filterCols } @@ -126,7 +124,7 @@ func createMap(k []string, v []interface{}, filterK []string) (map[string]interf m := make(map[string]interface{}, len(k)) for i := 0; i < len(k); i++ { - if filterK == nil || len(filterK) == 0 { + if len(filterK) == 0 { m[k[i]] = v[i] continue } diff --git a/dbtype/bigint_test.go b/dbtype/bigint_test.go index 198d129..e64e871 100644 --- a/dbtype/bigint_test.go +++ b/dbtype/bigint_test.go @@ -57,8 +57,8 @@ func TestBigIntScan(t *testing.T) { assert.Error(t, b.Scan("1.")) assert.NoError(t, b.Scan("100")) - assert.True(t, b.Uint64() == 100) + assert.Equal(t, uint64(100), b.Uint64()) assert.NoError(t, b.Scan("2e0")) - assert.True(t, b.Uint64() == 2) + assert.Equal(t, uint64(2), b.Uint64()) } diff --git a/internal/reflectx/reflect.go b/internal/reflectx/reflect.go index 0b10994..91b5285 100644 --- a/internal/reflectx/reflect.go +++ b/internal/reflectx/reflect.go @@ -3,7 +3,6 @@ // allows for Go-compatible named attribute access, including accessing embedded // struct attributes and the ability to use functions and struct tags to // customize field names. -// package reflectx import ( @@ -167,7 +166,7 @@ func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { // to a struct. Returns empty int slice for each name not found. func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { r := make([][]int, 0, len(names)) - m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { + _ = m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { if i == nil { r = append(r, []int{}) } else { diff --git a/internal/reflectx/reflect_test.go b/internal/reflectx/reflect_test.go index e73af5b..1f23309 100644 --- a/internal/reflectx/reflect_test.go +++ b/internal/reflectx/reflect_test.go @@ -4,6 +4,9 @@ import ( "reflect" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func ival(v reflect.Value) int { @@ -22,17 +25,11 @@ func TestBasic(t *testing.T) { m := NewMapperFunc("", func(s string) string { return s }) v := m.FieldByName(fv, "A") - if ival(v) != f.A { - t.Errorf("Expecting %d, got %d", ival(v), f.A) - } + assert.Equal(t, f.A, ival(v)) v = m.FieldByName(fv, "B") - if ival(v) != f.B { - t.Errorf("Expecting %d, got %d", f.B, ival(v)) - } + assert.Equal(t, f.B, ival(v)) v = m.FieldByName(fv, "C") - if ival(v) != f.C { - t.Errorf("Expecting %d, got %d", f.C, ival(v)) - } + assert.Equal(t, f.C, ival(v)) } func TestBasicEmbedded(t *testing.T) { @@ -57,40 +54,29 @@ func TestBasicEmbedded(t *testing.T) { z.A = 1 z.B = 2 z.C = 4 - z.Bar.Foo.A = 3 + z.Foo.A = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) - if len(fields.Index) != 5 { - t.Errorf("Expecting 5 fields") - } + assert.Len(t, fields.Index, 5) // for _, fi := range fields.Index { // log.Println(fi) // } v := m.FieldByName(zv, "A") - if ival(v) != z.A { - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } + assert.Equal(t, z.A, ival(v)) v = m.FieldByName(zv, "Bar.B") - if ival(v) != z.Bar.B { - t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) - } + assert.Equal(t, z.B, ival(v)) v = m.FieldByName(zv, "Bar.A") - if ival(v) != z.Bar.Foo.A { - t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) - } + assert.Equal(t, z.Foo.A, ival(v)) v = m.FieldByName(zv, "Bar.C") - if _, ok := v.Interface().(int); ok { - t.Errorf("Expecting Bar.C to not exist") - } + _, ok := v.Interface().(int) + assert.False(t, ok, "Expecting Bar.C to not exist") fi := fields.GetByPath("Bar.C") - if fi != nil { - t.Errorf("Bar.C should not exist") - } + assert.Nil(t, fi, "Bar.C should not exist") } func TestEmbeddedSimple(t *testing.T) { @@ -127,27 +113,21 @@ func TestBasicEmbeddedWithTags(t *testing.T) { z := Baz{} z.A = 1 z.B = 2 - z.Bar.Foo.A = 3 + z.Foo.A = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) - if len(fields.Index) != 5 { - t.Errorf("Expecting 5 fields") - } + assert.Len(t, fields.Index, 5) // for _, fi := range fields.index { // log.Println(fi) // } v := m.FieldByName(zv, "a") - if ival(v) != z.A { // the dominant field - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } + assert.Equal(t, z.A, ival(v)) // the dominant field v = m.FieldByName(zv, "b") - if ival(v) != z.B { - t.Errorf("Expecting %d, got %d", z.B, ival(v)) - } + assert.Equal(t, z.B, ival(v)) } func TestBasicEmbeddedWithSameName(t *testing.T) { @@ -171,22 +151,14 @@ func TestBasicEmbeddedWithSameName(t *testing.T) { zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) - if len(fields.Index) != 4 { - t.Errorf("Expecting 3 fields, found %d", len(fields.Index)) - } + assert.Len(t, fields.Index, 4) v := m.FieldByName(zv, "a") - if ival(v) != z.A { // the dominant field - t.Errorf("Expecting %d, got %d", z.A, ival(v)) - } + assert.Equal(t, z.A, ival(v)) // the dominant field v = m.FieldByName(zv, "b") - if ival(v) != z.B { - t.Errorf("Expecting %d, got %d", z.B, ival(v)) - } + assert.Equal(t, z.B, ival(v)) v = m.FieldByName(zv, "Foo") - if ival(v) != z.Foo.Foo { - t.Errorf("Expecting %d, got %d", z.Foo.Foo, ival(v)) - } + assert.Equal(t, z.Foo.Foo, ival(v)) } func TestFlatTags(t *testing.T) { @@ -205,13 +177,9 @@ func TestFlatTags(t *testing.T) { pv := reflect.ValueOf(post) v := m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) v = m.FieldByName(pv, "title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } + assert.Equal(t, post.Asset.Title, v.Interface().(string)) } func TestNestedStruct(t *testing.T) { @@ -237,21 +205,14 @@ func TestNestedStruct(t *testing.T) { pv := reflect.ValueOf(post) v := m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) v = m.FieldByName(pv, "title") - if _, ok := v.Interface().(string); ok { - t.Errorf("Expecting field to not exist") - } + _, ok := v.Interface().(string) + assert.False(t, ok, "Expecting field to not exist") v = m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } + assert.Equal(t, post.Title, v.Interface().(string)) v = m.FieldByName(pv, "asset.details.active") - if v.Interface().(bool) != post.Asset.Details.Active { - t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) - } + assert.Equal(t, post.Details.Active, v.Interface().(bool)) } func TestInlineStruct(t *testing.T) { @@ -272,18 +233,12 @@ func TestInlineStruct(t *testing.T) { ev := reflect.ValueOf(em) fields := m.TypeMap(reflect.TypeOf(em)) - if len(fields.Index) != 6 { - t.Errorf("Expecting 6 fields") - } + assert.Len(t, fields.Index, 6) v := m.FieldByName(ev, "employee.name") - if v.Interface().(string) != em.Employee.Name { - t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) - } + assert.Equal(t, em.Employee.Name, v.Interface().(string)) v = m.FieldByName(ev, "boss.id") - if ival(v) != em.Boss.ID { - t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) - } + assert.Equal(t, em.Boss.ID, ival(v)) } func TestRecursiveStruct(t *testing.T) { @@ -317,7 +272,7 @@ func TestFieldsEmbedded(t *testing.T) { pp := PP{} pp.Person.Name = "Peter" pp.Place.Name = "Toronto" - pp.Article.Title = "Best city ever" + pp.Title = "Best city ever" fields := m.TypeMap(reflect.TypeOf(pp)) // for i, f := range fields { @@ -327,62 +282,36 @@ func TestFieldsEmbedded(t *testing.T) { ppv := reflect.ValueOf(pp) v := m.FieldByName(ppv, "person.name") - if v.Interface().(string) != pp.Person.Name { - t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) - } + assert.Equal(t, pp.Person.Name, v.Interface().(string)) v = m.FieldByName(ppv, "name") - if v.Interface().(string) != pp.Place.Name { - t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) - } + assert.Equal(t, pp.Place.Name, v.Interface().(string)) v = m.FieldByName(ppv, "title") - if v.Interface().(string) != pp.Article.Title { - t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) - } + assert.Equal(t, pp.Title, v.Interface().(string)) fi := fields.GetByPath("person") - if _, ok := fi.Options["required"]; !ok { - t.Errorf("Expecting required option to be set") - } - if !fi.Embedded { - t.Errorf("Expecting field to be embedded") - } - if len(fi.Index) != 1 || fi.Index[0] != 0 { - t.Errorf("Expecting index to be [0]") - } + _, ok := fi.Options["required"] + assert.True(t, ok, "Expecting required option to be set") + assert.True(t, fi.Embedded, "Expecting field to be embedded") + assert.Equal(t, []int{0}, fi.Index) fi = fields.GetByPath("person.name") - if fi == nil { - t.Errorf("Expecting person.name to exist") - } - if fi.Path != "person.name" { - t.Errorf("Expecting %s, got %s", "person.name", fi.Path) - } - if fi.Options["size"] != "64" { - t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) - } + require.NotNil(t, fi, "Expecting person.name to exist") + assert.Equal(t, "person.name", fi.Path) + assert.Equal(t, "64", fi.Options["size"]) fi = fields.GetByTraversal([]int{1, 0}) - if fi == nil { - t.Errorf("Expecting traveral to exist") - } - if fi.Path != "name" { - t.Errorf("Expecting %s, got %s", "name", fi.Path) - } + require.NotNil(t, fi, "Expecting traversal to exist") + assert.Equal(t, "name", fi.Path) fi = fields.GetByTraversal([]int{2}) - if fi == nil { - t.Errorf("Expecting traversal to exist") - } - if _, ok := fi.Options["required"]; !ok { - t.Errorf("Expecting required option to be set") - } + require.NotNil(t, fi, "Expecting traversal to exist") + _, ok = fi.Options["required"] + assert.True(t, ok, "Expecting required option to be set") trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) - if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { - t.Errorf("Expecting traversal: %v", trs) - } + assert.Equal(t, [][]int{{0, 0}, {1, 0}, {2, 0}}, trs) } func TestPtrFields(t *testing.T) { @@ -399,18 +328,12 @@ func TestPtrFields(t *testing.T) { pv := reflect.ValueOf(post) fields := m.TypeMap(reflect.TypeOf(post)) - if len(fields.Index) != 3 { - t.Errorf("Expecting 3 fields") - } + assert.Len(t, fields.Index, 3) v := m.FieldByName(pv, "asset.title") - if v.Interface().(string) != post.Asset.Title { - t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) - } + assert.Equal(t, post.Title, v.Interface().(string)) v = m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) } func TestNamedPtrFields(t *testing.T) { @@ -436,30 +359,18 @@ func TestNamedPtrFields(t *testing.T) { pv := reflect.ValueOf(post) fields := m.TypeMap(reflect.TypeOf(post)) - if len(fields.Index) != 9 { - t.Errorf("Expecting 9 fields") - } + assert.Len(t, fields.Index, 9) v := m.FieldByName(pv, "asset1.title") - if v.Interface().(string) != post.Asset1.Title { - t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) - } + assert.Equal(t, post.Asset1.Title, v.Interface().(string)) v = m.FieldByName(pv, "asset1.owner.name") - if v.Interface().(string) != post.Asset1.Owner.Name { - t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) - } + assert.Equal(t, post.Asset1.Owner.Name, v.Interface().(string)) v = m.FieldByName(pv, "asset2.title") - if v.Interface().(string) != post.Asset2.Title { - t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) - } + assert.Equal(t, post.Asset2.Title, v.Interface().(string)) v = m.FieldByName(pv, "asset2.owner.name") - if v.Interface().(string) != post.Asset2.Owner.Name { - t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) - } + assert.Equal(t, post.Asset2.Owner.Name, v.Interface().(string)) v = m.FieldByName(pv, "author") - if v.Interface().(string) != post.Author { - t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) - } + assert.Equal(t, post.Author, v.Interface().(string)) } func TestFieldMap(t *testing.T) { @@ -474,18 +385,10 @@ func TestFieldMap(t *testing.T) { fm := m.FieldMap(reflect.ValueOf(f)) - if len(fm) != 3 { - t.Errorf("Expecting %d keys, got %d", 3, len(fm)) - } - if fm["a"].Interface().(int) != 1 { - t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) - } - if fm["b"].Interface().(int) != 2 { - t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) - } - if fm["c"].Interface().(int) != 3 { - t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) - } + assert.Len(t, fm, 3) + assert.Equal(t, 1, fm["a"].Interface().(int)) + assert.Equal(t, 2, fm["b"].Interface().(int)) + assert.Equal(t, 3, fm["c"].Interface().(int)) } func TestTagNameMapping(t *testing.T) { @@ -504,9 +407,8 @@ func TestTagNameMapping(t *testing.T) { mapping := m.TypeMap(reflect.TypeOf(strategy)) for _, key := range []string{"strategy_id", "STRATEGYNAME"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } } @@ -522,9 +424,8 @@ func TestMapping(t *testing.T) { mapping := m.TypeMap(reflect.TypeOf(p)) for _, key := range []string{"id", "name", "wears_glasses"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } type SportsPerson struct { @@ -535,9 +436,8 @@ func TestMapping(t *testing.T) { s := SportsPerson{Weight: 100, Age: 30, Person: p} mapping = m.TypeMap(reflect.TypeOf(s)) for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } type RugbyPlayer struct { @@ -549,14 +449,12 @@ func TestMapping(t *testing.T) { r := RugbyPlayer{12, true, false, s} mapping = m.TypeMap(reflect.TypeOf(r)) for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { - if fi := mapping.GetByPath(key); fi == nil { - t.Errorf("Expecting to find key %s in mapping but did not.", key) - } + fi := mapping.GetByPath(key) + assert.NotNil(t, fi, "Expecting to find key %s in mapping", key) } - if fi := mapping.GetByPath("isallblack"); fi != nil { - t.Errorf("Expecting to ignore `IsAllBlack` field") - } + fi := mapping.GetByPath("isallblack") + assert.Nil(t, fi, "Expecting to ignore `IsAllBlack` field") } func TestGetByTraversal(t *testing.T) { @@ -610,20 +508,12 @@ func TestGetByTraversal(t *testing.T) { for i, tc := range testCases { fi := tm.GetByTraversal(tc.Index) if tc.ExpectNil { - if fi != nil { - t.Errorf("%d: expected nil, got %v", i, fi) - } + assert.Nil(t, fi, "%d: expected nil", i) continue } - if fi == nil { - t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) - continue - } - - if fi.Name != tc.ExpectedName { - t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) - } + require.NotNil(t, fi, "%d: expected %s, got nil", i, tc.ExpectedName) + assert.Equal(t, tc.ExpectedName, fi.Name, "%d: name mismatch", i) } } @@ -642,7 +532,7 @@ func TestMapperMethodsByName(t *testing.T) { A0 *B `db:"A0"` B `db:"A1"` A2 int - a3 int + a3 int //nolint } val := &A{ @@ -740,37 +630,21 @@ func TestMapperMethodsByName(t *testing.T) { m := NewMapperFunc("db", func(n string) string { return n }) v := reflect.ValueOf(val) values := m.FieldsByName(v, names) - if len(values) != len(testCases) { - t.Errorf("expected %d values, got %d", len(testCases), len(values)) - t.FailNow() - } + require.Len(t, values, len(testCases)) indexes := m.TraversalsByName(v.Type(), names) - if len(indexes) != len(testCases) { - t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) - t.FailNow() - } + require.Len(t, indexes, len(testCases)) for i, val := range values { tc := testCases[i] traversal := indexes[i] - if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { - t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) - t.FailNow() - } + require.Equal(t, tc.ExpectedIndexes, traversal) val = reflect.Indirect(val) if tc.ExpectInvalid { - if val.IsValid() { - t.Errorf("%d: expected zero value, got %v", i, val) - } - continue - } - if !val.IsValid() { - t.Errorf("%d: expected valid value, got %v", i, val) + assert.False(t, val.IsValid(), "%d: expected zero value, got %v", i, val) continue } + require.True(t, val.IsValid(), "%d: expected valid value", i) actualValue := reflect.Indirect(val).Interface() - if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { - t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) - } + assert.Equal(t, tc.ExpectedValue, actualValue, "%d: value mismatch", i) } } @@ -821,13 +695,9 @@ func TestFieldByIndexes(t *testing.T) { for i, tc := range testCases { checkResults := func(v reflect.Value) { if tc.expectedValue == nil { - if !v.IsNil() { - t.Errorf("%d: expected nil, actual %v", i, v.Interface()) - } + assert.True(t, v.IsNil(), "%d: expected nil", i) } else { - if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { - t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) - } + assert.Equal(t, tc.expectedValue, v.Interface(), "%d: value mismatch", i) } } @@ -843,21 +713,11 @@ func TestMustBe(t *testing.T) { mustBe(typ, reflect.Struct) defer func() { - if r := recover(); r != nil { - valueErr, ok := r.(*reflect.ValueError) - if !ok { - t.Errorf("unexpected Method: %s", valueErr.Method) - t.Error("expected panic with *reflect.ValueError") - return - } - if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { - } - if valueErr.Kind != reflect.String { - t.Errorf("unexpected Kind: %s", valueErr.Kind) - } - } else { - t.Error("expected panic") - } + r := recover() + require.NotNil(t, r, "expected panic") + valueErr, ok := r.(*reflect.ValueError) + require.True(t, ok, "expected panic with *reflect.ValueError, got %T", r) + assert.Equal(t, reflect.String, valueErr.Kind) }() typ = reflect.TypeOf("string") diff --git a/mapper.go b/mapper.go index 29c7745..ca15e08 100644 --- a/mapper.go +++ b/mapper.go @@ -80,7 +80,7 @@ func MapWithOptions(record interface{}, options *MapOptions) ([]string, []interf for _, fi := range fieldMap { // Skip any fields which do not specify the `db:".."` tag - if strings.Index(string(fi.Field.Tag), dbTagPrefix) < 0 { + if !strings.Contains(string(fi.Field.Tag), dbTagPrefix) { continue } diff --git a/page.go b/page.go index 86f9efc..58a6002 100644 --- a/page.go +++ b/page.go @@ -87,6 +87,24 @@ func NewPage(size, page uint32, sort ...Sort) *Page { } } +func (p *Page) SetDefaults(o *PaginatorSettings) { + if o == nil { + o = &PaginatorSettings{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + } + } + if p.Size == 0 { + p.Size = o.DefaultSize + } + if p.Size > o.MaxSize { + p.Size = o.MaxSize + } + if p.Page == 0 { + p.Page = 1 + } +} + func (p *Page) GetOrder(defaultSort ...string) []Sort { // if page has sort, use it if p != nil && len(p.Sort) != 0 { @@ -110,7 +128,7 @@ func (p *Page) GetOrder(defaultSort ...string) []Sort { } // use column sort := make([]Sort, 0) - for _, part := range strings.Split(p.Column, ",") { + for part := range strings.SplitSeq(p.Column, ",") { part = strings.TrimSpace(part) if part == "" { continue @@ -135,66 +153,87 @@ func (p *Page) Offset() uint64 { } func (p *Page) Limit() uint64 { - var n = uint64(DefaultPageSize) + n := uint64(DefaultPageSize) if p != nil && p.Size != 0 { n = uint64(p.Size) } - if n > MaxPageSize { - n = MaxPageSize - } return n } -// PaginatorOption is a function that sets an option on a paginator. -type PaginatorOption[T any] func(*Paginator[T]) +// PaginatorSettings are the settings for the paginator. +type PaginatorSettings struct { + // DefaultSize is the default number of rows per page. + // If zero, DefaultPageSize is used. + DefaultSize uint32 + + // MaxSize is the maximum number of rows per page. + // If zero, MaxPageSize is used. If less than DefaultSize, it is set to DefaultSize. + MaxSize uint32 + + // Sort is the default sort order. + Sort []string + + // ColumnFunc is a transformation applied to column names. + ColumnFunc func(string) string +} + +type PaginatorOption func(*PaginatorSettings) // WithDefaultSize sets the default page size. -func WithDefaultSize[T any](size uint32) PaginatorOption[T] { - return func(p *Paginator[T]) { p.defaultSize = size } +func WithDefaultSize(size uint32) PaginatorOption { + return func(s *PaginatorSettings) { + s.DefaultSize = size + } } // WithMaxSize sets the maximum page size. -func WithMaxSize[T any](size uint32) PaginatorOption[T] { - return func(p *Paginator[T]) { p.maxSize = size } +func WithMaxSize(size uint32) PaginatorOption { + return func(s *PaginatorSettings) { + s.MaxSize = size + } } // WithSort sets the default sort order. -func WithSort[T any](sort ...string) PaginatorOption[T] { - return func(p *Paginator[T]) { p.defaultSort = sort } +func WithSort(sort ...string) PaginatorOption { + return func(s *PaginatorSettings) { + s.Sort = sort + } } // WithColumnFunc sets a function to transform column names. -func WithColumnFunc[T any](f func(string) string) PaginatorOption[T] { - return func(p *Paginator[T]) { p.columnFunc = f } +func WithColumnFunc(f func(string) string) PaginatorOption { + return func(s *PaginatorSettings) { + s.ColumnFunc = f + } } // NewPaginator creates a new paginator with the given options. -// Default page size is 10 and max size is 50. -func NewPaginator[T any](options ...PaginatorOption[T]) Paginator[T] { - p := Paginator[T]{ - defaultSize: DefaultPageSize, - maxSize: MaxPageSize, +// If MaxSize is less than DefaultSize, MaxSize is set to DefaultSize. +func NewPaginator[T any](options ...PaginatorOption) Paginator[T] { + settings := &PaginatorSettings{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + } + for _, option := range options { + option(settings) } - for _, opt := range options { - opt(&p) + if settings.MaxSize < settings.DefaultSize { + settings.MaxSize = settings.DefaultSize } - return p + return Paginator[T]{settings: *settings} } // Paginator is a helper to paginate results. type Paginator[T any] struct { - defaultSize uint32 - maxSize uint32 - defaultSort []string - columnFunc func(string) string + settings PaginatorSettings } func (p Paginator[T]) getOrder(page *Page) []string { - sort := page.GetOrder(p.defaultSort...) + sort := page.GetOrder(p.settings.Sort...) list := make([]string, len(sort)) for i, s := range sort { - if p.columnFunc != nil { - s.Column = p.columnFunc(s.Column) + if p.settings.ColumnFunc != nil { + s.Column = p.settings.ColumnFunc(s.Column) } list[i] = s.String() } @@ -203,16 +242,13 @@ func (p Paginator[T]) getOrder(page *Page) []string { // PrepareQuery adds pagination to the query. It sets the number of max rows to limit+1. func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder) { - if page != nil { - if page.Size == 0 { - page.Size = p.defaultSize - } - if page.Size > p.maxSize { - page.Size = p.maxSize - } + if page == nil { + page = &Page{} } + page.SetDefaults(&p.settings) + limit := page.Limit() - q = q.Limit(page.Limit() + 1).Offset(page.Offset()).OrderBy(p.getOrder(page)...) + q = q.Limit(limit + 1).Offset(page.Offset()).OrderBy(p.getOrder(page)...) return make([]T, 0, limit+1), q } diff --git a/page_test.go b/page_test.go index bdad83f..fa2e0ff 100644 --- a/page_test.go +++ b/page_test.go @@ -17,11 +17,11 @@ func TestPagination(t *testing.T) { MaxSize = 5 Sort = "ID" ) - paginator := pgkit.NewPaginator( - pgkit.WithColumnFunc[T](strings.ToLower), - pgkit.WithDefaultSize[T](DefaultSize), - pgkit.WithMaxSize[T](MaxSize), - pgkit.WithSort[T](Sort), + paginator := pgkit.NewPaginator[T]( + pgkit.WithColumnFunc(strings.ToLower), + pgkit.WithDefaultSize(DefaultSize), + pgkit.WithMaxSize(MaxSize), + pgkit.WithSort(Sort), ) page := pgkit.NewPage(0, 0) result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) @@ -104,3 +104,49 @@ func TestSortOrderInjection(t *testing.T) { require.Equal(t, "SELECT * FROM t ORDER BY \"id\" ASC, \"name\" DESC, \"created_at\" ASC LIMIT 11 OFFSET 0", sql) require.Empty(t, args) } + +func TestPaginationEdgeCases(t *testing.T) { + // Test case 1: nil options, NewPage with zeros + paginator1 := pgkit.NewPaginator[T]() + page1 := pgkit.NewPage(0, 0) + result1, query1 := paginator1.PrepareQuery(sq.Select("*").From("t"), page1) + require.Len(t, result1, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 10}, page1) + + sql1, _, err1 := query1.ToSql() + require.NoError(t, err1) + require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql1) + + // Test case 2: nil options, empty struct assignment + paginator2 := pgkit.NewPaginator[T]() + page2 := &pgkit.Page{} + result2, query2 := paginator2.PrepareQuery(sq.Select("*").From("t"), page2) + require.Len(t, result2, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 10}, page2) + + sql2, _, err2 := query2.ToSql() + require.NoError(t, err2) + require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql2) + + // Test case 3: empty options, NewPage + paginator3 := pgkit.NewPaginator[T]() + page3 := pgkit.NewPage(0, 0) + result3, query3 := paginator3.PrepareQuery(sq.Select("*").From("t"), page3) + require.Len(t, result3, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 10}, page3) + + sql3, _, err3 := query3.ToSql() + require.NoError(t, err3) + require.Equal(t, "SELECT * FROM t LIMIT 11 OFFSET 0", sql3) + + // Test case 4: max size lower than default size + paginator4 := pgkit.NewPaginator[T](pgkit.WithDefaultSize(20), pgkit.WithMaxSize(5)) + page4 := &pgkit.Page{} + result4, query4 := paginator4.PrepareQuery(sq.Select("*").From("t"), page4) + require.Len(t, result4, 0) + require.Equal(t, &pgkit.Page{Page: 1, Size: 20}, page4) + + sql4, _, err4 := query4.ToSql() + require.NoError(t, err4) + require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql4) +} diff --git a/querier.go b/querier.go index 294d8d2..548ced7 100644 --- a/querier.go +++ b/querier.go @@ -132,13 +132,15 @@ func (q *Querier) BatchExec(ctx context.Context, queries Queries) ([]pgconn.Comm } // Send batch - var results pgx.BatchResults + type batchSender interface { + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults + } + batcher := batchSender(q.pool) if q.Tx != nil { - results = q.Tx.SendBatch(ctx, batch) - } else { - results = q.pool.SendBatch(ctx, batch) + batcher = batchSender(q.Tx) } - defer results.Close() + results := batcher.SendBatch(ctx, batch) + defer results.Close() //nolint:errcheck // Exec the number of times as we have queries in the batch so we may get the exec // result and potential error response. @@ -291,7 +293,7 @@ func (r RawSQL) ToSql() (string, []interface{}, error) { return r.Query, r.Args, nil } - if r.Args == nil || len(r.Args) == 0 { + if len(r.Args) == 0 { return r.Query, r.Args, nil // assume no params passed } diff --git a/tests/helpers_test.go b/tests/helpers_test.go index dbddd16..f25ced6 100644 --- a/tests/helpers_test.go +++ b/tests/helpers_test.go @@ -4,37 +4,11 @@ import ( "context" "fmt" "testing" - "time" "github.com/stretchr/testify/assert" ) -func truncateAllTables(t *testing.T) { - truncateTable(t, "accounts") - truncateTable(t, "reviews") - truncateTable(t, "logs") - truncateTable(t, "stats") - truncateTable(t, "articles") -} - func truncateTable(t *testing.T, tableName string) { _, err := DB.Conn.Exec(context.Background(), fmt.Sprintf(`TRUNCATE TABLE %q CASCADE`, tableName)) assert.NoError(t, err) } - -func measureCall(fn func() error) (time.Duration, error) { - t0 := time.Now() - err := fn() - return time.Since(t0), err -} - -func measureCalls(n int, fn func() error) (time.Duration, error) { - t0 := time.Now() - for i := 0; i < n; i++ { - err := fn() - if err != nil { - return time.Since(t0), err - } - } - return time.Since(t0), nil -} diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index df65d4c..281f387 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -4,9 +4,7 @@ import ( "bufio" "bytes" "context" - "encoding/hex" "encoding/json" - "errors" "fmt" "io" "log" @@ -68,7 +66,7 @@ func TestSugarInsertAndSelectRows(t *testing.T) { require.NoError(t, err) require.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "peter", accounts[0].Name) } @@ -91,7 +89,7 @@ func TestInsertAndSelectRows(t *testing.T) { require.NoError(t, err) require.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "peter", accounts[0].Name) } @@ -114,7 +112,7 @@ func TestSugarInsertAndSelectRecords(t *testing.T) { err = DB.Query.GetAll(context.Background(), q2, &accounts) assert.NoError(t, err) assert.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "joe", accounts[0].Name) // Select one -- into object @@ -122,7 +120,7 @@ func TestSugarInsertAndSelectRecords(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, account) assert.NoError(t, err) assert.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "joe", accounts[0].Name) // Select one -- into struct value @@ -130,7 +128,7 @@ func TestSugarInsertAndSelectRecords(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &accountv) assert.NoError(t, err) assert.Len(t, accounts, 1) - assert.True(t, accounts[0].ID != 0) + assert.NotZero(t, accounts[0].ID) assert.Equal(t, "joe", accounts[0].Name) } @@ -161,7 +159,7 @@ func TestSugarInsertAndSelectRecordsReturningID(t *testing.T) { err := DB.Query.QueryRow(context.Background(), DB.SQL.InsertRecord(rec).Suffix(`RETURNING "id"`)).Scan(&rec.ID) assert.NoError(t, err) - assert.True(t, rec.ID > 0) + assert.NotZero(t, rec.ID) // Select one -- into object account := &Account{} @@ -198,15 +196,15 @@ func TestInsertAndSelectRecords(t *testing.T) { // Build select query rows, err := DB.Conn.Query(context.Background(), selectq, args...) - defer rows.Close() assert.NoError(t, err) + defer rows.Close() // Scan result into *Account object a := &Account{} err = DB.Query.Scan.ScanOne(a, rows) assert.NoError(t, err) - assert.True(t, a.ID != 0) + assert.NotZero(t, a.ID) assert.Equal(t, "joe", a.Name) assert.Equal(t, true, a.Disabled) @@ -225,8 +223,7 @@ func TestSugarQueryWithNoResults(t *testing.T) { var account interface{} err := DB.Query.GetOne(context.Background(), q, &account) - assert.True(t, errors.Is(err, pgkit.ErrNoRows)) - assert.True(t, errors.Is(err, pgx.ErrNoRows)) + assert.ErrorIs(t, err, pgkit.ErrNoRows) } func TestQueryWithNoResults(t *testing.T) { @@ -245,8 +242,8 @@ func TestQueryWithNoResults(t *testing.T) { // or, with more verbose method: { rows, err := DB.Conn.Query(context.Background(), selectq, args...) - defer rows.Close() assert.NoError(t, err) + defer rows.Close() err = DB.Query.Scan.ScanAll(&accounts, rows) @@ -258,7 +255,7 @@ func TestQueryWithNoResults(t *testing.T) { { var a *Account err = DB.Query.Scan.Get(context.Background(), DB.Conn, a, selectq, args...) - assert.True(t, errors.Is(err, pgx.ErrNoRows)) + assert.ErrorIs(t, err, pgkit.ErrNoRows) } } @@ -321,7 +318,7 @@ func TestRecordsWithJSONStruct(t *testing.T) { // Assert record mapping for nested jsonb struct cols, _, err := pgkit.Map(article) assert.NoError(t, err) - sort.Sort(sort.StringSlice(cols)) + sort.Strings(cols) assert.Equal(t, []string{"alias", "author", "content"}, cols) // Insert record @@ -355,7 +352,7 @@ func TestRowsWithBigInt(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &sout) assert.NoError(t, err) assert.Equal(t, "count", sout.Key) - assert.True(t, sout.Num.Int64() == 2) + assert.Equal(t, int64(2), sout.Num.Int64()) assert.True(t, sout.Num.IsValid) assert.False(t, sout.Rating.IsValid) @@ -376,7 +373,7 @@ func TestRowsWithBigInt(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &sout) assert.NoError(t, err) assert.Equal(t, "count2", sout.Key) - assert.True(t, sout.Num.String() == "12323942398472837489234") + assert.Equal(t, "12323942398472837489234", sout.Num.String()) assert.True(t, sout.Num.IsValid) assert.False(t, sout.Rating.IsValid) @@ -400,9 +397,9 @@ func TestRowsWithBigInt(t *testing.T) { err = DB.Query.GetOne(context.Background(), q2, &sout) assert.NoError(t, err) assert.Equal(t, "count3", sout.Key) - assert.True(t, sout.Num.String() == "44") + assert.Equal(t, "44", sout.Num.String()) assert.True(t, sout.Num.IsValid) - assert.True(t, sout.Rating.String() == "5") + assert.Equal(t, "5", sout.Rating.String()) assert.True(t, sout.Rating.IsValid) } @@ -464,7 +461,7 @@ func TestSugarUpdateRecord(t *testing.T) { err = DB.Query.GetOne(context.Background(), DB.SQL.Select("*").From("accounts"), accountResp) assert.NoError(t, err) assert.Equal(t, "julia", accountResp.Name) - assert.True(t, accountResp.ID != 0) + assert.NotZero(t, accountResp.ID) // Update accountResp.Name = "JUL14" @@ -476,9 +473,9 @@ func TestSugarUpdateRecord(t *testing.T) { err = DB.Query.GetOne(context.Background(), DB.SQL.Select("*").From("accounts"), accountResp2) assert.NoError(t, err) assert.Equal(t, "JUL14", accountResp.Name) - assert.True(t, accountResp2.ID != 0) - assert.True(t, accountResp2.ID == accountResp.ID) - assert.True(t, accountResp2.CreatedAt == accountResp.CreatedAt) + assert.NotZero(t, accountResp2.ID) + assert.Equal(t, accountResp2.ID, accountResp.ID) + assert.Equal(t, accountResp2.CreatedAt, accountResp.CreatedAt) } func TestSugarUpdateRecordColumns(t *testing.T) { @@ -494,7 +491,7 @@ func TestSugarUpdateRecordColumns(t *testing.T) { err = DB.Query.GetOne(context.Background(), DB.SQL.Select("*").From("accounts"), accountResp) assert.NoError(t, err) assert.Equal(t, "peter", accountResp.Name) - assert.True(t, accountResp.ID != 0) + assert.NotZero(t, accountResp.ID) assert.False(t, accountResp.Disabled) // Update @@ -509,15 +506,15 @@ func TestSugarUpdateRecordColumns(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "peter", accountResp2.Name) // should not have changed, expect as previous was recorded assert.True(t, accountResp2.Disabled) - assert.True(t, accountResp2.ID != 0) - assert.True(t, accountResp2.ID == accountResp.ID) + assert.NotZero(t, accountResp2.ID) + assert.Equal(t, accountResp2.ID, accountResp.ID) } func TestTransactionBasics(t *testing.T) { truncateTable(t, "accounts") // Insert some rows + commit - pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { + err := pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { // Insert 1 insertq, args, err := DB.SQL.Insert("accounts").Columns("name", "disabled").Values("peter", false).ToSql() require.NoError(t, err) @@ -534,6 +531,7 @@ func TestTransactionBasics(t *testing.T) { return nil }) + require.NoError(t, err) // Assert above records have been made { @@ -542,12 +540,12 @@ func TestTransactionBasics(t *testing.T) { err := DB.Query.GetAll(context.Background(), q, &accounts) require.NoError(t, err) assert.Len(t, accounts, 2) - assert.True(t, accounts[0].Name == "mario") - assert.True(t, accounts[1].Name == "peter") + assert.Equal(t, "mario", accounts[0].Name) + assert.Equal(t, "peter", accounts[1].Name) } // Insert some rows -- but rollback - pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { + err = pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { // Insert 1 insertq, args, err := DB.SQL.Insert("accounts").Columns("name", "disabled").Values("zelda", false).ToSql() require.NoError(t, err) @@ -564,6 +562,7 @@ func TestTransactionBasics(t *testing.T) { return fmt.Errorf("something bad happend") }) + require.Error(t, err) // Assert above records were rolled back { @@ -578,7 +577,7 @@ func TestTransactionBasics(t *testing.T) { func TestSugarTransaction(t *testing.T) { truncateTable(t, "accounts") - pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { + err := pgx.BeginFunc(context.Background(), DB.Conn, func(tx pgx.Tx) error { rec1 := &Account{ Name: "peter", Disabled: false, @@ -599,6 +598,7 @@ func TestSugarTransaction(t *testing.T) { return nil }) + require.NoError(t, err) // Assert above records have been made { @@ -607,8 +607,8 @@ func TestSugarTransaction(t *testing.T) { err := DB.Query.GetAll(context.Background(), q, &accounts) require.NoError(t, err) assert.Len(t, accounts, 2) - assert.True(t, accounts[0].Name == "mario") - assert.True(t, accounts[1].Name == "peter") + assert.Equal(t, "mario", accounts[0].Name) + assert.Equal(t, "peter", accounts[1].Name) } } @@ -703,7 +703,8 @@ func TestBatchQuery(t *testing.T) { require.NoError(t, err) accounts = append(accounts, &account) } - br.Close() + err := br.Close() + require.NoError(t, err) require.Len(t, accounts, len(names)) for i := 0; i < len(names); i++ { @@ -727,7 +728,10 @@ func TestSugarBatchQuery(t *testing.T) { batchResults, batchLen, err := DB.Query.BatchQuery(ctx, queries) require.NoError(t, err) - defer batchResults.Close() + defer func() { + err := batchResults.Close() + require.NoError(t, err) + }() var accounts []*Account @@ -826,6 +830,7 @@ func TestSlogQueryTracerWithValuesReplaced(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -924,6 +929,7 @@ func TestSlogQueryTracerUsingContextToInit(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -981,6 +987,7 @@ func TestSlogQueryTracerWithErr(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -1029,6 +1036,7 @@ func TestSlogSlowQuery(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -1065,6 +1073,7 @@ func TestSlogTracerBatchQuery(t *testing.T) { ConnMaxLifetime: "1h", Tracer: tracer.NewSQLTracer(slogTracer), }) + require.NoError(t, err) defer dbClient.Conn.Close() @@ -1148,9 +1157,3 @@ func connectToDb(conf pgkit.Config) (*pgkit.DB, error) { } return dbClient, err } - -func hexEncode(b []byte) string { - enc := make([]byte, len(b)*2) - hex.Encode(enc[0:], b) - return string(enc) -}