Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/appleboy/graceful"
"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus"
"github.com/redis/go-redis/v9"
)

Expand Down Expand Up @@ -142,6 +143,9 @@ func (app *Application) initializeInfrastructure(ctx context.Context) error {
// initializeBusinessLayer sets up services
func (app *Application) initializeBusinessLayer() {
// Audit service (required by other services)
if app.Config.MetricsEnabled {
services.SetAuditMetricsRegisterer(prometheus.DefaultRegisterer)
}
if app.Config.EnableAuditLogging {
app.AuditService = services.NewAuditService(
app.DB,
Expand Down
2 changes: 1 addition & 1 deletion internal/bootstrap/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func setupRouter(
// Setup middleware
r.Use(metrics.HTTPMetricsMiddleware(prometheusMetrics))
r.Use(gin.Logger(), gin.Recovery())
r.Use(middleware.IPMiddleware())
r.Use(middleware.RequestContextMiddleware())
r.Use(middleware.SecurityHeaders(strings.HasPrefix(cfg.BaseURL, "https://")))

// Setup session middleware
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (h *AuthHandler) Login(c *gin.Context,

// Set session fingerprint if enabled
if h.cfg.SessionFingerprint {
clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by IPMiddleware
clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by RequestContextMiddleware
userAgent := c.Request.UserAgent()
fingerprint := middleware.GenerateFingerprint(
clientIP,
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/oauth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (h *OAuthHandler) OAuthCallback(c *gin.Context) {

// Set session fingerprint if enabled
if h.sessionFingerprintEnabled {
clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by IPMiddleware
clientIP := c.GetString(middleware.ContextKeyClientIP) // Set by RequestContextMiddleware
userAgent := c.Request.UserAgent()
fingerprint := middleware.GenerateFingerprint(
clientIP,
Expand Down
2 changes: 1 addition & 1 deletion internal/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func SessionFingerprintMiddleware(enabled, includeIP bool) gin.HandlerFunc {

if storedFingerprint != nil {
// Get current fingerprint
clientIP := c.GetString(ContextKeyClientIP) // Set by IPMiddleware
clientIP := c.GetString(ContextKeyClientIP) // Set by RequestContextMiddleware
userAgent := c.Request.UserAgent()
currentFingerprint := GenerateFingerprint(clientIP, userAgent, includeIP)

Expand Down
17 changes: 13 additions & 4 deletions internal/middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,24 @@ import (
// ContextKeyClientIP is the gin context key for the client IP address.
const ContextKeyClientIP = "client_ip"

// IPMiddleware extracts client IP and stores it in the context
func IPMiddleware() gin.HandlerFunc {
// RequestContextMiddleware extracts client IP and HTTP request metadata
// (User-Agent, path, method) and stores them in the request context for
// downstream services (e.g. audit logging).
func RequestContextMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
clientIP := c.ClientIP()
// Gin's ClientIP() handles X-Forwarded-For and other headers
c.Set(ContextKeyClientIP, clientIP)

// Also store in request context for services layer
c.Request = c.Request.WithContext(util.SetIPContext(c.Request.Context(), clientIP))
// Store IP and request metadata in request context for services layer
ctx := util.SetIPContext(c.Request.Context(), clientIP)
ctx = util.SetRequestMetadataContext(
ctx,
c.Request.UserAgent(),
c.Request.URL.Path,
c.Request.Method,
)
c.Request = c.Request.WithContext(ctx)

c.Next()
}
Expand Down
171 changes: 151 additions & 20 deletions internal/services/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/go-authgate/authgate/internal/core"
Expand All @@ -14,11 +15,78 @@ import (
"github.com/go-authgate/authgate/internal/util"

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
)

// Compile-time interface check.
var _ core.AuditLogger = (*AuditService)(nil)

// auditEventsDropped is a singleton counter protected by a mutex so it can
// be created before metrics are configured and registered later once a
// registerer becomes available.
//
// The counter is only registered with Prometheus when a registerer is
// explicitly provided via SetAuditMetricsRegisterer, so deployments with
// metrics disabled do not leak collectors from the services layer.
var (
auditEventsDropped prometheus.Counter
auditEventsDroppedMu sync.Mutex
auditEventsDroppedRegistered bool
auditEventsDroppedRegisterer prometheus.Registerer
)

// registerAuditDroppedCounterLocked attempts to register the existing counter
// with the configured registerer. The caller must hold auditEventsDroppedMu.
func registerAuditDroppedCounterLocked() {
if auditEventsDropped == nil ||
auditEventsDroppedRegisterer == nil ||
auditEventsDroppedRegistered {
return
}
if err := auditEventsDroppedRegisterer.Register(auditEventsDropped); err != nil {
if existing, ok := err.(prometheus.AlreadyRegisteredError); ok {
if c, ok := existing.ExistingCollector.(prometheus.Counter); ok {
auditEventsDropped = c
auditEventsDroppedRegistered = true
return
}
}
log.Printf("failed to register audit dropped-events counter: %v", err)
return
}
auditEventsDroppedRegistered = true
}

// SetAuditMetricsRegisterer configures the Prometheus registerer used by the
// audit service. If the dropped-events counter was created before metrics
// were configured, setting a non-nil registerer will register the existing
// counter, ensuring late initialization (e.g. in tests) is not silently lost.
func SetAuditMetricsRegisterer(registerer prometheus.Registerer) {
auditEventsDroppedMu.Lock()
defer auditEventsDroppedMu.Unlock()

auditEventsDroppedRegisterer = registerer
registerAuditDroppedCounterLocked()
}

func getAuditEventsDroppedCounter() prometheus.Counter {
auditEventsDroppedMu.Lock()
defer auditEventsDroppedMu.Unlock()

if auditEventsDropped == nil {
// Use a single fully-prefixed Name (no Namespace/Subsystem) to
// match the convention used by metrics in internal/metrics.
auditEventsDropped = prometheus.NewCounter(prometheus.CounterOpts{
Name: "audit_events_dropped_total",
Help: "Total number of audit log events dropped due to a full buffer.",
})
}
registerAuditDroppedCounterLocked()
// When no registerer is set, the counter still works in-memory but
// is not exposed via the Prometheus /metrics endpoint.
return auditEventsDropped
}

// AuditService handles audit logging operations
type AuditService struct {
store core.Store
Expand All @@ -33,8 +101,12 @@ type AuditService struct {
batchTicker *time.Ticker

// Graceful shutdown
wg sync.WaitGroup
shutdownCh chan struct{}
wg sync.WaitGroup
sendMu sync.RWMutex // coordinates Log() senders with Shutdown()
stopped atomic.Bool

// Prometheus counter for dropped events
eventsDropped prometheus.Counter
}

// NewAuditService creates a new audit service
Expand All @@ -44,11 +116,11 @@ func NewAuditService(s core.Store, bufferSize int) *AuditService {
}

service := &AuditService{
store: s,
bufferSize: bufferSize,
logChan: make(chan *models.AuditLog, bufferSize),
batchBuffer: make([]*models.AuditLog, 0, 100),
shutdownCh: make(chan struct{}),
store: s,
bufferSize: bufferSize,
logChan: make(chan *models.AuditLog, bufferSize),
batchBuffer: make([]*models.AuditLog, 0, 100),
eventsDropped: getAuditEventsDroppedCounter(),
}

service.batchTicker = time.NewTicker(1 * time.Second)
Expand All @@ -59,23 +131,25 @@ func NewAuditService(s core.Store, bufferSize int) *AuditService {
return service
}

// worker is the background goroutine that processes audit logs
// worker is the background goroutine that processes audit logs.
// It drains logChan until the channel is closed by Shutdown, then
// flushes any remaining batch and exits.
func (s *AuditService) worker() {
defer s.wg.Done()

for {
select {
case log := <-s.logChan:
s.addToBatch(log)
case entry, ok := <-s.logChan:
if !ok {
// Channel closed by Shutdown — flush remaining batch.
s.flushBatch()
return
}
s.addToBatch(entry)

case <-s.batchTicker.C:
// Flush batch every second
s.flushBatch()

case <-s.shutdownCh:
// Flush remaining logs before shutdown
s.flushBatch()
return
}
}
}
Expand Down Expand Up @@ -126,14 +200,52 @@ func (s *AuditService) buildAuditLog(
if entry.ActorIP == "" {
entry.ActorIP = util.GetIPFromContext(ctx)
}
// Fill ActorUsername from context only when the entry's ActorUserID is
// empty or matches the context user — otherwise the username could be
// misattributed to a different principal than ActorUserID identifies.
if entry.ActorUsername == "" {
entry.ActorUsername = models.GetUsernameFromContext(ctx)
ctxUserID := models.GetUserIDFromContext(ctx)
if entry.ActorUserID == "" || entry.ActorUserID == ctxUserID {
entry.ActorUsername = models.GetUsernameFromContext(ctx)
}
}
// Fall back to a DB lookup when context did not provide a username and
// the actor is a real user (not a synthetic machine identity from the
// client_credentials grant, which uses a "client:<clientID>" format and
// has no corresponding user row).
if entry.ActorUsername == "" && entry.ActorUserID != "" &&
!strings.HasPrefix(entry.ActorUserID, "client:") {
if user, err := s.store.GetUserByID(entry.ActorUserID); err == nil {
entry.ActorUsername = user.Username
}
}
if entry.ActorUserID == "" {
entry.ActorUserID = models.GetUserIDFromContext(ctx)
}
if entry.UserAgent == "" {
entry.UserAgent = util.GetUserAgentFromContext(ctx)
}
if entry.RequestPath == "" {
entry.RequestPath = util.GetRequestPathFromContext(ctx)
}
if entry.RequestMethod == "" {
entry.RequestMethod = util.GetRequestMethodFromContext(ctx)
}
entry.Details = maskSensitiveDetails(entry.Details)

// Truncate fields to match database column size limits.
// TruncateString appends "..." (3 chars) when truncating, so subtract 3
// from the varchar limit to guarantee the final length fits the column.
entry.UserAgent = util.TruncateString(entry.UserAgent, 497)
entry.RequestPath = util.TruncateString(entry.RequestPath, 497)

// RequestMethod is stored in a varchar(10) column. Preserve values up to
// the full column width and hard-truncate anything longer without adding
// an ellipsis.
if len(entry.RequestMethod) > 10 {
entry.RequestMethod = entry.RequestMethod[:10]
}

now := time.Now()
return &models.AuditLog{
ID: uuid.New().String(),
Expand All @@ -157,13 +269,25 @@ func (s *AuditService) buildAuditLog(
}
}

// Log records an audit log entry asynchronously
// Log records an audit log entry asynchronously.
// Events submitted after Shutdown has been called are dropped.
// The RWMutex ensures all in-flight sends complete before Shutdown
// closes logChan, eliminating the send-on-closed-channel race.
func (s *AuditService) Log(ctx context.Context, entry core.AuditLogEntry) {
s.sendMu.RLock()
defer s.sendMu.RUnlock()

if s.stopped.Load() {
log.Printf("WARNING: Audit service stopped, dropping event: %s", entry.Action)
s.eventsDropped.Inc()
return
}
auditLog := s.buildAuditLog(ctx, entry)
select {
case s.logChan <- auditLog:
default:
log.Printf("WARNING: Audit log buffer full, dropping event: %s", entry.Action)
s.eventsDropped.Inc()
}
}

Expand Down Expand Up @@ -193,12 +317,19 @@ func (s *AuditService) GetAuditLogStats(startTime, endTime time.Time) (store.Aud

// Shutdown gracefully shuts down the audit service
func (s *AuditService) Shutdown(ctx context.Context) error {
// 1. Reject new events so future Log() calls return immediately.
s.stopped.Store(true)

// 2. Wait for all in-flight Log() calls to finish, then close
// logChan. The exclusive lock ensures no sender is mid-send
// when the channel is closed.
s.sendMu.Lock()
close(s.logChan)
s.sendMu.Unlock()

// Stop ticker
s.batchTicker.Stop()

// Signal worker to stop
close(s.shutdownCh)

// Wait for worker to finish with timeout
done := make(chan struct{})
go func() {
Expand Down
Loading
Loading