diff --git a/page.go b/page.go index b3e5d4c..412da48 100644 --- a/page.go +++ b/page.go @@ -6,6 +6,7 @@ import ( "strings" sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" ) const ( @@ -79,6 +80,10 @@ func NewPage(size, page uint32, sort ...Sort) *Page { func (p *Page) GetOrder(defaultSort ...string) []Sort { // if page has sort, use it if p != nil && len(p.Sort) != 0 { + for i, s := range p.Sort { + s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() + p.Sort[i] = s + } return p.Sort } // if page has column, use default sort @@ -194,6 +199,26 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele return make([]T, 0, limit+1), q } +func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) { + limit, offset := page.Limit(), page.Offset() + + q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") + q = q + " LIMIT @limit OFFSET @offset" + + for i, arg := range args { + if existing, ok := arg.(pgx.NamedArgs); ok { + existing["limit"] = limit + 1 + existing["offset"] = offset + break + } + if i == len(args)-1 { + args = append(args, pgx.NamedArgs{"limit": limit + 1, "offset": offset}) + } + } + + return make([]T, 0, limit+1), q, args +} + // PrepareResult prepares the paginated result. If the number of rows is n+1: // - it removes the last element, returning n elements // - it sets more to true in the page object diff --git a/page_test.go b/page_test.go index 5b7ac37..4ef175b 100644 --- a/page_test.go +++ b/page_test.go @@ -17,7 +17,7 @@ func TestPagination(t *testing.T) { MaxSize = 5 Sort = "ID" ) - paginator := pgkit.NewPaginator[T]( + paginator := pgkit.NewPaginator( pgkit.WithColumnFunc[T](strings.ToLower), pgkit.WithDefaultSize[T](DefaultSize), pgkit.WithMaxSize[T](MaxSize), @@ -45,3 +45,19 @@ func TestPagination(t *testing.T) { require.Len(t, result, MaxSize) require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page) } + +func TestInvalidSort(t *testing.T) { + paginator := pgkit.NewPaginator[T]() + page := pgkit.NewPage(0, 0) + page.Sort = []pgkit.Sort{ + {Column: "ID; DROP TABLE users;", Order: pgkit.Asc}, + {Column: "name", Order: pgkit.Desc}, + } + + _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) + + sql, args, err := query.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t ORDER BY \"ID; DROP TABLE users;\" ASC, \"name\" DESC LIMIT 11 OFFSET 0", sql) + require.Empty(t, args) +}