diff --git a/backend/api/azure_test.go b/backend/api/azure_test.go index c40312b..a43a4e8 100644 --- a/backend/api/azure_test.go +++ b/backend/api/azure_test.go @@ -3,6 +3,7 @@ package main import ( + "context" "os" "testing" @@ -68,13 +69,13 @@ func TestAzureTableIntegration(t *testing.T) { healthChecker := health.New() // For Azure Table, we add a mock health check since the repo is nil - healthChecker.AddCheck("database", func() error { + healthChecker.AddCheck("database", func(_ context.Context) error { return nil // Mock check since we expect repo to be nil in this test }) // The health check should return UP since we're using a mock check healthChecker.SetReady(true) - status := healthChecker.CheckReadiness() + status := healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "UP", status.Status) }) diff --git a/backend/api/main.go b/backend/api/main.go index c13f405..adad117 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -10,7 +10,7 @@ import ( "backend/internal/health" "context" "fmt" - "log" + "log/slog" "net/http" "os" "os/signal" @@ -23,7 +23,7 @@ import ( ) const ( - gracefulShutdownTimeout = 5 * time.Second + defaultShutdownTimeout = 5 * time.Second ) // @title Backend API @@ -38,22 +38,28 @@ func main() { // Load configuration cfg, err := config.LoadConfig() if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + slog.Error("Failed to load configuration", "error", err) + os.Exit(1) } - // Initialize database - db, err := database.NewFromAppConfig(cfg) + // Initialize repository using the factory (selects MySQL or Azure Table based on config) + repo, err := database.NewRepository(cfg) if err != nil { - log.Fatalf("Failed to initialize database: %v", err) + slog.Error("Failed to initialize repository", "error", err) + os.Exit(1) } - // Initialize health checker + // Initialize health checker with actual database dependency healthChecker := health.New() - healthChecker.AddCheck("database", db.Ping) + healthChecker.AddCheck("database", func(ctx context.Context) error { + return repo.Ping(ctx) + }) + healthChecker.SetReady(true) - // Setup router - router := gin.Default() - routes.SetupRoutes(router, db) + // Setup router — use gin.New() since SetupRoutes registers its own Logger and Recovery middleware. + router := gin.New() + rateLimiter := routes.SetupRoutes(router, repo, healthChecker, cfg) + defer rateLimiter.Stop() router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) // Create server with timeouts @@ -67,8 +73,10 @@ func main() { // Start server in a goroutine go func() { + slog.Info("Server starting", "addr", srv.Addr) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to start server: %v", err) + slog.Error("Failed to start server", "error", err) + os.Exit(1) } }() @@ -76,20 +84,27 @@ func main() { quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - log.Println("Shutting down server...") + slog.Info("Shutting down server...") - // Give outstanding requests 5 seconds to complete - ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout) + // Give outstanding requests time to complete + shutdownTimeout := cfg.Server.ShutdownTimeout + if shutdownTimeout == 0 { + shutdownTimeout = defaultShutdownTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() err = srv.Shutdown(ctx) - // Always execute cleanup - cancel() + + // Close repository connections (database pool, etc.) + if closeErr := repo.Close(); closeErr != nil { + slog.Error("Failed to close repository", "error", closeErr) + } if err != nil { - log.Printf("Server forced to shutdown: %v", err) - return // Return with error status from main + slog.Error("Server forced to shutdown", "error", err) + return } - log.Println("Server exiting") + slog.Info("Server exited gracefully") } diff --git a/backend/api/main_test.go b/backend/api/main_test.go index 2d571cd..2fe1a30 100644 --- a/backend/api/main_test.go +++ b/backend/api/main_test.go @@ -38,32 +38,37 @@ type MockRepository struct { mock.Mock } -func (m *MockRepository) Create(entity interface{}) error { - args := m.Called(entity) +func (m *MockRepository) Create(ctx context.Context, entity interface{}) error { + args := m.Called(ctx, entity) return args.Error(0) } -func (m *MockRepository) FindByID(id uint, dest interface{}) error { - args := m.Called(id, dest) +func (m *MockRepository) FindByID(ctx context.Context, id uint, dest interface{}) error { + args := m.Called(ctx, id, dest) return args.Error(0) } -func (m *MockRepository) Update(entity interface{}) error { - args := m.Called(entity) +func (m *MockRepository) Update(ctx context.Context, entity interface{}) error { + args := m.Called(ctx, entity) return args.Error(0) } -func (m *MockRepository) Delete(entity interface{}) error { - args := m.Called(entity) +func (m *MockRepository) Delete(ctx context.Context, entity interface{}) error { + args := m.Called(ctx, entity) return args.Error(0) } -func (m *MockRepository) List(dest interface{}, conditions ...interface{}) error { - args := m.Called(dest, conditions) +func (m *MockRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { + args := m.Called(ctx, dest, conditions) return args.Error(0) } -func (m *MockRepository) Ping() error { +func (m *MockRepository) Ping(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockRepository) Close() error { args := m.Called() return args.Error(0) } @@ -181,12 +186,12 @@ func TestHealthEndpoints(t *testing.T) { // Register health endpoints r.GET("/health/live", func(c *gin.Context) { - status := healthChecker.CheckLiveness() + status := healthChecker.CheckLiveness(c.Request.Context()) c.JSON(http.StatusOK, status) }) r.GET("/health/ready", func(c *gin.Context) { - status := healthChecker.CheckReadiness() + status := healthChecker.CheckReadiness(c.Request.Context()) if status.Status == "DOWN" { c.JSON(http.StatusServiceUnavailable, status) return @@ -211,7 +216,7 @@ func TestHealthEndpoints(t *testing.T) { assert.Contains(t, w.Body.String(), "UP") // Test with failing health check - healthChecker.AddCheck("test", func() error { + healthChecker.AddCheck("test", func(_ context.Context) error { return errors.New("test error") }) @@ -230,36 +235,36 @@ func TestDatabaseHealthCheck(t *testing.T) { // Test with SQLite database mockRepo := new(MockRepository) - mockRepo.On("Ping").Return(nil) + mockRepo.On("Ping", mock.Anything).Return(nil) // Add database health check - healthChecker.AddCheck("database", func() error { - return mockRepo.Ping() + healthChecker.AddCheck("database", func(_ context.Context) error { + return mockRepo.Ping(context.Background()) }) // Check readiness - status := healthChecker.CheckReadiness() + status := healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "DOWN", status.Status) // Initially DOWN because we haven't set ready // Mark as ready healthChecker.SetReady(true) // Check again - status = healthChecker.CheckReadiness() + status = healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "UP", status.Status) assert.Equal(t, "UP", status.Checks["database"].Status) // Test with failing database connection mockRepo = new(MockRepository) - mockRepo.On("Ping").Return(errors.New("connection failed")) + mockRepo.On("Ping", mock.Anything).Return(errors.New("connection failed")) healthChecker = health.New() healthChecker.SetReady(true) - healthChecker.AddCheck("database", func() error { - return mockRepo.Ping() + healthChecker.AddCheck("database", func(_ context.Context) error { + return mockRepo.Ping(context.Background()) }) - status = healthChecker.CheckReadiness() + status = healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "DOWN", status.Status) assert.Equal(t, "DOWN", status.Checks["database"].Status) assert.Contains(t, status.Checks["database"].Message, "connection failed") diff --git a/backend/internal/api/handlers/handlers.go b/backend/internal/api/handlers/handlers.go index f916182..543c4ff 100644 --- a/backend/internal/api/handlers/handlers.go +++ b/backend/internal/api/handlers/handlers.go @@ -7,14 +7,6 @@ import ( "github.com/gin-gonic/gin" ) -var healthChecker *health.HealthChecker - -func init() { - healthChecker = health.New() - // Set the service as ready after initialization - healthChecker.SetReady(true) -} - // @Summary Health Check // @Description Get API health status // @Tags health @@ -35,40 +27,25 @@ func Ping(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "pong"}) } -// @Summary Get items -// @Description Get all items -// @Tags items -// @Produce json -// @Success 200 {object} map[string]string -// @Router /api/v1/items [get] -func GetItems(c *gin.Context) { - // Logic to retrieve items goes here - c.JSON(http.StatusOK, gin.H{"message": "GetItems called"}) -} - -// @Summary Create item -// @Description Create a new item -// @Tags items -// @Accept json -// @Produce json -// @Success 201 {object} map[string]string -// @Router /api/v1/items [post] -func CreateItem(c *gin.Context) { - // Logic to create an item goes here - c.JSON(http.StatusCreated, gin.H{"message": "CreateItem called"}) -} - +// LivenessHandler returns a handler for liveness checks. +// The health checker is injected so the same instance used in main is checked. +// // @Summary Liveness Check // @Description Get API liveness status // @Tags health // @Produce json // @Success 200 {object} health.HealthStatus // @Router /health/live [get] -func LivenessCheck(c *gin.Context) { - status := healthChecker.CheckLiveness() - c.JSON(http.StatusOK, status) +func LivenessHandler(hc *health.HealthChecker) gin.HandlerFunc { + return func(c *gin.Context) { + status := hc.CheckLiveness(c.Request.Context()) + c.JSON(http.StatusOK, status) + } } +// ReadinessHandler returns a handler for readiness checks. +// The health checker is injected so the same instance used in main is checked. +// // @Summary Readiness Check // @Description Get API readiness status // @Tags health @@ -76,11 +53,13 @@ func LivenessCheck(c *gin.Context) { // @Success 200 {object} health.HealthStatus // @Failure 503 {object} health.HealthStatus // @Router /health/ready [get] -func ReadinessCheck(c *gin.Context) { - status := healthChecker.CheckReadiness() - if status.Status == "DOWN" { - c.JSON(http.StatusServiceUnavailable, status) - return +func ReadinessHandler(hc *health.HealthChecker) gin.HandlerFunc { + return func(c *gin.Context) { + status := hc.CheckReadiness(c.Request.Context()) + if status.Status == "DOWN" { + c.JSON(http.StatusServiceUnavailable, status) + return + } + c.JSON(http.StatusOK, status) } - c.JSON(http.StatusOK, status) } diff --git a/backend/internal/api/handlers/handlers_test.go b/backend/internal/api/handlers/handlers_test.go index 372f732..587bf01 100644 --- a/backend/internal/api/handlers/handlers_test.go +++ b/backend/internal/api/handlers/handlers_test.go @@ -61,43 +61,3 @@ func TestPingHandler(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "pong", response["message"]) } -func TestGetItemsHandler(t *testing.T) { - t.Parallel() - gin.SetMode(gin.TestMode) - - r := gin.Default() - r.GET("/items", GetItems) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/items", nil) - - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - - assert.Nil(t, err) - assert.Equal(t, "GetItems called", response["message"]) -} -func TestCreateItemHandler(t *testing.T) { - t.Parallel() - gin.SetMode(gin.TestMode) - - r := gin.Default() - r.POST("/items", CreateItem) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/items", nil) - - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusCreated, w.Code) - - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - - assert.Nil(t, err) - assert.Equal(t, "CreateItem called", response["message"]) -} diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index ff29a15..b6c80f1 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -42,7 +42,8 @@ func handleDBError(err error) (int, string) { if strings.Contains(err.Error(), "not found") { return http.StatusNotFound, "Item not found" } - return http.StatusInternalServerError, err.Error() + // Never leak raw error messages to clients + return http.StatusInternalServerError, "Internal server error" } // CreateItem godoc @@ -67,7 +68,10 @@ func (h *Handler) CreateItem(c *gin.Context) { return } - if err := h.repository.Create(&item); err != nil { + // Version is server-managed; force initial value regardless of client input. + item.Version = 1 + + if err := h.repository.Create(c.Request.Context(), &item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -126,7 +130,7 @@ func (h *Handler) GetItems(c *gin.Context) { conditions = append(conditions, models.Pagination{Limit: limit, Offset: offset}) } - if err := h.repository.List(&items, conditions...); err != nil { + if err := h.repository.List(c.Request.Context(), &items, conditions...); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -152,7 +156,7 @@ func (h *Handler) GetItem(c *gin.Context) { } var item models.Item - if err := h.repository.FindByID(uint(id), &item); err != nil { + if err := h.repository.FindByID(c.Request.Context(), uint(id), &item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -181,7 +185,7 @@ func (h *Handler) UpdateItem(c *gin.Context) { // Get the current version from the database var currentItem models.Item - if err := h.repository.FindByID(uint(id), ¤tItem); err != nil { + if err := h.repository.FindByID(c.Request.Context(), uint(id), ¤tItem); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -193,23 +197,21 @@ func (h *Handler) UpdateItem(c *gin.Context) { return } - // Keep all existing fields but update name and price from the request + // Update fields from request currentItem.Name = updateItem.Name currentItem.Price = updateItem.Price - // Version check for optimistic locking (if version was provided in request) - if updateItem.Version > 0 && updateItem.Version != currentItem.Version { - c.JSON(http.StatusConflict, gin.H{"error": "Item has been modified by another request"}) - return + // Optimistic locking: if the client provided a version, use it so the + // repository can detect conflicts. If version=0 (not provided), the + // repository uses the version we just read — this still detects conflicts + // that occur between our FindByID and the repository's WHERE-version check, + // but the client must send the version to guarantee end-to-end safety. + if updateItem.Version > 0 { + currentItem.Version = updateItem.Version } - // Make sure we're using the version from the request for optimistic locking - currentItem.Version = updateItem.Version - - // We don't increment the version here, the repository will handle that - - if err := h.repository.Update(¤tItem); err != nil { - if err.Error() == "version mismatch" { + if err := h.repository.Update(c.Request.Context(), ¤tItem); err != nil { + if strings.Contains(err.Error(), "version mismatch") { c.JSON(http.StatusConflict, gin.H{"error": "Item has been modified by another request"}) return } @@ -237,9 +239,10 @@ func (h *Handler) DeleteItem(c *gin.Context) { return } - item := &models.Item{} - item.ID = uint(id) - if err := h.repository.Delete(item); err != nil { + // Delete directly — the repository returns ErrNotFound if the item doesn't exist. + // This avoids a race condition between a FindByID check and the actual delete. + item := &models.Item{Base: models.Base{ID: uint(id)}} + if err := h.repository.Delete(c.Request.Context(), item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return diff --git a/backend/internal/api/handlers/items_test.go b/backend/internal/api/handlers/items_test.go index 5bc49db..28e01d6 100644 --- a/backend/internal/api/handlers/items_test.go +++ b/backend/internal/api/handlers/items_test.go @@ -2,6 +2,7 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -124,7 +125,7 @@ func TestGetItem(t *testing.T) { // Create a test item testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) tests := []struct { wantItem *models.Item // 8 bytes (pointer) @@ -179,7 +180,7 @@ func TestUpdateItem(t *testing.T) { // Create a test item testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) tests := []struct { name string @@ -259,7 +260,7 @@ func TestDeleteItem(t *testing.T) { // Create a test item testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) tests := []struct { name string @@ -310,7 +311,7 @@ func TestListItems(t *testing.T) { } for _, item := range items { - mockRepo.Create(&item) + mockRepo.Create(context.Background(), &item) } tests := []struct { @@ -415,7 +416,7 @@ func TestListItemsErrors(t *testing.T) { var response map[string]string err := json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) - assert.Contains(t, response["error"], "database error") + assert.Contains(t, response["error"], "Internal server error") } func TestConcurrentItemOperations(t *testing.T) { @@ -460,7 +461,7 @@ func TestConcurrentItemOperations(t *testing.T) { t.Run("concurrent item updates with version validation", func(t *testing.T) { // Create an item to update item := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(item) + mockRepo.Create(context.Background(), item) itemID := fmt.Sprint(item.ID) // Channel to collect successful updates @@ -660,7 +661,7 @@ func BenchmarkItemOperations(b *testing.B) { method: "GET", pathGen: func(mockRepo *MockRepository) string { testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) return "/api/v1/items/" + fmt.Sprint(testItem.ID) }, bodyGen: func(_ *MockRepository) []byte { return nil }, @@ -678,7 +679,7 @@ func BenchmarkItemOperations(b *testing.B) { method: "PUT", pathGen: func(mockRepo *MockRepository) string { testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) return "/api/v1/items/" + fmt.Sprint(testItem.ID) }, bodyGen: func(mockRepo *MockRepository) []byte { @@ -918,7 +919,7 @@ func TestHandleDBError(t *testing.T) { {err: duplicateErr, wantCode: http.StatusConflict, wantMsg: "Item already exists"}, {err: otherDBErr, wantCode: http.StatusInternalServerError, wantMsg: "Internal server error"}, {err: plainNotFound, wantCode: http.StatusNotFound, wantMsg: "Item not found"}, - {err: plainOther, wantCode: http.StatusInternalServerError, wantMsg: plainOther.Error()}, + {err: plainOther, wantCode: http.StatusInternalServerError, wantMsg: "Internal server error"}, } for _, tt := range tests { diff --git a/backend/internal/api/handlers/mock_repository.go b/backend/internal/api/handlers/mock_repository.go index 38a15ff..e372b8d 100644 --- a/backend/internal/api/handlers/mock_repository.go +++ b/backend/internal/api/handlers/mock_repository.go @@ -1,12 +1,14 @@ package handlers import ( - "backend/internal/models" + "context" "errors" "fmt" "sort" "strings" "sync" + + "backend/internal/models" ) // MockRepository is a mock implementation of the Repository interface for testing @@ -24,7 +26,7 @@ func NewMockRepository() *MockRepository { } } -func (m *MockRepository) Create(entity interface{}) error { +func (m *MockRepository) Create(_ context.Context, entity interface{}) error { m.Lock() defer m.Unlock() @@ -38,13 +40,13 @@ func (m *MockRepository) Create(entity interface{}) error { } item.ID = m.nextID - item.Version = 0 // Initialize version + item.Version = 1 // Initialize version (1 = first version; 0 = "not provided" sentinel) m.nextID++ m.items[item.ID] = item return nil } -func (m *MockRepository) FindByID(id uint, dest interface{}) error { +func (m *MockRepository) FindByID(_ context.Context, id uint, dest interface{}) error { m.RLock() defer m.RUnlock() @@ -66,7 +68,7 @@ func (m *MockRepository) FindByID(id uint, dest interface{}) error { return nil } -func (m *MockRepository) Update(entity interface{}) error { +func (m *MockRepository) Update(_ context.Context, entity interface{}) error { m.Lock() defer m.Unlock() @@ -102,7 +104,7 @@ func (m *MockRepository) Update(entity interface{}) error { return nil } -func (m *MockRepository) Delete(entity interface{}) error { +func (m *MockRepository) Delete(_ context.Context, entity interface{}) error { m.Lock() defer m.Unlock() @@ -123,7 +125,7 @@ func (m *MockRepository) Delete(entity interface{}) error { return nil } -func (m *MockRepository) List(dest interface{}, conditions ...interface{}) error { +func (m *MockRepository) List(_ context.Context, dest interface{}, conditions ...interface{}) error { m.RLock() defer m.RUnlock() @@ -218,7 +220,12 @@ func (m *MockRepository) List(dest interface{}, conditions ...interface{}) error } // Ping implements the Repository interface -func (m *MockRepository) Ping() error { +func (m *MockRepository) Ping(_ context.Context) error { + return nil +} + +// Close implements the Repository interface +func (m *MockRepository) Close() error { return nil } diff --git a/backend/internal/api/handlers/rate_limiter.go b/backend/internal/api/handlers/rate_limiter.go index fd61757..4b983d1 100644 --- a/backend/internal/api/handlers/rate_limiter.go +++ b/backend/internal/api/handlers/rate_limiter.go @@ -13,34 +13,42 @@ type RateLimiter struct { sync.RWMutex // size: 8 window time.Duration // size: 8 requests map[string][]time.Time // size: 8 (pointer) + done chan struct{} // size: 8 limit int // size: 4 + stopOnce sync.Once // ensures Stop is idempotent } func NewRateLimiter(limit int, window time.Duration) *RateLimiter { - return &RateLimiter{ + rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, + done: make(chan struct{}), } + go rl.cleanup() + return rl +} + +// Stop terminates the background cleanup goroutine. +// It is safe to call Stop multiple times. +func (rl *RateLimiter) Stop() { + rl.stopOnce.Do(func() { + close(rl.done) + }) } func (rl *RateLimiter) RateLimit() gin.HandlerFunc { return func(c *gin.Context) { ip := c.ClientIP() - - rl.Lock() - defer rl.Unlock() - now := time.Now() windowStart := now.Add(-rl.window) - // Initialize requests map for this IP if needed - if _, exists := rl.requests[ip]; !exists { - rl.requests[ip] = make([]time.Time, 0) - } - - // Remove old requests (outside our time window) - var valid []time.Time + // Single write lock for the check-and-add to avoid a TOCTOU race + // where concurrent requests could both pass the limit check. + rl.Lock() + // Filter expired timestamps during counting to prevent unbounded + // slice growth between periodic cleanup cycles. + valid := rl.requests[ip][:0] for _, t := range rl.requests[ip] { if t.After(windowStart) { valid = append(valid, t) @@ -48,24 +56,49 @@ func (rl *RateLimiter) RateLimit() gin.HandlerFunc { } rl.requests[ip] = valid - // For tests - use a slightly lower limit to ensure some requests get rate limited - // This helps identify rate limiting in tests that send requests concurrently - effectiveLimit := rl.limit - if len(rl.requests[ip]) >= 30 && now.Nanosecond()%4 == 0 { - // Artificially lower the limit sometimes to ensure rate limiting occurs in tests - effectiveLimit = len(rl.requests[ip]) - } - - // Check if limit exceeded - this must be evaluated AFTER cleaning up old requests - // and BEFORE adding the current request to ensure accurate rate limiting - if len(rl.requests[ip]) >= effectiveLimit { + if len(valid) >= rl.limit { + rl.Unlock() c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) c.Abort() return } - // Add current request rl.requests[ip] = append(rl.requests[ip], now) + rl.Unlock() + c.Next() } } + +// cleanup periodically removes expired entries to prevent memory leaks. +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(rl.window / 2) + defer ticker.Stop() + for { + select { + case <-ticker.C: + rl.cleanupExpired() + case <-rl.done: + return + } + } +} + +func (rl *RateLimiter) cleanupExpired() { + rl.Lock() + defer rl.Unlock() + now := time.Now() + for ip, times := range rl.requests { + var valid []time.Time + for _, t := range times { + if now.Sub(t) < rl.window { + valid = append(valid, t) + } + } + if len(valid) == 0 { + delete(rl.requests, ip) + } else { + rl.requests[ip] = valid + } + } +} diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index e9f7547..fe5d6de 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -11,39 +11,128 @@ import ( func TestCORSMiddleware(t *testing.T) { t.Parallel() - // Set Gin to Test Mode - gin.SetMode(gin.TestMode) - // Setup router with middleware - r := gin.New() - r.Use(CORS()) - r.Any("/test", func(c *gin.Context) { - c.Status(http.StatusOK) + t.Run("Wildcard allows all origins", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("*")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID", w.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, http.StatusOK, w.Code) }) - t.Run("Regular GET request", func(t *testing.T) { + t.Run("Empty string allows all origins", func(t *testing.T) { t.Parallel() + r := gin.New() + r.Use(CORS("")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) - // Check CORS headers assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) - assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) - assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization", w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("Allowed origin is set from whitelist", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com,https://other.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://example.com") + r.ServeHTTP(w, req) + + assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", w.Header().Get("Vary")) assert.Equal(t, http.StatusOK, w.Code) }) - t.Run("OPTIONS preflight request", func(t *testing.T) { + t.Run("Second allowed origin is matched", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com, https://other.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://other.com") + r.ServeHTTP(w, req) + + assert.Equal(t, "https://other.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", w.Header().Get("Vary")) + }) + + t.Run("Disallowed origin gets no Access-Control-Allow-Origin", func(t *testing.T) { t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://evil.com") + r.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Vary")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, http.StatusForbidden, w.Code) + }) + + t.Run("No Origin header passes through as non-CORS request", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + // No Origin header set — non-browser / same-origin request + r.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("Disallowed origin OPTIONS preflight returns 403 without CORS headers", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("Origin", "https://evil.com") + r.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, http.StatusForbidden, w.Code) + }) + + t.Run("OPTIONS preflight returns 204", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("*")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + w := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/test", nil) r.ServeHTTP(w, req) - // Check CORS headers assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) - assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization", w.Header().Get("Access-Control-Allow-Headers")) - assert.Equal(t, http.StatusNoContent, w.Code) // OPTIONS request should return 204 No Content + assert.Equal(t, http.StatusNoContent, w.Code) }) } diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index 5c7fb71..f08a79a 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -1,18 +1,49 @@ package middleware import ( - "log" + "crypto/rand" + "errors" + "fmt" + "io" + "log/slog" "net/http" + "strings" + "time" "github.com/gin-gonic/gin" ) -// CORS middleware -func CORS() gin.HandlerFunc { +// CORS middleware with configurable allowed origins. +// Pass "*" or "" to allow all origins (development only). +// For production, pass a comma-separated list of allowed origins. +func CORS(allowedOrigins string) gin.HandlerFunc { return func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + if allowedOrigins == "" || allowedOrigins == "*" { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } else { + requestOrigin := c.Request.Header.Get("Origin") + if requestOrigin != "" { + allowed := false + for _, origin := range strings.Split(allowedOrigins, ",") { + if strings.TrimSpace(origin) == requestOrigin { + c.Writer.Header().Set("Access-Control-Allow-Origin", requestOrigin) + c.Writer.Header().Set("Vary", "Origin") + allowed = true + break + } + } + if !allowed { + // Block requests from non-whitelisted origins as defense-in-depth; + // browsers enforce CORS client-side, but we also enforce server-side. + c.AbortWithStatus(http.StatusForbidden) + return + } + } + // If there is no Origin header, treat this as a non-CORS request: + // allow it through without setting Access-Control-Allow-Origin. + } c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, Authorization") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) @@ -23,10 +54,13 @@ func CORS() gin.HandlerFunc { } } -// Logger is a middleware that logs the incoming requests. +// Logger is a middleware that logs incoming requests using structured logging. func Logger() gin.HandlerFunc { return func(c *gin.Context) { - log.Printf("Request: %s %s", c.Request.Method, c.Request.URL) + slog.Info("incoming request", + "method", c.Request.Method, + "path", c.Request.URL.Path, + ) c.Next() } } @@ -36,7 +70,7 @@ func Recovery() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - log.Printf("Recovered from panic: %v", err) + slog.Error("recovered from panic", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"}) c.Abort() } @@ -44,3 +78,67 @@ func Recovery() gin.HandlerFunc { c.Next() } } + +// RequestID adds a unique request ID to each request. +// If the client sends an X-Request-ID header, it is reused; otherwise a new one is generated. +func RequestID() gin.HandlerFunc { + return func(c *gin.Context) { + requestID := c.GetHeader("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + c.Set("request_id", requestID) + c.Writer.Header().Set("X-Request-ID", requestID) + c.Next() + } +} + +// maxBytesBodyCapture wraps an io.ReadCloser to detect *http.MaxBytesError. +type maxBytesBodyCapture struct { + rc io.ReadCloser + exceeded *bool +} + +func (m *maxBytesBodyCapture) Read(p []byte) (int, error) { + n, err := m.rc.Read(p) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + *m.exceeded = true + } + } + return n, err +} + +func (m *maxBytesBodyCapture) Close() error { + return m.rc.Close() +} + +// MaxBodySize limits the size of the request body to prevent memory exhaustion. +// Oversized payloads are translated to a 413 Request Entity Too Large response. +func MaxBodySize(maxBytes int64) gin.HandlerFunc { + return func(c *gin.Context) { + var exceeded bool + limitedReader := http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) + c.Request.Body = &maxBytesBodyCapture{rc: limitedReader, exceeded: &exceeded} + c.Next() + + // If the body size was exceeded and the handler has not yet written + // a response, return a 413 so clients get a clear signal. + if exceeded && !c.Writer.Written() { + c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, + gin.H{"error": "request body too large"}) + return + } + } +} + +func generateRequestID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // Fallback: use timestamp-based ID if crypto/rand fails + slog.Warn("failed to generate random request ID, using fallback", "error", err) + return fmt.Sprintf("fallback-%d", time.Now().UnixNano()) + } + return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) +} diff --git a/backend/internal/api/middleware/middleware_test.go b/backend/internal/api/middleware/middleware_test.go index 6a05b29..73548dc 100644 --- a/backend/internal/api/middleware/middleware_test.go +++ b/backend/internal/api/middleware/middleware_test.go @@ -3,27 +3,33 @@ package middleware import ( "bytes" "encoding/json" - "log" + "errors" + "io" + "log/slog" "net/http" "net/http/httptest" "os" + "strings" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) -func TestLoggerMiddleware(t *testing.T) { - t.Parallel() - // Set Gin to Test Mode +func TestMain(m *testing.M) { gin.SetMode(gin.TestMode) + os.Exit(m.Run()) +} + +func TestLoggerMiddleware(t *testing.T) { + // Not parallel: this test mutates the global slog default logger. - // Create a buffer to capture log output + // Create a buffer to capture slog output var buf bytes.Buffer - log.SetOutput(&buf) - defer func() { - log.SetOutput(os.Stdout) - }() + handler := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) + origLogger := slog.Default() + slog.SetDefault(slog.New(handler)) + defer slog.SetDefault(origLogger) // Setup router with middleware r := gin.New() @@ -51,8 +57,6 @@ func TestLoggerMiddleware(t *testing.T) { func TestRecoveryMiddleware(t *testing.T) { t.Parallel() - // Set Gin to Test Mode - gin.SetMode(gin.TestMode) // Setup router with middleware r := gin.New() @@ -76,3 +80,116 @@ func TestRecoveryMiddleware(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "Internal Server Error", response["error"]) } + +func TestRequestIDMiddleware(t *testing.T) { + t.Parallel() + + t.Run("Generates new request ID when none provided", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(RequestID()) + var capturedID string + r.GET("/test", func(c *gin.Context) { + if v, ok := c.Get("request_id"); ok { + capturedID = v.(string) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.NotEmpty(t, w.Header().Get("X-Request-ID")) + assert.NotEmpty(t, capturedID) + assert.Equal(t, w.Header().Get("X-Request-ID"), capturedID) + }) + + t.Run("Reuses client-provided X-Request-ID", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(RequestID()) + var capturedID string + r.GET("/test", func(c *gin.Context) { + if v, ok := c.Get("request_id"); ok { + capturedID = v.(string) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Request-ID", "client-id-123") + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "client-id-123", w.Header().Get("X-Request-ID")) + assert.Equal(t, "client-id-123", capturedID) + }) + + t.Run("Generated IDs are unique", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(RequestID()) + r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + ids := make(map[string]bool) + for i := 0; i < 10; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + id := w.Header().Get("X-Request-ID") + assert.False(t, ids[id], "duplicate request ID generated") + ids[id] = true + } + }) +} + +func TestMaxBodySizeMiddleware(t *testing.T) { + t.Parallel() + + t.Run("Allows request within size limit", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(MaxBodySize(1024)) // 1 KB + r.POST("/test", func(c *gin.Context) { + body := make([]byte, 512) + _, err := c.Request.Body.Read(body) + if err != nil && !errors.Is(err, io.EOF) { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "body too large"}) + return + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + body := strings.NewReader(strings.Repeat("a", 100)) + req, _ := http.NewRequest("POST", "/test", body) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("Rejects request exceeding size limit", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(MaxBodySize(64)) // 64 bytes + r.POST("/test", func(c *gin.Context) { + body := make([]byte, 128) + _, err := c.Request.Body.Read(body) + if err != nil { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "body too large"}) + return + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + body := strings.NewReader(strings.Repeat("a", 128)) + req, _ := http.NewRequest("POST", "/test", body) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) + }) +} diff --git a/backend/internal/api/routes/routes.go b/backend/internal/api/routes/routes.go index 9d5ba6d..692927f 100644 --- a/backend/internal/api/routes/routes.go +++ b/backend/internal/api/routes/routes.go @@ -3,28 +3,39 @@ package routes import ( "backend/internal/api/handlers" "backend/internal/api/middleware" + "backend/internal/config" + "backend/internal/health" "backend/internal/models" + "time" "github.com/gin-gonic/gin" ) -// SetupRoutes configures all the routes for our application -func SetupRoutes(router *gin.Engine, repository models.Repository) { +// SetupRoutes configures all the routes for our application. +// healthChecker is injected from main so the readiness endpoint reflects real dependency health. +// Returns the rate limiter so the caller can stop it during shutdown. +func SetupRoutes(router *gin.Engine, repository models.Repository, healthChecker *health.HealthChecker, cfg *config.Config) *handlers.RateLimiter { // Add middleware + router.Use(middleware.RequestID()) router.Use(middleware.Logger()) router.Use(middleware.Recovery()) - router.Use(middleware.CORS()) + router.Use(middleware.CORS(cfg.CORS.AllowedOrigins)) + router.Use(middleware.MaxBodySize(1 << 20)) // 1 MB default // Health check endpoints - health := router.Group("/health") + healthGroup := router.Group("/health") { - health.GET("/live", handlers.LivenessCheck) - health.GET("/ready", handlers.ReadinessCheck) - health.GET("", handlers.HealthCheck) // Keep the original health check for backward compatibility + healthGroup.GET("/live", handlers.LivenessHandler(healthChecker)) + healthGroup.GET("/ready", handlers.ReadinessHandler(healthChecker)) + healthGroup.GET("", handlers.HealthCheck) // Keep the original health check for backward compatibility } + // Rate limiter for API routes + rateLimiter := handlers.NewRateLimiter(100, time.Minute) + // API v1 routes v1 := router.Group("/api/v1") + v1.Use(rateLimiter.RateLimit()) { // Ping endpoint v1.GET("/ping", handlers.Ping) @@ -40,4 +51,6 @@ func SetupRoutes(router *gin.Engine, repository models.Repository) { items.DELETE("/:id", itemsHandler.DeleteItem) } } + + return rateLimiter } diff --git a/backend/internal/api/routes/routes_test.go b/backend/internal/api/routes/routes_test.go index e6200ed..affb6dd 100644 --- a/backend/internal/api/routes/routes_test.go +++ b/backend/internal/api/routes/routes_test.go @@ -7,6 +7,7 @@ import ( "testing" "backend/internal/api/handlers" + "backend/internal/config" "backend/internal/health" "github.com/gin-gonic/gin" @@ -22,10 +23,19 @@ func TestSetupRoutes(t *testing.T) { mockRepo := handlers.NewMockRepository() // Initialize health checker and set it as ready - health.New().SetReady(true) + healthChecker := health.New() + healthChecker.SetReady(true) + + // Create a minimal config for testing + cfg := &config.Config{ + CORS: config.CORSConfig{ + AllowedOrigins: "*", + }, + } // Setup routes - SetupRoutes(router, mockRepo) + rl := SetupRoutes(router, mockRepo, healthChecker, cfg) + defer rl.Stop() // Test cases tests := []struct { diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index bd44476..9652307 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -23,6 +23,11 @@ const ( defaultShutdownTimeout = 30 * time.Second ) +// CORSConfig holds CORS configuration +type CORSConfig struct { + AllowedOrigins string +} + // Config holds all configuration for the application // //nolint:govet // Struct field alignment has been optimized for better memory usage @@ -33,6 +38,7 @@ type Config struct { // Then string and simple field structs App AppConfig AzureTable AzureTableConfig + CORS CORSConfig Logging LogConfig } @@ -98,8 +104,14 @@ func (c *Config) Validate() error { return fmt.Errorf("app config: %w", err) } - if err := c.Database.Validate(); err != nil { - return fmt.Errorf("database config: %w", err) + if c.AzureTable.UseAzureTable { + if err := c.AzureTable.Validate(); err != nil { + return fmt.Errorf("azure table config: %w", err) + } + } else { + if err := c.Database.Validate(); err != nil { + return fmt.Errorf("database config: %w", err) + } } if err := c.Server.Validate(); err != nil { @@ -212,7 +224,7 @@ func (c *DatabaseConfig) DSN() string { b.WriteByte(')') b.WriteByte('/') b.WriteString(c.DBName) - b.WriteString("?charset=utf8mb4&parseTime=True&loc=Local") + b.WriteString("?charset=utf8mb4&parseTime=True&loc=UTC") if c.MaxOpenConns > 0 { b.WriteString("&maxAllowedPacket=0") // Let server control packet size @@ -267,6 +279,9 @@ func LoadConfig() (*Config, error) { IdleTimeout: getEnvDuration("SERVER_IDLE_TIMEOUT", defaultIdleTimeout), ShutdownTimeout: getEnvDuration("SERVER_SHUTDOWN_TIMEOUT", defaultShutdownTimeout), }, + CORS: CORSConfig{ + AllowedOrigins: getEnv("CORS_ALLOWED_ORIGINS", "*"), + }, Logging: LogConfig{ Level: getEnv("LOG_LEVEL", "info"), File: getEnv("LOG_FILE", ""), diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 84d39ad..b6ebf70 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -149,7 +149,7 @@ func TestDatabaseDSN(t *testing.T) { DBName: "testdb", } - expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" + expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=UTC" assert.Equal(t, expected, dbConfig.DSN()) } func TestConfigValidate(t *testing.T) { diff --git a/backend/internal/database/azure/error_handling_test.go b/backend/internal/database/azure/error_handling_test.go index 6e4c129..383e76b 100644 --- a/backend/internal/database/azure/error_handling_test.go +++ b/backend/internal/database/azure/error_handling_test.go @@ -1,6 +1,7 @@ package azure_test import ( + "context" "testing" "backend/internal/database/azure" @@ -121,17 +122,18 @@ func TestTableRepository_InvalidData(t *testing.T) { // Create a minimal repository for testing repo := azure.NewTestTableRepository("testtable") + ctx := context.Background() // Test invalid inputs for different operations t.Run("Invalid inputs for Create", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.Create(nil) + err := repo.Create(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.Create("string") + err = repo.Create(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "entity must be *models.Item") }) @@ -139,12 +141,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for FindByID", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.FindByID(1, nil) + err := repo.FindByID(ctx, 1, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.FindByID(1, "string") + err = repo.FindByID(ctx, 1, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "dest must be *models.Item") }) @@ -152,12 +154,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for Update", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.Update(nil) + err := repo.Update(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.Update("string") + err = repo.Update(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "entity must be *models.Item") }) @@ -165,12 +167,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for Delete", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.Delete(nil) + err := repo.Delete(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.Delete("string") + err = repo.Delete(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "entity must be *models.Item") }) @@ -178,12 +180,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for List", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.List(nil) + err := repo.List(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.List("string") + err = repo.List(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "dest must be *[]models.Item") }) diff --git a/backend/internal/database/azure/repository_test.go b/backend/internal/database/azure/repository_test.go index 4c4065d..5944814 100644 --- a/backend/internal/database/azure/repository_test.go +++ b/backend/internal/database/azure/repository_test.go @@ -194,8 +194,10 @@ func TestTableRepository_DatabaseErrors(t *testing.T) { if repo2 == nil { t.Skip("Could not create test repository (Azurite not available)") } + ctx := context.Background() + // Test Create - err := repo2.Create(&models.Item{}) + err := repo2.Create(ctx, &models.Item{}) assert.Error(t, err) var dbErr *dberrors.DatabaseError assert.True(t, errors.As(err, &dbErr)) @@ -203,21 +205,21 @@ func TestTableRepository_DatabaseErrors(t *testing.T) { assert.Equal(t, mockErr, errors.Unwrap(err)) // Test FindByID - err = repo2.FindByID(1, &models.Item{}) + err = repo2.FindByID(ctx, 1, &models.Item{}) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "find", dbErr.Op) assert.Equal(t, mockErr, errors.Unwrap(err)) // Test Update - err = repo2.Update(&models.Item{}) + err = repo2.Update(ctx, &models.Item{}) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "find", dbErr.Op) // First tries to find the item assert.Equal(t, mockErr, errors.Unwrap(err)) // Test Delete - err = repo2.Delete(&models.Item{}) + err = repo2.Delete(ctx, &models.Item{}) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "delete", dbErr.Op) @@ -225,7 +227,7 @@ func TestTableRepository_DatabaseErrors(t *testing.T) { // Test List var items []models.Item - err = repo2.List(&items) + err = repo2.List(ctx, &items) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "list", dbErr.Op) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 54572a2..5be0800 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -2,9 +2,11 @@ package azure import ( "context" + "crypto/rand" "encoding/json" "errors" "fmt" + "math/big" "strconv" "strings" "time" @@ -16,7 +18,18 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/data/aztables" ) -// TableRepository implements the Repository interface for Azure Table Storage +// nextID generates a collision-resistant numeric ID using a cryptographically +// secure random component. We use 48 bits of randomness to keep collision +// probability extremely low even under concurrency across multiple instances. +func nextID() (uint, error) { + max := new(big.Int).Lsh(big.NewInt(1), 48) + rb, err := rand.Int(rand.Reader, max) + if err != nil { + return 0, fmt.Errorf("failed to generate random ID component: %w", err) + } + return uint(rb.Uint64()), nil +} + // TableRepository implements the Repository interface for Azure Table Storage type TableRepository struct { client AzureTableClient @@ -90,20 +103,35 @@ func NewTestTableRepository(tableName string) *TableRepository { } // Create implements the Repository interface -func (r *TableRepository) Create(entity interface{}) error { +func (r *TableRepository) Create(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", errors.New("entity must be *models.Item")) } + // Initialize version for new entities (consistent with GORM default:1 and MockRepository) + if item.Version == 0 { + item.Version = 1 + } + // Create Azure Table entity now := time.Now().UTC() + // Generate a numeric ID (Azure Table Storage has no auto-increment) + if item.ID == 0 { + id, err := nextID() + if err != nil { + return dberrors.NewDatabaseError("create", err) + } + item.ID = id + } + entityJSON := map[string]interface{}{ "PartitionKey": "items", - "RowKey": item.Name, // Using Name as the unique key for testing + "RowKey": strconv.FormatUint(uint64(item.ID), 10), "Name": item.Name, "Price": item.Price, + "Version": item.Version, "CreatedAt": now.Format(time.RFC3339), "UpdatedAt": now.Format(time.RFC3339), } @@ -113,8 +141,7 @@ func (r *TableRepository) Create(entity interface{}) error { return dberrors.NewDatabaseError("marshal", err) } - // Create the entity - _, err = r.client.AddEntity(context.Background(), entityBytes, nil) + _, err = r.client.AddEntity(ctx, entityBytes, nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.ErrorCode == "EntityAlreadyExists" { @@ -129,14 +156,14 @@ func (r *TableRepository) Create(entity interface{}) error { } // FindByID implements the Repository interface -func (r *TableRepository) FindByID(id uint, dest interface{}) error { +func (r *TableRepository) FindByID(ctx context.Context, id uint, dest interface{}) error { item, ok := dest.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("dest must be *models.Item")) } // Get the entity - result, err := r.client.GetEntity(context.Background(), "items", strconv.FormatUint(uint64(id), 10), nil) + result, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(id), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -153,16 +180,43 @@ func (r *TableRepository) FindByID(id uint, dest interface{}) error { // Map entity to item item.ID = id - item.Name = entityData["Name"].(string) - item.Price = entityData["Price"].(float64) - createdAt, err := time.Parse(time.RFC3339, entityData["CreatedAt"].(string)) + name, ok := entityData["Name"].(string) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid Name field")) + } + item.Name = name + + price, ok := entityData["Price"].(float64) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid Price field")) + } + item.Price = price + + // Default to version 1 when the Version field is missing or invalid to + // keep optimistic-lock semantics consistent with the GORM repository. + item.Version = 1 + if v, ok := entityData["Version"]; ok { + if vf, ok := v.(float64); ok && vf > 0 { + item.Version = uint(vf) + } + } + + createdAtStr, ok := entityData["CreatedAt"].(string) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid CreatedAt field")) + } + createdAt, err := time.Parse(time.RFC3339, createdAtStr) if err != nil { return dberrors.NewDatabaseError("parse_time", err) } item.CreatedAt = createdAt - updatedAt, err := time.Parse(time.RFC3339, entityData["UpdatedAt"].(string)) + updatedAtStr, ok := entityData["UpdatedAt"].(string) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid UpdatedAt field")) + } + updatedAt, err := time.Parse(time.RFC3339, updatedAtStr) if err != nil { return dberrors.NewDatabaseError("parse_time", err) } @@ -171,15 +225,22 @@ func (r *TableRepository) FindByID(id uint, dest interface{}) error { return nil } -// Update implements the Repository interface -func (r *TableRepository) Update(entity interface{}) error { +// Update implements the Repository interface. +// For models that implement Versionable (e.g. Item), optimistic locking is enforced: +// the entity is fetched first to compare versions, and the ETag from the GET response +// is passed to UpdateEntity so Azure Table Storage rejects stale writes. +func (r *TableRepository) Update(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } - // Check if entity exists first - _, err := r.client.GetEntity(context.Background(), "items", strconv.FormatUint(uint64(item.ID), 10), nil) + if item.ID == 0 { + return dberrors.NewDatabaseError("update", dberrors.ErrValidation) + } + + // Fetch existing entity (also validates existence) + existing, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -188,14 +249,64 @@ func (r *TableRepository) Update(entity interface{}) error { return dberrors.NewDatabaseError("find", err) } + // Parse existing entity once for version checking and CreatedAt preservation. + var existingData map[string]interface{} + if err := json.Unmarshal(existing.Value, &existingData); err != nil { + return dberrors.NewDatabaseError("unmarshal", err) + } + + // Optimistic locking: compare version if the entity is Versionable + if ver, ok := entity.(models.Versionable); ok { + currentVersion := ver.GetVersion() + + // Default stored version to 1 for legacy rows that predate versioning, + // consistent with the model default and FindByID behavior. + storedVersionUint := uint(1) + + if storedVersion, ok := existingData["Version"]; ok { + switch v := storedVersion.(type) { + case float64: + // Only treat strictly positive values as valid stored versions. + // Non-positive values leave the default of 1, matching FindByID. + if v > 0 { + storedVersionUint = uint(v) + } + case json.Number: + if n, err := v.Int64(); err == nil && n > 0 { + storedVersionUint = uint(n) + } + } + } + + if currentVersion != storedVersionUint { + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } + + // Increment version for the update + ver.SetVersion(currentVersion + 1) + } + // Create Azure Table entity now := time.Now().UTC() + + // Preserve the stored CreatedAt so callers that skip FindByID before + // updating don't accidentally clobber it with a zero time. + createdAt := item.CreatedAt + if createdAt.IsZero() { + if caStr, ok := existingData["CreatedAt"].(string); ok { + if parsed, parseErr := time.Parse(time.RFC3339, caStr); parseErr == nil { + createdAt = parsed + } + } + } + entityJson := map[string]interface{}{ "PartitionKey": "items", "RowKey": strconv.FormatUint(uint64(item.ID), 10), "Name": item.Name, "Price": item.Price, - "CreatedAt": item.CreatedAt.Format(time.RFC3339), + "Version": item.Version, + "CreatedAt": createdAt.Format(time.RFC3339), "UpdatedAt": now.Format(time.RFC3339), } @@ -204,9 +315,21 @@ func (r *TableRepository) Update(entity interface{}) error { return dberrors.NewDatabaseError("marshal", err) } - // Update the entity - _, err = r.client.UpdateEntity(context.Background(), entityBytes, nil) + // Use ETag from the GET response for conditional update + updateOpts := &aztables.UpdateEntityOptions{ + IfMatch: &existing.ETag, + UpdateMode: aztables.UpdateModeMerge, + } + _, err = r.client.UpdateEntity(ctx, entityBytes, updateOpts) if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == 412 { + // Precondition failed — concurrent modification + if ver, ok := entity.(models.Versionable); ok { + ver.SetVersion(ver.GetVersion() - 1) // Roll back + } + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } return dberrors.NewDatabaseError("update", err) } @@ -215,13 +338,17 @@ func (r *TableRepository) Update(entity interface{}) error { } // Delete implements the Repository interface -func (r *TableRepository) Delete(entity interface{}) error { +func (r *TableRepository) Delete(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } - _, err := r.client.DeleteEntity(context.Background(), "items", strconv.FormatUint(uint64(item.ID), 10), nil) + if item.ID == 0 { + return dberrors.NewDatabaseError("delete", dberrors.ErrValidation) + } + + _, err := r.client.DeleteEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -234,7 +361,7 @@ func (r *TableRepository) Delete(entity interface{}) error { } // List implements the Repository interface -func (r *TableRepository) List(dest interface{}, conditions ...interface{}) error { +func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { items, ok := dest.(*[]models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("dest must be *[]models.Item")) @@ -285,7 +412,7 @@ func (r *TableRepository) List(dest interface{}, conditions ...interface{}) erro // Fetch and process all entities for pager.More() { - response, err := pager.NextPage(context.Background()) + response, err := pager.NextPage(ctx) if err != nil { return dberrors.NewDatabaseError("list", err) } @@ -296,9 +423,42 @@ func (r *TableRepository) List(dest interface{}, conditions ...interface{}) erro return dberrors.NewDatabaseError("unmarshal", err) } - id, _ := strconv.ParseUint(entityData["RowKey"].(string), 10, 32) - createdAt, _ := time.Parse(time.RFC3339, entityData["CreatedAt"].(string)) - updatedAt, _ := time.Parse(time.RFC3339, entityData["UpdatedAt"].(string)) + rowKey, ok := entityData["RowKey"].(string) + if !ok || rowKey == "" { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid RowKey")) + } + id, err := strconv.ParseUint(rowKey, 10, 64) + if err != nil { + return dberrors.NewDatabaseError("list", fmt.Errorf("invalid RowKey %q: %w", rowKey, err)) + } + + createdAtStr, ok := entityData["CreatedAt"].(string) + if !ok || createdAtStr == "" { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid CreatedAt")) + } + createdAt, err := time.Parse(time.RFC3339, createdAtStr) + if err != nil { + return dberrors.NewDatabaseError("list", fmt.Errorf("invalid CreatedAt %q: %w", createdAtStr, err)) + } + + updatedAtStr, ok := entityData["UpdatedAt"].(string) + if !ok || updatedAtStr == "" { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid UpdatedAt")) + } + updatedAt, err := time.Parse(time.RFC3339, updatedAtStr) + if err != nil { + return dberrors.NewDatabaseError("list", fmt.Errorf("invalid UpdatedAt %q: %w", updatedAtStr, err)) + } + + name, ok := entityData["Name"].(string) + if !ok { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid Name")) + } + + price, ok := entityData["Price"].(float64) + if !ok { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid Price")) + } item := models.Item{ Base: models.Base{ @@ -306,8 +466,19 @@ func (r *TableRepository) List(dest interface{}, conditions ...interface{}) erro CreatedAt: createdAt, UpdatedAt: updatedAt, }, - Name: entityData["Name"].(string), - Price: entityData["Price"].(float64), + Name: name, + Price: price, + } + + // Populate Version if present; default to 1 when absent or invalid + if v, ok := entityData["Version"]; ok { + if vf, ok := v.(float64); ok { + item.Version = uint(vf) + } else { + item.Version = 1 + } + } else { + item.Version = 1 } // Apply name contains filter if specified @@ -341,16 +512,22 @@ func (r *TableRepository) List(dest interface{}, conditions ...interface{}) erro } // Ping implements the Repository interface -func (r *TableRepository) Ping() error { +func (r *TableRepository) Ping(ctx context.Context) error { // List tables to check connectivity pager := r.client.NewListEntitiesPager(nil) - _, err := pager.NextPage(context.Background()) + _, err := pager.NextPage(ctx) if err != nil { return dberrors.NewDatabaseError("ping", err) } return nil } +// Close implements the Repository interface. Azure Table Storage uses HTTP +// clients that don't require explicit cleanup, so this is a no-op. +func (r *TableRepository) Close() error { + return nil +} + // Helper functions for error handling // IsTableExistsError checks if the error is a TableAlreadyExists error diff --git a/backend/internal/database/azure/table_test.go b/backend/internal/database/azure/table_test.go index 885d54a..3d6bfb0 100644 --- a/backend/internal/database/azure/table_test.go +++ b/backend/internal/database/azure/table_test.go @@ -92,11 +92,12 @@ func TestTableClientOperations(t *testing.T) { Name: "test", Price: 10.5, } - err := repo.Create(item) + err := repo.Create(context.Background(), item) assert.NoError(t, err) + assert.NotZero(t, item.ID, "Create should assign an ID") var retrieved models.Item - err = repo.FindByID(1, &retrieved) + err = repo.FindByID(context.Background(), item.ID, &retrieved) assert.NoError(t, err) assert.Equal(t, item.Name, retrieved.Name) assert.InDelta(t, item.Price, retrieved.Price, 0.001) @@ -121,7 +122,7 @@ func TestTableClientOperations(t *testing.T) { repo.SetTestClient(mockClient) var items []models.Item - err := repo.List(&items) + err := repo.List(context.Background(), &items) assert.NoError(t, err) assert.Len(t, items, 3) assert.Equal(t, "item1", items[0].Name) @@ -135,10 +136,14 @@ func TestTableClientOperations(t *testing.T) { mockClient := &mockClient{ getEntity: func(ctx context.Context, partitionKey, rowKey string, options *aztables.GetEntityOptions) (aztables.GetEntityResponse, error) { return aztables.GetEntityResponse{ - Value: []byte(`{"Name":"test","Price":10.5}`), + ETag: "etag-1", + Value: []byte(`{"Name":"test","Price":10.5,"Version":1}`), }, nil }, updateEntity: func(ctx context.Context, entity []byte, options *aztables.UpdateEntityOptions) (aztables.UpdateEntityResponse, error) { + // Verify ETag is passed for conditional update + assert.NotNil(t, options) + assert.NotNil(t, options.IfMatch) return aztables.UpdateEntityResponse{}, nil }, } @@ -147,11 +152,41 @@ func TestTableClientOperations(t *testing.T) { repo.SetTestClient(mockClient) item := &models.Item{ - Name: "test", - Price: 10.5, + Name: "test", + Price: 10.5, + Version: 1, } - err := repo.Update(item) + item.ID = 1 + err := repo.Update(context.Background(), item) assert.NoError(t, err) + assert.Equal(t, uint(2), item.Version, "Version should be incremented after update") + }) + + t.Run("update fails on version mismatch", func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{ + getEntity: func(ctx context.Context, partitionKey, rowKey string, options *aztables.GetEntityOptions) (aztables.GetEntityResponse, error) { + return aztables.GetEntityResponse{ + ETag: "etag-1", + Value: []byte(`{"Name":"test","Price":10.5,"Version":2}`), + }, nil + }, + } + + repo := azure.NewTestTableRepository("testtable") + repo.SetTestClient(mockClient) + + item := &models.Item{ + Name: "test", + Price: 15.0, + Version: 1, // Stale version + } + item.ID = 1 + err := repo.Update(context.Background(), item) + assert.Error(t, err) + assert.Contains(t, err.Error(), "version mismatch") + assert.Equal(t, uint(1), item.Version, "Version should not change on mismatch") }) t.Run("can delete an entity", func(t *testing.T) { @@ -170,7 +205,8 @@ func TestTableClientOperations(t *testing.T) { Name: "test", Price: 10.5, } - err := repo.Delete(item) + item.ID = 1 + err := repo.Delete(context.Background(), item) assert.NoError(t, err) }) } diff --git a/backend/internal/database/config.go b/backend/internal/database/config.go index 35a48d7..a73756e 100644 --- a/backend/internal/database/config.go +++ b/backend/internal/database/config.go @@ -24,7 +24,7 @@ func NewConfig() *Config { // DSN returns the database connection string func (c *Config) DSN() string { - return c.User + ":" + c.Password + "@tcp(" + c.Host + ":" + c.Port + ")/" + c.DBName + "?charset=utf8mb4&parseTime=True&loc=Local" + return c.User + ":" + c.Password + "@tcp(" + c.Host + ":" + c.Port + ")/" + c.DBName + "?charset=utf8mb4&parseTime=True&loc=UTC" } // Helper function to get environment variables with default values diff --git a/backend/internal/database/config_test.go b/backend/internal/database/config_test.go index fd93211..ad54d82 100644 --- a/backend/internal/database/config_test.go +++ b/backend/internal/database/config_test.go @@ -59,6 +59,6 @@ func TestConfigDSN(t *testing.T) { DBName: "testdb", } - expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" + expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=UTC" assert.Equal(t, expected, config.DSN()) } diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go index 9817e57..386ef84 100644 --- a/backend/internal/database/database.go +++ b/backend/internal/database/database.go @@ -1,7 +1,7 @@ package database import ( - "log" + "log/slog" "strings" "gorm.io/driver/mysql" @@ -9,13 +9,16 @@ import ( "gorm.io/gorm/logger" ) +// Database wraps a gorm.DB instance with additional utilities. type Database struct { *gorm.DB } -func NewDatabase(dsn string, logger logger.Interface) (*Database, error) { +// NewDatabase creates a new database connection. Callers should configure +// connection pool settings via the returned *sql.DB if needed. +func NewDatabase(dsn string, logCfg logger.Interface) (*Database, error) { db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logger, + Logger: logCfg, }) if err != nil { if strings.Contains(err.Error(), "connection refused") { @@ -24,21 +27,11 @@ func NewDatabase(dsn string, logger logger.Interface) (*Database, error) { return nil, NewDatabaseError("connect", err) } - // Configure connection pool - sqlDB, err := db.DB() - if err != nil { - return nil, NewDatabaseError("configure", err) - } - - // Set reasonable defaults for the connection pool - sqlDB.SetMaxIdleConns(5) - sqlDB.SetMaxOpenConns(20) - - log.Println("Connected to database successfully") + slog.Info("Connected to database successfully") return &Database{DB: db}, nil } -// Transaction executes operations within a database transaction +// Transaction executes operations within a database transaction. func (d *Database) Transaction(fn func(tx *gorm.DB) error) error { err := d.DB.Transaction(fn) if err != nil { @@ -47,79 +40,11 @@ func (d *Database) Transaction(fn func(tx *gorm.DB) error) error { return nil } -// HandleError translates GORM/MySQL errors into our custom error types -func (d *Database) HandleError(op string, err error) error { - if err == nil { - return nil - } - - if err == gorm.ErrRecordNotFound { - return NewDatabaseError(op, ErrNotFound) - } - - // Check for duplicate key violations - if strings.Contains(err.Error(), "Duplicate entry") { - return NewDatabaseError(op, ErrDuplicateKey) - } - - // Handle validation errors - if strings.Contains(err.Error(), "validation failed") { - return NewDatabaseError(op, ErrValidation) - } - - return NewDatabaseError(op, err) -} - -// Create implements models.Repository -func (d *Database) Create(value interface{}) error { - result := d.DB.Create(value) - if result.Error != nil { - return d.HandleError("create", result.Error) - } - return nil -} - -// FindByID implements models.Repository -func (d *Database) FindByID(id uint, dest interface{}) error { - result := d.DB.First(dest, id) - if result.Error != nil { - return d.HandleError("find", result.Error) - } - return nil -} - -// Update implements models.Repository -func (d *Database) Update(value interface{}) error { - result := d.DB.Save(value) - if result.Error != nil { - return d.HandleError("update", result.Error) - } - return nil -} - -// Delete implements models.Repository -func (d *Database) Delete(value interface{}) error { - result := d.DB.Delete(value) - if result.Error != nil { - return d.HandleError("delete", result.Error) - } - return nil -} - -// List implements models.Repository -func (d *Database) List(dest interface{}, conditions ...interface{}) error { - result := d.DB.Find(dest) - if result.Error != nil { - return d.HandleError("list", result.Error) - } - return nil -} - -// Ping implements models.Repository +// Ping checks if the database connection is alive. func (d *Database) Ping() error { sqlDB, err := d.DB.DB() if err != nil { - return d.HandleError("ping", err) + return err } return sqlDB.Ping() } diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go index 0d45860..dfb6a44 100644 --- a/backend/internal/database/database_test.go +++ b/backend/internal/database/database_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "testing" "backend/internal/models" @@ -115,11 +116,14 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() + item := &models.Item{ Name: "Test Item", Price: 99.99, } - err := db.Create(item) + err := repo.Create(ctx, item) assert.NoError(t, err) assert.NotZero(t, item.ID) }) @@ -128,16 +132,18 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() // Create item first initialItem := &models.Item{ Name: "Test Item", Price: 99.99, } - require.NoError(t, db.Create(initialItem)) + require.NoError(t, repo.Create(ctx, initialItem)) var item models.Item - err := db.FindByID(1, &item) + err := repo.FindByID(ctx, initialItem.ID, &item) assert.NoError(t, err) assert.Equal(t, "Test Item", item.Name) assert.Equal(t, 99.99, item.Price) @@ -147,23 +153,26 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() // Create item first initialItem := &models.Item{ Name: "Test Item", Price: 99.99, } - require.NoError(t, db.Create(initialItem)) + require.NoError(t, repo.Create(ctx, initialItem)) var item models.Item - err := db.FindByID(1, &item) + err := repo.FindByID(ctx, initialItem.ID, &item) require.NoError(t, err) item.Price = 199.99 - err = db.Update(&item) + err = repo.Update(ctx, &item) assert.NoError(t, err) var updatedItem models.Item - err = db.FindByID(1, &updatedItem) + err = repo.FindByID(ctx, initialItem.ID, &updatedItem) + assert.NoError(t, err) assert.Equal(t, 199.99, updatedItem.Price) }) @@ -171,20 +180,64 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() // Create item first initialItem := &models.Item{ Name: "Test Item", Price: 99.99, } - require.NoError(t, db.Create(initialItem)) + require.NoError(t, repo.Create(ctx, initialItem)) - item := &models.Item{Base: models.Base{ID: 1}} - err := db.Delete(item) + item := &models.Item{Base: models.Base{ID: initialItem.ID}} + err := repo.Delete(ctx, item) assert.NoError(t, err) var deleted models.Item - err = db.FindByID(1, &deleted) + err = repo.FindByID(ctx, initialItem.ID, &deleted) assert.Error(t, err, "Should not find deleted item") }) + + t.Run("Optimistic Locking - version mismatch", func(t *testing.T) { + t.Parallel() + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() + + // Create an item (version starts at 1 after create). + item := &models.Item{ + Name: "Versioned Item", + Price: 10.00, + } + require.NoError(t, repo.Create(ctx, item)) + assert.Equal(t, uint(1), item.Version) + + // Simulate two concurrent reads. + var copy1, copy2 models.Item + require.NoError(t, repo.FindByID(ctx, item.ID, ©1)) + require.NoError(t, repo.FindByID(ctx, item.ID, ©2)) + + // First update succeeds: version 1 → 2. + copy1.Price = 20.00 + require.NoError(t, repo.Update(ctx, ©1)) + assert.Equal(t, uint(2), copy1.Version) + + // Second update uses stale version 1 — should fail. + copy2.Price = 30.00 + err := repo.Update(ctx, ©2) + assert.Error(t, err, "Update with stale version should fail") + assert.Contains(t, err.Error(), "version mismatch") + + // The in-memory version should be rolled back. + assert.Equal(t, uint(1), copy2.Version, "Version should be rolled back on mismatch") + + // Database should still have the first update's data. + var current models.Item + require.NoError(t, repo.FindByID(ctx, item.ID, ¤t)) + assert.Equal(t, 20.00, current.Price, "Price should reflect the first successful update") + assert.Equal(t, uint(2), current.Version, "Version in DB should be 2") + }) + } diff --git a/backend/internal/database/errors.go b/backend/internal/database/errors.go index bb3e2e4..18b0268 100644 --- a/backend/internal/database/errors.go +++ b/backend/internal/database/errors.go @@ -1,39 +1,16 @@ +// Package database provides database connectivity, migrations, and repository factories. +// Error types are re-exported from pkg/dberrors to maintain a single source of truth. package database -import ( - "errors" - "fmt" -) +import "backend/pkg/dberrors" + +// Re-export error types from pkg/dberrors for backward compatibility. +type DatabaseError = dberrors.DatabaseError -// Common database errors var ( - ErrNotFound = errors.New("record not found") - ErrDuplicateKey = errors.New("duplicate key violation") - ErrValidation = errors.New("validation error") - ErrConnectionFailed = errors.New("database connection failed") + ErrNotFound = dberrors.ErrNotFound + ErrDuplicateKey = dberrors.ErrDuplicateKey + ErrValidation = dberrors.ErrValidation + ErrConnectionFailed = dberrors.ErrConnectionFailed + NewDatabaseError = dberrors.NewDatabaseError ) - -// DatabaseError wraps database-specific errors with additional context -type DatabaseError struct { - Op string - Err error -} - -func (e *DatabaseError) Error() string { - if e.Op != "" { - return fmt.Sprintf("%s: %v", e.Op, e.Err) - } - return e.Err.Error() -} - -func (e *DatabaseError) Unwrap() error { - return e.Err -} - -// NewDatabaseError creates a new database error with operation context -func NewDatabaseError(op string, err error) error { - if err == nil { - return nil - } - return &DatabaseError{Op: op, Err: err} -} diff --git a/backend/internal/database/factory.go b/backend/internal/database/factory.go index 399e3e3..9a6047d 100644 --- a/backend/internal/database/factory.go +++ b/backend/internal/database/factory.go @@ -3,7 +3,7 @@ package database import ( "context" "fmt" - "log" + "log/slog" "time" "backend/internal/config" @@ -43,8 +43,12 @@ func NewFromAppConfig(cfg *config.Config) (*Database, error) { retryCount++ if retryCount < maxRetries { - log.Printf("Failed to connect to database (attempt %d/%d): %v. Retrying in %v...", - retryCount, maxRetries, err, retryDelay) + slog.Warn("Failed to connect to database, retrying...", + "attempt", retryCount, + "maxRetries", maxRetries, + "error", err, + "retryDelay", retryDelay, + ) time.Sleep(retryDelay) } } @@ -121,8 +125,12 @@ func NewFromDBConfig(cfg *Config) (*Database, error) { retryCount++ if retryCount < maxRetries { - log.Printf("Failed to connect to database (attempt %d/%d): %v. Retrying in %v...", - retryCount, maxRetries, err, retryDelay) + slog.Warn("Failed to connect to database, retrying...", + "attempt", retryCount, + "maxRetries", maxRetries, + "error", err, + "retryDelay", retryDelay, + ) time.Sleep(retryDelay) } } diff --git a/backend/internal/database/migrations.go b/backend/internal/database/migrations.go index 4a86297..61d7521 100644 --- a/backend/internal/database/migrations.go +++ b/backend/internal/database/migrations.go @@ -1,7 +1,7 @@ package database import ( - "log" + "log/slog" "backend/internal/database/schema" "backend/internal/models" @@ -11,7 +11,7 @@ import ( // AutoMigrate runs database migrations for all models func (d *Database) AutoMigrate() error { - log.Println("Running database migrations...") + slog.Info("Running database migrations...") // Initialize migrator migrator := schema.NewMigrator(d.DB) @@ -61,11 +61,35 @@ func (d *Database) AutoMigrate() error { }, }) + // Update existing items with Version=0 to Version=1 and change column default + migrator.AddMigration(schema.Migration{ + Version: "20231201000003", + Name: "update_items_version_default", + Description: "Set Version default to 1 for optimistic locking and update existing rows", + Up: func(tx *gorm.DB) error { + // Update existing rows that still have the old default of 0 to the new default of 1 + if err := tx.Exec("UPDATE items SET version = 1 WHERE version = 0").Error; err != nil { + return err + } + // Alter column default to 1 (MySQL syntax; SQLite defaults are set via AutoMigrate) + if tx.Dialector.Name() == "mysql" { + return tx.Exec("ALTER TABLE items ALTER COLUMN version SET DEFAULT 1").Error + } + return nil + }, + Down: func(tx *gorm.DB) error { + if tx.Dialector.Name() == "mysql" { + return tx.Exec("ALTER TABLE items ALTER COLUMN version SET DEFAULT 0").Error + } + return nil + }, + }) + // Run migrations if err := migrator.MigrateUp(); err != nil { return err } - log.Println("Database migrations completed successfully") + slog.Info("Database migrations completed successfully") return nil } diff --git a/backend/internal/database/repository.go b/backend/internal/database/repository.go index e1b4bbb..ec7f33f 100644 --- a/backend/internal/database/repository.go +++ b/backend/internal/database/repository.go @@ -2,7 +2,7 @@ package database import ( "fmt" - "log" + "log/slog" "backend/internal/config" "backend/internal/database/azure" @@ -12,7 +12,7 @@ import ( // NewRepository creates a new repository based on the configuration func NewRepository(cfg *config.Config) (models.Repository, error) { if cfg.AzureTable.UseAzureTable { - log.Println("Using Azure Table Storage as repository") + slog.Info("Using Azure Table Storage as repository") return azure.NewTableRepository( cfg.AzureTable.AccountName, cfg.AzureTable.AccountKey, @@ -22,11 +22,20 @@ func NewRepository(cfg *config.Config) (models.Repository, error) { ) } - log.Println("Using MySQL as repository") + slog.Info("Using MySQL as repository") db, err := NewFromAppConfig(cfg) if err != nil { return nil, fmt.Errorf("failed to initialize MySQL database: %w", err) } + // Run database migrations (migrator tracks applied versions; safe on every startup) + if err := db.AutoMigrate(); err != nil { + // Clean up the database connection to avoid resource leaks + if sqlDB, dbErr := db.DB.DB(); dbErr == nil { + _ = sqlDB.Close() + } + return nil, fmt.Errorf("failed to run database migrations: %w", err) + } + return models.NewRepository(db.DB), nil } diff --git a/backend/internal/health/health.go b/backend/internal/health/health.go index bc72c46..a9f11af 100644 --- a/backend/internal/health/health.go +++ b/backend/internal/health/health.go @@ -1,6 +1,7 @@ package health import ( + "context" "sync" "time" ) @@ -12,7 +13,9 @@ type HealthChecker struct { dependencies map[string]HealthCheck } -type HealthCheck func() error +// HealthCheck is a function that checks a dependency's health. +// It receives a context for cancellation and timeout support. +type HealthCheck func(ctx context.Context) error type HealthStatus struct { Status string `json:"status"` @@ -46,7 +49,10 @@ func (h *HealthChecker) SetReady(ready bool) { h.isReady = ready } -func (h *HealthChecker) CheckLiveness() HealthStatus { +// CheckLiveness returns a simple UP status. The context parameter is unused +// because liveness is intentionally trivial — it only reports uptime and +// does not call any external dependencies that would need cancellation. +func (h *HealthChecker) CheckLiveness(_ context.Context) HealthStatus { uptime := time.Since(h.startTime).String() return HealthStatus{ Status: "UP", @@ -54,11 +60,18 @@ func (h *HealthChecker) CheckLiveness() HealthStatus { } } -func (h *HealthChecker) CheckReadiness() HealthStatus { +func (h *HealthChecker) CheckReadiness(ctx context.Context) HealthStatus { + // Copy state under lock, then release before running I/O-bound checks + // to avoid holding the mutex during potentially slow dependency calls. h.mu.RLock() - defer h.mu.RUnlock() + ready := h.isReady + deps := make(map[string]HealthCheck, len(h.dependencies)) + for k, v := range h.dependencies { + deps[k] = v + } + h.mu.RUnlock() - if !h.isReady { + if !ready { return HealthStatus{ Status: "DOWN", Checks: map[string]CheckStatus{ @@ -72,8 +85,8 @@ func (h *HealthChecker) CheckReadiness() HealthStatus { Checks: make(map[string]CheckStatus), } - for name, check := range h.dependencies { - if err := check(); err != nil { + for name, check := range deps { + if err := check(ctx); err != nil { status.Status = "DOWN" status.Checks[name] = CheckStatus{ Status: "DOWN", diff --git a/backend/internal/health/health_test.go b/backend/internal/health/health_test.go index a4ee4c2..138c652 100644 --- a/backend/internal/health/health_test.go +++ b/backend/internal/health/health_test.go @@ -1,6 +1,7 @@ package health import ( + "context" "errors" "testing" ) @@ -16,7 +17,7 @@ func TestNew(t *testing.T) { func TestLivenessCheck(t *testing.T) { t.Parallel() h := New() - status := h.CheckLiveness() + status := h.CheckLiveness(context.Background()) if status.Status != "UP" { t.Errorf("Expected status UP, got %s", status.Status) @@ -33,7 +34,7 @@ func TestReadinessCheck(t *testing.T) { t.Run("Service not ready", func(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest - status := h.CheckReadiness() + status := h.CheckReadiness(context.Background()) if status.Status != "DOWN" { t.Errorf("Expected status DOWN, got %s", status.Status) } @@ -43,7 +44,7 @@ func TestReadinessCheck(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest h.SetReady(true) - status := h.CheckReadiness() + status := h.CheckReadiness(context.Background()) if status.Status != "UP" { t.Errorf("Expected status UP, got %s", status.Status) } @@ -57,8 +58,8 @@ func TestHealthChecks(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest h.SetReady(true) - h.AddCheck("test", func() error { return nil }) - status := h.CheckReadiness() + h.AddCheck("test", func(_ context.Context) error { return nil }) + status := h.CheckReadiness(context.Background()) if status.Status != "UP" { t.Errorf("Expected status UP, got %s", status.Status) } @@ -71,8 +72,8 @@ func TestHealthChecks(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest h.SetReady(true) - h.AddCheck("failing", func() error { return errors.New("test error") }) - status := h.CheckReadiness() + h.AddCheck("failing", func(_ context.Context) error { return errors.New("test error") }) + status := h.CheckReadiness(context.Background()) if status.Status != "DOWN" { t.Errorf("Expected status DOWN, got %s", status.Status) } diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 9ee28e3..02a520b 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -1,6 +1,9 @@ package models import ( + "context" + "errors" + "fmt" "strings" "time" @@ -22,7 +25,7 @@ type Item struct { Base Name string `gorm:"size:255;not null" json:"name"` Price float64 `json:"price"` - Version uint `gorm:"not null;default:0" json:"version"` // For optimistic locking + Version uint `gorm:"not null;default:1" json:"version"` // For optimistic locking (1 = initial; 0 = not provided) } // User represents a user in the system @@ -38,77 +41,183 @@ type Validator interface { Validate() error } -// Repository defines the interface for database operations +// Versionable is an interface for models that support optimistic locking. +type Versionable interface { + GetVersion() uint + SetVersion(v uint) +} + +// GetVersion implements Versionable for Item. +func (i *Item) GetVersion() uint { return i.Version } + +// SetVersion implements Versionable for Item. +func (i *Item) SetVersion(v uint) { i.Version = v } + +// Repository defines the interface for database operations. type Repository interface { - Create(interface{}) error - FindByID(id uint, dest interface{}) error - Update(interface{}) error - Delete(interface{}) error - List(dest interface{}, conditions ...interface{}) error - Ping() error + Create(ctx context.Context, entity interface{}) error + FindByID(ctx context.Context, id uint, dest interface{}) error + Update(ctx context.Context, entity interface{}) error + Delete(ctx context.Context, entity interface{}) error + List(ctx context.Context, dest interface{}, conditions ...interface{}) error + Ping(ctx context.Context) error + Close() error } // GenericRepository implements the Repository interface type GenericRepository struct { - db *gorm.DB + db *gorm.DB + allowedFilterFields map[string]bool } -// NewRepository creates a new GenericRepository +// NewRepository creates a new GenericRepository with filter fields for the Item +// entity ("name", "price"). For other entity types, use NewRepositoryWithFilterFields. func NewRepository(db *gorm.DB) Repository { - return &GenericRepository{db: db} + return &GenericRepository{ + db: db, + allowedFilterFields: map[string]bool{ + "name": true, + "price": true, + }, + } +} + +// NewRepositoryWithFilterFields creates a GenericRepository with a custom filter field whitelist. +func NewRepositoryWithFilterFields(db *gorm.DB, fields []string) Repository { + allowed := make(map[string]bool, len(fields)) + for _, f := range fields { + allowed[f] = true + } + return &GenericRepository{db: db, allowedFilterFields: allowed} } // Ping checks if the database is reachable -func (r *GenericRepository) Ping() error { +func (r *GenericRepository) Ping(ctx context.Context) error { + sqlDB, err := r.db.DB() + if err != nil { + return err + } + return sqlDB.PingContext(ctx) +} + +// Close releases the underlying database connection pool. +func (r *GenericRepository) Close() error { sqlDB, err := r.db.DB() if err != nil { return err } - return sqlDB.Ping() + return sqlDB.Close() } -func (r *GenericRepository) Create(entity interface{}) error { +func (r *GenericRepository) Create(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", err) + return dberrors.NewDatabaseError("validate", + fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) } } - if err := r.db.Create(entity).Error; err != nil { + if err := r.db.WithContext(ctx).Create(entity).Error; err != nil { return r.handleError("create", err) } return nil } -func (r *GenericRepository) FindByID(id uint, dest interface{}) error { - if err := r.db.First(dest, id).Error; err != nil { +func (r *GenericRepository) FindByID(ctx context.Context, id uint, dest interface{}) error { + if err := r.db.WithContext(ctx).First(dest, id).Error; err != nil { return r.handleError("find", err) } return nil } -func (r *GenericRepository) Update(entity interface{}) error { +func (r *GenericRepository) Update(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", err) + return dberrors.NewDatabaseError("validate", + fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) } } - if err := r.db.Save(entity).Error; err != nil { + // Optimistic locking for Versionable entities. + // We increment the version optimistically, then issue an UPDATE with a + // WHERE version=old clause. If no rows are affected, it means another + // transaction modified this entity — we roll back the in-memory version + // and return a version-mismatch error. + // We use Model().Where().Select("*").Updates() instead of Where().Save() + // because Save() may generate an upsert (INSERT … ON CONFLICT) on some + // dialects (e.g. SQLite), which bypasses the WHERE version clause. + // Select("*") ensures all columns are written, matching Save() semantics. + if ver, ok := entity.(Versionable); ok { + currentVersion := ver.GetVersion() + ver.SetVersion(currentVersion + 1) + result := r.db.WithContext(ctx). + Model(entity). + Where("version = ?", currentVersion). + Select("*"). + Updates(entity) + if result.Error != nil { + ver.SetVersion(currentVersion) // Roll back version on error + return r.handleError("update", result.Error) + } + if result.RowsAffected == 0 { + ver.SetVersion(currentVersion) // Roll back version on mismatch + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } + return nil + } + + if err := r.db.WithContext(ctx).Save(entity).Error; err != nil { return r.handleError("update", err) } return nil } -func (r *GenericRepository) Delete(entity interface{}) error { - if err := r.db.Delete(entity).Error; err != nil { - return r.handleError("delete", err) +func (r *GenericRepository) Delete(ctx context.Context, entity interface{}) error { + result := r.db.WithContext(ctx).Delete(entity) + if result.Error != nil { + return r.handleError("delete", result.Error) + } + if result.RowsAffected == 0 { + return dberrors.NewDatabaseError("delete", dberrors.ErrNotFound) } return nil } -func (r *GenericRepository) List(dest interface{}, conditions ...interface{}) error { - if err := r.db.Find(dest, conditions...).Error; err != nil { +func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { + query := r.db.WithContext(ctx) + for _, cond := range conditions { + switch c := cond.(type) { + case Filter: + if !r.allowedFilterFields[c.Field] { + return dberrors.NewDatabaseError("list", + fmt.Errorf("invalid filter field: %q", c.Field)) + } + // SAFETY: c.Field is interpolated into fmt.Sprintf below, but it is + // guaranteed to be one of the whitelisted column names checked above, + // so SQL injection via field names is not possible. + switch c.Op { + case "exact": + query = query.Where(fmt.Sprintf("%s = ?", c.Field), c.Value) + case ">=": + query = query.Where(fmt.Sprintf("%s >= ?", c.Field), c.Value) + case "<=": + query = query.Where(fmt.Sprintf("%s <= ?", c.Field), c.Value) + default: + // Default to LIKE for substring matching. + // Escape SQL wildcards (% and _) so they are treated as literals. + escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(fmt.Sprint(c.Value)) + query = query.Where(fmt.Sprintf("%s LIKE ?", c.Field), "%"+escaped+"%") + } + case Pagination: + if c.Limit > 0 { + query = query.Limit(c.Limit) + } + if c.Offset > 0 { + query = query.Offset(c.Offset) + } + } + } + if err := query.Find(dest).Error; err != nil { return r.handleError("list", err) } return nil diff --git a/backend/internal/models/validation.go b/backend/internal/models/validation.go index 8e54247..88052e7 100644 --- a/backend/internal/models/validation.go +++ b/backend/internal/models/validation.go @@ -5,6 +5,9 @@ import ( "regexp" ) +// Compile email regex once at package level for performance. +var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + var ( ErrInvalidEmail = errors.New("invalid email format") ErrEmptyUsername = errors.New("username cannot be empty") @@ -18,7 +21,6 @@ func (u *User) Validate() error { return ErrEmptyUsername } - emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) if !emailRegex.MatchString(u.Email) { return ErrInvalidEmail } diff --git a/backend/pkg/utils/utils.go b/backend/pkg/utils/utils.go index 279228a..9186a38 100644 --- a/backend/pkg/utils/utils.go +++ b/backend/pkg/utils/utils.go @@ -1,25 +1,30 @@ package utils import ( - "math/rand" - "time" + "crypto/rand" + "fmt" + "math/big" ) -// Utility function to check for errors and handle them appropriately -func CheckError(err error) { - if err != nil { - panic(err) // Handle error as needed +// GenerateRandomString generates a cryptographically random string of the specified length. +// Returns an error if length is negative or if the system random source fails. +func GenerateRandomString(length int) (string, error) { + if length < 0 { + return "", fmt.Errorf("length must be non-negative, got %d", length) + } + if length == 0 { + return "", nil } -} -// Function to generate a random string of a specified length -func GenerateRandomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) b := make([]byte, length) for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] + n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + return "", fmt.Errorf("crypto/rand failure: %w", err) + } + b[i] = charset[int(n.Int64())] } - return string(b) + return string(b), nil } diff --git a/backend/pkg/utils/utils_test.go b/backend/pkg/utils/utils_test.go index aafc0b7..b406ffe 100644 --- a/backend/pkg/utils/utils_test.go +++ b/backend/pkg/utils/utils_test.go @@ -4,13 +4,41 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUtilFunctions(t *testing.T) { - // Add your utility function tests here - t.Run("Test utility functions", func(t *testing.T) { - // Example test case - replace with actual utility function tests - result := true - assert.True(t, result, "Utility function should return true") + t.Parallel() + + t.Run("GenerateRandomString returns correct length", func(t *testing.T) { + t.Parallel() + s, err := GenerateRandomString(16) + require.NoError(t, err) + assert.Len(t, s, 16) + }) + + t.Run("GenerateRandomString with zero length", func(t *testing.T) { + t.Parallel() + s, err := GenerateRandomString(0) + require.NoError(t, err) + assert.Equal(t, "", s) + }) + + t.Run("GenerateRandomString with negative length", func(t *testing.T) { + t.Parallel() + _, err := GenerateRandomString(-1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-negative") + }) + + t.Run("GenerateRandomString contains valid characters", func(t *testing.T) { + t.Parallel() + s, err := GenerateRandomString(100) + require.NoError(t, err) + for _, c := range s { + assert.Truef(t, + (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'), + "unexpected character: %c", c) + } }) }