Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions cmd/siftd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ const (
)

type config struct {
ListenAddr string
PostgresDSN string
RegistryPath string
OutputDir string
SyncInterval time.Duration
SyncTimeout time.Duration
RetentionWindow time.Duration
SyncOnStart bool
ZitadelIssuer string
ZitadelAudience string
WSAllowedOrigins []string
ListenAddr string
PostgresDSN string
RegistryPath string
OutputDir string
SyncInterval time.Duration
SyncTimeout time.Duration
RetentionWindow time.Duration
SyncOnStart bool
ZitadelIssuer string
ZitadelAudience string
AllowedOrigins []string
}

func main() {
Expand Down Expand Up @@ -68,11 +68,11 @@ func run(ctx context.Context) error {
}

api, err := hosted.New(hosted.Options{
Store: store,
Validator: validator,
OutputDir: cfg.OutputDir,
AllowedWebSocketOrigins: cfg.WSAllowedOrigins,
Now: func() time.Time { return time.Now().UTC() },
Store: store,
Validator: validator,
OutputDir: cfg.OutputDir,
AllowedBrowserOrigins: cfg.AllowedOrigins,
Now: func() time.Time { return time.Now().UTC() },
})
if err != nil {
return err
Expand Down Expand Up @@ -190,18 +190,23 @@ func loadConfigFromEnv() (config, error) {
return config{}, err
}

allowedOrigins := parseCSVEnv("SIFTD_ALLOWED_ORIGINS")
if len(allowedOrigins) == 0 {
allowedOrigins = parseCSVEnv("SIFTD_WS_ALLOWED_ORIGINS")
}

cfg := config{
ListenAddr: envOrDefault("SIFTD_ADDR", defaultListenAddr),
PostgresDSN: strings.TrimSpace(os.Getenv("SIFTD_POSTGRES_DSN")),
RegistryPath: envOrDefault("SIFTD_REGISTRY", defaultRegistryPath),
OutputDir: envOrDefault("SIFTD_OUTPUT_DIR", defaultOutputDir),
SyncInterval: syncInterval,
SyncTimeout: syncTimeout,
RetentionWindow: retentionWindow,
SyncOnStart: syncOnStart,
ZitadelIssuer: strings.TrimSpace(os.Getenv("SIFTD_ZITADEL_ISSUER")),
ZitadelAudience: strings.TrimSpace(os.Getenv("SIFTD_ZITADEL_AUDIENCE")),
WSAllowedOrigins: parseCSVEnv("SIFTD_WS_ALLOWED_ORIGINS"),
ListenAddr: envOrDefault("SIFTD_ADDR", defaultListenAddr),
PostgresDSN: strings.TrimSpace(os.Getenv("SIFTD_POSTGRES_DSN")),
RegistryPath: envOrDefault("SIFTD_REGISTRY", defaultRegistryPath),
OutputDir: envOrDefault("SIFTD_OUTPUT_DIR", defaultOutputDir),
SyncInterval: syncInterval,
SyncTimeout: syncTimeout,
RetentionWindow: retentionWindow,
SyncOnStart: syncOnStart,
ZitadelIssuer: strings.TrimSpace(os.Getenv("SIFTD_ZITADEL_ISSUER")),
ZitadelAudience: strings.TrimSpace(os.Getenv("SIFTD_ZITADEL_AUDIENCE")),
AllowedOrigins: allowedOrigins,
}

if cfg.PostgresDSN == "" {
Expand Down
27 changes: 22 additions & 5 deletions cmd/siftd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestLoadConfigFromEnv(t *testing.T) {
t.Setenv("SIFTD_SYNC_TIMEOUT", "30s")
t.Setenv("SIFTD_RETENTION", "720h")
t.Setenv("SIFTD_SYNC_ON_START", "false")
t.Setenv("SIFTD_WS_ALLOWED_ORIGINS", "https://sift.local, https://console.sift.local")
t.Setenv("SIFTD_ALLOWED_ORIGINS", "https://sift.local, https://console.sift.local")

cfg, err := loadConfigFromEnv()
if err != nil {
Expand All @@ -29,11 +29,28 @@ func TestLoadConfigFromEnv(t *testing.T) {
if cfg.RetentionWindow.String() != "720h0m0s" {
t.Fatalf("unexpected retention window: %s", cfg.RetentionWindow)
}
if len(cfg.WSAllowedOrigins) != 2 {
t.Fatalf("unexpected ws allowed origins count: %d", len(cfg.WSAllowedOrigins))
if len(cfg.AllowedOrigins) != 2 {
t.Fatalf("unexpected allowed origins count: %d", len(cfg.AllowedOrigins))
}
if cfg.WSAllowedOrigins[0] != "https://sift.local" {
t.Fatalf("unexpected first ws allowed origin: %s", cfg.WSAllowedOrigins[0])
if cfg.AllowedOrigins[0] != "https://sift.local" {
t.Fatalf("unexpected first allowed origin: %s", cfg.AllowedOrigins[0])
}
}

func TestLoadConfigFromEnvFallsBackToLegacyWSOrigins(t *testing.T) {
t.Setenv("SIFTD_POSTGRES_DSN", "postgres://user:pass@localhost:5432/sift?sslmode=disable")
t.Setenv("SIFTD_ZITADEL_ISSUER", "https://auth.example.com")
t.Setenv("SIFTD_ZITADEL_AUDIENCE", "audience")
t.Setenv("SIFTD_ALLOWED_ORIGINS", "")
t.Setenv("SIFTD_WS_ALLOWED_ORIGINS", "https://legacy.sift.local")

cfg, err := loadConfigFromEnv()
if err != nil {
t.Fatalf("loadConfigFromEnv returned error: %v", err)
}

if len(cfg.AllowedOrigins) != 1 || cfg.AllowedOrigins[0] != "https://legacy.sift.local" {
t.Fatalf("unexpected legacy fallback origins: %#v", cfg.AllowedOrigins)
}
}

Expand Down
4 changes: 4 additions & 0 deletions docs/contracts/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ paths:

Stream messages are advisory. Clients should re-fetch canonical records
over REST when full event truth is required.

Non-browser clients may authenticate with `Authorization: Bearer <token>`.
Browser clients should use `Sec-WebSocket-Protocol: sift.v1, bearer.<token>`
so the server can validate the bearer token without exposing it in the URL.
responses:
"101":
description: Switching protocols to WebSocket
Expand Down
19 changes: 14 additions & 5 deletions docs/runbooks/siftd.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Before deployment, have all of the following:
- reachable Postgres DSN for the shared Sift database;
- reachable Zitadel issuer URL;
- Zitadel audience for the API application;
- one allowed browser origin if WebSocket is used from a browser UI;
- one allowed browser origin if REST or WebSocket is used from a browser UI;
- repo checkout path `/srv/sift/current`.

## Expected Host Layout
Expand Down Expand Up @@ -107,13 +107,14 @@ Recommended variables:
- `SIFTD_SYNC_TIMEOUT=4m`
- `SIFTD_RETENTION=720h`
- `SIFTD_SYNC_ON_START=true`
- `SIFTD_WS_ALLOWED_ORIGINS=https://console.example.com`
- `SIFTD_ALLOWED_ORIGINS=https://skill7.dev`

Operational notes:

- `SIFTD_RETENTION=720h` means `30d`.
- `SIFTD_SYNC_TIMEOUT` must stay below `SIFTD_SYNC_INTERVAL`.
- if `SIFTD_WS_ALLOWED_ORIGINS` is set, browser clients must send one of those origins;
- if `SIFTD_ALLOWED_ORIGINS` is set, browser REST requests and browser WebSocket clients must send one of those origins;
- `SIFTD_WS_ALLOWED_ORIGINS` is still accepted as a legacy fallback alias for older deployments;
- keep `SIFTD_ADDR` bound to localhost unless TLS termination is handled directly in-process.

## Start and Verify
Expand Down Expand Up @@ -205,6 +206,14 @@ If `/readyz` stays degraded:

If browser WebSocket connection fails:

- confirm the client uses `Authorization: Bearer <token>`;
- confirm request `Origin` matches `SIFTD_WS_ALLOWED_ORIGINS`;
- confirm browser clients send `Sec-WebSocket-Protocol: sift.v1, bearer.<token>`;
- confirm non-browser clients use `Authorization: Bearer <token>` if they do not support subprotocol auth;
- confirm request `Origin` matches `SIFTD_ALLOWED_ORIGINS`;
- confirm the reverse proxy forwards the WebSocket upgrade headers unchanged.
- scrub or disable logging of `Sec-WebSocket-Protocol` in the reverse proxy, because `bearer.<token>` can otherwise land in access/proxy logs.

If browser REST requests fail with CORS errors:

- confirm `SIFTD_ALLOWED_ORIGINS` includes the browser origin exactly;
- confirm the browser is calling the public HTTPS hostname rather than the pod or cluster IP;
- confirm the reverse proxy preserves the `Origin` header unchanged.
127 changes: 106 additions & 21 deletions internal/hosted/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
)

const (
defaultListLimit = 20
maxListLimit = 100
defaultListLimit = 20
maxListLimit = 100
webSocketProtocol = "sift.v1"
webSocketBearerPrefix = "bearer."
)

type EventStore interface {
Expand All @@ -35,6 +37,7 @@ type Options struct {
Store EventStore
Validator zitadel.Validator
OutputDir string
AllowedBrowserOrigins []string
AllowedWebSocketOrigins []string
Now func() time.Time
}
Expand All @@ -55,7 +58,7 @@ type Server struct {
clients map[*wsClient]struct{}
upgrader websocket.Upgrader

allowedWebSocketOrigins map[string]struct{}
allowedBrowserOrigins map[string]struct{}
}

type wsClient struct {
Expand Down Expand Up @@ -101,25 +104,28 @@ func New(options Options) (*Server, error) {
now = func() time.Time { return time.Now().UTC() }
}

allowedOrigins := make(map[string]struct{}, len(options.AllowedWebSocketOrigins))
for _, rawOrigin := range options.AllowedWebSocketOrigins {
allowedOrigins := make(map[string]struct{}, len(options.AllowedBrowserOrigins)+len(options.AllowedWebSocketOrigins))
for _, rawOrigin := range append(options.AllowedBrowserOrigins, options.AllowedWebSocketOrigins...) {
origin, err := normalizeOrigin(rawOrigin)
if err != nil {
return nil, fmt.Errorf("invalid websocket allowed origin %q: %w", rawOrigin, err)
return nil, fmt.Errorf("invalid allowed origin %q: %w", rawOrigin, err)
}
allowedOrigins[origin] = struct{}{}
}

server := &Server{
store: options.Store,
validator: options.Validator,
outputDir: outputDir,
now: now,
clients: make(map[*wsClient]struct{}),
allowedWebSocketOrigins: allowedOrigins,
store: options.Store,
validator: options.Validator,
outputDir: outputDir,
now: now,
clients: make(map[*wsClient]struct{}),
allowedBrowserOrigins: allowedOrigins,
}
server.upgrader = websocket.Upgrader{
CheckOrigin: server.checkWebSocketOrigin,
Subprotocols: []string{
webSocketProtocol,
},
}

return server, nil
Expand All @@ -146,28 +152,33 @@ func normalizeOrigin(raw string) (string, error) {
}

func (s *Server) checkWebSocketOrigin(r *http.Request) bool {
originHeader := strings.TrimSpace(r.Header.Get("Origin"))
_, ok := s.allowedOrigin(r.Header.Get("Origin"), r.Host)
return ok
}

func (s *Server) allowedOrigin(rawOrigin, requestHost string) (string, bool) {
originHeader := strings.TrimSpace(rawOrigin)
if originHeader == "" {
// If an explicit allowlist is configured, require Origin header presence.
return len(s.allowedWebSocketOrigins) == 0
return "", len(s.allowedBrowserOrigins) == 0
}

origin, err := normalizeOrigin(originHeader)
if err != nil {
return false
return "", false
}

if len(s.allowedWebSocketOrigins) > 0 {
_, ok := s.allowedWebSocketOrigins[origin]
return ok
if len(s.allowedBrowserOrigins) > 0 {
_, ok := s.allowedBrowserOrigins[origin]
return origin, ok
}

originURL, err := url.Parse(origin)
if err != nil {
return false
return "", false
}

return strings.EqualFold(originURL.Host, r.Host)
return origin, strings.EqualFold(originURL.Host, requestHost)
}

func (s *Server) Handler() http.Handler {
Expand All @@ -178,7 +189,7 @@ func (s *Server) Handler() http.Handler {
mux.HandleFunc("/v1/events/", s.requireAuth(s.handleGetEvent))
mux.HandleFunc("/v1/digests/", s.requireAuth(s.handleGetDigest))
mux.HandleFunc("/v1/ws", s.requireAuth(s.handleWebSocket))
return mux
return s.withCORS(mux)
}

func (s *Server) MarkSyncSuccess(runID string, at time.Time) {
Expand Down Expand Up @@ -411,6 +422,63 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
}
}

func (s *Server) withCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !isRESTAPIPath(r.URL.Path) {
next.ServeHTTP(w, r)
return
}

applyCORSVaryHeaders(w.Header())

origin, ok := s.allowedOrigin(r.Header.Get("Origin"), r.Host)
if strings.TrimSpace(r.Header.Get("Origin")) != "" && !ok {
writeJSONError(w, http.StatusForbidden, "origin not allowed")
return
}

if ok && origin != "" {
applyCORSHeaders(w.Header(), origin)
}

if isPreflightRequest(r) {
if !ok || origin == "" {
writeJSONError(w, http.StatusForbidden, "origin not allowed")
return
}
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Access-Control-Request-Method")), http.MethodGet) {
writeJSONError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
w.WriteHeader(http.StatusNoContent)
return
}

next.ServeHTTP(w, r)
})
}

func isRESTAPIPath(path string) bool {
return path == "/v1/events" || strings.HasPrefix(path, "/v1/events/") || strings.HasPrefix(path, "/v1/digests/")
}

func isPreflightRequest(r *http.Request) bool {
return r.Method == http.MethodOptions && strings.TrimSpace(r.Header.Get("Access-Control-Request-Method")) != ""
}

func applyCORSHeaders(header http.Header, origin string) {
header.Set("Access-Control-Allow-Origin", origin)
header.Set("Access-Control-Allow-Methods", "GET, OPTIONS")
header.Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
header.Set("Access-Control-Max-Age", "600")
}

func applyCORSVaryHeaders(header http.Header) {
header.Add("Vary", "Origin")
header.Add("Vary", "Access-Control-Request-Method")
header.Add("Vary", "Access-Control-Request-Headers")
}

func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token, err := tokenFromRequest(r)
Expand All @@ -432,10 +500,27 @@ func tokenFromRequest(r *http.Request) (string, error) {
if token, err := zitadel.ExtractBearerToken(r.Header.Get("Authorization")); err == nil {
return token, nil
}
if websocket.IsWebSocketUpgrade(r) {
if token, ok := tokenFromWebSocketSubprotocols(r); ok {
return token, nil
}
}

return "", fmt.Errorf("missing bearer token")
}

func tokenFromWebSocketSubprotocols(r *http.Request) (string, bool) {
for _, protocol := range websocket.Subprotocols(r) {
if strings.HasPrefix(protocol, webSocketBearerPrefix) {
token := strings.TrimPrefix(protocol, webSocketBearerPrefix)
if token != "" {
return token, true
}
}
}
return "", false
}

func (s *Server) registerClient(client *wsClient) {
s.wsMu.Lock()
defer s.wsMu.Unlock()
Expand Down
Loading