diff --git a/internal/oauth/middleware.go b/internal/oauth/middleware.go index b444dec..cbd776d 100644 --- a/internal/oauth/middleware.go +++ b/internal/oauth/middleware.go @@ -1,7 +1,9 @@ package oauth import ( + "context" "database/sql" + "net/http" "time" "github.com/labstack/echo/v4" @@ -12,9 +14,15 @@ type User struct { DID string } +// SessionStore defines the session operations needed by the middleware. +type SessionStore interface { + GetSessionByID(ctx context.Context, id string) (*OAuthSession, error) + DeleteSession(ctx context.Context, id string) error +} + // SessionMiddleware creates middleware that reads the session cookie // and adds the user to the context if the session is valid -func SessionMiddleware(storage *Storage) echo.MiddlewareFunc { +func SessionMiddleware(storage SessionStore) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // Try to get session cookie @@ -38,8 +46,18 @@ func SessionMiddleware(storage *Storage) echo.MiddlewareFunc { // Check if session is expired if session.ExpiresAt.Before(time.Now()) { - // Expired session - continue without user - // TODO: Consider cleaning up expired session here + // Clean up expired session from database + if err := storage.DeleteSession(c.Request().Context(), cookie.Value); err != nil { + c.Logger().Errorf("Failed to delete expired session: %v", err) + } + // Clear the session cookie from browser + c.SetCookie(&http.Cookie{ + Name: "session", + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + }) return next(c) } diff --git a/internal/oauth/middleware_test.go b/internal/oauth/middleware_test.go index a11d007..30a8c38 100644 --- a/internal/oauth/middleware_test.go +++ b/internal/oauth/middleware_test.go @@ -95,7 +95,7 @@ func TestSessionMiddleware(t *testing.T) { assert.Equal(t, "did:plc:test123", capturedUser.DID) }) - t.Run("expired session - sets nil user in context", func(t *testing.T) { + t.Run("expired session - sets nil user in context and deletes session", func(t *testing.T) { // Create an expired session e := echo.New() setupReq := httptest.NewRequest(http.MethodGet, "/", nil) @@ -103,13 +103,17 @@ func TestSessionMiddleware(t *testing.T) { setupCtx := e.NewContext(setupReq, setupRec) session := OAuthSession{ - ID: "expired-session", + ID: "expired-session-cleanup", DID: "did:plc:expired", ExpiresAt: time.Now().Add(-1 * time.Hour), } err := storage.CreateSession(setupCtx.Request().Context(), session) require.NoError(t, err) + // Verify session exists before middleware call + _, err = storage.GetSessionByID(setupCtx.Request().Context(), session.ID) + require.NoError(t, err, "Session should exist before middleware call") + // Make request with expired session cookie req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ @@ -128,6 +132,22 @@ func TestSessionMiddleware(t *testing.T) { err = handler(c) require.NoError(t, err) assert.Nil(t, capturedUser) + + // Verify session was deleted from database + _, err = storage.GetSessionByID(setupCtx.Request().Context(), session.ID) + assert.Error(t, err, "Expired session should be deleted from database") + + // Verify session cookie was cleared in response + cookies := rec.Result().Cookies() + var sessionCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == "session" { + sessionCookie = cookie + break + } + } + require.NotNil(t, sessionCookie, "Session cookie should be set in response") + assert.Equal(t, -1, sessionCookie.MaxAge, "Session cookie MaxAge should be -1 to clear it") }) } diff --git a/internal/oauth/middleware_unit_test.go b/internal/oauth/middleware_unit_test.go new file mode 100644 index 0000000..c21c6a5 --- /dev/null +++ b/internal/oauth/middleware_unit_test.go @@ -0,0 +1,137 @@ +package oauth + +import ( + "context" + "database/sql" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type stubSessionStore struct { + sessions map[string]*OAuthSession + deleteErr error + deleteCalls []string +} + +func (s *stubSessionStore) GetSessionByID(ctx context.Context, id string) (*OAuthSession, error) { + session, ok := s.sessions[id] + if !ok { + return nil, sql.ErrNoRows + } + return session, nil +} + +func (s *stubSessionStore) DeleteSession(ctx context.Context, id string) error { + s.deleteCalls = append(s.deleteCalls, id) + if s.deleteErr != nil { + return s.deleteErr + } + delete(s.sessions, id) + return nil +} + +func TestSessionMiddlewareExpiredSessionDeletes(t *testing.T) { + store := &stubSessionStore{ + sessions: map[string]*OAuthSession{ + "expired-session": { + ID: "expired-session", + DID: "did:plc:expired", + ExpiresAt: time.Now().Add(-1 * time.Minute), + }, + }, + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "session", Value: "expired-session"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser *User + nextCalled := false + handler := SessionMiddleware(store)(func(c echo.Context) error { + nextCalled = true + capturedUser = GetUser(c) + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.NoError(t, err) + assert.True(t, nextCalled) + assert.Nil(t, capturedUser) + assert.Equal(t, []string{"expired-session"}, store.deleteCalls) + _, exists := store.sessions["expired-session"] + assert.False(t, exists) +} + +func TestSessionMiddlewareDeleteErrorDoesNotBlock(t *testing.T) { + store := &stubSessionStore{ + sessions: map[string]*OAuthSession{ + "expired-session": { + ID: "expired-session", + DID: "did:plc:expired", + ExpiresAt: time.Now().Add(-1 * time.Minute), + }, + }, + deleteErr: errors.New("delete failed"), + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "session", Value: "expired-session"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser *User + nextCalled := false + handler := SessionMiddleware(store)(func(c echo.Context) error { + nextCalled = true + capturedUser = GetUser(c) + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.NoError(t, err) + assert.True(t, nextCalled) + assert.Nil(t, capturedUser) + assert.Equal(t, []string{"expired-session"}, store.deleteCalls) + _, exists := store.sessions["expired-session"] + assert.True(t, exists) +} + +func TestSessionMiddlewareValidSessionDoesNotDelete(t *testing.T) { + store := &stubSessionStore{ + sessions: map[string]*OAuthSession{ + "valid-session": { + ID: "valid-session", + DID: "did:plc:valid", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, + }, + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "session", Value: "valid-session"}) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser *User + handler := SessionMiddleware(store)(func(c echo.Context) error { + capturedUser = GetUser(c) + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.NoError(t, err) + require.NotNil(t, capturedUser) + assert.Equal(t, "did:plc:valid", capturedUser.DID) + assert.Empty(t, store.deleteCalls) +}