diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go
index 5756c6d..4f13ea2 100644
--- a/cmd/proxy/main.go
+++ b/cmd/proxy/main.go
@@ -103,6 +103,8 @@ import (
"github.com/git-pkgs/proxy/internal/server"
)
+const defaultTopN = 10
+
var (
// Version is set at build time.
Version = "dev"
@@ -211,7 +213,7 @@ func runServe() {
cfg.Storage.URL = *storageURL
}
if *storagePath != "" {
- cfg.Storage.Path = *storagePath
+ cfg.Storage.Path = *storagePath //nolint:staticcheck // backwards compat
}
if *databaseDriver != "" {
cfg.Database.Driver = *databaseDriver
@@ -247,7 +249,6 @@ func runServe() {
// Handle shutdown signals
ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
go func() {
sigCh := make(chan os.Signal, 1)
@@ -266,10 +267,12 @@ func runServe() {
// Wait for shutdown or error
select {
case <-ctx.Done():
+ cancel()
if err := srv.Shutdown(context.Background()); err != nil {
logger.Error("shutdown error", "error", err)
}
case err := <-errCh:
+ cancel()
if err != nil {
logger.Error("server error", "error", err)
os.Exit(1)
@@ -283,8 +286,8 @@ func runStats() {
databasePath := fs.String("database-path", "./cache/proxy.db", "Path to SQLite database file")
databaseURL := fs.String("database-url", "", "PostgreSQL connection URL")
asJSON := fs.Bool("json", false, "Output as JSON")
- popular := fs.Int("popular", 10, "Show top N most popular packages")
- recent := fs.Int("recent", 10, "Show N recently cached packages")
+ popular := fs.Int("popular", defaultTopN, "Show top N most popular packages")
+ recent := fs.Int("recent", defaultTopN, "Show N recently cached packages")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "git-pkgs proxy - Show cache statistics\n\n")
@@ -330,32 +333,37 @@ func runStats() {
fmt.Fprintf(os.Stderr, "error opening database: %v\n", err)
os.Exit(1)
}
+
+ if err := printStats(db, *popular, *recent, *asJSON); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+}
+
+func printStats(db *database.DB, popular, recent int, asJSON bool) error {
defer func() { _ = db.Close() }()
- // Get stats
stats, err := db.GetCacheStats()
if err != nil {
- fmt.Fprintf(os.Stderr, "error getting stats: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error getting stats: %w", err)
}
- popularPkgs, err := db.GetMostPopularPackages(*popular)
+ popularPkgs, err := db.GetMostPopularPackages(popular)
if err != nil {
- fmt.Fprintf(os.Stderr, "error getting popular packages: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error getting popular packages: %w", err)
}
- recentPkgs, err := db.GetRecentlyCachedPackages(*recent)
+ recentPkgs, err := db.GetRecentlyCachedPackages(recent)
if err != nil {
- fmt.Fprintf(os.Stderr, "error getting recent packages: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error getting recent packages: %w", err)
}
- if *asJSON {
+ if asJSON {
outputJSON(stats, popularPkgs, recentPkgs)
} else {
outputText(stats, popularPkgs, recentPkgs)
}
+ return nil
}
type jsonOutput struct {
diff --git a/internal/config/config.go b/internal/config/config.go
index 76067ae..3bc45af 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -109,8 +109,9 @@ type StorageConfig struct {
URL string `json:"url" yaml:"url"`
// Path is the directory where cached artifacts are stored.
- // Deprecated: Use URL with file:// scheme instead.
// If URL is empty, this is used as file://{Path}.
+ //
+ // Deprecated: Use URL with file:// scheme instead.
Path string `json:"path" yaml:"path"`
// MaxSize is the maximum cache size (e.g., "10GB", "500MB").
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 47bacf0..4dd1c17 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -6,6 +6,12 @@ import (
"testing"
)
+const (
+ testDriverPostgres = "postgres"
+ testInvalid = "invalid"
+ testLevelDebug = "debug"
+)
+
func TestDefault(t *testing.T) {
cfg := Default()
@@ -63,27 +69,27 @@ func TestValidate(t *testing.T) {
},
{
name: "postgres without url",
- modify: func(c *Config) { c.Database.Driver = "postgres"; c.Database.URL = "" },
+ modify: func(c *Config) { c.Database.Driver = testDriverPostgres; c.Database.URL = "" },
wantErr: true,
},
{
name: "postgres with url",
- modify: func(c *Config) { c.Database.Driver = "postgres"; c.Database.URL = "postgres://localhost/test" },
+ modify: func(c *Config) { c.Database.Driver = testDriverPostgres; c.Database.URL = "postgres://localhost/test" },
wantErr: false,
},
{
name: "invalid log level",
- modify: func(c *Config) { c.Log.Level = "invalid" },
+ modify: func(c *Config) { c.Log.Level = testInvalid },
wantErr: true,
},
{
name: "invalid log format",
- modify: func(c *Config) { c.Log.Format = "invalid" },
+ modify: func(c *Config) { c.Log.Format = testInvalid },
wantErr: true,
},
{
name: "invalid max size",
- modify: func(c *Config) { c.Storage.MaxSize = "invalid" },
+ modify: func(c *Config) { c.Storage.MaxSize = testInvalid },
wantErr: true,
},
{
@@ -176,8 +182,8 @@ log:
if cfg.Storage.MaxSize != "5GB" {
t.Errorf("Storage.MaxSize = %q, want %q", cfg.Storage.MaxSize, "5GB")
}
- if cfg.Log.Level != "debug" {
- t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug")
+ if cfg.Log.Level != testLevelDebug {
+ t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, testLevelDebug)
}
if cfg.Log.Format != "json" {
t.Errorf("Log.Format = %q, want %q", cfg.Log.Format, "json")
@@ -215,7 +221,7 @@ func TestLoadFromEnv(t *testing.T) {
t.Setenv("PROXY_LISTEN", ":9000")
t.Setenv("PROXY_BASE_URL", "https://env.example.com")
t.Setenv("PROXY_STORAGE_PATH", "/env/cache")
- t.Setenv("PROXY_LOG_LEVEL", "debug")
+ t.Setenv("PROXY_LOG_LEVEL", testLevelDebug)
cfg.LoadFromEnv()
@@ -228,8 +234,8 @@ func TestLoadFromEnv(t *testing.T) {
if cfg.Storage.Path != "/env/cache" {
t.Errorf("Storage.Path = %q, want %q", cfg.Storage.Path, "/env/cache")
}
- if cfg.Log.Level != "debug" {
- t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug")
+ if cfg.Log.Level != testLevelDebug {
+ t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, testLevelDebug)
}
}
diff --git a/internal/cooldown/cooldown.go b/internal/cooldown/cooldown.go
index 2bb4fc6..f37a2b9 100644
--- a/internal/cooldown/cooldown.go
+++ b/internal/cooldown/cooldown.go
@@ -7,6 +7,8 @@ import (
"time"
)
+const hoursPerDay = 24
+
// Config holds cooldown settings for version filtering.
// Cooldown hides package versions published too recently, giving the community
// time to spot malicious releases before they're pulled into projects.
@@ -112,7 +114,7 @@ func ParseDuration(s string) (time.Duration, error) {
if err != nil {
return 0, fmt.Errorf("invalid duration %q: %w", s, err)
}
- return time.Duration(days * float64(24*time.Hour)), nil
+ return time.Duration(days * float64(hoursPerDay*time.Hour)), nil
}
d, err := time.ParseDuration(s)
diff --git a/internal/database/database.go b/internal/database/database.go
index c04c122..eded6d2 100644
--- a/internal/database/database.go
+++ b/internal/database/database.go
@@ -12,6 +12,8 @@ import (
const SchemaVersion = 1
+const dirPermissions = 0755
+
type Dialect string
const (
@@ -56,7 +58,7 @@ func Create(path string) (*DB, error) {
func Open(path string) (*DB, error) {
if dir := filepath.Dir(path); dir != "." && dir != "/" {
- if err := os.MkdirAll(dir, 0755); err != nil {
+ if err := os.MkdirAll(dir, dirPermissions); err != nil {
return nil, fmt.Errorf("creating database directory: %w", err)
}
}
diff --git a/internal/database/queries_packages_list_test.go b/internal/database/queries_packages_list_test.go
index b77d34d..fc9356c 100644
--- a/internal/database/queries_packages_list_test.go
+++ b/internal/database/queries_packages_list_test.go
@@ -6,118 +6,121 @@ import (
"time"
)
-func TestListCachedPackages(t *testing.T) {
+const testEcosystemNPM = "npm"
+
+func setupListCachedPackagesDB(t *testing.T) *DB {
+ t.Helper()
+
db, err := Create(t.TempDir() + "/test.db")
if err != nil {
t.Fatal(err)
}
- defer func() { _ = db.Close() }()
- // Create test packages
- pkg1 := &Package{
- PURL: "pkg:npm/lodash",
- Ecosystem: "npm",
- Name: "lodash",
- LatestVersion: sql.NullString{String: "4.17.21", Valid: true},
- License: sql.NullString{String: "MIT", Valid: true},
- }
- pkg2 := &Package{
- PURL: "pkg:cargo/serde",
- Ecosystem: "cargo",
- Name: "serde",
- LatestVersion: sql.NullString{String: "1.0.0", Valid: true},
- License: sql.NullString{String: "MIT OR Apache-2.0", Valid: true},
- }
- pkg3 := &Package{
- PURL: "pkg:npm/react",
- Ecosystem: "npm",
- Name: "react",
- LatestVersion: sql.NullString{String: "18.0.0", Valid: true},
- License: sql.NullString{String: "MIT", Valid: true},
- }
+ seedListCachedPackagesData(t, db)
- if err := db.UpsertPackage(pkg1); err != nil {
- t.Fatal(err)
- }
- if err := db.UpsertPackage(pkg2); err != nil {
- t.Fatal(err)
- }
- if err := db.UpsertPackage(pkg3); err != nil {
- t.Fatal(err)
- }
+ return db
+}
- // Create versions
- ver1 := &Version{
- PURL: "pkg:npm/lodash@4.17.21",
- PackagePURL: pkg1.PURL,
- }
- ver2 := &Version{
- PURL: "pkg:cargo/serde@1.0.0",
- PackagePURL: pkg2.PURL,
- }
- ver3 := &Version{
- PURL: "pkg:npm/react@18.0.0",
- PackagePURL: pkg3.PURL,
- }
+func seedListCachedPackagesData(t *testing.T, db *DB) {
+ t.Helper()
- if err := db.UpsertVersion(ver1); err != nil {
- t.Fatal(err)
- }
- if err := db.UpsertVersion(ver2); err != nil {
- t.Fatal(err)
- }
- if err := db.UpsertVersion(ver3); err != nil {
- t.Fatal(err)
+ packages := []*Package{
+ {
+ PURL: "pkg:npm/lodash",
+ Ecosystem: testEcosystemNPM,
+ Name: "lodash",
+ LatestVersion: sql.NullString{String: "4.17.21", Valid: true},
+ License: sql.NullString{String: "MIT", Valid: true},
+ },
+ {
+ PURL: "pkg:cargo/serde",
+ Ecosystem: "cargo",
+ Name: "serde",
+ LatestVersion: sql.NullString{String: "1.0.0", Valid: true},
+ License: sql.NullString{String: "MIT OR Apache-2.0", Valid: true},
+ },
+ {
+ PURL: "pkg:npm/react",
+ Ecosystem: testEcosystemNPM,
+ Name: "react",
+ LatestVersion: sql.NullString{String: "18.0.0", Valid: true},
+ License: sql.NullString{String: "MIT", Valid: true},
+ },
}
- // Create artifacts
- art1 := &Artifact{
- VersionPURL: ver1.PURL,
- Filename: "lodash.tgz",
- UpstreamURL: "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
- StoragePath: sql.NullString{String: "npm/lodash/4.17.21/lodash.tgz", Valid: true},
- Size: sql.NullInt64{Int64: 1024, Valid: true},
- HitCount: 100,
- FetchedAt: sql.NullTime{Time: time.Now(), Valid: true},
- }
- art2 := &Artifact{
- VersionPURL: ver2.PURL,
- Filename: "serde.crate",
- UpstreamURL: "https://crates.io/api/v1/crates/serde/1.0.0/download",
- StoragePath: sql.NullString{String: "cargo/serde/1.0.0/serde.crate", Valid: true},
- Size: sql.NullInt64{Int64: 2048, Valid: true},
- HitCount: 50,
- FetchedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
+ for _, pkg := range packages {
+ if err := db.UpsertPackage(pkg); err != nil {
+ t.Fatal(err)
+ }
}
- art3 := &Artifact{
- VersionPURL: ver3.PURL,
- Filename: "react.tgz",
- UpstreamURL: "https://registry.npmjs.org/react/-/react-18.0.0.tgz",
- StoragePath: sql.NullString{String: "npm/react/18.0.0/react.tgz", Valid: true},
- Size: sql.NullInt64{Int64: 512, Valid: true},
- HitCount: 200,
- FetchedAt: sql.NullTime{Time: time.Now().Add(-2 * time.Hour), Valid: true},
+
+ versions := []*Version{
+ {PURL: "pkg:npm/lodash@4.17.21", PackagePURL: packages[0].PURL},
+ {PURL: "pkg:cargo/serde@1.0.0", PackagePURL: packages[1].PURL},
+ {PURL: "pkg:npm/react@18.0.0", PackagePURL: packages[2].PURL},
}
- if err := db.UpsertArtifact(art1); err != nil {
- t.Fatal(err)
+ for _, ver := range versions {
+ if err := db.UpsertVersion(ver); err != nil {
+ t.Fatal(err)
+ }
}
- if err := db.UpsertArtifact(art2); err != nil {
- t.Fatal(err)
+
+ artifacts := []*Artifact{
+ {
+ VersionPURL: versions[0].PURL,
+ Filename: "lodash.tgz",
+ UpstreamURL: "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
+ StoragePath: sql.NullString{String: "npm/lodash/4.17.21/lodash.tgz", Valid: true},
+ Size: sql.NullInt64{Int64: 1024, Valid: true},
+ HitCount: 100,
+ FetchedAt: sql.NullTime{Time: time.Now(), Valid: true},
+ },
+ {
+ VersionPURL: versions[1].PURL,
+ Filename: "serde.crate",
+ UpstreamURL: "https://crates.io/api/v1/crates/serde/1.0.0/download",
+ StoragePath: sql.NullString{String: "cargo/serde/1.0.0/serde.crate", Valid: true},
+ Size: sql.NullInt64{Int64: 2048, Valid: true},
+ HitCount: 50,
+ FetchedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
+ },
+ {
+ VersionPURL: versions[2].PURL,
+ Filename: "react.tgz",
+ UpstreamURL: "https://registry.npmjs.org/react/-/react-18.0.0.tgz",
+ StoragePath: sql.NullString{String: "npm/react/18.0.0/react.tgz", Valid: true},
+ Size: sql.NullInt64{Int64: 512, Valid: true},
+ HitCount: 200,
+ FetchedAt: sql.NullTime{Time: time.Now().Add(-2 * time.Hour), Valid: true},
+ },
}
- if err := db.UpsertArtifact(art3); err != nil {
- t.Fatal(err)
+
+ for _, art := range artifacts {
+ if err := db.UpsertArtifact(art); err != nil {
+ t.Fatal(err)
+ }
}
+}
- t.Run("list all packages", func(t *testing.T) {
- packages, err := db.ListCachedPackages("", "hits", 10, 0)
+func TestListCachedPackages(t *testing.T) {
+ db := setupListCachedPackagesDB(t)
+ defer func() { _ = db.Close() }()
+
+ listAll := func(ecosystem, sortBy string) []PackageListItem {
+ t.Helper()
+ packages, err := db.ListCachedPackages(ecosystem, sortBy, 10, 0)
if err != nil {
t.Fatal(err)
}
+ return packages
+ }
+
+ t.Run("list all packages", func(t *testing.T) {
+ packages := listAll("", "hits")
if len(packages) != 3 {
t.Errorf("expected 3 packages, got %d", len(packages))
}
- // Should be sorted by hits DESC
if packages[0].Name != "react" {
t.Errorf("expected first package to be react, got %s", packages[0].Name)
}
@@ -127,35 +130,26 @@ func TestListCachedPackages(t *testing.T) {
})
t.Run("filter by ecosystem", func(t *testing.T) {
- packages, err := db.ListCachedPackages("npm", "hits", 10, 0)
- if err != nil {
- t.Fatal(err)
- }
+ packages := listAll(testEcosystemNPM, "hits")
if len(packages) != 2 {
t.Errorf("expected 2 npm packages, got %d", len(packages))
}
for _, pkg := range packages {
- if pkg.Ecosystem != "npm" {
+ if pkg.Ecosystem != testEcosystemNPM {
t.Errorf("expected npm ecosystem, got %s", pkg.Ecosystem)
}
}
})
t.Run("sort by name", func(t *testing.T) {
- packages, err := db.ListCachedPackages("", "name", 10, 0)
- if err != nil {
- t.Fatal(err)
- }
+ packages := listAll("", "name")
if packages[0].Name != "lodash" {
t.Errorf("expected first package to be lodash, got %s", packages[0].Name)
}
})
t.Run("sort by size", func(t *testing.T) {
- packages, err := db.ListCachedPackages("", "size", 10, 0)
- if err != nil {
- t.Fatal(err)
- }
+ packages := listAll("", "size")
if packages[0].Name != "serde" {
t.Errorf("expected first package to be serde (largest), got %s", packages[0].Name)
}
@@ -170,7 +164,7 @@ func TestListCachedPackages(t *testing.T) {
t.Errorf("expected count 3, got %d", count)
}
- count, err = db.CountCachedPackages("npm")
+ count, err = db.CountCachedPackages(testEcosystemNPM)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/database/schema.go b/internal/database/schema.go
index 13363e7..496a129 100644
--- a/internal/database/schema.go
+++ b/internal/database/schema.go
@@ -2,6 +2,8 @@ package database
import "fmt"
+const postgresTimestamp = "TIMESTAMP"
+
// Schema for proxy-specific tables. The packages and versions tables
// are compatible with git-pkgs, allowing the proxy to use an existing
// git-pkgs database as a starting point.
@@ -303,8 +305,8 @@ func (db *DB) MigrateSchema() error {
}
if db.dialect == DialectPostgres {
- packagesColumns["enriched_at"] = "TIMESTAMP"
- packagesColumns["vulns_synced_at"] = "TIMESTAMP"
+ packagesColumns["enriched_at"] = postgresTimestamp
+ packagesColumns["vulns_synced_at"] = postgresTimestamp
}
for column, colType := range packagesColumns {
@@ -313,12 +315,7 @@ func (db *DB) MigrateSchema() error {
return fmt.Errorf("checking column %s: %w", column, err)
}
if !hasCol {
- var alterQuery string
- if db.dialect == DialectPostgres {
- alterQuery = fmt.Sprintf("ALTER TABLE packages ADD COLUMN %s %s", column, colType)
- } else {
- alterQuery = fmt.Sprintf("ALTER TABLE packages ADD COLUMN %s %s", column, colType)
- }
+ alterQuery := fmt.Sprintf("ALTER TABLE packages ADD COLUMN %s %s", column, colType)
if _, err := db.Exec(alterQuery); err != nil {
return fmt.Errorf("adding column %s to packages: %w", column, err)
}
@@ -335,7 +332,7 @@ func (db *DB) MigrateSchema() error {
if db.dialect == DialectPostgres {
versionsColumns["yanked"] = "BOOLEAN DEFAULT FALSE"
- versionsColumns["enriched_at"] = "TIMESTAMP"
+ versionsColumns["enriched_at"] = postgresTimestamp
}
for column, colType := range versionsColumns {
@@ -344,12 +341,7 @@ func (db *DB) MigrateSchema() error {
return fmt.Errorf("checking column %s: %w", column, err)
}
if !hasCol {
- var alterQuery string
- if db.dialect == DialectPostgres {
- alterQuery = fmt.Sprintf("ALTER TABLE versions ADD COLUMN %s %s", column, colType)
- } else {
- alterQuery = fmt.Sprintf("ALTER TABLE versions ADD COLUMN %s %s", column, colType)
- }
+ alterQuery := fmt.Sprintf("ALTER TABLE versions ADD COLUMN %s %s", column, colType)
if _, err := db.Exec(alterQuery); err != nil {
return fmt.Errorf("adding column %s to versions: %w", column, err)
}
diff --git a/internal/handler/cargo.go b/internal/handler/cargo.go
index 8aab5f0..b5a3fb0 100644
--- a/internal/handler/cargo.go
+++ b/internal/handler/cargo.go
@@ -11,6 +11,10 @@ import (
const (
cargoUpstream = "https://index.crates.io"
cargoDownloadBase = "https://static.crates.io/crates"
+
+ cargoIndexLen1 = 1
+ cargoIndexLen2 = 2
+ cargoIndexLen3 = 3
)
// CargoHandler handles cargo registry protocol requests.
@@ -125,11 +129,11 @@ func (h *CargoHandler) buildIndexPath(name string) string {
name = strings.ToLower(name)
switch len(name) {
- case 1:
+ case cargoIndexLen1:
return fmt.Sprintf("1/%s", name)
- case 2:
+ case cargoIndexLen2:
return fmt.Sprintf("2/%s", name)
- case 3:
+ case cargoIndexLen3:
return fmt.Sprintf("3/%c/%s", name[0], name)
default:
return fmt.Sprintf("%s/%s/%s", name[0:2], name[2:4], name)
diff --git a/internal/handler/composer.go b/internal/handler/composer.go
index 2753e4a..3b76a9c 100644
--- a/internal/handler/composer.go
+++ b/internal/handler/composer.go
@@ -12,8 +12,9 @@ import (
)
const (
- composerUpstream = "https://packagist.org"
- composerRepo = "https://repo.packagist.org"
+ composerUpstream = "https://packagist.org"
+ composerRepo = "https://repo.packagist.org"
+ vendorPackageParts = 2
)
// ComposerHandler handles Composer/Packagist registry protocol requests.
@@ -74,8 +75,8 @@ func (h *ComposerHandler) handlePackageMetadata(w http.ResponseWriter, r *http.R
// Parse path: /p2/{vendor}/{package}.json
path := strings.TrimPrefix(r.URL.Path, "/p2/")
path = strings.TrimSuffix(path, ".json")
- parts := strings.SplitN(path, "/", 2)
- if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+ parts := strings.SplitN(path, "/", vendorPackageParts)
+ if len(parts) != vendorPackageParts || parts[0] == "" || parts[1] == "" {
http.Error(w, "invalid package path", http.StatusBadRequest)
return
}
@@ -145,58 +146,85 @@ func (h *ComposerHandler) rewriteMetadata(body []byte) ([]byte, error) {
continue
}
- packagePURL := purl.MakePURLString("composer", packageName, "")
-
- filtered := versionList[:0]
- for _, v := range versionList {
- vmap, ok := v.(map[string]any)
- if !ok {
- continue
- }
+ packages[packageName] = h.filterAndRewriteVersions(packageName, versionList)
+ }
- version, _ := vmap["version"].(string)
-
- // Apply cooldown filtering
- if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() {
- if timeStr, ok := vmap["time"].(string); ok {
- if publishedAt, err := time.Parse(time.RFC3339, timeStr); err == nil {
- if !h.proxy.Cooldown.IsAllowed("composer", packagePURL, publishedAt) {
- h.proxy.Logger.Info("cooldown: filtering composer version",
- "package", packageName, "version", version)
- continue
- }
- }
- }
- }
+ return json.Marshal(metadata)
+}
- dist, ok := vmap["dist"].(map[string]any)
- if !ok {
- filtered = append(filtered, v)
- continue
- }
+// filterAndRewriteVersions applies cooldown filtering and rewrites dist URLs
+// for a single package's version list.
+func (h *ComposerHandler) filterAndRewriteVersions(packageName string, versionList []any) []any {
+ packagePURL := purl.MakePURLString("composer", packageName, "")
- // Rewrite the dist URL
- if url, ok := dist["url"].(string); ok && url != "" {
- filename := "package.zip"
- if idx := strings.LastIndex(url, "/"); idx >= 0 {
- filename = url[idx+1:]
- }
+ filtered := versionList[:0]
+ for _, v := range versionList {
+ vmap, ok := v.(map[string]any)
+ if !ok {
+ continue
+ }
- parts := strings.SplitN(packageName, "/", 2)
- if len(parts) == 2 {
- newURL := fmt.Sprintf("%s/composer/files/%s/%s/%s/%s",
- h.proxyURL, parts[0], parts[1], version, filename)
- dist["url"] = newURL
- }
- }
+ version, _ := vmap["version"].(string)
- filtered = append(filtered, v)
+ if h.shouldFilterVersion(packagePURL, packageName, version, vmap) {
+ continue
}
- packages[packageName] = filtered
+ h.rewriteDistURL(vmap, packageName, version)
+ filtered = append(filtered, v)
}
- return json.Marshal(metadata)
+ return filtered
+}
+
+// shouldFilterVersion returns true if the version should be excluded due to cooldown.
+func (h *ComposerHandler) shouldFilterVersion(packagePURL, packageName, version string, vmap map[string]any) bool {
+ if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() {
+ return false
+ }
+
+ timeStr, ok := vmap["time"].(string)
+ if !ok {
+ return false
+ }
+
+ publishedAt, err := time.Parse(time.RFC3339, timeStr)
+ if err != nil {
+ return false
+ }
+
+ if !h.proxy.Cooldown.IsAllowed("composer", packagePURL, publishedAt) {
+ h.proxy.Logger.Info("cooldown: filtering composer version",
+ "package", packageName, "version", version)
+ return true
+ }
+
+ return false
+}
+
+// rewriteDistURL rewrites the dist URL in a version entry to point at this proxy.
+func (h *ComposerHandler) rewriteDistURL(vmap map[string]any, packageName, version string) {
+ dist, ok := vmap["dist"].(map[string]any)
+ if !ok {
+ return
+ }
+
+ url, ok := dist["url"].(string)
+ if !ok || url == "" {
+ return
+ }
+
+ filename := "package.zip"
+ if idx := strings.LastIndex(url, "/"); idx >= 0 {
+ filename = url[idx+1:]
+ }
+
+ parts := strings.SplitN(packageName, "/", vendorPackageParts)
+ if len(parts) == vendorPackageParts {
+ newURL := fmt.Sprintf("%s/composer/files/%s/%s/%s/%s",
+ h.proxyURL, parts[0], parts[1], version, filename)
+ dist["url"] = newURL
+ }
}
// handleDownload serves a package file, fetching and caching from upstream if needed.
diff --git a/internal/handler/conan_test.go b/internal/handler/conan_test.go
index a0bf1da..a7bd362 100644
--- a/internal/handler/conan_test.go
+++ b/internal/handler/conan_test.go
@@ -9,6 +9,8 @@ import (
"testing"
)
+const testProxyURL = "http://localhost:8080"
+
func conanTestProxy() *Proxy {
return &Proxy{
Logger: slog.Default(),
@@ -44,7 +46,7 @@ func TestConanShouldCacheFile(t *testing.T) {
func TestConanPingV1(t *testing.T) {
h := &ConanHandler{
proxy: conanTestProxy(),
- proxyURL: "http://localhost:8080",
+ proxyURL: testProxyURL,
}
req := httptest.NewRequest(http.MethodGet, "/v1/ping", nil)
@@ -65,7 +67,7 @@ func TestConanPingV1(t *testing.T) {
func TestConanPingV2(t *testing.T) {
h := &ConanHandler{
proxy: conanTestProxy(),
- proxyURL: "http://localhost:8080",
+ proxyURL: testProxyURL,
}
req := httptest.NewRequest(http.MethodGet, "/v2/ping", nil)
@@ -372,17 +374,17 @@ func TestNewConanHandler(t *testing.T) {
if h.upstreamURL != conanUpstream {
t.Errorf("upstreamURL = %q, want %q", h.upstreamURL, conanUpstream)
}
- if h.proxyURL != "http://localhost:8080" {
- t.Errorf("proxyURL = %q, want %q (trailing slash should be trimmed)", h.proxyURL, "http://localhost:8080")
+ if h.proxyURL != testProxyURL {
+ t.Errorf("proxyURL = %q, want %q (trailing slash should be trimmed)", h.proxyURL, testProxyURL)
}
}
func TestNewConanHandlerNoTrailingSlash(t *testing.T) {
proxy := conanTestProxy()
- h := NewConanHandler(proxy, "http://localhost:8080")
+ h := NewConanHandler(proxy, testProxyURL)
- if h.proxyURL != "http://localhost:8080" {
- t.Errorf("proxyURL = %q, want %q", h.proxyURL, "http://localhost:8080")
+ if h.proxyURL != testProxyURL {
+ t.Errorf("proxyURL = %q, want %q", h.proxyURL, testProxyURL)
}
}
diff --git a/internal/handler/conda.go b/internal/handler/conda.go
index 303ec0d..f929194 100644
--- a/internal/handler/conda.go
+++ b/internal/handler/conda.go
@@ -1,13 +1,13 @@
package handler
import (
- "io"
"net/http"
"strings"
)
const (
condaUpstream = "https://conda.anaconda.org"
+ minCondaParts = 3 // name-version-build requires at least 3 hyphen-separated parts
)
// CondaHandler handles Conda/Anaconda registry protocol requests.
@@ -98,7 +98,7 @@ func (h *CondaHandler) parseFilename(filename string) (name, version string) {
// Split by hyphens, the format is name-version-build
// The name can contain hyphens, so we need to find version-build at the end
parts := strings.Split(base, "-")
- if len(parts) < 3 {
+ if len(parts) < minCondaParts {
return "", ""
}
@@ -121,35 +121,5 @@ func (h *CondaHandler) parseFilename(filename string) (name, version string) {
// proxyUpstream forwards a request to Anaconda without caching.
func (h *CondaHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) {
- upstreamURL := h.upstreamURL + r.URL.Path
-
- h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL)
-
- req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- // Copy accept-encoding for compression
- if ae := r.Header.Get("Accept-Encoding"); ae != "" {
- req.Header.Set("Accept-Encoding", ae)
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("upstream request failed", "error", err)
- http.Error(w, "upstream request failed", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- for k, vv := range resp.Header {
- for _, v := range vv {
- w.Header().Add(k, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"})
}
diff --git a/internal/handler/container.go b/internal/handler/container.go
index 0669b12..fc5f98c 100644
--- a/internal/handler/container.go
+++ b/internal/handler/container.go
@@ -10,8 +10,11 @@ import (
)
const (
- dockerHubRegistry = "https://registry-1.docker.io"
- dockerHubAuth = "https://auth.docker.io"
+ dockerHubRegistry = "https://registry-1.docker.io"
+ dockerHubAuth = "https://auth.docker.io"
+ blobMatchCount = 3 // full match + name + digest
+ manifestMatchCount = 3 // full match + name + reference
+ tagsListMatchCount = 2 // full match + name
)
// ContainerHandler handles OCI/Docker container registry protocol requests.
@@ -347,7 +350,7 @@ var blobPathPattern = regexp.MustCompile(`^(.+)/blobs/(sha256:[a-f0-9]+)$`)
// parseBlobPath extracts repository name and digest from a blob path.
func (h *ContainerHandler) parseBlobPath(path string) (name, digest string) {
matches := blobPathPattern.FindStringSubmatch(path)
- if len(matches) != 3 {
+ if len(matches) != blobMatchCount {
return "", ""
}
return matches[1], matches[2]
@@ -359,7 +362,7 @@ var manifestPathPattern = regexp.MustCompile(`^(.+)/manifests/(.+)$`)
// parseManifestPath extracts repository name and reference from a manifest path.
func (h *ContainerHandler) parseManifestPath(path string) (name, reference string) {
matches := manifestPathPattern.FindStringSubmatch(path)
- if len(matches) != 3 {
+ if len(matches) != manifestMatchCount {
return "", ""
}
return matches[1], matches[2]
@@ -371,7 +374,7 @@ var tagsListPathPattern = regexp.MustCompile(`^(.+)/tags/list$`)
// parseTagsListPath extracts repository name from a tags list path.
func (h *ContainerHandler) parseTagsListPath(path string) string {
matches := tagsListPathPattern.FindStringSubmatch(path)
- if len(matches) != 2 {
+ if len(matches) != tagsListMatchCount {
return ""
}
return matches[1]
diff --git a/internal/handler/cran.go b/internal/handler/cran.go
index 4702473..b8c3c3a 100644
--- a/internal/handler/cran.go
+++ b/internal/handler/cran.go
@@ -1,7 +1,6 @@
package handler
import (
- "io"
"net/http"
"strings"
)
@@ -153,34 +152,5 @@ func (h *CRANHandler) isBinaryPackage(filename string) bool {
// proxyUpstream forwards a request to CRAN without caching.
func (h *CRANHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) {
- upstreamURL := h.upstreamURL + r.URL.Path
-
- h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL)
-
- req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- if ae := r.Header.Get("Accept-Encoding"); ae != "" {
- req.Header.Set("Accept-Encoding", ae)
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("upstream request failed", "error", err)
- http.Error(w, "upstream request failed", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- for k, vv := range resp.Header {
- for _, v := range vv {
- w.Header().Add(k, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"})
}
diff --git a/internal/handler/debian.go b/internal/handler/debian.go
index bada1af..11a1979 100644
--- a/internal/handler/debian.go
+++ b/internal/handler/debian.go
@@ -2,7 +2,6 @@ package handler
import (
"fmt"
- "io"
"net/http"
"regexp"
"strings"
@@ -10,6 +9,7 @@ import (
const (
debianUpstream = "http://deb.debian.org/debian"
+ debMatchCount = 4 // full match + name + version + arch
)
// DebianHandler handles APT/Debian repository protocol requests.
@@ -93,67 +93,12 @@ func (h *DebianHandler) handlePackageDownload(w http.ResponseWriter, r *http.Req
// handleMetadata proxies repository metadata files.
// These change frequently so we don't cache them.
func (h *DebianHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) {
- upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path)
-
- h.proxy.Logger.Debug("debian metadata request", "path", path)
-
- req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- // Forward relevant headers
- for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} {
- if v := r.Header.Get(header); v != "" {
- req.Header.Set(header, v)
- }
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("failed to fetch upstream metadata", "error", err)
- http.Error(w, "failed to fetch from upstream", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- // Copy response headers
- for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} {
- if v := resp.Header.Get(header); v != "" {
- w.Header().Set(header, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "debian")
}
// proxyFile proxies any file directly without caching.
func (h *DebianHandler) proxyFile(w http.ResponseWriter, r *http.Request, path string) {
- upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path)
-
- req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- http.Error(w, "failed to fetch from upstream", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- for key, values := range resp.Header {
- for _, v := range values {
- w.Header().Add(key, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyFile(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path))
}
// debPackagePattern matches .deb filenames to extract name, version, and arch.
@@ -172,7 +117,7 @@ func (h *DebianHandler) parsePoolPath(path string) (name, version, arch string)
// Parse the filename
matches := debPackagePattern.FindStringSubmatch(filename)
- if len(matches) != 4 {
+ if len(matches) != debMatchCount {
return "", "", ""
}
diff --git a/internal/handler/debian_test.go b/internal/handler/debian_test.go
index 77d6843..dfdd326 100644
--- a/internal/handler/debian_test.go
+++ b/internal/handler/debian_test.go
@@ -1,98 +1,23 @@
package handler
import (
- "net/http"
- "net/http/httptest"
"testing"
)
func TestDebianHandler_parsePoolPath(t *testing.T) {
h := &DebianHandler{}
- tests := []struct {
- path string
- wantName string
- wantVersion string
- wantArch string
- }{
- {
- path: "pool/main/n/nginx/nginx_1.18.0-6_amd64.deb",
- wantName: "nginx",
- wantVersion: "1.18.0-6",
- wantArch: "amd64",
- },
- {
- path: "pool/main/libn/libncurses/libncurses6_6.2-1_amd64.deb",
- wantName: "libncurses6",
- wantVersion: "6.2-1",
- wantArch: "amd64",
- },
- {
- path: "pool/contrib/v/virtualbox/virtualbox_6.1.38-1_amd64.deb",
- wantName: "virtualbox",
- wantVersion: "6.1.38-1",
- wantArch: "amd64",
- },
- {
- path: "pool/main/g/git/git_2.39.2-1_arm64.deb",
- wantName: "git",
- wantVersion: "2.39.2-1",
- wantArch: "arm64",
- },
- {
- path: "invalid/path",
- wantName: "",
- wantVersion: "",
- wantArch: "",
- },
- {
- path: "pool/main/n/nginx/nginx.deb",
- wantName: "",
- wantVersion: "",
- wantArch: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.path, func(t *testing.T) {
- name, version, arch := h.parsePoolPath(tt.path)
- if name != tt.wantName {
- t.Errorf("parsePoolPath() name = %q, want %q", name, tt.wantName)
- }
- if version != tt.wantVersion {
- t.Errorf("parsePoolPath() version = %q, want %q", version, tt.wantVersion)
- }
- if arch != tt.wantArch {
- t.Errorf("parsePoolPath() arch = %q, want %q", arch, tt.wantArch)
- }
- })
- }
+ assertPathParser(t, "parsePoolPath", h.parsePoolPath, []pathParseCase{
+ {"pool/main/n/nginx/nginx_1.18.0-6_amd64.deb", "nginx", "1.18.0-6", "amd64"},
+ {"pool/main/libn/libncurses/libncurses6_6.2-1_amd64.deb", "libncurses6", "6.2-1", "amd64"},
+ {"pool/contrib/v/virtualbox/virtualbox_6.1.38-1_amd64.deb", "virtualbox", "6.1.38-1", "amd64"},
+ {"pool/main/g/git/git_2.39.2-1_arm64.deb", "git", "2.39.2-1", "arm64"},
+ {"invalid/path", "", "", ""},
+ {"pool/main/n/nginx/nginx.deb", "", "", ""},
+ })
}
func TestDebianHandler_Routes(t *testing.T) {
h := NewDebianHandler(nil, "http://localhost:8080")
-
- // Test that handler doesn't panic on initialization
- handler := h.Routes()
- if handler == nil {
- t.Fatal("Routes() returned nil")
- }
-
- // Test method not allowed
- req := httptest.NewRequest(http.MethodPost, "/dists/stable/Release", nil)
- w := httptest.NewRecorder()
- handler.ServeHTTP(w, req)
-
- if w.Code != http.StatusMethodNotAllowed {
- t.Errorf("POST request: got status %d, want %d", w.Code, http.StatusMethodNotAllowed)
- }
-
- // Test path traversal rejection
- req = httptest.NewRequest(http.MethodGet, "/pool/../../../etc/passwd", nil)
- w = httptest.NewRecorder()
- handler.ServeHTTP(w, req)
-
- if w.Code != http.StatusBadRequest {
- t.Errorf("path traversal: got status %d, want %d", w.Code, http.StatusBadRequest)
- }
+ assertRoutesBasics(t, h.Routes(), "/dists/stable/Release", "/pool/../../../etc/passwd")
}
diff --git a/internal/handler/download_test.go b/internal/handler/download_test.go
index a51f908..a6e0cb3 100644
--- a/internal/handler/download_test.go
+++ b/internal/handler/download_test.go
@@ -59,6 +59,38 @@ func seedPackageWithPURL(t *testing.T, db *database.DB, store *mockStorage, ecos
}
}
+// assertUpstreamProxied verifies that a handler proxies a request to the upstream
+// server and returns the expected response body. The makeHandler function receives
+// a configured Proxy and the upstream URL, and returns the handler to test.
+func assertUpstreamProxied(t *testing.T, wantBody, path string, makeHandler func(*Proxy, string) http.Handler) {
+ t.Helper()
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = fmt.Fprint(w, wantBody)
+ }))
+ defer upstream.Close()
+
+ proxy, _, _, _ := setupTestProxy(t)
+ proxy.HTTPClient = upstream.Client()
+
+ srv := httptest.NewServer(makeHandler(proxy, upstream.URL))
+ defer srv.Close()
+
+ resp, err := http.Get(srv.URL + path)
+ if err != nil {
+ t.Fatalf("request failed: %v", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
+ }
+ body, _ := io.ReadAll(resp.Body)
+ if string(body) != wantBody {
+ t.Errorf("body = %q, want %q", body, wantBody)
+ }
+}
+
func TestGemHandler_DownloadCacheHit(t *testing.T) {
proxy, db, store, _ := setupTestProxy(t)
seedPackage(t, db, store, "gem", "rails", "7.1.0", "rails-7.1.0.gem", "gem binary data")
@@ -71,7 +103,7 @@ func TestGemHandler_DownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -94,7 +126,7 @@ func TestGemHandler_DownloadCacheHitMultiHyphen(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -125,7 +157,7 @@ func TestGemHandler_InvalidFilename(t *testing.T) {
if err != nil {
t.Fatalf("request to %s failed: %v", tt.path, err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != tt.code {
t.Errorf("GET %s: status = %d, want %d", tt.path, resp.StatusCode, tt.code)
@@ -137,7 +169,7 @@ func TestGemHandler_UpstreamProxy(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "upstream")
w.WriteHeader(http.StatusOK)
- fmt.Fprint(w, "upstream specs data")
+ _, _ = fmt.Fprint(w, "upstream specs data")
}))
defer upstream.Close()
@@ -156,7 +188,7 @@ func TestGemHandler_UpstreamProxy(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -185,7 +217,7 @@ func TestGemHandler_CacheMiss(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if !fetcher.fetchCalled {
t.Error("expected fetcher to be called on cache miss")
@@ -204,7 +236,7 @@ func TestGoHandler_DownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -225,7 +257,7 @@ func TestGoHandler_MethodNotAllowed(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
@@ -242,7 +274,7 @@ func TestGoHandler_NotFound(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)
@@ -259,7 +291,7 @@ func TestGoHandler_UnknownAtVSuffix(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)
@@ -268,7 +300,7 @@ func TestGoHandler_UnknownAtVSuffix(t *testing.T) {
func TestGoHandler_UpstreamProxy(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "v0.14.0\nv0.13.0\n")
+ _, _ = fmt.Fprint(w, "v0.14.0\nv0.13.0\n")
}))
defer upstream.Close()
@@ -296,7 +328,7 @@ func TestGoHandler_UpstreamProxy(t *testing.T) {
if err != nil {
t.Fatalf("GET %s failed: %v", path, err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("GET %s: status = %d, want %d", path, resp.StatusCode, http.StatusOK)
@@ -319,7 +351,7 @@ func TestGoHandler_CacheMiss(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if !fetcher.fetchCalled {
t.Error("expected fetcher to be called on cache miss")
@@ -338,7 +370,7 @@ func TestHexHandler_DownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -368,7 +400,7 @@ func TestHexHandler_InvalidFilename(t *testing.T) {
if err != nil {
t.Fatalf("request to %s failed: %v", tt.path, err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != tt.code {
t.Errorf("GET %s: status = %d, want %d", tt.path, resp.StatusCode, tt.code)
@@ -377,35 +409,12 @@ func TestHexHandler_InvalidFilename(t *testing.T) {
}
func TestHexHandler_UpstreamProxy(t *testing.T) {
- upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "hex registry data")
- }))
- defer upstream.Close()
-
- proxy, _, _, _ := setupTestProxy(t)
- h := &HexHandler{
- proxy: proxy,
- upstreamURL: upstream.URL,
- proxyURL: "http://localhost",
- }
- proxy.HTTPClient = upstream.Client()
-
- srv := httptest.NewServer(h.Routes())
- defer srv.Close()
-
- resp, err := http.Get(srv.URL + "/packages/phoenix")
- if err != nil {
- t.Fatalf("request failed: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
- }
- body, _ := io.ReadAll(resp.Body)
- if string(body) != "hex registry data" {
- t.Errorf("body = %q, want %q", body, "hex registry data")
- }
+ assertUpstreamProxied(t, "hex registry data", "/packages/phoenix",
+ func(proxy *Proxy, upstreamURL string) http.Handler {
+ h := &HexHandler{proxy: proxy, upstreamURL: upstreamURL, proxyURL: "http://localhost"}
+ return h.Routes()
+ },
+ )
}
func TestHexHandler_CacheMiss(t *testing.T) {
@@ -423,7 +432,7 @@ func TestHexHandler_CacheMiss(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if !fetcher.fetchCalled {
t.Error("expected fetcher to be called on cache miss")
@@ -442,7 +451,7 @@ func TestCondaHandler_DownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -465,7 +474,7 @@ func TestCondaHandler_DownloadTarBz2CacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -477,35 +486,12 @@ func TestCondaHandler_DownloadTarBz2CacheHit(t *testing.T) {
}
func TestCondaHandler_NonPackageFileProxied(t *testing.T) {
- upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "repodata json")
- }))
- defer upstream.Close()
-
- proxy, _, _, _ := setupTestProxy(t)
- h := &CondaHandler{
- proxy: proxy,
- upstreamURL: upstream.URL,
- proxyURL: "http://localhost",
- }
- proxy.HTTPClient = upstream.Client()
-
- srv := httptest.NewServer(h.Routes())
- defer srv.Close()
-
- resp, err := http.Get(srv.URL + "/main/linux-64/repodata.json")
- if err != nil {
- t.Fatalf("request failed: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
- }
- body, _ := io.ReadAll(resp.Body)
- if string(body) != "repodata json" {
- t.Errorf("body = %q, want %q", body, "repodata json")
- }
+ assertUpstreamProxied(t, "repodata json", "/main/linux-64/repodata.json",
+ func(proxy *Proxy, upstreamURL string) http.Handler {
+ h := &CondaHandler{proxy: proxy, upstreamURL: upstreamURL, proxyURL: "http://localhost"}
+ return h.Routes()
+ },
+ )
}
func TestCondaHandler_CacheMiss(t *testing.T) {
@@ -531,7 +517,7 @@ func TestCondaHandler_CacheMiss(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if !fetcher.fetchCalled {
t.Error("expected fetcher to be called on cache miss")
@@ -550,7 +536,7 @@ func TestCRANHandler_SourceDownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -573,7 +559,7 @@ func TestCRANHandler_BinaryDownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -585,40 +571,17 @@ func TestCRANHandler_BinaryDownloadCacheHit(t *testing.T) {
}
func TestCRANHandler_NonPackageFileProxied(t *testing.T) {
- upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "PACKAGES index")
- }))
- defer upstream.Close()
-
- proxy, _, _, _ := setupTestProxy(t)
- h := &CRANHandler{
- proxy: proxy,
- upstreamURL: upstream.URL,
- proxyURL: "http://localhost",
- }
- proxy.HTTPClient = upstream.Client()
-
- srv := httptest.NewServer(h.Routes())
- defer srv.Close()
-
- resp, err := http.Get(srv.URL + "/src/contrib/PACKAGES")
- if err != nil {
- t.Fatalf("request failed: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
- }
- body, _ := io.ReadAll(resp.Body)
- if string(body) != "PACKAGES index" {
- t.Errorf("body = %q, want %q", body, "PACKAGES index")
- }
+ assertUpstreamProxied(t, "PACKAGES index", "/src/contrib/PACKAGES",
+ func(proxy *Proxy, upstreamURL string) http.Handler {
+ h := &CRANHandler{proxy: proxy, upstreamURL: upstreamURL, proxyURL: "http://localhost"}
+ return h.Routes()
+ },
+ )
}
func TestCRANHandler_SourceNonTarGzProxied(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "some other file")
+ _, _ = fmt.Fprint(w, "some other file")
}))
defer upstream.Close()
@@ -637,7 +600,7 @@ func TestCRANHandler_SourceNonTarGzProxied(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -661,7 +624,7 @@ func TestCRANHandler_CacheMiss(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if !fetcher.fetchCalled {
t.Error("expected fetcher to be called on cache miss")
@@ -680,7 +643,7 @@ func TestMavenHandler_DownloadCacheHit(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK)
@@ -693,7 +656,7 @@ func TestMavenHandler_DownloadCacheHit(t *testing.T) {
func TestMavenHandler_MetadataProxied(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprint(w, "")
+ _, _ = fmt.Fprint(w, "")
}))
defer upstream.Close()
@@ -719,7 +682,7 @@ func TestMavenHandler_MetadataProxied(t *testing.T) {
if err != nil {
t.Fatalf("GET %s failed: %v", path, err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("GET %s: status = %d, want %d", path, resp.StatusCode, http.StatusOK)
@@ -737,7 +700,7 @@ func TestMavenHandler_EmptyPathNotFound(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound)
@@ -770,7 +733,7 @@ func TestMavenHandler_ArtifactExtensions(t *testing.T) {
if err != nil {
t.Fatalf("GET %s failed: %v", path, err)
}
- resp.Body.Close()
+ _ = resp.Body.Close()
if !fetcher.fetchCalled {
t.Errorf("fetcher not called for %s", ext)
@@ -796,7 +759,7 @@ func TestMavenHandler_CacheMiss(t *testing.T) {
if err != nil {
t.Fatalf("request failed: %v", err)
}
- defer resp.Body.Close()
+ defer func() { _ = resp.Body.Close() }()
if !fetcher.fetchCalled {
t.Error("expected fetcher to be called on cache miss")
diff --git a/internal/handler/go.go b/internal/handler/go.go
index d64ded3..dd4e17d 100644
--- a/internal/handler/go.go
+++ b/internal/handler/go.go
@@ -2,13 +2,13 @@ package handler
import (
"fmt"
- "io"
"net/http"
"strings"
)
const (
- goUpstream = "https://proxy.golang.org"
+ goUpstream = "https://proxy.golang.org"
+ asciiCaseOffset = 32 // difference between lowercase and uppercase ASCII letters
)
// GoHandler handles Go module proxy protocol requests.
@@ -108,33 +108,7 @@ func (h *GoHandler) handleDownload(w http.ResponseWriter, r *http.Request, modul
// proxyUpstream forwards a request to proxy.golang.org without caching.
func (h *GoHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) {
- upstreamURL := h.upstreamURL + r.URL.Path
-
- h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL)
-
- req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("upstream request failed", "error", err)
- http.Error(w, "upstream request failed", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- // Copy response headers
- for k, vv := range resp.Header {
- for _, v := range vv {
- w.Header().Add(k, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil)
}
// decodeGoModule decodes an encoded module path.
@@ -143,7 +117,7 @@ func decodeGoModule(encoded string) string {
var b strings.Builder
for i := 0; i < len(encoded); i++ {
if encoded[i] == '!' && i+1 < len(encoded) {
- b.WriteByte(encoded[i+1] - 32) // lowercase to uppercase
+ b.WriteByte(encoded[i+1] - asciiCaseOffset) // lowercase to uppercase
i++
} else {
b.WriteByte(encoded[i])
diff --git a/internal/handler/handler.go b/internal/handler/handler.go
index 9205008..91d8960 100644
--- a/internal/handler/handler.go
+++ b/internal/handler/handler.go
@@ -30,6 +30,8 @@ func containsPathTraversal(path string) bool {
return false
}
+const defaultHTTPTimeout = 30 * time.Second
+
// maxMetadataSize is the maximum size of upstream metadata responses (50 MB).
// Package metadata (e.g. npm with many versions) can be large, but unbounded
// reads risk OOM if an upstream misbehaves.
@@ -64,7 +66,7 @@ func NewProxy(db *database.DB, store storage.Storage, fetcher fetch.FetcherInter
Resolver: resolver,
Logger: logger,
HTTPClient: &http.Client{
- Timeout: 30 * time.Second,
+ Timeout: defaultHTTPTimeout,
},
}
}
@@ -187,7 +189,7 @@ func (p *Proxy) fetchAndCache(ctx context.Context, ecosystem, name, version, fil
}
// Update database
- if err := p.updateCacheDB(ctx, ecosystem, name, version, filename, pkgPURL, versionPURL, info.URL, storagePath, hash, size, artifact.ContentType); err != nil {
+ if err := p.updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, info.URL, storagePath, hash, size, artifact.ContentType); err != nil {
p.Logger.Warn("failed to update cache database", "error", err)
// Continue anyway - we have the file
}
@@ -211,7 +213,7 @@ func (p *Proxy) fetchAndCache(ctx context.Context, ecosystem, name, version, fil
}, nil
}
-func (p *Proxy) updateCacheDB(ctx context.Context, ecosystem, name, version, filename, pkgPURL, versionPURL, upstreamURL, storagePath, hash string, size int64, contentType string) error {
+func (p *Proxy) updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, upstreamURL, storagePath, hash string, size int64, contentType string) error {
now := time.Now()
// Upsert package
@@ -272,6 +274,102 @@ func ServeArtifact(w http.ResponseWriter, result *CacheResult) {
_, _ = io.Copy(w, result.Reader)
}
+// ProxyUpstream forwards a request to an upstream URL without caching.
+// It copies the request, forwards specified headers, and streams the response back.
+// If forwardHeaders is nil, all response headers are copied.
+func (p *Proxy) ProxyUpstream(w http.ResponseWriter, r *http.Request, upstreamURL string, forwardHeaders []string) {
+ p.Logger.Debug("proxying to upstream", "url", upstreamURL)
+
+ req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
+ if err != nil {
+ http.Error(w, "failed to create request", http.StatusInternalServerError)
+ return
+ }
+
+ // Copy request headers that affect content negotiation / caching
+ for _, header := range forwardHeaders {
+ if v := r.Header.Get(header); v != "" {
+ req.Header.Set(header, v)
+ }
+ }
+
+ resp, err := p.HTTPClient.Do(req)
+ if err != nil {
+ p.Logger.Error("upstream request failed", "error", err)
+ http.Error(w, "upstream request failed", http.StatusBadGateway)
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ for k, vv := range resp.Header {
+ for _, v := range vv {
+ w.Header().Add(k, v)
+ }
+ }
+
+ w.WriteHeader(resp.StatusCode)
+ _, _ = io.Copy(w, resp.Body)
+}
+
+// ProxyMetadata forwards a metadata request to upstream, copying only specific response headers.
+func (p *Proxy) ProxyMetadata(w http.ResponseWriter, r *http.Request, upstreamURL string, logLabel string) {
+ p.Logger.Debug(logLabel+" metadata request", "url", upstreamURL)
+
+ req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
+ if err != nil {
+ http.Error(w, "failed to create request", http.StatusInternalServerError)
+ return
+ }
+
+ for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} {
+ if v := r.Header.Get(header); v != "" {
+ req.Header.Set(header, v)
+ }
+ }
+
+ resp, err := p.HTTPClient.Do(req)
+ if err != nil {
+ p.Logger.Error("failed to fetch upstream metadata", "error", err)
+ http.Error(w, "failed to fetch from upstream", http.StatusBadGateway)
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} {
+ if v := resp.Header.Get(header); v != "" {
+ w.Header().Set(header, v)
+ }
+ }
+
+ w.WriteHeader(resp.StatusCode)
+ _, _ = io.Copy(w, resp.Body)
+}
+
+// ProxyFile forwards a file request to upstream, copying all response headers.
+func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL string) {
+ req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
+ if err != nil {
+ http.Error(w, "failed to create request", http.StatusInternalServerError)
+ return
+ }
+
+ resp, err := p.HTTPClient.Do(req)
+ if err != nil {
+ http.Error(w, "failed to fetch from upstream", http.StatusBadGateway)
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ for key, values := range resp.Header {
+ for _, v := range values {
+ w.Header().Add(key, v)
+ }
+ }
+
+ w.WriteHeader(resp.StatusCode)
+ _, _ = io.Copy(w, resp.Body)
+}
+
// JSONError writes a JSON error response.
func JSONError(w http.ResponseWriter, status int, message string) {
w.Header().Set("Content-Type", "application/json")
@@ -310,7 +408,7 @@ func (p *Proxy) fetchAndCacheFromURL(ctx context.Context, ecosystem, name, versi
return nil, fmt.Errorf("storing artifact: %w", err)
}
- if err := p.updateCacheDB(ctx, ecosystem, name, version, filename, pkgPURL, versionPURL, downloadURL, storagePath, hash, size, artifact.ContentType); err != nil {
+ if err := p.updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, downloadURL, storagePath, hash, size, artifact.ContentType); err != nil {
p.Logger.Warn("failed to update cache database", "error", err)
}
diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go
index 5ba8ae3..dd85a17 100644
--- a/internal/handler/handler_test.go
+++ b/internal/handler/handler_test.go
@@ -160,6 +160,61 @@ func seedPackage(t *testing.T, db *database.DB, store *mockStorage, ecosystem, n
}
}
+// pathParseCase holds a single test case for path parsing functions that return
+// (name, version, arch).
+type pathParseCase struct {
+ path string
+ wantName string
+ wantVersion string
+ wantArch string
+}
+
+// assertPathParser runs table-driven tests for a path parser function that returns
+// three strings (name, version, arch).
+func assertPathParser(t *testing.T, funcName string, parse func(string) (string, string, string), cases []pathParseCase) {
+ t.Helper()
+ for _, tt := range cases {
+ t.Run(tt.path, func(t *testing.T) {
+ name, version, arch := parse(tt.path)
+ if name != tt.wantName {
+ t.Errorf("%s() name = %q, want %q", funcName, name, tt.wantName)
+ }
+ if version != tt.wantVersion {
+ t.Errorf("%s() version = %q, want %q", funcName, version, tt.wantVersion)
+ }
+ if arch != tt.wantArch {
+ t.Errorf("%s() arch = %q, want %q", funcName, arch, tt.wantArch)
+ }
+ })
+ }
+}
+
+// assertRoutesBasics checks that a handler's Routes() returns a non-nil handler,
+// rejects POST requests with 405, and rejects path traversal with 400.
+func assertRoutesBasics(t *testing.T, handler http.Handler, postPath, traversalPath string) {
+ t.Helper()
+
+ if handler == nil {
+ t.Fatal("Routes() returned nil")
+ }
+
+ req := httptest.NewRequest(http.MethodPost, postPath, nil)
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusMethodNotAllowed {
+ t.Errorf("POST request: got status %d, want %d", w.Code, http.StatusMethodNotAllowed)
+ }
+
+ req = httptest.NewRequest(http.MethodGet, traversalPath, nil)
+ w = httptest.NewRecorder()
+ handler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("path traversal: got status %d, want %d", w.Code, http.StatusBadRequest)
+ }
+}
+
func TestGetOrFetchArtifact_CacheHit(t *testing.T) {
proxy, db, store, fetcher := setupTestProxy(t)
seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content")
diff --git a/internal/handler/hex.go b/internal/handler/hex.go
index 2e972f0..4e0f2a2 100644
--- a/internal/handler/hex.go
+++ b/internal/handler/hex.go
@@ -1,8 +1,6 @@
package handler
import (
- "fmt"
- "io"
"net/http"
"strings"
)
@@ -89,40 +87,5 @@ func (h *HexHandler) parseTarballFilename(filename string) (name, version string
// proxyUpstream forwards a request to hex.pm without caching.
func (h *HexHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) {
- upstreamURL := h.upstreamURL + r.URL.Path
-
- h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL)
-
- req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- // Copy accept header for content negotiation
- if accept := r.Header.Get("Accept"); accept != "" {
- req.Header.Set("Accept", accept)
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("upstream request failed", "error", err)
- http.Error(w, "upstream request failed", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- // Copy response headers
- for k, vv := range resp.Header {
- for _, v := range vv {
- w.Header().Add(k, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
-}
-
-func init() {
- _ = fmt.Sprintf // silence import if unused
+ h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept"})
}
diff --git a/internal/handler/maven.go b/internal/handler/maven.go
index 2c7214f..79da0c0 100644
--- a/internal/handler/maven.go
+++ b/internal/handler/maven.go
@@ -2,7 +2,6 @@ package handler
import (
"fmt"
- "io"
"net/http"
"path"
"strings"
@@ -10,6 +9,7 @@ import (
const (
mavenUpstream = "https://repo1.maven.org/maven2"
+ minMavenParts = 4 // group path segments + artifact + version + filename
)
// MavenHandler handles Maven repository protocol requests.
@@ -99,7 +99,7 @@ func (h *MavenHandler) handleDownload(w http.ResponseWriter, r *http.Request, ur
// -> ("com.google.guava", "guava", "32.1.3-jre", "guava-32.1.3-jre.jar")
func (h *MavenHandler) parsePath(urlPath string) (group, artifact, version, filename string) {
parts := strings.Split(urlPath, "/")
- if len(parts) < 4 {
+ if len(parts) < minMavenParts {
return "", "", "", ""
}
@@ -136,30 +136,5 @@ func (h *MavenHandler) isMetadataFile(filename string) bool {
// proxyUpstream forwards a request to Maven Central without caching.
func (h *MavenHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) {
- upstreamURL := h.upstreamURL + r.URL.Path
-
- h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL)
-
- req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("upstream request failed", "error", err)
- http.Error(w, "upstream request failed", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- for k, vv := range resp.Header {
- for _, v := range vv {
- w.Header().Add(k, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil)
}
diff --git a/internal/handler/npm.go b/internal/handler/npm.go
index 99a2bb9..e0b0566 100644
--- a/internal/handler/npm.go
+++ b/internal/handler/npm.go
@@ -14,6 +14,7 @@ import (
const (
npmUpstream = "https://registry.npmjs.org"
+ scopedParts = 2 // scope + name in scoped packages
)
// NPMHandler handles npm registry protocol requests.
@@ -127,46 +128,71 @@ func (h *NPMHandler) rewriteMetadata(packageName string, body []byte) ([]byte, e
return body, nil // No versions to rewrite
}
- // Apply cooldown filtering
- if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() {
- timeMap, _ := metadata["time"].(map[string]any)
- packagePURL := purl.MakePURLString("npm", packageName, "")
+ h.applyCooldownFiltering(metadata, versions, packageName)
+ h.rewriteTarballURLs(versions, packageName)
- for version := range versions {
- if timeMap == nil {
- continue
- }
- publishedStr, ok := timeMap[version].(string)
- if !ok {
- continue
- }
- publishedAt, err := time.Parse(time.RFC3339, publishedStr)
- if err != nil {
- continue
- }
- if !h.proxy.Cooldown.IsAllowed("npm", packagePURL, publishedAt) {
- h.proxy.Logger.Info("cooldown: filtering npm version",
- "package", packageName, "version", version,
- "published", publishedStr)
- delete(versions, version)
- delete(timeMap, version)
- }
- }
+ return json.Marshal(metadata)
+}
- // Update dist-tags.latest if it was filtered
- if distTags, ok := metadata["dist-tags"].(map[string]any); ok {
- if latest, ok := distTags["latest"].(string); ok {
- if _, exists := versions[latest]; !exists {
- // Find newest remaining version from the time map
- newLatest := h.findNewestVersion(versions, timeMap)
- if newLatest != "" {
- distTags["latest"] = newLatest
- }
- }
- }
+// applyCooldownFiltering removes versions that are too recently published,
+// and updates dist-tags.latest if the current latest was filtered out.
+func (h *NPMHandler) applyCooldownFiltering(metadata map[string]any, versions map[string]any, packageName string) {
+ if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() {
+ return
+ }
+
+ timeMap, _ := metadata["time"].(map[string]any)
+ if timeMap == nil {
+ return
+ }
+
+ packagePURL := purl.MakePURLString("npm", packageName, "")
+
+ for version := range versions {
+ publishedStr, ok := timeMap[version].(string)
+ if !ok {
+ continue
}
+ publishedAt, err := time.Parse(time.RFC3339, publishedStr)
+ if err != nil {
+ continue
+ }
+ if !h.proxy.Cooldown.IsAllowed("npm", packagePURL, publishedAt) {
+ h.proxy.Logger.Info("cooldown: filtering npm version",
+ "package", packageName, "version", version,
+ "published", publishedStr)
+ delete(versions, version)
+ delete(timeMap, version)
+ }
+ }
+
+ h.updateDistTagsLatest(metadata, versions, timeMap)
+}
+
+// updateDistTagsLatest updates the dist-tags.latest field if the current latest
+// version was removed by cooldown filtering.
+func (h *NPMHandler) updateDistTagsLatest(metadata, versions, timeMap map[string]any) {
+ distTags, ok := metadata["dist-tags"].(map[string]any)
+ if !ok {
+ return
+ }
+
+ latest, ok := distTags["latest"].(string)
+ if !ok {
+ return
+ }
+
+ if _, exists := versions[latest]; exists {
+ return
+ }
+
+ if newLatest := h.findNewestVersion(versions, timeMap); newLatest != "" {
+ distTags["latest"] = newLatest
}
+}
+// rewriteTarballURLs rewrites all tarball URLs in version entries to point at this proxy.
+func (h *NPMHandler) rewriteTarballURLs(versions map[string]any, packageName string) {
for version, vdata := range versions {
vmap, ok := vdata.(map[string]any)
if !ok {
@@ -178,25 +204,24 @@ func (h *NPMHandler) rewriteMetadata(packageName string, body []byte) ([]byte, e
continue
}
- if tarball, ok := dist["tarball"].(string); ok {
- // Extract filename from tarball URL
- filename := tarball
- if idx := strings.LastIndex(tarball, "/"); idx >= 0 {
- filename = tarball[idx+1:]
- }
-
- // Rewrite to our proxy URL
- escapedName := url.PathEscape(packageName)
- newTarball := fmt.Sprintf("%s/npm/%s/-/%s", h.proxyURL, escapedName, filename)
- dist["tarball"] = newTarball
+ tarball, ok := dist["tarball"].(string)
+ if !ok {
+ continue
+ }
- h.proxy.Logger.Debug("rewrote tarball URL",
- "package", packageName, "version", version,
- "old", tarball, "new", newTarball)
+ filename := tarball
+ if idx := strings.LastIndex(tarball, "/"); idx >= 0 {
+ filename = tarball[idx+1:]
}
- }
- return json.Marshal(metadata)
+ escapedName := url.PathEscape(packageName)
+ newTarball := fmt.Sprintf("%s/npm/%s/-/%s", h.proxyURL, escapedName, filename)
+ dist["tarball"] = newTarball
+
+ h.proxy.Logger.Debug("rewrote tarball URL",
+ "package", packageName, "version", version,
+ "old", tarball, "new", newTarball)
+ }
}
// findNewestVersion returns the version string with the most recent timestamp
@@ -313,7 +338,7 @@ func (h *NPMHandler) extractVersionFromFilename(packageName, filename string) st
// For scoped packages, the filename uses the short name
shortName := packageName
if strings.Contains(packageName, "/") {
- parts := strings.SplitN(packageName, "/", 2)
+ parts := strings.SplitN(packageName, "/", scopedParts)
shortName = parts[1]
}
diff --git a/internal/handler/npm_test.go b/internal/handler/npm_test.go
index 90616d7..3c3dc7d 100644
--- a/internal/handler/npm_test.go
+++ b/internal/handler/npm_test.go
@@ -11,6 +11,8 @@ import (
"github.com/git-pkgs/proxy/internal/cooldown"
)
+const testVersion100 = "1.0.0"
+
func testProxy() *Proxy {
return &Proxy{
Logger: slog.Default(),
@@ -172,7 +174,7 @@ func TestNPMHandlerMetadataProxy(t *testing.T) {
// Check that tarball URL was rewritten
versions := result["versions"].(map[string]any)
- v := versions["1.0.0"].(map[string]any)
+ v := versions[testVersion100].(map[string]any)
dist := v["dist"].(map[string]any)
tarball := dist["tarball"].(string)
@@ -232,7 +234,7 @@ func TestNPMRewriteMetadataCooldown(t *testing.T) {
versions := result["versions"].(map[string]any)
// Old version should remain
- if _, ok := versions["1.0.0"]; !ok {
+ if _, ok := versions[testVersion100]; !ok {
t.Error("version 1.0.0 should not be filtered")
}
@@ -243,8 +245,8 @@ func TestNPMRewriteMetadataCooldown(t *testing.T) {
// dist-tags.latest should be updated to 1.0.0
distTags := result["dist-tags"].(map[string]any)
- if distTags["latest"] != "1.0.0" {
- t.Errorf("dist-tags.latest = %q, want %q", distTags["latest"], "1.0.0")
+ if distTags["latest"] != testVersion100 {
+ t.Errorf("dist-tags.latest = %q, want %q", distTags["latest"], testVersion100)
}
}
@@ -286,7 +288,7 @@ func TestNPMRewriteMetadataCooldownExemptPackage(t *testing.T) {
}
versions := result["versions"].(map[string]any)
- if _, ok := versions["1.0.0"]; !ok {
+ if _, ok := versions[testVersion100]; !ok {
t.Error("exempt package version should not be filtered")
}
}
diff --git a/internal/handler/pub.go b/internal/handler/pub.go
index 53c3a4f..b8f6207 100644
--- a/internal/handler/pub.go
+++ b/internal/handler/pub.go
@@ -12,7 +12,8 @@ import (
)
const (
- pubUpstream = "https://pub.dev"
+ pubUpstream = "https://pub.dev"
+ pubPathParts = 2 // name + version in path split by /versions/
)
// PubHandler handles pub.dev registry protocol requests.
@@ -49,7 +50,7 @@ func (h *PubHandler) handleDownload(w http.ResponseWriter, r *http.Request) {
// Parse path: /packages/{name}/versions/{version}.tar.gz
path := strings.TrimPrefix(r.URL.Path, "/packages/")
parts := strings.Split(path, "/versions/")
- if len(parts) != 2 {
+ if len(parts) != pubPathParts {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
@@ -137,15 +138,23 @@ func (h *PubHandler) rewriteMetadata(name string, body []byte) ([]byte, error) {
return nil, err
}
- // Rewrite archive URLs in versions
versions, ok := metadata["versions"].([]any)
if !ok {
return body, nil
}
packagePURL := purl.MakePURLString("pub", name, "")
+ filtered := h.filterAndRewriteVersions(name, packagePURL, versions)
+ metadata["versions"] = filtered
+
+ h.updateLatestVersion(metadata, filtered)
+
+ return json.Marshal(metadata)
+}
- // Filter and rewrite versions
+// filterAndRewriteVersions applies cooldown filtering and rewrites archive URLs
+// for a package's version list.
+func (h *PubHandler) filterAndRewriteVersions(name, packagePURL string, versions []any) []any {
filtered := versions[:0]
for _, vdata := range versions {
vmap, ok := vdata.(map[string]any)
@@ -158,20 +167,10 @@ func (h *PubHandler) rewriteMetadata(name string, body []byte) ([]byte, error) {
continue
}
- // Apply cooldown filtering
- if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() {
- if publishedStr, ok := vmap["published"].(string); ok {
- if publishedAt, err := time.Parse(time.RFC3339, publishedStr); err == nil {
- if !h.proxy.Cooldown.IsAllowed("pub", packagePURL, publishedAt) {
- h.proxy.Logger.Info("cooldown: filtering pub version",
- "package", name, "version", version)
- continue
- }
- }
- }
+ if h.shouldFilterVersion(packagePURL, name, version, vmap) {
+ continue
}
- // Rewrite archive_url
newURL := fmt.Sprintf("%s/pub/packages/%s/versions/%s.tar.gz", h.proxyURL, name, version)
vmap["archive_url"] = newURL
filtered = append(filtered, vdata)
@@ -180,28 +179,60 @@ func (h *PubHandler) rewriteMetadata(name string, body []byte) ([]byte, error) {
"package", name, "version", version, "new", newURL)
}
- metadata["versions"] = filtered
+ return filtered
+}
+
+// shouldFilterVersion returns true if the version should be excluded due to cooldown.
+func (h *PubHandler) shouldFilterVersion(packagePURL, name, version string, vmap map[string]any) bool {
+ if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() {
+ return false
+ }
+
+ publishedStr, ok := vmap["published"].(string)
+ if !ok {
+ return false
+ }
+
+ publishedAt, err := time.Parse(time.RFC3339, publishedStr)
+ if err != nil {
+ return false
+ }
- // Update latest if it points to a filtered version
- if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() {
- if latest, ok := metadata["latest"].(map[string]any); ok {
- if latestVer, ok := latest["version"].(string); ok {
- found := false
- for _, vdata := range filtered {
- if vmap, ok := vdata.(map[string]any); ok {
- if vmap["version"] == latestVer {
- found = true
- break
- }
- }
- }
- if !found && len(filtered) > 0 {
- // Use the last entry (most recent remaining)
- metadata["latest"] = filtered[len(filtered)-1]
- }
+ if !h.proxy.Cooldown.IsAllowed("pub", packagePURL, publishedAt) {
+ h.proxy.Logger.Info("cooldown: filtering pub version",
+ "package", name, "version", version)
+ return true
+ }
+
+ return false
+}
+
+// updateLatestVersion updates the latest field if the current latest version
+// was removed by cooldown filtering.
+func (h *PubHandler) updateLatestVersion(metadata map[string]any, filtered []any) {
+ if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() {
+ return
+ }
+
+ latest, ok := metadata["latest"].(map[string]any)
+ if !ok {
+ return
+ }
+
+ latestVer, ok := latest["version"].(string)
+ if !ok {
+ return
+ }
+
+ for _, vdata := range filtered {
+ if vmap, ok := vdata.(map[string]any); ok {
+ if vmap["version"] == latestVer {
+ return
}
}
}
- return json.Marshal(metadata)
+ if len(filtered) > 0 {
+ metadata["latest"] = filtered[len(filtered)-1]
+ }
}
diff --git a/internal/handler/pypi.go b/internal/handler/pypi.go
index 4a4232c..51b2871 100644
--- a/internal/handler/pypi.go
+++ b/internal/handler/pypi.go
@@ -16,7 +16,11 @@ import (
)
const (
- pypiUpstream = "https://pypi.org"
+ pypiUpstream = "https://pypi.org"
+ minWheelParts = 5 // name + version + python + abi + platform
+ minSubmatchParts = 2 // full match + first capture group
+ minPyPIPathParts = 3 // hash_prefix + hash + filename
+ minPythonTagLen = 2 // minimum length for a python tag (e.g., "py")
)
// PyPIHandler handles PyPI registry protocol requests.
@@ -172,7 +176,7 @@ func (h *PyPIHandler) rewriteSimpleHTML(body []byte, filteredVersions map[string
// Extract filename from between tags
innerRe := regexp.MustCompile(`>([^<]+)`)
innerMatch := innerRe.FindSubmatch(match)
- if len(innerMatch) < 2 {
+ if len(innerMatch) < minSubmatchParts {
return match
}
filename := string(innerMatch[1])
@@ -190,7 +194,7 @@ func (h *PyPIHandler) rewriteSimpleHTML(body []byte, filteredVersions map[string
return re.ReplaceAllFunc(body, func(match []byte) []byte {
submatch := re.FindSubmatch(match)
- if len(submatch) < 2 {
+ if len(submatch) < minSubmatchParts {
return match
}
@@ -285,59 +289,86 @@ func (h *PyPIHandler) rewriteJSONMetadata(body []byte) ([]byte, error) {
return nil, err
}
- // Determine package name for cooldown lookup
packageName, _ := extractPyPIName(metadata)
packagePURL := ""
if packageName != "" {
packagePURL = purl.MakePURLString("pypi", packageName, "")
}
- // Filter and rewrite URLs in releases map
- if releases, ok := metadata["releases"].(map[string]any); ok {
- for version, files := range releases {
- if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() && packagePURL != "" {
- if filesArr, ok := files.([]any); ok {
- if publishedAt := h.newestUploadTime(filesArr); !publishedAt.IsZero() {
- if !h.proxy.Cooldown.IsAllowed("pypi", packagePURL, publishedAt) {
- h.proxy.Logger.Info("cooldown: filtering pypi version",
- "package", packageName, "version", version)
- delete(releases, version)
- continue
- }
- }
- }
- }
+ h.filterAndRewriteReleases(metadata, packageName, packagePURL)
+ h.filterAndRewriteURLs(metadata, packagePURL)
- if filesArr, ok := files.([]any); ok {
- for _, f := range filesArr {
- if fmap, ok := f.(map[string]any); ok {
- h.rewriteURLEntry(fmap)
- }
- }
- }
+ return json.Marshal(metadata)
+}
+
+// filterAndRewriteReleases applies cooldown filtering and URL rewriting to the
+// releases map in PyPI metadata.
+func (h *PyPIHandler) filterAndRewriteReleases(metadata map[string]any, packageName, packagePURL string) {
+ releases, ok := metadata["releases"].(map[string]any)
+ if !ok {
+ return
+ }
+
+ for version, files := range releases {
+ if h.shouldFilterRelease(packagePURL, files) {
+ h.proxy.Logger.Info("cooldown: filtering pypi version",
+ "package", packageName, "version", version)
+ delete(releases, version)
+ continue
}
+
+ h.rewriteFileEntries(files)
}
+}
- // Filter and rewrite URLs in urls array (current version files)
- if urls, ok := metadata["urls"].([]any); ok {
- if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() && packagePURL != "" {
- if publishedAt := h.newestUploadTime(urls); !publishedAt.IsZero() {
- if !h.proxy.Cooldown.IsAllowed("pypi", packagePURL, publishedAt) {
- metadata["urls"] = []any{}
- }
- }
+// shouldFilterRelease returns true if a release should be excluded due to cooldown.
+func (h *PyPIHandler) shouldFilterRelease(packagePURL string, files any) bool {
+ if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() || packagePURL == "" {
+ return false
+ }
+
+ filesArr, ok := files.([]any)
+ if !ok {
+ return false
+ }
+
+ publishedAt := h.newestUploadTime(filesArr)
+ return !publishedAt.IsZero() && !h.proxy.Cooldown.IsAllowed("pypi", packagePURL, publishedAt)
+}
+
+// rewriteFileEntries rewrites URLs in a list of file entries.
+func (h *PyPIHandler) rewriteFileEntries(files any) {
+ filesArr, ok := files.([]any)
+ if !ok {
+ return
+ }
+
+ for _, f := range filesArr {
+ if fmap, ok := f.(map[string]any); ok {
+ h.rewriteURLEntry(fmap)
}
+ }
+}
- if urls, ok := metadata["urls"].([]any); ok {
- for _, u := range urls {
- if umap, ok := u.(map[string]any); ok {
- h.rewriteURLEntry(umap)
- }
+// filterAndRewriteURLs applies cooldown filtering and URL rewriting to the
+// urls array (current version files) in PyPI metadata.
+func (h *PyPIHandler) filterAndRewriteURLs(metadata map[string]any, packagePURL string) {
+ urls, ok := metadata["urls"].([]any)
+ if !ok {
+ return
+ }
+
+ if h.shouldFilterRelease(packagePURL, urls) {
+ metadata["urls"] = []any{}
+ }
+
+ if urls, ok := metadata["urls"].([]any); ok {
+ for _, u := range urls {
+ if umap, ok := u.(map[string]any); ok {
+ h.rewriteURLEntry(umap)
}
}
}
-
- return json.Marshal(metadata)
}
// extractPyPIName extracts the package name from PyPI JSON metadata.
@@ -403,7 +434,7 @@ func (h *PyPIHandler) handleDownload(w http.ResponseWriter, r *http.Request) {
// Path format: /packages/{hash_prefix}/{hash}/{filename}
// e.g., /packages/ab/cd/abc123.../requests-2.31.0.tar.gz
parts := strings.Split(path, "/")
- if len(parts) < 3 {
+ if len(parts) < minPyPIPathParts {
http.Error(w, "invalid path", http.StatusBadRequest)
return
}
@@ -442,7 +473,7 @@ func (h *PyPIHandler) parseFilename(filename string) (name, version string) {
if strings.HasSuffix(filename, ".whl") {
base := strings.TrimSuffix(filename, ".whl")
parts := strings.Split(base, "-")
- if len(parts) >= 5 {
+ if len(parts) >= minWheelParts {
// Find where version ends (version followed by python tag)
for i := 1; i < len(parts)-2; i++ {
// Check if this looks like a python tag (py2, py3, cp39, etc)
@@ -472,7 +503,7 @@ func (h *PyPIHandler) parseFilename(filename string) (name, version string) {
}
func isPythonTag(s string) bool {
- if len(s) < 2 {
+ if len(s) < minPythonTagLen {
return false
}
// Python tags start with py, cp, pp, ip, jy
diff --git a/internal/handler/rpm.go b/internal/handler/rpm.go
index de4295a..92da8b6 100644
--- a/internal/handler/rpm.go
+++ b/internal/handler/rpm.go
@@ -2,7 +2,6 @@ package handler
import (
"fmt"
- "io"
"net/http"
"regexp"
"strings"
@@ -11,6 +10,7 @@ import (
const (
// Default upstream for Fedora packages
defaultRPMUpstream = "https://dl.fedoraproject.org/pub/fedora/linux"
+ rpmMatchCount = 5 // full match + name + version + release + arch
)
// RPMHandler handles RPM/Yum repository protocol requests.
@@ -95,67 +95,12 @@ func (h *RPMHandler) handlePackageDownload(w http.ResponseWriter, r *http.Reques
// handleMetadata proxies repository metadata files (repomd.xml, primary.xml.gz, etc.).
// These change frequently so we don't cache them.
func (h *RPMHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) {
- upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path)
-
- h.proxy.Logger.Debug("rpm metadata request", "path", path)
-
- req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- // Forward relevant headers
- for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} {
- if v := r.Header.Get(header); v != "" {
- req.Header.Set(header, v)
- }
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- h.proxy.Logger.Error("failed to fetch upstream metadata", "error", err)
- http.Error(w, "failed to fetch from upstream", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- // Copy response headers
- for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} {
- if v := resp.Header.Get(header); v != "" {
- w.Header().Set(header, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "rpm")
}
// proxyFile proxies any file directly without caching.
func (h *RPMHandler) proxyFile(w http.ResponseWriter, r *http.Request, path string) {
- upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path)
-
- req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil)
- if err != nil {
- http.Error(w, "failed to create request", http.StatusInternalServerError)
- return
- }
-
- resp, err := h.proxy.HTTPClient.Do(req)
- if err != nil {
- http.Error(w, "failed to fetch from upstream", http.StatusBadGateway)
- return
- }
- defer func() { _ = resp.Body.Close() }()
-
- for key, values := range resp.Header {
- for _, v := range values {
- w.Header().Add(key, v)
- }
- }
-
- w.WriteHeader(resp.StatusCode)
- _, _ = io.Copy(w, resp.Body)
+ h.proxy.ProxyFile(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path))
}
// rpmPackagePattern matches .rpm filenames to extract name, version, release, and arch.
@@ -176,7 +121,7 @@ func (h *RPMHandler) parseRPMPath(path string) (name, version, arch string) {
// Parse the filename
matches := rpmPackagePattern.FindStringSubmatch(filename)
- if len(matches) != 5 {
+ if len(matches) != rpmMatchCount {
return "", "", ""
}
diff --git a/internal/handler/rpm_test.go b/internal/handler/rpm_test.go
index b63f44d..851a304 100644
--- a/internal/handler/rpm_test.go
+++ b/internal/handler/rpm_test.go
@@ -1,98 +1,23 @@
package handler
import (
- "net/http"
- "net/http/httptest"
"testing"
)
func TestRPMHandler_parseRPMPath(t *testing.T) {
h := &RPMHandler{}
- tests := []struct {
- path string
- wantName string
- wantVersion string
- wantArch string
- }{
- {
- path: "releases/39/Everything/x86_64/os/Packages/n/nginx-1.24.0-1.fc39.x86_64.rpm",
- wantName: "nginx",
- wantVersion: "1.24.0-1.fc39",
- wantArch: "x86_64",
- },
- {
- path: "Packages/kernel-core-6.5.5-200.fc38.x86_64.rpm",
- wantName: "kernel-core",
- wantVersion: "6.5.5-200.fc38",
- wantArch: "x86_64",
- },
- {
- path: "updates/39/Everything/aarch64/Packages/g/git-2.42.0-1.fc39.aarch64.rpm",
- wantName: "git",
- wantVersion: "2.42.0-1.fc39",
- wantArch: "aarch64",
- },
- {
- path: "vim-enhanced-9.0.1000-1.fc38.noarch.rpm",
- wantName: "vim-enhanced",
- wantVersion: "9.0.1000-1.fc38",
- wantArch: "noarch",
- },
- {
- path: "invalid.rpm",
- wantName: "",
- wantVersion: "",
- wantArch: "",
- },
- {
- path: "not-an-rpm-file",
- wantName: "",
- wantVersion: "",
- wantArch: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.path, func(t *testing.T) {
- name, version, arch := h.parseRPMPath(tt.path)
- if name != tt.wantName {
- t.Errorf("parseRPMPath() name = %q, want %q", name, tt.wantName)
- }
- if version != tt.wantVersion {
- t.Errorf("parseRPMPath() version = %q, want %q", version, tt.wantVersion)
- }
- if arch != tt.wantArch {
- t.Errorf("parseRPMPath() arch = %q, want %q", arch, tt.wantArch)
- }
- })
- }
+ assertPathParser(t, "parseRPMPath", h.parseRPMPath, []pathParseCase{
+ {"releases/39/Everything/x86_64/os/Packages/n/nginx-1.24.0-1.fc39.x86_64.rpm", "nginx", "1.24.0-1.fc39", "x86_64"},
+ {"Packages/kernel-core-6.5.5-200.fc38.x86_64.rpm", "kernel-core", "6.5.5-200.fc38", "x86_64"},
+ {"updates/39/Everything/aarch64/Packages/g/git-2.42.0-1.fc39.aarch64.rpm", "git", "2.42.0-1.fc39", "aarch64"},
+ {"vim-enhanced-9.0.1000-1.fc38.noarch.rpm", "vim-enhanced", "9.0.1000-1.fc38", "noarch"},
+ {"invalid.rpm", "", "", ""},
+ {"not-an-rpm-file", "", "", ""},
+ })
}
func TestRPMHandler_Routes(t *testing.T) {
h := NewRPMHandler(nil, "http://localhost:8080")
-
- // Test that handler doesn't panic on initialization
- handler := h.Routes()
- if handler == nil {
- t.Fatal("Routes() returned nil")
- }
-
- // Test method not allowed
- req := httptest.NewRequest(http.MethodPost, "/repodata/repomd.xml", nil)
- w := httptest.NewRecorder()
- handler.ServeHTTP(w, req)
-
- if w.Code != http.StatusMethodNotAllowed {
- t.Errorf("POST request: got status %d, want %d", w.Code, http.StatusMethodNotAllowed)
- }
-
- // Test path traversal rejection
- req = httptest.NewRequest(http.MethodGet, "/releases/../../../etc/passwd", nil)
- w = httptest.NewRecorder()
- handler.ServeHTTP(w, req)
-
- if w.Code != http.StatusBadRequest {
- t.Errorf("path traversal: got status %d, want %d", w.Code, http.StatusBadRequest)
- }
+ assertRoutesBasics(t, h.Routes(), "/repodata/repomd.xml", "/releases/../../../etc/passwd")
}
diff --git a/internal/server/api.go b/internal/server/api.go
index f756cba..e9b91be 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -12,6 +12,12 @@ import (
"github.com/go-chi/chi/v5"
)
+const (
+ maxBodySize = 1 << 20 // 1 MB
+ licenseCategoryUnknown = "unknown"
+ defaultSortBy = "hits"
+)
+
// APIHandler provides REST endpoints for package enrichment data.
type APIHandler struct {
enrichment *enrichment.Service
@@ -327,7 +333,7 @@ func (h *APIHandler) HandleGetVulns(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {string} string
// @Router /api/outdated [post]
func (h *APIHandler) HandleOutdated(w http.ResponseWriter, r *http.Request) {
- r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MB
+ r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
var req OutdatedRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
@@ -373,7 +379,7 @@ func (h *APIHandler) HandleOutdated(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {string} string
// @Router /api/bulk [post]
func (h *APIHandler) HandleBulkLookup(w http.ResponseWriter, r *http.Request) {
- r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MB
+ r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
var req BulkRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
@@ -573,11 +579,11 @@ func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request)
ecosystem := r.URL.Query().Get("ecosystem")
sortBy := r.URL.Query().Get("sort")
if sortBy == "" {
- sortBy = "hits"
+ sortBy = defaultSortBy
}
validSorts := map[string]bool{
- "hits": true,
+ defaultSortBy: true,
"name": true,
"size": true,
"cached_at": true,
@@ -619,7 +625,7 @@ func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request)
latestVersion = pkg.LatestVersion.String
}
license := ""
- licenseCategory := "unknown"
+ licenseCategory := licenseCategoryUnknown
if pkg.License.Valid {
license = pkg.License.String
if h.enrichment != nil {
diff --git a/internal/server/api_test.go b/internal/server/api_test.go
index 88705a1..96cce9e 100644
--- a/internal/server/api_test.go
+++ b/internal/server/api_test.go
@@ -16,6 +16,8 @@ import (
"github.com/go-chi/chi/v5"
)
+const testEcosystemNPM = "npm"
+
func TestNewAPIHandler(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
svc := enrichment.New(logger)
@@ -178,7 +180,7 @@ func TestWriteJSON(t *testing.T) {
func TestPackageResponseJSON(t *testing.T) {
resp := &PackageResponse{
- Ecosystem: "npm",
+ Ecosystem: testEcosystemNPM,
Name: "lodash",
LatestVersion: "4.17.21",
License: "MIT",
@@ -199,7 +201,7 @@ func TestPackageResponseJSON(t *testing.T) {
t.Fatalf("failed to unmarshal: %v", err)
}
- if decoded.Ecosystem != "npm" {
+ if decoded.Ecosystem != testEcosystemNPM {
t.Errorf("expected ecosystem npm, got %s", decoded.Ecosystem)
}
if decoded.Name != "lodash" {
@@ -310,7 +312,7 @@ func TestHandleSearch_WithNullValues(t *testing.T) {
pkg := &database.Package{
PURL: "pkg:npm/api-test",
- Ecosystem: "npm",
+ Ecosystem: testEcosystemNPM,
Name: "api-test",
}
if err := db.UpsertPackage(pkg); err != nil {
@@ -387,7 +389,7 @@ func TestHandlePackagesListAPI(t *testing.T) {
for _, name := range []string{"api-list-one", "api-list-two"} {
pkg := &database.Package{
PURL: "pkg:npm/" + name,
- Ecosystem: "npm",
+ Ecosystem: testEcosystemNPM,
Name: name,
}
if err := db.UpsertPackage(pkg); err != nil {
@@ -433,7 +435,7 @@ func TestHandlePackagesListAPI(t *testing.T) {
t.Fatalf("expected at least 2 results, got %d", len(resp.Results))
}
- if resp.SortBy != "hits" {
+ if resp.SortBy != defaultSortBy {
t.Errorf("expected default sort by hits, got %q", resp.SortBy)
}
diff --git a/internal/server/browse.go b/internal/server/browse.go
index f19ea67..372df50 100644
--- a/internal/server/browse.go
+++ b/internal/server/browse.go
@@ -15,6 +15,8 @@ import (
"github.com/go-chi/chi/v5"
)
+const contentTypePlainText = "text/plain; charset=utf-8"
+
// getStripPrefix returns the path prefix to strip for a given ecosystem.
// npm packages wrap content in a "package/" directory.
func getStripPrefix(ecosystem string) string {
@@ -240,7 +242,7 @@ func detectContentType(filename string) string {
switch ext {
// Text formats
case ".txt", ".md", ".markdown":
- return "text/plain; charset=utf-8"
+ return contentTypePlainText
case ".html", ".htm":
return "text/html; charset=utf-8"
case ".css":
@@ -282,7 +284,7 @@ func detectContentType(filename string) string {
// Config files
case ".conf", ".config", ".ini":
- return "text/plain; charset=utf-8"
+ return contentTypePlainText
case ".sh", ".bash":
return "text/x-shellscript; charset=utf-8"
case ".dockerfile":
@@ -307,7 +309,7 @@ func detectContentType(filename string) string {
default:
// Try to detect if it looks like text
if isLikelyText(filename) {
- return "text/plain; charset=utf-8"
+ return contentTypePlainText
}
return "application/octet-stream"
}
@@ -482,8 +484,9 @@ func (s *Server) handleComparePage(w http.ResponseWriter, r *http.Request) {
versions := chi.URLParam(r, "versions")
// Parse versions (format: "1.0.0...2.0.0")
+ const compareVersionParts = 2
parts := strings.Split(versions, "...")
- if len(parts) != 2 {
+ if len(parts) != compareVersionParts {
http.Error(w, "invalid version format, use: version1...version2", http.StatusBadRequest)
return
}
diff --git a/internal/server/browse_test.go b/internal/server/browse_test.go
index b85116d..13680a5 100644
--- a/internal/server/browse_test.go
+++ b/internal/server/browse_test.go
@@ -16,6 +16,8 @@ import (
"github.com/git-pkgs/proxy/internal/database"
)
+const testArchiveName = "test.tar.gz"
+
func TestHandleBrowseList(t *testing.T) {
ts := newTestServer(t)
defer ts.close()
@@ -26,12 +28,12 @@ func TestHandleBrowseList(t *testing.T) {
if err := os.MkdirAll(artifactsDir, 0755); err != nil {
t.Fatalf("failed to create artifacts dir: %v", err)
}
- storagePath := filepath.Join(artifactsDir, "test.tar.gz")
+ storagePath := filepath.Join(artifactsDir, testArchiveName)
if err := os.WriteFile(storagePath, archiveData, 0644); err != nil {
t.Fatalf("failed to write test archive: %v", err)
}
// Storage path relative to artifacts directory
- relPath := "test.tar.gz"
+ relPath := testArchiveName
// Setup test package and artifact
pkg := &database.Package{
@@ -99,12 +101,12 @@ func TestHandleBrowseFile(t *testing.T) {
if err := os.MkdirAll(artifactsDir, 0755); err != nil {
t.Fatalf("failed to create artifacts dir: %v", err)
}
- storagePath := filepath.Join(artifactsDir, "test.tar.gz")
+ storagePath := filepath.Join(artifactsDir, testArchiveName)
if err := os.WriteFile(storagePath, archiveData, 0644); err != nil {
t.Fatalf("failed to write test archive: %v", err)
}
// Storage path relative to artifacts directory
- relPath := "test.tar.gz"
+ relPath := testArchiveName
// Setup test package and artifact
pkg := &database.Package{
@@ -150,7 +152,7 @@ func TestHandleBrowseFile(t *testing.T) {
// Check content type
contentType := w.Header().Get("Content-Type")
- if contentType != "text/plain; charset=utf-8" {
+ if contentType != contentTypePlainText {
t.Errorf("expected text/plain content type, got %q", contentType)
}
@@ -169,8 +171,8 @@ func TestDetectContentType(t *testing.T) {
filename string
expectedCT string
}{
- {"file.txt", "text/plain; charset=utf-8"},
- {"file.md", "text/plain; charset=utf-8"},
+ {"file.txt", contentTypePlainText},
+ {"file.md", contentTypePlainText},
{"file.json", "application/json; charset=utf-8"},
{"file.js", "application/javascript; charset=utf-8"},
{"file.go", "text/x-go; charset=utf-8"},
@@ -178,10 +180,10 @@ func TestDetectContentType(t *testing.T) {
{"file.rs", "text/x-rust; charset=utf-8"},
{"file.png", "image/png"},
{"file.jpg", "image/jpeg"},
- {"README", "text/plain; charset=utf-8"},
- {"LICENSE", "text/plain; charset=utf-8"},
- {"Makefile", "text/plain; charset=utf-8"},
- {".gitignore", "text/plain; charset=utf-8"},
+ {"README", contentTypePlainText},
+ {"LICENSE", contentTypePlainText},
+ {"Makefile", contentTypePlainText},
+ {".gitignore", contentTypePlainText},
{"file.bin", "application/octet-stream"},
}
@@ -313,11 +315,11 @@ func TestHandleBrowseSourcePage(t *testing.T) {
if err := os.MkdirAll(artifactsDir, 0755); err != nil {
t.Fatalf("failed to create artifacts dir: %v", err)
}
- storagePath := filepath.Join(artifactsDir, "test.tar.gz")
+ storagePath := filepath.Join(artifactsDir, testArchiveName)
if err := os.WriteFile(storagePath, archiveData, 0644); err != nil {
t.Fatalf("failed to write test archive: %v", err)
}
- relPath := "test.tar.gz"
+ relPath := testArchiveName
// Setup test package and artifact
pkg := &database.Package{
diff --git a/internal/server/packages_list_test.go b/internal/server/packages_list_test.go
index 1b915ad..ac57d74 100644
--- a/internal/server/packages_list_test.go
+++ b/internal/server/packages_list_test.go
@@ -28,7 +28,7 @@ func TestHandlePackagesList(t *testing.T) {
// Create test data
pkg1 := &database.Package{
PURL: "pkg:npm/lodash",
- Ecosystem: "npm",
+ Ecosystem: testEcosystemNPM,
Name: "lodash",
LatestVersion: sql.NullString{String: "4.17.21", Valid: true},
License: sql.NullString{String: "MIT", Valid: true},
@@ -103,7 +103,7 @@ func TestHandlePackagesList(t *testing.T) {
if len(resp.Results) != 2 {
t.Errorf("expected 2 results, got %d", len(resp.Results))
}
- if resp.SortBy != "hits" {
+ if resp.SortBy != defaultSortBy {
t.Errorf("expected default sort to be hits, got %s", resp.SortBy)
}
})
@@ -123,7 +123,7 @@ func TestHandlePackagesList(t *testing.T) {
t.Fatal(err)
}
- if resp.Ecosystem != "npm" {
+ if resp.Ecosystem != testEcosystemNPM {
t.Errorf("expected ecosystem npm, got %s", resp.Ecosystem)
}
if resp.Count != 1 {
diff --git a/internal/server/server.go b/internal/server/server.go
index 19eb468..8e6b588 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -60,6 +60,14 @@ import (
"github.com/go-chi/chi/v5/middleware"
)
+const (
+ serverReadTimeout = 30 * time.Second
+ serverWriteTimeout = 5 * time.Minute
+ serverIdleTimeout = 60 * time.Second
+ dashboardTopN = 10
+ hoursPerDay = 24
+)
+
// Server is the main proxy server.
type Server struct {
cfg *config.Config
@@ -96,7 +104,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*Server, error) {
storageURL := cfg.Storage.URL
if storageURL == "" {
// Fall back to file:// with Path
- storageURL = "file://" + cfg.Storage.Path
+ storageURL = "file://" + cfg.Storage.Path //nolint:staticcheck // backwards compat
}
store, err := storage.OpenBucket(context.Background(), storageURL)
if err != nil {
@@ -228,15 +236,15 @@ func (s *Server) Start() error {
s.http = &http.Server{
Addr: s.cfg.Listen,
Handler: r,
- ReadTimeout: 30 * time.Second,
- WriteTimeout: 5 * time.Minute, // Large artifacts need time
- IdleTimeout: 60 * time.Second,
+ ReadTimeout: serverReadTimeout,
+ WriteTimeout: serverWriteTimeout, // Large artifacts need time
+ IdleTimeout: serverIdleTimeout,
}
s.logger.Info("starting server",
"listen", s.cfg.Listen,
"base_url", s.cfg.BaseURL,
- "storage", s.cfg.Storage.Path,
+ "storage", s.cfg.Storage.Path, //nolint:staticcheck // backwards compat
"database", s.cfg.Database.Path)
// Start background goroutine to update cache stats metrics
@@ -316,13 +324,13 @@ func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
}
// Get popular packages
- popular, err := s.db.GetMostPopularPackages(10)
+ popular, err := s.db.GetMostPopularPackages(dashboardTopN)
if err != nil {
s.logger.Error("failed to get popular packages", "error", err)
}
// Get recent packages
- recent, err := s.db.GetRecentlyCachedPackages(10)
+ recent, err := s.db.GetRecentlyCachedPackages(dashboardTopN)
if err != nil {
s.logger.Error("failed to get recent packages", "error", err)
}
@@ -497,7 +505,7 @@ func (s *Server) handlePackagesList(w http.ResponseWriter, r *http.Request) {
ecosystem := r.URL.Query().Get("ecosystem")
sortBy := r.URL.Query().Get("sort")
if sortBy == "" {
- sortBy = "hits"
+ sortBy = defaultSortBy
}
page := 1
@@ -731,7 +739,7 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
CachedArtifacts: count,
TotalSize: size,
TotalSizeHuman: formatSize(size),
- StoragePath: s.cfg.Storage.Path,
+ StoragePath: s.cfg.Storage.Path, //nolint:staticcheck // backwards compat
DatabasePath: s.cfg.Database.Path,
}
@@ -766,14 +774,14 @@ func formatTimeAgo(t time.Time) string {
return "1 min ago"
}
return fmt.Sprintf("%d mins ago", m)
- case d < 24*time.Hour:
+ case d < hoursPerDay*time.Hour:
h := int(d.Hours())
if h == 1 {
return "1 hour ago"
}
return fmt.Sprintf("%d hours ago", h)
- case d < 7*24*time.Hour:
- days := int(d.Hours() / 24)
+ case d < 7*hoursPerDay*time.Hour:
+ days := int(d.Hours() / hoursPerDay)
if days == 1 {
return "1 day ago"
}
@@ -786,7 +794,7 @@ func formatTimeAgo(t time.Time) string {
// categorizeLicenseCSS returns the CSS class suffix for a license category using the spdx module.
func categorizeLicenseCSS(license string) string {
if license == "" {
- return "unknown"
+ return licenseCategoryUnknown
}
if spdx.HasCopyleft(license) {
@@ -797,13 +805,13 @@ func categorizeLicenseCSS(license string) string {
return "permissive"
}
- return "unknown"
+ return licenseCategoryUnknown
}
// categorizeLicense is a helper that handles sql.NullString.
func categorizeLicense(license sql.NullString) string {
if !license.Valid {
- return "unknown"
+ return licenseCategoryUnknown
}
return categorizeLicenseCSS(license.String)
}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 555e80a..69f36e8 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -125,6 +125,39 @@ func (ts *testServer) close() {
_ = os.RemoveAll(ts.tempDir)
}
+// seedTestPackage creates a package, version, and artifact in the database for testing
+// page rendering. The package is created under the npm ecosystem with version 1.0.0.
+func seedTestPackage(t *testing.T, db *database.DB, name string) {
+ t.Helper()
+
+ pkg := &database.Package{
+ PURL: "pkg:npm/" + name,
+ Ecosystem: "npm",
+ Name: name,
+ }
+ if err := db.UpsertPackage(pkg); err != nil {
+ t.Fatalf("failed to upsert package: %v", err)
+ }
+
+ ver := &database.Version{
+ PURL: "pkg:npm/" + name + "@1.0.0",
+ PackagePURL: pkg.PURL,
+ }
+ if err := db.UpsertVersion(ver); err != nil {
+ t.Fatalf("failed to upsert version: %v", err)
+ }
+
+ artifact := &database.Artifact{
+ VersionPURL: ver.PURL,
+ Filename: name + "-1.0.0.tgz",
+ UpstreamURL: "https://registry.npmjs.org/" + name + "/-/" + name + "-1.0.0.tgz",
+ StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true},
+ }
+ if err := db.UpsertArtifact(artifact); err != nil {
+ t.Fatalf("failed to upsert artifact: %v", err)
+ }
+}
+
func TestHandleOpenAPIJSON(t *testing.T) {
ts := newTestServer(t)
defer ts.close()
@@ -668,32 +701,7 @@ func TestSearchPage_WithSeededResults(t *testing.T) {
ts := newTestServer(t)
defer ts.close()
- pkg := &database.Package{
- PURL: "pkg:npm/searchable-pkg",
- Ecosystem: "npm",
- Name: "searchable-pkg",
- }
- if err := ts.db.UpsertPackage(pkg); err != nil {
- t.Fatalf("failed to upsert package: %v", err)
- }
-
- ver := &database.Version{
- PURL: "pkg:npm/searchable-pkg@1.0.0",
- PackagePURL: pkg.PURL,
- }
- if err := ts.db.UpsertVersion(ver); err != nil {
- t.Fatalf("failed to upsert version: %v", err)
- }
-
- artifact := &database.Artifact{
- VersionPURL: ver.PURL,
- Filename: "searchable-pkg-1.0.0.tgz",
- UpstreamURL: "https://registry.npmjs.org/searchable-pkg/-/searchable-pkg-1.0.0.tgz",
- StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true},
- }
- if err := ts.db.UpsertArtifact(artifact); err != nil {
- t.Fatalf("failed to upsert artifact: %v", err)
- }
+ seedTestPackage(t, ts.db, "searchable-pkg")
req := httptest.NewRequest("GET", "/search?q=searchable", nil)
w := httptest.NewRecorder()
@@ -844,32 +852,7 @@ func TestHandlePackagesListPage(t *testing.T) {
ts := newTestServer(t)
defer ts.close()
- pkg := &database.Package{
- PURL: "pkg:npm/list-test",
- Ecosystem: "npm",
- Name: "list-test",
- }
- if err := ts.db.UpsertPackage(pkg); err != nil {
- t.Fatalf("failed to upsert package: %v", err)
- }
-
- ver := &database.Version{
- PURL: "pkg:npm/list-test@1.0.0",
- PackagePURL: pkg.PURL,
- }
- if err := ts.db.UpsertVersion(ver); err != nil {
- t.Fatalf("failed to upsert version: %v", err)
- }
-
- artifact := &database.Artifact{
- VersionPURL: ver.PURL,
- Filename: "list-test-1.0.0.tgz",
- UpstreamURL: "https://registry.npmjs.org/list-test/-/list-test-1.0.0.tgz",
- StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true},
- }
- if err := ts.db.UpsertArtifact(artifact); err != nil {
- t.Fatalf("failed to upsert artifact: %v", err)
- }
+ seedTestPackage(t, ts.db, "list-test")
req := httptest.NewRequest("GET", "/packages", nil)
w := httptest.NewRecorder()
diff --git a/internal/server/templates_test.go b/internal/server/templates_test.go
index 3a67880..9d26269 100644
--- a/internal/server/templates_test.go
+++ b/internal/server/templates_test.go
@@ -62,7 +62,7 @@ func TestTemplatesRenderAllPages(t *testing.T) {
}},
{"packages_list", PackagesListPageData{
Ecosystem: "",
- SortBy: "hits",
+ SortBy: defaultSortBy,
Results: []SearchResultItem{{Ecosystem: "npm", Name: "express", Hits: 200, SizeFormatted: "2 MB"}},
Count: 1,
Page: 1,
diff --git a/internal/storage/blob.go b/internal/storage/blob.go
index 22643fa..11da357 100644
--- a/internal/storage/blob.go
+++ b/internal/storage/blob.go
@@ -17,6 +17,8 @@ import (
"gocloud.dev/gcerrors"
)
+const osWindows = "windows"
+
// Blob implements Storage using gocloud.dev/blob.
// Supports local filesystem (file://) and S3 (s3://) URLs.
type Blob struct {
@@ -38,10 +40,10 @@ func OpenBucket(ctx context.Context, urlStr string) (*Blob, error) {
path := strings.TrimPrefix(urlStr, "file://")
// Handle file:/// (three slashes) for absolute paths
- if strings.HasPrefix(path, "/") && runtime.GOOS != "windows" {
+ if strings.HasPrefix(path, "/") && runtime.GOOS != osWindows {
// Unix: file:///path -> /path
// path is already correct
- } else if strings.HasPrefix(path, "/") && runtime.GOOS == "windows" {
+ } else if strings.HasPrefix(path, "/") && runtime.GOOS == osWindows {
// Windows: file:///C:/path -> C:/path
path = strings.TrimPrefix(path, "/")
}
@@ -50,7 +52,7 @@ func OpenBucket(ctx context.Context, urlStr string) (*Blob, error) {
nativePath := filepath.FromSlash(path)
// Ensure directory exists
- if err := os.MkdirAll(nativePath, 0755); err != nil {
+ if err := os.MkdirAll(nativePath, dirPermissions); err != nil {
return nil, fmt.Errorf("creating directory: %w", err)
}
@@ -62,7 +64,7 @@ func OpenBucket(ctx context.Context, urlStr string) (*Blob, error) {
// Convert back to URL format with forward slashes
urlPath := filepath.ToSlash(absPath)
- if runtime.GOOS == "windows" {
+ if runtime.GOOS == osWindows {
// Windows needs file:///C:/path format
urlStr = "file:///" + urlPath
} else {
diff --git a/internal/storage/blob_test.go b/internal/storage/blob_test.go
index 3add0a7..cfb8281 100644
--- a/internal/storage/blob_test.go
+++ b/internal/storage/blob_test.go
@@ -1,7 +1,6 @@
package storage
import (
- "bytes"
"context"
"crypto/sha256"
"encoding/hex"
@@ -186,33 +185,7 @@ func TestBlobUsedSpace(t *testing.T) {
}
func TestBlobLargeFile(t *testing.T) {
- b := createTestBlob(t)
- ctx := context.Background()
-
- // 1MB of data
- data := bytes.Repeat([]byte("x"), 1024*1024)
-
- size, hash, err := b.Store(ctx, "large/file.bin", bytes.NewReader(data))
- if err != nil {
- t.Fatalf("Store large file failed: %v", err)
- }
- if size != int64(len(data)) {
- t.Errorf("size = %d, want %d", size, len(data))
- }
-
- h := sha256.Sum256(data)
- wantHash := hex.EncodeToString(h[:])
- if hash != wantHash {
- t.Errorf("hash mismatch for large file")
- }
-
- // Read it back
- r, _ := b.Open(ctx, "large/file.bin")
- defer func() { _ = r.Close() }()
- readBack, _ := io.ReadAll(r)
- if !bytes.Equal(readBack, data) {
- t.Error("large file content mismatch")
- }
+ assertLargeFileRoundTrip(t, createTestBlob(t))
}
func TestBlobOverwrite(t *testing.T) {
@@ -258,7 +231,7 @@ func createTestBlob(t *testing.T) *Blob {
}
func fileURLFromPath(path string) string {
- if runtime.GOOS == "windows" {
+ if runtime.GOOS == osWindows {
// Windows paths need file:///C:/path format
path = filepath.ToSlash(path)
return "file:///" + path
diff --git a/internal/storage/filesystem.go b/internal/storage/filesystem.go
index 01be90d..8dec48b 100644
--- a/internal/storage/filesystem.go
+++ b/internal/storage/filesystem.go
@@ -23,7 +23,7 @@ func NewFilesystem(root string) (*Filesystem, error) {
return nil, fmt.Errorf("resolving root path: %w", err)
}
- if err := os.MkdirAll(absRoot, 0755); err != nil {
+ if err := os.MkdirAll(absRoot, dirPermissions); err != nil {
return nil, fmt.Errorf("creating root directory: %w", err)
}
@@ -38,7 +38,7 @@ func (fs *Filesystem) Store(ctx context.Context, path string, r io.Reader) (int6
fullPath := fs.fullPath(path)
dir := filepath.Dir(fullPath)
- if err := os.MkdirAll(dir, 0755); err != nil {
+ if err := os.MkdirAll(dir, dirPermissions); err != nil {
return 0, "", fmt.Errorf("creating directory: %w", err)
}
diff --git a/internal/storage/filesystem_test.go b/internal/storage/filesystem_test.go
index 16c7447..7fbba10 100644
--- a/internal/storage/filesystem_test.go
+++ b/internal/storage/filesystem_test.go
@@ -1,7 +1,6 @@
package storage
import (
- "bytes"
"context"
"crypto/sha256"
"encoding/hex"
@@ -234,33 +233,7 @@ func TestFilesystemUsedSpace(t *testing.T) {
}
func TestFilesystemLargeFile(t *testing.T) {
- fs := createTestFilesystem(t)
- ctx := context.Background()
-
- // 1MB of data
- data := bytes.Repeat([]byte("x"), 1024*1024)
-
- size, hash, err := fs.Store(ctx, "large/file.bin", bytes.NewReader(data))
- if err != nil {
- t.Fatalf("Store large file failed: %v", err)
- }
- if size != int64(len(data)) {
- t.Errorf("size = %d, want %d", size, len(data))
- }
-
- h := sha256.Sum256(data)
- wantHash := hex.EncodeToString(h[:])
- if hash != wantHash {
- t.Errorf("hash mismatch for large file")
- }
-
- // Read it back
- r, _ := fs.Open(ctx, "large/file.bin")
- defer func() { _ = r.Close() }()
- readBack, _ := io.ReadAll(r)
- if !bytes.Equal(readBack, data) {
- t.Error("large file content mismatch")
- }
+ assertLargeFileRoundTrip(t, createTestFilesystem(t))
}
func createTestFilesystem(t *testing.T) *Filesystem {
diff --git a/internal/storage/storage.go b/internal/storage/storage.go
index 1efbcba..93053ca 100644
--- a/internal/storage/storage.go
+++ b/internal/storage/storage.go
@@ -17,6 +17,8 @@ import (
"io"
)
+const dirPermissions = 0755
+
var (
ErrNotFound = errors.New("artifact not found")
)
diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go
index c71df26..97800f0 100644
--- a/internal/storage/storage_test.go
+++ b/internal/storage/storage_test.go
@@ -1,6 +1,8 @@
package storage
import (
+ "bytes"
+ "context"
"crypto/sha256"
"encoding/hex"
"io"
@@ -56,3 +58,33 @@ func TestHashingReader(t *testing.T) {
t.Errorf("got hash %s, want %s", r.Sum(), wantHash)
}
}
+
+// assertLargeFileRoundTrip stores a 1MB file in the given storage, verifies size and
+// hash, then reads it back and confirms the content matches.
+func assertLargeFileRoundTrip(t *testing.T, s Storage) {
+ t.Helper()
+ ctx := context.Background()
+
+ data := bytes.Repeat([]byte("x"), 1024*1024)
+
+ size, hash, err := s.Store(ctx, "large/file.bin", bytes.NewReader(data))
+ if err != nil {
+ t.Fatalf("Store large file failed: %v", err)
+ }
+ if size != int64(len(data)) {
+ t.Errorf("size = %d, want %d", size, len(data))
+ }
+
+ h := sha256.Sum256(data)
+ wantHash := hex.EncodeToString(h[:])
+ if hash != wantHash {
+ t.Errorf("hash mismatch for large file")
+ }
+
+ r, _ := s.Open(ctx, "large/file.bin")
+ defer func() { _ = r.Close() }()
+ readBack, _ := io.ReadAll(r)
+ if !bytes.Equal(readBack, data) {
+ t.Error("large file content mismatch")
+ }
+}