diff --git a/README.markdown b/README.markdown index 198b688c..4daa62ad 100644 --- a/README.markdown +++ b/README.markdown @@ -38,6 +38,14 @@ docker run -p 4000:4000 -v $(pwd)/config.docker.json:/app/config.json:ro -v magl ```bash curl http://localhost:4000/healthz +### Windows (PowerShell) Note + +When using PowerShell, you may see a security warning when running: + +```powershell +curl http://localhost:4000/healthz + +to avoid this use curl http://localhost:4000/healthz -UseBasicParsing ``` diff --git a/internal/restapi/arrival_and_departure_for_stop_handler.go b/internal/restapi/arrival_and_departure_for_stop_handler.go index 7171bc1c..ddf51d85 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler.go @@ -172,7 +172,11 @@ func (api *RestAPI) arrivalAndDepartureForStopHandler(w http.ResponseWriter, r * stopAgency, err := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, stopAgencyID) if err != nil { - api.serverErrorResponse(w, r, err) + if errors.Is(err, sql.ErrNoRows) { + api.sendNotFound(w, r) + } else { + api.serverErrorResponse(w, r, err) + } return } diff --git a/internal/restapi/arrival_and_departure_for_stop_handler_test.go b/internal/restapi/arrival_and_departure_for_stop_handler_test.go index 823cb219..ba73fcfd 100644 --- a/internal/restapi/arrival_and_departure_for_stop_handler_test.go +++ b/internal/restapi/arrival_and_departure_for_stop_handler_test.go @@ -1131,3 +1131,34 @@ func TestArrivalAndDepartureForStop_VehicleWithNilID(t *testing.T) { require.True(t, ok) assert.Equal(t, "", entry["vehicleId"], "vehicleId should be empty for vehicle with nil ID") } + +// TestArrivalAndDepartureForStopHandler_AgencyNotFound verifies that the handler +// returns 404 (not 500) when the stop's agency does not exist in the database. +// This guards against regression of the sql.ErrNoRows → serverErrorResponse bug. +func TestArrivalAndDepartureForStopHandler_AgencyNotFound(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + trips := api.GtfsManager.GetTrips() + if len(trips) == 0 { + t.Skip("No trips available for testing") + } + + // Use a real agency's trip ID but a non-existent agency prefix for the stop. + realAgency := api.GtfsManager.GetAgencies()[0] + tripID := utils.FormCombinedID(realAgency.Id, trips[0].ID) + serviceDate := time.Now().Unix() * 1000 + + // "nonexistent_agency" does not exist in the DB, so GetAgency will return sql.ErrNoRows. + stopID := utils.FormCombinedID("nonexistent_agency", "some_stop_code") + + _, resp, model := serveAndRetrieveEndpoint(t, + "/api/where/arrival-and-departure-for-stop/"+stopID+".json?key=TEST&tripId="+tripID+ + "&serviceDate="+fmt.Sprintf("%d", serviceDate)) + + assert.Equal(t, http.StatusNotFound, resp.StatusCode, + "should return 404 when the agency is not found, not 500") + assert.Equal(t, http.StatusNotFound, model.Code) + assert.Equal(t, "resource not found", model.Text) +} + diff --git a/internal/restapi/search_stops_handler.go b/internal/restapi/search_stops_handler.go index fe41eeaa..b27e2bfb 100644 --- a/internal/restapi/search_stops_handler.go +++ b/internal/restapi/search_stops_handler.go @@ -52,9 +52,14 @@ func (api *RestAPI) searchStopsHandler(w http.ResponseWriter, r *http.Request) { limit := 50 if maxCountStr := r.URL.Query().Get("maxCount"); maxCountStr != "" { - if parsed, err := strconv.Atoi(maxCountStr); err == nil && parsed > 0 { - limit = parsed + parsed, err := strconv.Atoi(maxCountStr) + if err != nil || parsed <= 0 { + api.validationErrorResponse(w, r, map[string][]string{ + "maxCount": {"must be a positive integer"}, + }) + return } + limit = parsed } // 2. Sanitize and construct FTS5 query diff --git a/internal/restapi/search_stops_handler_test.go b/internal/restapi/search_stops_handler_test.go index 290a6d48..24225829 100644 --- a/internal/restapi/search_stops_handler_test.go +++ b/internal/restapi/search_stops_handler_test.go @@ -210,14 +210,15 @@ func TestSearchStopsHandlerSpecialCharactersOnly(t *testing.T) { assert.Empty(t, list) } -func TestSearchStopsHandlerMaxCountBoundaries(t *testing.T) { +// TestSearchStopsHandlerInvalidMaxCount verifies that non-positive integers and +// non-integer strings for maxCount are rejected with 400 Bad Request. +func TestSearchStopsHandlerInvalidMaxCount(t *testing.T) { api := createTestApi(t) stops := api.GtfsManager.GetStops() require.NotEmpty(t, stops) - targetStop := stops[0] - query := url.QueryEscape(targetStop.Name) + query := url.QueryEscape(stops[0].Name) tests := []struct { name string @@ -225,7 +226,8 @@ func TestSearchStopsHandlerMaxCountBoundaries(t *testing.T) { }{ {"zero", "0"}, {"negative", "-1"}, - {"tooLarge", "101"}, + {"non_integer", "abc"}, + {"float", "3.14"}, } for _, tt := range tests { @@ -238,19 +240,41 @@ func TestSearchStopsHandlerMaxCountBoundaries(t *testing.T) { resp, model := serveApiAndRetrieveEndpoint(t, api, reqUrl) - assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode, + "maxCount=%q should return 400", tt.maxCount) + assert.Equal(t, http.StatusBadRequest, model.Code) data, ok := model.Data.(map[string]interface{}) require.True(t, ok) - list, ok := data["list"].([]interface{}) - require.True(t, ok) - - assert.NotEmpty(t, list) + fieldErrors, ok := data["fieldErrors"].(map[string]interface{}) + require.True(t, ok, "response should contain fieldErrors") + assert.Contains(t, fieldErrors, "maxCount", + "fieldErrors should mention the maxCount field") }) } } +// TestSearchStopsHandlerLargeMaxCount verifies that a large but valid positive +// integer is accepted (it is a valid limit, not a validation error). +func TestSearchStopsHandlerLargeMaxCount(t *testing.T) { + api := createTestApi(t) + + stops := api.GtfsManager.GetStops() + require.NotEmpty(t, stops) + + query := url.QueryEscape(stops[0].Name) + reqUrl := fmt.Sprintf( + "/api/where/search/stop.json?key=TEST&input=%s&maxCount=1000", + query, + ) + + resp, model := serveApiAndRetrieveEndpoint(t, api, reqUrl) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, http.StatusOK, model.Code) +} + func TestSearchStopsHandlerFTSInjectionAttempt(t *testing.T) { api := createTestApi(t) diff --git a/internal/restapi/validation_middleware.go b/internal/restapi/validation_middleware.go new file mode 100644 index 00000000..aed11853 --- /dev/null +++ b/internal/restapi/validation_middleware.go @@ -0,0 +1,94 @@ +package restapi + +import ( + "fmt" + "net/http" + "strconv" +) + +// QueryParamRule defines a validation rule for a single query parameter. +// If the parameter is absent from the request, the rule is skipped (optional params). +// The Validate function receives the raw string value and returns an error message +// and a boolean indicating whether validation passed. +type QueryParamRule struct { + Param string + Validate func(value string) (errMsg string, ok bool) +} + +// PositiveIntRule returns a rule that validates a query parameter is a positive integer (> 0). +func PositiveIntRule(param string) QueryParamRule { + return QueryParamRule{ + Param: param, + Validate: func(value string) (string, bool) { + n, err := strconv.Atoi(value) + if err != nil { + return "must be a valid integer", false + } + if n <= 0 { + return "must be a positive integer", false + } + return "", true + }, + } +} + +// IntRangeRule returns a rule that validates a query parameter is an integer within [min, max]. +func IntRangeRule(param string, min, max int) QueryParamRule { + return QueryParamRule{ + Param: param, + Validate: func(value string) (string, bool) { + n, err := strconv.Atoi(value) + if err != nil { + return "must be a valid integer", false + } + if n < min || n > max { + return fmt.Sprintf("must be between %d and %d", min, max), false + } + return "", true + }, + } +} + +// NonNegativeIntRule returns a rule that validates a query parameter is a non-negative integer (>= 0). +func NonNegativeIntRule(param string) QueryParamRule { + return QueryParamRule{ + Param: param, + Validate: func(value string) (string, bool) { + n, err := strconv.Atoi(value) + if err != nil { + return "must be a valid integer", false + } + if n < 0 { + return "must be a non-negative integer", false + } + return "", true + }, + } +} + +// ValidateQueryParams applies validation rules to query parameters +// and returns 400 if validation fails before invoking the next handler. +// Parameters not present in the request are skipped (all rules are optional-param-safe). +func ValidateQueryParams(api *RestAPI, rules []QueryParamRule, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + fieldErrors := make(map[string][]string) + + for _, rule := range rules { + value := query.Get(rule.Param) + if value == "" { + continue // param not provided, skip + } + if errMsg, ok := rule.Validate(value); !ok { + fieldErrors[rule.Param] = append(fieldErrors[rule.Param], errMsg) + } + } + + if len(fieldErrors) > 0 { + api.validationErrorResponse(w, r, fieldErrors) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/restapi/validation_middleware_test.go b/internal/restapi/validation_middleware_test.go new file mode 100644 index 00000000..8fdbd55b --- /dev/null +++ b/internal/restapi/validation_middleware_test.go @@ -0,0 +1,220 @@ +package restapi + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPositiveIntRule(t *testing.T) { + rule := PositiveIntRule("maxCount") + + tests := []struct { + name string + value string + ok bool + }{ + {"valid positive", "10", true}, + {"valid one", "1", true}, + {"zero is invalid", "0", false}, + {"negative is invalid", "-5", false}, + {"non-numeric is invalid", "abc", false}, + {"float is invalid", "3.14", false}, + {"large valid", "999999", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errMsg, ok := rule.Validate(tt.value) + assert.Equal(t, tt.ok, ok) + if !ok { + assert.NotEmpty(t, errMsg) + } + }) + } +} + +func TestIntRangeRule(t *testing.T) { + rule := IntRangeRule("minutesAfter", 0, 240) + + tests := []struct { + name string + value string + ok bool + }{ + {"within range", "35", true}, + {"at minimum", "0", true}, + {"at maximum", "240", true}, + {"below minimum", "-1", false}, + {"above maximum", "241", false}, + {"non-numeric", "abc", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errMsg, ok := rule.Validate(tt.value) + assert.Equal(t, tt.ok, ok) + if !ok { + assert.NotEmpty(t, errMsg) + } + }) + } +} + +func TestNonNegativeIntRule(t *testing.T) { + rule := NonNegativeIntRule("offset") + + tests := []struct { + name string + value string + ok bool + }{ + {"positive", "5", true}, + {"zero", "0", true}, + {"negative", "-1", false}, + {"non-numeric", "xyz", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errMsg, ok := rule.Validate(tt.value) + assert.Equal(t, tt.ok, ok) + if !ok { + assert.NotEmpty(t, errMsg) + } + }) + } +} + +func TestValidateQueryParams(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + // A simple handler that returns 200 OK when reached + okHandler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) + } + + rules := []QueryParamRule{ + PositiveIntRule("maxCount"), + } + + wrapped := ValidateQueryParams(api, rules, http.HandlerFunc(okHandler)) + + t.Run("valid param passes through", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=10", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("absent param passes through", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("invalid param returns 400", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=-1", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + assert.Contains(t, string(body), "maxCount") + }) + + t.Run("non-numeric param returns 400", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=abc", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + }) + + t.Run("zero param returns 400", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=0", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + }) + + t.Run("response matches validationErrorResponse format", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=bad", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var resp struct { + Code int `json:"code"` + Text string `json:"text"` + Data struct { + FieldErrors map[string][]string `json:"fieldErrors"` + } `json:"data"` + } + err := json.NewDecoder(rec.Body).Decode(&resp) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.Code) + assert.NotEmpty(t, resp.Data.FieldErrors["maxCount"]) + }) +} + +func TestValidateQueryParamsMultipleRules(t *testing.T) { + api := createTestApi(t) + defer api.Shutdown() + + okHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + rules := []QueryParamRule{ + PositiveIntRule("maxCount"), + NonNegativeIntRule("offset"), + } + + wrapped := ValidateQueryParams(api, rules, http.HandlerFunc(okHandler)) + + t.Run("both valid", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=10&offset=0", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("first invalid", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=-1&offset=0", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + }) + + t.Run("second invalid", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=10&offset=-5", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + }) + + t.Run("both invalid collects all errors", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test?maxCount=-1&offset=-5", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + bodyStr := string(body) + assert.Contains(t, bodyStr, "maxCount") + assert.Contains(t, bodyStr, "offset") + }) +}