From 07db321072ec5448552cc75b83c6c3b728032e56 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 16:18:45 +0100 Subject: [PATCH 01/18] refactor: improve backend code quality and security MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add context.Context propagation to Repository interface and all implementations - Replace log.Printf with structured logging (log/slog) - Fix optimistic locking: version default 0→1, add WHERE clause on update - Add CORS origin configuration support (CORS_ALLOWED_ORIGINS env var) - Add RequestID and MaxBodySize middleware - Inject health checker dependency instead of package-level init() - Remove dead code: standalone GetItems/CreateItem, unused Database methods - Use crypto/rand instead of math/rand for GenerateRandomString - Consolidate DB errors into pkg/dberrors with backward-compatible re-exports - Fix error leak: return generic message for internal server errors - Use loc=UTC in MySQL DSN instead of loc=Local - Implement proper List() filtering with models.Filter and Pagination - Add Versionable interface for optimistic locking pattern - Update all tests to match new interfaces --- backend/api/main.go | 37 +++++--- backend/api/main_test.go | 48 +++++----- backend/internal/api/handlers/handlers.go | 59 ++++-------- .../internal/api/handlers/handlers_test.go | 40 -------- backend/internal/api/handlers/items.go | 38 ++++---- backend/internal/api/handlers/items_test.go | 19 ++-- .../internal/api/handlers/mock_repository.go | 18 ++-- backend/internal/api/handlers/rate_limiter.go | 57 ++++++++--- backend/internal/api/middleware/cors_test.go | 6 +- backend/internal/api/middleware/middleware.go | 63 ++++++++++-- .../api/middleware/middleware_test.go | 13 ++- backend/internal/api/routes/routes.go | 19 ++-- backend/internal/api/routes/routes_test.go | 13 ++- backend/internal/config/config.go | 17 +++- backend/internal/config/config_test.go | 2 +- .../database/azure/error_handling_test.go | 22 +++-- .../database/azure/repository_test.go | 12 ++- backend/internal/database/azure/table.go | 28 +++--- backend/internal/database/azure/table_test.go | 10 +- backend/internal/database/config.go | 2 +- backend/internal/database/config_test.go | 2 +- backend/internal/database/database.go | 95 ++----------------- backend/internal/database/database_test.go | 36 ++++--- backend/internal/database/errors.go | 45 +++------ backend/internal/database/factory.go | 18 +++- backend/internal/database/migrations.go | 6 +- backend/internal/database/repository.go | 6 +- backend/internal/health/health.go | 11 ++- backend/internal/health/health_test.go | 15 +-- backend/internal/models/models.go | 95 +++++++++++++++---- backend/internal/models/validation.go | 4 +- backend/pkg/utils/utils.go | 19 ++-- 32 files changed, 460 insertions(+), 415 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index c13f405..1e662bf 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -10,7 +10,7 @@ import ( "backend/internal/health" "context" "fmt" - "log" + "log/slog" "net/http" "os" "os/signal" @@ -38,22 +38,27 @@ func main() { // Load configuration cfg, err := config.LoadConfig() if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + slog.Error("Failed to load configuration", "error", err) + os.Exit(1) } - // Initialize database - db, err := database.NewFromAppConfig(cfg) + // Initialize repository using the factory (selects MySQL or Azure Table based on config) + repo, err := database.NewRepository(cfg) if err != nil { - log.Fatalf("Failed to initialize database: %v", err) + slog.Error("Failed to initialize repository", "error", err) + os.Exit(1) } - // Initialize health checker + // Initialize health checker with actual database dependency healthChecker := health.New() - healthChecker.AddCheck("database", db.Ping) + healthChecker.AddCheck("database", func(ctx context.Context) error { + return repo.Ping(ctx) + }) + healthChecker.SetReady(true) // Setup router router := gin.Default() - routes.SetupRoutes(router, db) + routes.SetupRoutes(router, repo, healthChecker, cfg) router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) // Create server with timeouts @@ -68,28 +73,30 @@ func main() { // Start server in a goroutine go func() { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Failed to start server: %v", err) + slog.Error("Failed to start server", "error", err) + os.Exit(1) } }() + slog.Info("Server started", "addr", srv.Addr) + // Wait for interrupt signal quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - log.Println("Shutting down server...") + slog.Info("Shutting down server...") - // Give outstanding requests 5 seconds to complete + // Give outstanding requests time to complete ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout) defer cancel() err = srv.Shutdown(ctx) - // Always execute cleanup cancel() if err != nil { - log.Printf("Server forced to shutdown: %v", err) - return // Return with error status from main + slog.Error("Server forced to shutdown", "error", err) + return } - log.Println("Server exiting") + slog.Info("Server exited gracefully") } diff --git a/backend/api/main_test.go b/backend/api/main_test.go index 2d571cd..28740ae 100644 --- a/backend/api/main_test.go +++ b/backend/api/main_test.go @@ -38,33 +38,33 @@ type MockRepository struct { mock.Mock } -func (m *MockRepository) Create(entity interface{}) error { - args := m.Called(entity) +func (m *MockRepository) Create(ctx context.Context, entity interface{}) error { + args := m.Called(ctx, entity) return args.Error(0) } -func (m *MockRepository) FindByID(id uint, dest interface{}) error { - args := m.Called(id, dest) +func (m *MockRepository) FindByID(ctx context.Context, id uint, dest interface{}) error { + args := m.Called(ctx, id, dest) return args.Error(0) } -func (m *MockRepository) Update(entity interface{}) error { - args := m.Called(entity) +func (m *MockRepository) Update(ctx context.Context, entity interface{}) error { + args := m.Called(ctx, entity) return args.Error(0) } -func (m *MockRepository) Delete(entity interface{}) error { - args := m.Called(entity) +func (m *MockRepository) Delete(ctx context.Context, entity interface{}) error { + args := m.Called(ctx, entity) return args.Error(0) } -func (m *MockRepository) List(dest interface{}, conditions ...interface{}) error { - args := m.Called(dest, conditions) +func (m *MockRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { + args := m.Called(ctx, dest, conditions) return args.Error(0) } -func (m *MockRepository) Ping() error { - args := m.Called() +func (m *MockRepository) Ping(ctx context.Context) error { + args := m.Called(ctx) return args.Error(0) } @@ -181,12 +181,12 @@ func TestHealthEndpoints(t *testing.T) { // Register health endpoints r.GET("/health/live", func(c *gin.Context) { - status := healthChecker.CheckLiveness() + status := healthChecker.CheckLiveness(c.Request.Context()) c.JSON(http.StatusOK, status) }) r.GET("/health/ready", func(c *gin.Context) { - status := healthChecker.CheckReadiness() + status := healthChecker.CheckReadiness(c.Request.Context()) if status.Status == "DOWN" { c.JSON(http.StatusServiceUnavailable, status) return @@ -211,7 +211,7 @@ func TestHealthEndpoints(t *testing.T) { assert.Contains(t, w.Body.String(), "UP") // Test with failing health check - healthChecker.AddCheck("test", func() error { + healthChecker.AddCheck("test", func(_ context.Context) error { return errors.New("test error") }) @@ -230,36 +230,36 @@ func TestDatabaseHealthCheck(t *testing.T) { // Test with SQLite database mockRepo := new(MockRepository) - mockRepo.On("Ping").Return(nil) + mockRepo.On("Ping", mock.Anything).Return(nil) // Add database health check - healthChecker.AddCheck("database", func() error { - return mockRepo.Ping() + healthChecker.AddCheck("database", func(_ context.Context) error { + return mockRepo.Ping(context.Background()) }) // Check readiness - status := healthChecker.CheckReadiness() + status := healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "DOWN", status.Status) // Initially DOWN because we haven't set ready // Mark as ready healthChecker.SetReady(true) // Check again - status = healthChecker.CheckReadiness() + status = healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "UP", status.Status) assert.Equal(t, "UP", status.Checks["database"].Status) // Test with failing database connection mockRepo = new(MockRepository) - mockRepo.On("Ping").Return(errors.New("connection failed")) + mockRepo.On("Ping", mock.Anything).Return(errors.New("connection failed")) healthChecker = health.New() healthChecker.SetReady(true) - healthChecker.AddCheck("database", func() error { - return mockRepo.Ping() + healthChecker.AddCheck("database", func(_ context.Context) error { + return mockRepo.Ping(context.Background()) }) - status = healthChecker.CheckReadiness() + status = healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "DOWN", status.Status) assert.Equal(t, "DOWN", status.Checks["database"].Status) assert.Contains(t, status.Checks["database"].Message, "connection failed") diff --git a/backend/internal/api/handlers/handlers.go b/backend/internal/api/handlers/handlers.go index f916182..543c4ff 100644 --- a/backend/internal/api/handlers/handlers.go +++ b/backend/internal/api/handlers/handlers.go @@ -7,14 +7,6 @@ import ( "github.com/gin-gonic/gin" ) -var healthChecker *health.HealthChecker - -func init() { - healthChecker = health.New() - // Set the service as ready after initialization - healthChecker.SetReady(true) -} - // @Summary Health Check // @Description Get API health status // @Tags health @@ -35,40 +27,25 @@ func Ping(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "pong"}) } -// @Summary Get items -// @Description Get all items -// @Tags items -// @Produce json -// @Success 200 {object} map[string]string -// @Router /api/v1/items [get] -func GetItems(c *gin.Context) { - // Logic to retrieve items goes here - c.JSON(http.StatusOK, gin.H{"message": "GetItems called"}) -} - -// @Summary Create item -// @Description Create a new item -// @Tags items -// @Accept json -// @Produce json -// @Success 201 {object} map[string]string -// @Router /api/v1/items [post] -func CreateItem(c *gin.Context) { - // Logic to create an item goes here - c.JSON(http.StatusCreated, gin.H{"message": "CreateItem called"}) -} - +// LivenessHandler returns a handler for liveness checks. +// The health checker is injected so the same instance used in main is checked. +// // @Summary Liveness Check // @Description Get API liveness status // @Tags health // @Produce json // @Success 200 {object} health.HealthStatus // @Router /health/live [get] -func LivenessCheck(c *gin.Context) { - status := healthChecker.CheckLiveness() - c.JSON(http.StatusOK, status) +func LivenessHandler(hc *health.HealthChecker) gin.HandlerFunc { + return func(c *gin.Context) { + status := hc.CheckLiveness(c.Request.Context()) + c.JSON(http.StatusOK, status) + } } +// ReadinessHandler returns a handler for readiness checks. +// The health checker is injected so the same instance used in main is checked. +// // @Summary Readiness Check // @Description Get API readiness status // @Tags health @@ -76,11 +53,13 @@ func LivenessCheck(c *gin.Context) { // @Success 200 {object} health.HealthStatus // @Failure 503 {object} health.HealthStatus // @Router /health/ready [get] -func ReadinessCheck(c *gin.Context) { - status := healthChecker.CheckReadiness() - if status.Status == "DOWN" { - c.JSON(http.StatusServiceUnavailable, status) - return +func ReadinessHandler(hc *health.HealthChecker) gin.HandlerFunc { + return func(c *gin.Context) { + status := hc.CheckReadiness(c.Request.Context()) + if status.Status == "DOWN" { + c.JSON(http.StatusServiceUnavailable, status) + return + } + c.JSON(http.StatusOK, status) } - c.JSON(http.StatusOK, status) } diff --git a/backend/internal/api/handlers/handlers_test.go b/backend/internal/api/handlers/handlers_test.go index 372f732..587bf01 100644 --- a/backend/internal/api/handlers/handlers_test.go +++ b/backend/internal/api/handlers/handlers_test.go @@ -61,43 +61,3 @@ func TestPingHandler(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "pong", response["message"]) } -func TestGetItemsHandler(t *testing.T) { - t.Parallel() - gin.SetMode(gin.TestMode) - - r := gin.Default() - r.GET("/items", GetItems) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/items", nil) - - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - - assert.Nil(t, err) - assert.Equal(t, "GetItems called", response["message"]) -} -func TestCreateItemHandler(t *testing.T) { - t.Parallel() - gin.SetMode(gin.TestMode) - - r := gin.Default() - r.POST("/items", CreateItem) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/items", nil) - - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusCreated, w.Code) - - var response map[string]string - err := json.Unmarshal(w.Body.Bytes(), &response) - - assert.Nil(t, err) - assert.Equal(t, "CreateItem called", response["message"]) -} diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index ff29a15..51dda28 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -42,7 +42,8 @@ func handleDBError(err error) (int, string) { if strings.Contains(err.Error(), "not found") { return http.StatusNotFound, "Item not found" } - return http.StatusInternalServerError, err.Error() + // Never leak raw error messages to clients + return http.StatusInternalServerError, "Internal server error" } // CreateItem godoc @@ -67,7 +68,7 @@ func (h *Handler) CreateItem(c *gin.Context) { return } - if err := h.repository.Create(&item); err != nil { + if err := h.repository.Create(c.Request.Context(), &item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -126,7 +127,7 @@ func (h *Handler) GetItems(c *gin.Context) { conditions = append(conditions, models.Pagination{Limit: limit, Offset: offset}) } - if err := h.repository.List(&items, conditions...); err != nil { + if err := h.repository.List(c.Request.Context(), &items, conditions...); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -152,7 +153,7 @@ func (h *Handler) GetItem(c *gin.Context) { } var item models.Item - if err := h.repository.FindByID(uint(id), &item); err != nil { + if err := h.repository.FindByID(c.Request.Context(), uint(id), &item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -181,7 +182,7 @@ func (h *Handler) UpdateItem(c *gin.Context) { // Get the current version from the database var currentItem models.Item - if err := h.repository.FindByID(uint(id), ¤tItem); err != nil { + if err := h.repository.FindByID(c.Request.Context(), uint(id), ¤tItem); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return @@ -193,23 +194,18 @@ func (h *Handler) UpdateItem(c *gin.Context) { return } - // Keep all existing fields but update name and price from the request - currentItem.Name = updateItem.Name - currentItem.Price = updateItem.Price - // Version check for optimistic locking (if version was provided in request) if updateItem.Version > 0 && updateItem.Version != currentItem.Version { c.JSON(http.StatusConflict, gin.H{"error": "Item has been modified by another request"}) return } - // Make sure we're using the version from the request for optimistic locking - currentItem.Version = updateItem.Version - - // We don't increment the version here, the repository will handle that + // Update fields from request but keep DB version for repository-level optimistic locking + currentItem.Name = updateItem.Name + currentItem.Price = updateItem.Price - if err := h.repository.Update(¤tItem); err != nil { - if err.Error() == "version mismatch" { + if err := h.repository.Update(c.Request.Context(), ¤tItem); err != nil { + if strings.Contains(err.Error(), "version mismatch") { c.JSON(http.StatusConflict, gin.H{"error": "Item has been modified by another request"}) return } @@ -237,9 +233,15 @@ func (h *Handler) DeleteItem(c *gin.Context) { return } - item := &models.Item{} - item.ID = uint(id) - if err := h.repository.Delete(item); err != nil { + // Verify item exists before attempting delete + var item models.Item + if err := h.repository.FindByID(c.Request.Context(), uint(id), &item); err != nil { + status, message := handleDBError(err) + c.JSON(status, gin.H{"error": message}) + return + } + + if err := h.repository.Delete(c.Request.Context(), &item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return diff --git a/backend/internal/api/handlers/items_test.go b/backend/internal/api/handlers/items_test.go index 5bc49db..28e01d6 100644 --- a/backend/internal/api/handlers/items_test.go +++ b/backend/internal/api/handlers/items_test.go @@ -2,6 +2,7 @@ package handlers import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -124,7 +125,7 @@ func TestGetItem(t *testing.T) { // Create a test item testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) tests := []struct { wantItem *models.Item // 8 bytes (pointer) @@ -179,7 +180,7 @@ func TestUpdateItem(t *testing.T) { // Create a test item testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) tests := []struct { name string @@ -259,7 +260,7 @@ func TestDeleteItem(t *testing.T) { // Create a test item testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) tests := []struct { name string @@ -310,7 +311,7 @@ func TestListItems(t *testing.T) { } for _, item := range items { - mockRepo.Create(&item) + mockRepo.Create(context.Background(), &item) } tests := []struct { @@ -415,7 +416,7 @@ func TestListItemsErrors(t *testing.T) { var response map[string]string err := json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) - assert.Contains(t, response["error"], "database error") + assert.Contains(t, response["error"], "Internal server error") } func TestConcurrentItemOperations(t *testing.T) { @@ -460,7 +461,7 @@ func TestConcurrentItemOperations(t *testing.T) { t.Run("concurrent item updates with version validation", func(t *testing.T) { // Create an item to update item := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(item) + mockRepo.Create(context.Background(), item) itemID := fmt.Sprint(item.ID) // Channel to collect successful updates @@ -660,7 +661,7 @@ func BenchmarkItemOperations(b *testing.B) { method: "GET", pathGen: func(mockRepo *MockRepository) string { testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) return "/api/v1/items/" + fmt.Sprint(testItem.ID) }, bodyGen: func(_ *MockRepository) []byte { return nil }, @@ -678,7 +679,7 @@ func BenchmarkItemOperations(b *testing.B) { method: "PUT", pathGen: func(mockRepo *MockRepository) string { testItem := &models.Item{Name: "Test Item", Price: 99.99} - mockRepo.Create(testItem) + mockRepo.Create(context.Background(), testItem) return "/api/v1/items/" + fmt.Sprint(testItem.ID) }, bodyGen: func(mockRepo *MockRepository) []byte { @@ -918,7 +919,7 @@ func TestHandleDBError(t *testing.T) { {err: duplicateErr, wantCode: http.StatusConflict, wantMsg: "Item already exists"}, {err: otherDBErr, wantCode: http.StatusInternalServerError, wantMsg: "Internal server error"}, {err: plainNotFound, wantCode: http.StatusNotFound, wantMsg: "Item not found"}, - {err: plainOther, wantCode: http.StatusInternalServerError, wantMsg: plainOther.Error()}, + {err: plainOther, wantCode: http.StatusInternalServerError, wantMsg: "Internal server error"}, } for _, tt := range tests { diff --git a/backend/internal/api/handlers/mock_repository.go b/backend/internal/api/handlers/mock_repository.go index 38a15ff..a16aebf 100644 --- a/backend/internal/api/handlers/mock_repository.go +++ b/backend/internal/api/handlers/mock_repository.go @@ -1,12 +1,14 @@ package handlers import ( - "backend/internal/models" + "context" "errors" "fmt" "sort" "strings" "sync" + + "backend/internal/models" ) // MockRepository is a mock implementation of the Repository interface for testing @@ -24,7 +26,7 @@ func NewMockRepository() *MockRepository { } } -func (m *MockRepository) Create(entity interface{}) error { +func (m *MockRepository) Create(_ context.Context, entity interface{}) error { m.Lock() defer m.Unlock() @@ -38,13 +40,13 @@ func (m *MockRepository) Create(entity interface{}) error { } item.ID = m.nextID - item.Version = 0 // Initialize version + item.Version = 1 // Initialize version (1 = first version; 0 = "not provided" sentinel) m.nextID++ m.items[item.ID] = item return nil } -func (m *MockRepository) FindByID(id uint, dest interface{}) error { +func (m *MockRepository) FindByID(_ context.Context, id uint, dest interface{}) error { m.RLock() defer m.RUnlock() @@ -66,7 +68,7 @@ func (m *MockRepository) FindByID(id uint, dest interface{}) error { return nil } -func (m *MockRepository) Update(entity interface{}) error { +func (m *MockRepository) Update(_ context.Context, entity interface{}) error { m.Lock() defer m.Unlock() @@ -102,7 +104,7 @@ func (m *MockRepository) Update(entity interface{}) error { return nil } -func (m *MockRepository) Delete(entity interface{}) error { +func (m *MockRepository) Delete(_ context.Context, entity interface{}) error { m.Lock() defer m.Unlock() @@ -123,7 +125,7 @@ func (m *MockRepository) Delete(entity interface{}) error { return nil } -func (m *MockRepository) List(dest interface{}, conditions ...interface{}) error { +func (m *MockRepository) List(_ context.Context, dest interface{}, conditions ...interface{}) error { m.RLock() defer m.RUnlock() @@ -218,7 +220,7 @@ func (m *MockRepository) List(dest interface{}, conditions ...interface{}) error } // Ping implements the Repository interface -func (m *MockRepository) Ping() error { +func (m *MockRepository) Ping(_ context.Context) error { return nil } diff --git a/backend/internal/api/handlers/rate_limiter.go b/backend/internal/api/handlers/rate_limiter.go index fd61757..25cb558 100644 --- a/backend/internal/api/handlers/rate_limiter.go +++ b/backend/internal/api/handlers/rate_limiter.go @@ -13,15 +13,24 @@ type RateLimiter struct { sync.RWMutex // size: 8 window time.Duration // size: 8 requests map[string][]time.Time // size: 8 (pointer) + done chan struct{} // size: 8 limit int // size: 4 } func NewRateLimiter(limit int, window time.Duration) *RateLimiter { - return &RateLimiter{ + rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, + done: make(chan struct{}), } + go rl.cleanup() + return rl +} + +// Stop terminates the background cleanup goroutine. +func (rl *RateLimiter) Stop() { + close(rl.done) } func (rl *RateLimiter) RateLimit() gin.HandlerFunc { @@ -48,17 +57,8 @@ func (rl *RateLimiter) RateLimit() gin.HandlerFunc { } rl.requests[ip] = valid - // For tests - use a slightly lower limit to ensure some requests get rate limited - // This helps identify rate limiting in tests that send requests concurrently - effectiveLimit := rl.limit - if len(rl.requests[ip]) >= 30 && now.Nanosecond()%4 == 0 { - // Artificially lower the limit sometimes to ensure rate limiting occurs in tests - effectiveLimit = len(rl.requests[ip]) - } - - // Check if limit exceeded - this must be evaluated AFTER cleaning up old requests - // and BEFORE adding the current request to ensure accurate rate limiting - if len(rl.requests[ip]) >= effectiveLimit { + // Check if limit exceeded + if len(rl.requests[ip]) >= rl.limit { c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) c.Abort() return @@ -69,3 +69,36 @@ func (rl *RateLimiter) RateLimit() gin.HandlerFunc { c.Next() } } + +// cleanup periodically removes expired entries to prevent memory leaks. +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(rl.window * 2) + defer ticker.Stop() + for { + select { + case <-ticker.C: + rl.cleanupExpired() + case <-rl.done: + return + } + } +} + +func (rl *RateLimiter) cleanupExpired() { + rl.Lock() + defer rl.Unlock() + now := time.Now() + for ip, times := range rl.requests { + var valid []time.Time + for _, t := range times { + if now.Sub(t) < rl.window { + valid = append(valid, t) + } + } + if len(valid) == 0 { + delete(rl.requests, ip) + } else { + rl.requests[ip] = valid + } + } +} diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index e9f7547..1bfd741 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -16,7 +16,7 @@ func TestCORSMiddleware(t *testing.T) { // Setup router with middleware r := gin.New() - r.Use(CORS()) + r.Use(CORS("*")) r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) @@ -30,7 +30,7 @@ func TestCORSMiddleware(t *testing.T) { // Check CORS headers assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) - assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization", w.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID", w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, http.StatusOK, w.Code) }) @@ -43,7 +43,7 @@ func TestCORSMiddleware(t *testing.T) { // Check CORS headers assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) - assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization", w.Header().Get("Access-Control-Allow-Headers")) + assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID", w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, http.StatusNoContent, w.Code) // OPTIONS request should return 204 No Content }) } diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index 5c7fb71..f509ac5 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -1,18 +1,34 @@ package middleware import ( - "log" + "crypto/rand" + "fmt" + "log/slog" "net/http" + "strings" "github.com/gin-gonic/gin" ) -// CORS middleware -func CORS() gin.HandlerFunc { +// CORS middleware with configurable allowed origins. +// Pass "*" or "" to allow all origins (development only). +// For production, pass a comma-separated list of allowed origins. +func CORS(allowedOrigins string) gin.HandlerFunc { return func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + if allowedOrigins == "" || allowedOrigins == "*" { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } else { + requestOrigin := c.Request.Header.Get("Origin") + for _, origin := range strings.Split(allowedOrigins, ",") { + if strings.TrimSpace(origin) == requestOrigin { + c.Writer.Header().Set("Access-Control-Allow-Origin", requestOrigin) + c.Writer.Header().Set("Vary", "Origin") + break + } + } + } c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, Authorization") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) @@ -23,10 +39,13 @@ func CORS() gin.HandlerFunc { } } -// Logger is a middleware that logs the incoming requests. +// Logger is a middleware that logs incoming requests using structured logging. func Logger() gin.HandlerFunc { return func(c *gin.Context) { - log.Printf("Request: %s %s", c.Request.Method, c.Request.URL) + slog.Info("incoming request", + "method", c.Request.Method, + "path", c.Request.URL.Path, + ) c.Next() } } @@ -36,7 +55,7 @@ func Recovery() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - log.Printf("Recovered from panic: %v", err) + slog.Error("recovered from panic", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"}) c.Abort() } @@ -44,3 +63,31 @@ func Recovery() gin.HandlerFunc { c.Next() } } + +// RequestID adds a unique request ID to each request. +// If the client sends an X-Request-ID header, it is reused; otherwise a new one is generated. +func RequestID() gin.HandlerFunc { + return func(c *gin.Context) { + requestID := c.GetHeader("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + c.Set("request_id", requestID) + c.Writer.Header().Set("X-Request-ID", requestID) + c.Next() + } +} + +// MaxBodySize limits the size of the request body to prevent memory exhaustion. +func MaxBodySize(maxBytes int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) + c.Next() + } +} + +func generateRequestID() string { + b := make([]byte, 16) + _, _ = rand.Read(b) + return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) +} diff --git a/backend/internal/api/middleware/middleware_test.go b/backend/internal/api/middleware/middleware_test.go index 6a05b29..3584200 100644 --- a/backend/internal/api/middleware/middleware_test.go +++ b/backend/internal/api/middleware/middleware_test.go @@ -3,10 +3,9 @@ package middleware import ( "bytes" "encoding/json" - "log" + "log/slog" "net/http" "net/http/httptest" - "os" "testing" "github.com/gin-gonic/gin" @@ -18,12 +17,12 @@ func TestLoggerMiddleware(t *testing.T) { // Set Gin to Test Mode gin.SetMode(gin.TestMode) - // Create a buffer to capture log output + // Create a buffer to capture slog output var buf bytes.Buffer - log.SetOutput(&buf) - defer func() { - log.SetOutput(os.Stdout) - }() + handler := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) + origLogger := slog.Default() + slog.SetDefault(slog.New(handler)) + defer slog.SetDefault(origLogger) // Setup router with middleware r := gin.New() diff --git a/backend/internal/api/routes/routes.go b/backend/internal/api/routes/routes.go index 9d5ba6d..4aa9788 100644 --- a/backend/internal/api/routes/routes.go +++ b/backend/internal/api/routes/routes.go @@ -3,24 +3,29 @@ package routes import ( "backend/internal/api/handlers" "backend/internal/api/middleware" + "backend/internal/config" + "backend/internal/health" "backend/internal/models" "github.com/gin-gonic/gin" ) -// SetupRoutes configures all the routes for our application -func SetupRoutes(router *gin.Engine, repository models.Repository) { +// SetupRoutes configures all the routes for our application. +// healthChecker is injected from main so the readiness endpoint reflects real dependency health. +func SetupRoutes(router *gin.Engine, repository models.Repository, healthChecker *health.HealthChecker, cfg *config.Config) { // Add middleware + router.Use(middleware.RequestID()) router.Use(middleware.Logger()) router.Use(middleware.Recovery()) - router.Use(middleware.CORS()) + router.Use(middleware.CORS(cfg.CORS.AllowedOrigins)) + router.Use(middleware.MaxBodySize(1 << 20)) // 1 MB default // Health check endpoints - health := router.Group("/health") + healthGroup := router.Group("/health") { - health.GET("/live", handlers.LivenessCheck) - health.GET("/ready", handlers.ReadinessCheck) - health.GET("", handlers.HealthCheck) // Keep the original health check for backward compatibility + healthGroup.GET("/live", handlers.LivenessHandler(healthChecker)) + healthGroup.GET("/ready", handlers.ReadinessHandler(healthChecker)) + healthGroup.GET("", handlers.HealthCheck) // Keep the original health check for backward compatibility } // API v1 routes diff --git a/backend/internal/api/routes/routes_test.go b/backend/internal/api/routes/routes_test.go index e6200ed..ebfe65e 100644 --- a/backend/internal/api/routes/routes_test.go +++ b/backend/internal/api/routes/routes_test.go @@ -7,6 +7,7 @@ import ( "testing" "backend/internal/api/handlers" + "backend/internal/config" "backend/internal/health" "github.com/gin-gonic/gin" @@ -22,10 +23,18 @@ func TestSetupRoutes(t *testing.T) { mockRepo := handlers.NewMockRepository() // Initialize health checker and set it as ready - health.New().SetReady(true) + healthChecker := health.New() + healthChecker.SetReady(true) + + // Create a minimal config for testing + cfg := &config.Config{ + CORS: config.CORSConfig{ + AllowedOrigins: "*", + }, + } // Setup routes - SetupRoutes(router, mockRepo) + SetupRoutes(router, mockRepo, healthChecker, cfg) // Test cases tests := []struct { diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index bd44476..22bcac1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -23,6 +23,11 @@ const ( defaultShutdownTimeout = 30 * time.Second ) +// CORSConfig holds CORS configuration +type CORSConfig struct { + AllowedOrigins string +} + // Config holds all configuration for the application // //nolint:govet // Struct field alignment has been optimized for better memory usage @@ -33,6 +38,7 @@ type Config struct { // Then string and simple field structs App AppConfig AzureTable AzureTableConfig + CORS CORSConfig Logging LogConfig } @@ -98,6 +104,12 @@ func (c *Config) Validate() error { return fmt.Errorf("app config: %w", err) } + if c.AzureTable.UseAzureTable { + if err := c.AzureTable.Validate(); err != nil { + return fmt.Errorf("azure table config: %w", err) + } + } + if err := c.Database.Validate(); err != nil { return fmt.Errorf("database config: %w", err) } @@ -212,7 +224,7 @@ func (c *DatabaseConfig) DSN() string { b.WriteByte(')') b.WriteByte('/') b.WriteString(c.DBName) - b.WriteString("?charset=utf8mb4&parseTime=True&loc=Local") + b.WriteString("?charset=utf8mb4&parseTime=True&loc=UTC") if c.MaxOpenConns > 0 { b.WriteString("&maxAllowedPacket=0") // Let server control packet size @@ -267,6 +279,9 @@ func LoadConfig() (*Config, error) { IdleTimeout: getEnvDuration("SERVER_IDLE_TIMEOUT", defaultIdleTimeout), ShutdownTimeout: getEnvDuration("SERVER_SHUTDOWN_TIMEOUT", defaultShutdownTimeout), }, + CORS: CORSConfig{ + AllowedOrigins: getEnv("CORS_ALLOWED_ORIGINS", "*"), + }, Logging: LogConfig{ Level: getEnv("LOG_LEVEL", "info"), File: getEnv("LOG_FILE", ""), diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 84d39ad..b6ebf70 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -149,7 +149,7 @@ func TestDatabaseDSN(t *testing.T) { DBName: "testdb", } - expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" + expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=UTC" assert.Equal(t, expected, dbConfig.DSN()) } func TestConfigValidate(t *testing.T) { diff --git a/backend/internal/database/azure/error_handling_test.go b/backend/internal/database/azure/error_handling_test.go index 6e4c129..383e76b 100644 --- a/backend/internal/database/azure/error_handling_test.go +++ b/backend/internal/database/azure/error_handling_test.go @@ -1,6 +1,7 @@ package azure_test import ( + "context" "testing" "backend/internal/database/azure" @@ -121,17 +122,18 @@ func TestTableRepository_InvalidData(t *testing.T) { // Create a minimal repository for testing repo := azure.NewTestTableRepository("testtable") + ctx := context.Background() // Test invalid inputs for different operations t.Run("Invalid inputs for Create", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.Create(nil) + err := repo.Create(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.Create("string") + err = repo.Create(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "entity must be *models.Item") }) @@ -139,12 +141,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for FindByID", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.FindByID(1, nil) + err := repo.FindByID(ctx, 1, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.FindByID(1, "string") + err = repo.FindByID(ctx, 1, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "dest must be *models.Item") }) @@ -152,12 +154,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for Update", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.Update(nil) + err := repo.Update(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.Update("string") + err = repo.Update(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "entity must be *models.Item") }) @@ -165,12 +167,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for Delete", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.Delete(nil) + err := repo.Delete(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.Delete("string") + err = repo.Delete(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "entity must be *models.Item") }) @@ -178,12 +180,12 @@ func TestTableRepository_InvalidData(t *testing.T) { t.Run("Invalid inputs for List", func(t *testing.T) { t.Parallel() // Test with nil - err := repo.List(nil) + err := repo.List(ctx, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "type_assertion") // Test with wrong type - err = repo.List("string") + err = repo.List(ctx, "string") assert.Error(t, err) assert.Contains(t, err.Error(), "dest must be *[]models.Item") }) diff --git a/backend/internal/database/azure/repository_test.go b/backend/internal/database/azure/repository_test.go index 4c4065d..5944814 100644 --- a/backend/internal/database/azure/repository_test.go +++ b/backend/internal/database/azure/repository_test.go @@ -194,8 +194,10 @@ func TestTableRepository_DatabaseErrors(t *testing.T) { if repo2 == nil { t.Skip("Could not create test repository (Azurite not available)") } + ctx := context.Background() + // Test Create - err := repo2.Create(&models.Item{}) + err := repo2.Create(ctx, &models.Item{}) assert.Error(t, err) var dbErr *dberrors.DatabaseError assert.True(t, errors.As(err, &dbErr)) @@ -203,21 +205,21 @@ func TestTableRepository_DatabaseErrors(t *testing.T) { assert.Equal(t, mockErr, errors.Unwrap(err)) // Test FindByID - err = repo2.FindByID(1, &models.Item{}) + err = repo2.FindByID(ctx, 1, &models.Item{}) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "find", dbErr.Op) assert.Equal(t, mockErr, errors.Unwrap(err)) // Test Update - err = repo2.Update(&models.Item{}) + err = repo2.Update(ctx, &models.Item{}) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "find", dbErr.Op) // First tries to find the item assert.Equal(t, mockErr, errors.Unwrap(err)) // Test Delete - err = repo2.Delete(&models.Item{}) + err = repo2.Delete(ctx, &models.Item{}) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "delete", dbErr.Op) @@ -225,7 +227,7 @@ func TestTableRepository_DatabaseErrors(t *testing.T) { // Test List var items []models.Item - err = repo2.List(&items) + err = repo2.List(ctx, &items) assert.Error(t, err) assert.True(t, errors.As(err, &dbErr)) assert.Equal(t, "list", dbErr.Op) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 54572a2..b8202eb 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -90,7 +90,7 @@ func NewTestTableRepository(tableName string) *TableRepository { } // Create implements the Repository interface -func (r *TableRepository) Create(entity interface{}) error { +func (r *TableRepository) Create(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", errors.New("entity must be *models.Item")) @@ -113,8 +113,7 @@ func (r *TableRepository) Create(entity interface{}) error { return dberrors.NewDatabaseError("marshal", err) } - // Create the entity - _, err = r.client.AddEntity(context.Background(), entityBytes, nil) + _, err = r.client.AddEntity(ctx, entityBytes, nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.ErrorCode == "EntityAlreadyExists" { @@ -129,14 +128,14 @@ func (r *TableRepository) Create(entity interface{}) error { } // FindByID implements the Repository interface -func (r *TableRepository) FindByID(id uint, dest interface{}) error { +func (r *TableRepository) FindByID(ctx context.Context, id uint, dest interface{}) error { item, ok := dest.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("dest must be *models.Item")) } // Get the entity - result, err := r.client.GetEntity(context.Background(), "items", strconv.FormatUint(uint64(id), 10), nil) + result, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(id), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -172,14 +171,14 @@ func (r *TableRepository) FindByID(id uint, dest interface{}) error { } // Update implements the Repository interface -func (r *TableRepository) Update(entity interface{}) error { +func (r *TableRepository) Update(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } // Check if entity exists first - _, err := r.client.GetEntity(context.Background(), "items", strconv.FormatUint(uint64(item.ID), 10), nil) + _, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -204,8 +203,7 @@ func (r *TableRepository) Update(entity interface{}) error { return dberrors.NewDatabaseError("marshal", err) } - // Update the entity - _, err = r.client.UpdateEntity(context.Background(), entityBytes, nil) + _, err = r.client.UpdateEntity(ctx, entityBytes, nil) if err != nil { return dberrors.NewDatabaseError("update", err) } @@ -215,13 +213,13 @@ func (r *TableRepository) Update(entity interface{}) error { } // Delete implements the Repository interface -func (r *TableRepository) Delete(entity interface{}) error { +func (r *TableRepository) Delete(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } - _, err := r.client.DeleteEntity(context.Background(), "items", strconv.FormatUint(uint64(item.ID), 10), nil) + _, err := r.client.DeleteEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -234,7 +232,7 @@ func (r *TableRepository) Delete(entity interface{}) error { } // List implements the Repository interface -func (r *TableRepository) List(dest interface{}, conditions ...interface{}) error { +func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { items, ok := dest.(*[]models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("dest must be *[]models.Item")) @@ -285,7 +283,7 @@ func (r *TableRepository) List(dest interface{}, conditions ...interface{}) erro // Fetch and process all entities for pager.More() { - response, err := pager.NextPage(context.Background()) + response, err := pager.NextPage(ctx) if err != nil { return dberrors.NewDatabaseError("list", err) } @@ -341,10 +339,10 @@ func (r *TableRepository) List(dest interface{}, conditions ...interface{}) erro } // Ping implements the Repository interface -func (r *TableRepository) Ping() error { +func (r *TableRepository) Ping(ctx context.Context) error { // List tables to check connectivity pager := r.client.NewListEntitiesPager(nil) - _, err := pager.NextPage(context.Background()) + _, err := pager.NextPage(ctx) if err != nil { return dberrors.NewDatabaseError("ping", err) } diff --git a/backend/internal/database/azure/table_test.go b/backend/internal/database/azure/table_test.go index 885d54a..f712200 100644 --- a/backend/internal/database/azure/table_test.go +++ b/backend/internal/database/azure/table_test.go @@ -92,11 +92,11 @@ func TestTableClientOperations(t *testing.T) { Name: "test", Price: 10.5, } - err := repo.Create(item) + err := repo.Create(context.Background(), item) assert.NoError(t, err) var retrieved models.Item - err = repo.FindByID(1, &retrieved) + err = repo.FindByID(context.Background(), 1, &retrieved) assert.NoError(t, err) assert.Equal(t, item.Name, retrieved.Name) assert.InDelta(t, item.Price, retrieved.Price, 0.001) @@ -121,7 +121,7 @@ func TestTableClientOperations(t *testing.T) { repo.SetTestClient(mockClient) var items []models.Item - err := repo.List(&items) + err := repo.List(context.Background(), &items) assert.NoError(t, err) assert.Len(t, items, 3) assert.Equal(t, "item1", items[0].Name) @@ -150,7 +150,7 @@ func TestTableClientOperations(t *testing.T) { Name: "test", Price: 10.5, } - err := repo.Update(item) + err := repo.Update(context.Background(), item) assert.NoError(t, err) }) @@ -170,7 +170,7 @@ func TestTableClientOperations(t *testing.T) { Name: "test", Price: 10.5, } - err := repo.Delete(item) + err := repo.Delete(context.Background(), item) assert.NoError(t, err) }) } diff --git a/backend/internal/database/config.go b/backend/internal/database/config.go index 35a48d7..a73756e 100644 --- a/backend/internal/database/config.go +++ b/backend/internal/database/config.go @@ -24,7 +24,7 @@ func NewConfig() *Config { // DSN returns the database connection string func (c *Config) DSN() string { - return c.User + ":" + c.Password + "@tcp(" + c.Host + ":" + c.Port + ")/" + c.DBName + "?charset=utf8mb4&parseTime=True&loc=Local" + return c.User + ":" + c.Password + "@tcp(" + c.Host + ":" + c.Port + ")/" + c.DBName + "?charset=utf8mb4&parseTime=True&loc=UTC" } // Helper function to get environment variables with default values diff --git a/backend/internal/database/config_test.go b/backend/internal/database/config_test.go index fd93211..ad54d82 100644 --- a/backend/internal/database/config_test.go +++ b/backend/internal/database/config_test.go @@ -59,6 +59,6 @@ func TestConfigDSN(t *testing.T) { DBName: "testdb", } - expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" + expected := "testuser:testpass@tcp(testhost:3306)/testdb?charset=utf8mb4&parseTime=True&loc=UTC" assert.Equal(t, expected, config.DSN()) } diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go index 9817e57..386ef84 100644 --- a/backend/internal/database/database.go +++ b/backend/internal/database/database.go @@ -1,7 +1,7 @@ package database import ( - "log" + "log/slog" "strings" "gorm.io/driver/mysql" @@ -9,13 +9,16 @@ import ( "gorm.io/gorm/logger" ) +// Database wraps a gorm.DB instance with additional utilities. type Database struct { *gorm.DB } -func NewDatabase(dsn string, logger logger.Interface) (*Database, error) { +// NewDatabase creates a new database connection. Callers should configure +// connection pool settings via the returned *sql.DB if needed. +func NewDatabase(dsn string, logCfg logger.Interface) (*Database, error) { db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logger, + Logger: logCfg, }) if err != nil { if strings.Contains(err.Error(), "connection refused") { @@ -24,21 +27,11 @@ func NewDatabase(dsn string, logger logger.Interface) (*Database, error) { return nil, NewDatabaseError("connect", err) } - // Configure connection pool - sqlDB, err := db.DB() - if err != nil { - return nil, NewDatabaseError("configure", err) - } - - // Set reasonable defaults for the connection pool - sqlDB.SetMaxIdleConns(5) - sqlDB.SetMaxOpenConns(20) - - log.Println("Connected to database successfully") + slog.Info("Connected to database successfully") return &Database{DB: db}, nil } -// Transaction executes operations within a database transaction +// Transaction executes operations within a database transaction. func (d *Database) Transaction(fn func(tx *gorm.DB) error) error { err := d.DB.Transaction(fn) if err != nil { @@ -47,79 +40,11 @@ func (d *Database) Transaction(fn func(tx *gorm.DB) error) error { return nil } -// HandleError translates GORM/MySQL errors into our custom error types -func (d *Database) HandleError(op string, err error) error { - if err == nil { - return nil - } - - if err == gorm.ErrRecordNotFound { - return NewDatabaseError(op, ErrNotFound) - } - - // Check for duplicate key violations - if strings.Contains(err.Error(), "Duplicate entry") { - return NewDatabaseError(op, ErrDuplicateKey) - } - - // Handle validation errors - if strings.Contains(err.Error(), "validation failed") { - return NewDatabaseError(op, ErrValidation) - } - - return NewDatabaseError(op, err) -} - -// Create implements models.Repository -func (d *Database) Create(value interface{}) error { - result := d.DB.Create(value) - if result.Error != nil { - return d.HandleError("create", result.Error) - } - return nil -} - -// FindByID implements models.Repository -func (d *Database) FindByID(id uint, dest interface{}) error { - result := d.DB.First(dest, id) - if result.Error != nil { - return d.HandleError("find", result.Error) - } - return nil -} - -// Update implements models.Repository -func (d *Database) Update(value interface{}) error { - result := d.DB.Save(value) - if result.Error != nil { - return d.HandleError("update", result.Error) - } - return nil -} - -// Delete implements models.Repository -func (d *Database) Delete(value interface{}) error { - result := d.DB.Delete(value) - if result.Error != nil { - return d.HandleError("delete", result.Error) - } - return nil -} - -// List implements models.Repository -func (d *Database) List(dest interface{}, conditions ...interface{}) error { - result := d.DB.Find(dest) - if result.Error != nil { - return d.HandleError("list", result.Error) - } - return nil -} - -// Ping implements models.Repository +// Ping checks if the database connection is alive. func (d *Database) Ping() error { sqlDB, err := d.DB.DB() if err != nil { - return d.HandleError("ping", err) + return err } return sqlDB.Ping() } diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go index 0d45860..4e6e3ae 100644 --- a/backend/internal/database/database_test.go +++ b/backend/internal/database/database_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "testing" "backend/internal/models" @@ -110,16 +111,20 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + ctx := context.Background() t.Run("Create Item", func(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() + item := &models.Item{ Name: "Test Item", Price: 99.99, } - err := db.Create(item) + err := repo.Create(ctx, item) assert.NoError(t, err) assert.NotZero(t, item.ID) }) @@ -128,16 +133,18 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() // Create item first initialItem := &models.Item{ Name: "Test Item", Price: 99.99, } - require.NoError(t, db.Create(initialItem)) + require.NoError(t, repo.Create(ctx, initialItem)) var item models.Item - err := db.FindByID(1, &item) + err := repo.FindByID(ctx, initialItem.ID, &item) assert.NoError(t, err) assert.Equal(t, "Test Item", item.Name) assert.Equal(t, 99.99, item.Price) @@ -147,23 +154,26 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() // Create item first initialItem := &models.Item{ Name: "Test Item", Price: 99.99, } - require.NoError(t, db.Create(initialItem)) + require.NoError(t, repo.Create(ctx, initialItem)) var item models.Item - err := db.FindByID(1, &item) + err := repo.FindByID(ctx, initialItem.ID, &item) require.NoError(t, err) item.Price = 199.99 - err = db.Update(&item) + err = repo.Update(ctx, &item) assert.NoError(t, err) var updatedItem models.Item - err = db.FindByID(1, &updatedItem) + err = repo.FindByID(ctx, initialItem.ID, &updatedItem) + assert.NoError(t, err) assert.Equal(t, 199.99, updatedItem.Price) }) @@ -171,20 +181,24 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() // Create item first initialItem := &models.Item{ Name: "Test Item", Price: 99.99, } - require.NoError(t, db.Create(initialItem)) + require.NoError(t, repo.Create(ctx, initialItem)) - item := &models.Item{Base: models.Base{ID: 1}} - err := db.Delete(item) + item := &models.Item{Base: models.Base{ID: initialItem.ID}} + err := repo.Delete(ctx, item) assert.NoError(t, err) var deleted models.Item - err = db.FindByID(1, &deleted) + err = repo.FindByID(ctx, initialItem.ID, &deleted) assert.Error(t, err, "Should not find deleted item") }) + + _ = ctx // suppress unused variable (used only in subtests) } diff --git a/backend/internal/database/errors.go b/backend/internal/database/errors.go index bb3e2e4..18b0268 100644 --- a/backend/internal/database/errors.go +++ b/backend/internal/database/errors.go @@ -1,39 +1,16 @@ +// Package database provides database connectivity, migrations, and repository factories. +// Error types are re-exported from pkg/dberrors to maintain a single source of truth. package database -import ( - "errors" - "fmt" -) +import "backend/pkg/dberrors" + +// Re-export error types from pkg/dberrors for backward compatibility. +type DatabaseError = dberrors.DatabaseError -// Common database errors var ( - ErrNotFound = errors.New("record not found") - ErrDuplicateKey = errors.New("duplicate key violation") - ErrValidation = errors.New("validation error") - ErrConnectionFailed = errors.New("database connection failed") + ErrNotFound = dberrors.ErrNotFound + ErrDuplicateKey = dberrors.ErrDuplicateKey + ErrValidation = dberrors.ErrValidation + ErrConnectionFailed = dberrors.ErrConnectionFailed + NewDatabaseError = dberrors.NewDatabaseError ) - -// DatabaseError wraps database-specific errors with additional context -type DatabaseError struct { - Op string - Err error -} - -func (e *DatabaseError) Error() string { - if e.Op != "" { - return fmt.Sprintf("%s: %v", e.Op, e.Err) - } - return e.Err.Error() -} - -func (e *DatabaseError) Unwrap() error { - return e.Err -} - -// NewDatabaseError creates a new database error with operation context -func NewDatabaseError(op string, err error) error { - if err == nil { - return nil - } - return &DatabaseError{Op: op, Err: err} -} diff --git a/backend/internal/database/factory.go b/backend/internal/database/factory.go index 399e3e3..9a6047d 100644 --- a/backend/internal/database/factory.go +++ b/backend/internal/database/factory.go @@ -3,7 +3,7 @@ package database import ( "context" "fmt" - "log" + "log/slog" "time" "backend/internal/config" @@ -43,8 +43,12 @@ func NewFromAppConfig(cfg *config.Config) (*Database, error) { retryCount++ if retryCount < maxRetries { - log.Printf("Failed to connect to database (attempt %d/%d): %v. Retrying in %v...", - retryCount, maxRetries, err, retryDelay) + slog.Warn("Failed to connect to database, retrying...", + "attempt", retryCount, + "maxRetries", maxRetries, + "error", err, + "retryDelay", retryDelay, + ) time.Sleep(retryDelay) } } @@ -121,8 +125,12 @@ func NewFromDBConfig(cfg *Config) (*Database, error) { retryCount++ if retryCount < maxRetries { - log.Printf("Failed to connect to database (attempt %d/%d): %v. Retrying in %v...", - retryCount, maxRetries, err, retryDelay) + slog.Warn("Failed to connect to database, retrying...", + "attempt", retryCount, + "maxRetries", maxRetries, + "error", err, + "retryDelay", retryDelay, + ) time.Sleep(retryDelay) } } diff --git a/backend/internal/database/migrations.go b/backend/internal/database/migrations.go index 4a86297..64b3e6b 100644 --- a/backend/internal/database/migrations.go +++ b/backend/internal/database/migrations.go @@ -1,7 +1,7 @@ package database import ( - "log" + "log/slog" "backend/internal/database/schema" "backend/internal/models" @@ -11,7 +11,7 @@ import ( // AutoMigrate runs database migrations for all models func (d *Database) AutoMigrate() error { - log.Println("Running database migrations...") + slog.Info("Running database migrations...") // Initialize migrator migrator := schema.NewMigrator(d.DB) @@ -66,6 +66,6 @@ func (d *Database) AutoMigrate() error { return err } - log.Println("Database migrations completed successfully") + slog.Info("Database migrations completed successfully") return nil } diff --git a/backend/internal/database/repository.go b/backend/internal/database/repository.go index e1b4bbb..e74bd2f 100644 --- a/backend/internal/database/repository.go +++ b/backend/internal/database/repository.go @@ -2,7 +2,7 @@ package database import ( "fmt" - "log" + "log/slog" "backend/internal/config" "backend/internal/database/azure" @@ -12,7 +12,7 @@ import ( // NewRepository creates a new repository based on the configuration func NewRepository(cfg *config.Config) (models.Repository, error) { if cfg.AzureTable.UseAzureTable { - log.Println("Using Azure Table Storage as repository") + slog.Info("Using Azure Table Storage as repository") return azure.NewTableRepository( cfg.AzureTable.AccountName, cfg.AzureTable.AccountKey, @@ -22,7 +22,7 @@ func NewRepository(cfg *config.Config) (models.Repository, error) { ) } - log.Println("Using MySQL as repository") + slog.Info("Using MySQL as repository") db, err := NewFromAppConfig(cfg) if err != nil { return nil, fmt.Errorf("failed to initialize MySQL database: %w", err) diff --git a/backend/internal/health/health.go b/backend/internal/health/health.go index bc72c46..32ac0a3 100644 --- a/backend/internal/health/health.go +++ b/backend/internal/health/health.go @@ -1,6 +1,7 @@ package health import ( + "context" "sync" "time" ) @@ -12,7 +13,9 @@ type HealthChecker struct { dependencies map[string]HealthCheck } -type HealthCheck func() error +// HealthCheck is a function that checks a dependency's health. +// It receives a context for cancellation and timeout support. +type HealthCheck func(ctx context.Context) error type HealthStatus struct { Status string `json:"status"` @@ -46,7 +49,7 @@ func (h *HealthChecker) SetReady(ready bool) { h.isReady = ready } -func (h *HealthChecker) CheckLiveness() HealthStatus { +func (h *HealthChecker) CheckLiveness(_ context.Context) HealthStatus { uptime := time.Since(h.startTime).String() return HealthStatus{ Status: "UP", @@ -54,7 +57,7 @@ func (h *HealthChecker) CheckLiveness() HealthStatus { } } -func (h *HealthChecker) CheckReadiness() HealthStatus { +func (h *HealthChecker) CheckReadiness(ctx context.Context) HealthStatus { h.mu.RLock() defer h.mu.RUnlock() @@ -73,7 +76,7 @@ func (h *HealthChecker) CheckReadiness() HealthStatus { } for name, check := range h.dependencies { - if err := check(); err != nil { + if err := check(ctx); err != nil { status.Status = "DOWN" status.Checks[name] = CheckStatus{ Status: "DOWN", diff --git a/backend/internal/health/health_test.go b/backend/internal/health/health_test.go index a4ee4c2..138c652 100644 --- a/backend/internal/health/health_test.go +++ b/backend/internal/health/health_test.go @@ -1,6 +1,7 @@ package health import ( + "context" "errors" "testing" ) @@ -16,7 +17,7 @@ func TestNew(t *testing.T) { func TestLivenessCheck(t *testing.T) { t.Parallel() h := New() - status := h.CheckLiveness() + status := h.CheckLiveness(context.Background()) if status.Status != "UP" { t.Errorf("Expected status UP, got %s", status.Status) @@ -33,7 +34,7 @@ func TestReadinessCheck(t *testing.T) { t.Run("Service not ready", func(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest - status := h.CheckReadiness() + status := h.CheckReadiness(context.Background()) if status.Status != "DOWN" { t.Errorf("Expected status DOWN, got %s", status.Status) } @@ -43,7 +44,7 @@ func TestReadinessCheck(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest h.SetReady(true) - status := h.CheckReadiness() + status := h.CheckReadiness(context.Background()) if status.Status != "UP" { t.Errorf("Expected status UP, got %s", status.Status) } @@ -57,8 +58,8 @@ func TestHealthChecks(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest h.SetReady(true) - h.AddCheck("test", func() error { return nil }) - status := h.CheckReadiness() + h.AddCheck("test", func(_ context.Context) error { return nil }) + status := h.CheckReadiness(context.Background()) if status.Status != "UP" { t.Errorf("Expected status UP, got %s", status.Status) } @@ -71,8 +72,8 @@ func TestHealthChecks(t *testing.T) { t.Parallel() h := New() // Create a fresh instance for this subtest h.SetReady(true) - h.AddCheck("failing", func() error { return errors.New("test error") }) - status := h.CheckReadiness() + h.AddCheck("failing", func(_ context.Context) error { return errors.New("test error") }) + status := h.CheckReadiness(context.Background()) if status.Status != "DOWN" { t.Errorf("Expected status DOWN, got %s", status.Status) } diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 9ee28e3..395e58d 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -1,6 +1,9 @@ package models import ( + "context" + "errors" + "fmt" "strings" "time" @@ -22,7 +25,7 @@ type Item struct { Base Name string `gorm:"size:255;not null" json:"name"` Price float64 `json:"price"` - Version uint `gorm:"not null;default:0" json:"version"` // For optimistic locking + Version uint `gorm:"not null;default:1" json:"version"` // For optimistic locking (1 = initial; 0 = not provided) } // User represents a user in the system @@ -38,14 +41,26 @@ type Validator interface { Validate() error } -// Repository defines the interface for database operations +// Versionable is an interface for models that support optimistic locking. +type Versionable interface { + GetVersion() uint + SetVersion(v uint) +} + +// GetVersion implements Versionable for Item. +func (i *Item) GetVersion() uint { return i.Version } + +// SetVersion implements Versionable for Item. +func (i *Item) SetVersion(v uint) { i.Version = v } + +// Repository defines the interface for database operations. type Repository interface { - Create(interface{}) error - FindByID(id uint, dest interface{}) error - Update(interface{}) error - Delete(interface{}) error - List(dest interface{}, conditions ...interface{}) error - Ping() error + Create(ctx context.Context, entity interface{}) error + FindByID(ctx context.Context, id uint, dest interface{}) error + Update(ctx context.Context, entity interface{}) error + Delete(ctx context.Context, entity interface{}) error + List(ctx context.Context, dest interface{}, conditions ...interface{}) error + Ping(ctx context.Context) error } // GenericRepository implements the Repository interface @@ -59,56 +74,96 @@ func NewRepository(db *gorm.DB) Repository { } // Ping checks if the database is reachable -func (r *GenericRepository) Ping() error { +func (r *GenericRepository) Ping(ctx context.Context) error { sqlDB, err := r.db.DB() if err != nil { return err } - return sqlDB.Ping() + return sqlDB.PingContext(ctx) } -func (r *GenericRepository) Create(entity interface{}) error { +func (r *GenericRepository) Create(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { return dberrors.NewDatabaseError("validate", err) } } - if err := r.db.Create(entity).Error; err != nil { + if err := r.db.WithContext(ctx).Create(entity).Error; err != nil { return r.handleError("create", err) } return nil } -func (r *GenericRepository) FindByID(id uint, dest interface{}) error { - if err := r.db.First(dest, id).Error; err != nil { +func (r *GenericRepository) FindByID(ctx context.Context, id uint, dest interface{}) error { + if err := r.db.WithContext(ctx).First(dest, id).Error; err != nil { return r.handleError("find", err) } return nil } -func (r *GenericRepository) Update(entity interface{}) error { +func (r *GenericRepository) Update(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { return dberrors.NewDatabaseError("validate", err) } } - if err := r.db.Save(entity).Error; err != nil { + // Optimistic locking for Versionable entities + if ver, ok := entity.(Versionable); ok { + currentVersion := ver.GetVersion() + ver.SetVersion(currentVersion + 1) + result := r.db.WithContext(ctx).Where("version = ?", currentVersion).Save(entity) + if result.Error != nil { + ver.SetVersion(currentVersion) // Roll back version on error + return r.handleError("update", result.Error) + } + if result.RowsAffected == 0 { + ver.SetVersion(currentVersion) // Roll back version on mismatch + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } + return nil + } + + if err := r.db.WithContext(ctx).Save(entity).Error; err != nil { return r.handleError("update", err) } return nil } -func (r *GenericRepository) Delete(entity interface{}) error { - if err := r.db.Delete(entity).Error; err != nil { +func (r *GenericRepository) Delete(ctx context.Context, entity interface{}) error { + if err := r.db.WithContext(ctx).Delete(entity).Error; err != nil { return r.handleError("delete", err) } return nil } -func (r *GenericRepository) List(dest interface{}, conditions ...interface{}) error { - if err := r.db.Find(dest, conditions...).Error; err != nil { +func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { + query := r.db.WithContext(ctx) + for _, cond := range conditions { + switch c := cond.(type) { + case Filter: + switch c.Op { + case "exact": + query = query.Where(fmt.Sprintf("%s = ?", c.Field), c.Value) + case ">=": + query = query.Where(fmt.Sprintf("%s >= ?", c.Field), c.Value) + case "<=": + query = query.Where(fmt.Sprintf("%s <= ?", c.Field), c.Value) + default: + // Default to LIKE for substring matching + query = query.Where(fmt.Sprintf("%s LIKE ?", c.Field), fmt.Sprintf("%%%v%%", c.Value)) + } + case Pagination: + if c.Limit > 0 { + query = query.Limit(c.Limit) + } + if c.Offset > 0 { + query = query.Offset(c.Offset) + } + } + } + if err := query.Find(dest).Error; err != nil { return r.handleError("list", err) } return nil diff --git a/backend/internal/models/validation.go b/backend/internal/models/validation.go index 8e54247..88052e7 100644 --- a/backend/internal/models/validation.go +++ b/backend/internal/models/validation.go @@ -5,6 +5,9 @@ import ( "regexp" ) +// Compile email regex once at package level for performance. +var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + var ( ErrInvalidEmail = errors.New("invalid email format") ErrEmptyUsername = errors.New("username cannot be empty") @@ -18,7 +21,6 @@ func (u *User) Validate() error { return ErrEmptyUsername } - emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) if !emailRegex.MatchString(u.Email) { return ErrInvalidEmail } diff --git a/backend/pkg/utils/utils.go b/backend/pkg/utils/utils.go index 279228a..e98eb18 100644 --- a/backend/pkg/utils/utils.go +++ b/backend/pkg/utils/utils.go @@ -1,25 +1,24 @@ package utils import ( - "math/rand" - "time" + "crypto/rand" + "math/big" ) -// Utility function to check for errors and handle them appropriately -func CheckError(err error) { - if err != nil { - panic(err) // Handle error as needed - } +// CheckError returns the error for the caller to handle. +// Deprecated: Use explicit error handling instead of this wrapper. +func CheckError(err error) error { + return err } -// Function to generate a random string of a specified length +// GenerateRandomString generates a cryptographically random string of the specified length. func GenerateRandomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) b := make([]byte, length) for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] + n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + b[i] = charset[n.Int64()] } return string(b) } From 01eff7cf27905a15216c154d99030dcd94168de3 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 16:29:32 +0100 Subject: [PATCH 02/18] fix: address PR review comments - GenerateRandomString: return (string, error), validate negative length - generateRequestID: handle rand.Read error with timestamp fallback - Logger test: remove t.Parallel() to avoid global slog mutation - main: use gin.New() instead of gin.Default() to avoid duplicate middleware - database_test: remove unused top-level ctx variable - RateLimiter.Stop: use sync.Once for idempotent close - Expand utils tests for new error return signature --- backend/api/main.go | 4 +- backend/internal/api/handlers/rate_limiter.go | 6 ++- backend/internal/api/middleware/middleware.go | 7 ++- .../api/middleware/middleware_test.go | 2 +- backend/internal/database/database_test.go | 2 - backend/pkg/utils/utils.go | 18 ++++++-- backend/pkg/utils/utils_test.go | 44 ++++++++++++++++--- 7 files changed, 68 insertions(+), 15 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index 1e662bf..4076e61 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -56,8 +56,8 @@ func main() { }) healthChecker.SetReady(true) - // Setup router - router := gin.Default() + // Setup router — use gin.New() since SetupRoutes registers its own Logger and Recovery middleware. + router := gin.New() routes.SetupRoutes(router, repo, healthChecker, cfg) router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) diff --git a/backend/internal/api/handlers/rate_limiter.go b/backend/internal/api/handlers/rate_limiter.go index 25cb558..7728e4f 100644 --- a/backend/internal/api/handlers/rate_limiter.go +++ b/backend/internal/api/handlers/rate_limiter.go @@ -15,6 +15,7 @@ type RateLimiter struct { requests map[string][]time.Time // size: 8 (pointer) done chan struct{} // size: 8 limit int // size: 4 + stopOnce sync.Once // ensures Stop is idempotent } func NewRateLimiter(limit int, window time.Duration) *RateLimiter { @@ -29,8 +30,11 @@ func NewRateLimiter(limit int, window time.Duration) *RateLimiter { } // Stop terminates the background cleanup goroutine. +// It is safe to call Stop multiple times. func (rl *RateLimiter) Stop() { - close(rl.done) + rl.stopOnce.Do(func() { + close(rl.done) + }) } func (rl *RateLimiter) RateLimit() gin.HandlerFunc { diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index f509ac5..8a9857a 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -6,6 +6,7 @@ import ( "log/slog" "net/http" "strings" + "time" "github.com/gin-gonic/gin" ) @@ -88,6 +89,10 @@ func MaxBodySize(maxBytes int64) gin.HandlerFunc { func generateRequestID() string { b := make([]byte, 16) - _, _ = rand.Read(b) + if _, err := rand.Read(b); err != nil { + // Fallback: use timestamp-based ID if crypto/rand fails + slog.Warn("failed to generate random request ID, using fallback", "error", err) + return fmt.Sprintf("fallback-%d", time.Now().UnixNano()) + } return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) } diff --git a/backend/internal/api/middleware/middleware_test.go b/backend/internal/api/middleware/middleware_test.go index 3584200..00609d5 100644 --- a/backend/internal/api/middleware/middleware_test.go +++ b/backend/internal/api/middleware/middleware_test.go @@ -13,7 +13,7 @@ import ( ) func TestLoggerMiddleware(t *testing.T) { - t.Parallel() + // Not parallel: this test mutates the global slog default logger. // Set Gin to Test Mode gin.SetMode(gin.TestMode) diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go index 4e6e3ae..543a33a 100644 --- a/backend/internal/database/database_test.go +++ b/backend/internal/database/database_test.go @@ -111,7 +111,6 @@ func TestItemCRUD(t *testing.T) { t.Parallel() db := setupTestDB(t) require.NoError(t, db.AutoMigrate()) - ctx := context.Background() t.Run("Create Item", func(t *testing.T) { t.Parallel() @@ -200,5 +199,4 @@ func TestItemCRUD(t *testing.T) { assert.Error(t, err, "Should not find deleted item") }) - _ = ctx // suppress unused variable (used only in subtests) } diff --git a/backend/pkg/utils/utils.go b/backend/pkg/utils/utils.go index e98eb18..45e2d67 100644 --- a/backend/pkg/utils/utils.go +++ b/backend/pkg/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "crypto/rand" + "fmt" "math/big" ) @@ -12,13 +13,24 @@ func CheckError(err error) error { } // GenerateRandomString generates a cryptographically random string of the specified length. -func GenerateRandomString(length int) string { +// Returns an error if length is negative or if the system random source fails. +func GenerateRandomString(length int) (string, error) { + if length < 0 { + return "", fmt.Errorf("length must be non-negative, got %d", length) + } + if length == 0 { + return "", nil + } + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, length) for i := range b { - n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + return "", fmt.Errorf("failed to generate random byte at index %d: %w", i, err) + } b[i] = charset[n.Int64()] } - return string(b) + return string(b), nil } diff --git a/backend/pkg/utils/utils_test.go b/backend/pkg/utils/utils_test.go index aafc0b7..f2db26b 100644 --- a/backend/pkg/utils/utils_test.go +++ b/backend/pkg/utils/utils_test.go @@ -4,13 +4,47 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUtilFunctions(t *testing.T) { - // Add your utility function tests here - t.Run("Test utility functions", func(t *testing.T) { - // Example test case - replace with actual utility function tests - result := true - assert.True(t, result, "Utility function should return true") + t.Parallel() + + t.Run("GenerateRandomString returns correct length", func(t *testing.T) { + t.Parallel() + s, err := GenerateRandomString(16) + require.NoError(t, err) + assert.Len(t, s, 16) + }) + + t.Run("GenerateRandomString with zero length", func(t *testing.T) { + t.Parallel() + s, err := GenerateRandomString(0) + require.NoError(t, err) + assert.Equal(t, "", s) + }) + + t.Run("GenerateRandomString with negative length", func(t *testing.T) { + t.Parallel() + _, err := GenerateRandomString(-1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-negative") + }) + + t.Run("GenerateRandomString contains valid characters", func(t *testing.T) { + t.Parallel() + s, err := GenerateRandomString(100) + require.NoError(t, err) + for _, c := range s { + assert.True(t, + (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'), + "unexpected character: %c", c) + } + }) + + t.Run("CheckError returns the error", func(t *testing.T) { + t.Parallel() + assert.Nil(t, CheckError(nil)) + assert.Error(t, CheckError(assert.AnError)) }) } From 18dd33efd8a919b95f0c7f6644f12371ee762426 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 16:45:03 +0100 Subject: [PATCH 03/18] fix: address second round of PR review comments - Azure Table Storage: add optimistic locking via Version field comparison and ETag-based conditional updates (412 Precondition Failed detection) - Azure Table: store/read Version in Create and FindByID for consistency - CORS tests: add coverage for origin whitelist, disallowed origins, multiple comma-separated origins, and empty-string fallback - RequestID middleware tests: verify generation, reuse, response header, context storage, and uniqueness - MaxBodySize middleware tests: verify within-limit and over-limit requests - Remove redundant cancel() call in main.go (defer handles it) - Filter.Field SQL injection prevention: add allowedFilterFields whitelist that rejects unknown field names before interpolation --- backend/api/main.go | 1 - backend/internal/api/middleware/cors_test.go | 85 ++++++++++--- .../api/middleware/middleware_test.go | 116 ++++++++++++++++++ backend/internal/database/azure/table.go | 58 ++++++++- backend/internal/database/azure/table_test.go | 38 +++++- backend/internal/models/models.go | 11 ++ 6 files changed, 287 insertions(+), 22 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index 4076e61..cdf0be4 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -91,7 +91,6 @@ func main() { defer cancel() err = srv.Shutdown(ctx) - cancel() if err != nil { slog.Error("Server forced to shutdown", "error", err) diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index 1bfd741..5240283 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -11,39 +11,96 @@ import ( func TestCORSMiddleware(t *testing.T) { t.Parallel() - // Set Gin to Test Mode gin.SetMode(gin.TestMode) - // Setup router with middleware - r := gin.New() - r.Use(CORS("*")) - r.Any("/test", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - - t.Run("Regular GET request", func(t *testing.T) { + t.Run("Wildcard allows all origins", func(t *testing.T) { t.Parallel() + r := gin.New() + r.Use(CORS("*")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) - // Check CORS headers assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID", w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, http.StatusOK, w.Code) }) - t.Run("OPTIONS preflight request", func(t *testing.T) { + t.Run("Empty string allows all origins", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) + }) + + t.Run("Allowed origin is set from whitelist", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com,https://other.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://example.com") + r.ServeHTTP(w, req) + + assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", w.Header().Get("Vary")) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("Second allowed origin is matched", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com, https://other.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://other.com") + r.ServeHTTP(w, req) + + assert.Equal(t, "https://other.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", w.Header().Get("Vary")) + }) + + t.Run("Disallowed origin gets no Access-Control-Allow-Origin", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://evil.com") + r.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Vary")) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("OPTIONS preflight returns 204", func(t *testing.T) { t.Parallel() + r := gin.New() + r.Use(CORS("*")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + w := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/test", nil) r.ServeHTTP(w, req) - // Check CORS headers assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) - assert.Equal(t, "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID", w.Header().Get("Access-Control-Allow-Headers")) - assert.Equal(t, http.StatusNoContent, w.Code) // OPTIONS request should return 204 No Content + assert.Equal(t, http.StatusNoContent, w.Code) }) } diff --git a/backend/internal/api/middleware/middleware_test.go b/backend/internal/api/middleware/middleware_test.go index 00609d5..738f008 100644 --- a/backend/internal/api/middleware/middleware_test.go +++ b/backend/internal/api/middleware/middleware_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "strings" "testing" "github.com/gin-gonic/gin" @@ -75,3 +76,118 @@ func TestRecoveryMiddleware(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "Internal Server Error", response["error"]) } + +func TestRequestIDMiddleware(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + t.Run("Generates new request ID when none provided", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(RequestID()) + var capturedID string + r.GET("/test", func(c *gin.Context) { + if v, ok := c.Get("request_id"); ok { + capturedID = v.(string) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.NotEmpty(t, w.Header().Get("X-Request-ID")) + assert.NotEmpty(t, capturedID) + assert.Equal(t, w.Header().Get("X-Request-ID"), capturedID) + }) + + t.Run("Reuses client-provided X-Request-ID", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(RequestID()) + var capturedID string + r.GET("/test", func(c *gin.Context) { + if v, ok := c.Get("request_id"); ok { + capturedID = v.(string) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("X-Request-ID", "client-id-123") + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "client-id-123", w.Header().Get("X-Request-ID")) + assert.Equal(t, "client-id-123", capturedID) + }) + + t.Run("Generated IDs are unique", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(RequestID()) + r.GET("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + ids := make(map[string]bool) + for i := 0; i < 10; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + id := w.Header().Get("X-Request-ID") + assert.False(t, ids[id], "duplicate request ID generated") + ids[id] = true + } + }) +} + +func TestMaxBodySizeMiddleware(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + t.Run("Allows request within size limit", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(MaxBodySize(1024)) // 1 KB + r.POST("/test", func(c *gin.Context) { + body := make([]byte, 512) + _, err := c.Request.Body.Read(body) + if err != nil && err.Error() != "EOF" { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "body too large"}) + return + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + body := strings.NewReader(strings.Repeat("a", 100)) + req, _ := http.NewRequest("POST", "/test", body) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("Rejects request exceeding size limit", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(MaxBodySize(64)) // 64 bytes + r.POST("/test", func(c *gin.Context) { + body := make([]byte, 128) + _, err := c.Request.Body.Read(body) + if err != nil { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "body too large"}) + return + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + body := strings.NewReader(strings.Repeat("a", 128)) + req, _ := http.NewRequest("POST", "/test", body) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) + }) +} diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index b8202eb..34ec773 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -104,6 +104,7 @@ func (r *TableRepository) Create(ctx context.Context, entity interface{}) error "RowKey": item.Name, // Using Name as the unique key for testing "Name": item.Name, "Price": item.Price, + "Version": item.Version, "CreatedAt": now.Format(time.RFC3339), "UpdatedAt": now.Format(time.RFC3339), } @@ -154,6 +155,11 @@ func (r *TableRepository) FindByID(ctx context.Context, id uint, dest interface{ item.ID = id item.Name = entityData["Name"].(string) item.Price = entityData["Price"].(float64) + if v, ok := entityData["Version"]; ok { + if vf, ok := v.(float64); ok { + item.Version = uint(vf) + } + } createdAt, err := time.Parse(time.RFC3339, entityData["CreatedAt"].(string)) if err != nil { @@ -170,15 +176,18 @@ func (r *TableRepository) FindByID(ctx context.Context, id uint, dest interface{ return nil } -// Update implements the Repository interface +// Update implements the Repository interface. +// For models that implement Versionable (e.g. Item), optimistic locking is enforced: +// the entity is fetched first to compare versions, and the ETag from the GET response +// is passed to UpdateEntity so Azure Table Storage rejects stale writes. func (r *TableRepository) Update(ctx context.Context, entity interface{}) error { item, ok := entity.(*models.Item) if !ok { return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } - // Check if entity exists first - _, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) + // Fetch existing entity (also validates existence) + existing, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.StatusCode == 404 { @@ -187,6 +196,33 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("find", err) } + // Optimistic locking: compare version if the entity is Versionable + if ver, ok := entity.(models.Versionable); ok { + currentVersion := ver.GetVersion() + + // Parse existing entity to check stored version + var existingData map[string]interface{} + if err := json.Unmarshal(existing.Value, &existingData); err != nil { + return dberrors.NewDatabaseError("unmarshal", err) + } + if storedVersion, ok := existingData["Version"]; ok { + var sv uint + switch v := storedVersion.(type) { + case float64: + sv = uint(v) + case json.Number: + n, _ := v.Int64() + sv = uint(n) + } + if currentVersion != sv { + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } + } + + // Increment version for the update + ver.SetVersion(currentVersion + 1) + } + // Create Azure Table entity now := time.Now().UTC() entityJson := map[string]interface{}{ @@ -194,6 +230,7 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error "RowKey": strconv.FormatUint(uint64(item.ID), 10), "Name": item.Name, "Price": item.Price, + "Version": item.Version, "CreatedAt": item.CreatedAt.Format(time.RFC3339), "UpdatedAt": now.Format(time.RFC3339), } @@ -203,8 +240,21 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("marshal", err) } - _, err = r.client.UpdateEntity(ctx, entityBytes, nil) + // Use ETag from the GET response for conditional update + updateOpts := &aztables.UpdateEntityOptions{ + IfMatch: &existing.ETag, + UpdateMode: aztables.UpdateModeMerge, + } + _, err = r.client.UpdateEntity(ctx, entityBytes, updateOpts) if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == 412 { + // Precondition failed — concurrent modification + if ver, ok := entity.(models.Versionable); ok { + ver.SetVersion(ver.GetVersion() - 1) // Roll back + } + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } return dberrors.NewDatabaseError("update", err) } diff --git a/backend/internal/database/azure/table_test.go b/backend/internal/database/azure/table_test.go index f712200..55be7e5 100644 --- a/backend/internal/database/azure/table_test.go +++ b/backend/internal/database/azure/table_test.go @@ -135,10 +135,14 @@ func TestTableClientOperations(t *testing.T) { mockClient := &mockClient{ getEntity: func(ctx context.Context, partitionKey, rowKey string, options *aztables.GetEntityOptions) (aztables.GetEntityResponse, error) { return aztables.GetEntityResponse{ - Value: []byte(`{"Name":"test","Price":10.5}`), + ETag: "etag-1", + Value: []byte(`{"Name":"test","Price":10.5,"Version":1}`), }, nil }, updateEntity: func(ctx context.Context, entity []byte, options *aztables.UpdateEntityOptions) (aztables.UpdateEntityResponse, error) { + // Verify ETag is passed for conditional update + assert.NotNil(t, options) + assert.NotNil(t, options.IfMatch) return aztables.UpdateEntityResponse{}, nil }, } @@ -147,11 +151,39 @@ func TestTableClientOperations(t *testing.T) { repo.SetTestClient(mockClient) item := &models.Item{ - Name: "test", - Price: 10.5, + Name: "test", + Price: 10.5, + Version: 1, } err := repo.Update(context.Background(), item) assert.NoError(t, err) + assert.Equal(t, uint(2), item.Version, "Version should be incremented after update") + }) + + t.Run("update fails on version mismatch", func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{ + getEntity: func(ctx context.Context, partitionKey, rowKey string, options *aztables.GetEntityOptions) (aztables.GetEntityResponse, error) { + return aztables.GetEntityResponse{ + ETag: "etag-1", + Value: []byte(`{"Name":"test","Price":10.5,"Version":2}`), + }, nil + }, + } + + repo := azure.NewTestTableRepository("testtable") + repo.SetTestClient(mockClient) + + item := &models.Item{ + Name: "test", + Price: 15.0, + Version: 1, // Stale version + } + err := repo.Update(context.Background(), item) + assert.Error(t, err) + assert.Contains(t, err.Error(), "version mismatch") + assert.Equal(t, uint(1), item.Version, "Version should not change on mismatch") }) t.Run("can delete an entity", func(t *testing.T) { diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 395e58d..e435f91 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -138,11 +138,22 @@ func (r *GenericRepository) Delete(ctx context.Context, entity interface{}) erro return nil } +// allowedFilterFields is a whitelist of column names that may be used in filter queries. +// This prevents SQL injection via the Filter.Field parameter. +var allowedFilterFields = map[string]bool{ + "name": true, + "price": true, +} + func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { query := r.db.WithContext(ctx) for _, cond := range conditions { switch c := cond.(type) { case Filter: + if !allowedFilterFields[c.Field] { + return dberrors.NewDatabaseError("list", + fmt.Errorf("invalid filter field: %q", c.Field)) + } switch c.Op { case "exact": query = query.Where(fmt.Sprintf("%s = ?", c.Field), c.Value) From 5f9cd92efc8e8e84276500dc2ab588b1147bee39 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 18:17:53 +0100 Subject: [PATCH 04/18] fix: address third round of PR review comments - LIKE value: use fmt.Sprint for proper parameterized escaping - Remove deprecated CheckError function and its test (no callers) - Azure Table Create: initialize Version to 1 when not set, consistent with GORM default and MockRepository --- backend/internal/database/azure/table.go | 5 +++++ backend/internal/database/azure/table_test.go | 4 ++-- backend/internal/models/models.go | 2 +- backend/pkg/utils/utils.go | 6 ------ backend/pkg/utils/utils_test.go | 6 ------ 5 files changed, 8 insertions(+), 15 deletions(-) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 34ec773..5abd840 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -96,6 +96,11 @@ func (r *TableRepository) Create(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("type_assertion", errors.New("entity must be *models.Item")) } + // Initialize version for new entities (consistent with GORM default:1 and MockRepository) + if item.Version == 0 { + item.Version = 1 + } + // Create Azure Table entity now := time.Now().UTC() diff --git a/backend/internal/database/azure/table_test.go b/backend/internal/database/azure/table_test.go index 55be7e5..a9cb257 100644 --- a/backend/internal/database/azure/table_test.go +++ b/backend/internal/database/azure/table_test.go @@ -135,7 +135,7 @@ func TestTableClientOperations(t *testing.T) { mockClient := &mockClient{ getEntity: func(ctx context.Context, partitionKey, rowKey string, options *aztables.GetEntityOptions) (aztables.GetEntityResponse, error) { return aztables.GetEntityResponse{ - ETag: "etag-1", + ETag: "etag-1", Value: []byte(`{"Name":"test","Price":10.5,"Version":1}`), }, nil }, @@ -166,7 +166,7 @@ func TestTableClientOperations(t *testing.T) { mockClient := &mockClient{ getEntity: func(ctx context.Context, partitionKey, rowKey string, options *aztables.GetEntityOptions) (aztables.GetEntityResponse, error) { return aztables.GetEntityResponse{ - ETag: "etag-1", + ETag: "etag-1", Value: []byte(`{"Name":"test","Price":10.5,"Version":2}`), }, nil }, diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index e435f91..9e826d0 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -163,7 +163,7 @@ func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditio query = query.Where(fmt.Sprintf("%s <= ?", c.Field), c.Value) default: // Default to LIKE for substring matching - query = query.Where(fmt.Sprintf("%s LIKE ?", c.Field), fmt.Sprintf("%%%v%%", c.Value)) + query = query.Where(fmt.Sprintf("%s LIKE ?", c.Field), "%"+fmt.Sprint(c.Value)+"%") } case Pagination: if c.Limit > 0 { diff --git a/backend/pkg/utils/utils.go b/backend/pkg/utils/utils.go index 45e2d67..db6c331 100644 --- a/backend/pkg/utils/utils.go +++ b/backend/pkg/utils/utils.go @@ -6,12 +6,6 @@ import ( "math/big" ) -// CheckError returns the error for the caller to handle. -// Deprecated: Use explicit error handling instead of this wrapper. -func CheckError(err error) error { - return err -} - // GenerateRandomString generates a cryptographically random string of the specified length. // Returns an error if length is negative or if the system random source fails. func GenerateRandomString(length int) (string, error) { diff --git a/backend/pkg/utils/utils_test.go b/backend/pkg/utils/utils_test.go index f2db26b..917cddf 100644 --- a/backend/pkg/utils/utils_test.go +++ b/backend/pkg/utils/utils_test.go @@ -41,10 +41,4 @@ func TestUtilFunctions(t *testing.T) { "unexpected character: %c", c) } }) - - t.Run("CheckError returns the error", func(t *testing.T) { - t.Parallel() - assert.Nil(t, CheckError(nil)) - assert.Error(t, CheckError(assert.AnError)) - }) } From 77c3312bd792f8994258e35c9000dfda424a5955 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 18:29:58 +0100 Subject: [PATCH 05/18] fix: address fourth round of PR review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wire rate limiter into API routes and stop during graceful shutdown - Add migration to update Version=0→1 for existing items (dialect-aware) - Restore AutoMigrate() call in MySQL repository factory - Remove FindByID pre-check from DeleteItem to avoid race condition --- backend/api/main.go | 3 ++- backend/internal/api/handlers/items.go | 13 ++++-------- backend/internal/api/routes/routes.go | 10 ++++++++- backend/internal/api/routes/routes_test.go | 3 ++- backend/internal/database/migrations.go | 24 ++++++++++++++++++++++ backend/internal/database/repository.go | 5 +++++ 6 files changed, 46 insertions(+), 12 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index cdf0be4..62b3e85 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -58,7 +58,8 @@ func main() { // Setup router — use gin.New() since SetupRoutes registers its own Logger and Recovery middleware. router := gin.New() - routes.SetupRoutes(router, repo, healthChecker, cfg) + rateLimiter := routes.SetupRoutes(router, repo, healthChecker, cfg) + defer rateLimiter.Stop() router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) // Create server with timeouts diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index 51dda28..8dec361 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -233,15 +233,10 @@ func (h *Handler) DeleteItem(c *gin.Context) { return } - // Verify item exists before attempting delete - var item models.Item - if err := h.repository.FindByID(c.Request.Context(), uint(id), &item); err != nil { - status, message := handleDBError(err) - c.JSON(status, gin.H{"error": message}) - return - } - - if err := h.repository.Delete(c.Request.Context(), &item); err != nil { + // Delete directly — the repository returns ErrNotFound if the item doesn't exist. + // This avoids a race condition between a FindByID check and the actual delete. + item := &models.Item{Base: models.Base{ID: uint(id)}} + if err := h.repository.Delete(c.Request.Context(), item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) return diff --git a/backend/internal/api/routes/routes.go b/backend/internal/api/routes/routes.go index 4aa9788..692927f 100644 --- a/backend/internal/api/routes/routes.go +++ b/backend/internal/api/routes/routes.go @@ -6,13 +6,15 @@ import ( "backend/internal/config" "backend/internal/health" "backend/internal/models" + "time" "github.com/gin-gonic/gin" ) // SetupRoutes configures all the routes for our application. // healthChecker is injected from main so the readiness endpoint reflects real dependency health. -func SetupRoutes(router *gin.Engine, repository models.Repository, healthChecker *health.HealthChecker, cfg *config.Config) { +// Returns the rate limiter so the caller can stop it during shutdown. +func SetupRoutes(router *gin.Engine, repository models.Repository, healthChecker *health.HealthChecker, cfg *config.Config) *handlers.RateLimiter { // Add middleware router.Use(middleware.RequestID()) router.Use(middleware.Logger()) @@ -28,8 +30,12 @@ func SetupRoutes(router *gin.Engine, repository models.Repository, healthChecker healthGroup.GET("", handlers.HealthCheck) // Keep the original health check for backward compatibility } + // Rate limiter for API routes + rateLimiter := handlers.NewRateLimiter(100, time.Minute) + // API v1 routes v1 := router.Group("/api/v1") + v1.Use(rateLimiter.RateLimit()) { // Ping endpoint v1.GET("/ping", handlers.Ping) @@ -45,4 +51,6 @@ func SetupRoutes(router *gin.Engine, repository models.Repository, healthChecker items.DELETE("/:id", itemsHandler.DeleteItem) } } + + return rateLimiter } diff --git a/backend/internal/api/routes/routes_test.go b/backend/internal/api/routes/routes_test.go index ebfe65e..affb6dd 100644 --- a/backend/internal/api/routes/routes_test.go +++ b/backend/internal/api/routes/routes_test.go @@ -34,7 +34,8 @@ func TestSetupRoutes(t *testing.T) { } // Setup routes - SetupRoutes(router, mockRepo, healthChecker, cfg) + rl := SetupRoutes(router, mockRepo, healthChecker, cfg) + defer rl.Stop() // Test cases tests := []struct { diff --git a/backend/internal/database/migrations.go b/backend/internal/database/migrations.go index 64b3e6b..ccba18e 100644 --- a/backend/internal/database/migrations.go +++ b/backend/internal/database/migrations.go @@ -61,6 +61,30 @@ func (d *Database) AutoMigrate() error { }, }) + // Update existing items with Version=0 to Version=1 and change column default + migrator.AddMigration(schema.Migration{ + Version: "20231201000003", + Name: "update_items_version_default", + Description: "Set Version default to 1 for optimistic locking and update existing rows", + Up: func(tx *gorm.DB) error { + // Update existing rows that have the old default of 0 + if err := tx.Exec("UPDATE items SET version = 1 WHERE version = 0").Error; err != nil { + return err + } + // Alter column default to 1 (MySQL syntax; SQLite defaults are set via AutoMigrate) + if tx.Dialector.Name() == "mysql" { + return tx.Exec("ALTER TABLE items ALTER COLUMN version SET DEFAULT 1").Error + } + return nil + }, + Down: func(tx *gorm.DB) error { + if tx.Dialector.Name() == "mysql" { + return tx.Exec("ALTER TABLE items ALTER COLUMN version SET DEFAULT 0").Error + } + return nil + }, + }) + // Run migrations if err := migrator.MigrateUp(); err != nil { return err diff --git a/backend/internal/database/repository.go b/backend/internal/database/repository.go index e74bd2f..d181842 100644 --- a/backend/internal/database/repository.go +++ b/backend/internal/database/repository.go @@ -28,5 +28,10 @@ func NewRepository(cfg *config.Config) (models.Repository, error) { return nil, fmt.Errorf("failed to initialize MySQL database: %w", err) } + // Run database migrations + if err := db.AutoMigrate(); err != nil { + return nil, fmt.Errorf("failed to run database migrations: %w", err) + } + return models.NewRepository(db.DB), nil } From 7c9df5bd28d848224e39cc5792fc64008648f905 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 18:35:09 +0100 Subject: [PATCH 06/18] fix: address fifth round of PR review comments - Use comma-ok idiom for all Azure Table type assertions to prevent panics - Only set CORS Allow-Methods/Allow-Headers when origin is allowed; early return for disallowed OPTIONS preflight - Add ID != 0 validation in Azure Table Update and Delete methods - Remove redundant handler-level version check; pass client version to repository for optimistic locking - Make allowedFilterFields configurable per GenericRepository instance via NewRepositoryWithFilterFields constructor --- backend/internal/api/handlers/items.go | 14 +++--- backend/internal/api/middleware/cors_test.go | 18 +++++++ backend/internal/api/middleware/middleware.go | 10 ++++ backend/internal/database/azure/table.go | 50 +++++++++++++++---- backend/internal/database/azure/table_test.go | 3 ++ backend/internal/models/models.go | 31 ++++++++---- 6 files changed, 99 insertions(+), 27 deletions(-) diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index 8dec361..5a9e21f 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -194,16 +194,16 @@ func (h *Handler) UpdateItem(c *gin.Context) { return } - // Version check for optimistic locking (if version was provided in request) - if updateItem.Version > 0 && updateItem.Version != currentItem.Version { - c.JSON(http.StatusConflict, gin.H{"error": "Item has been modified by another request"}) - return - } - - // Update fields from request but keep DB version for repository-level optimistic locking + // Update fields from request currentItem.Name = updateItem.Name currentItem.Price = updateItem.Price + // If the client provided a version, use it for optimistic locking; + // otherwise keep the DB version so the repository check passes. + if updateItem.Version > 0 { + currentItem.Version = updateItem.Version + } + if err := h.repository.Update(c.Request.Context(), ¤tItem); err != nil { if strings.Contains(err.Error(), "version mismatch") { c.JSON(http.StatusConflict, gin.H{"error": "Item has been modified by another request"}) diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index 5240283..ee56a87 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -86,9 +86,27 @@ func TestCORSMiddleware(t *testing.T) { assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) assert.Empty(t, w.Header().Get("Vary")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) assert.Equal(t, http.StatusOK, w.Code) }) + t.Run("Disallowed origin OPTIONS preflight returns 204 without CORS headers", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("Origin", "https://evil.com") + r.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, http.StatusNoContent, w.Code) + }) + t.Run("OPTIONS preflight returns 204", func(t *testing.T) { t.Parallel() r := gin.New() diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index 8a9857a..7a650d6 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -20,13 +20,23 @@ func CORS(allowedOrigins string) gin.HandlerFunc { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") } else { requestOrigin := c.Request.Header.Get("Origin") + allowed := false for _, origin := range strings.Split(allowedOrigins, ",") { if strings.TrimSpace(origin) == requestOrigin { c.Writer.Header().Set("Access-Control-Allow-Origin", requestOrigin) c.Writer.Header().Set("Vary", "Origin") + allowed = true break } } + if !allowed { + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(http.StatusNoContent) + return + } + c.Next() + return + } } c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID") diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 5abd840..5131c37 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -158,21 +158,40 @@ func (r *TableRepository) FindByID(ctx context.Context, id uint, dest interface{ // Map entity to item item.ID = id - item.Name = entityData["Name"].(string) - item.Price = entityData["Price"].(float64) + + name, ok := entityData["Name"].(string) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid Name field")) + } + item.Name = name + + price, ok := entityData["Price"].(float64) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid Price field")) + } + item.Price = price + if v, ok := entityData["Version"]; ok { if vf, ok := v.(float64); ok { item.Version = uint(vf) } } - createdAt, err := time.Parse(time.RFC3339, entityData["CreatedAt"].(string)) + createdAtStr, ok := entityData["CreatedAt"].(string) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid CreatedAt field")) + } + createdAt, err := time.Parse(time.RFC3339, createdAtStr) if err != nil { return dberrors.NewDatabaseError("parse_time", err) } item.CreatedAt = createdAt - updatedAt, err := time.Parse(time.RFC3339, entityData["UpdatedAt"].(string)) + updatedAtStr, ok := entityData["UpdatedAt"].(string) + if !ok { + return dberrors.NewDatabaseError("unmarshal", fmt.Errorf("missing or invalid UpdatedAt field")) + } + updatedAt, err := time.Parse(time.RFC3339, updatedAtStr) if err != nil { return dberrors.NewDatabaseError("parse_time", err) } @@ -191,6 +210,10 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } + if item.ID == 0 { + return dberrors.NewDatabaseError("update", dberrors.ErrValidation) + } + // Fetch existing entity (also validates existence) existing, err := r.client.GetEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { @@ -274,6 +297,10 @@ func (r *TableRepository) Delete(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("type_assertion", fmt.Errorf("entity must be *models.Item")) } + if item.ID == 0 { + return dberrors.NewDatabaseError("delete", dberrors.ErrValidation) + } + _, err := r.client.DeleteEntity(ctx, "items", strconv.FormatUint(uint64(item.ID), 10), nil) if err != nil { var respErr *azcore.ResponseError @@ -349,9 +376,14 @@ func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions return dberrors.NewDatabaseError("unmarshal", err) } - id, _ := strconv.ParseUint(entityData["RowKey"].(string), 10, 32) - createdAt, _ := time.Parse(time.RFC3339, entityData["CreatedAt"].(string)) - updatedAt, _ := time.Parse(time.RFC3339, entityData["UpdatedAt"].(string)) + rowKey, _ := entityData["RowKey"].(string) + id, _ := strconv.ParseUint(rowKey, 10, 32) + createdAtStr, _ := entityData["CreatedAt"].(string) + createdAt, _ := time.Parse(time.RFC3339, createdAtStr) + updatedAtStr, _ := entityData["UpdatedAt"].(string) + updatedAt, _ := time.Parse(time.RFC3339, updatedAtStr) + name, _ := entityData["Name"].(string) + price, _ := entityData["Price"].(float64) item := models.Item{ Base: models.Base{ @@ -359,8 +391,8 @@ func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions CreatedAt: createdAt, UpdatedAt: updatedAt, }, - Name: entityData["Name"].(string), - Price: entityData["Price"].(float64), + Name: name, + Price: price, } // Apply name contains filter if specified diff --git a/backend/internal/database/azure/table_test.go b/backend/internal/database/azure/table_test.go index a9cb257..be70de2 100644 --- a/backend/internal/database/azure/table_test.go +++ b/backend/internal/database/azure/table_test.go @@ -155,6 +155,7 @@ func TestTableClientOperations(t *testing.T) { Price: 10.5, Version: 1, } + item.ID = 1 err := repo.Update(context.Background(), item) assert.NoError(t, err) assert.Equal(t, uint(2), item.Version, "Version should be incremented after update") @@ -180,6 +181,7 @@ func TestTableClientOperations(t *testing.T) { Price: 15.0, Version: 1, // Stale version } + item.ID = 1 err := repo.Update(context.Background(), item) assert.Error(t, err) assert.Contains(t, err.Error(), "version mismatch") @@ -202,6 +204,7 @@ func TestTableClientOperations(t *testing.T) { Name: "test", Price: 10.5, } + item.ID = 1 err := repo.Delete(context.Background(), item) assert.NoError(t, err) }) diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 9e826d0..f175b23 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -65,12 +65,28 @@ type Repository interface { // GenericRepository implements the Repository interface type GenericRepository struct { - db *gorm.DB + db *gorm.DB + allowedFilterFields map[string]bool } -// NewRepository creates a new GenericRepository +// NewRepository creates a new GenericRepository with default allowed filter fields. func NewRepository(db *gorm.DB) Repository { - return &GenericRepository{db: db} + return &GenericRepository{ + db: db, + allowedFilterFields: map[string]bool{ + "name": true, + "price": true, + }, + } +} + +// NewRepositoryWithFilterFields creates a GenericRepository with a custom filter field whitelist. +func NewRepositoryWithFilterFields(db *gorm.DB, fields []string) Repository { + allowed := make(map[string]bool, len(fields)) + for _, f := range fields { + allowed[f] = true + } + return &GenericRepository{db: db, allowedFilterFields: allowed} } // Ping checks if the database is reachable @@ -138,19 +154,12 @@ func (r *GenericRepository) Delete(ctx context.Context, entity interface{}) erro return nil } -// allowedFilterFields is a whitelist of column names that may be used in filter queries. -// This prevents SQL injection via the Filter.Field parameter. -var allowedFilterFields = map[string]bool{ - "name": true, - "price": true, -} - func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditions ...interface{}) error { query := r.db.WithContext(ctx) for _, cond := range conditions { switch c := cond.(type) { case Filter: - if !allowedFilterFields[c.Field] { + if !r.allowedFilterFields[c.Field] { return dberrors.NewDatabaseError("list", fmt.Errorf("invalid filter field: %q", c.Field)) } From 9bf93f0f50073a7f631f1e5abf49c6f744ccf985 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 19:35:05 +0100 Subject: [PATCH 07/18] fix: address sixth round of PR review comments - Use errors.Is(err, io.EOF) instead of string comparison in middleware test - Convert charset index to int for proper byte slice indexing in GenerateRandomString - Fix azure_test.go to use context-aware HealthCheck signature and CheckReadiness(ctx) - Validate all type assertions and parse results in Azure Table List method --- backend/api/azure_test.go | 5 ++- .../api/middleware/middleware_test.go | 4 +- backend/internal/database/azure/table.go | 44 +++++++++++++++---- backend/pkg/utils/utils.go | 2 +- 4 files changed, 43 insertions(+), 12 deletions(-) diff --git a/backend/api/azure_test.go b/backend/api/azure_test.go index c40312b..a43a4e8 100644 --- a/backend/api/azure_test.go +++ b/backend/api/azure_test.go @@ -3,6 +3,7 @@ package main import ( + "context" "os" "testing" @@ -68,13 +69,13 @@ func TestAzureTableIntegration(t *testing.T) { healthChecker := health.New() // For Azure Table, we add a mock health check since the repo is nil - healthChecker.AddCheck("database", func() error { + healthChecker.AddCheck("database", func(_ context.Context) error { return nil // Mock check since we expect repo to be nil in this test }) // The health check should return UP since we're using a mock check healthChecker.SetReady(true) - status := healthChecker.CheckReadiness() + status := healthChecker.CheckReadiness(context.Background()) assert.Equal(t, "UP", status.Status) }) diff --git a/backend/internal/api/middleware/middleware_test.go b/backend/internal/api/middleware/middleware_test.go index 738f008..bd228f4 100644 --- a/backend/internal/api/middleware/middleware_test.go +++ b/backend/internal/api/middleware/middleware_test.go @@ -3,6 +3,8 @@ package middleware import ( "bytes" "encoding/json" + "errors" + "io" "log/slog" "net/http" "net/http/httptest" @@ -154,7 +156,7 @@ func TestMaxBodySizeMiddleware(t *testing.T) { r.POST("/test", func(c *gin.Context) { body := make([]byte, 512) _, err := c.Request.Body.Read(body) - if err != nil && err.Error() != "EOF" { + if err != nil && !errors.Is(err, io.EOF) { c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "body too large"}) return } diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 5131c37..db515b9 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -376,14 +376,42 @@ func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions return dberrors.NewDatabaseError("unmarshal", err) } - rowKey, _ := entityData["RowKey"].(string) - id, _ := strconv.ParseUint(rowKey, 10, 32) - createdAtStr, _ := entityData["CreatedAt"].(string) - createdAt, _ := time.Parse(time.RFC3339, createdAtStr) - updatedAtStr, _ := entityData["UpdatedAt"].(string) - updatedAt, _ := time.Parse(time.RFC3339, updatedAtStr) - name, _ := entityData["Name"].(string) - price, _ := entityData["Price"].(float64) + rowKey, ok := entityData["RowKey"].(string) + if !ok || rowKey == "" { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid RowKey")) + } + id, err := strconv.ParseUint(rowKey, 10, 32) + if err != nil { + return dberrors.NewDatabaseError("list", fmt.Errorf("invalid RowKey %q: %w", rowKey, err)) + } + + createdAtStr, ok := entityData["CreatedAt"].(string) + if !ok || createdAtStr == "" { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid CreatedAt")) + } + createdAt, err := time.Parse(time.RFC3339, createdAtStr) + if err != nil { + return dberrors.NewDatabaseError("list", fmt.Errorf("invalid CreatedAt %q: %w", createdAtStr, err)) + } + + updatedAtStr, ok := entityData["UpdatedAt"].(string) + if !ok || updatedAtStr == "" { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid UpdatedAt")) + } + updatedAt, err := time.Parse(time.RFC3339, updatedAtStr) + if err != nil { + return dberrors.NewDatabaseError("list", fmt.Errorf("invalid UpdatedAt %q: %w", updatedAtStr, err)) + } + + name, ok := entityData["Name"].(string) + if !ok { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid Name")) + } + + price, ok := entityData["Price"].(float64) + if !ok { + return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid Price")) + } item := models.Item{ Base: models.Base{ diff --git a/backend/pkg/utils/utils.go b/backend/pkg/utils/utils.go index db6c331..e1bdf53 100644 --- a/backend/pkg/utils/utils.go +++ b/backend/pkg/utils/utils.go @@ -24,7 +24,7 @@ func GenerateRandomString(length int) (string, error) { if err != nil { return "", fmt.Errorf("failed to generate random byte at index %d: %w", i, err) } - b[i] = charset[n.Int64()] + b[i] = charset[int(n.Int64())] } return string(b), nil } From 3743b8e47539efc1711fe29dc64c3941bf0989b9 Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 20:52:23 +0100 Subject: [PATCH 08/18] fix: make Azure Table RowKey consistent across all CRUD operations - Generate numeric ID via atomic counter (seeded from UnixMilli) in Create() - Use strconv.FormatUint(item.ID) as RowKey instead of item.Name - Assign generated ID back to item.ID so callers can use FindByID - Remove duplicate comment on TableRepository struct --- backend/internal/database/azure/table.go | 22 +++++++++++++++++-- backend/internal/database/azure/table_test.go | 3 ++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index db515b9..8ec67d2 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -7,6 +7,7 @@ import ( "fmt" "strconv" "strings" + "sync/atomic" "time" "backend/internal/models" @@ -16,7 +17,19 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/data/aztables" ) -// TableRepository implements the Repository interface for Azure Table Storage +// idCounter is a monotonic counter used to generate unique IDs for Azure Table entities. +// Seeded from the current Unix timestamp in milliseconds so IDs are globally unique +// across restarts (within millisecond granularity). +var idCounter atomic.Uint64 + +func init() { + idCounter.Store(uint64(time.Now().UnixMilli())) +} + +func nextID() uint { + return uint(idCounter.Add(1)) +} + // TableRepository implements the Repository interface for Azure Table Storage type TableRepository struct { client AzureTableClient @@ -104,9 +117,14 @@ func (r *TableRepository) Create(ctx context.Context, entity interface{}) error // Create Azure Table entity now := time.Now().UTC() + // Generate a numeric ID (Azure Table Storage has no auto-increment) + if item.ID == 0 { + item.ID = nextID() + } + entityJSON := map[string]interface{}{ "PartitionKey": "items", - "RowKey": item.Name, // Using Name as the unique key for testing + "RowKey": strconv.FormatUint(uint64(item.ID), 10), "Name": item.Name, "Price": item.Price, "Version": item.Version, diff --git a/backend/internal/database/azure/table_test.go b/backend/internal/database/azure/table_test.go index be70de2..3d6bfb0 100644 --- a/backend/internal/database/azure/table_test.go +++ b/backend/internal/database/azure/table_test.go @@ -94,9 +94,10 @@ func TestTableClientOperations(t *testing.T) { } err := repo.Create(context.Background(), item) assert.NoError(t, err) + assert.NotZero(t, item.ID, "Create should assign an ID") var retrieved models.Item - err = repo.FindByID(context.Background(), 1, &retrieved) + err = repo.FindByID(context.Background(), item.ID, &retrieved) assert.NoError(t, err) assert.Equal(t, item.Name, retrieved.Name) assert.InDelta(t, item.Price, retrieved.Price, 0.001) From 09edfaa8a6f18d717ee940e341d86f898423646f Mon Sep 17 00:00:00 2001 From: omattsson Date: Tue, 24 Feb 2026 21:39:23 +0100 Subject: [PATCH 09/18] fix: address eighth round of PR review comments - Populate Version field in Azure Table List() for consistent optimistic locking - Replace process-local atomic counter with crypto/rand-based ID generator to prevent collisions in multi-instance deployments - Wrap Validator.Validate() errors with ErrValidation sentinel in both Create() and Update() so handlers map them to 400 instead of 500 --- backend/internal/database/azure/table.go | 37 +++++++++++++++--------- backend/internal/models/models.go | 6 ++-- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 8ec67d2..4790a9a 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -2,12 +2,13 @@ package azure import ( "context" + "crypto/rand" "encoding/json" "errors" "fmt" + "math/big" "strconv" "strings" - "sync/atomic" "time" "backend/internal/models" @@ -17,17 +18,16 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/data/aztables" ) -// idCounter is a monotonic counter used to generate unique IDs for Azure Table entities. -// Seeded from the current Unix timestamp in milliseconds so IDs are globally unique -// across restarts (within millisecond granularity). -var idCounter atomic.Uint64 - -func init() { - idCounter.Store(uint64(time.Now().UnixMilli())) -} - -func nextID() uint { - return uint(idCounter.Add(1)) +// nextID generates a collision-resistant numeric ID by combining +// the current Unix timestamp in nanoseconds with a random component. +func nextID() (uint, error) { + // Use upper 48 bits from time, lower 16 bits from crypto/rand + ts := uint64(time.Now().UnixNano()) + rb, err := rand.Int(rand.Reader, big.NewInt(1<<16)) + if err != nil { + return 0, fmt.Errorf("failed to generate random ID component: %w", err) + } + return uint((ts << 16) | rb.Uint64()), nil } // TableRepository implements the Repository interface for Azure Table Storage @@ -119,7 +119,11 @@ func (r *TableRepository) Create(ctx context.Context, entity interface{}) error // Generate a numeric ID (Azure Table Storage has no auto-increment) if item.ID == 0 { - item.ID = nextID() + id, err := nextID() + if err != nil { + return dberrors.NewDatabaseError("create", err) + } + item.ID = id } entityJSON := map[string]interface{}{ @@ -441,6 +445,13 @@ func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions Price: price, } + // Populate Version if present + if v, ok := entityData["Version"]; ok { + if vf, ok := v.(float64); ok { + item.Version = uint(vf) + } + } + // Apply name contains filter if specified if nameContainsFilter != "" { if !strings.Contains(strings.ToLower(item.Name), nameContainsFilter) { diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index f175b23..edc6fe6 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -101,7 +101,8 @@ func (r *GenericRepository) Ping(ctx context.Context) error { func (r *GenericRepository) Create(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", err) + return dberrors.NewDatabaseError("validate", + fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) } } @@ -121,7 +122,8 @@ func (r *GenericRepository) FindByID(ctx context.Context, id uint, dest interfac func (r *GenericRepository) Update(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", err) + return dberrors.NewDatabaseError("validate", + fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) } } From 8739e22edfe883db8dac8f6b028fec937d07c4c2 Mon Sep 17 00:00:00 2001 From: omattsson Date: Wed, 25 Feb 2026 08:23:39 +0100 Subject: [PATCH 10/18] fix: address ninth round of PR review comments - Document TOCTOU behavior in UpdateItem when client omits version - Escape SQL LIKE wildcards (% and _) to prevent unintended matching - Reduce rate limiter cleanup interval to window/2 for faster purging - Use RLock in rate limiter check phase to reduce lock contention - Increase Azure Table ID randomness from 16 to 48 bits - Handle json.Number Int64 error in Azure Table version parsing - Default Azure Table List Version to 1 when absent or invalid - Clean up DB connection if AutoMigrate fails (prevent resource leak) - Simplify GenerateRandomString error message - Log 'Server starting' before ListenAndServe (not after) - Use MAX(version)+1 in migration to avoid version collisions - Abort with 403 for non-whitelisted CORS origins (defense-in-depth) - Add safety comment for SQL field interpolation (whitelist-guarded) - Document CheckLiveness context parameter rationale - Document optimistic locking Save() behavior --- backend/api/main.go | 3 +- backend/internal/api/handlers/items.go | 7 +++-- backend/internal/api/handlers/rate_limiter.go | 31 ++++++++----------- backend/internal/api/middleware/cors_test.go | 6 ++-- backend/internal/api/middleware/middleware.go | 8 ++--- backend/internal/database/azure/table.go | 23 +++++++++----- backend/internal/database/migrations.go | 11 +++++-- backend/internal/database/repository.go | 6 +++- backend/internal/health/health.go | 3 ++ backend/internal/models/models.go | 17 ++++++++-- backend/pkg/utils/utils.go | 2 +- 11 files changed, 72 insertions(+), 45 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index 62b3e85..91eef74 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -73,14 +73,13 @@ func main() { // Start server in a goroutine go func() { + slog.Info("Server starting", "addr", srv.Addr) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { slog.Error("Failed to start server", "error", err) os.Exit(1) } }() - slog.Info("Server started", "addr", srv.Addr) - // Wait for interrupt signal quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index 5a9e21f..8f5c132 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -198,8 +198,11 @@ func (h *Handler) UpdateItem(c *gin.Context) { currentItem.Name = updateItem.Name currentItem.Price = updateItem.Price - // If the client provided a version, use it for optimistic locking; - // otherwise keep the DB version so the repository check passes. + // Optimistic locking: if the client provided a version, use it so the + // repository can detect conflicts. If version=0 (not provided), the + // repository uses the version we just read — this still detects conflicts + // that occur between our FindByID and the repository's WHERE-version check, + // but the client must send the version to guarantee end-to-end safety. if updateItem.Version > 0 { currentItem.Version = updateItem.Version } diff --git a/backend/internal/api/handlers/rate_limiter.go b/backend/internal/api/handlers/rate_limiter.go index 7728e4f..bb49b33 100644 --- a/backend/internal/api/handlers/rate_limiter.go +++ b/backend/internal/api/handlers/rate_limiter.go @@ -40,43 +40,38 @@ func (rl *RateLimiter) Stop() { func (rl *RateLimiter) RateLimit() gin.HandlerFunc { return func(c *gin.Context) { ip := c.ClientIP() - - rl.Lock() - defer rl.Unlock() - now := time.Now() windowStart := now.Add(-rl.window) - // Initialize requests map for this IP if needed - if _, exists := rl.requests[ip]; !exists { - rl.requests[ip] = make([]time.Time, 0) - } - - // Remove old requests (outside our time window) - var valid []time.Time - for _, t := range rl.requests[ip] { + // Read-lock to check current count without blocking other readers. + rl.RLock() + times := rl.requests[ip] + count := 0 + for _, t := range times { if t.After(windowStart) { - valid = append(valid, t) + count++ } } - rl.requests[ip] = valid + rl.RUnlock() - // Check if limit exceeded - if len(rl.requests[ip]) >= rl.limit { + if count >= rl.limit { c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) c.Abort() return } - // Add current request + // Write-lock only to add the new request timestamp. + rl.Lock() rl.requests[ip] = append(rl.requests[ip], now) + rl.Unlock() + c.Next() } } // cleanup periodically removes expired entries to prevent memory leaks. func (rl *RateLimiter) cleanup() { - ticker := time.NewTicker(rl.window * 2) + ticker := time.NewTicker(rl.window / 2) defer ticker.Stop() for { select { diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index ee56a87..d0db9e4 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -88,10 +88,10 @@ func TestCORSMiddleware(t *testing.T) { assert.Empty(t, w.Header().Get("Vary")) assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) }) - t.Run("Disallowed origin OPTIONS preflight returns 204 without CORS headers", func(t *testing.T) { + t.Run("Disallowed origin OPTIONS preflight returns 403 without CORS headers", func(t *testing.T) { t.Parallel() r := gin.New() r.Use(CORS("https://example.com")) @@ -104,7 +104,7 @@ func TestCORSMiddleware(t *testing.T) { assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) - assert.Equal(t, http.StatusNoContent, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) }) t.Run("OPTIONS preflight returns 204", func(t *testing.T) { diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index 7a650d6..c5a8b20 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -30,11 +30,9 @@ func CORS(allowedOrigins string) gin.HandlerFunc { } } if !allowed { - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(http.StatusNoContent) - return - } - c.Next() + // Block requests from non-whitelisted origins as defense-in-depth; + // browsers enforce CORS client-side, but we also enforce server-side. + c.AbortWithStatus(http.StatusForbidden) return } } diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 4790a9a..aa29efc 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -18,16 +18,16 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/data/aztables" ) -// nextID generates a collision-resistant numeric ID by combining -// the current Unix timestamp in nanoseconds with a random component. +// nextID generates a collision-resistant numeric ID using a cryptographically +// secure random component. We use 48 bits of randomness to keep collision +// probability extremely low even under concurrency across multiple instances. func nextID() (uint, error) { - // Use upper 48 bits from time, lower 16 bits from crypto/rand - ts := uint64(time.Now().UnixNano()) - rb, err := rand.Int(rand.Reader, big.NewInt(1<<16)) + max := new(big.Int).Lsh(big.NewInt(1), 48) + rb, err := rand.Int(rand.Reader, max) if err != nil { return 0, fmt.Errorf("failed to generate random ID component: %w", err) } - return uint((ts << 16) | rb.Uint64()), nil + return uint(rb.Uint64()), nil } // TableRepository implements the Repository interface for Azure Table Storage @@ -261,7 +261,10 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error case float64: sv = uint(v) case json.Number: - n, _ := v.Int64() + n, err := v.Int64() + if err != nil || n < 0 { + return dberrors.NewDatabaseError("update", fmt.Errorf("invalid stored version value: %v", v)) + } sv = uint(n) } if currentVersion != sv { @@ -445,11 +448,15 @@ func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions Price: price, } - // Populate Version if present + // Populate Version if present; default to 1 when absent or invalid if v, ok := entityData["Version"]; ok { if vf, ok := v.(float64); ok { item.Version = uint(vf) + } else { + item.Version = 1 } + } else { + item.Version = 1 } // Apply name contains filter if specified diff --git a/backend/internal/database/migrations.go b/backend/internal/database/migrations.go index ccba18e..533ff1c 100644 --- a/backend/internal/database/migrations.go +++ b/backend/internal/database/migrations.go @@ -67,8 +67,15 @@ func (d *Database) AutoMigrate() error { Name: "update_items_version_default", Description: "Set Version default to 1 for optimistic locking and update existing rows", Up: func(tx *gorm.DB) error { - // Update existing rows that have the old default of 0 - if err := tx.Exec("UPDATE items SET version = 1 WHERE version = 0").Error; err != nil { + // Find the current max version to avoid collisions with already-updated items + var maxVersion int64 + if err := tx.Table("items").Select("COALESCE(MAX(version), 0)").Scan(&maxVersion).Error; err != nil { + return err + } + newVersion := maxVersion + 1 + + // Update existing rows that still have the old default of 0 + if err := tx.Exec("UPDATE items SET version = ? WHERE version = 0", newVersion).Error; err != nil { return err } // Alter column default to 1 (MySQL syntax; SQLite defaults are set via AutoMigrate) diff --git a/backend/internal/database/repository.go b/backend/internal/database/repository.go index d181842..ec7f33f 100644 --- a/backend/internal/database/repository.go +++ b/backend/internal/database/repository.go @@ -28,8 +28,12 @@ func NewRepository(cfg *config.Config) (models.Repository, error) { return nil, fmt.Errorf("failed to initialize MySQL database: %w", err) } - // Run database migrations + // Run database migrations (migrator tracks applied versions; safe on every startup) if err := db.AutoMigrate(); err != nil { + // Clean up the database connection to avoid resource leaks + if sqlDB, dbErr := db.DB.DB(); dbErr == nil { + _ = sqlDB.Close() + } return nil, fmt.Errorf("failed to run database migrations: %w", err) } diff --git a/backend/internal/health/health.go b/backend/internal/health/health.go index 32ac0a3..fe272f7 100644 --- a/backend/internal/health/health.go +++ b/backend/internal/health/health.go @@ -49,6 +49,9 @@ func (h *HealthChecker) SetReady(ready bool) { h.isReady = ready } +// CheckLiveness returns a simple UP status. The context parameter is unused +// because liveness is intentionally trivial — it only reports uptime and +// does not call any external dependencies that would need cancellation. func (h *HealthChecker) CheckLiveness(_ context.Context) HealthStatus { uptime := time.Since(h.startTime).String() return HealthStatus{ diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index edc6fe6..72eda31 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -127,7 +127,13 @@ func (r *GenericRepository) Update(ctx context.Context, entity interface{}) erro } } - // Optimistic locking for Versionable entities + // Optimistic locking for Versionable entities. + // We increment the version optimistically before Save, then use a + // WHERE version=old clause. If no rows are affected, it means another + // transaction modified this entity \u2014 we roll back the in-memory version + // and return a version-mismatch error. Note: Save() issues an UPDATE + // for all columns, which is safe here because the handler loaded the + // entity first (FindByID), then applied changes on top of it. if ver, ok := entity.(Versionable); ok { currentVersion := ver.GetVersion() ver.SetVersion(currentVersion + 1) @@ -165,6 +171,9 @@ func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditio return dberrors.NewDatabaseError("list", fmt.Errorf("invalid filter field: %q", c.Field)) } + // SAFETY: c.Field is interpolated into fmt.Sprintf below, but it is + // guaranteed to be one of the whitelisted column names checked above, + // so SQL injection via field names is not possible. switch c.Op { case "exact": query = query.Where(fmt.Sprintf("%s = ?", c.Field), c.Value) @@ -173,8 +182,10 @@ func (r *GenericRepository) List(ctx context.Context, dest interface{}, conditio case "<=": query = query.Where(fmt.Sprintf("%s <= ?", c.Field), c.Value) default: - // Default to LIKE for substring matching - query = query.Where(fmt.Sprintf("%s LIKE ?", c.Field), "%"+fmt.Sprint(c.Value)+"%") + // Default to LIKE for substring matching. + // Escape SQL wildcards (% and _) so they are treated as literals. + escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(fmt.Sprint(c.Value)) + query = query.Where(fmt.Sprintf("%s LIKE ?", c.Field), "%"+escaped+"%") } case Pagination: if c.Limit > 0 { diff --git a/backend/pkg/utils/utils.go b/backend/pkg/utils/utils.go index e1bdf53..9186a38 100644 --- a/backend/pkg/utils/utils.go +++ b/backend/pkg/utils/utils.go @@ -22,7 +22,7 @@ func GenerateRandomString(length int) (string, error) { for i := range b { n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) if err != nil { - return "", fmt.Errorf("failed to generate random byte at index %d: %w", i, err) + return "", fmt.Errorf("crypto/rand failure: %w", err) } b[i] = charset[int(n.Int64())] } From 27cb619184d75eea1e6848ffbfa952a6f7978256 Mon Sep 17 00:00:00 2001 From: omattsson Date: Wed, 25 Feb 2026 08:34:08 +0100 Subject: [PATCH 11/18] fix: address tenth round of PR review comments - Simplify validation error wrapping: pass err directly to NewDatabaseError instead of double-wrapping with fmt.Errorf and ErrValidation sentinel - Add Close() to Repository interface for proper resource cleanup; GenericRepository closes the sql.DB pool, Azure TableRepository is a no-op - Close database connections during graceful shutdown in main.go - Fix rate limiter TOCTOU race: use single write lock for check-and-add to prevent concurrent requests from exceeding the limit - Simplify migration: set version=0 items to version=1 (the new default) rather than MAX(version)+1 which creates unnecessary gaps --- backend/api/main.go | 5 +++++ backend/api/main_test.go | 5 +++++ backend/internal/api/handlers/mock_repository.go | 5 +++++ backend/internal/api/handlers/rate_limiter.go | 12 +++++------- backend/internal/database/azure/table.go | 6 ++++++ backend/internal/database/migrations.go | 11 ++--------- backend/internal/models/models.go | 16 ++++++++++++---- 7 files changed, 40 insertions(+), 20 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index 91eef74..c18fbc3 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -92,6 +92,11 @@ func main() { err = srv.Shutdown(ctx) + // Close repository connections (database pool, etc.) + if closeErr := repo.Close(); closeErr != nil { + slog.Error("Failed to close repository", "error", closeErr) + } + if err != nil { slog.Error("Server forced to shutdown", "error", err) return diff --git a/backend/api/main_test.go b/backend/api/main_test.go index 28740ae..2fe1a30 100644 --- a/backend/api/main_test.go +++ b/backend/api/main_test.go @@ -68,6 +68,11 @@ func (m *MockRepository) Ping(ctx context.Context) error { return args.Error(0) } +func (m *MockRepository) Close() error { + args := m.Called() + return args.Error(0) +} + // MockSQLDB is a mock implementation of the sql.DB interface type MockSQLDB struct { mock.Mock diff --git a/backend/internal/api/handlers/mock_repository.go b/backend/internal/api/handlers/mock_repository.go index a16aebf..e372b8d 100644 --- a/backend/internal/api/handlers/mock_repository.go +++ b/backend/internal/api/handlers/mock_repository.go @@ -224,6 +224,11 @@ func (m *MockRepository) Ping(_ context.Context) error { return nil } +// Close implements the Repository interface +func (m *MockRepository) Close() error { + return nil +} + func (m *MockRepository) SetError(err error) { m.Lock() defer m.Unlock() diff --git a/backend/internal/api/handlers/rate_limiter.go b/backend/internal/api/handlers/rate_limiter.go index bb49b33..bfde592 100644 --- a/backend/internal/api/handlers/rate_limiter.go +++ b/backend/internal/api/handlers/rate_limiter.go @@ -43,25 +43,23 @@ func (rl *RateLimiter) RateLimit() gin.HandlerFunc { now := time.Now() windowStart := now.Add(-rl.window) - // Read-lock to check current count without blocking other readers. - rl.RLock() - times := rl.requests[ip] + // Single write lock for the check-and-add to avoid a TOCTOU race + // where concurrent requests could both pass the limit check. + rl.Lock() count := 0 - for _, t := range times { + for _, t := range rl.requests[ip] { if t.After(windowStart) { count++ } } - rl.RUnlock() if count >= rl.limit { + rl.Unlock() c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) c.Abort() return } - // Write-lock only to add the new request timestamp. - rl.Lock() rl.requests[ip] = append(rl.requests[ip], now) rl.Unlock() diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index aa29efc..8679832 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -500,6 +500,12 @@ func (r *TableRepository) Ping(ctx context.Context) error { return nil } +// Close implements the Repository interface. Azure Table Storage uses HTTP +// clients that don't require explicit cleanup, so this is a no-op. +func (r *TableRepository) Close() error { + return nil +} + // Helper functions for error handling // IsTableExistsError checks if the error is a TableAlreadyExists error diff --git a/backend/internal/database/migrations.go b/backend/internal/database/migrations.go index 533ff1c..61d7521 100644 --- a/backend/internal/database/migrations.go +++ b/backend/internal/database/migrations.go @@ -67,15 +67,8 @@ func (d *Database) AutoMigrate() error { Name: "update_items_version_default", Description: "Set Version default to 1 for optimistic locking and update existing rows", Up: func(tx *gorm.DB) error { - // Find the current max version to avoid collisions with already-updated items - var maxVersion int64 - if err := tx.Table("items").Select("COALESCE(MAX(version), 0)").Scan(&maxVersion).Error; err != nil { - return err - } - newVersion := maxVersion + 1 - - // Update existing rows that still have the old default of 0 - if err := tx.Exec("UPDATE items SET version = ? WHERE version = 0", newVersion).Error; err != nil { + // Update existing rows that still have the old default of 0 to the new default of 1 + if err := tx.Exec("UPDATE items SET version = 1 WHERE version = 0").Error; err != nil { return err } // Alter column default to 1 (MySQL syntax; SQLite defaults are set via AutoMigrate) diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 72eda31..0a1ebc3 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -61,6 +61,7 @@ type Repository interface { Delete(ctx context.Context, entity interface{}) error List(ctx context.Context, dest interface{}, conditions ...interface{}) error Ping(ctx context.Context) error + Close() error } // GenericRepository implements the Repository interface @@ -98,11 +99,19 @@ func (r *GenericRepository) Ping(ctx context.Context) error { return sqlDB.PingContext(ctx) } +// Close releases the underlying database connection pool. +func (r *GenericRepository) Close() error { + sqlDB, err := r.db.DB() + if err != nil { + return err + } + return sqlDB.Close() +} + func (r *GenericRepository) Create(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", - fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) + return dberrors.NewDatabaseError("validate", err) } } @@ -122,8 +131,7 @@ func (r *GenericRepository) FindByID(ctx context.Context, id uint, dest interfac func (r *GenericRepository) Update(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", - fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) + return dberrors.NewDatabaseError("validate", err) } } From 2333d5454884694a975fdf84498e4e04d8f3daa8 Mon Sep 17 00:00:00 2001 From: omattsson Date: Wed, 25 Feb 2026 09:55:38 +0100 Subject: [PATCH 12/18] fix: address eleventh round of PR review comments - GenericRepository.Delete now checks RowsAffected and returns ErrNotFound when no rows are affected (delete of non-existent entity) - Move gin.SetMode(gin.TestMode) to TestMain in middleware tests to avoid mutating global state from parallel test goroutines - Note: Azure TableRepository.Create already populates item.CreatedAt and item.UpdatedAt before returning (lines 153-154 of table.go) --- backend/internal/api/middleware/cors_test.go | 1 - backend/internal/api/middleware/middleware_test.go | 12 ++++++------ backend/internal/models/models.go | 8 ++++++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index d0db9e4..a7ff14f 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -11,7 +11,6 @@ import ( func TestCORSMiddleware(t *testing.T) { t.Parallel() - gin.SetMode(gin.TestMode) t.Run("Wildcard allows all origins", func(t *testing.T) { t.Parallel() diff --git a/backend/internal/api/middleware/middleware_test.go b/backend/internal/api/middleware/middleware_test.go index bd228f4..73548dc 100644 --- a/backend/internal/api/middleware/middleware_test.go +++ b/backend/internal/api/middleware/middleware_test.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" "strings" "testing" @@ -15,10 +16,13 @@ import ( "github.com/stretchr/testify/assert" ) +func TestMain(m *testing.M) { + gin.SetMode(gin.TestMode) + os.Exit(m.Run()) +} + func TestLoggerMiddleware(t *testing.T) { // Not parallel: this test mutates the global slog default logger. - // Set Gin to Test Mode - gin.SetMode(gin.TestMode) // Create a buffer to capture slog output var buf bytes.Buffer @@ -53,8 +57,6 @@ func TestLoggerMiddleware(t *testing.T) { func TestRecoveryMiddleware(t *testing.T) { t.Parallel() - // Set Gin to Test Mode - gin.SetMode(gin.TestMode) // Setup router with middleware r := gin.New() @@ -81,7 +83,6 @@ func TestRecoveryMiddleware(t *testing.T) { func TestRequestIDMiddleware(t *testing.T) { t.Parallel() - gin.SetMode(gin.TestMode) t.Run("Generates new request ID when none provided", func(t *testing.T) { t.Parallel() @@ -147,7 +148,6 @@ func TestRequestIDMiddleware(t *testing.T) { func TestMaxBodySizeMiddleware(t *testing.T) { t.Parallel() - gin.SetMode(gin.TestMode) t.Run("Allows request within size limit", func(t *testing.T) { t.Parallel() diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 0a1ebc3..1af3efd 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -164,8 +164,12 @@ func (r *GenericRepository) Update(ctx context.Context, entity interface{}) erro } func (r *GenericRepository) Delete(ctx context.Context, entity interface{}) error { - if err := r.db.WithContext(ctx).Delete(entity).Error; err != nil { - return r.handleError("delete", err) + result := r.db.WithContext(ctx).Delete(entity) + if result.Error != nil { + return r.handleError("delete", result.Error) + } + if result.RowsAffected == 0 { + return dberrors.NewDatabaseError("delete", dberrors.ErrNotFound) } return nil } From 5a8b43ad6f04efb87399e8aef1bae1bfeedb72c0 Mon Sep 17 00:00:00 2001 From: omattsson Date: Wed, 25 Feb 2026 10:01:44 +0100 Subject: [PATCH 13/18] fix: address twelfth round of PR review comments - Azure Table Update: preserve stored CreatedAt when item.CreatedAt is zero to prevent accidental clobber if caller skips FindByID - Azure Table FindByID: default Version to 1 when field is missing or invalid, consistent with GORM default and Create/List behavior - CreateItem handler: reset item.Version to 0 after binding so the repository's Create sets the server-managed initial version (1) - utils_test: use assert.Truef so format verb %c is interpolated in failure messages --- backend/internal/api/handlers/items.go | 3 +++ backend/internal/database/azure/table.go | 22 ++++++++++++++++++++-- backend/pkg/utils/utils_test.go | 2 +- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index 8f5c132..70a519a 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -68,6 +68,9 @@ func (h *Handler) CreateItem(c *gin.Context) { return } + // Version is server-managed; force initial value regardless of client input. + item.Version = 0 + if err := h.repository.Create(c.Request.Context(), &item); err != nil { status, message := handleDBError(err) c.JSON(status, gin.H{"error": message}) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 8679832..da803c6 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -193,8 +193,11 @@ func (r *TableRepository) FindByID(ctx context.Context, id uint, dest interface{ } item.Price = price + // Default to version 1 when the Version field is missing or invalid to + // keep optimistic-lock semantics consistent with the GORM repository. + item.Version = 1 if v, ok := entityData["Version"]; ok { - if vf, ok := v.(float64); ok { + if vf, ok := v.(float64); ok && vf > 0 { item.Version = uint(vf) } } @@ -278,13 +281,28 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error // Create Azure Table entity now := time.Now().UTC() + + // Preserve the stored CreatedAt so callers that skip FindByID before + // updating don't accidentally clobber it with a zero time. + createdAt := item.CreatedAt + if createdAt.IsZero() { + var existingData2 map[string]interface{} + if err := json.Unmarshal(existing.Value, &existingData2); err == nil { + if caStr, ok := existingData2["CreatedAt"].(string); ok { + if parsed, parseErr := time.Parse(time.RFC3339, caStr); parseErr == nil { + createdAt = parsed + } + } + } + } + entityJson := map[string]interface{}{ "PartitionKey": "items", "RowKey": strconv.FormatUint(uint64(item.ID), 10), "Name": item.Name, "Price": item.Price, "Version": item.Version, - "CreatedAt": item.CreatedAt.Format(time.RFC3339), + "CreatedAt": createdAt.Format(time.RFC3339), "UpdatedAt": now.Format(time.RFC3339), } diff --git a/backend/pkg/utils/utils_test.go b/backend/pkg/utils/utils_test.go index 917cddf..b406ffe 100644 --- a/backend/pkg/utils/utils_test.go +++ b/backend/pkg/utils/utils_test.go @@ -36,7 +36,7 @@ func TestUtilFunctions(t *testing.T) { s, err := GenerateRandomString(100) require.NoError(t, err) for _, c := range s { - assert.True(t, + assert.Truef(t, (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9'), "unexpected character: %c", c) } From 420c0533c28a22f27d9e72d4a351961845a3fbb3 Mon Sep 17 00:00:00 2001 From: omattsson Date: Wed, 25 Feb 2026 10:12:57 +0100 Subject: [PATCH 14/18] fix: address thirteenth round of PR review comments - CreateItem: set item.Version = 1 (not 0) so GORM default is honored and optimistic locking works correctly for newly created items - Azure Table Update: parse existing entity JSON once and reuse for both version checking and CreatedAt preservation, eliminating redundant json.Unmarshal --- backend/internal/api/handlers/items.go | 2 +- backend/internal/database/azure/table.go | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/backend/internal/api/handlers/items.go b/backend/internal/api/handlers/items.go index 70a519a..b6c80f1 100644 --- a/backend/internal/api/handlers/items.go +++ b/backend/internal/api/handlers/items.go @@ -69,7 +69,7 @@ func (h *Handler) CreateItem(c *gin.Context) { } // Version is server-managed; force initial value regardless of client input. - item.Version = 0 + item.Version = 1 if err := h.repository.Create(c.Request.Context(), &item); err != nil { status, message := handleDBError(err) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index da803c6..47b9224 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -249,15 +249,16 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("find", err) } + // Parse existing entity once for version checking and CreatedAt preservation. + var existingData map[string]interface{} + if err := json.Unmarshal(existing.Value, &existingData); err != nil { + return dberrors.NewDatabaseError("unmarshal", err) + } + // Optimistic locking: compare version if the entity is Versionable if ver, ok := entity.(models.Versionable); ok { currentVersion := ver.GetVersion() - // Parse existing entity to check stored version - var existingData map[string]interface{} - if err := json.Unmarshal(existing.Value, &existingData); err != nil { - return dberrors.NewDatabaseError("unmarshal", err) - } if storedVersion, ok := existingData["Version"]; ok { var sv uint switch v := storedVersion.(type) { @@ -286,12 +287,9 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error // updating don't accidentally clobber it with a zero time. createdAt := item.CreatedAt if createdAt.IsZero() { - var existingData2 map[string]interface{} - if err := json.Unmarshal(existing.Value, &existingData2); err == nil { - if caStr, ok := existingData2["CreatedAt"].(string); ok { - if parsed, parseErr := time.Parse(time.RFC3339, caStr); parseErr == nil { - createdAt = parsed - } + if caStr, ok := existingData["CreatedAt"].(string); ok { + if parsed, parseErr := time.Parse(time.RFC3339, caStr); parseErr == nil { + createdAt = parsed } } } From 4a9f009486661b242ffbe7658fa43b69d521c760 Mon Sep 17 00:00:00 2001 From: omattsson Date: Wed, 25 Feb 2026 10:28:35 +0100 Subject: [PATCH 15/18] fix: address fourteenth round of PR review comments - Rate limiter: prune expired timestamps during count loop to prevent unbounded slice growth between periodic cleanup cycles - Azure Table Update: add default case to version type switch to return an error for unexpected types instead of silently using zero - Config: skip MySQL Database validation when UseAzureTable is true to avoid unnecessary startup failures with incomplete MySQL credentials - main.go: use cfg.Server.ShutdownTimeout instead of hardcoded 5s constant, falling back to default if not configured - NewRepository: document that it is Item-specific; other entity types should use NewRepositoryWithFilterFields --- backend/api/main.go | 8 ++++++-- backend/internal/api/handlers/rate_limiter.go | 9 ++++++--- backend/internal/config/config.go | 8 ++++---- backend/internal/database/azure/table.go | 2 ++ backend/internal/models/models.go | 3 ++- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/backend/api/main.go b/backend/api/main.go index c18fbc3..adad117 100644 --- a/backend/api/main.go +++ b/backend/api/main.go @@ -23,7 +23,7 @@ import ( ) const ( - gracefulShutdownTimeout = 5 * time.Second + defaultShutdownTimeout = 5 * time.Second ) // @title Backend API @@ -87,7 +87,11 @@ func main() { slog.Info("Shutting down server...") // Give outstanding requests time to complete - ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout) + shutdownTimeout := cfg.Server.ShutdownTimeout + if shutdownTimeout == 0 { + shutdownTimeout = defaultShutdownTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() err = srv.Shutdown(ctx) diff --git a/backend/internal/api/handlers/rate_limiter.go b/backend/internal/api/handlers/rate_limiter.go index bfde592..4b983d1 100644 --- a/backend/internal/api/handlers/rate_limiter.go +++ b/backend/internal/api/handlers/rate_limiter.go @@ -46,14 +46,17 @@ func (rl *RateLimiter) RateLimit() gin.HandlerFunc { // Single write lock for the check-and-add to avoid a TOCTOU race // where concurrent requests could both pass the limit check. rl.Lock() - count := 0 + // Filter expired timestamps during counting to prevent unbounded + // slice growth between periodic cleanup cycles. + valid := rl.requests[ip][:0] for _, t := range rl.requests[ip] { if t.After(windowStart) { - count++ + valid = append(valid, t) } } + rl.requests[ip] = valid - if count >= rl.limit { + if len(valid) >= rl.limit { rl.Unlock() c.JSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) c.Abort() diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 22bcac1..9652307 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -108,10 +108,10 @@ func (c *Config) Validate() error { if err := c.AzureTable.Validate(); err != nil { return fmt.Errorf("azure table config: %w", err) } - } - - if err := c.Database.Validate(); err != nil { - return fmt.Errorf("database config: %w", err) + } else { + if err := c.Database.Validate(); err != nil { + return fmt.Errorf("database config: %w", err) + } } if err := c.Server.Validate(); err != nil { diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 47b9224..7b2af93 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -270,6 +270,8 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error return dberrors.NewDatabaseError("update", fmt.Errorf("invalid stored version value: %v", v)) } sv = uint(n) + default: + return dberrors.NewDatabaseError("update", fmt.Errorf("unexpected version type: %T", storedVersion)) } if currentVersion != sv { return dberrors.NewDatabaseError("update", errors.New("version mismatch")) diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 1af3efd..37cf7af 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -70,7 +70,8 @@ type GenericRepository struct { allowedFilterFields map[string]bool } -// NewRepository creates a new GenericRepository with default allowed filter fields. +// NewRepository creates a new GenericRepository with filter fields for the Item +// entity ("name", "price"). For other entity types, use NewRepositoryWithFilterFields. func NewRepository(db *gorm.DB) Repository { return &GenericRepository{ db: db, From ba1de3d3d417749b03385e30240e45f698977d9d Mon Sep 17 00:00:00 2001 From: omattsson Date: Fri, 6 Mar 2026 18:11:54 +0100 Subject: [PATCH 16/18] fix: address fifteenth round of PR review comments - Azure Table Update: default stored version to 1 for legacy rows missing the Version field, enforcing optimistic locking consistently - MaxBodySize middleware: detect *http.MaxBytesError after handler execution and return 413 Request Entity Too Large instead of 400 - HealthChecker.CheckReadiness: copy isReady and dependencies under RLock, then release before running I/O-bound dependency checks to reduce mutex contention --- backend/internal/api/middleware/middleware.go | 14 +++++++++++ backend/internal/database/azure/table.go | 24 ++++++++++--------- backend/internal/health/health.go | 13 +++++++--- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index c5a8b20..9bd5f3e 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -2,6 +2,7 @@ package middleware import ( "crypto/rand" + "errors" "fmt" "log/slog" "net/http" @@ -88,10 +89,23 @@ func RequestID() gin.HandlerFunc { } // MaxBodySize limits the size of the request body to prevent memory exhaustion. +// Oversized payloads are translated to a 413 Request Entity Too Large response. func MaxBodySize(maxBytes int64) gin.HandlerFunc { return func(c *gin.Context) { c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) c.Next() + + // If reading the body hit the limit, MaxBytesReader returns + // *http.MaxBytesError. Detect this and override the status code + // so clients receive a proper 413 instead of a generic 400. + if c.Errors.Last() != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(c.Errors.Last().Err, &maxBytesErr) { + c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, + gin.H{"error": "request body too large"}) + return + } + } } } diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 7b2af93..1fc6c05 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -259,25 +259,27 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error if ver, ok := entity.(models.Versionable); ok { currentVersion := ver.GetVersion() + // Default stored version to 1 for legacy rows that predate versioning, + // consistent with the model default and FindByID behavior. + storedVersionUint := uint(1) + if storedVersion, ok := existingData["Version"]; ok { - var sv uint switch v := storedVersion.(type) { case float64: - sv = uint(v) + if v >= 0 { + storedVersionUint = uint(v) + } case json.Number: - n, err := v.Int64() - if err != nil || n < 0 { - return dberrors.NewDatabaseError("update", fmt.Errorf("invalid stored version value: %v", v)) + if n, err := v.Int64(); err == nil && n >= 0 { + storedVersionUint = uint(n) } - sv = uint(n) - default: - return dberrors.NewDatabaseError("update", fmt.Errorf("unexpected version type: %T", storedVersion)) - } - if currentVersion != sv { - return dberrors.NewDatabaseError("update", errors.New("version mismatch")) } } + if currentVersion != storedVersionUint { + return dberrors.NewDatabaseError("update", errors.New("version mismatch")) + } + // Increment version for the update ver.SetVersion(currentVersion + 1) } diff --git a/backend/internal/health/health.go b/backend/internal/health/health.go index fe272f7..a9f11af 100644 --- a/backend/internal/health/health.go +++ b/backend/internal/health/health.go @@ -61,10 +61,17 @@ func (h *HealthChecker) CheckLiveness(_ context.Context) HealthStatus { } func (h *HealthChecker) CheckReadiness(ctx context.Context) HealthStatus { + // Copy state under lock, then release before running I/O-bound checks + // to avoid holding the mutex during potentially slow dependency calls. h.mu.RLock() - defer h.mu.RUnlock() + ready := h.isReady + deps := make(map[string]HealthCheck, len(h.dependencies)) + for k, v := range h.dependencies { + deps[k] = v + } + h.mu.RUnlock() - if !h.isReady { + if !ready { return HealthStatus{ Status: "DOWN", Checks: map[string]CheckStatus{ @@ -78,7 +85,7 @@ func (h *HealthChecker) CheckReadiness(ctx context.Context) HealthStatus { Checks: make(map[string]CheckStatus), } - for name, check := range h.dependencies { + for name, check := range deps { if err := check(ctx); err != nil { status.Status = "DOWN" status.Checks[name] = CheckStatus{ From 4ee8fa795fe844536b05b81080fa43f6ea54a175 Mon Sep 17 00:00:00 2001 From: omattsson Date: Fri, 6 Mar 2026 18:25:10 +0100 Subject: [PATCH 17/18] fix: address sixteenth round of PR review comments - Azure Table Update: only accept stored version > 0 as valid, mapping 0 to the default of 1 to match FindByID semantics - CORS middleware: allow requests with no Origin header (non-browser / same-origin) through without blocking; only enforce whitelist when Origin is present. Added test for no-Origin pass-through. - MaxBodySize middleware: replace c.Errors inspection with a body wrapper (maxBytesBodyCapture) that captures *http.MaxBytesError during Read, then returns 413 if exceeded and not yet written - Validation wrapping: re-wrap validation errors with ErrValidation sentinel via fmt.Errorf so handleDBError can reliably map to 400 --- backend/internal/api/middleware/cors_test.go | 15 ++++ backend/internal/api/middleware/middleware.go | 72 ++++++++++++------- backend/internal/database/azure/table.go | 6 +- backend/internal/models/models.go | 6 +- 4 files changed, 71 insertions(+), 28 deletions(-) diff --git a/backend/internal/api/middleware/cors_test.go b/backend/internal/api/middleware/cors_test.go index a7ff14f..fe5d6de 100644 --- a/backend/internal/api/middleware/cors_test.go +++ b/backend/internal/api/middleware/cors_test.go @@ -90,6 +90,21 @@ func TestCORSMiddleware(t *testing.T) { assert.Equal(t, http.StatusForbidden, w.Code) }) + t.Run("No Origin header passes through as non-CORS request", func(t *testing.T) { + t.Parallel() + r := gin.New() + r.Use(CORS("https://example.com")) + r.Any("/test", func(c *gin.Context) { c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + // No Origin header set — non-browser / same-origin request + r.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, w.Code) + }) + t.Run("Disallowed origin OPTIONS preflight returns 403 without CORS headers", func(t *testing.T) { t.Parallel() r := gin.New() diff --git a/backend/internal/api/middleware/middleware.go b/backend/internal/api/middleware/middleware.go index 9bd5f3e..f08a79a 100644 --- a/backend/internal/api/middleware/middleware.go +++ b/backend/internal/api/middleware/middleware.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "errors" "fmt" + "io" "log/slog" "net/http" "strings" @@ -21,21 +22,25 @@ func CORS(allowedOrigins string) gin.HandlerFunc { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") } else { requestOrigin := c.Request.Header.Get("Origin") - allowed := false - for _, origin := range strings.Split(allowedOrigins, ",") { - if strings.TrimSpace(origin) == requestOrigin { - c.Writer.Header().Set("Access-Control-Allow-Origin", requestOrigin) - c.Writer.Header().Set("Vary", "Origin") - allowed = true - break + if requestOrigin != "" { + allowed := false + for _, origin := range strings.Split(allowedOrigins, ",") { + if strings.TrimSpace(origin) == requestOrigin { + c.Writer.Header().Set("Access-Control-Allow-Origin", requestOrigin) + c.Writer.Header().Set("Vary", "Origin") + allowed = true + break + } + } + if !allowed { + // Block requests from non-whitelisted origins as defense-in-depth; + // browsers enforce CORS client-side, but we also enforce server-side. + c.AbortWithStatus(http.StatusForbidden) + return } } - if !allowed { - // Block requests from non-whitelisted origins as defense-in-depth; - // browsers enforce CORS client-side, but we also enforce server-side. - c.AbortWithStatus(http.StatusForbidden) - return - } + // If there is no Origin header, treat this as a non-CORS request: + // allow it through without setting Access-Control-Allow-Origin. } c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, Authorization, X-Request-ID") @@ -88,23 +93,42 @@ func RequestID() gin.HandlerFunc { } } +// maxBytesBodyCapture wraps an io.ReadCloser to detect *http.MaxBytesError. +type maxBytesBodyCapture struct { + rc io.ReadCloser + exceeded *bool +} + +func (m *maxBytesBodyCapture) Read(p []byte) (int, error) { + n, err := m.rc.Read(p) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + *m.exceeded = true + } + } + return n, err +} + +func (m *maxBytesBodyCapture) Close() error { + return m.rc.Close() +} + // MaxBodySize limits the size of the request body to prevent memory exhaustion. // Oversized payloads are translated to a 413 Request Entity Too Large response. func MaxBodySize(maxBytes int64) gin.HandlerFunc { return func(c *gin.Context) { - c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) + var exceeded bool + limitedReader := http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) + c.Request.Body = &maxBytesBodyCapture{rc: limitedReader, exceeded: &exceeded} c.Next() - // If reading the body hit the limit, MaxBytesReader returns - // *http.MaxBytesError. Detect this and override the status code - // so clients receive a proper 413 instead of a generic 400. - if c.Errors.Last() != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(c.Errors.Last().Err, &maxBytesErr) { - c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, - gin.H{"error": "request body too large"}) - return - } + // If the body size was exceeded and the handler has not yet written + // a response, return a 413 so clients get a clear signal. + if exceeded && !c.Writer.Written() { + c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, + gin.H{"error": "request body too large"}) + return } } } diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index 1fc6c05..d64d19c 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -266,11 +266,13 @@ func (r *TableRepository) Update(ctx context.Context, entity interface{}) error if storedVersion, ok := existingData["Version"]; ok { switch v := storedVersion.(type) { case float64: - if v >= 0 { + // Only treat strictly positive values as valid stored versions. + // Non-positive values leave the default of 1, matching FindByID. + if v > 0 { storedVersionUint = uint(v) } case json.Number: - if n, err := v.Int64(); err == nil && n >= 0 { + if n, err := v.Int64(); err == nil && n > 0 { storedVersionUint = uint(n) } } diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index 37cf7af..b23dd27 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -112,7 +112,8 @@ func (r *GenericRepository) Close() error { func (r *GenericRepository) Create(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", err) + return dberrors.NewDatabaseError("validate", + fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) } } @@ -132,7 +133,8 @@ func (r *GenericRepository) FindByID(ctx context.Context, id uint, dest interfac func (r *GenericRepository) Update(ctx context.Context, entity interface{}) error { if v, ok := entity.(Validator); ok { if err := v.Validate(); err != nil { - return dberrors.NewDatabaseError("validate", err) + return dberrors.NewDatabaseError("validate", + fmt.Errorf("%w: %s", dberrors.ErrValidation, err.Error())) } } From 8a8f845adf720cea4fbd06dd375669377082ee21 Mon Sep 17 00:00:00 2001 From: omattsson Date: Fri, 6 Mar 2026 18:38:29 +0100 Subject: [PATCH 18/18] =?UTF-8?q?fix:=20Azure=20List()=20ParseUint=20bitSi?= =?UTF-8?q?ze=2032=E2=86=9264,=20add=20optimistic=20locking=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix strconv.ParseUint bitSize from 32 to 64 in Azure Table List() to match 48-bit IDs generated by nextID(). - Add TestItemCRUD/Optimistic_Locking_-_version_mismatch covering the GenericRepository optimistic-lock path: creates an item, updates it, then attempts a second update with a stale version and asserts version-mismatch error, version rollback, and no data clobber. - Change Save() to Model().Where().Select("*").Updates() in the optimistic-locking path so the WHERE version clause works consistently across dialects (Save() may emit INSERT…ON CONFLICT on SQLite). --- backend/internal/database/azure/table.go | 2 +- backend/internal/database/database_test.go | 41 ++++++++++++++++++++++ backend/internal/models/models.go | 18 ++++++---- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/backend/internal/database/azure/table.go b/backend/internal/database/azure/table.go index d64d19c..5be0800 100644 --- a/backend/internal/database/azure/table.go +++ b/backend/internal/database/azure/table.go @@ -427,7 +427,7 @@ func (r *TableRepository) List(ctx context.Context, dest interface{}, conditions if !ok || rowKey == "" { return dberrors.NewDatabaseError("list", fmt.Errorf("entity missing or invalid RowKey")) } - id, err := strconv.ParseUint(rowKey, 10, 32) + id, err := strconv.ParseUint(rowKey, 10, 64) if err != nil { return dberrors.NewDatabaseError("list", fmt.Errorf("invalid RowKey %q: %w", rowKey, err)) } diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go index 543a33a..dfb6a44 100644 --- a/backend/internal/database/database_test.go +++ b/backend/internal/database/database_test.go @@ -199,4 +199,45 @@ func TestItemCRUD(t *testing.T) { assert.Error(t, err, "Should not find deleted item") }) + t.Run("Optimistic Locking - version mismatch", func(t *testing.T) { + t.Parallel() + db := setupTestDB(t) + require.NoError(t, db.AutoMigrate()) + repo := models.NewRepository(db.DB) + ctx := context.Background() + + // Create an item (version starts at 1 after create). + item := &models.Item{ + Name: "Versioned Item", + Price: 10.00, + } + require.NoError(t, repo.Create(ctx, item)) + assert.Equal(t, uint(1), item.Version) + + // Simulate two concurrent reads. + var copy1, copy2 models.Item + require.NoError(t, repo.FindByID(ctx, item.ID, ©1)) + require.NoError(t, repo.FindByID(ctx, item.ID, ©2)) + + // First update succeeds: version 1 → 2. + copy1.Price = 20.00 + require.NoError(t, repo.Update(ctx, ©1)) + assert.Equal(t, uint(2), copy1.Version) + + // Second update uses stale version 1 — should fail. + copy2.Price = 30.00 + err := repo.Update(ctx, ©2) + assert.Error(t, err, "Update with stale version should fail") + assert.Contains(t, err.Error(), "version mismatch") + + // The in-memory version should be rolled back. + assert.Equal(t, uint(1), copy2.Version, "Version should be rolled back on mismatch") + + // Database should still have the first update's data. + var current models.Item + require.NoError(t, repo.FindByID(ctx, item.ID, ¤t)) + assert.Equal(t, 20.00, current.Price, "Price should reflect the first successful update") + assert.Equal(t, uint(2), current.Version, "Version in DB should be 2") + }) + } diff --git a/backend/internal/models/models.go b/backend/internal/models/models.go index b23dd27..02a520b 100644 --- a/backend/internal/models/models.go +++ b/backend/internal/models/models.go @@ -139,16 +139,22 @@ func (r *GenericRepository) Update(ctx context.Context, entity interface{}) erro } // Optimistic locking for Versionable entities. - // We increment the version optimistically before Save, then use a + // We increment the version optimistically, then issue an UPDATE with a // WHERE version=old clause. If no rows are affected, it means another - // transaction modified this entity \u2014 we roll back the in-memory version - // and return a version-mismatch error. Note: Save() issues an UPDATE - // for all columns, which is safe here because the handler loaded the - // entity first (FindByID), then applied changes on top of it. + // transaction modified this entity — we roll back the in-memory version + // and return a version-mismatch error. + // We use Model().Where().Select("*").Updates() instead of Where().Save() + // because Save() may generate an upsert (INSERT … ON CONFLICT) on some + // dialects (e.g. SQLite), which bypasses the WHERE version clause. + // Select("*") ensures all columns are written, matching Save() semantics. if ver, ok := entity.(Versionable); ok { currentVersion := ver.GetVersion() ver.SetVersion(currentVersion + 1) - result := r.db.WithContext(ctx).Where("version = ?", currentVersion).Save(entity) + result := r.db.WithContext(ctx). + Model(entity). + Where("version = ?", currentVersion). + Select("*"). + Updates(entity) if result.Error != nil { ver.SetVersion(currentVersion) // Roll back version on error return r.handleError("update", result.Error)