From 49236781fe4b3703f8a57a1bf59111037b9e7c13 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Thu, 28 May 2026 11:46:46 +0200 Subject: [PATCH 1/3] feat: add pagination for oauth client list endpoint --- internal/api/oauthserver/handlers.go | 18 ++++-- internal/api/oauthserver/handlers_test.go | 76 +++++++++++++++++++++++ internal/api/pagination.go | 54 +--------------- internal/api/shared/pagination.go | 64 +++++++++++++++++++ 4 files changed, 155 insertions(+), 57 deletions(-) create mode 100644 internal/api/shared/pagination.go diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index d219e56530..8790617154 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -231,22 +231,28 @@ func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) e ctx := r.Context() db := s.db.WithContext(ctx) - // TODO(cemal) :: Add pagination, check the `/admin/users` endpoint for reference + pageParams, err := shared.Paginate(r) + if err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) + } + var clients []models.OAuthServerClient - if err := db.Q().Where("deleted_at is null").Order("created_at desc").All(&clients); err != nil { + q := db.Q().Where("deleted_at is null").Order("created_at desc") + if err := q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&clients); err != nil { // #nosec G115 return apierrors.NewInternalServerError("Error listing OAuth clients").WithInternalError(err) } + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + + shared.AddPaginationHeaders(w, r, pageParams) responses := make([]OAuthServerClientResponse, len(clients)) for i, client := range clients { responses[i] = *oauthServerClientToResponse(&client) } - response := OAuthServerClientListResponse{ + return shared.SendJSON(w, http.StatusOK, OAuthServerClientListResponse{ Clients: responses, - } - - return shared.SendJSON(w, http.StatusOK, response) + }) } // OAuthTokenParams represents the parameters for the OAuth token endpoint diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go index 3aef32f543..ecb789b448 100644 --- a/internal/api/oauthserver/handlers_test.go +++ b/internal/api/oauthserver/handlers_test.go @@ -253,6 +253,82 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() { } } +func (ts *OAuthClientTestSuite) TestOAuthServerClientListPagination() { + client1, _ := ts.createTestOAuthClient() + client2, _ := ts.createTestOAuthClient() + client3, _ := ts.createTestOAuthClient() + allIDs := []string{client1.ID.String(), client2.ID.String(), client3.ID.String()} + + // page=1, per_page=1: returns 1 item, has next + last links, total count = 3 + req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=1&per_page=1", nil) + w := httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) + assert.Equal(ts.T(), http.StatusOK, w.Code) + assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`) + assert.Contains(ts.T(), w.Header().Get("Link"), `rel="last"`) + var page1 OAuthServerClientListResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page1)) + assert.Len(ts.T(), page1.Clients, 1) + + // page=2, per_page=1: returns 1 item, still has next link + req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=2&per_page=1", nil) + w = httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) + assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`) + var page2 OAuthServerClientListResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page2)) + assert.Len(ts.T(), page2.Clients, 1) + + // page=3, per_page=1: last page — no next link + req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=3&per_page=1", nil) + w = httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) + assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.NotContains(ts.T(), w.Header().Get("Link"), `rel="next"`) + var page3 OAuthServerClientListResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page3)) + assert.Len(ts.T(), page3.Clients, 1) + + // all three pages together cover all clients with no duplicates + pagedIDs := []string{page1.Clients[0].ClientID, page2.Clients[0].ClientID, page3.Clients[0].ClientID} + for _, id := range allIDs { + assert.Contains(ts.T(), pagedIDs, id) + } + + // per_page=2: page 1 returns 2, page 2 returns 1 + req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=1&per_page=2", nil) + w = httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) + var halfPage1 OAuthServerClientListResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &halfPage1)) + assert.Len(ts.T(), halfPage1.Clients, 2) + assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`) + + req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=2&per_page=2", nil) + w = httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) + var halfPage2 OAuthServerClientListResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &halfPage2)) + assert.Len(ts.T(), halfPage2.Clients, 1) + assert.NotContains(ts.T(), w.Header().Get("Link"), `rel="next"`) + + // no params: returns all 3 with default page size + req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients", nil) + w = httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) + assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + var all OAuthServerClientListResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &all)) + assert.Len(ts.T(), all.Clients, 3) + + // invalid page param returns an error + req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=abc", nil) + w = httptest.NewRecorder() + assert.Error(ts.T(), ts.Server.OAuthServerClientList(w, req)) +} + func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandler() { // Create a test client first client, _ := ts.createTestOAuthClient() diff --git a/internal/api/pagination.go b/internal/api/pagination.go index 386f40310f..63d9f2743c 100644 --- a/internal/api/pagination.go +++ b/internal/api/pagination.go @@ -1,64 +1,16 @@ package api import ( - "fmt" "net/http" - "net/url" - "strconv" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" ) -const defaultPerPage = 50 - -func calculateTotalPages(perPage, total uint64) uint64 { - pages := total / perPage - if total%perPage > 0 { - return pages + 1 - } - return pages -} - func addPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagination) { - totalPages := calculateTotalPages(p.PerPage, p.Count) - url, _ := url.ParseRequestURI(r.URL.String()) - query := url.Query() - header := "" - if totalPages > p.Page { - query.Set("page", fmt.Sprintf("%v", p.Page+1)) - url.RawQuery = query.Encode() - header += "<" + url.String() + ">; rel=\"next\", " - } - query.Set("page", fmt.Sprintf("%v", totalPages)) - url.RawQuery = query.Encode() - header += "<" + url.String() + ">; rel=\"last\"" - - w.Header().Add("Link", header) - w.Header().Add("X-Total-Count", fmt.Sprintf("%v", p.Count)) + shared.AddPaginationHeaders(w, r, p) } func paginate(r *http.Request) (*models.Pagination, error) { - params := r.URL.Query() - queryPage := params.Get("page") - queryPerPage := params.Get("per_page") - var page uint64 = 1 - var perPage uint64 = defaultPerPage - var err error - if queryPage != "" { - page, err = strconv.ParseUint(queryPage, 10, 64) - if err != nil { - return nil, err - } - } - if queryPerPage != "" { - perPage, err = strconv.ParseUint(queryPerPage, 10, 64) - if err != nil { - return nil, err - } - } - - return &models.Pagination{ - Page: page, - PerPage: perPage, - }, nil + return shared.Paginate(r) } diff --git a/internal/api/shared/pagination.go b/internal/api/shared/pagination.go new file mode 100644 index 0000000000..c95eea57c3 --- /dev/null +++ b/internal/api/shared/pagination.go @@ -0,0 +1,64 @@ +package shared + +import ( + "fmt" + "net/http" + "net/url" + "strconv" + + "github.com/supabase/auth/internal/models" +) + +const DefaultPerPage = 50 + +func calculateTotalPages(perPage, total uint64) uint64 { + pages := total / perPage + if total%perPage > 0 { + return pages + 1 + } + return pages +} + +func AddPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagination) { + totalPages := calculateTotalPages(p.PerPage, p.Count) + u, _ := url.ParseRequestURI(r.URL.String()) + query := u.Query() + header := "" + if totalPages > p.Page { + query.Set("page", fmt.Sprintf("%v", p.Page+1)) + u.RawQuery = query.Encode() + header += "<" + u.String() + ">; rel=\"next\", " + } + query.Set("page", fmt.Sprintf("%v", totalPages)) + u.RawQuery = query.Encode() + header += "<" + u.String() + ">; rel=\"last\"" + + w.Header().Add("Link", header) + w.Header().Add("X-Total-Count", fmt.Sprintf("%v", p.Count)) +} + +func Paginate(r *http.Request) (*models.Pagination, error) { + params := r.URL.Query() + queryPage := params.Get("page") + queryPerPage := params.Get("per_page") + var page uint64 = 1 + var perPage uint64 = DefaultPerPage + var err error + if queryPage != "" { + page, err = strconv.ParseUint(queryPage, 10, 64) + if err != nil { + return nil, err + } + } + if queryPerPage != "" { + perPage, err = strconv.ParseUint(queryPerPage, 10, 64) + if err != nil { + return nil, err + } + } + + return &models.Pagination{ + Page: page, + PerPage: perPage, + }, nil +} From d6b6b5f9a9d8f9a3bbe67d6f3bd29269f33ed9a1 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Fri, 29 May 2026 15:45:37 +0200 Subject: [PATCH 2/3] feat: improve validation of params and dont return count header --- internal/api/admin.go | 5 +- internal/api/admin_test.go | 2 +- internal/api/audit.go | 5 +- internal/api/oauthserver/handlers.go | 4 +- internal/api/oauthserver/handlers_test.go | 10 +-- internal/api/pagination.go | 16 ---- internal/api/shared/pagination.go | 20 ++++- internal/api/shared/pagination_test.go | 100 ++++++++++++++++++++++ internal/models/audit_log_entry.go | 3 +- internal/models/connection.go | 7 +- internal/models/user.go | 3 +- 11 files changed, 140 insertions(+), 35 deletions(-) delete mode 100644 internal/api/pagination.go create mode 100644 internal/api/shared/pagination_test.go diff --git a/internal/api/admin.go b/internal/api/admin.go index 181ea735e4..11fbe8edb6 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -12,6 +12,7 @@ import ( "github.com/sethvargo/go-password/password" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" @@ -107,7 +108,7 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) aud := a.requestAud(ctx, r) - pageParams, err := paginate(r) + pageParams, err := shared.Paginate(r) if err != nil { return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) } @@ -123,7 +124,7 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { if err != nil { return apierrors.NewInternalServerError("Database error finding users").WithInternalError(err) } - addPaginationHeaders(w, r, pageParams) + shared.AddPaginationHeaders(w, r, pageParams) return sendJSON(w, http.StatusOK, AdminListUsersResponse{ Users: users, diff --git a/internal/api/admin_test.go b/internal/api/admin_test.go index dbd3842002..2cdf4f3a0e 100644 --- a/internal/api/admin_test.go +++ b/internal/api/admin_test.go @@ -73,7 +73,7 @@ func (ts *AdminTestSuite) TestAdminUsers() { ts.API.handler.ServeHTTP(w, req) require.Equal(ts.T(), http.StatusOK, w.Code) - assert.Equal(ts.T(), "; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "; rel=\"last\"", w.Header().Get("Link")) assert.Equal(ts.T(), "0", w.Header().Get("X-Total-Count")) } diff --git a/internal/api/audit.go b/internal/api/audit.go index 665af1dd77..0f29ac9575 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" ) @@ -19,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) // aud := a.requestAud(ctx, r) - pageParams, err := paginate(r) + pageParams, err := shared.Paginate(r) if err != nil { return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } @@ -42,7 +43,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { return apierrors.NewInternalServerError("Error searching for audit logs").WithInternalError(err) } - addPaginationHeaders(w, r, pageParams) + shared.AddPaginationHeaders(w, r, pageParams) return sendJSON(w, http.StatusOK, logs) } diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 8790617154..8c27ed3496 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -238,10 +238,10 @@ func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) e var clients []models.OAuthServerClient q := db.Q().Where("deleted_at is null").Order("created_at desc") - if err := q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&clients); err != nil { // #nosec G115 + if err := q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&clients); err != nil { return apierrors.NewInternalServerError("Error listing OAuth clients").WithInternalError(err) } - pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) shared.AddPaginationHeaders(w, r, pageParams) diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go index ecb789b448..a9ba812b1f 100644 --- a/internal/api/oauthserver/handlers_test.go +++ b/internal/api/oauthserver/handlers_test.go @@ -259,12 +259,12 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListPagination() { client3, _ := ts.createTestOAuthClient() allIDs := []string{client1.ID.String(), client2.ID.String(), client3.ID.String()} - // page=1, per_page=1: returns 1 item, has next + last links, total count = 3 + // page=1, per_page=1: returns 1 item, has next + last links req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=1&per_page=1", nil) w := httptest.NewRecorder() require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) assert.Equal(ts.T(), http.StatusOK, w.Code) - assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.Empty(ts.T(), w.Header().Get("X-Total-Count")) assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`) assert.Contains(ts.T(), w.Header().Get("Link"), `rel="last"`) var page1 OAuthServerClientListResponse @@ -275,7 +275,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListPagination() { req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=2&per_page=1", nil) w = httptest.NewRecorder() require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) - assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.Empty(ts.T(), w.Header().Get("X-Total-Count")) assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`) var page2 OAuthServerClientListResponse require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page2)) @@ -285,7 +285,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListPagination() { req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=3&per_page=1", nil) w = httptest.NewRecorder() require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) - assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.Empty(ts.T(), w.Header().Get("X-Total-Count")) assert.NotContains(ts.T(), w.Header().Get("Link"), `rel="next"`) var page3 OAuthServerClientListResponse require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page3)) @@ -318,7 +318,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListPagination() { req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients", nil) w = httptest.NewRecorder() require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req)) - assert.Equal(ts.T(), "3", w.Header().Get("X-Total-Count")) + assert.Empty(ts.T(), w.Header().Get("X-Total-Count")) var all OAuthServerClientListResponse require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &all)) assert.Len(ts.T(), all.Clients, 3) diff --git a/internal/api/pagination.go b/internal/api/pagination.go deleted file mode 100644 index 63d9f2743c..0000000000 --- a/internal/api/pagination.go +++ /dev/null @@ -1,16 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/supabase/auth/internal/api/shared" - "github.com/supabase/auth/internal/models" -) - -func addPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagination) { - shared.AddPaginationHeaders(w, r, p) -} - -func paginate(r *http.Request) (*models.Pagination, error) { - return shared.Paginate(r) -} diff --git a/internal/api/shared/pagination.go b/internal/api/shared/pagination.go index c95eea57c3..df79904231 100644 --- a/internal/api/shared/pagination.go +++ b/internal/api/shared/pagination.go @@ -2,6 +2,7 @@ package shared import ( "fmt" + "math" "net/http" "net/url" "strconv" @@ -10,6 +11,7 @@ import ( ) const DefaultPerPage = 50 +const MaxPerPage = 1000 func calculateTotalPages(perPage, total uint64) uint64 { pages := total / perPage @@ -20,7 +22,7 @@ func calculateTotalPages(perPage, total uint64) uint64 { } func AddPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagination) { - totalPages := calculateTotalPages(p.PerPage, p.Count) + totalPages := max(calculateTotalPages(p.PerPage, p.Count), 1) u, _ := url.ParseRequestURI(r.URL.String()) query := u.Query() header := "" @@ -34,7 +36,9 @@ func AddPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagi header += "<" + u.String() + ">; rel=\"last\"" w.Header().Add("Link", header) - w.Header().Add("X-Total-Count", fmt.Sprintf("%v", p.Count)) + if p.ShowTotalCount { + w.Header().Add("X-Total-Count", fmt.Sprintf("%v", p.Count)) + } } func Paginate(r *http.Request) (*models.Pagination, error) { @@ -49,12 +53,24 @@ func Paginate(r *http.Request) (*models.Pagination, error) { if err != nil { return nil, err } + if page == 0 { + return nil, fmt.Errorf("page must be greater than 0") + } + if page > math.MaxInt32 { + return nil, fmt.Errorf("page exceeds maximum allowed value") + } } if queryPerPage != "" { perPage, err = strconv.ParseUint(queryPerPage, 10, 64) if err != nil { return nil, err } + if perPage == 0 { + return nil, fmt.Errorf("per_page must be greater than 0") + } + if perPage > MaxPerPage { + return nil, fmt.Errorf("per_page must not exceed %d", MaxPerPage) + } } return &models.Pagination{ diff --git a/internal/api/shared/pagination_test.go b/internal/api/shared/pagination_test.go new file mode 100644 index 0000000000..733cbc30c9 --- /dev/null +++ b/internal/api/shared/pagination_test.go @@ -0,0 +1,100 @@ +package shared + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func TestPaginate_Defaults(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + p, err := Paginate(r) + require.NoError(t, err) + assert.Equal(t, uint64(1), p.Page) + assert.Equal(t, uint64(DefaultPerPage), p.PerPage) +} + +func TestPaginate_CustomValues(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/?page=3&per_page=10", nil) + p, err := Paginate(r) + require.NoError(t, err) + assert.Equal(t, uint64(3), p.Page) + assert.Equal(t, uint64(10), p.PerPage) +} + +func TestPaginate_PageZeroIsInvalid(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/?page=0", nil) + _, err := Paginate(r) + assert.Error(t, err) +} + +func TestPaginate_PerPageZeroIsInvalid(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/?per_page=0", nil) + _, err := Paginate(r) + assert.Error(t, err) +} + +func TestPaginate_PerPageExceedsMaxIsInvalid(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/?per_page=10000", nil) + _, err := Paginate(r) + assert.Error(t, err) +} + +func TestPaginate_MaxPerPageIsValid(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/?per_page=1000", nil) + p, err := Paginate(r) + require.NoError(t, err) + assert.Equal(t, uint64(MaxPerPage), p.PerPage) +} + +func TestAddPaginationHeaders_EmptyResultSet(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + p := &models.Pagination{Page: 1, PerPage: 50, Count: 0, ShowTotalCount: true} + AddPaginationHeaders(w, r, p) + + link := w.Header().Get("Link") + assert.Contains(t, link, `rel="last"`) + // last page should be page=1 for empty results, not page=0 + assert.NotContains(t, link, "page=0") + assert.Equal(t, "0", w.Header().Get("X-Total-Count")) +} + +func TestAddPaginationHeaders_SinglePage(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + p := &models.Pagination{Page: 1, PerPage: 50, Count: 10, ShowTotalCount: true} + AddPaginationHeaders(w, r, p) + + link := w.Header().Get("Link") + assert.NotContains(t, link, `rel="next"`) + assert.Contains(t, link, `rel="last"`) +} + +func TestAddPaginationHeaders_MultiPage(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + p := &models.Pagination{Page: 1, PerPage: 10, Count: 25, ShowTotalCount: true} + AddPaginationHeaders(w, r, p) + + link := w.Header().Get("Link") + assert.Contains(t, link, `rel="next"`) + assert.Contains(t, link, `rel="last"`) + assert.Equal(t, "25", w.Header().Get("X-Total-Count")) +} + +func TestAddPaginationHeaders_ShowTotalCountFalseOmitsHeader(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + p := &models.Pagination{Page: 1, PerPage: 10, Count: 25, ShowTotalCount: false} + AddPaginationHeaders(w, r, p) + + assert.Empty(t, w.Header().Get("X-Total-Count")) + // Link headers should still be computed correctly from Count + assert.Contains(t, w.Header().Get("Link"), `rel="next"`) + assert.Contains(t, w.Header().Get("Link"), `rel="last"`) +} diff --git a/internal/models/audit_log_entry.go b/internal/models/audit_log_entry.go index 5481851b7a..79d6c718a0 100644 --- a/internal/models/audit_log_entry.go +++ b/internal/models/audit_log_entry.go @@ -203,7 +203,8 @@ func FindAuditLogEntries(tx *storage.Connection, filterColumns []string, filterV var err error if pageParams != nil { err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&logs) // #nosec G115 - pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + pageParams.ShowTotalCount = true } else { err = q.All(&logs) } diff --git a/internal/models/connection.go b/internal/models/connection.go index dcc6b84cf0..f8c1255d78 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -6,9 +6,10 @@ import ( ) type Pagination struct { - Page uint64 - PerPage uint64 - Count uint64 + Page uint64 + PerPage uint64 + Count uint64 + ShowTotalCount bool } func (p *Pagination) Offset() uint64 { diff --git a/internal/models/user.go b/internal/models/user.go index 3c706b80e7..941b65d069 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -767,7 +767,8 @@ func FindUsersInAudience(tx *storage.Connection, aud string, pageParams *Paginat var err error if pageParams != nil { err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&users) // #nosec G115 - pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + pageParams.ShowTotalCount = true } else { err = q.All(&users) } From e82f12da20c82b7de1fab1642c491a5f0e7e85fb Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Fri, 29 May 2026 16:06:39 +0200 Subject: [PATCH 3/3] fix: add nosec g115 the bounds added in Paginate() function already tracks them --- internal/api/oauthserver/handlers.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 8c27ed3496..8790617154 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -238,10 +238,10 @@ func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) e var clients []models.OAuthServerClient q := db.Q().Where("deleted_at is null").Order("created_at desc") - if err := q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&clients); err != nil { + if err := q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&clients); err != nil { // #nosec G115 return apierrors.NewInternalServerError("Error listing OAuth clients").WithInternalError(err) } - pageParams.Count = uint64(q.Paginator.TotalEntriesSize) + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 shared.AddPaginationHeaders(w, r, pageParams)