Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions auth/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func CsrfCheck(next echo.HandlerFunc) echo.HandlerFunc {
return next(c)
}

if c.Path() == "/oauth2/device/code" || c.Path() == "/oauth2/device/token" {
return next(c)
}

cookie, err := c.Cookie(CsrfCookieName)
if err != nil || cookie.Value == "" {
return c.String(http.StatusForbidden, "Missing CSRF cookie")
Expand Down
8 changes: 8 additions & 0 deletions auth/noauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ func (p noauthProvider) GetSession(c echo.Context) (*Session, error) {
func (noauthProvider) DropSession(echo.Context, *Session) {
}

func (noauthProvider) GetRateLimiterMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
return next(c)
}
}
}

func init() {
RegisterProvider(&noauthProvider{})
}
2 changes: 2 additions & 0 deletions auth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type Provider interface {
// GetSession retrieves the session associated with the given context.
GetSession(c echo.Context) (*Session, error)
DropSession(c echo.Context, session *Session)

GetRateLimiterMiddleware() echo.MiddlewareFunc
}

const AuthCookieName = "fioserver-session"
Expand Down
4 changes: 4 additions & 0 deletions auth/provider_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@ func (p *commonProvider) GetSession(c echo.Context) (*Session, error) {
}
return nil, p.renderer.renderLoginPage(c, "")
}

func (p *commonProvider) GetRateLimiterMiddleware() echo.MiddlewareFunc {
return p.rateLimiter.Middleware
}
5 changes: 5 additions & 0 deletions auth/provider_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/foundriesio/update-server/server/ui/web/templates"
"github.com/foundriesio/update-server/storage"
"github.com/foundriesio/update-server/storage/users"
"github.com/foundriesio/update-server/version"
)

const localLoginTemplate = "local-login.html"
Expand Down Expand Up @@ -198,10 +199,12 @@ func (p localProvider) renderLoginPage(c echo.Context, reason string) error {
User *users.User
NavItems []string
CsrfToken string
Version string
}{
Title: "Login",
Reason: reason,
CsrfToken: csrfToken,
Version: version.Version,
}
return templates.Templates.ExecuteTemplate(c.Response(), localLoginTemplate, context)
}
Expand Down Expand Up @@ -246,11 +249,13 @@ func (p *localProvider) handlePasswordPage(c echo.Context, session *Session) err
User *users.User
NavItems []string
CsrfToken string
Version string
}{
Title: "Change Password",
Message: "Your password has expired. Please choose a new password.",
User: session.User,
CsrfToken: csrfToken,
Version: version.Version,
}
return templates.Templates.ExecuteTemplate(c.Response(), localPasswordChangeTemplate, context)
}
Expand Down
156 changes: 150 additions & 6 deletions cli/subcommands/login/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
package login

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"

"github.com/spf13/cobra"

"github.com/foundriesio/update-server/cli/config"
models "github.com/foundriesio/update-server/server/ui/api"
)

