diff --git a/internal/handler/download_test.go b/internal/handler/download_test.go new file mode 100644 index 0000000..a51f908 --- /dev/null +++ b/internal/handler/download_test.go @@ -0,0 +1,804 @@ +package handler + +import ( + "database/sql" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/purl" + "github.com/git-pkgs/registries/fetch" +) + +// seedPackageWithPURL seeds a package using purl.MakePURLString for PURL generation, +// matching how the handlers construct PURLs internally. +func seedPackageWithPURL(t *testing.T, db *database.DB, store *mockStorage, ecosystem, name, version, filename, content string) { + t.Helper() + + pkgPURL := purl.MakePURLString(ecosystem, name, "") + versionPURL := purl.MakePURLString(ecosystem, name, version) + + pkg := &database.Package{ + PURL: pkgPURL, + Ecosystem: ecosystem, + Name: name, + } + if err := db.UpsertPackage(pkg); err != nil { + t.Fatalf("failed to upsert package: %v", err) + } + + ver := &database.Version{ + PURL: versionPURL, + PackagePURL: pkgPURL, + } + if err := db.UpsertVersion(ver); err != nil { + t.Fatalf("failed to upsert version: %v", err) + } + + storagePath := storage.ArtifactPath(ecosystem, "", name, version, filename) + store.files[storagePath] = []byte(content) + + art := &database.Artifact{ + VersionPURL: versionPURL, + Filename: filename, + UpstreamURL: "https://example.com/" + filename, + StoragePath: sql.NullString{String: storagePath, Valid: true}, + ContentHash: sql.NullString{String: "abc123", Valid: true}, + Size: sql.NullInt64{Int64: int64(len(content)), Valid: true}, + ContentType: sql.NullString{String: "application/octet-stream", Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + if err := db.UpsertArtifact(art); err != nil { + t.Fatalf("failed to upsert artifact: %v", err) + } +} + +func TestGemHandler_DownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "gem", "rails", "7.1.0", "rails-7.1.0.gem", "gem binary data") + + h := NewGemHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/gems/rails-7.1.0.gem") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "gem binary data" { + t.Errorf("body = %q, want %q", body, "gem binary data") + } +} + +func TestGemHandler_DownloadCacheHitMultiHyphen(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "gem", "aws-sdk-s3", "1.142.0", "aws-sdk-s3-1.142.0.gem", "aws gem") + + h := NewGemHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/gems/aws-sdk-s3-1.142.0.gem") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "aws gem" { + t.Errorf("body = %q, want %q", body, "aws gem") + } +} + +func TestGemHandler_InvalidFilename(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGemHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + tests := []struct { + path string + code int + }{ + {"/gems/notagem.tar.gz", http.StatusBadRequest}, + {"/gems/noversion.gem", http.StatusBadRequest}, + {"/gems/.gem", http.StatusBadRequest}, + } + + for _, tt := range tests { + resp, err := http.Get(srv.URL + tt.path) + if err != nil { + t.Fatalf("request to %s failed: %v", tt.path, err) + } + resp.Body.Close() + + if resp.StatusCode != tt.code { + t.Errorf("GET %s: status = %d, want %d", tt.path, resp.StatusCode, tt.code) + } + } +} + +func TestGemHandler_UpstreamProxy(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "upstream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "upstream specs data") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &GemHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/versions") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "upstream specs data" { + t.Errorf("body = %q, want %q", body, "upstream specs data") + } + if resp.Header.Get("X-Test") != "upstream" { + t.Errorf("missing upstream header") + } +} + +func TestGemHandler_CacheMiss(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("fetched gem")), + ContentType: "application/octet-stream", + } + + h := NewGemHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/gems/sinatra-3.0.0.gem") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !fetcher.fetchCalled { + t.Error("expected fetcher to be called on cache miss") + } +} + +func TestGoHandler_DownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "golang", "golang.org/x/text", "v0.14.0", "text@v0.14.0.zip", "go module zip") + + h := NewGoHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/golang.org/x/text/@v/v0.14.0.zip") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "go module zip" { + t.Errorf("body = %q, want %q", body, "go module zip") + } +} + +func TestGoHandler_MethodNotAllowed(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGoHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Post(srv.URL+"/golang.org/x/text/@v/v0.14.0.zip", "", nil) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } +} + +func TestGoHandler_NotFound(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGoHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/some/unknown/path") + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestGoHandler_UnknownAtVSuffix(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewGoHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/golang.org/x/text/@v/v0.14.0.unknown") + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestGoHandler_UpstreamProxy(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "v0.14.0\nv0.13.0\n") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &GoHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + tests := []string{ + "/golang.org/x/text/@v/list", + "/golang.org/x/text/@v/v0.14.0.info", + "/golang.org/x/text/@v/v0.14.0.mod", + "/golang.org/x/text/@latest", + "/sumdb/sum.golang.org/lookup/golang.org/x/text@v0.14.0", + } + + for _, path := range tests { + resp, err := http.Get(srv.URL + path) + if err != nil { + t.Fatalf("GET %s failed: %v", path, err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("GET %s: status = %d, want %d", path, resp.StatusCode, http.StatusOK) + } + } +} + +func TestGoHandler_CacheMiss(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("module zip data")), + ContentType: "application/zip", + } + + h := NewGoHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/example.com/mod/@v/v1.0.0.zip") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !fetcher.fetchCalled { + t.Error("expected fetcher to be called on cache miss") + } +} + +func TestHexHandler_DownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackage(t, db, store, "hex", "phoenix", "1.7.10", "phoenix-1.7.10.tar", "hex tarball") + + h := NewHexHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/tarballs/phoenix-1.7.10.tar") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "hex tarball" { + t.Errorf("body = %q, want %q", body, "hex tarball") + } +} + +func TestHexHandler_InvalidFilename(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewHexHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + tests := []struct { + path string + code int + }{ + {"/tarballs/notatar.zip", http.StatusBadRequest}, + {"/tarballs/noversion.tar", http.StatusBadRequest}, + } + + for _, tt := range tests { + resp, err := http.Get(srv.URL + tt.path) + if err != nil { + t.Fatalf("request to %s failed: %v", tt.path, err) + } + resp.Body.Close() + + if resp.StatusCode != tt.code { + t.Errorf("GET %s: status = %d, want %d", tt.path, resp.StatusCode, tt.code) + } + } +} + +func TestHexHandler_UpstreamProxy(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "hex registry data") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &HexHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/packages/phoenix") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "hex registry data" { + t.Errorf("body = %q, want %q", body, "hex registry data") + } +} + +func TestHexHandler_CacheMiss(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("fetched hex")), + ContentType: "application/x-tar", + } + + h := NewHexHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/tarballs/plug-1.15.0.tar") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !fetcher.fetchCalled { + t.Error("expected fetcher to be called on cache miss") + } +} + +func TestCondaHandler_DownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackageWithPURL(t, db, store, "conda", "main/numpy", "1.24.0", "numpy-1.24.0-py311h64a7726_0.conda", "conda pkg") + + h := NewCondaHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/main/linux-64/numpy-1.24.0-py311h64a7726_0.conda") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "conda pkg" { + t.Errorf("body = %q, want %q", body, "conda pkg") + } +} + +func TestCondaHandler_DownloadTarBz2CacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackageWithPURL(t, db, store, "conda", "main/scipy", "1.11.0", "scipy-1.11.0-py311hb2e3ea1_0.tar.bz2", "tar bz2 data") + + h := NewCondaHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/main/linux-64/scipy-1.11.0-py311hb2e3ea1_0.tar.bz2") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "tar bz2 data" { + t.Errorf("body = %q, want %q", body, "tar bz2 data") + } +} + +func TestCondaHandler_NonPackageFileProxied(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "repodata json") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &CondaHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/main/linux-64/repodata.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "repodata json" { + t.Errorf("body = %q, want %q", body, "repodata json") + } +} + +func TestCondaHandler_CacheMiss(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("fetched conda")), + ContentType: "application/octet-stream", + } + + h := NewCondaHandler(proxy, "http://localhost") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("should not hit upstream for .conda files when fetcher is set") + })) + defer upstream.Close() + h.upstreamURL = upstream.URL + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/conda-forge/linux-64/pandas-2.0.0-py311h320fe9a_0.conda") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !fetcher.fetchCalled { + t.Error("expected fetcher to be called on cache miss") + } +} + +func TestCRANHandler_SourceDownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackageWithPURL(t, db, store, "cran", "ggplot2", "3.4.0", "ggplot2_3.4.0.tar.gz", "cran source") + + h := NewCRANHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/src/contrib/ggplot2_3.4.0.tar.gz") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "cran source" { + t.Errorf("body = %q, want %q", body, "cran source") + } +} + +func TestCRANHandler_BinaryDownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackageWithPURL(t, db, store, "cran", "dplyr", "1.1.0_windows_4.3", "dplyr_1.1.0.zip", "cran binary") + + h := NewCRANHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/bin/windows/contrib/4.3/dplyr_1.1.0.zip") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "cran binary" { + t.Errorf("body = %q, want %q", body, "cran binary") + } +} + +func TestCRANHandler_NonPackageFileProxied(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "PACKAGES index") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &CRANHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/src/contrib/PACKAGES") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "PACKAGES index" { + t.Errorf("body = %q, want %q", body, "PACKAGES index") + } +} + +func TestCRANHandler_SourceNonTarGzProxied(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "some other file") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &CRANHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/src/contrib/somefile.txt") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } +} + +func TestCRANHandler_CacheMiss(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("fetched cran")), + ContentType: "application/x-gzip", + } + + h := NewCRANHandler(proxy, "http://localhost") + h.upstreamURL = "http://should-not-be-reached" + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/src/contrib/tidyr_1.3.0.tar.gz") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !fetcher.fetchCalled { + t.Error("expected fetcher to be called on cache miss") + } +} + +func TestMavenHandler_DownloadCacheHit(t *testing.T) { + proxy, db, store, _ := setupTestProxy(t) + seedPackageWithPURL(t, db, store, "maven", "com.google.guava:guava", "32.1.3-jre", "guava-32.1.3-jre.jar", "jar content") + + h := NewMavenHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/com/google/guava/guava/32.1.3-jre/guava-32.1.3-jre.jar") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "jar content" { + t.Errorf("body = %q, want %q", body, "jar content") + } +} + +func TestMavenHandler_MetadataProxied(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "") + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + h := &MavenHandler{ + proxy: proxy, + upstreamURL: upstream.URL, + proxyURL: "http://localhost", + } + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + paths := []string{ + "/com/google/guava/guava/maven-metadata.xml", + "/com/google/guava/guava/32.1.3-jre/guava-32.1.3-jre.jar.sha1", + "/com/google/guava/guava/32.1.3-jre/guava-32.1.3-jre.jar.md5", + } + + for _, path := range paths { + resp, err := http.Get(srv.URL + path) + if err != nil { + t.Fatalf("GET %s failed: %v", path, err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("GET %s: status = %d, want %d", path, resp.StatusCode, http.StatusOK) + } + } +} + +func TestMavenHandler_EmptyPathNotFound(t *testing.T) { + proxy, _, _, _ := setupTestProxy(t) + h := NewMavenHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/") + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestMavenHandler_ArtifactExtensions(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + + extensions := []string{".jar", ".war", ".ear", ".pom", ".aar", ".klib"} + for _, ext := range extensions { + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("artifact")), + ContentType: "application/java-archive", + } + fetcher.fetchCalled = false + + h := NewMavenHandler(proxy, "http://localhost") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Errorf("should not proxy artifact file %s to upstream", ext) + })) + h.upstreamURL = upstream.URL + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(h.Routes()) + + path := fmt.Sprintf("/com/example/lib/1.0/lib-1.0%s", ext) + resp, err := http.Get(srv.URL + path) + if err != nil { + t.Fatalf("GET %s failed: %v", path, err) + } + resp.Body.Close() + + if !fetcher.fetchCalled { + t.Errorf("fetcher not called for %s", ext) + } + + srv.Close() + upstream.Close() + } +} + +func TestMavenHandler_CacheMiss(t *testing.T) { + proxy, _, _, fetcher := setupTestProxy(t) + fetcher.artifact = &fetch.Artifact{ + Body: io.NopCloser(strings.NewReader("fetched jar")), + ContentType: "application/java-archive", + } + + h := NewMavenHandler(proxy, "http://localhost") + srv := httptest.NewServer(h.Routes()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/org/apache/commons/commons-lang3/3.14.0/commons-lang3-3.14.0.jar") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if !fetcher.fetchCalled { + t.Error("expected fetcher to be called on cache miss") + } +} diff --git a/internal/server/api_test.go b/internal/server/api_test.go index b11983e..88705a1 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -13,6 +13,7 @@ import ( "github.com/git-pkgs/proxy/internal/database" "github.com/git-pkgs/proxy/internal/enrichment" + "github.com/go-chi/chi/v5" ) func TestNewAPIHandler(t *testing.T) { @@ -368,3 +369,109 @@ func TestHandleSearch_WithNullValues(t *testing.T) { t.Errorf("expected 3 hits, got %d", result.Hits) } } + +func TestHandlePackagesListAPI(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + svc := enrichment.New(logger) + + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + db, err := database.Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + // Seed two packages + for _, name := range []string{"api-list-one", "api-list-two"} { + pkg := &database.Package{ + PURL: "pkg:npm/" + name, + Ecosystem: "npm", + Name: name, + } + if err := db.UpsertPackage(pkg); err != nil { + t.Fatalf("UpsertPackage failed: %v", err) + } + ver := &database.Version{ + PURL: "pkg:npm/" + name + "@1.0.0", + PackagePURL: pkg.PURL, + } + if err := db.UpsertVersion(ver); err != nil { + t.Fatalf("UpsertVersion failed: %v", err) + } + art := &database.Artifact{ + VersionPURL: ver.PURL, + Filename: name + "-1.0.0.tgz", + UpstreamURL: "https://registry.npmjs.org/" + name + "/-/" + name + "-1.0.0.tgz", + StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, + } + if err := db.UpsertArtifact(art); err != nil { + t.Fatalf("UpsertArtifact failed: %v", err) + } + } + + h := NewAPIHandler(svc, db) + + r := chi.NewRouter() + r.Get("/api/packages", h.HandlePackagesList) + + req := httptest.NewRequest("GET", "/api/packages", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + var resp PackagesListResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if len(resp.Results) < 2 { + t.Fatalf("expected at least 2 results, got %d", len(resp.Results)) + } + + if resp.SortBy != "hits" { + t.Errorf("expected default sort by hits, got %q", resp.SortBy) + } + + found := false + for _, pkg := range resp.Results { + if pkg.Name == "api-list-one" || pkg.Name == "api-list-two" { + found = true + break + } + } + if !found { + t.Error("expected seeded packages in results") + } +} + +func TestHandlePackagesListAPI_InvalidSort(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + svc := enrichment.New(logger) + + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + db, err := database.Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + h := NewAPIHandler(svc, db) + + r := chi.NewRouter() + r.Get("/api/packages", h.HandlePackagesList) + + req := httptest.NewRequest("GET", "/api/packages?sort=invalid", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400 for invalid sort, got %d", w.Code) + } +} diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go index 22a1167..75c6ccd 100644 --- a/internal/server/middleware_test.go +++ b/internal/server/middleware_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "io" + "log/slog" "net/http" "net/http/httptest" "testing" @@ -92,3 +94,56 @@ func TestActiveRequestsMiddleware_SkipsMetricsEndpoint(t *testing.T) { t.Errorf("expected status 200, got %d", rec.Code) } } + +func TestLoggerMiddleware(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + s := &Server{logger: logger} + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusCreated) + }) + + handler := s.LoggerMiddleware(next) + + req := httptest.NewRequest(http.MethodGet, "/test-path", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if !called { + t.Error("expected next handler to be called") + } + + if rec.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d", rec.Code) + } +} + +func TestResponseWriter_WriteHeader(t *testing.T) { + tests := []struct { + name string + status int + }{ + {"ok", http.StatusOK}, + {"not found", http.StatusNotFound}, + {"internal error", http.StatusInternalServerError}, + {"created", http.StatusCreated}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + rw := &responseWriter{ResponseWriter: rec, status: http.StatusOK} + rw.WriteHeader(tc.status) + + if rw.status != tc.status { + t.Errorf("expected status %d, got %d", tc.status, rw.status) + } + if rec.Code != tc.status { + t.Errorf("expected underlying recorder status %d, got %d", tc.status, rec.Code) + } + }) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 5a704a1..555e80a 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -3,6 +3,7 @@ package server import ( "database/sql" "encoding/json" + "fmt" "io" "log/slog" "net/http" @@ -11,6 +12,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/git-pkgs/proxy/internal/config" "github.com/git-pkgs/proxy/internal/database" @@ -508,3 +510,377 @@ func TestSearchWithNullValues(t *testing.T) { t.Error("expected search results to contain package name") } } + +func TestFormatTimeAgo_AllRanges(t *testing.T) { + tests := []struct { + name string + input time.Time + expected string + }{ + {"zero time", time.Time{}, ""}, + {"now", time.Now(), "just now"}, + {"30 seconds ago", time.Now().Add(-30 * time.Second), "just now"}, + {"1 minute ago", time.Now().Add(-1 * time.Minute), "1 min ago"}, + {"5 minutes ago", time.Now().Add(-5 * time.Minute), "5 mins ago"}, + {"1 hour ago", time.Now().Add(-1 * time.Hour), "1 hour ago"}, + {"3 hours ago", time.Now().Add(-3 * time.Hour), "3 hours ago"}, + {"1 day ago", time.Now().Add(-24 * time.Hour), "1 day ago"}, + {"3 days ago", time.Now().Add(-3 * 24 * time.Hour), "3 days ago"}, + {"10 days ago", time.Now().Add(-10 * 24 * time.Hour), time.Now().Add(-10 * 24 * time.Hour).Format("Jan 2")}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := formatTimeAgo(tc.input) + if got != tc.expected { + t.Errorf("formatTimeAgo() = %q, want %q", got, tc.expected) + } + }) + } +} + +func TestFormatSize_AllUnits(t *testing.T) { + tests := []struct { + bytes int64 + expected string + }{ + {0, "0 B"}, + {500, "500 B"}, + {1024, "1.0 KB"}, + {1536, "1.5 KB"}, + {1048576, "1.0 MB"}, + {1073741824, "1.0 GB"}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + got := formatSize(tc.bytes) + if got != tc.expected { + t.Errorf("formatSize(%d) = %q, want %q", tc.bytes, got, tc.expected) + } + }) + } +} + +func TestCategorizeLicense_NullString(t *testing.T) { + tests := []struct { + name string + license sql.NullString + expected string + }{ + {"invalid null string", sql.NullString{Valid: false}, "unknown"}, + {"MIT", sql.NullString{String: "MIT", Valid: true}, "permissive"}, + {"GPL-3.0", sql.NullString{String: "GPL-3.0", Valid: true}, "copyleft"}, + {"empty string", sql.NullString{String: "", Valid: true}, "unknown"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := categorizeLicense(tc.license) + if got != tc.expected { + t.Errorf("categorizeLicense(%v) = %q, want %q", tc.license, got, tc.expected) + } + }) + } +} + +func TestSearchRedirectsWhenEmpty(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + req := httptest.NewRequest("GET", "/search", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusSeeOther { + t.Errorf("expected status 303, got %d", w.Code) + } + + loc := w.Header().Get("Location") + if loc != "/" { + t.Errorf("expected redirect to /, got %q", loc) + } +} + +func TestPackageShowPage_NotFoundServer(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + req := httptest.NewRequest("GET", "/package/npm/nonexistent-srv", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestVersionShowPage_NotFoundServer(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + req := httptest.NewRequest("GET", "/package/npm/nonexistent-srv/1.0.0", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected status 404, got %d", w.Code) + } +} + +func TestPackageShowPage_WithLicense(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + pkg := &database.Package{ + PURL: "pkg:npm/show-test-lic", + Ecosystem: "npm", + Name: "show-test-lic", + License: sql.NullString{String: "MIT", Valid: true}, + } + if err := ts.db.UpsertPackage(pkg); err != nil { + t.Fatalf("failed to upsert package: %v", err) + } + + ver := &database.Version{ + PURL: "pkg:npm/show-test-lic@1.0.0", + PackagePURL: pkg.PURL, + } + if err := ts.db.UpsertVersion(ver); err != nil { + t.Fatalf("failed to upsert version: %v", err) + } + + req := httptest.NewRequest("GET", "/package/npm/show-test-lic", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "show-test-lic") { + t.Error("expected page to contain the package name") + } +} + +func TestSearchPage_WithSeededResults(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + pkg := &database.Package{ + PURL: "pkg:npm/searchable-pkg", + Ecosystem: "npm", + Name: "searchable-pkg", + } + if err := ts.db.UpsertPackage(pkg); err != nil { + t.Fatalf("failed to upsert package: %v", err) + } + + ver := &database.Version{ + PURL: "pkg:npm/searchable-pkg@1.0.0", + PackagePURL: pkg.PURL, + } + if err := ts.db.UpsertVersion(ver); err != nil { + t.Fatalf("failed to upsert version: %v", err) + } + + artifact := &database.Artifact{ + VersionPURL: ver.PURL, + Filename: "searchable-pkg-1.0.0.tgz", + UpstreamURL: "https://registry.npmjs.org/searchable-pkg/-/searchable-pkg-1.0.0.tgz", + StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, + } + if err := ts.db.UpsertArtifact(artifact); err != nil { + t.Fatalf("failed to upsert artifact: %v", err) + } + + req := httptest.NewRequest("GET", "/search?q=searchable", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "searchable-pkg") { + t.Error("expected search results to contain package name") + } +} + +func TestSearchPage_PaginationMultiPage(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + // Seed 55 packages to exceed one page (limit=50) + for i := 0; i < 55; i++ { + name := fmt.Sprintf("page-test-%03d", i) + pkg := &database.Package{ + PURL: fmt.Sprintf("pkg:npm/%s", name), + Ecosystem: "npm", + Name: name, + } + if err := ts.db.UpsertPackage(pkg); err != nil { + t.Fatalf("failed to upsert package %d: %v", i, err) + } + ver := &database.Version{ + PURL: fmt.Sprintf("pkg:npm/%s@1.0.0", name), + PackagePURL: pkg.PURL, + } + if err := ts.db.UpsertVersion(ver); err != nil { + t.Fatalf("failed to upsert version %d: %v", i, err) + } + artifact := &database.Artifact{ + VersionPURL: ver.PURL, + Filename: fmt.Sprintf("%s-1.0.0.tgz", name), + UpstreamURL: fmt.Sprintf("https://registry.npmjs.org/%s/-/%s-1.0.0.tgz", name, name), + StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, + } + if err := ts.db.UpsertArtifact(artifact); err != nil { + t.Fatalf("failed to upsert artifact %d: %v", i, err) + } + } + + // First page + req := httptest.NewRequest("GET", "/search?q=page-test", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "page-test-") { + t.Error("expected first page to contain results") + } + + // Second page + req = httptest.NewRequest("GET", "/search?q=page-test&page=2", nil) + w = httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200 for page 2, got %d", w.Code) + } +} + +func TestSearchPage_EcosystemFilterWithSeededData(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + // Seed npm package + npmPkg := &database.Package{ + PURL: "pkg:npm/eco-filter-npm", + Ecosystem: "npm", + Name: "eco-filter-npm", + } + if err := ts.db.UpsertPackage(npmPkg); err != nil { + t.Fatalf("failed to upsert npm package: %v", err) + } + npmVer := &database.Version{ + PURL: "pkg:npm/eco-filter-npm@1.0.0", + PackagePURL: npmPkg.PURL, + } + if err := ts.db.UpsertVersion(npmVer); err != nil { + t.Fatalf("failed to upsert npm version: %v", err) + } + npmArt := &database.Artifact{ + VersionPURL: npmVer.PURL, + Filename: "eco-filter-npm-1.0.0.tgz", + UpstreamURL: "https://registry.npmjs.org/eco-filter-npm/-/eco-filter-npm-1.0.0.tgz", + StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, + } + if err := ts.db.UpsertArtifact(npmArt); err != nil { + t.Fatalf("failed to upsert npm artifact: %v", err) + } + + // Seed pypi package + pypiPkg := &database.Package{ + PURL: "pkg:pypi/eco-filter-pypi", + Ecosystem: "pypi", + Name: "eco-filter-pypi", + } + if err := ts.db.UpsertPackage(pypiPkg); err != nil { + t.Fatalf("failed to upsert pypi package: %v", err) + } + pypiVer := &database.Version{ + PURL: "pkg:pypi/eco-filter-pypi@1.0.0", + PackagePURL: pypiPkg.PURL, + } + if err := ts.db.UpsertVersion(pypiVer); err != nil { + t.Fatalf("failed to upsert pypi version: %v", err) + } + pypiArt := &database.Artifact{ + VersionPURL: pypiVer.PURL, + Filename: "eco-filter-pypi-1.0.0.tar.gz", + UpstreamURL: "https://files.pythonhosted.org/eco-filter-pypi-1.0.0.tar.gz", + StoragePath: sql.NullString{String: "/tmp/test.tar.gz", Valid: true}, + } + if err := ts.db.UpsertArtifact(pypiArt); err != nil { + t.Fatalf("failed to upsert pypi artifact: %v", err) + } + + // Search with ecosystem filter for npm only + req := httptest.NewRequest("GET", "/search?q=eco-filter&ecosystem=npm", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "eco-filter-npm") { + t.Error("expected npm package in filtered results") + } + if strings.Contains(body, "eco-filter-pypi") { + t.Error("did not expect pypi package in npm-filtered results") + } +} + +func TestHandlePackagesListPage(t *testing.T) { + ts := newTestServer(t) + defer ts.close() + + pkg := &database.Package{ + PURL: "pkg:npm/list-test", + Ecosystem: "npm", + Name: "list-test", + } + if err := ts.db.UpsertPackage(pkg); err != nil { + t.Fatalf("failed to upsert package: %v", err) + } + + ver := &database.Version{ + PURL: "pkg:npm/list-test@1.0.0", + PackagePURL: pkg.PURL, + } + if err := ts.db.UpsertVersion(ver); err != nil { + t.Fatalf("failed to upsert version: %v", err) + } + + artifact := &database.Artifact{ + VersionPURL: ver.PURL, + Filename: "list-test-1.0.0.tgz", + UpstreamURL: "https://registry.npmjs.org/list-test/-/list-test-1.0.0.tgz", + StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, + } + if err := ts.db.UpsertArtifact(artifact); err != nil { + t.Fatalf("failed to upsert artifact: %v", err) + } + + req := httptest.NewRequest("GET", "/packages", nil) + w := httptest.NewRecorder() + ts.handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "list-test") { + t.Error("expected packages list to contain seeded package") + } +}