Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions internal/api/custom_oauth_admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -109,18 +110,43 @@ func (a *API) adminCustomOAuthProvidersList(w http.ResponseWriter, r *http.Reque
})
}

// adminCustomOAuthProviderGet returns a single custom OAuth/OIDC provider
func (a *API) adminCustomOAuthProviderGet(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)

// providerIdentifierFromPath reads the custom OAuth/OIDC provider identifier
// from the request path, percent-decodes it, and validates the "custom:" prefix.
//
// chi routes on (and exposes via chi.URLParam) the raw, still percent-encoded
// path segment. Browsers encode the ':' in the required "custom:" prefix as
// "%3A" (encodeURIComponent("custom:line") == "custom%3Aline"), so the value
// returned by chi.URLParam must be decoded before the prefix check and the
// database lookup. Without decoding, a perfectly valid request is rejected with
// a misleading "must start with 'custom:' prefix" error and the provider can
// never be fetched, updated, or deleted from the dashboard.
func providerIdentifierFromPath(r *http.Request) (string, error) {
identifier := chi.URLParam(r, "identifier")
if identifier == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required")
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required")
}

decoded, err := url.PathUnescape(identifier)
if err != nil {
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is not a valid URL-encoded value")
}
identifier = decoded

if !strings.HasPrefix(identifier, "custom:") {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier)
return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier)
}

return identifier, nil
}

// adminCustomOAuthProviderGet returns a single custom OAuth/OIDC provider
func (a *API) adminCustomOAuthProviderGet(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)

identifier, err := providerIdentifierFromPath(r)
if err != nil {
return err
}

observability.LogEntrySetField(r, "identifier", identifier)
Expand Down Expand Up @@ -250,13 +276,9 @@ func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Requ
db := a.db.WithContext(ctx)
config := a.config

identifier := chi.URLParam(r, "identifier")
if identifier == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required")
}

if !strings.HasPrefix(identifier, "custom:") {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier)
identifier, err := providerIdentifierFromPath(r)
if err != nil {
return err
}

observability.LogEntrySetField(r, "identifier", identifier)
Expand Down Expand Up @@ -353,19 +375,15 @@ func (a *API) adminCustomOAuthProviderDelete(w http.ResponseWriter, r *http.Requ
ctx := r.Context()
db := a.db.WithContext(ctx)

identifier := chi.URLParam(r, "identifier")
if identifier == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required")
}

if !strings.HasPrefix(identifier, "custom:") {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier)
identifier, err := providerIdentifierFromPath(r)
if err != nil {
return err
}

observability.LogEntrySetField(r, "identifier", identifier)

