Skip to content
Merged
3 changes: 3 additions & 0 deletions cmd/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ func BuildApplication(ctx context.Context, cfg appconf.Config, gtfsCfg gtfs.Conf
var directionCalculator *gtfs.AdvancedDirectionCalculator
if gtfsManager != nil {
directionCalculator = gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries)
// Register the calculator on the manager so ForceUpdate can refresh its
// queries pointer (and evict the direction cache) after every DB hot-swap.
gtfsManager.DirectionCalculator = directionCalculator
}

// Select clock implementation based on environment
Expand Down
85 changes: 67 additions & 18 deletions internal/gtfs/advanced_direction_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ const (
// AdvancedDirectionCalculator implements the OneBusAway Java algorithm for stop direction calculation
type AdvancedDirectionCalculator struct {
queries *gtfsdb.Queries
queriesMu sync.RWMutex // Protects queries pointer
standardDeviationThreshold float64
shapeCache map[string][]gtfsdb.GetShapePointsWithDistanceRow // Cache of all shape data for bulk operations
initialized atomic.Bool // Tracks whether concurrent operations have started
cacheMutex sync.RWMutex // Protects map access
cacheMutex sync.RWMutex // Protects shapeCache map access
// directionResults caches computed stop directions.
// Lifecycle note: This map grows indefinitely for the lifetime of the application.
// Unbounded growth is acceptable here because it is strictly bounded by the finite
// number of valid real-world stops, and computed directions remain stable across GTFS reloads.
// Only non-error results are cached; transient DB errors are never stored so that
// a recovered database will be retried on the next request.
// Lifecycle note: This map caches computed directions to reduce database load.
// It is explicitly cleared during GTFS reloads (via UpdateQueries) to prevent
// stale directions from persisting across dataset updates.
directionResults sync.Map // Cached direction results (stopID -> string), includes negative cache
requestGroup singleflight.Group // Prevents duplicate concurrent computations for the same stop
}
Expand All @@ -45,6 +48,19 @@ func NewAdvancedDirectionCalculator(queries *gtfsdb.Queries) *AdvancedDirectionC
}
}

// UpdateQueries replaces the queries pointer used for on-demand DB lookups.
// Call this after a GTFS hot-swap so the calculator queries the new database.
// It also clears the direction result cache so stale entries from the old database
// are not served.
func (adc *AdvancedDirectionCalculator) UpdateQueries(queries *gtfsdb.Queries) {
adc.queriesMu.Lock()
adc.queries = queries
adc.queriesMu.Unlock()

// Evict all cached directions so they are recomputed against the new DB.
adc.directionResults.Clear()
}

// SetStandardDeviationThreshold sets the standard deviation threshold for direction variance checking.
// IMPORTANT: This must be called before any concurrent operations begin.
// Returns an error if called after CalculateStopDirection has been invoked.
Expand Down Expand Up @@ -100,11 +116,17 @@ func (adc *AdvancedDirectionCalculator) CalculateStopDirection(ctx context.Conte
}

// Actually compute it (Hits the DB)
computedDir := adc.computeFromShapes(context.WithoutCancel(ctx), stopID)
computedDir, err := adc.computeFromShapes(context.WithoutCancel(ctx), stopID)

// Store in sync.Map for all future requests
adc.directionResults.Store(stopID, computedDir)
// Only cache when there was no transient error. A transient error (e.g. DB
// connection lost) must not permanently poison the cache; omitting it here
// means the next request will retry the DB.
if err == nil {
adc.directionResults.Store(stopID, computedDir)
}

// Intentionally return nil so singleflight shares the empty fallback result with concurrent callers.
// Since we skip caching on error, future requests will safely retry the DB.
return computedDir, nil
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, even if computeFromShapes returned an err, you're returning computedDir and a nil error. If the intention is to ignore errors, you should leave a comment here explaining why it's safe to do so.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional so we gracefully fallback to an empty direction "" without failing the API. Returning nil also ensures singleflight correctly shares this fallback with concurrent callers. I've added a comment to clarify this. Thanks!

})

Expand Down Expand Up @@ -156,15 +178,22 @@ func (adc *AdvancedDirectionCalculator) translateGtfsDirection(direction string)
return ""
}

