diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 59cbc6f5..8587cfed 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -13,6 +13,7 @@ import ( "github.com/appleboy/graceful" "github.com/gin-gonic/gin" + "github.com/prometheus/client_golang/prometheus" "github.com/redis/go-redis/v9" ) @@ -142,6 +143,9 @@ func (app *Application) initializeInfrastructure(ctx context.Context) error { // initializeBusinessLayer sets up services func (app *Application) initializeBusinessLayer() { // Audit service (required by other services) + if app.Config.MetricsEnabled { + services.SetAuditMetricsRegisterer(prometheus.DefaultRegisterer) + } if app.Config.EnableAuditLogging { app.AuditService = services.NewAuditService( app.DB, diff --git a/internal/bootstrap/router.go b/internal/bootstrap/router.go index 9de6de51..0c0a3064 100644 --- a/internal/bootstrap/router.go +++ b/internal/bootstrap/router.go @@ -44,7 +44,7 @@ func setupRouter( // Setup middleware r.Use(metrics.HTTPMetricsMiddleware(prometheusMetrics)) r.Use(gin.Logger(), gin.Recovery()) - r.Use(middleware.IPMiddleware()) + r.Use(middleware.RequestContextMiddleware()) r.Use(middleware.SecurityHeaders(strings.HasPrefix(cfg.BaseURL, "https://"))) // Setup session middleware diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index cb98b6f5..e0316344 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -170,7 +170,7 @@ func (h *AuthHandler) Login(c *gin.Context, // Set session fingerprint if enabled if h.cfg.SessionFingerprint { - clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by IPMiddleware + clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by RequestContextMiddleware userAgent := c.Request.UserAgent() fingerprint := middleware.GenerateFingerprint( clientIP, diff --git a/internal/handlers/oauth_handler.go b/internal/handlers/oauth_handler.go index bc3aef4c..9ff9d8bf 100644 --- a/internal/handlers/oauth_handler.go +++ b/internal/handlers/oauth_handler.go @@ -230,7 +230,7 @@ func (h *OAuthHandler) OAuthCallback(c *gin.Context) { // Set session fingerprint if enabled if h.sessionFingerprintEnabled { - clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by IPMiddleware + clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by RequestContextMiddleware userAgent := c.Request.UserAgent() fingerprint := middleware.GenerateFingerprint( clientIP, diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 42f7946c..5383387f 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -146,7 +146,7 @@ func SessionFingerprintMiddleware(enabled, includeIP bool) gin.HandlerFunc { if storedFingerprint != nil { // Get current fingerprint - clientIP := c.GetString(ContextKeyClientIP) // Set by IPMiddleware + clientIP := c.GetString(ContextKeyClientIP) // Set by RequestContextMiddleware userAgent := c.Request.UserAgent() currentFingerprint := GenerateFingerprint(clientIP, userAgent, includeIP) diff --git a/internal/middleware/context.go b/internal/middleware/context.go index 9d63b17d..3d1a0f97 100644 --- a/internal/middleware/context.go +++ b/internal/middleware/context.go @@ -8,15 +8,24 @@ import ( // ContextKeyClientIP is the gin context key for the client IP address. const ContextKeyClientIP = "client_ip" -// IPMiddleware extracts client IP and stores it in the context -func IPMiddleware() gin.HandlerFunc { +// RequestContextMiddleware extracts client IP and HTTP request metadata +// (User-Agent, path, method) and stores them in the request context for +// downstream services (e.g. audit logging). +func RequestContextMiddleware() gin.HandlerFunc { return func(c *gin.Context) { clientIP := c.ClientIP() // Gin's ClientIP() handles X-Forwarded-For and other headers c.Set(ContextKeyClientIP, clientIP) - // Also store in request context for services layer - c.Request = c.Request.WithContext(util.SetIPContext(c.Request.Context(), clientIP)) + // Store IP and request metadata in request context for services layer + ctx := util.SetIPContext(c.Request.Context(), clientIP) + ctx = util.SetRequestMetadataContext( + ctx, + c.Request.UserAgent(), + c.Request.URL.Path, + c.Request.Method, + ) + c.Request = c.Request.WithContext(ctx) c.Next() } diff --git a/internal/services/audit.go b/internal/services/audit.go index 68d73e20..f8c940c4 100644 --- a/internal/services/audit.go +++ b/internal/services/audit.go @@ -6,6 +6,7 @@ import ( "log" "strings" "sync" + "sync/atomic" "time" "github.com/go-authgate/authgate/internal/core" @@ -14,11 +15,78 @@ import ( "github.com/go-authgate/authgate/internal/util" "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" ) // Compile-time interface check. var _ core.AuditLogger = (*AuditService)(nil) +// auditEventsDropped is a singleton counter protected by a mutex so it can +// be created before metrics are configured and registered later once a +// registerer becomes available. +// +// The counter is only registered with Prometheus when a registerer is +// explicitly provided via SetAuditMetricsRegisterer, so deployments with +// metrics disabled do not leak collectors from the services layer. +var ( + auditEventsDropped prometheus.Counter + auditEventsDroppedMu sync.Mutex + auditEventsDroppedRegistered bool + auditEventsDroppedRegisterer prometheus.Registerer +) + +// registerAuditDroppedCounterLocked attempts to register the existing counter +// with the configured registerer. The caller must hold auditEventsDroppedMu. +func registerAuditDroppedCounterLocked() { + if auditEventsDropped == nil || + auditEventsDroppedRegisterer == nil || + auditEventsDroppedRegistered { + return + } + if err := auditEventsDroppedRegisterer.Register(auditEventsDropped); err != nil { + if existing, ok := err.(prometheus.AlreadyRegisteredError); ok { + if c, ok := existing.ExistingCollector.(prometheus.Counter); ok { + auditEventsDropped = c + auditEventsDroppedRegistered = true + return + } + } + log.Printf("failed to register audit dropped-events counter: %v", err) + return + } + auditEventsDroppedRegistered = true +} + +// SetAuditMetricsRegisterer configures the Prometheus registerer used by the +// audit service. If the dropped-events counter was created before metrics +// were configured, setting a non-nil registerer will register the existing +// counter, ensuring late initialization (e.g. in tests) is not silently lost. +func SetAuditMetricsRegisterer(registerer prometheus.Registerer) { + auditEventsDroppedMu.Lock() + defer auditEventsDroppedMu.Unlock() + + auditEventsDroppedRegisterer = registerer + registerAuditDroppedCounterLocked() +} + +func getAuditEventsDroppedCounter() prometheus.Counter { + auditEventsDroppedMu.Lock() + defer auditEventsDroppedMu.Unlock() + + if auditEventsDropped == nil { + // Use a single fully-prefixed Name (no Namespace/Subsystem) to + // match the convention used by metrics in internal/metrics. + auditEventsDropped = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "audit_events_dropped_total", + Help: "Total number of audit log events dropped due to a full buffer.", + }) + } + registerAuditDroppedCounterLocked() + // When no registerer is set, the counter still works in-memory but + // is not exposed via the Prometheus /metrics endpoint. + return auditEventsDropped +} + // AuditService handles audit logging operations type AuditService struct { store core.Store @@ -33,8 +101,12 @@ type AuditService struct { batchTicker *time.Ticker // Graceful shutdown - wg sync.WaitGroup - shutdownCh chan struct{} + wg sync.WaitGroup + sendMu sync.RWMutex // coordinates Log() senders with Shutdown() + stopped atomic.Bool + + // Prometheus counter for dropped events + eventsDropped prometheus.Counter } // NewAuditService creates a new audit service @@ -44,11 +116,11 @@ func NewAuditService(s core.Store, bufferSize int) *AuditService { } service := &AuditService{ - store: s, - bufferSize: bufferSize, - logChan: make(chan *models.AuditLog, bufferSize), - batchBuffer: make([]*models.AuditLog, 0, 100), - shutdownCh: make(chan struct{}), + store: s, + bufferSize: bufferSize, + logChan: make(chan *models.AuditLog, bufferSize), + batchBuffer: make([]*models.AuditLog, 0, 100), + eventsDropped: getAuditEventsDroppedCounter(), } service.batchTicker = time.NewTicker(1 * time.Second) @@ -59,23 +131,25 @@ func NewAuditService(s core.Store, bufferSize int) *AuditService { return service } -// worker is the background goroutine that processes audit logs +// worker is the background goroutine that processes audit logs. +// It drains logChan until the channel is closed by Shutdown, then +// flushes any remaining batch and exits. func (s *AuditService) worker() { defer s.wg.Done() for { select { - case log := <-s.logChan: - s.addToBatch(log) + case entry, ok := <-s.logChan: + if !ok { + // Channel closed by Shutdown — flush remaining batch. + s.flushBatch() + return + } + s.addToBatch(entry) case <-s.batchTicker.C: // Flush batch every second s.flushBatch() - - case <-s.shutdownCh: - // Flush remaining logs before shutdown - s.flushBatch() - return } } } @@ -126,14 +200,52 @@ func (s *AuditService) buildAuditLog( if entry.ActorIP == "" { entry.ActorIP = util.GetIPFromContext(ctx) } + // Fill ActorUsername from context only when the entry's ActorUserID is + // empty or matches the context user — otherwise the username could be + // misattributed to a different principal than ActorUserID identifies. if entry.ActorUsername == "" { - entry.ActorUsername = models.GetUsernameFromContext(ctx) + ctxUserID := models.GetUserIDFromContext(ctx) + if entry.ActorUserID == "" || entry.ActorUserID == ctxUserID { + entry.ActorUsername = models.GetUsernameFromContext(ctx) + } + } + // Fall back to a DB lookup when context did not provide a username and + // the actor is a real user (not a synthetic machine identity from the + // client_credentials grant, which uses a "client:" format and + // has no corresponding user row). + if entry.ActorUsername == "" && entry.ActorUserID != "" && + !strings.HasPrefix(entry.ActorUserID, "client:") { + if user, err := s.store.GetUserByID(entry.ActorUserID); err == nil { + entry.ActorUsername = user.Username + } } if entry.ActorUserID == "" { entry.ActorUserID = models.GetUserIDFromContext(ctx) } + if entry.UserAgent == "" { + entry.UserAgent = util.GetUserAgentFromContext(ctx) + } + if entry.RequestPath == "" { + entry.RequestPath = util.GetRequestPathFromContext(ctx) + } + if entry.RequestMethod == "" { + entry.RequestMethod = util.GetRequestMethodFromContext(ctx) + } entry.Details = maskSensitiveDetails(entry.Details) + // Truncate fields to match database column size limits. + // TruncateString appends "..." (3 chars) when truncating, so subtract 3 + // from the varchar limit to guarantee the final length fits the column. + entry.UserAgent = util.TruncateString(entry.UserAgent, 497) + entry.RequestPath = util.TruncateString(entry.RequestPath, 497) + + // RequestMethod is stored in a varchar(10) column. Preserve values up to + // the full column width and hard-truncate anything longer without adding + // an ellipsis. + if len(entry.RequestMethod) > 10 { + entry.RequestMethod = entry.RequestMethod[:10] + } + now := time.Now() return &models.AuditLog{ ID: uuid.New().String(), @@ -157,13 +269,25 @@ func (s *AuditService) buildAuditLog( } } -// Log records an audit log entry asynchronously +// Log records an audit log entry asynchronously. +// Events submitted after Shutdown has been called are dropped. +// The RWMutex ensures all in-flight sends complete before Shutdown +// closes logChan, eliminating the send-on-closed-channel race. func (s *AuditService) Log(ctx context.Context, entry core.AuditLogEntry) { + s.sendMu.RLock() + defer s.sendMu.RUnlock() + + if s.stopped.Load() { + log.Printf("WARNING: Audit service stopped, dropping event: %s", entry.Action) + s.eventsDropped.Inc() + return + } auditLog := s.buildAuditLog(ctx, entry) select { case s.logChan <- auditLog: default: log.Printf("WARNING: Audit log buffer full, dropping event: %s", entry.Action) + s.eventsDropped.Inc() } } @@ -193,12 +317,19 @@ func (s *AuditService) GetAuditLogStats(startTime, endTime time.Time) (store.Aud // Shutdown gracefully shuts down the audit service func (s *AuditService) Shutdown(ctx context.Context) error { + // 1. Reject new events so future Log() calls return immediately. + s.stopped.Store(true) + + // 2. Wait for all in-flight Log() calls to finish, then close + // logChan. The exclusive lock ensures no sender is mid-send + // when the channel is closed. + s.sendMu.Lock() + close(s.logChan) + s.sendMu.Unlock() + // Stop ticker s.batchTicker.Stop() - // Signal worker to stop - close(s.shutdownCh) - // Wait for worker to finish with timeout done := make(chan struct{}) go func() { diff --git a/internal/services/audit_test.go b/internal/services/audit_test.go index a8bc5a8e..9f14f568 100644 --- a/internal/services/audit_test.go +++ b/internal/services/audit_test.go @@ -1,10 +1,17 @@ package services import ( + "context" + "fmt" "testing" + "time" + "github.com/go-authgate/authgate/internal/core" "github.com/go-authgate/authgate/internal/models" + storetypes "github.com/go-authgate/authgate/internal/store/types" + "github.com/go-authgate/authgate/internal/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMaskSensitiveDetails_FullRedaction(t *testing.T) { @@ -70,3 +77,163 @@ func TestMaskSensitiveDetails_PlainField(t *testing.T) { func TestMaskSensitiveDetails_Nil(t *testing.T) { assert.Nil(t, maskSensitiveDetails(nil)) } + +func TestBuildAuditLog_EnrichesRequestMetadataFromContext(t *testing.T) { + svc := &AuditService{} + + ctx := context.Background() + ctx = util.SetIPContext(ctx, "10.0.0.1") + ctx = util.SetRequestMetadataContext(ctx, "Mozilla/5.0", "/oauth/token", "POST") + + entry := core.AuditLogEntry{ + EventType: models.EventAccessTokenIssued, + Severity: models.SeverityInfo, + Action: "test", + Success: true, + } + + result := svc.buildAuditLog(ctx, entry) + + assert.Equal(t, "10.0.0.1", result.ActorIP) + assert.Equal(t, "Mozilla/5.0", result.UserAgent) + assert.Equal(t, "/oauth/token", result.RequestPath) + assert.Equal(t, "POST", result.RequestMethod) +} + +func TestBuildAuditLog_DoesNotOverrideExplicitValues(t *testing.T) { + svc := &AuditService{} + + ctx := context.Background() + ctx = util.SetIPContext(ctx, "10.0.0.1") + ctx = util.SetRequestMetadataContext(ctx, "Mozilla/5.0", "/oauth/token", "POST") + + entry := core.AuditLogEntry{ + EventType: models.EventAccessTokenIssued, + Severity: models.SeverityInfo, + ActorIP: "192.168.1.1", + UserAgent: "custom-agent", + RequestPath: "/custom/path", + RequestMethod: "GET", + Action: "test", + Success: true, + } + + result := svc.buildAuditLog(ctx, entry) + + // Explicit values should be preserved, not overwritten by context + assert.Equal(t, "192.168.1.1", result.ActorIP) + assert.Equal(t, "custom-agent", result.UserAgent) + assert.Equal(t, "/custom/path", result.RequestPath) + assert.Equal(t, "GET", result.RequestMethod) +} + +func TestBuildAuditLog_EnrichesUserFromContext(t *testing.T) { + svc := &AuditService{} + + user := &models.User{ + ID: "user-123", + Username: "testuser", + } + ctx := models.SetUserContext(context.Background(), user) + + entry := core.AuditLogEntry{ + EventType: models.EventAccessTokenIssued, + Severity: models.SeverityInfo, + Action: "test", + Success: true, + } + + result := svc.buildAuditLog(ctx, entry) + + assert.Equal(t, "user-123", result.ActorUserID) + assert.Equal(t, "testuser", result.ActorUsername) +} + +func TestBuildAuditLog_FillsActorUsernameFromDBFallback(t *testing.T) { + // When context has no user but the entry's ActorUserID points to a real + // user, buildAuditLog should resolve the username via a DB lookup. + s := setupTestStore(t) + user := &models.User{ + ID: "fallback-user-id", + Username: "fallback-user", + Email: "fallback@example.com", + PasswordHash: "x", + AuthSource: models.AuthSourceLocal, + } + require.NoError(t, s.CreateUser(user)) + + svc := &AuditService{store: s} + + entry := core.AuditLogEntry{ + EventType: models.EventAccessTokenIssued, + Severity: models.SeverityInfo, + ActorUserID: user.ID, + Action: "test", + Success: true, + } + + result := svc.buildAuditLog(context.Background(), entry) + + assert.Equal(t, user.ID, result.ActorUserID) + assert.Equal(t, "fallback-user", result.ActorUsername) +} + +func TestBuildAuditLog_SkipsDBLookupForMachineIdentity(t *testing.T) { + // Synthetic machine identities (client_credentials grant) use the + // "client:" format and have no user row, so buildAuditLog + // must not attempt a DB lookup. A nil store would panic if a query ran. + svc := &AuditService{} + + entry := core.AuditLogEntry{ + EventType: models.EventClientCredentialsTokenIssued, + Severity: models.SeverityInfo, + ActorUserID: "client:test-client-id", + Action: "test", + Success: true, + } + + result := svc.buildAuditLog(context.Background(), entry) + + assert.Equal(t, "client:test-client-id", result.ActorUserID) + assert.Empty(t, result.ActorUsername) +} + +func TestShutdown_DrainsLogChan(t *testing.T) { + // Construct the service struct directly (without starting the worker) + // so we can populate the channel deterministically before the drain runs. + s := setupTestStore(t) + svc := &AuditService{ + store: s, + bufferSize: 100, + logChan: make(chan *models.AuditLog, 100), + batchBuffer: make([]*models.AuditLog, 0, 100), + eventsDropped: getAuditEventsDroppedCounter(), + } + + // Populate the channel before the worker starts + const numEntries = 5 + for i := range numEntries { + svc.logChan <- &models.AuditLog{ + ID: fmt.Sprintf("drain-test-%d", i), + EventType: models.EventAccessTokenIssued, + Severity: models.SeverityInfo, + Action: "drain-test", + } + } + + // Now start the worker and immediately shut down + svc.batchTicker = time.NewTicker(1 * time.Second) + svc.wg.Add(1) + go svc.worker() + + err := svc.Shutdown(context.Background()) + require.NoError(t, err) + + // Verify entries were persisted to the store + logs, _, err := s.GetAuditLogsPaginated( + storetypes.PaginationParams{Page: 1, PageSize: 10}, + storetypes.AuditLogFilters{}, + ) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(logs), numEntries, "all drain-test entries should be persisted") +} diff --git a/internal/services/token_exchange.go b/internal/services/token_exchange.go index 89180553..4e5c08e4 100644 --- a/internal/services/token_exchange.go +++ b/internal/services/token_exchange.go @@ -77,7 +77,7 @@ func (s *TokenService) ExchangeDeviceCode( // Delete the used device code _ = s.store.DeleteDeviceCodeByID(dc.ID) - // Log token issuance + // Log token issuance — ActorUsername is auto-resolved by buildAuditLog. s.auditService.Log(ctx, core.AuditLogEntry{ EventType: models.EventAccessTokenIssued, Severity: models.SeverityInfo, @@ -153,25 +153,30 @@ func (s *TokenService) ExchangeAuthorizationCode( AtHash: token.ComputeAtHash(accessToken.RawToken), } - // Fetch user profile for scope-gated claims - if user, err := s.store.GetUserByID(authCode.UserID); err == nil { - if scopeSet["profile"] { - params.Name = user.FullName - params.PreferredUsername = user.Username - params.Picture = user.AvatarURL - updatedAt := user.UpdatedAt - params.UpdatedAt = &updatedAt + // Fetch user profile only when scope-gated claims are needed + if scopeSet["profile"] || scopeSet["email"] { + if user, err := s.store.GetUserByID(authCode.UserID); err == nil { + // Cache the user in context so the audit service's + // ActorUsername enrichment hits context (no extra DB call). + ctx = models.SetUserContext(ctx, user) + if scopeSet["profile"] { + params.Name = user.FullName + params.PreferredUsername = user.Username + params.Picture = user.AvatarURL + updatedAt := user.UpdatedAt + params.UpdatedAt = &updatedAt + } + if scopeSet["email"] { + params.Email = user.Email + params.EmailVerified = false // AuthGate does not verify email addresses + } + } else { + log.Printf( + "[Token] ID token: failed to fetch user profile for user_id=%s, profile/email claims will be omitted: %v", + authCode.UserID, + err, + ) } - if scopeSet["email"] { - params.Email = user.Email - params.EmailVerified = false // AuthGate does not verify email addresses - } - } else if scopeSet["profile"] || scopeSet["email"] { - log.Printf( - "[Token] ID token: failed to fetch user profile for user_id=%s, profile/email claims will be omitted: %v", - authCode.UserID, - err, - ) } if generated, err := idp.GenerateIDToken(params); err == nil { @@ -202,7 +207,9 @@ func (s *TokenService) ExchangeAuthorizationCode( s.metrics.RecordTokenIssued("access", "authorization_code", duration, providerName) s.metrics.RecordTokenIssued("refresh", "authorization_code", duration, providerName) - // Audit + // Audit — ActorUsername is auto-resolved by buildAuditLog (from the + // context user cached above when openid+profile/email was requested, + // or via DB fallback otherwise). s.auditService.Log(ctx, core.AuditLogEntry{ EventType: models.EventAccessTokenIssued, Severity: models.SeverityInfo, diff --git a/internal/services/token_management.go b/internal/services/token_management.go index 811cf17f..0b67159c 100644 --- a/internal/services/token_management.go +++ b/internal/services/token_management.go @@ -42,7 +42,7 @@ func (s *TokenService) RevokeTokenByID(ctx context.Context, tokenID, actorUserID err = s.store.RevokeToken(tokenID) if err != nil { - // Log revocation failure + // Log revocation failure — ActorUsername is auto-resolved by buildAuditLog. s.auditService.Log(ctx, core.AuditLogEntry{ EventType: models.EventTokenRevoked, Severity: models.SeverityError, @@ -62,7 +62,7 @@ func (s *TokenService) RevokeTokenByID(ctx context.Context, tokenID, actorUserID // Record revocation s.metrics.RecordTokenRevoked(tok.TokenCategory, "user_request") - // Log token revocation + // Log token revocation — ActorUsername is auto-resolved by buildAuditLog. s.auditService.Log(ctx, core.AuditLogEntry{ EventType: models.EventTokenRevoked, Severity: models.SeverityInfo, @@ -133,7 +133,7 @@ func (s *TokenService) updateTokenStatusWithAudit( err = s.store.UpdateTokenStatus(tokenID, newStatus) if err != nil { - // Log failure + // Log failure — ActorUsername is auto-resolved by buildAuditLog. s.auditService.Log(ctx, core.AuditLogEntry{ EventType: eventType, Severity: models.SeverityError, @@ -150,7 +150,7 @@ func (s *TokenService) updateTokenStatusWithAudit( s.invalidateTokenCache(ctx, tok.TokenHash) - // Log success + // Log success — ActorUsername is auto-resolved by buildAuditLog. s.auditService.Log(ctx, core.AuditLogEntry{ EventType: eventType, Severity: models.SeverityInfo, diff --git a/internal/services/token_refresh.go b/internal/services/token_refresh.go index 42c5fa96..4a164059 100644 --- a/internal/services/token_refresh.go +++ b/internal/services/token_refresh.go @@ -52,7 +52,8 @@ func (s *TokenService) revokeTokenFamilyWithAudit( s.metrics.RecordTokenRevoked("family", "replay_detection") } - // Audit log — CRITICAL severity because this indicates potential token theft + // Audit log — CRITICAL severity because this indicates potential token theft. + // ActorUsername is auto-resolved by buildAuditLog. _ = s.auditService.LogSync(ctx, core.AuditLogEntry{ EventType: models.EventSuspiciousActivity, Severity: models.SeverityCritical, @@ -204,7 +205,7 @@ func (s *TokenService) RefreshAccessToken( // Record successful refresh s.metrics.RecordTokenRefresh(true) - // Log token refresh + // Log token refresh — ActorUsername is auto-resolved by buildAuditLog. providerName := s.tokenProvider.Name() details := models.AuditDetails{ "client_id": newAccessToken.ClientID, diff --git a/internal/util/context.go b/internal/util/context.go index d60dc27c..5332423b 100644 --- a/internal/util/context.go +++ b/internal/util/context.go @@ -9,6 +9,9 @@ type contextKey int const ( contextKeyClientIP contextKey = iota + contextKeyUserAgent + contextKeyRequestPath + contextKeyRequestMethod ) // SetIPContext embeds client IP into a standard context @@ -27,3 +30,38 @@ func GetIPFromContext(ctx context.Context) string { } return "" } + +// SetRequestMetadataContext embeds HTTP request metadata into a standard context. +func SetRequestMetadataContext( + ctx context.Context, + userAgent, path, method string, +) context.Context { + ctx = context.WithValue(ctx, contextKeyUserAgent, userAgent) + ctx = context.WithValue(ctx, contextKeyRequestPath, path) + ctx = context.WithValue(ctx, contextKeyRequestMethod, method) + return ctx +} + +// GetUserAgentFromContext extracts the User-Agent from the context. +func GetUserAgentFromContext(ctx context.Context) string { + if v, ok := ctx.Value(contextKeyUserAgent).(string); ok { + return v + } + return "" +} + +// GetRequestPathFromContext extracts the request path from the context. +func GetRequestPathFromContext(ctx context.Context) string { + if v, ok := ctx.Value(contextKeyRequestPath).(string); ok { + return v + } + return "" +} + +// GetRequestMethodFromContext extracts the HTTP method from the context. +func GetRequestMethodFromContext(ctx context.Context) string { + if v, ok := ctx.Value(contextKeyRequestMethod).(string); ok { + return v + } + return "" +} diff --git a/internal/util/context_test.go b/internal/util/context_test.go index f147608c..17ac3bdc 100644 --- a/internal/util/context_test.go +++ b/internal/util/context_test.go @@ -81,6 +81,35 @@ func TestGetIPFromContext(t *testing.T) { } } +func TestSetRequestMetadataContext(t *testing.T) { + ctx := context.Background() + ctx = SetRequestMetadataContext(ctx, "Mozilla/5.0", "/oauth/token", "POST") + + assert.Equal(t, "Mozilla/5.0", GetUserAgentFromContext(ctx)) + assert.Equal(t, "/oauth/token", GetRequestPathFromContext(ctx)) + assert.Equal(t, "POST", GetRequestMethodFromContext(ctx)) +} + +func TestRequestMetadataContext_Empty(t *testing.T) { + ctx := context.Background() + + assert.Empty(t, GetUserAgentFromContext(ctx)) + assert.Empty(t, GetRequestPathFromContext(ctx)) + assert.Empty(t, GetRequestMethodFromContext(ctx)) +} + +func TestRequestMetadataContextChaining(t *testing.T) { + ctx := context.Background() + ctx = SetIPContext(ctx, "10.0.0.1") + ctx = SetRequestMetadataContext(ctx, "curl/7.68", "/api/v1", "GET") + + // All values should coexist + assert.Equal(t, "10.0.0.1", GetIPFromContext(ctx)) + assert.Equal(t, "curl/7.68", GetUserAgentFromContext(ctx)) + assert.Equal(t, "/api/v1", GetRequestPathFromContext(ctx)) + assert.Equal(t, "GET", GetRequestMethodFromContext(ctx)) +} + func TestIPContextChaining(t *testing.T) { type testKey int const testKeyOther testKey = 0