Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ docker run -p 4000:4000 -v $(pwd)/config.docker.json:/app/config.json:ro -v magl

```bash
curl http://localhost:4000/healthz
### Windows (PowerShell) Note

When using PowerShell, you may see a security warning when running:

```powershell
curl http://localhost:4000/healthz

to avoid this use curl http://localhost:4000/healthz -UseBasicParsing

```

Expand Down
6 changes: 5 additions & 1 deletion internal/restapi/arrival_and_departure_for_stop_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ func (api *RestAPI) arrivalAndDepartureForStopHandler(w http.ResponseWriter, r *

stopAgency, err := api.GtfsManager.GtfsDB.Queries.GetAgency(ctx, stopAgencyID)
if err != nil {
api.serverErrorResponse(w, r, err)
if errors.Is(err, sql.ErrNoRows) {
api.sendNotFound(w, r)
} else {
api.serverErrorResponse(w, r, err)
}
return
}

Expand Down
31 changes: 31 additions & 0 deletions internal/restapi/arrival_and_departure_for_stop_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,3 +1131,34 @@ func TestArrivalAndDepartureForStop_VehicleWithNilID(t *testing.T) {
require.True(t, ok)
assert.Equal(t, "", entry["vehicleId"], "vehicleId should be empty for vehicle with nil ID")
}

// TestArrivalAndDepartureForStopHandler_AgencyNotFound verifies that the handler
// returns 404 (not 500) when the stop's agency does not exist in the database.
// This guards against regression of the sql.ErrNoRows → serverErrorResponse bug.
func TestArrivalAndDepartureForStopHandler_AgencyNotFound(t *testing.T) {
api := createTestApi(t)
defer api.Shutdown()

trips := api.GtfsManager.GetTrips()
if len(trips) == 0 {
t.Skip("No trips available for testing")
}

// Use a real agency's trip ID but a non-existent agency prefix for the stop.
realAgency := api.GtfsManager.GetAgencies()[0]
tripID := utils.FormCombinedID(realAgency.Id, trips[0].ID)
serviceDate := time.Now().Unix() * 1000

// "nonexistent_agency" does not exist in the DB, so GetAgency will return sql.ErrNoRows.
stopID := utils.FormCombinedID("nonexistent_agency", "some_stop_code")

_, resp, model := serveAndRetrieveEndpoint(t,
"/api/where/arrival-and-departure-for-stop/"+stopID+".json?key=TEST&tripId="+tripID+
"&serviceDate="+fmt.Sprintf("%d", serviceDate))

assert.Equal(t, http.StatusNotFound, resp.StatusCode,
"should return 404 when the agency is not found, not 500")
assert.Equal(t, http.StatusNotFound, model.Code)
assert.Equal(t, "resource not found", model.Text)
}

9 changes: 7 additions & 2 deletions internal/restapi/search_stops_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ func (api *RestAPI) searchStopsHandler(w http.ResponseWriter, r *http.Request) {

limit := 50
if maxCountStr := r.URL.Query().Get("maxCount"); maxCountStr != "" {
if parsed, err := strconv.Atoi(maxCountStr); err == nil && parsed > 0 {
limit = parsed
parsed, err := strconv.Atoi(maxCountStr)
if err != nil || parsed <= 0 {
api.validationErrorResponse(w, r, map[string][]string{
"maxCount": {"must be a positive integer"},
})
return
}
limit = parsed
}

// 2. Sanitize and construct FTS5 query
Expand Down
42 changes: 33 additions & 9 deletions internal/restapi/search_stops_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,24 @@ func TestSearchStopsHandlerSpecialCharactersOnly(t *testing.T) {
assert.Empty(t, list)
}

func TestSearchStopsHandlerMaxCountBoundaries(t *testing.T) {
// TestSearchStopsHandlerInvalidMaxCount verifies that non-positive integers and
// non-integer strings for maxCount are rejected with 400 Bad Request.
func TestSearchStopsHandlerInvalidMaxCount(t *testing.T) {
api := createTestApi(t)

stops := api.GtfsManager.GetStops()
require.NotEmpty(t, stops)

targetStop := stops[0]
query := url.QueryEscape(targetStop.Name)
query := url.QueryEscape(stops[0].Name)

tests := []struct {
name string
maxCount string
}{
{"zero", "0"},
{"negative", "-1"},
{"tooLarge", "101"},
{"non_integer", "abc"},
{"float", "3.14"},
}

for _, tt := range tests {
Expand All @@ -238,19 +240,41 @@ func TestSearchStopsHandlerMaxCountBoundaries(t *testing.T) {

resp, model := serveApiAndRetrieveEndpoint(t, api, reqUrl)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode,
"maxCount=%q should return 400", tt.maxCount)
assert.Equal(t, http.StatusBadRequest, model.Code)

data, ok := model.Data.(map[string]interface{})
require.True(t, ok)

list, ok := data["list"].([]interface{})
require.True(t, ok)

assert.NotEmpty(t, list)
fieldErrors, ok := data["fieldErrors"].(map[string]interface{})
require.True(t, ok, "response should contain fieldErrors")
assert.Contains(t, fieldErrors, "maxCount",
"fieldErrors should mention the maxCount field")
})
}
}

// TestSearchStopsHandlerLargeMaxCount verifies that a large but valid positive
// integer is accepted (it is a valid limit, not a validation error).
func TestSearchStopsHandlerLargeMaxCount(t *testing.T) {
api := createTestApi(t)

stops := api.GtfsManager.GetStops()
require.NotEmpty(t, stops)

query := url.QueryEscape(stops[0].Name)
reqUrl := fmt.Sprintf(
"/api/where/search/stop.json?key=TEST&input=%s&maxCount=1000",
query,
)

resp, model := serveApiAndRetrieveEndpoint(t, api, reqUrl)

assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusOK, model.Code)
}

func TestSearchStopsHandlerFTSInjectionAttempt(t *testing.T) {
api := createTestApi(t)

Expand Down
94 changes: 94 additions & 0 deletions internal/restapi/validation_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package restapi

import (
"fmt"
"net/http"
"strconv"
)

// QueryParamRule defines a validation rule for a single query parameter.
// If the parameter is absent from the request, the rule is skipped (optional params).
// The Validate function receives the raw string value and returns an error message
// and a boolean indicating whether validation passed.
type QueryParamRule struct {
Param string
Validate func(value string) (errMsg string, ok bool)
}

// PositiveIntRule returns a rule that validates a query parameter is a positive integer (> 0).
func PositiveIntRule(param string) QueryParamRule {
return QueryParamRule{
Param: param,
Validate: func(value string) (string, bool) {
n, err := strconv.Atoi(value)
if err != nil {
return "must be a valid integer", false
}
if n <= 0 {
return "must be a positive integer", false
}
return "", true
},
}
}

// IntRangeRule returns a rule that validates a query parameter is an integer within [min, max].
func IntRangeRule(param string, min, max int) QueryParamRule {
return QueryParamRule{
Param: param,
Validate: func(value string) (string, bool) {
n, err := strconv.Atoi(value)
if err != nil {
return "must be a valid integer", false
}
if n < min || n > max {
return fmt.Sprintf("must be between %d and %d", min, max), false
}
return "", true
},
}
}

// NonNegativeIntRule returns a rule that validates a query parameter is a non-negative integer (>= 0).
func NonNegativeIntRule(param string) QueryParamRule {
return QueryParamRule{
Param: param,
Validate: func(value string) (string, bool) {
n, err := strconv.Atoi(value)
if err != nil {
return "must be a valid integer", false
}
if n < 0 {
return "must be a non-negative integer", false
}
return "", true
},
}
}

// ValidateQueryParams applies validation rules to query parameters
// and returns 400 if validation fails before invoking the next handler.
// Parameters not present in the request are skipped (all rules are optional-param-safe).
func ValidateQueryParams(api *RestAPI, rules []QueryParamRule, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
fieldErrors := make(map[string][]string)

for _, rule := range rules {
value := query.Get(rule.Param)
if value == "" {
continue // param not provided, skip
}
if errMsg, ok := rule.Validate(value); !ok {
fieldErrors[rule.Param] = append(fieldErrors[rule.Param], errMsg)
}
}

if len(fieldErrors) > 0 {
api.validationErrorResponse(w, r, fieldErrors)
return
}

next.ServeHTTP(w, r)
})
}
Loading
Loading