var issuerToInvalidate string
err := db.Transaction(func(tx *storage.Connection) error {
err = db.Transaction(func(tx *storage.Connection) error {
provider, terr := models.FindCustomOAuthProviderByIdentifier(tx, identifier)
if terr != nil {
if models.IsNotFoundError(terr) {
Expand Down Expand Up @@ -777,4 +795,3 @@ func validateAttributeMapping(mapping map[string]interface{}) error {

return nil
}

132 changes: 131 additions & 1 deletion internal/api/custom_oauth_admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (
"strings"
"testing"

"github.com/go-chi/chi/v5"
popslices "github.com/gobuffalo/pop/v6/slices"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -625,6 +626,71 @@ func (ts *CustomOAuthAdminTestSuite) TestDeleteProvider() {
require.Equal(ts.T(), http.StatusNotFound, w.Code)
}

// The following tests use a percent-encoded identifier in the URL path, exactly
// as a browser (and the Supabase dashboard) sends it: the ':' in the required
// "custom:" prefix is encoded as "%3A". Before the path parameter was decoded,
// these requests failed with "identifier must start with 'custom:' prefix".

func (ts *CustomOAuthAdminTestSuite) TestGetProviderWithEncodedIdentifier() {
ts.createProvider(ts.createTestOAuth2Payload("encoded-get"), http.StatusCreated)

req := httptest.NewRequest(http.MethodGet, "/admin/custom-providers/custom%3Aencoded-get", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

require.Equal(ts.T(), http.StatusOK, w.Code)

var got models.CustomOAuthProvider
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&got))
assert.Equal(ts.T(), "custom:encoded-get", got.Identifier)
}

func (ts *CustomOAuthAdminTestSuite) TestUpdateProviderWithEncodedIdentifier() {
ts.createProvider(ts.createTestOAuth2Payload("encoded-update"), http.StatusCreated)

var body bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&body).Encode(map[string]interface{}{"name": "Renamed Provider"}))

req := httptest.NewRequest(http.MethodPut, "/admin/custom-providers/custom%3Aencoded-update", &body)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

require.Equal(ts.T(), http.StatusOK, w.Code)

var got models.CustomOAuthProvider
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&got))
assert.Equal(ts.T(), "custom:encoded-update", got.Identifier)
assert.Equal(ts.T(), "Renamed Provider", got.Name)
}

func (ts *CustomOAuthAdminTestSuite) TestDeleteProviderWithEncodedIdentifier() {
// Reproduces the dashboard bug from the report: deleting a custom provider
// failed with "identifier must start with 'custom:' prefix" because the
// browser sends the ':' percent-encoded (custom%3A...).
ts.createProvider(ts.createTestOAuth2Payload("encoded-delete"), http.StatusCreated)

req := httptest.NewRequest(http.MethodDelete, "/admin/custom-providers/custom%3Aencoded-delete", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

require.Equal(ts.T(), http.StatusNoContent, w.Code)

// Confirm it is actually gone.
req = httptest.NewRequest(http.MethodGet, "/admin/custom-providers/custom%3Aencoded-delete", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

w = httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

require.Equal(ts.T(), http.StatusNotFound, w.Code)
}

// Helper methods

func (ts *CustomOAuthAdminTestSuite) createTestOAuth2Payload(identifier string) map[string]interface{} {
Expand Down Expand Up @@ -687,3 +753,67 @@ func (ts *CustomOAuthAdminTestSuite) createProvider(payload map[string]interface

return w
}

// TestProviderIdentifierFromPathDecoding verifies that a percent-encoded
// provider identifier in the URL path is decoded before the "custom:" prefix
// check. Browsers and the Supabase dashboard send the ':' percent-encoded
// (encodeURIComponent("custom:line") == "custom%3Aline"), and chi exposes the
// raw, still-encoded path segment via chi.URLParam.
//
// This test exercises the real chi routing the admin endpoints use and needs
// no database, guarding against a regression where a valid request such as
// DELETE /custom-providers/custom%3Aline is rejected with a misleading
// "must start with 'custom:' prefix" error (and the provider can never be
// fetched, updated, or deleted from the dashboard).
func TestProviderIdentifierFromPathDecoding(t *testing.T) {
r := chi.NewRouter()
r.Route("/custom-providers/{identifier}", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, req *http.Request) {
identifier, err := providerIdentifierFromPath(req)
if err != nil {
herr := err.(*apierrors.HTTPError)
w.WriteHeader(herr.HTTPStatus)
_, _ = w.Write([]byte(herr.Message))
return
}
_, _ = w.Write([]byte(identifier))
})
})

cases := []struct {
name string
path string
wantStatus int
wantBody string
}{
{
name: "percent-encoded colon is decoded (browser behaviour)",
path: "/custom-providers/custom%3Aline",
wantStatus: http.StatusOK,
wantBody: "custom:line",
},
{
name: "raw colon keeps working",
path: "/custom-providers/custom:line",
wantStatus: http.StatusOK,
wantBody: "custom:line",
},
{
name: "identifier without prefix is still rejected",
path: "/custom-providers/line",
wantStatus: http.StatusBadRequest,
wantBody: "must start with 'custom:' prefix",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, c.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)

require.Equal(t, c.wantStatus, w.Code)
require.Contains(t, w.Body.String(), c.wantBody)
})
}
}