// computeFromShapes calculates direction from shape data using the Java algorithm
func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, stopID string) string {
// computeFromShapes calculates direction from shape data using the Java algorithm.
// Returns (direction, nil) on success, ("", nil) when there is legitimately no shape
// data for the stop (safe to cache), or ("", err) on a transient database error
// (must NOT be cached so the next request retries the DB).
func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, stopID string) (string, error) {

stopTrips, err := adc.queries.GetStopsWithShapeContext(ctx, stopID)
adc.queriesMu.RLock()
q := adc.queries
adc.queriesMu.RUnlock()

stopTrips, err := q.GetStopsWithShapeContext(ctx, stopID)
if err != nil {
slog.Warn("failed to get stop shape context",
slog.String("stopID", stopID),
slog.String("error", err.Error()))
return ""
return "", err
}

// Collect orientations from all trips, using cache to avoid duplicates
Expand All @@ -183,6 +212,8 @@ func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, s
stopLon = stopTrips[0].Lon
}

var lastTransientErr error

for _, stopTrip := range stopTrips {
if !stopTrip.ShapeID.Valid {
continue
Expand All @@ -209,6 +240,13 @@ func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, s
// Calculate orientation at this stop location using shape point window
orientation, err := adc.calculateOrientationAtStop(ctx, shapeID, distTraveled, stopLat, stopLon)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
slog.Warn("failed to calculate orientation at stop",
slog.String("stopID", stopID),
slog.String("shapeID", shapeID),
slog.String("error", err.Error()))
lastTransientErr = err
}
continue
}

Expand All @@ -218,12 +256,15 @@ func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, s
}

if len(orientations) == 0 {
return ""
if lastTransientErr != nil {
return "", lastTransientErr
}
return "", nil
}

// Single orientation - return it directly
if len(orientations) == 1 {
return adc.getAngleAsDirection(orientations[0])
return adc.getAngleAsDirection(orientations[0]), nil
}

// Calculate mean orientation vector
Expand All @@ -239,7 +280,7 @@ func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, s
// Intentional improvement over Java's exact == 0.0 comparison;
// floating-point mean of cos/sin values is unlikely to be exactly zero.
if math.Abs(xMu) < 1e-6 && math.Abs(yMu) < 1e-6 {
return ""
return "", nil
}

// Calculate standard deviation and compare against threshold
Expand All @@ -249,7 +290,7 @@ func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, s
xStdDev := math.Sqrt(xVariance)
yStdDev := math.Sqrt(yVariance)
if xStdDev > adc.standardDeviationThreshold || yStdDev > adc.standardDeviationThreshold {
return "" // Too much variance
return "", nil // Too much variance
}

// Calculate median orientation
Expand All @@ -273,7 +314,7 @@ func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, s
sort.Float64s(normalizedThetas)
thetaMedian := median(normalizedThetas)

return adc.getAngleAsDirection(thetaMedian)
return adc.getAngleAsDirection(thetaMedian), nil
}

// calculateOrientationAtStop calculates the orientation at a stop using a window of shape points
Expand All @@ -297,10 +338,18 @@ func (adc *AdvancedDirectionCalculator) calculateOrientationAtStop(ctx context.C
}
} else {
// Fall back to database query if no cache
shapePoints, err = adc.queries.GetShapePointsWithDistance(ctx, shapeID)
if err != nil || len(shapePoints) < 2 {
adc.queriesMu.RLock()
q := adc.queries
adc.queriesMu.RUnlock()
shapePoints, err = q.GetShapePointsWithDistance(ctx, shapeID)
if err != nil {
return 0, err
}
if len(shapePoints) < 2 {
// Insufficient points is a data condition, not a transient error.
// Return ErrNoRows so the caller treats this the same as "no shape data".
return 0, sql.ErrNoRows
}
}

closestIdx := 0
Expand Down
57 changes: 41 additions & 16 deletions internal/gtfs/advanced_direction_calculator_test.go
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an easy way you can add a test for the new behavior you've added, verifying that a returned error value isn't cached?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I've added TestTransientDBError_NotCached. It simulates a db failure by closing an in-memory database and verifies that the resulting empty string is not cached

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"sync"
"testing"

_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"maglev.onebusaway.org/gtfsdb"
"maglev.onebusaway.org/internal/models"
)

Expand Down Expand Up @@ -125,6 +127,30 @@ func TestCalculateStopDirectionResultCache(t *testing.T) {
assert.Equal(t, "", result, "second call should return cached empty result")
}

