diff --git a/Makefile b/Makefile index bf0962c..f4b9f6b 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ $(FMTSTAMP): $(GOFILES) $(GOTESTFILES) lint: $(LINTSTAMP) ## Run linters $(LINTSTAMP): $(GOFILES) $(GOTESTFILES) - golangci-lint run --verbose + golangci-lint run --disable goheader --verbose touch $@ ## Testing: diff --git a/internal/helpers.go b/internal/helpers.go index 4913d6c..9b0aa64 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -88,17 +88,17 @@ func removeHopByHopHeaders(resp *http.Response) { } } -// updateStoredHeaders updates the stored response headers with the -// headers from the revalidated response, excluding hop-by-hop headers -// and the Content-Length header, as per RFC 9111 §3.2. -func updateStoredHeaders(storedResp, resp *http.Response) { - omitted := hopByHopHeaders(resp.Header) - omitted["Content-Length"] = struct{}{} - for hdr, val := range resp.Header { - if _, ok := omitted[hdr]; ok { +// mergeResponseHeaders merges the headers from the revalidated response into +// the stored response, excluding hop-by-hop headers and the Content-Length +// header, as per RFC 9111 §3.2. +func mergeResponseHeaders(targetResp *http.Response, srcHdrs http.Header) { + nonCacheHdrs := hopByHopHeaders(srcHdrs) + nonCacheHdrs["Content-Length"] = struct{}{} + for hdr, val := range srcHdrs { + if _, ok := nonCacheHdrs[hdr]; ok { continue } - storedResp.Header[hdr] = val + targetResp.Header[hdr] = val } } diff --git a/internal/validationresponsehandler.go b/internal/validationresponsehandler.go index 292e705..428c370 100644 --- a/internal/validationresponsehandler.go +++ b/internal/validationresponsehandler.go @@ -80,7 +80,16 @@ func (r *validationResponseHandler) HandleValidationResponse( if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { // RFC 9111 §4.3.3 Handling Validation Responses (304 Not Modified) // RFC 9111 §4.3.4 Freshening Stored Responses upon Validation - updateStoredHeaders(ctx.Stored.Data, resp) + mergeResponseHeaders(ctx.Stored.Data, resp.Header) + _ = r.rs.StoreResponse( + req, + ctx.Stored.Data, + ctx.URLKey, + ctx.Refs, + ctx.Start, + ctx.End, + ctx.RefIndex, + ) CacheStatusRevalidated.ApplyTo(ctx.Stored.Data.Header) r.l.LogCacheRevalidated(req, ctx.URLKey, ctx.ToMisc(nil)) return ctx.Stored.Data, nil diff --git a/internal/validationresponsehandler_test.go b/internal/validationresponsehandler_test.go index e2e8740..9c6aafd 100644 --- a/internal/validationresponsehandler_test.go +++ b/internal/validationresponsehandler_test.go @@ -65,6 +65,14 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) { l: noopLogger, }, setup: func(tt *testing.T, handler *validationResponseHandler) args { + handler.rs = &MockResponseStorer{ + StoreResponseFunc: func(req *http.Request, resp *http.Response, key string, headers ResponseRefs, reqTime, respTime time.Time, refIndex int) error { + testutil.AssertEqual(tt, "key", key) + testutil.AssertTrue(tt, respTime.Equal(base)) + testutil.AssertTrue(tt, reqTime.Equal(base)) + return nil + }, + } return args{ req: &http.Request{Method: http.MethodGet}, resp: &http.Response{StatusCode: http.StatusNotModified, Header: http.Header{}}, diff --git a/roundtripper_test.go b/roundtripper_test.go index 69771b1..06d3820 100644 --- a/roundtripper_test.go +++ b/roundtripper_test.go @@ -22,13 +22,14 @@ import ( "net/http/httptest" "net/url" "os" + "sync/atomic" "testing" "time" "github.com/bartventer/httpcache/internal" "github.com/bartventer/httpcache/internal/testutil" "github.com/bartventer/httpcache/store" - _ "github.com/bartventer/httpcache/store/memcache" + "github.com/bartventer/httpcache/store/memcache" ) func mockTransport(fields func(rt *transport)) *transport { @@ -860,3 +861,99 @@ func Test_transport_Vary(t *testing.T) { testutil.AssertEqual(t, tc.wantBody, string(body), i) } } + +// This test verifies that when a cached response is revalidated via a 304 Not +// Modified, the cache entry is updated with any new headers from the 304 +// response, and subsequent requests can HIT the cache again until it becomes +// stale once more. +func Test_transport_RevalidationUpdatesCache(t *testing.T) { + var originCalls atomic.Int32 + + const etag = `"v1"` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + originCalls.Add(1) + + // Revalidation path: client sends validator, server says cached body is still valid + if r.Header.Get("If-None-Match") == etag { + w.Header().Set("ETag", etag) + w.Header().Set("Cache-Control", "max-age=1") + w.Header().Set("Expires", time.Now().Add(1*time.Second).UTC().Format(http.TimeFormat)) + w.WriteHeader(http.StatusNotModified) + return + } + + // Initial fetch + w.Header().Set("ETag", etag) + w.Header().Set("Cache-Control", "max-age=1") + w.Header().Set("Expires", time.Now().Add(1*time.Second).UTC().Format(http.TimeFormat)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello")) + })) + defer server.Close() + + c := memcache.Open() + tr := newTransport(c) + + req, _ := http.NewRequest(http.MethodGet, server.URL, nil) + + tests := []struct { + name string + expectedStatusCode int + expectedCacheStatus string + expectedBody string + expectedOriginCalls int32 + preReqFunc func() + }{ + { + name: "Initial request should be a MISS", + expectedStatusCode: http.StatusOK, + expectedCacheStatus: internal.CacheStatusMiss.Value, + expectedBody: "hello", + expectedOriginCalls: 1, + }, + { + name: "Second request should be a HIT", + expectedStatusCode: http.StatusOK, + expectedCacheStatus: internal.CacheStatusHit.Value, + expectedBody: "hello", + expectedOriginCalls: 1, + }, + { + name: "After becoming stale, request should be REVALIDATED via 304", + expectedStatusCode: http.StatusOK, + expectedCacheStatus: internal.CacheStatusRevalidated.Value, + expectedBody: "hello", + expectedOriginCalls: 2, + preReqFunc: func() { + time.Sleep(1100 * time.Millisecond) + }, + }, + { + name: "After revalidation, request should be HIT again", + expectedStatusCode: http.StatusOK, + expectedCacheStatus: internal.CacheStatusHit.Value, + expectedBody: "hello", + expectedOriginCalls: 2, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.preReqFunc != nil { + tc.preReqFunc() + } + resp, err := tr.RoundTrip(req) + testutil.RequireNoError(t, err) + testutil.AssertEqual(t, tc.expectedStatusCode, resp.StatusCode) + testutil.AssertEqual( + t, + tc.expectedCacheStatus, + resp.Header.Get(internal.CacheStatusHeader), + ) + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + testutil.AssertEqual(t, tc.expectedBody, string(body)) + testutil.AssertEqual(t, tc.expectedOriginCalls, originCalls.Load()) + }) + } +}