var LoginCmd = &cobra.Command{
Expand All @@ -27,23 +33,32 @@ the configuration to ~/.config/satcli.yaml.`,
token, _ := cmd.Flags().GetString("token")
setDefault, _ := cmd.Flags().GetBool("set-default")
configPath, _ := cmd.Flags().GetString("config")
scopes, _ := cmd.Flags().GetString("scopes")
expiresInDays, _ := cmd.Flags().GetInt("expires-in-days")

cobra.CheckErr(login(contextName, serverURL, token, configPath, setDefault))
cobra.CheckErr(login(configPath, contextName, serverURL, token, scopes, expiresInDays, setDefault))
},
}

func init() {
LoginCmd.Flags().String("token", "", "API token for authentication (required for now)")
LoginCmd.Flags().String("token", "", "API token for authentication (skips OAuth2 device flow)")
LoginCmd.Flags().Bool("set-default", true, "Set this context as the default")
LoginCmd.Flags().String("config", "", "Specify the configuration file to use")
cobra.CheckErr(LoginCmd.MarkFlagRequired("token"))
LoginCmd.Flags().String("scopes", "devices:read-update,updates:read-update", "Comma-separated list of OAuth2 scopes to request (optional)")
LoginCmd.Flags().Int("expires-in-days", 90, "Number of days until the access token expires")
}

func login(contextName, serverURL, token, configPath string, setDefault bool) error {
if token == "" {
return fmt.Errorf("--token is required")
func login(configPath, contextName, serverURL, token, scopes string, expiresInDays int, setDefault bool) error {
if token != "" {
return saveToken(configPath, contextName, serverURL, token, setDefault)
}

fmt.Println("Initiating OAuth2 device authorization flow...")
expires := time.Now().Add(time.Duration(expiresInDays) * 24 * time.Hour).Unix()
return oauth2DeviceFlow(configPath, contextName, serverURL, scopes, expires, setDefault)
}

func saveToken(configPath, contextName, serverURL, token string, setDefault bool) error {
// Load existing config or create new one
cfg, err := config.LoadConfig(configPath)
if err != nil {
Expand Down Expand Up @@ -80,3 +95,132 @@ func login(contextName, serverURL, token, configPath string, setDefault bool) er

return nil
}

type oauth2Error struct {
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
}

func (e *oauth2Error) Error() string {
if e.ErrorDescription != "" {
return fmt.Sprintf("%s: %s", e.ErrorCode, e.ErrorDescription)
}
return e.ErrorCode
}

func oauth2DeviceFlow(configPath, contextName, serverURL, scopes string, expires int64, setDefault bool) error {
// Step 1: Request device code
codeReq := models.DeviceCodeRequest{
Scopes: scopes,
TokenExpires: expires,
}
jsonData, err := json.Marshal(codeReq)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}

resp, err := http.Post(serverURL+"/oauth2/device/code", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to request device code: %w", err)
}
defer resp.Body.Close() // nolint:errcheck

if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("failed to get device code (status %d): %s", resp.StatusCode, string(body))
}

var codeResp models.DeviceCodeResponse
if err := json.NewDecoder(resp.Body).Decode(&codeResp); err != nil {
return fmt.Errorf("failed to decode device code response: %w", err)
}

// Step 2: Display user code and verification URI
fmt.Println()
fmt.Println("------------------------------------------------")
fmt.Printf(" Visit: %s\n", codeResp.VerificationURI)
fmt.Println()
fmt.Printf(" Enter code: %s\n", codeResp.UserCode)
fmt.Println("------------------------------------------------")
fmt.Println()
fmt.Println("Waiting for authorization...")

// Step 3: Poll for token
pollInterval := time.Duration(codeResp.Interval) * time.Second
expiresAt := time.Now().Add(time.Duration(codeResp.ExpiresIn) * time.Second)

for time.Now().Before(expiresAt) {
time.Sleep(pollInterval)

token, err := pollForToken(serverURL, codeResp.DeviceCode)
if err == nil {
// Success! Save the token
fmt.Println()
fmt.Println("✓ Authorization successful!")
return saveToken(configPath, contextName, serverURL, token, setDefault)
}

// Check if we should continue polling
if oauth2Err, ok := err.(*oauth2Error); ok {
switch oauth2Err.ErrorCode {
case "authorization_pending":
continue
case "slow_down":
pollInterval *= 2
continue
case "access_denied":
return fmt.Errorf("authorization was denied")
case "expired_token":
return fmt.Errorf("authorization code expired")
default:
return fmt.Errorf("OAuth2 error: %s - %s", oauth2Err.ErrorCode, oauth2Err.ErrorDescription)
}
}

return fmt.Errorf("failed to get token: %w", err)
}

return fmt.Errorf("authorization timed out")
}

func pollForToken(serverURL, deviceCode string) (string, error) {
tokenReq := models.DeviceTokenRequest{
DeviceCode: deviceCode,
GrantType: "urn:ietf:params:oauth:grant-type:device_code",
}

jsonData, err := json.Marshal(tokenReq)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}

resp, err := http.Post(serverURL+"/oauth2/device/token", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("failed to request token: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to close response body: %v\n", err)
}
}()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}

if resp.StatusCode == 200 {
var tokenResp models.DeviceTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
return tokenResp.AccessToken, nil
}

var errResp oauth2Error
if err := json.Unmarshal(body, &errResp); err != nil {
return "", fmt.Errorf("request failed (status %d): %s", resp.StatusCode, string(body))
}

return "", &errResp
}
14 changes: 12 additions & 2 deletions server/ui/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@ import (

type handlers struct {
storage *storage.Storage
users *users.Storage
}

var EchoError = server.EchoError

func RegisterHandlers(e *echo.Echo, storage *storage.Storage, a auth.Provider) {
h := handlers{storage: storage}
func RegisterHandlers(e *echo.Echo, storage *storage.Storage, userStorage *users.Storage, a auth.Provider) {
h := handlers{
storage: storage,
users: userStorage,
}

// OAuth2 endpoints (no authentication required)
oauth2 := oauth2Handlers{users: userStorage}
e.POST("/oauth2/device/code", oauth2.oauth2DeviceCode, a.GetRateLimiterMiddleware())
e.POST("/oauth2/device/token", oauth2.oauth2DeviceToken, a.GetRateLimiterMiddleware())

g := e.Group("/v1")
g.Use(authUser(a))

Expand Down
Loading
Loading