From a4af3db638a329d2e040b4d6b337fa0398af9ae4 Mon Sep 17 00:00:00 2001 From: Gourab Singha Date: Fri, 26 Jun 2026 18:54:58 +0530 Subject: [PATCH] fix: preserve double slashes in custom scheme redirect URIs (fixes #2423) --- internal/api/external.go | 14 +++++++++----- internal/api/external_test.go | 2 +- internal/api/helpers.go | 5 +++++ internal/api/samlacs.go | 6 +++--- internal/api/verify.go | 16 ++++++++++++---- internal/api/verify_test.go | 20 +++++++++++++++++--- 6 files changed, 47 insertions(+), 16 deletions(-) diff --git a/internal/api/external.go b/internal/api/external.go index 0f88a64b8f..b7b06d8fc1 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -28,7 +28,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e if err != nil { return err } - http.Redirect(w, r, rurl, http.StatusFound) // #nosec G710 + a.redirect(w, r, rurl, http.StatusFound) // #nosec G710 return nil } @@ -134,7 +134,7 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e if err != nil { return err } - redirectErrors(a.internalExternalProviderCallback, w, r, u) + a.redirectErrors(a.internalExternalProviderCallback, w, r, u) return nil } @@ -285,7 +285,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re } - http.Redirect(w, r, rurl, http.StatusFound) // #nosec G710 + a.redirect(w, r, rurl, http.StatusFound) // #nosec G710 return nil } @@ -797,7 +797,7 @@ func (a *API) loadCustomProvider(ctx context.Context, db *storage.Connection, id return p, pConfig, nil } -func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { +func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { ctx := r.Context() log := observability.GetLogEntry(r).Entry errorID := utilities.GetRequestID(ctx) @@ -820,7 +820,11 @@ func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, // Add Supabase Auth identifier to help clients distinguish Supabase Auth redirects hq.Set("sb", "") u.Fragment = hq.Encode() - http.Redirect(w, r, u.String(), http.StatusFound) // #nosec G710 + res := u.String() + if u.Scheme != "" && strings.HasPrefix(res, u.Scheme+":") && !strings.HasPrefix(res, u.Scheme+"://") { + res = u.Scheme + "://" + strings.TrimPrefix(res, u.Scheme+":") + } + a.redirect(w, r, res, http.StatusFound) } } diff --git a/internal/api/external_test.go b/internal/api/external_test.go index cbaceebc13..2c5f3300c9 100644 --- a/internal/api/external_test.go +++ b/internal/api/external_test.go @@ -309,7 +309,7 @@ func (ts *ExternalTestSuite) TestRedirectErrorsShouldPreserveParams() { parsedURL, err := url.Parse(c.RedirectURL) require.Equal(ts.T(), err, nil) - redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL) + ts.API.redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL) parsedParams, err := url.ParseQuery(parsedURL.RawQuery) require.Equal(ts.T(), err, nil) diff --git a/internal/api/helpers.go b/internal/api/helpers.go index a4a8458402..dfd0bfaa91 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -94,3 +94,8 @@ func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { } return nil } + +func (a *API) redirect(w http.ResponseWriter, r *http.Request, url string, status int) { + w.Header().Set("Location", url) + w.WriteHeader(status) +} diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index e0a35df85d..d7bfec0f39 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -54,7 +54,7 @@ func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) u.RawQuery = q.Encode() - http.Redirect(w, r, u.String(), http.StatusSeeOther) + a.redirect(w, r, u.String(), http.StatusSeeOther) } return nil } @@ -370,7 +370,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { return err } - http.Redirect(w, r, redirectTo, http.StatusFound) // #nosec G710 + a.redirect(w, r, redirectTo, http.StatusFound) // #nosec G710 return nil } @@ -382,7 +382,7 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } // #nosec G710 - http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound) + a.redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound) return nil } diff --git a/internal/api/verify.go b/internal/api/verify.go index 212d7388eb..27a7456429 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -205,7 +205,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } } if rurl != "" { - http.Redirect(w, r, rurl, http.StatusSeeOther) + a.redirect(w, r, rurl, http.StatusSeeOther) return nil } rurl = params.RedirectTo @@ -222,7 +222,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return err } } - http.Redirect(w, r, rurl, http.StatusSeeOther) + a.redirect(w, r, rurl, http.StatusSeeOther) return nil } @@ -528,7 +528,11 @@ func (a *API) prepRedirectURL(message string, rurl string, flowType models.FlowT // Add Supabase Auth identifier to help clients distinguish Supabase Auth redirects hq.Set("sb", "") u.Fragment = hq.Encode() - return u.String(), nil + res := u.String() + if u.Scheme != "" && strings.HasPrefix(rurl, u.Scheme+"://") && !strings.HasPrefix(res, u.Scheme+"://") { + res = u.Scheme + "://" + strings.TrimPrefix(res, u.Scheme+":") + } + return res, nil } func (a *API) prepPKCERedirectURL(rurl, code string) (string, error) { @@ -539,7 +543,11 @@ func (a *API) prepPKCERedirectURL(rurl, code string) (string, error) { q := u.Query() q.Set("code", code) u.RawQuery = q.Encode() - return u.String(), nil + res := u.String() + if u.Scheme != "" && strings.HasPrefix(rurl, u.Scheme+"://") && !strings.HasPrefix(res, u.Scheme+"://") { + res = u.Scheme + "://" + strings.TrimPrefix(res, u.Scheme+":") + } + return res, nil } func (a *API) emailChangeVerify(r *http.Request, conn *storage.Connection, params *VerifyParams, user *models.User) (*models.User, error) { diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index a75df9c15d..a8454c7700 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -676,7 +676,7 @@ func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { siteURL: "http://localhost:3000", uriAllowList: []string{"com.myapp://**"}, requestredirectURL: "com.myapp://", - expectedredirectURL: "com.myapp:", + expectedredirectURL: "com.myapp://", }, } @@ -703,8 +703,8 @@ func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - rURL, _ := w.Result().Location() - assert.Contains(ts.T(), rURL.String(), tC.expectedredirectURL) // redirected url starts with per test value + loc := w.Header().Get("Location") + assert.Contains(ts.T(), loc, tC.expectedredirectURL) // redirected url starts with per test value u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) @@ -1178,6 +1178,20 @@ func (ts *VerifyTestSuite) TestPrepRedirectURL() { flowType: models.ImplicitFlow, expected: fmt.Sprintf("https://example.com/?first=another#message=%s&sb=", escapedMessage), }, + { + desc: "(Implicit): custom scheme preserved", + message: singleConfirmationAccepted, + rurl: "com.myapp://", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("com.myapp://#message=%s&sb=", escapedMessage), + }, + { + desc: "(PKCE): custom scheme preserved", + message: singleConfirmationAccepted, + rurl: "com.myapp://", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("com.myapp://?message=%s#message=%s&sb=", escapedMessage, escapedMessage), + }, } for _, c := range cases { ts.Run(c.desc, func() {