// TestTransientDBError_NotCached verifies that if the DB fails (simulated by closing it),
// the resulting empty direction is NOT cached, allowing future requests to retry.
func TestTransientDBError_NotCached(t *testing.T) {
// Simulate a transient DB failure by opening and immediately closing an in-memory DB
rawDB, err := sql.Open("sqlite3", ":memory:")
assert.NoError(t, err)
err = rawDB.Close()
assert.NoError(t, err)

brokenQueries := gtfsdb.New(rawDB)

calc := NewAdvancedDirectionCalculator(brokenQueries)

stopID := "transient-error-stop"

// The query will fail, gracefully returning an empty direction
result := calc.CalculateStopDirection(context.Background(), stopID)
assert.Equal(t, "", result, "should return empty string on DB error")

// Critical check: ensure the failure was NOT permanently cached
_, cached := calc.directionResults.Load(stopID)
assert.False(t, cached, "transient DB error result must not be cached in directionResults")
}

func TestCalculateStopDirectionPrecomputedAbbreviations(t *testing.T) {
// Verify all compass abbreviations that DirectionPrecomputer writes to SQLite
// are correctly recognized by CalculateStopDirection via translateGtfsDirection.
Expand Down Expand Up @@ -286,7 +312,8 @@ func TestComputeFromShapes_NoShapeData(t *testing.T) {
_, calc := getSharedTestComponents(t)

// Test with a non-existent stop
direction := calc.computeFromShapes(ctx, "nonexistent")
direction, err := calc.computeFromShapes(ctx, "nonexistent")
assert.NoError(t, err)
assert.Equal(t, "", direction)
}

Expand All @@ -296,7 +323,8 @@ func TestComputeFromShapes_SingleOrientation(t *testing.T) {
_, calc := getSharedTestComponents(t)

// Test with actual stop data - single orientation path will be taken if only one trip
direction := calc.computeFromShapes(ctx, "7000")
direction, err := calc.computeFromShapes(ctx, "7000")
assert.NoError(t, err)
// Direction should be valid or empty
assert.True(t, direction == "" || len(direction) <= 2)
}
Expand All @@ -314,7 +342,8 @@ func TestComputeFromShapes_StandardDeviationThreshold(t *testing.T) {
assert.NoError(t, err)

// Test with a stop that might have multiple trips
direction := calc.computeFromShapes(ctx, "7000")
direction, err := calc.computeFromShapes(ctx, "7000")
assert.NoError(t, err)
// With low threshold, high variance might return empty
assert.True(t, direction == "" || len(direction) <= 2)
}
Expand All @@ -331,10 +360,9 @@ func TestCalculateOrientationAtStop_WithDistanceTraveled(t *testing.T) {

// Test with distance traveled
orientation, err := calc.calculateOrientationAtStop(ctx, "19_0_1", 100.0, 0, 0)
if err == nil {
assert.NoError(t, err)
assert.GreaterOrEqual(t, orientation, -math.Pi)
assert.LessOrEqual(t, orientation, math.Pi)
}
}

func TestCalculateOrientationAtStop_GeographicMatching(t *testing.T) {
Expand All @@ -351,20 +379,19 @@ func TestCalculateOrientationAtStop_GeographicMatching(t *testing.T) {
stopLat := shapes[0].Lat
stopLon := shapes[0].Lon
orientation, err := calc.calculateOrientationAtStop(ctx, "19_0_1", -1.0, stopLat, stopLon)
if err == nil {
assert.NoError(t, err)
assert.GreaterOrEqual(t, orientation, -math.Pi)
assert.LessOrEqual(t, orientation, math.Pi)
}
}

func TestCalculateOrientationAtStop_NoShapePoints(t *testing.T) {
ctx := context.Background()
_, calc := getSharedTestComponents(t)

// Test with non-existent shape - should return error or 0 orientation
// Test with non-existent shape - should return error
orientation, err := calc.calculateOrientationAtStop(ctx, "nonexistent", 0, 0, 0)
// Either err is not nil, or orientation is 0
assert.True(t, err != nil || orientation == 0)
assert.Error(t, err)
assert.Equal(t, float64(0), orientation)
}

func TestCalculateOrientationAtStop_EdgeCases(t *testing.T) {
Expand All @@ -379,21 +406,19 @@ func TestCalculateOrientationAtStop_EdgeCases(t *testing.T) {
// Test at the very beginning of the shape
if len(shapes) > 0 && shapes[0].ShapeDistTraveled.Valid {
orientation, err := calc.calculateOrientationAtStop(ctx, "19_0_1", shapes[0].ShapeDistTraveled.Float64, 0, 0)
if err == nil {
assert.NoError(t, err)
assert.GreaterOrEqual(t, orientation, -math.Pi)
assert.LessOrEqual(t, orientation, math.Pi)
}
}

// Test at the very end of the shape
if len(shapes) > 1 && shapes[len(shapes)-1].ShapeDistTraveled.Valid {
orientation, err := calc.calculateOrientationAtStop(ctx, "19_0_1", shapes[len(shapes)-1].ShapeDistTraveled.Float64, 0, 0)
if err == nil {
assert.NoError(t, err)
assert.GreaterOrEqual(t, orientation, -math.Pi)
assert.LessOrEqual(t, orientation, math.Pi)
}
}
}

func TestGetAngleAsDirection_EdgeCases(t *testing.T) {
calc := &AdvancedDirectionCalculator{}
Expand Down Expand Up @@ -489,7 +514,7 @@ func TestBulkQuery_GetStopsWithShapeContextByIDs(t *testing.T) {
results, err := manager.GtfsDB.Queries.GetStopsWithShapeContextByIDs(ctx, stopIDs)

// Verify Results
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotEmpty(t, results)

// We expect AT LEAST as many rows as IDs we asked for.
Expand Down Expand Up @@ -522,7 +547,7 @@ func TestBulkQuery_GetShapePointsByIDs(t *testing.T) {
points, err := manager.GtfsDB.Queries.GetShapePointsByIDs(ctx, shapeIDs)

// Verify
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotEmpty(t, points)

// Verify sorting
Expand Down
5 changes: 5 additions & 0 deletions internal/gtfs/gtfs_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ type Manager struct {
// Exported metrics client dependency
Metrics *metrics.Metrics

// DirectionCalculator is set by the application layer after construction so that
// ForceUpdate can refresh its queries pointer whenever the DB is hot-swapped.
// May be nil when running without direction computation (e.g. in tests).
DirectionCalculator *AdvancedDirectionCalculator

// Tracks the last successful update time per feed
feedLastUpdate map[string]time.Time

Expand Down
2 changes: 1 addition & 1 deletion internal/gtfs/spatial_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestQueryStopsInBounds(t *testing.T) {
name string
bounds utils.CoordinateBounds
expectedCount int
expectedIDs []string
expectedIDs []string
}{
{
name: "NormalBoundingBox_SomeStops",
Expand Down
13 changes: 13 additions & 0 deletions internal/gtfs/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,17 @@ func (manager *Manager) ForceUpdate(ctx context.Context) error {
dbConfig := newGTFSDBConfig(finalDBPath, manager.config)
if reopenedClient, reopenErr := gtfsdb.NewClient(dbConfig); reopenErr == nil {
manager.GtfsDB = reopenedClient
if manager.DirectionCalculator != nil {
manager.DirectionCalculator.UpdateQueries(reopenedClient.Queries)
}
logging.LogOperation(logger, "recovery_successful_old_db_reopened")
} else {
logging.LogError(logger, "CRITICAL: Failed to recover old DB after rename failure", reopenErr)
logging.LogOperation(logger, "setting manager.gtfsDB to nil")
manager.GtfsDB = nil
if manager.DirectionCalculator != nil {
manager.DirectionCalculator.UpdateQueries(nil)
}

manager.isHealthy = false
}
Expand Down Expand Up @@ -361,6 +367,13 @@ func (manager *Manager) ForceUpdate(ctx context.Context) error {
manager.stopSpatialIndex = newStopSpatialIndex
manager.regionBounds = newRegionBounds

// Refresh the direction calculator's queries pointer so on-demand lookups
// use the new database. This also clears the direction result cache so stale
// entries from the old database are not served.
if manager.DirectionCalculator != nil {
manager.DirectionCalculator.UpdateQueries(client.Queries)
}

// Note: the epoch is incremented after GtfsDB is assigned. A narrow race exists where
// a reader snapshots epochBefore, then reads the new DB pointer (already live), queries
// the new DB, and writes to the cache before the epoch advances — only for this clear
Expand Down
1 change: 0 additions & 1 deletion internal/restapi/trips_for_location_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,6 @@ func (rb *referenceBuilder) createRoute(route gtfsdb.Route) models.Route {

}


func (rb *referenceBuilder) buildTripReferences() error {
rb.tripsRefList = make([]models.Trip, 0, len(rb.presentTrips))

Expand Down
Loading
Loading