diff --git a/internal/api/custom_oauth_admin.go b/internal/api/custom_oauth_admin.go index 5a37e705c..9e66a98fb 100644 --- a/internal/api/custom_oauth_admin.go +++ b/internal/api/custom_oauth_admin.go @@ -5,6 +5,7 @@ import ( "encoding/json" "io" "net/http" + "net/url" "slices" "strings" "time" @@ -109,18 +110,43 @@ func (a *API) adminCustomOAuthProvidersList(w http.ResponseWriter, r *http.Reque }) } -// adminCustomOAuthProviderGet returns a single custom OAuth/OIDC provider -func (a *API) adminCustomOAuthProviderGet(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - db := a.db.WithContext(ctx) - +// providerIdentifierFromPath reads the custom OAuth/OIDC provider identifier +// from the request path, percent-decodes it, and validates the "custom:" prefix. +// +// chi routes on (and exposes via chi.URLParam) the raw, still percent-encoded +// path segment. Browsers encode the ':' in the required "custom:" prefix as +// "%3A" (encodeURIComponent("custom:line") == "custom%3Aline"), so the value +// returned by chi.URLParam must be decoded before the prefix check and the +// database lookup. Without decoding, a perfectly valid request is rejected with +// a misleading "must start with 'custom:' prefix" error and the provider can +// never be fetched, updated, or deleted from the dashboard. +func providerIdentifierFromPath(r *http.Request) (string, error) { identifier := chi.URLParam(r, "identifier") if identifier == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") + } + + decoded, err := url.PathUnescape(identifier) + if err != nil { + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is not a valid URL-encoded value") } + identifier = decoded if !strings.HasPrefix(identifier, "custom:") { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier) + return "", apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier) + } + + return identifier, nil +} + +// adminCustomOAuthProviderGet returns a single custom OAuth/OIDC provider +func (a *API) adminCustomOAuthProviderGet(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + identifier, err := providerIdentifierFromPath(r) + if err != nil { + return err } observability.LogEntrySetField(r, "identifier", identifier) @@ -250,13 +276,9 @@ func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Requ db := a.db.WithContext(ctx) config := a.config - identifier := chi.URLParam(r, "identifier") - if identifier == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") - } - - if !strings.HasPrefix(identifier, "custom:") { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier) + identifier, err := providerIdentifierFromPath(r) + if err != nil { + return err } observability.LogEntrySetField(r, "identifier", identifier) @@ -353,19 +375,15 @@ func (a *API) adminCustomOAuthProviderDelete(w http.ResponseWriter, r *http.Requ ctx := r.Context() db := a.db.WithContext(ctx) - identifier := chi.URLParam(r, "identifier") - if identifier == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier is required") - } - - if !strings.HasPrefix(identifier, "custom:") { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "identifier must start with 'custom:' prefix, e.g. 'custom:%s'", identifier) + identifier, err := providerIdentifierFromPath(r) + if err != nil { + return err } observability.LogEntrySetField(r, "identifier", identifier) var issuerToInvalidate string - err := db.Transaction(func(tx *storage.Connection) error { + err = db.Transaction(func(tx *storage.Connection) error { provider, terr := models.FindCustomOAuthProviderByIdentifier(tx, identifier) if terr != nil { if models.IsNotFoundError(terr) { @@ -777,4 +795,3 @@ func validateAttributeMapping(mapping map[string]interface{}) error { return nil } - diff --git a/internal/api/custom_oauth_admin_test.go b/internal/api/custom_oauth_admin_test.go index e8927290d..5bdad8a51 100644 --- a/internal/api/custom_oauth_admin_test.go +++ b/internal/api/custom_oauth_admin_test.go @@ -9,9 +9,10 @@ import ( "strings" "testing" + "github.com/go-chi/chi/v5" popslices "github.com/gobuffalo/pop/v6/slices" - jwt "github.com/golang-jwt/jwt/v5" "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -625,6 +626,71 @@ func (ts *CustomOAuthAdminTestSuite) TestDeleteProvider() { require.Equal(ts.T(), http.StatusNotFound, w.Code) } +// The following tests use a percent-encoded identifier in the URL path, exactly +// as a browser (and the Supabase dashboard) sends it: the ':' in the required +// "custom:" prefix is encoded as "%3A". Before the path parameter was decoded, +// these requests failed with "identifier must start with 'custom:' prefix". + +func (ts *CustomOAuthAdminTestSuite) TestGetProviderWithEncodedIdentifier() { + ts.createProvider(ts.createTestOAuth2Payload("encoded-get"), http.StatusCreated) + + req := httptest.NewRequest(http.MethodGet, "/admin/custom-providers/custom%3Aencoded-get", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + var got models.CustomOAuthProvider + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&got)) + assert.Equal(ts.T(), "custom:encoded-get", got.Identifier) +} + +func (ts *CustomOAuthAdminTestSuite) TestUpdateProviderWithEncodedIdentifier() { + ts.createProvider(ts.createTestOAuth2Payload("encoded-update"), http.StatusCreated) + + var body bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&body).Encode(map[string]interface{}{"name": "Renamed Provider"})) + + req := httptest.NewRequest(http.MethodPut, "/admin/custom-providers/custom%3Aencoded-update", &body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + var got models.CustomOAuthProvider + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&got)) + assert.Equal(ts.T(), "custom:encoded-update", got.Identifier) + assert.Equal(ts.T(), "Renamed Provider", got.Name) +} + +func (ts *CustomOAuthAdminTestSuite) TestDeleteProviderWithEncodedIdentifier() { + // Reproduces the dashboard bug from the report: deleting a custom provider + // failed with "identifier must start with 'custom:' prefix" because the + // browser sends the ':' percent-encoded (custom%3A...). + ts.createProvider(ts.createTestOAuth2Payload("encoded-delete"), http.StatusCreated) + + req := httptest.NewRequest(http.MethodDelete, "/admin/custom-providers/custom%3Aencoded-delete", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + // Confirm it is actually gone. + req = httptest.NewRequest(http.MethodGet, "/admin/custom-providers/custom%3Aencoded-delete", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) +} + // Helper methods func (ts *CustomOAuthAdminTestSuite) createTestOAuth2Payload(identifier string) map[string]interface{} { @@ -687,3 +753,67 @@ func (ts *CustomOAuthAdminTestSuite) createProvider(payload map[string]interface return w } + +// TestProviderIdentifierFromPathDecoding verifies that a percent-encoded +// provider identifier in the URL path is decoded before the "custom:" prefix +// check. Browsers and the Supabase dashboard send the ':' percent-encoded +// (encodeURIComponent("custom:line") == "custom%3Aline"), and chi exposes the +// raw, still-encoded path segment via chi.URLParam. +// +// This test exercises the real chi routing the admin endpoints use and needs +// no database, guarding against a regression where a valid request such as +// DELETE /custom-providers/custom%3Aline is rejected with a misleading +// "must start with 'custom:' prefix" error (and the provider can never be +// fetched, updated, or deleted from the dashboard). +func TestProviderIdentifierFromPathDecoding(t *testing.T) { + r := chi.NewRouter() + r.Route("/custom-providers/{identifier}", func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, req *http.Request) { + identifier, err := providerIdentifierFromPath(req) + if err != nil { + herr := err.(*apierrors.HTTPError) + w.WriteHeader(herr.HTTPStatus) + _, _ = w.Write([]byte(herr.Message)) + return + } + _, _ = w.Write([]byte(identifier)) + }) + }) + + cases := []struct { + name string + path string + wantStatus int + wantBody string + }{ + { + name: "percent-encoded colon is decoded (browser behaviour)", + path: "/custom-providers/custom%3Aline", + wantStatus: http.StatusOK, + wantBody: "custom:line", + }, + { + name: "raw colon keeps working", + path: "/custom-providers/custom:line", + wantStatus: http.StatusOK, + wantBody: "custom:line", + }, + { + name: "identifier without prefix is still rejected", + path: "/custom-providers/line", + wantStatus: http.StatusBadRequest, + wantBody: "must start with 'custom:' prefix", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, c.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + require.Equal(t, c.wantStatus, w.Code) + require.Contains(t, w.Body.String(), c.wantBody) + }) + } +}