diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 5756c6d..4f13ea2 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -103,6 +103,8 @@ import ( "github.com/git-pkgs/proxy/internal/server" ) +const defaultTopN = 10 + var ( // Version is set at build time. Version = "dev" @@ -211,7 +213,7 @@ func runServe() { cfg.Storage.URL = *storageURL } if *storagePath != "" { - cfg.Storage.Path = *storagePath + cfg.Storage.Path = *storagePath //nolint:staticcheck // backwards compat } if *databaseDriver != "" { cfg.Database.Driver = *databaseDriver @@ -247,7 +249,6 @@ func runServe() { // Handle shutdown signals ctx, cancel := context.WithCancel(context.Background()) - defer cancel() go func() { sigCh := make(chan os.Signal, 1) @@ -266,10 +267,12 @@ func runServe() { // Wait for shutdown or error select { case <-ctx.Done(): + cancel() if err := srv.Shutdown(context.Background()); err != nil { logger.Error("shutdown error", "error", err) } case err := <-errCh: + cancel() if err != nil { logger.Error("server error", "error", err) os.Exit(1) @@ -283,8 +286,8 @@ func runStats() { databasePath := fs.String("database-path", "./cache/proxy.db", "Path to SQLite database file") databaseURL := fs.String("database-url", "", "PostgreSQL connection URL") asJSON := fs.Bool("json", false, "Output as JSON") - popular := fs.Int("popular", 10, "Show top N most popular packages") - recent := fs.Int("recent", 10, "Show N recently cached packages") + popular := fs.Int("popular", defaultTopN, "Show top N most popular packages") + recent := fs.Int("recent", defaultTopN, "Show N recently cached packages") fs.Usage = func() { fmt.Fprintf(os.Stderr, "git-pkgs proxy - Show cache statistics\n\n") @@ -330,32 +333,37 @@ func runStats() { fmt.Fprintf(os.Stderr, "error opening database: %v\n", err) os.Exit(1) } + + if err := printStats(db, *popular, *recent, *asJSON); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} + +func printStats(db *database.DB, popular, recent int, asJSON bool) error { defer func() { _ = db.Close() }() - // Get stats stats, err := db.GetCacheStats() if err != nil { - fmt.Fprintf(os.Stderr, "error getting stats: %v\n", err) - os.Exit(1) + return fmt.Errorf("error getting stats: %w", err) } - popularPkgs, err := db.GetMostPopularPackages(*popular) + popularPkgs, err := db.GetMostPopularPackages(popular) if err != nil { - fmt.Fprintf(os.Stderr, "error getting popular packages: %v\n", err) - os.Exit(1) + return fmt.Errorf("error getting popular packages: %w", err) } - recentPkgs, err := db.GetRecentlyCachedPackages(*recent) + recentPkgs, err := db.GetRecentlyCachedPackages(recent) if err != nil { - fmt.Fprintf(os.Stderr, "error getting recent packages: %v\n", err) - os.Exit(1) + return fmt.Errorf("error getting recent packages: %w", err) } - if *asJSON { + if asJSON { outputJSON(stats, popularPkgs, recentPkgs) } else { outputText(stats, popularPkgs, recentPkgs) } + return nil } type jsonOutput struct { diff --git a/internal/config/config.go b/internal/config/config.go index 76067ae..3bc45af 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -109,8 +109,9 @@ type StorageConfig struct { URL string `json:"url" yaml:"url"` // Path is the directory where cached artifacts are stored. - // Deprecated: Use URL with file:// scheme instead. // If URL is empty, this is used as file://{Path}. + // + // Deprecated: Use URL with file:// scheme instead. Path string `json:"path" yaml:"path"` // MaxSize is the maximum cache size (e.g., "10GB", "500MB"). diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 47bacf0..4dd1c17 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -6,6 +6,12 @@ import ( "testing" ) +const ( + testDriverPostgres = "postgres" + testInvalid = "invalid" + testLevelDebug = "debug" +) + func TestDefault(t *testing.T) { cfg := Default() @@ -63,27 +69,27 @@ func TestValidate(t *testing.T) { }, { name: "postgres without url", - modify: func(c *Config) { c.Database.Driver = "postgres"; c.Database.URL = "" }, + modify: func(c *Config) { c.Database.Driver = testDriverPostgres; c.Database.URL = "" }, wantErr: true, }, { name: "postgres with url", - modify: func(c *Config) { c.Database.Driver = "postgres"; c.Database.URL = "postgres://localhost/test" }, + modify: func(c *Config) { c.Database.Driver = testDriverPostgres; c.Database.URL = "postgres://localhost/test" }, wantErr: false, }, { name: "invalid log level", - modify: func(c *Config) { c.Log.Level = "invalid" }, + modify: func(c *Config) { c.Log.Level = testInvalid }, wantErr: true, }, { name: "invalid log format", - modify: func(c *Config) { c.Log.Format = "invalid" }, + modify: func(c *Config) { c.Log.Format = testInvalid }, wantErr: true, }, { name: "invalid max size", - modify: func(c *Config) { c.Storage.MaxSize = "invalid" }, + modify: func(c *Config) { c.Storage.MaxSize = testInvalid }, wantErr: true, }, { @@ -176,8 +182,8 @@ log: if cfg.Storage.MaxSize != "5GB" { t.Errorf("Storage.MaxSize = %q, want %q", cfg.Storage.MaxSize, "5GB") } - if cfg.Log.Level != "debug" { - t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug") + if cfg.Log.Level != testLevelDebug { + t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, testLevelDebug) } if cfg.Log.Format != "json" { t.Errorf("Log.Format = %q, want %q", cfg.Log.Format, "json") @@ -215,7 +221,7 @@ func TestLoadFromEnv(t *testing.T) { t.Setenv("PROXY_LISTEN", ":9000") t.Setenv("PROXY_BASE_URL", "https://env.example.com") t.Setenv("PROXY_STORAGE_PATH", "/env/cache") - t.Setenv("PROXY_LOG_LEVEL", "debug") + t.Setenv("PROXY_LOG_LEVEL", testLevelDebug) cfg.LoadFromEnv() @@ -228,8 +234,8 @@ func TestLoadFromEnv(t *testing.T) { if cfg.Storage.Path != "/env/cache" { t.Errorf("Storage.Path = %q, want %q", cfg.Storage.Path, "/env/cache") } - if cfg.Log.Level != "debug" { - t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, "debug") + if cfg.Log.Level != testLevelDebug { + t.Errorf("Log.Level = %q, want %q", cfg.Log.Level, testLevelDebug) } } diff --git a/internal/cooldown/cooldown.go b/internal/cooldown/cooldown.go index 2bb4fc6..f37a2b9 100644 --- a/internal/cooldown/cooldown.go +++ b/internal/cooldown/cooldown.go @@ -7,6 +7,8 @@ import ( "time" ) +const hoursPerDay = 24 + // Config holds cooldown settings for version filtering. // Cooldown hides package versions published too recently, giving the community // time to spot malicious releases before they're pulled into projects. @@ -112,7 +114,7 @@ func ParseDuration(s string) (time.Duration, error) { if err != nil { return 0, fmt.Errorf("invalid duration %q: %w", s, err) } - return time.Duration(days * float64(24*time.Hour)), nil + return time.Duration(days * float64(hoursPerDay*time.Hour)), nil } d, err := time.ParseDuration(s) diff --git a/internal/database/database.go b/internal/database/database.go index c04c122..eded6d2 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -12,6 +12,8 @@ import ( const SchemaVersion = 1 +const dirPermissions = 0755 + type Dialect string const ( @@ -56,7 +58,7 @@ func Create(path string) (*DB, error) { func Open(path string) (*DB, error) { if dir := filepath.Dir(path); dir != "." && dir != "/" { - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, dirPermissions); err != nil { return nil, fmt.Errorf("creating database directory: %w", err) } } diff --git a/internal/database/queries_packages_list_test.go b/internal/database/queries_packages_list_test.go index b77d34d..fc9356c 100644 --- a/internal/database/queries_packages_list_test.go +++ b/internal/database/queries_packages_list_test.go @@ -6,118 +6,121 @@ import ( "time" ) -func TestListCachedPackages(t *testing.T) { +const testEcosystemNPM = "npm" + +func setupListCachedPackagesDB(t *testing.T) *DB { + t.Helper() + db, err := Create(t.TempDir() + "/test.db") if err != nil { t.Fatal(err) } - defer func() { _ = db.Close() }() - // Create test packages - pkg1 := &Package{ - PURL: "pkg:npm/lodash", - Ecosystem: "npm", - Name: "lodash", - LatestVersion: sql.NullString{String: "4.17.21", Valid: true}, - License: sql.NullString{String: "MIT", Valid: true}, - } - pkg2 := &Package{ - PURL: "pkg:cargo/serde", - Ecosystem: "cargo", - Name: "serde", - LatestVersion: sql.NullString{String: "1.0.0", Valid: true}, - License: sql.NullString{String: "MIT OR Apache-2.0", Valid: true}, - } - pkg3 := &Package{ - PURL: "pkg:npm/react", - Ecosystem: "npm", - Name: "react", - LatestVersion: sql.NullString{String: "18.0.0", Valid: true}, - License: sql.NullString{String: "MIT", Valid: true}, - } + seedListCachedPackagesData(t, db) - if err := db.UpsertPackage(pkg1); err != nil { - t.Fatal(err) - } - if err := db.UpsertPackage(pkg2); err != nil { - t.Fatal(err) - } - if err := db.UpsertPackage(pkg3); err != nil { - t.Fatal(err) - } + return db +} - // Create versions - ver1 := &Version{ - PURL: "pkg:npm/lodash@4.17.21", - PackagePURL: pkg1.PURL, - } - ver2 := &Version{ - PURL: "pkg:cargo/serde@1.0.0", - PackagePURL: pkg2.PURL, - } - ver3 := &Version{ - PURL: "pkg:npm/react@18.0.0", - PackagePURL: pkg3.PURL, - } +func seedListCachedPackagesData(t *testing.T, db *DB) { + t.Helper() - if err := db.UpsertVersion(ver1); err != nil { - t.Fatal(err) - } - if err := db.UpsertVersion(ver2); err != nil { - t.Fatal(err) - } - if err := db.UpsertVersion(ver3); err != nil { - t.Fatal(err) + packages := []*Package{ + { + PURL: "pkg:npm/lodash", + Ecosystem: testEcosystemNPM, + Name: "lodash", + LatestVersion: sql.NullString{String: "4.17.21", Valid: true}, + License: sql.NullString{String: "MIT", Valid: true}, + }, + { + PURL: "pkg:cargo/serde", + Ecosystem: "cargo", + Name: "serde", + LatestVersion: sql.NullString{String: "1.0.0", Valid: true}, + License: sql.NullString{String: "MIT OR Apache-2.0", Valid: true}, + }, + { + PURL: "pkg:npm/react", + Ecosystem: testEcosystemNPM, + Name: "react", + LatestVersion: sql.NullString{String: "18.0.0", Valid: true}, + License: sql.NullString{String: "MIT", Valid: true}, + }, } - // Create artifacts - art1 := &Artifact{ - VersionPURL: ver1.PURL, - Filename: "lodash.tgz", - UpstreamURL: "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - StoragePath: sql.NullString{String: "npm/lodash/4.17.21/lodash.tgz", Valid: true}, - Size: sql.NullInt64{Int64: 1024, Valid: true}, - HitCount: 100, - FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, - } - art2 := &Artifact{ - VersionPURL: ver2.PURL, - Filename: "serde.crate", - UpstreamURL: "https://crates.io/api/v1/crates/serde/1.0.0/download", - StoragePath: sql.NullString{String: "cargo/serde/1.0.0/serde.crate", Valid: true}, - Size: sql.NullInt64{Int64: 2048, Valid: true}, - HitCount: 50, - FetchedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + for _, pkg := range packages { + if err := db.UpsertPackage(pkg); err != nil { + t.Fatal(err) + } } - art3 := &Artifact{ - VersionPURL: ver3.PURL, - Filename: "react.tgz", - UpstreamURL: "https://registry.npmjs.org/react/-/react-18.0.0.tgz", - StoragePath: sql.NullString{String: "npm/react/18.0.0/react.tgz", Valid: true}, - Size: sql.NullInt64{Int64: 512, Valid: true}, - HitCount: 200, - FetchedAt: sql.NullTime{Time: time.Now().Add(-2 * time.Hour), Valid: true}, + + versions := []*Version{ + {PURL: "pkg:npm/lodash@4.17.21", PackagePURL: packages[0].PURL}, + {PURL: "pkg:cargo/serde@1.0.0", PackagePURL: packages[1].PURL}, + {PURL: "pkg:npm/react@18.0.0", PackagePURL: packages[2].PURL}, } - if err := db.UpsertArtifact(art1); err != nil { - t.Fatal(err) + for _, ver := range versions { + if err := db.UpsertVersion(ver); err != nil { + t.Fatal(err) + } } - if err := db.UpsertArtifact(art2); err != nil { - t.Fatal(err) + + artifacts := []*Artifact{ + { + VersionPURL: versions[0].PURL, + Filename: "lodash.tgz", + UpstreamURL: "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + StoragePath: sql.NullString{String: "npm/lodash/4.17.21/lodash.tgz", Valid: true}, + Size: sql.NullInt64{Int64: 1024, Valid: true}, + HitCount: 100, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + }, + { + VersionPURL: versions[1].PURL, + Filename: "serde.crate", + UpstreamURL: "https://crates.io/api/v1/crates/serde/1.0.0/download", + StoragePath: sql.NullString{String: "cargo/serde/1.0.0/serde.crate", Valid: true}, + Size: sql.NullInt64{Int64: 2048, Valid: true}, + HitCount: 50, + FetchedAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true}, + }, + { + VersionPURL: versions[2].PURL, + Filename: "react.tgz", + UpstreamURL: "https://registry.npmjs.org/react/-/react-18.0.0.tgz", + StoragePath: sql.NullString{String: "npm/react/18.0.0/react.tgz", Valid: true}, + Size: sql.NullInt64{Int64: 512, Valid: true}, + HitCount: 200, + FetchedAt: sql.NullTime{Time: time.Now().Add(-2 * time.Hour), Valid: true}, + }, } - if err := db.UpsertArtifact(art3); err != nil { - t.Fatal(err) + + for _, art := range artifacts { + if err := db.UpsertArtifact(art); err != nil { + t.Fatal(err) + } } +} - t.Run("list all packages", func(t *testing.T) { - packages, err := db.ListCachedPackages("", "hits", 10, 0) +func TestListCachedPackages(t *testing.T) { + db := setupListCachedPackagesDB(t) + defer func() { _ = db.Close() }() + + listAll := func(ecosystem, sortBy string) []PackageListItem { + t.Helper() + packages, err := db.ListCachedPackages(ecosystem, sortBy, 10, 0) if err != nil { t.Fatal(err) } + return packages + } + + t.Run("list all packages", func(t *testing.T) { + packages := listAll("", "hits") if len(packages) != 3 { t.Errorf("expected 3 packages, got %d", len(packages)) } - // Should be sorted by hits DESC if packages[0].Name != "react" { t.Errorf("expected first package to be react, got %s", packages[0].Name) } @@ -127,35 +130,26 @@ func TestListCachedPackages(t *testing.T) { }) t.Run("filter by ecosystem", func(t *testing.T) { - packages, err := db.ListCachedPackages("npm", "hits", 10, 0) - if err != nil { - t.Fatal(err) - } + packages := listAll(testEcosystemNPM, "hits") if len(packages) != 2 { t.Errorf("expected 2 npm packages, got %d", len(packages)) } for _, pkg := range packages { - if pkg.Ecosystem != "npm" { + if pkg.Ecosystem != testEcosystemNPM { t.Errorf("expected npm ecosystem, got %s", pkg.Ecosystem) } } }) t.Run("sort by name", func(t *testing.T) { - packages, err := db.ListCachedPackages("", "name", 10, 0) - if err != nil { - t.Fatal(err) - } + packages := listAll("", "name") if packages[0].Name != "lodash" { t.Errorf("expected first package to be lodash, got %s", packages[0].Name) } }) t.Run("sort by size", func(t *testing.T) { - packages, err := db.ListCachedPackages("", "size", 10, 0) - if err != nil { - t.Fatal(err) - } + packages := listAll("", "size") if packages[0].Name != "serde" { t.Errorf("expected first package to be serde (largest), got %s", packages[0].Name) } @@ -170,7 +164,7 @@ func TestListCachedPackages(t *testing.T) { t.Errorf("expected count 3, got %d", count) } - count, err = db.CountCachedPackages("npm") + count, err = db.CountCachedPackages(testEcosystemNPM) if err != nil { t.Fatal(err) } diff --git a/internal/database/schema.go b/internal/database/schema.go index 13363e7..496a129 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -2,6 +2,8 @@ package database import "fmt" +const postgresTimestamp = "TIMESTAMP" + // Schema for proxy-specific tables. The packages and versions tables // are compatible with git-pkgs, allowing the proxy to use an existing // git-pkgs database as a starting point. @@ -303,8 +305,8 @@ func (db *DB) MigrateSchema() error { } if db.dialect == DialectPostgres { - packagesColumns["enriched_at"] = "TIMESTAMP" - packagesColumns["vulns_synced_at"] = "TIMESTAMP" + packagesColumns["enriched_at"] = postgresTimestamp + packagesColumns["vulns_synced_at"] = postgresTimestamp } for column, colType := range packagesColumns { @@ -313,12 +315,7 @@ func (db *DB) MigrateSchema() error { return fmt.Errorf("checking column %s: %w", column, err) } if !hasCol { - var alterQuery string - if db.dialect == DialectPostgres { - alterQuery = fmt.Sprintf("ALTER TABLE packages ADD COLUMN %s %s", column, colType) - } else { - alterQuery = fmt.Sprintf("ALTER TABLE packages ADD COLUMN %s %s", column, colType) - } + alterQuery := fmt.Sprintf("ALTER TABLE packages ADD COLUMN %s %s", column, colType) if _, err := db.Exec(alterQuery); err != nil { return fmt.Errorf("adding column %s to packages: %w", column, err) } @@ -335,7 +332,7 @@ func (db *DB) MigrateSchema() error { if db.dialect == DialectPostgres { versionsColumns["yanked"] = "BOOLEAN DEFAULT FALSE" - versionsColumns["enriched_at"] = "TIMESTAMP" + versionsColumns["enriched_at"] = postgresTimestamp } for column, colType := range versionsColumns { @@ -344,12 +341,7 @@ func (db *DB) MigrateSchema() error { return fmt.Errorf("checking column %s: %w", column, err) } if !hasCol { - var alterQuery string - if db.dialect == DialectPostgres { - alterQuery = fmt.Sprintf("ALTER TABLE versions ADD COLUMN %s %s", column, colType) - } else { - alterQuery = fmt.Sprintf("ALTER TABLE versions ADD COLUMN %s %s", column, colType) - } + alterQuery := fmt.Sprintf("ALTER TABLE versions ADD COLUMN %s %s", column, colType) if _, err := db.Exec(alterQuery); err != nil { return fmt.Errorf("adding column %s to versions: %w", column, err) } diff --git a/internal/handler/cargo.go b/internal/handler/cargo.go index 8aab5f0..b5a3fb0 100644 --- a/internal/handler/cargo.go +++ b/internal/handler/cargo.go @@ -11,6 +11,10 @@ import ( const ( cargoUpstream = "https://index.crates.io" cargoDownloadBase = "https://static.crates.io/crates" + + cargoIndexLen1 = 1 + cargoIndexLen2 = 2 + cargoIndexLen3 = 3 ) // CargoHandler handles cargo registry protocol requests. @@ -125,11 +129,11 @@ func (h *CargoHandler) buildIndexPath(name string) string { name = strings.ToLower(name) switch len(name) { - case 1: + case cargoIndexLen1: return fmt.Sprintf("1/%s", name) - case 2: + case cargoIndexLen2: return fmt.Sprintf("2/%s", name) - case 3: + case cargoIndexLen3: return fmt.Sprintf("3/%c/%s", name[0], name) default: return fmt.Sprintf("%s/%s/%s", name[0:2], name[2:4], name) diff --git a/internal/handler/composer.go b/internal/handler/composer.go index 2753e4a..3b76a9c 100644 --- a/internal/handler/composer.go +++ b/internal/handler/composer.go @@ -12,8 +12,9 @@ import ( ) const ( - composerUpstream = "https://packagist.org" - composerRepo = "https://repo.packagist.org" + composerUpstream = "https://packagist.org" + composerRepo = "https://repo.packagist.org" + vendorPackageParts = 2 ) // ComposerHandler handles Composer/Packagist registry protocol requests. @@ -74,8 +75,8 @@ func (h *ComposerHandler) handlePackageMetadata(w http.ResponseWriter, r *http.R // Parse path: /p2/{vendor}/{package}.json path := strings.TrimPrefix(r.URL.Path, "/p2/") path = strings.TrimSuffix(path, ".json") - parts := strings.SplitN(path, "/", 2) - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + parts := strings.SplitN(path, "/", vendorPackageParts) + if len(parts) != vendorPackageParts || parts[0] == "" || parts[1] == "" { http.Error(w, "invalid package path", http.StatusBadRequest) return } @@ -145,58 +146,85 @@ func (h *ComposerHandler) rewriteMetadata(body []byte) ([]byte, error) { continue } - packagePURL := purl.MakePURLString("composer", packageName, "") - - filtered := versionList[:0] - for _, v := range versionList { - vmap, ok := v.(map[string]any) - if !ok { - continue - } + packages[packageName] = h.filterAndRewriteVersions(packageName, versionList) + } - version, _ := vmap["version"].(string) - - // Apply cooldown filtering - if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() { - if timeStr, ok := vmap["time"].(string); ok { - if publishedAt, err := time.Parse(time.RFC3339, timeStr); err == nil { - if !h.proxy.Cooldown.IsAllowed("composer", packagePURL, publishedAt) { - h.proxy.Logger.Info("cooldown: filtering composer version", - "package", packageName, "version", version) - continue - } - } - } - } + return json.Marshal(metadata) +} - dist, ok := vmap["dist"].(map[string]any) - if !ok { - filtered = append(filtered, v) - continue - } +// filterAndRewriteVersions applies cooldown filtering and rewrites dist URLs +// for a single package's version list. +func (h *ComposerHandler) filterAndRewriteVersions(packageName string, versionList []any) []any { + packagePURL := purl.MakePURLString("composer", packageName, "") - // Rewrite the dist URL - if url, ok := dist["url"].(string); ok && url != "" { - filename := "package.zip" - if idx := strings.LastIndex(url, "/"); idx >= 0 { - filename = url[idx+1:] - } + filtered := versionList[:0] + for _, v := range versionList { + vmap, ok := v.(map[string]any) + if !ok { + continue + } - parts := strings.SplitN(packageName, "/", 2) - if len(parts) == 2 { - newURL := fmt.Sprintf("%s/composer/files/%s/%s/%s/%s", - h.proxyURL, parts[0], parts[1], version, filename) - dist["url"] = newURL - } - } + version, _ := vmap["version"].(string) - filtered = append(filtered, v) + if h.shouldFilterVersion(packagePURL, packageName, version, vmap) { + continue } - packages[packageName] = filtered + h.rewriteDistURL(vmap, packageName, version) + filtered = append(filtered, v) } - return json.Marshal(metadata) + return filtered +} + +// shouldFilterVersion returns true if the version should be excluded due to cooldown. +func (h *ComposerHandler) shouldFilterVersion(packagePURL, packageName, version string, vmap map[string]any) bool { + if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { + return false + } + + timeStr, ok := vmap["time"].(string) + if !ok { + return false + } + + publishedAt, err := time.Parse(time.RFC3339, timeStr) + if err != nil { + return false + } + + if !h.proxy.Cooldown.IsAllowed("composer", packagePURL, publishedAt) { + h.proxy.Logger.Info("cooldown: filtering composer version", + "package", packageName, "version", version) + return true + } + + return false +} + +// rewriteDistURL rewrites the dist URL in a version entry to point at this proxy. +func (h *ComposerHandler) rewriteDistURL(vmap map[string]any, packageName, version string) { + dist, ok := vmap["dist"].(map[string]any) + if !ok { + return + } + + url, ok := dist["url"].(string) + if !ok || url == "" { + return + } + + filename := "package.zip" + if idx := strings.LastIndex(url, "/"); idx >= 0 { + filename = url[idx+1:] + } + + parts := strings.SplitN(packageName, "/", vendorPackageParts) + if len(parts) == vendorPackageParts { + newURL := fmt.Sprintf("%s/composer/files/%s/%s/%s/%s", + h.proxyURL, parts[0], parts[1], version, filename) + dist["url"] = newURL + } } // handleDownload serves a package file, fetching and caching from upstream if needed. diff --git a/internal/handler/conan_test.go b/internal/handler/conan_test.go index a0bf1da..a7bd362 100644 --- a/internal/handler/conan_test.go +++ b/internal/handler/conan_test.go @@ -9,6 +9,8 @@ import ( "testing" ) +const testProxyURL = "http://localhost:8080" + func conanTestProxy() *Proxy { return &Proxy{ Logger: slog.Default(), @@ -44,7 +46,7 @@ func TestConanShouldCacheFile(t *testing.T) { func TestConanPingV1(t *testing.T) { h := &ConanHandler{ proxy: conanTestProxy(), - proxyURL: "http://localhost:8080", + proxyURL: testProxyURL, } req := httptest.NewRequest(http.MethodGet, "/v1/ping", nil) @@ -65,7 +67,7 @@ func TestConanPingV1(t *testing.T) { func TestConanPingV2(t *testing.T) { h := &ConanHandler{ proxy: conanTestProxy(), - proxyURL: "http://localhost:8080", + proxyURL: testProxyURL, } req := httptest.NewRequest(http.MethodGet, "/v2/ping", nil) @@ -372,17 +374,17 @@ func TestNewConanHandler(t *testing.T) { if h.upstreamURL != conanUpstream { t.Errorf("upstreamURL = %q, want %q", h.upstreamURL, conanUpstream) } - if h.proxyURL != "http://localhost:8080" { - t.Errorf("proxyURL = %q, want %q (trailing slash should be trimmed)", h.proxyURL, "http://localhost:8080") + if h.proxyURL != testProxyURL { + t.Errorf("proxyURL = %q, want %q (trailing slash should be trimmed)", h.proxyURL, testProxyURL) } } func TestNewConanHandlerNoTrailingSlash(t *testing.T) { proxy := conanTestProxy() - h := NewConanHandler(proxy, "http://localhost:8080") + h := NewConanHandler(proxy, testProxyURL) - if h.proxyURL != "http://localhost:8080" { - t.Errorf("proxyURL = %q, want %q", h.proxyURL, "http://localhost:8080") + if h.proxyURL != testProxyURL { + t.Errorf("proxyURL = %q, want %q", h.proxyURL, testProxyURL) } } diff --git a/internal/handler/conda.go b/internal/handler/conda.go index 303ec0d..f929194 100644 --- a/internal/handler/conda.go +++ b/internal/handler/conda.go @@ -1,13 +1,13 @@ package handler import ( - "io" "net/http" "strings" ) const ( condaUpstream = "https://conda.anaconda.org" + minCondaParts = 3 // name-version-build requires at least 3 hyphen-separated parts ) // CondaHandler handles Conda/Anaconda registry protocol requests. @@ -98,7 +98,7 @@ func (h *CondaHandler) parseFilename(filename string) (name, version string) { // Split by hyphens, the format is name-version-build // The name can contain hyphens, so we need to find version-build at the end parts := strings.Split(base, "-") - if len(parts) < 3 { + if len(parts) < minCondaParts { return "", "" } @@ -121,35 +121,5 @@ func (h *CondaHandler) parseFilename(filename string) (name, version string) { // proxyUpstream forwards a request to Anaconda without caching. func (h *CondaHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - upstreamURL := h.upstreamURL + r.URL.Path - - h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - // Copy accept-encoding for compression - if ae := r.Header.Get("Accept-Encoding"); ae != "" { - req.Header.Set("Accept-Encoding", ae) - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("upstream request failed", "error", err) - http.Error(w, "upstream request failed", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) } diff --git a/internal/handler/container.go b/internal/handler/container.go index 0669b12..fc5f98c 100644 --- a/internal/handler/container.go +++ b/internal/handler/container.go @@ -10,8 +10,11 @@ import ( ) const ( - dockerHubRegistry = "https://registry-1.docker.io" - dockerHubAuth = "https://auth.docker.io" + dockerHubRegistry = "https://registry-1.docker.io" + dockerHubAuth = "https://auth.docker.io" + blobMatchCount = 3 // full match + name + digest + manifestMatchCount = 3 // full match + name + reference + tagsListMatchCount = 2 // full match + name ) // ContainerHandler handles OCI/Docker container registry protocol requests. @@ -347,7 +350,7 @@ var blobPathPattern = regexp.MustCompile(`^(.+)/blobs/(sha256:[a-f0-9]+)$`) // parseBlobPath extracts repository name and digest from a blob path. func (h *ContainerHandler) parseBlobPath(path string) (name, digest string) { matches := blobPathPattern.FindStringSubmatch(path) - if len(matches) != 3 { + if len(matches) != blobMatchCount { return "", "" } return matches[1], matches[2] @@ -359,7 +362,7 @@ var manifestPathPattern = regexp.MustCompile(`^(.+)/manifests/(.+)$`) // parseManifestPath extracts repository name and reference from a manifest path. func (h *ContainerHandler) parseManifestPath(path string) (name, reference string) { matches := manifestPathPattern.FindStringSubmatch(path) - if len(matches) != 3 { + if len(matches) != manifestMatchCount { return "", "" } return matches[1], matches[2] @@ -371,7 +374,7 @@ var tagsListPathPattern = regexp.MustCompile(`^(.+)/tags/list$`) // parseTagsListPath extracts repository name from a tags list path. func (h *ContainerHandler) parseTagsListPath(path string) string { matches := tagsListPathPattern.FindStringSubmatch(path) - if len(matches) != 2 { + if len(matches) != tagsListMatchCount { return "" } return matches[1] diff --git a/internal/handler/cran.go b/internal/handler/cran.go index 4702473..b8c3c3a 100644 --- a/internal/handler/cran.go +++ b/internal/handler/cran.go @@ -1,7 +1,6 @@ package handler import ( - "io" "net/http" "strings" ) @@ -153,34 +152,5 @@ func (h *CRANHandler) isBinaryPackage(filename string) bool { // proxyUpstream forwards a request to CRAN without caching. func (h *CRANHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - upstreamURL := h.upstreamURL + r.URL.Path - - h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - if ae := r.Header.Get("Accept-Encoding"); ae != "" { - req.Header.Set("Accept-Encoding", ae) - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("upstream request failed", "error", err) - http.Error(w, "upstream request failed", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) } diff --git a/internal/handler/debian.go b/internal/handler/debian.go index bada1af..11a1979 100644 --- a/internal/handler/debian.go +++ b/internal/handler/debian.go @@ -2,7 +2,6 @@ package handler import ( "fmt" - "io" "net/http" "regexp" "strings" @@ -10,6 +9,7 @@ import ( const ( debianUpstream = "http://deb.debian.org/debian" + debMatchCount = 4 // full match + name + version + arch ) // DebianHandler handles APT/Debian repository protocol requests. @@ -93,67 +93,12 @@ func (h *DebianHandler) handlePackageDownload(w http.ResponseWriter, r *http.Req // handleMetadata proxies repository metadata files. // These change frequently so we don't cache them. func (h *DebianHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { - upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path) - - h.proxy.Logger.Debug("debian metadata request", "path", path) - - req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - // Forward relevant headers - for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} { - if v := r.Header.Get(header); v != "" { - req.Header.Set(header, v) - } - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("failed to fetch upstream metadata", "error", err) - http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - // Copy response headers - for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { - if v := resp.Header.Get(header); v != "" { - w.Header().Set(header, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "debian") } // proxyFile proxies any file directly without caching. func (h *DebianHandler) proxyFile(w http.ResponseWriter, r *http.Request, path string) { - upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path) - - req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - for key, values := range resp.Header { - for _, v := range values { - w.Header().Add(key, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyFile(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path)) } // debPackagePattern matches .deb filenames to extract name, version, and arch. @@ -172,7 +117,7 @@ func (h *DebianHandler) parsePoolPath(path string) (name, version, arch string) // Parse the filename matches := debPackagePattern.FindStringSubmatch(filename) - if len(matches) != 4 { + if len(matches) != debMatchCount { return "", "", "" } diff --git a/internal/handler/debian_test.go b/internal/handler/debian_test.go index 77d6843..dfdd326 100644 --- a/internal/handler/debian_test.go +++ b/internal/handler/debian_test.go @@ -1,98 +1,23 @@ package handler import ( - "net/http" - "net/http/httptest" "testing" ) func TestDebianHandler_parsePoolPath(t *testing.T) { h := &DebianHandler{} - tests := []struct { - path string - wantName string - wantVersion string - wantArch string - }{ - { - path: "pool/main/n/nginx/nginx_1.18.0-6_amd64.deb", - wantName: "nginx", - wantVersion: "1.18.0-6", - wantArch: "amd64", - }, - { - path: "pool/main/libn/libncurses/libncurses6_6.2-1_amd64.deb", - wantName: "libncurses6", - wantVersion: "6.2-1", - wantArch: "amd64", - }, - { - path: "pool/contrib/v/virtualbox/virtualbox_6.1.38-1_amd64.deb", - wantName: "virtualbox", - wantVersion: "6.1.38-1", - wantArch: "amd64", - }, - { - path: "pool/main/g/git/git_2.39.2-1_arm64.deb", - wantName: "git", - wantVersion: "2.39.2-1", - wantArch: "arm64", - }, - { - path: "invalid/path", - wantName: "", - wantVersion: "", - wantArch: "", - }, - { - path: "pool/main/n/nginx/nginx.deb", - wantName: "", - wantVersion: "", - wantArch: "", - }, - } - - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - name, version, arch := h.parsePoolPath(tt.path) - if name != tt.wantName { - t.Errorf("parsePoolPath() name = %q, want %q", name, tt.wantName) - } - if version != tt.wantVersion { - t.Errorf("parsePoolPath() version = %q, want %q", version, tt.wantVersion) - } - if arch != tt.wantArch { - t.Errorf("parsePoolPath() arch = %q, want %q", arch, tt.wantArch) - } - }) - } + assertPathParser(t, "parsePoolPath", h.parsePoolPath, []pathParseCase{ + {"pool/main/n/nginx/nginx_1.18.0-6_amd64.deb", "nginx", "1.18.0-6", "amd64"}, + {"pool/main/libn/libncurses/libncurses6_6.2-1_amd64.deb", "libncurses6", "6.2-1", "amd64"}, + {"pool/contrib/v/virtualbox/virtualbox_6.1.38-1_amd64.deb", "virtualbox", "6.1.38-1", "amd64"}, + {"pool/main/g/git/git_2.39.2-1_arm64.deb", "git", "2.39.2-1", "arm64"}, + {"invalid/path", "", "", ""}, + {"pool/main/n/nginx/nginx.deb", "", "", ""}, + }) } func TestDebianHandler_Routes(t *testing.T) { h := NewDebianHandler(nil, "http://localhost:8080") - - // Test that handler doesn't panic on initialization - handler := h.Routes() - if handler == nil { - t.Fatal("Routes() returned nil") - } - - // Test method not allowed - req := httptest.NewRequest(http.MethodPost, "/dists/stable/Release", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("POST request: got status %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - - // Test path traversal rejection - req = httptest.NewRequest(http.MethodGet, "/pool/../../../etc/passwd", nil) - w = httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("path traversal: got status %d, want %d", w.Code, http.StatusBadRequest) - } + assertRoutesBasics(t, h.Routes(), "/dists/stable/Release", "/pool/../../../etc/passwd") } diff --git a/internal/handler/download_test.go b/internal/handler/download_test.go index a51f908..a6e0cb3 100644 --- a/internal/handler/download_test.go +++ b/internal/handler/download_test.go @@ -59,6 +59,38 @@ func seedPackageWithPURL(t *testing.T, db *database.DB, store *mockStorage, ecos } } +// assertUpstreamProxied verifies that a handler proxies a request to the upstream +// server and returns the expected response body. The makeHandler function receives +// a configured Proxy and the upstream URL, and returns the handler to test. +func assertUpstreamProxied(t *testing.T, wantBody, path string, makeHandler func(*Proxy, string) http.Handler) { + t.Helper() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, wantBody) + })) + defer upstream.Close() + + proxy, _, _, _ := setupTestProxy(t) + proxy.HTTPClient = upstream.Client() + + srv := httptest.NewServer(makeHandler(proxy, upstream.URL)) + defer srv.Close() + + resp, err := http.Get(srv.URL + path) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != wantBody { + t.Errorf("body = %q, want %q", body, wantBody) + } +} + func TestGemHandler_DownloadCacheHit(t *testing.T) { proxy, db, store, _ := setupTestProxy(t) seedPackage(t, db, store, "gem", "rails", "7.1.0", "rails-7.1.0.gem", "gem binary data") @@ -71,7 +103,7 @@ func TestGemHandler_DownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -94,7 +126,7 @@ func TestGemHandler_DownloadCacheHitMultiHyphen(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -125,7 +157,7 @@ func TestGemHandler_InvalidFilename(t *testing.T) { if err != nil { t.Fatalf("request to %s failed: %v", tt.path, err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != tt.code { t.Errorf("GET %s: status = %d, want %d", tt.path, resp.StatusCode, tt.code) @@ -137,7 +169,7 @@ func TestGemHandler_UpstreamProxy(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Test", "upstream") w.WriteHeader(http.StatusOK) - fmt.Fprint(w, "upstream specs data") + _, _ = fmt.Fprint(w, "upstream specs data") })) defer upstream.Close() @@ -156,7 +188,7 @@ func TestGemHandler_UpstreamProxy(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -185,7 +217,7 @@ func TestGemHandler_CacheMiss(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if !fetcher.fetchCalled { t.Error("expected fetcher to be called on cache miss") @@ -204,7 +236,7 @@ func TestGoHandler_DownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -225,7 +257,7 @@ func TestGoHandler_MethodNotAllowed(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) @@ -242,7 +274,7 @@ func TestGoHandler_NotFound(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) @@ -259,7 +291,7 @@ func TestGoHandler_UnknownAtVSuffix(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) @@ -268,7 +300,7 @@ func TestGoHandler_UnknownAtVSuffix(t *testing.T) { func TestGoHandler_UpstreamProxy(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "v0.14.0\nv0.13.0\n") + _, _ = fmt.Fprint(w, "v0.14.0\nv0.13.0\n") })) defer upstream.Close() @@ -296,7 +328,7 @@ func TestGoHandler_UpstreamProxy(t *testing.T) { if err != nil { t.Fatalf("GET %s failed: %v", path, err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("GET %s: status = %d, want %d", path, resp.StatusCode, http.StatusOK) @@ -319,7 +351,7 @@ func TestGoHandler_CacheMiss(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if !fetcher.fetchCalled { t.Error("expected fetcher to be called on cache miss") @@ -338,7 +370,7 @@ func TestHexHandler_DownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -368,7 +400,7 @@ func TestHexHandler_InvalidFilename(t *testing.T) { if err != nil { t.Fatalf("request to %s failed: %v", tt.path, err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != tt.code { t.Errorf("GET %s: status = %d, want %d", tt.path, resp.StatusCode, tt.code) @@ -377,35 +409,12 @@ func TestHexHandler_InvalidFilename(t *testing.T) { } func TestHexHandler_UpstreamProxy(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "hex registry data") - })) - defer upstream.Close() - - proxy, _, _, _ := setupTestProxy(t) - h := &HexHandler{ - proxy: proxy, - upstreamURL: upstream.URL, - proxyURL: "http://localhost", - } - proxy.HTTPClient = upstream.Client() - - srv := httptest.NewServer(h.Routes()) - defer srv.Close() - - resp, err := http.Get(srv.URL + "/packages/phoenix") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) - } - body, _ := io.ReadAll(resp.Body) - if string(body) != "hex registry data" { - t.Errorf("body = %q, want %q", body, "hex registry data") - } + assertUpstreamProxied(t, "hex registry data", "/packages/phoenix", + func(proxy *Proxy, upstreamURL string) http.Handler { + h := &HexHandler{proxy: proxy, upstreamURL: upstreamURL, proxyURL: "http://localhost"} + return h.Routes() + }, + ) } func TestHexHandler_CacheMiss(t *testing.T) { @@ -423,7 +432,7 @@ func TestHexHandler_CacheMiss(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if !fetcher.fetchCalled { t.Error("expected fetcher to be called on cache miss") @@ -442,7 +451,7 @@ func TestCondaHandler_DownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -465,7 +474,7 @@ func TestCondaHandler_DownloadTarBz2CacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -477,35 +486,12 @@ func TestCondaHandler_DownloadTarBz2CacheHit(t *testing.T) { } func TestCondaHandler_NonPackageFileProxied(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "repodata json") - })) - defer upstream.Close() - - proxy, _, _, _ := setupTestProxy(t) - h := &CondaHandler{ - proxy: proxy, - upstreamURL: upstream.URL, - proxyURL: "http://localhost", - } - proxy.HTTPClient = upstream.Client() - - srv := httptest.NewServer(h.Routes()) - defer srv.Close() - - resp, err := http.Get(srv.URL + "/main/linux-64/repodata.json") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) - } - body, _ := io.ReadAll(resp.Body) - if string(body) != "repodata json" { - t.Errorf("body = %q, want %q", body, "repodata json") - } + assertUpstreamProxied(t, "repodata json", "/main/linux-64/repodata.json", + func(proxy *Proxy, upstreamURL string) http.Handler { + h := &CondaHandler{proxy: proxy, upstreamURL: upstreamURL, proxyURL: "http://localhost"} + return h.Routes() + }, + ) } func TestCondaHandler_CacheMiss(t *testing.T) { @@ -531,7 +517,7 @@ func TestCondaHandler_CacheMiss(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if !fetcher.fetchCalled { t.Error("expected fetcher to be called on cache miss") @@ -550,7 +536,7 @@ func TestCRANHandler_SourceDownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -573,7 +559,7 @@ func TestCRANHandler_BinaryDownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -585,40 +571,17 @@ func TestCRANHandler_BinaryDownloadCacheHit(t *testing.T) { } func TestCRANHandler_NonPackageFileProxied(t *testing.T) { - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "PACKAGES index") - })) - defer upstream.Close() - - proxy, _, _, _ := setupTestProxy(t) - h := &CRANHandler{ - proxy: proxy, - upstreamURL: upstream.URL, - proxyURL: "http://localhost", - } - proxy.HTTPClient = upstream.Client() - - srv := httptest.NewServer(h.Routes()) - defer srv.Close() - - resp, err := http.Get(srv.URL + "/src/contrib/PACKAGES") - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) - } - body, _ := io.ReadAll(resp.Body) - if string(body) != "PACKAGES index" { - t.Errorf("body = %q, want %q", body, "PACKAGES index") - } + assertUpstreamProxied(t, "PACKAGES index", "/src/contrib/PACKAGES", + func(proxy *Proxy, upstreamURL string) http.Handler { + h := &CRANHandler{proxy: proxy, upstreamURL: upstreamURL, proxyURL: "http://localhost"} + return h.Routes() + }, + ) } func TestCRANHandler_SourceNonTarGzProxied(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "some other file") + _, _ = fmt.Fprint(w, "some other file") })) defer upstream.Close() @@ -637,7 +600,7 @@ func TestCRANHandler_SourceNonTarGzProxied(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -661,7 +624,7 @@ func TestCRANHandler_CacheMiss(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if !fetcher.fetchCalled { t.Error("expected fetcher to be called on cache miss") @@ -680,7 +643,7 @@ func TestMavenHandler_DownloadCacheHit(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) @@ -693,7 +656,7 @@ func TestMavenHandler_DownloadCacheHit(t *testing.T) { func TestMavenHandler_MetadataProxied(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "") + _, _ = fmt.Fprint(w, "") })) defer upstream.Close() @@ -719,7 +682,7 @@ func TestMavenHandler_MetadataProxied(t *testing.T) { if err != nil { t.Fatalf("GET %s failed: %v", path, err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("GET %s: status = %d, want %d", path, resp.StatusCode, http.StatusOK) @@ -737,7 +700,7 @@ func TestMavenHandler_EmptyPathNotFound(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) @@ -770,7 +733,7 @@ func TestMavenHandler_ArtifactExtensions(t *testing.T) { if err != nil { t.Fatalf("GET %s failed: %v", path, err) } - resp.Body.Close() + _ = resp.Body.Close() if !fetcher.fetchCalled { t.Errorf("fetcher not called for %s", ext) @@ -796,7 +759,7 @@ func TestMavenHandler_CacheMiss(t *testing.T) { if err != nil { t.Fatalf("request failed: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if !fetcher.fetchCalled { t.Error("expected fetcher to be called on cache miss") diff --git a/internal/handler/go.go b/internal/handler/go.go index d64ded3..dd4e17d 100644 --- a/internal/handler/go.go +++ b/internal/handler/go.go @@ -2,13 +2,13 @@ package handler import ( "fmt" - "io" "net/http" "strings" ) const ( - goUpstream = "https://proxy.golang.org" + goUpstream = "https://proxy.golang.org" + asciiCaseOffset = 32 // difference between lowercase and uppercase ASCII letters ) // GoHandler handles Go module proxy protocol requests. @@ -108,33 +108,7 @@ func (h *GoHandler) handleDownload(w http.ResponseWriter, r *http.Request, modul // proxyUpstream forwards a request to proxy.golang.org without caching. func (h *GoHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - upstreamURL := h.upstreamURL + r.URL.Path - - h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("upstream request failed", "error", err) - http.Error(w, "upstream request failed", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - // Copy response headers - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil) } // decodeGoModule decodes an encoded module path. @@ -143,7 +117,7 @@ func decodeGoModule(encoded string) string { var b strings.Builder for i := 0; i < len(encoded); i++ { if encoded[i] == '!' && i+1 < len(encoded) { - b.WriteByte(encoded[i+1] - 32) // lowercase to uppercase + b.WriteByte(encoded[i+1] - asciiCaseOffset) // lowercase to uppercase i++ } else { b.WriteByte(encoded[i]) diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 9205008..91d8960 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -30,6 +30,8 @@ func containsPathTraversal(path string) bool { return false } +const defaultHTTPTimeout = 30 * time.Second + // maxMetadataSize is the maximum size of upstream metadata responses (50 MB). // Package metadata (e.g. npm with many versions) can be large, but unbounded // reads risk OOM if an upstream misbehaves. @@ -64,7 +66,7 @@ func NewProxy(db *database.DB, store storage.Storage, fetcher fetch.FetcherInter Resolver: resolver, Logger: logger, HTTPClient: &http.Client{ - Timeout: 30 * time.Second, + Timeout: defaultHTTPTimeout, }, } } @@ -187,7 +189,7 @@ func (p *Proxy) fetchAndCache(ctx context.Context, ecosystem, name, version, fil } // Update database - if err := p.updateCacheDB(ctx, ecosystem, name, version, filename, pkgPURL, versionPURL, info.URL, storagePath, hash, size, artifact.ContentType); err != nil { + if err := p.updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, info.URL, storagePath, hash, size, artifact.ContentType); err != nil { p.Logger.Warn("failed to update cache database", "error", err) // Continue anyway - we have the file } @@ -211,7 +213,7 @@ func (p *Proxy) fetchAndCache(ctx context.Context, ecosystem, name, version, fil }, nil } -func (p *Proxy) updateCacheDB(ctx context.Context, ecosystem, name, version, filename, pkgPURL, versionPURL, upstreamURL, storagePath, hash string, size int64, contentType string) error { +func (p *Proxy) updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, upstreamURL, storagePath, hash string, size int64, contentType string) error { now := time.Now() // Upsert package @@ -272,6 +274,102 @@ func ServeArtifact(w http.ResponseWriter, result *CacheResult) { _, _ = io.Copy(w, result.Reader) } +// ProxyUpstream forwards a request to an upstream URL without caching. +// It copies the request, forwards specified headers, and streams the response back. +// If forwardHeaders is nil, all response headers are copied. +func (p *Proxy) ProxyUpstream(w http.ResponseWriter, r *http.Request, upstreamURL string, forwardHeaders []string) { + p.Logger.Debug("proxying to upstream", "url", upstreamURL) + + req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) + if err != nil { + http.Error(w, "failed to create request", http.StatusInternalServerError) + return + } + + // Copy request headers that affect content negotiation / caching + for _, header := range forwardHeaders { + if v := r.Header.Get(header); v != "" { + req.Header.Set(header, v) + } + } + + resp, err := p.HTTPClient.Do(req) + if err != nil { + p.Logger.Error("upstream request failed", "error", err) + http.Error(w, "upstream request failed", http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for k, vv := range resp.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } + + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + +// ProxyMetadata forwards a metadata request to upstream, copying only specific response headers. +func (p *Proxy) ProxyMetadata(w http.ResponseWriter, r *http.Request, upstreamURL string, logLabel string) { + p.Logger.Debug(logLabel+" metadata request", "url", upstreamURL) + + req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) + if err != nil { + http.Error(w, "failed to create request", http.StatusInternalServerError) + return + } + + for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} { + if v := r.Header.Get(header); v != "" { + req.Header.Set(header, v) + } + } + + resp, err := p.HTTPClient.Do(req) + if err != nil { + p.Logger.Error("failed to fetch upstream metadata", "error", err) + http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { + if v := resp.Header.Get(header); v != "" { + w.Header().Set(header, v) + } + } + + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + +// ProxyFile forwards a file request to upstream, copying all response headers. +func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL string) { + req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) + if err != nil { + http.Error(w, "failed to create request", http.StatusInternalServerError) + return + } + + resp, err := p.HTTPClient.Do(req) + if err != nil { + http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for key, values := range resp.Header { + for _, v := range values { + w.Header().Add(key, v) + } + } + + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + // JSONError writes a JSON error response. func JSONError(w http.ResponseWriter, status int, message string) { w.Header().Set("Content-Type", "application/json") @@ -310,7 +408,7 @@ func (p *Proxy) fetchAndCacheFromURL(ctx context.Context, ecosystem, name, versi return nil, fmt.Errorf("storing artifact: %w", err) } - if err := p.updateCacheDB(ctx, ecosystem, name, version, filename, pkgPURL, versionPURL, downloadURL, storagePath, hash, size, artifact.ContentType); err != nil { + if err := p.updateCacheDB(ecosystem, name, filename, pkgPURL, versionPURL, downloadURL, storagePath, hash, size, artifact.ContentType); err != nil { p.Logger.Warn("failed to update cache database", "error", err) } diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index 5ba8ae3..dd85a17 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -160,6 +160,61 @@ func seedPackage(t *testing.T, db *database.DB, store *mockStorage, ecosystem, n } } +// pathParseCase holds a single test case for path parsing functions that return +// (name, version, arch). +type pathParseCase struct { + path string + wantName string + wantVersion string + wantArch string +} + +// assertPathParser runs table-driven tests for a path parser function that returns +// three strings (name, version, arch). +func assertPathParser(t *testing.T, funcName string, parse func(string) (string, string, string), cases []pathParseCase) { + t.Helper() + for _, tt := range cases { + t.Run(tt.path, func(t *testing.T) { + name, version, arch := parse(tt.path) + if name != tt.wantName { + t.Errorf("%s() name = %q, want %q", funcName, name, tt.wantName) + } + if version != tt.wantVersion { + t.Errorf("%s() version = %q, want %q", funcName, version, tt.wantVersion) + } + if arch != tt.wantArch { + t.Errorf("%s() arch = %q, want %q", funcName, arch, tt.wantArch) + } + }) + } +} + +// assertRoutesBasics checks that a handler's Routes() returns a non-nil handler, +// rejects POST requests with 405, and rejects path traversal with 400. +func assertRoutesBasics(t *testing.T, handler http.Handler, postPath, traversalPath string) { + t.Helper() + + if handler == nil { + t.Fatal("Routes() returned nil") + } + + req := httptest.NewRequest(http.MethodPost, postPath, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("POST request: got status %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + + req = httptest.NewRequest(http.MethodGet, traversalPath, nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("path traversal: got status %d, want %d", w.Code, http.StatusBadRequest) + } +} + func TestGetOrFetchArtifact_CacheHit(t *testing.T) { proxy, db, store, fetcher := setupTestProxy(t) seedPackage(t, db, store, "npm", "lodash", "4.17.21", "lodash-4.17.21.tgz", "cached content") diff --git a/internal/handler/hex.go b/internal/handler/hex.go index 2e972f0..4e0f2a2 100644 --- a/internal/handler/hex.go +++ b/internal/handler/hex.go @@ -1,8 +1,6 @@ package handler import ( - "fmt" - "io" "net/http" "strings" ) @@ -89,40 +87,5 @@ func (h *HexHandler) parseTarballFilename(filename string) (name, version string // proxyUpstream forwards a request to hex.pm without caching. func (h *HexHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - upstreamURL := h.upstreamURL + r.URL.Path - - h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - // Copy accept header for content negotiation - if accept := r.Header.Get("Accept"); accept != "" { - req.Header.Set("Accept", accept) - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("upstream request failed", "error", err) - http.Error(w, "upstream request failed", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - // Copy response headers - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) -} - -func init() { - _ = fmt.Sprintf // silence import if unused + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept"}) } diff --git a/internal/handler/maven.go b/internal/handler/maven.go index 2c7214f..79da0c0 100644 --- a/internal/handler/maven.go +++ b/internal/handler/maven.go @@ -2,7 +2,6 @@ package handler import ( "fmt" - "io" "net/http" "path" "strings" @@ -10,6 +9,7 @@ import ( const ( mavenUpstream = "https://repo1.maven.org/maven2" + minMavenParts = 4 // group path segments + artifact + version + filename ) // MavenHandler handles Maven repository protocol requests. @@ -99,7 +99,7 @@ func (h *MavenHandler) handleDownload(w http.ResponseWriter, r *http.Request, ur // -> ("com.google.guava", "guava", "32.1.3-jre", "guava-32.1.3-jre.jar") func (h *MavenHandler) parsePath(urlPath string) (group, artifact, version, filename string) { parts := strings.Split(urlPath, "/") - if len(parts) < 4 { + if len(parts) < minMavenParts { return "", "", "", "" } @@ -136,30 +136,5 @@ func (h *MavenHandler) isMetadataFile(filename string) bool { // proxyUpstream forwards a request to Maven Central without caching. func (h *MavenHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { - upstreamURL := h.upstreamURL + r.URL.Path - - h.proxy.Logger.Debug("proxying to upstream", "url", upstreamURL) - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("upstream request failed", "error", err) - http.Error(w, "upstream request failed", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil) } diff --git a/internal/handler/npm.go b/internal/handler/npm.go index 99a2bb9..e0b0566 100644 --- a/internal/handler/npm.go +++ b/internal/handler/npm.go @@ -14,6 +14,7 @@ import ( const ( npmUpstream = "https://registry.npmjs.org" + scopedParts = 2 // scope + name in scoped packages ) // NPMHandler handles npm registry protocol requests. @@ -127,46 +128,71 @@ func (h *NPMHandler) rewriteMetadata(packageName string, body []byte) ([]byte, e return body, nil // No versions to rewrite } - // Apply cooldown filtering - if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() { - timeMap, _ := metadata["time"].(map[string]any) - packagePURL := purl.MakePURLString("npm", packageName, "") + h.applyCooldownFiltering(metadata, versions, packageName) + h.rewriteTarballURLs(versions, packageName) - for version := range versions { - if timeMap == nil { - continue - } - publishedStr, ok := timeMap[version].(string) - if !ok { - continue - } - publishedAt, err := time.Parse(time.RFC3339, publishedStr) - if err != nil { - continue - } - if !h.proxy.Cooldown.IsAllowed("npm", packagePURL, publishedAt) { - h.proxy.Logger.Info("cooldown: filtering npm version", - "package", packageName, "version", version, - "published", publishedStr) - delete(versions, version) - delete(timeMap, version) - } - } + return json.Marshal(metadata) +} - // Update dist-tags.latest if it was filtered - if distTags, ok := metadata["dist-tags"].(map[string]any); ok { - if latest, ok := distTags["latest"].(string); ok { - if _, exists := versions[latest]; !exists { - // Find newest remaining version from the time map - newLatest := h.findNewestVersion(versions, timeMap) - if newLatest != "" { - distTags["latest"] = newLatest - } - } - } +// applyCooldownFiltering removes versions that are too recently published, +// and updates dist-tags.latest if the current latest was filtered out. +func (h *NPMHandler) applyCooldownFiltering(metadata map[string]any, versions map[string]any, packageName string) { + if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { + return + } + + timeMap, _ := metadata["time"].(map[string]any) + if timeMap == nil { + return + } + + packagePURL := purl.MakePURLString("npm", packageName, "") + + for version := range versions { + publishedStr, ok := timeMap[version].(string) + if !ok { + continue } + publishedAt, err := time.Parse(time.RFC3339, publishedStr) + if err != nil { + continue + } + if !h.proxy.Cooldown.IsAllowed("npm", packagePURL, publishedAt) { + h.proxy.Logger.Info("cooldown: filtering npm version", + "package", packageName, "version", version, + "published", publishedStr) + delete(versions, version) + delete(timeMap, version) + } + } + + h.updateDistTagsLatest(metadata, versions, timeMap) +} + +// updateDistTagsLatest updates the dist-tags.latest field if the current latest +// version was removed by cooldown filtering. +func (h *NPMHandler) updateDistTagsLatest(metadata, versions, timeMap map[string]any) { + distTags, ok := metadata["dist-tags"].(map[string]any) + if !ok { + return + } + + latest, ok := distTags["latest"].(string) + if !ok { + return + } + + if _, exists := versions[latest]; exists { + return + } + + if newLatest := h.findNewestVersion(versions, timeMap); newLatest != "" { + distTags["latest"] = newLatest } +} +// rewriteTarballURLs rewrites all tarball URLs in version entries to point at this proxy. +func (h *NPMHandler) rewriteTarballURLs(versions map[string]any, packageName string) { for version, vdata := range versions { vmap, ok := vdata.(map[string]any) if !ok { @@ -178,25 +204,24 @@ func (h *NPMHandler) rewriteMetadata(packageName string, body []byte) ([]byte, e continue } - if tarball, ok := dist["tarball"].(string); ok { - // Extract filename from tarball URL - filename := tarball - if idx := strings.LastIndex(tarball, "/"); idx >= 0 { - filename = tarball[idx+1:] - } - - // Rewrite to our proxy URL - escapedName := url.PathEscape(packageName) - newTarball := fmt.Sprintf("%s/npm/%s/-/%s", h.proxyURL, escapedName, filename) - dist["tarball"] = newTarball + tarball, ok := dist["tarball"].(string) + if !ok { + continue + } - h.proxy.Logger.Debug("rewrote tarball URL", - "package", packageName, "version", version, - "old", tarball, "new", newTarball) + filename := tarball + if idx := strings.LastIndex(tarball, "/"); idx >= 0 { + filename = tarball[idx+1:] } - } - return json.Marshal(metadata) + escapedName := url.PathEscape(packageName) + newTarball := fmt.Sprintf("%s/npm/%s/-/%s", h.proxyURL, escapedName, filename) + dist["tarball"] = newTarball + + h.proxy.Logger.Debug("rewrote tarball URL", + "package", packageName, "version", version, + "old", tarball, "new", newTarball) + } } // findNewestVersion returns the version string with the most recent timestamp @@ -313,7 +338,7 @@ func (h *NPMHandler) extractVersionFromFilename(packageName, filename string) st // For scoped packages, the filename uses the short name shortName := packageName if strings.Contains(packageName, "/") { - parts := strings.SplitN(packageName, "/", 2) + parts := strings.SplitN(packageName, "/", scopedParts) shortName = parts[1] } diff --git a/internal/handler/npm_test.go b/internal/handler/npm_test.go index 90616d7..3c3dc7d 100644 --- a/internal/handler/npm_test.go +++ b/internal/handler/npm_test.go @@ -11,6 +11,8 @@ import ( "github.com/git-pkgs/proxy/internal/cooldown" ) +const testVersion100 = "1.0.0" + func testProxy() *Proxy { return &Proxy{ Logger: slog.Default(), @@ -172,7 +174,7 @@ func TestNPMHandlerMetadataProxy(t *testing.T) { // Check that tarball URL was rewritten versions := result["versions"].(map[string]any) - v := versions["1.0.0"].(map[string]any) + v := versions[testVersion100].(map[string]any) dist := v["dist"].(map[string]any) tarball := dist["tarball"].(string) @@ -232,7 +234,7 @@ func TestNPMRewriteMetadataCooldown(t *testing.T) { versions := result["versions"].(map[string]any) // Old version should remain - if _, ok := versions["1.0.0"]; !ok { + if _, ok := versions[testVersion100]; !ok { t.Error("version 1.0.0 should not be filtered") } @@ -243,8 +245,8 @@ func TestNPMRewriteMetadataCooldown(t *testing.T) { // dist-tags.latest should be updated to 1.0.0 distTags := result["dist-tags"].(map[string]any) - if distTags["latest"] != "1.0.0" { - t.Errorf("dist-tags.latest = %q, want %q", distTags["latest"], "1.0.0") + if distTags["latest"] != testVersion100 { + t.Errorf("dist-tags.latest = %q, want %q", distTags["latest"], testVersion100) } } @@ -286,7 +288,7 @@ func TestNPMRewriteMetadataCooldownExemptPackage(t *testing.T) { } versions := result["versions"].(map[string]any) - if _, ok := versions["1.0.0"]; !ok { + if _, ok := versions[testVersion100]; !ok { t.Error("exempt package version should not be filtered") } } diff --git a/internal/handler/pub.go b/internal/handler/pub.go index 53c3a4f..b8f6207 100644 --- a/internal/handler/pub.go +++ b/internal/handler/pub.go @@ -12,7 +12,8 @@ import ( ) const ( - pubUpstream = "https://pub.dev" + pubUpstream = "https://pub.dev" + pubPathParts = 2 // name + version in path split by /versions/ ) // PubHandler handles pub.dev registry protocol requests. @@ -49,7 +50,7 @@ func (h *PubHandler) handleDownload(w http.ResponseWriter, r *http.Request) { // Parse path: /packages/{name}/versions/{version}.tar.gz path := strings.TrimPrefix(r.URL.Path, "/packages/") parts := strings.Split(path, "/versions/") - if len(parts) != 2 { + if len(parts) != pubPathParts { http.Error(w, "invalid request", http.StatusBadRequest) return } @@ -137,15 +138,23 @@ func (h *PubHandler) rewriteMetadata(name string, body []byte) ([]byte, error) { return nil, err } - // Rewrite archive URLs in versions versions, ok := metadata["versions"].([]any) if !ok { return body, nil } packagePURL := purl.MakePURLString("pub", name, "") + filtered := h.filterAndRewriteVersions(name, packagePURL, versions) + metadata["versions"] = filtered + + h.updateLatestVersion(metadata, filtered) + + return json.Marshal(metadata) +} - // Filter and rewrite versions +// filterAndRewriteVersions applies cooldown filtering and rewrites archive URLs +// for a package's version list. +func (h *PubHandler) filterAndRewriteVersions(name, packagePURL string, versions []any) []any { filtered := versions[:0] for _, vdata := range versions { vmap, ok := vdata.(map[string]any) @@ -158,20 +167,10 @@ func (h *PubHandler) rewriteMetadata(name string, body []byte) ([]byte, error) { continue } - // Apply cooldown filtering - if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() { - if publishedStr, ok := vmap["published"].(string); ok { - if publishedAt, err := time.Parse(time.RFC3339, publishedStr); err == nil { - if !h.proxy.Cooldown.IsAllowed("pub", packagePURL, publishedAt) { - h.proxy.Logger.Info("cooldown: filtering pub version", - "package", name, "version", version) - continue - } - } - } + if h.shouldFilterVersion(packagePURL, name, version, vmap) { + continue } - // Rewrite archive_url newURL := fmt.Sprintf("%s/pub/packages/%s/versions/%s.tar.gz", h.proxyURL, name, version) vmap["archive_url"] = newURL filtered = append(filtered, vdata) @@ -180,28 +179,60 @@ func (h *PubHandler) rewriteMetadata(name string, body []byte) ([]byte, error) { "package", name, "version", version, "new", newURL) } - metadata["versions"] = filtered + return filtered +} + +// shouldFilterVersion returns true if the version should be excluded due to cooldown. +func (h *PubHandler) shouldFilterVersion(packagePURL, name, version string, vmap map[string]any) bool { + if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { + return false + } + + publishedStr, ok := vmap["published"].(string) + if !ok { + return false + } + + publishedAt, err := time.Parse(time.RFC3339, publishedStr) + if err != nil { + return false + } - // Update latest if it points to a filtered version - if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() { - if latest, ok := metadata["latest"].(map[string]any); ok { - if latestVer, ok := latest["version"].(string); ok { - found := false - for _, vdata := range filtered { - if vmap, ok := vdata.(map[string]any); ok { - if vmap["version"] == latestVer { - found = true - break - } - } - } - if !found && len(filtered) > 0 { - // Use the last entry (most recent remaining) - metadata["latest"] = filtered[len(filtered)-1] - } + if !h.proxy.Cooldown.IsAllowed("pub", packagePURL, publishedAt) { + h.proxy.Logger.Info("cooldown: filtering pub version", + "package", name, "version", version) + return true + } + + return false +} + +// updateLatestVersion updates the latest field if the current latest version +// was removed by cooldown filtering. +func (h *PubHandler) updateLatestVersion(metadata map[string]any, filtered []any) { + if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { + return + } + + latest, ok := metadata["latest"].(map[string]any) + if !ok { + return + } + + latestVer, ok := latest["version"].(string) + if !ok { + return + } + + for _, vdata := range filtered { + if vmap, ok := vdata.(map[string]any); ok { + if vmap["version"] == latestVer { + return } } } - return json.Marshal(metadata) + if len(filtered) > 0 { + metadata["latest"] = filtered[len(filtered)-1] + } } diff --git a/internal/handler/pypi.go b/internal/handler/pypi.go index 4a4232c..51b2871 100644 --- a/internal/handler/pypi.go +++ b/internal/handler/pypi.go @@ -16,7 +16,11 @@ import ( ) const ( - pypiUpstream = "https://pypi.org" + pypiUpstream = "https://pypi.org" + minWheelParts = 5 // name + version + python + abi + platform + minSubmatchParts = 2 // full match + first capture group + minPyPIPathParts = 3 // hash_prefix + hash + filename + minPythonTagLen = 2 // minimum length for a python tag (e.g., "py") ) // PyPIHandler handles PyPI registry protocol requests. @@ -172,7 +176,7 @@ func (h *PyPIHandler) rewriteSimpleHTML(body []byte, filteredVersions map[string // Extract filename from between tags innerRe := regexp.MustCompile(`>([^<]+)`) innerMatch := innerRe.FindSubmatch(match) - if len(innerMatch) < 2 { + if len(innerMatch) < minSubmatchParts { return match } filename := string(innerMatch[1]) @@ -190,7 +194,7 @@ func (h *PyPIHandler) rewriteSimpleHTML(body []byte, filteredVersions map[string return re.ReplaceAllFunc(body, func(match []byte) []byte { submatch := re.FindSubmatch(match) - if len(submatch) < 2 { + if len(submatch) < minSubmatchParts { return match } @@ -285,59 +289,86 @@ func (h *PyPIHandler) rewriteJSONMetadata(body []byte) ([]byte, error) { return nil, err } - // Determine package name for cooldown lookup packageName, _ := extractPyPIName(metadata) packagePURL := "" if packageName != "" { packagePURL = purl.MakePURLString("pypi", packageName, "") } - // Filter and rewrite URLs in releases map - if releases, ok := metadata["releases"].(map[string]any); ok { - for version, files := range releases { - if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() && packagePURL != "" { - if filesArr, ok := files.([]any); ok { - if publishedAt := h.newestUploadTime(filesArr); !publishedAt.IsZero() { - if !h.proxy.Cooldown.IsAllowed("pypi", packagePURL, publishedAt) { - h.proxy.Logger.Info("cooldown: filtering pypi version", - "package", packageName, "version", version) - delete(releases, version) - continue - } - } - } - } + h.filterAndRewriteReleases(metadata, packageName, packagePURL) + h.filterAndRewriteURLs(metadata, packagePURL) - if filesArr, ok := files.([]any); ok { - for _, f := range filesArr { - if fmap, ok := f.(map[string]any); ok { - h.rewriteURLEntry(fmap) - } - } - } + return json.Marshal(metadata) +} + +// filterAndRewriteReleases applies cooldown filtering and URL rewriting to the +// releases map in PyPI metadata. +func (h *PyPIHandler) filterAndRewriteReleases(metadata map[string]any, packageName, packagePURL string) { + releases, ok := metadata["releases"].(map[string]any) + if !ok { + return + } + + for version, files := range releases { + if h.shouldFilterRelease(packagePURL, files) { + h.proxy.Logger.Info("cooldown: filtering pypi version", + "package", packageName, "version", version) + delete(releases, version) + continue } + + h.rewriteFileEntries(files) } +} - // Filter and rewrite URLs in urls array (current version files) - if urls, ok := metadata["urls"].([]any); ok { - if h.proxy.Cooldown != nil && h.proxy.Cooldown.Enabled() && packagePURL != "" { - if publishedAt := h.newestUploadTime(urls); !publishedAt.IsZero() { - if !h.proxy.Cooldown.IsAllowed("pypi", packagePURL, publishedAt) { - metadata["urls"] = []any{} - } - } +// shouldFilterRelease returns true if a release should be excluded due to cooldown. +func (h *PyPIHandler) shouldFilterRelease(packagePURL string, files any) bool { + if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() || packagePURL == "" { + return false + } + + filesArr, ok := files.([]any) + if !ok { + return false + } + + publishedAt := h.newestUploadTime(filesArr) + return !publishedAt.IsZero() && !h.proxy.Cooldown.IsAllowed("pypi", packagePURL, publishedAt) +} + +// rewriteFileEntries rewrites URLs in a list of file entries. +func (h *PyPIHandler) rewriteFileEntries(files any) { + filesArr, ok := files.([]any) + if !ok { + return + } + + for _, f := range filesArr { + if fmap, ok := f.(map[string]any); ok { + h.rewriteURLEntry(fmap) } + } +} - if urls, ok := metadata["urls"].([]any); ok { - for _, u := range urls { - if umap, ok := u.(map[string]any); ok { - h.rewriteURLEntry(umap) - } +// filterAndRewriteURLs applies cooldown filtering and URL rewriting to the +// urls array (current version files) in PyPI metadata. +func (h *PyPIHandler) filterAndRewriteURLs(metadata map[string]any, packagePURL string) { + urls, ok := metadata["urls"].([]any) + if !ok { + return + } + + if h.shouldFilterRelease(packagePURL, urls) { + metadata["urls"] = []any{} + } + + if urls, ok := metadata["urls"].([]any); ok { + for _, u := range urls { + if umap, ok := u.(map[string]any); ok { + h.rewriteURLEntry(umap) } } } - - return json.Marshal(metadata) } // extractPyPIName extracts the package name from PyPI JSON metadata. @@ -403,7 +434,7 @@ func (h *PyPIHandler) handleDownload(w http.ResponseWriter, r *http.Request) { // Path format: /packages/{hash_prefix}/{hash}/{filename} // e.g., /packages/ab/cd/abc123.../requests-2.31.0.tar.gz parts := strings.Split(path, "/") - if len(parts) < 3 { + if len(parts) < minPyPIPathParts { http.Error(w, "invalid path", http.StatusBadRequest) return } @@ -442,7 +473,7 @@ func (h *PyPIHandler) parseFilename(filename string) (name, version string) { if strings.HasSuffix(filename, ".whl") { base := strings.TrimSuffix(filename, ".whl") parts := strings.Split(base, "-") - if len(parts) >= 5 { + if len(parts) >= minWheelParts { // Find where version ends (version followed by python tag) for i := 1; i < len(parts)-2; i++ { // Check if this looks like a python tag (py2, py3, cp39, etc) @@ -472,7 +503,7 @@ func (h *PyPIHandler) parseFilename(filename string) (name, version string) { } func isPythonTag(s string) bool { - if len(s) < 2 { + if len(s) < minPythonTagLen { return false } // Python tags start with py, cp, pp, ip, jy diff --git a/internal/handler/rpm.go b/internal/handler/rpm.go index de4295a..92da8b6 100644 --- a/internal/handler/rpm.go +++ b/internal/handler/rpm.go @@ -2,7 +2,6 @@ package handler import ( "fmt" - "io" "net/http" "regexp" "strings" @@ -11,6 +10,7 @@ import ( const ( // Default upstream for Fedora packages defaultRPMUpstream = "https://dl.fedoraproject.org/pub/fedora/linux" + rpmMatchCount = 5 // full match + name + version + release + arch ) // RPMHandler handles RPM/Yum repository protocol requests. @@ -95,67 +95,12 @@ func (h *RPMHandler) handlePackageDownload(w http.ResponseWriter, r *http.Reques // handleMetadata proxies repository metadata files (repomd.xml, primary.xml.gz, etc.). // These change frequently so we don't cache them. func (h *RPMHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { - upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path) - - h.proxy.Logger.Debug("rpm metadata request", "path", path) - - req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - // Forward relevant headers - for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} { - if v := r.Header.Get(header); v != "" { - req.Header.Set(header, v) - } - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("failed to fetch upstream metadata", "error", err) - http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - // Copy response headers - for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { - if v := resp.Header.Get(header); v != "" { - w.Header().Set(header, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "rpm") } // proxyFile proxies any file directly without caching. func (h *RPMHandler) proxyFile(w http.ResponseWriter, r *http.Request, path string) { - upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, path) - - req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) - return - } - defer func() { _ = resp.Body.Close() }() - - for key, values := range resp.Header { - for _, v := range values { - w.Header().Add(key, v) - } - } - - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + h.proxy.ProxyFile(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path)) } // rpmPackagePattern matches .rpm filenames to extract name, version, release, and arch. @@ -176,7 +121,7 @@ func (h *RPMHandler) parseRPMPath(path string) (name, version, arch string) { // Parse the filename matches := rpmPackagePattern.FindStringSubmatch(filename) - if len(matches) != 5 { + if len(matches) != rpmMatchCount { return "", "", "" } diff --git a/internal/handler/rpm_test.go b/internal/handler/rpm_test.go index b63f44d..851a304 100644 --- a/internal/handler/rpm_test.go +++ b/internal/handler/rpm_test.go @@ -1,98 +1,23 @@ package handler import ( - "net/http" - "net/http/httptest" "testing" ) func TestRPMHandler_parseRPMPath(t *testing.T) { h := &RPMHandler{} - tests := []struct { - path string - wantName string - wantVersion string - wantArch string - }{ - { - path: "releases/39/Everything/x86_64/os/Packages/n/nginx-1.24.0-1.fc39.x86_64.rpm", - wantName: "nginx", - wantVersion: "1.24.0-1.fc39", - wantArch: "x86_64", - }, - { - path: "Packages/kernel-core-6.5.5-200.fc38.x86_64.rpm", - wantName: "kernel-core", - wantVersion: "6.5.5-200.fc38", - wantArch: "x86_64", - }, - { - path: "updates/39/Everything/aarch64/Packages/g/git-2.42.0-1.fc39.aarch64.rpm", - wantName: "git", - wantVersion: "2.42.0-1.fc39", - wantArch: "aarch64", - }, - { - path: "vim-enhanced-9.0.1000-1.fc38.noarch.rpm", - wantName: "vim-enhanced", - wantVersion: "9.0.1000-1.fc38", - wantArch: "noarch", - }, - { - path: "invalid.rpm", - wantName: "", - wantVersion: "", - wantArch: "", - }, - { - path: "not-an-rpm-file", - wantName: "", - wantVersion: "", - wantArch: "", - }, - } - - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - name, version, arch := h.parseRPMPath(tt.path) - if name != tt.wantName { - t.Errorf("parseRPMPath() name = %q, want %q", name, tt.wantName) - } - if version != tt.wantVersion { - t.Errorf("parseRPMPath() version = %q, want %q", version, tt.wantVersion) - } - if arch != tt.wantArch { - t.Errorf("parseRPMPath() arch = %q, want %q", arch, tt.wantArch) - } - }) - } + assertPathParser(t, "parseRPMPath", h.parseRPMPath, []pathParseCase{ + {"releases/39/Everything/x86_64/os/Packages/n/nginx-1.24.0-1.fc39.x86_64.rpm", "nginx", "1.24.0-1.fc39", "x86_64"}, + {"Packages/kernel-core-6.5.5-200.fc38.x86_64.rpm", "kernel-core", "6.5.5-200.fc38", "x86_64"}, + {"updates/39/Everything/aarch64/Packages/g/git-2.42.0-1.fc39.aarch64.rpm", "git", "2.42.0-1.fc39", "aarch64"}, + {"vim-enhanced-9.0.1000-1.fc38.noarch.rpm", "vim-enhanced", "9.0.1000-1.fc38", "noarch"}, + {"invalid.rpm", "", "", ""}, + {"not-an-rpm-file", "", "", ""}, + }) } func TestRPMHandler_Routes(t *testing.T) { h := NewRPMHandler(nil, "http://localhost:8080") - - // Test that handler doesn't panic on initialization - handler := h.Routes() - if handler == nil { - t.Fatal("Routes() returned nil") - } - - // Test method not allowed - req := httptest.NewRequest(http.MethodPost, "/repodata/repomd.xml", nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("POST request: got status %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - - // Test path traversal rejection - req = httptest.NewRequest(http.MethodGet, "/releases/../../../etc/passwd", nil) - w = httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("path traversal: got status %d, want %d", w.Code, http.StatusBadRequest) - } + assertRoutesBasics(t, h.Routes(), "/repodata/repomd.xml", "/releases/../../../etc/passwd") } diff --git a/internal/server/api.go b/internal/server/api.go index f756cba..e9b91be 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -12,6 +12,12 @@ import ( "github.com/go-chi/chi/v5" ) +const ( + maxBodySize = 1 << 20 // 1 MB + licenseCategoryUnknown = "unknown" + defaultSortBy = "hits" +) + // APIHandler provides REST endpoints for package enrichment data. type APIHandler struct { enrichment *enrichment.Service @@ -327,7 +333,7 @@ func (h *APIHandler) HandleGetVulns(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string // @Router /api/outdated [post] func (h *APIHandler) HandleOutdated(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MB + r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) var req OutdatedRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "invalid request body", http.StatusBadRequest) @@ -373,7 +379,7 @@ func (h *APIHandler) HandleOutdated(w http.ResponseWriter, r *http.Request) { // @Failure 500 {string} string // @Router /api/bulk [post] func (h *APIHandler) HandleBulkLookup(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MB + r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) var req BulkRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "invalid request body", http.StatusBadRequest) @@ -573,11 +579,11 @@ func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request) ecosystem := r.URL.Query().Get("ecosystem") sortBy := r.URL.Query().Get("sort") if sortBy == "" { - sortBy = "hits" + sortBy = defaultSortBy } validSorts := map[string]bool{ - "hits": true, + defaultSortBy: true, "name": true, "size": true, "cached_at": true, @@ -619,7 +625,7 @@ func (h *APIHandler) HandlePackagesList(w http.ResponseWriter, r *http.Request) latestVersion = pkg.LatestVersion.String } license := "" - licenseCategory := "unknown" + licenseCategory := licenseCategoryUnknown if pkg.License.Valid { license = pkg.License.String if h.enrichment != nil { diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 88705a1..96cce9e 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -16,6 +16,8 @@ import ( "github.com/go-chi/chi/v5" ) +const testEcosystemNPM = "npm" + func TestNewAPIHandler(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) svc := enrichment.New(logger) @@ -178,7 +180,7 @@ func TestWriteJSON(t *testing.T) { func TestPackageResponseJSON(t *testing.T) { resp := &PackageResponse{ - Ecosystem: "npm", + Ecosystem: testEcosystemNPM, Name: "lodash", LatestVersion: "4.17.21", License: "MIT", @@ -199,7 +201,7 @@ func TestPackageResponseJSON(t *testing.T) { t.Fatalf("failed to unmarshal: %v", err) } - if decoded.Ecosystem != "npm" { + if decoded.Ecosystem != testEcosystemNPM { t.Errorf("expected ecosystem npm, got %s", decoded.Ecosystem) } if decoded.Name != "lodash" { @@ -310,7 +312,7 @@ func TestHandleSearch_WithNullValues(t *testing.T) { pkg := &database.Package{ PURL: "pkg:npm/api-test", - Ecosystem: "npm", + Ecosystem: testEcosystemNPM, Name: "api-test", } if err := db.UpsertPackage(pkg); err != nil { @@ -387,7 +389,7 @@ func TestHandlePackagesListAPI(t *testing.T) { for _, name := range []string{"api-list-one", "api-list-two"} { pkg := &database.Package{ PURL: "pkg:npm/" + name, - Ecosystem: "npm", + Ecosystem: testEcosystemNPM, Name: name, } if err := db.UpsertPackage(pkg); err != nil { @@ -433,7 +435,7 @@ func TestHandlePackagesListAPI(t *testing.T) { t.Fatalf("expected at least 2 results, got %d", len(resp.Results)) } - if resp.SortBy != "hits" { + if resp.SortBy != defaultSortBy { t.Errorf("expected default sort by hits, got %q", resp.SortBy) } diff --git a/internal/server/browse.go b/internal/server/browse.go index f19ea67..372df50 100644 --- a/internal/server/browse.go +++ b/internal/server/browse.go @@ -15,6 +15,8 @@ import ( "github.com/go-chi/chi/v5" ) +const contentTypePlainText = "text/plain; charset=utf-8" + // getStripPrefix returns the path prefix to strip for a given ecosystem. // npm packages wrap content in a "package/" directory. func getStripPrefix(ecosystem string) string { @@ -240,7 +242,7 @@ func detectContentType(filename string) string { switch ext { // Text formats case ".txt", ".md", ".markdown": - return "text/plain; charset=utf-8" + return contentTypePlainText case ".html", ".htm": return "text/html; charset=utf-8" case ".css": @@ -282,7 +284,7 @@ func detectContentType(filename string) string { // Config files case ".conf", ".config", ".ini": - return "text/plain; charset=utf-8" + return contentTypePlainText case ".sh", ".bash": return "text/x-shellscript; charset=utf-8" case ".dockerfile": @@ -307,7 +309,7 @@ func detectContentType(filename string) string { default: // Try to detect if it looks like text if isLikelyText(filename) { - return "text/plain; charset=utf-8" + return contentTypePlainText } return "application/octet-stream" } @@ -482,8 +484,9 @@ func (s *Server) handleComparePage(w http.ResponseWriter, r *http.Request) { versions := chi.URLParam(r, "versions") // Parse versions (format: "1.0.0...2.0.0") + const compareVersionParts = 2 parts := strings.Split(versions, "...") - if len(parts) != 2 { + if len(parts) != compareVersionParts { http.Error(w, "invalid version format, use: version1...version2", http.StatusBadRequest) return } diff --git a/internal/server/browse_test.go b/internal/server/browse_test.go index b85116d..13680a5 100644 --- a/internal/server/browse_test.go +++ b/internal/server/browse_test.go @@ -16,6 +16,8 @@ import ( "github.com/git-pkgs/proxy/internal/database" ) +const testArchiveName = "test.tar.gz" + func TestHandleBrowseList(t *testing.T) { ts := newTestServer(t) defer ts.close() @@ -26,12 +28,12 @@ func TestHandleBrowseList(t *testing.T) { if err := os.MkdirAll(artifactsDir, 0755); err != nil { t.Fatalf("failed to create artifacts dir: %v", err) } - storagePath := filepath.Join(artifactsDir, "test.tar.gz") + storagePath := filepath.Join(artifactsDir, testArchiveName) if err := os.WriteFile(storagePath, archiveData, 0644); err != nil { t.Fatalf("failed to write test archive: %v", err) } // Storage path relative to artifacts directory - relPath := "test.tar.gz" + relPath := testArchiveName // Setup test package and artifact pkg := &database.Package{ @@ -99,12 +101,12 @@ func TestHandleBrowseFile(t *testing.T) { if err := os.MkdirAll(artifactsDir, 0755); err != nil { t.Fatalf("failed to create artifacts dir: %v", err) } - storagePath := filepath.Join(artifactsDir, "test.tar.gz") + storagePath := filepath.Join(artifactsDir, testArchiveName) if err := os.WriteFile(storagePath, archiveData, 0644); err != nil { t.Fatalf("failed to write test archive: %v", err) } // Storage path relative to artifacts directory - relPath := "test.tar.gz" + relPath := testArchiveName // Setup test package and artifact pkg := &database.Package{ @@ -150,7 +152,7 @@ func TestHandleBrowseFile(t *testing.T) { // Check content type contentType := w.Header().Get("Content-Type") - if contentType != "text/plain; charset=utf-8" { + if contentType != contentTypePlainText { t.Errorf("expected text/plain content type, got %q", contentType) } @@ -169,8 +171,8 @@ func TestDetectContentType(t *testing.T) { filename string expectedCT string }{ - {"file.txt", "text/plain; charset=utf-8"}, - {"file.md", "text/plain; charset=utf-8"}, + {"file.txt", contentTypePlainText}, + {"file.md", contentTypePlainText}, {"file.json", "application/json; charset=utf-8"}, {"file.js", "application/javascript; charset=utf-8"}, {"file.go", "text/x-go; charset=utf-8"}, @@ -178,10 +180,10 @@ func TestDetectContentType(t *testing.T) { {"file.rs", "text/x-rust; charset=utf-8"}, {"file.png", "image/png"}, {"file.jpg", "image/jpeg"}, - {"README", "text/plain; charset=utf-8"}, - {"LICENSE", "text/plain; charset=utf-8"}, - {"Makefile", "text/plain; charset=utf-8"}, - {".gitignore", "text/plain; charset=utf-8"}, + {"README", contentTypePlainText}, + {"LICENSE", contentTypePlainText}, + {"Makefile", contentTypePlainText}, + {".gitignore", contentTypePlainText}, {"file.bin", "application/octet-stream"}, } @@ -313,11 +315,11 @@ func TestHandleBrowseSourcePage(t *testing.T) { if err := os.MkdirAll(artifactsDir, 0755); err != nil { t.Fatalf("failed to create artifacts dir: %v", err) } - storagePath := filepath.Join(artifactsDir, "test.tar.gz") + storagePath := filepath.Join(artifactsDir, testArchiveName) if err := os.WriteFile(storagePath, archiveData, 0644); err != nil { t.Fatalf("failed to write test archive: %v", err) } - relPath := "test.tar.gz" + relPath := testArchiveName // Setup test package and artifact pkg := &database.Package{ diff --git a/internal/server/packages_list_test.go b/internal/server/packages_list_test.go index 1b915ad..ac57d74 100644 --- a/internal/server/packages_list_test.go +++ b/internal/server/packages_list_test.go @@ -28,7 +28,7 @@ func TestHandlePackagesList(t *testing.T) { // Create test data pkg1 := &database.Package{ PURL: "pkg:npm/lodash", - Ecosystem: "npm", + Ecosystem: testEcosystemNPM, Name: "lodash", LatestVersion: sql.NullString{String: "4.17.21", Valid: true}, License: sql.NullString{String: "MIT", Valid: true}, @@ -103,7 +103,7 @@ func TestHandlePackagesList(t *testing.T) { if len(resp.Results) != 2 { t.Errorf("expected 2 results, got %d", len(resp.Results)) } - if resp.SortBy != "hits" { + if resp.SortBy != defaultSortBy { t.Errorf("expected default sort to be hits, got %s", resp.SortBy) } }) @@ -123,7 +123,7 @@ func TestHandlePackagesList(t *testing.T) { t.Fatal(err) } - if resp.Ecosystem != "npm" { + if resp.Ecosystem != testEcosystemNPM { t.Errorf("expected ecosystem npm, got %s", resp.Ecosystem) } if resp.Count != 1 { diff --git a/internal/server/server.go b/internal/server/server.go index 19eb468..8e6b588 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -60,6 +60,14 @@ import ( "github.com/go-chi/chi/v5/middleware" ) +const ( + serverReadTimeout = 30 * time.Second + serverWriteTimeout = 5 * time.Minute + serverIdleTimeout = 60 * time.Second + dashboardTopN = 10 + hoursPerDay = 24 +) + // Server is the main proxy server. type Server struct { cfg *config.Config @@ -96,7 +104,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*Server, error) { storageURL := cfg.Storage.URL if storageURL == "" { // Fall back to file:// with Path - storageURL = "file://" + cfg.Storage.Path + storageURL = "file://" + cfg.Storage.Path //nolint:staticcheck // backwards compat } store, err := storage.OpenBucket(context.Background(), storageURL) if err != nil { @@ -228,15 +236,15 @@ func (s *Server) Start() error { s.http = &http.Server{ Addr: s.cfg.Listen, Handler: r, - ReadTimeout: 30 * time.Second, - WriteTimeout: 5 * time.Minute, // Large artifacts need time - IdleTimeout: 60 * time.Second, + ReadTimeout: serverReadTimeout, + WriteTimeout: serverWriteTimeout, // Large artifacts need time + IdleTimeout: serverIdleTimeout, } s.logger.Info("starting server", "listen", s.cfg.Listen, "base_url", s.cfg.BaseURL, - "storage", s.cfg.Storage.Path, + "storage", s.cfg.Storage.Path, //nolint:staticcheck // backwards compat "database", s.cfg.Database.Path) // Start background goroutine to update cache stats metrics @@ -316,13 +324,13 @@ func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { } // Get popular packages - popular, err := s.db.GetMostPopularPackages(10) + popular, err := s.db.GetMostPopularPackages(dashboardTopN) if err != nil { s.logger.Error("failed to get popular packages", "error", err) } // Get recent packages - recent, err := s.db.GetRecentlyCachedPackages(10) + recent, err := s.db.GetRecentlyCachedPackages(dashboardTopN) if err != nil { s.logger.Error("failed to get recent packages", "error", err) } @@ -497,7 +505,7 @@ func (s *Server) handlePackagesList(w http.ResponseWriter, r *http.Request) { ecosystem := r.URL.Query().Get("ecosystem") sortBy := r.URL.Query().Get("sort") if sortBy == "" { - sortBy = "hits" + sortBy = defaultSortBy } page := 1 @@ -731,7 +739,7 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { CachedArtifacts: count, TotalSize: size, TotalSizeHuman: formatSize(size), - StoragePath: s.cfg.Storage.Path, + StoragePath: s.cfg.Storage.Path, //nolint:staticcheck // backwards compat DatabasePath: s.cfg.Database.Path, } @@ -766,14 +774,14 @@ func formatTimeAgo(t time.Time) string { return "1 min ago" } return fmt.Sprintf("%d mins ago", m) - case d < 24*time.Hour: + case d < hoursPerDay*time.Hour: h := int(d.Hours()) if h == 1 { return "1 hour ago" } return fmt.Sprintf("%d hours ago", h) - case d < 7*24*time.Hour: - days := int(d.Hours() / 24) + case d < 7*hoursPerDay*time.Hour: + days := int(d.Hours() / hoursPerDay) if days == 1 { return "1 day ago" } @@ -786,7 +794,7 @@ func formatTimeAgo(t time.Time) string { // categorizeLicenseCSS returns the CSS class suffix for a license category using the spdx module. func categorizeLicenseCSS(license string) string { if license == "" { - return "unknown" + return licenseCategoryUnknown } if spdx.HasCopyleft(license) { @@ -797,13 +805,13 @@ func categorizeLicenseCSS(license string) string { return "permissive" } - return "unknown" + return licenseCategoryUnknown } // categorizeLicense is a helper that handles sql.NullString. func categorizeLicense(license sql.NullString) string { if !license.Valid { - return "unknown" + return licenseCategoryUnknown } return categorizeLicenseCSS(license.String) } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 555e80a..69f36e8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -125,6 +125,39 @@ func (ts *testServer) close() { _ = os.RemoveAll(ts.tempDir) } +// seedTestPackage creates a package, version, and artifact in the database for testing +// page rendering. The package is created under the npm ecosystem with version 1.0.0. +func seedTestPackage(t *testing.T, db *database.DB, name string) { + t.Helper() + + pkg := &database.Package{ + PURL: "pkg:npm/" + name, + Ecosystem: "npm", + Name: name, + } + if err := db.UpsertPackage(pkg); err != nil { + t.Fatalf("failed to upsert package: %v", err) + } + + ver := &database.Version{ + PURL: "pkg:npm/" + name + "@1.0.0", + PackagePURL: pkg.PURL, + } + if err := db.UpsertVersion(ver); err != nil { + t.Fatalf("failed to upsert version: %v", err) + } + + artifact := &database.Artifact{ + VersionPURL: ver.PURL, + Filename: name + "-1.0.0.tgz", + UpstreamURL: "https://registry.npmjs.org/" + name + "/-/" + name + "-1.0.0.tgz", + StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, + } + if err := db.UpsertArtifact(artifact); err != nil { + t.Fatalf("failed to upsert artifact: %v", err) + } +} + func TestHandleOpenAPIJSON(t *testing.T) { ts := newTestServer(t) defer ts.close() @@ -668,32 +701,7 @@ func TestSearchPage_WithSeededResults(t *testing.T) { ts := newTestServer(t) defer ts.close() - pkg := &database.Package{ - PURL: "pkg:npm/searchable-pkg", - Ecosystem: "npm", - Name: "searchable-pkg", - } - if err := ts.db.UpsertPackage(pkg); err != nil { - t.Fatalf("failed to upsert package: %v", err) - } - - ver := &database.Version{ - PURL: "pkg:npm/searchable-pkg@1.0.0", - PackagePURL: pkg.PURL, - } - if err := ts.db.UpsertVersion(ver); err != nil { - t.Fatalf("failed to upsert version: %v", err) - } - - artifact := &database.Artifact{ - VersionPURL: ver.PURL, - Filename: "searchable-pkg-1.0.0.tgz", - UpstreamURL: "https://registry.npmjs.org/searchable-pkg/-/searchable-pkg-1.0.0.tgz", - StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, - } - if err := ts.db.UpsertArtifact(artifact); err != nil { - t.Fatalf("failed to upsert artifact: %v", err) - } + seedTestPackage(t, ts.db, "searchable-pkg") req := httptest.NewRequest("GET", "/search?q=searchable", nil) w := httptest.NewRecorder() @@ -844,32 +852,7 @@ func TestHandlePackagesListPage(t *testing.T) { ts := newTestServer(t) defer ts.close() - pkg := &database.Package{ - PURL: "pkg:npm/list-test", - Ecosystem: "npm", - Name: "list-test", - } - if err := ts.db.UpsertPackage(pkg); err != nil { - t.Fatalf("failed to upsert package: %v", err) - } - - ver := &database.Version{ - PURL: "pkg:npm/list-test@1.0.0", - PackagePURL: pkg.PURL, - } - if err := ts.db.UpsertVersion(ver); err != nil { - t.Fatalf("failed to upsert version: %v", err) - } - - artifact := &database.Artifact{ - VersionPURL: ver.PURL, - Filename: "list-test-1.0.0.tgz", - UpstreamURL: "https://registry.npmjs.org/list-test/-/list-test-1.0.0.tgz", - StoragePath: sql.NullString{String: "/tmp/test.tgz", Valid: true}, - } - if err := ts.db.UpsertArtifact(artifact); err != nil { - t.Fatalf("failed to upsert artifact: %v", err) - } + seedTestPackage(t, ts.db, "list-test") req := httptest.NewRequest("GET", "/packages", nil) w := httptest.NewRecorder() diff --git a/internal/server/templates_test.go b/internal/server/templates_test.go index 3a67880..9d26269 100644 --- a/internal/server/templates_test.go +++ b/internal/server/templates_test.go @@ -62,7 +62,7 @@ func TestTemplatesRenderAllPages(t *testing.T) { }}, {"packages_list", PackagesListPageData{ Ecosystem: "", - SortBy: "hits", + SortBy: defaultSortBy, Results: []SearchResultItem{{Ecosystem: "npm", Name: "express", Hits: 200, SizeFormatted: "2 MB"}}, Count: 1, Page: 1, diff --git a/internal/storage/blob.go b/internal/storage/blob.go index 22643fa..11da357 100644 --- a/internal/storage/blob.go +++ b/internal/storage/blob.go @@ -17,6 +17,8 @@ import ( "gocloud.dev/gcerrors" ) +const osWindows = "windows" + // Blob implements Storage using gocloud.dev/blob. // Supports local filesystem (file://) and S3 (s3://) URLs. type Blob struct { @@ -38,10 +40,10 @@ func OpenBucket(ctx context.Context, urlStr string) (*Blob, error) { path := strings.TrimPrefix(urlStr, "file://") // Handle file:/// (three slashes) for absolute paths - if strings.HasPrefix(path, "/") && runtime.GOOS != "windows" { + if strings.HasPrefix(path, "/") && runtime.GOOS != osWindows { // Unix: file:///path -> /path // path is already correct - } else if strings.HasPrefix(path, "/") && runtime.GOOS == "windows" { + } else if strings.HasPrefix(path, "/") && runtime.GOOS == osWindows { // Windows: file:///C:/path -> C:/path path = strings.TrimPrefix(path, "/") } @@ -50,7 +52,7 @@ func OpenBucket(ctx context.Context, urlStr string) (*Blob, error) { nativePath := filepath.FromSlash(path) // Ensure directory exists - if err := os.MkdirAll(nativePath, 0755); err != nil { + if err := os.MkdirAll(nativePath, dirPermissions); err != nil { return nil, fmt.Errorf("creating directory: %w", err) } @@ -62,7 +64,7 @@ func OpenBucket(ctx context.Context, urlStr string) (*Blob, error) { // Convert back to URL format with forward slashes urlPath := filepath.ToSlash(absPath) - if runtime.GOOS == "windows" { + if runtime.GOOS == osWindows { // Windows needs file:///C:/path format urlStr = "file:///" + urlPath } else { diff --git a/internal/storage/blob_test.go b/internal/storage/blob_test.go index 3add0a7..cfb8281 100644 --- a/internal/storage/blob_test.go +++ b/internal/storage/blob_test.go @@ -1,7 +1,6 @@ package storage import ( - "bytes" "context" "crypto/sha256" "encoding/hex" @@ -186,33 +185,7 @@ func TestBlobUsedSpace(t *testing.T) { } func TestBlobLargeFile(t *testing.T) { - b := createTestBlob(t) - ctx := context.Background() - - // 1MB of data - data := bytes.Repeat([]byte("x"), 1024*1024) - - size, hash, err := b.Store(ctx, "large/file.bin", bytes.NewReader(data)) - if err != nil { - t.Fatalf("Store large file failed: %v", err) - } - if size != int64(len(data)) { - t.Errorf("size = %d, want %d", size, len(data)) - } - - h := sha256.Sum256(data) - wantHash := hex.EncodeToString(h[:]) - if hash != wantHash { - t.Errorf("hash mismatch for large file") - } - - // Read it back - r, _ := b.Open(ctx, "large/file.bin") - defer func() { _ = r.Close() }() - readBack, _ := io.ReadAll(r) - if !bytes.Equal(readBack, data) { - t.Error("large file content mismatch") - } + assertLargeFileRoundTrip(t, createTestBlob(t)) } func TestBlobOverwrite(t *testing.T) { @@ -258,7 +231,7 @@ func createTestBlob(t *testing.T) *Blob { } func fileURLFromPath(path string) string { - if runtime.GOOS == "windows" { + if runtime.GOOS == osWindows { // Windows paths need file:///C:/path format path = filepath.ToSlash(path) return "file:///" + path diff --git a/internal/storage/filesystem.go b/internal/storage/filesystem.go index 01be90d..8dec48b 100644 --- a/internal/storage/filesystem.go +++ b/internal/storage/filesystem.go @@ -23,7 +23,7 @@ func NewFilesystem(root string) (*Filesystem, error) { return nil, fmt.Errorf("resolving root path: %w", err) } - if err := os.MkdirAll(absRoot, 0755); err != nil { + if err := os.MkdirAll(absRoot, dirPermissions); err != nil { return nil, fmt.Errorf("creating root directory: %w", err) } @@ -38,7 +38,7 @@ func (fs *Filesystem) Store(ctx context.Context, path string, r io.Reader) (int6 fullPath := fs.fullPath(path) dir := filepath.Dir(fullPath) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, dirPermissions); err != nil { return 0, "", fmt.Errorf("creating directory: %w", err) } diff --git a/internal/storage/filesystem_test.go b/internal/storage/filesystem_test.go index 16c7447..7fbba10 100644 --- a/internal/storage/filesystem_test.go +++ b/internal/storage/filesystem_test.go @@ -1,7 +1,6 @@ package storage import ( - "bytes" "context" "crypto/sha256" "encoding/hex" @@ -234,33 +233,7 @@ func TestFilesystemUsedSpace(t *testing.T) { } func TestFilesystemLargeFile(t *testing.T) { - fs := createTestFilesystem(t) - ctx := context.Background() - - // 1MB of data - data := bytes.Repeat([]byte("x"), 1024*1024) - - size, hash, err := fs.Store(ctx, "large/file.bin", bytes.NewReader(data)) - if err != nil { - t.Fatalf("Store large file failed: %v", err) - } - if size != int64(len(data)) { - t.Errorf("size = %d, want %d", size, len(data)) - } - - h := sha256.Sum256(data) - wantHash := hex.EncodeToString(h[:]) - if hash != wantHash { - t.Errorf("hash mismatch for large file") - } - - // Read it back - r, _ := fs.Open(ctx, "large/file.bin") - defer func() { _ = r.Close() }() - readBack, _ := io.ReadAll(r) - if !bytes.Equal(readBack, data) { - t.Error("large file content mismatch") - } + assertLargeFileRoundTrip(t, createTestFilesystem(t)) } func createTestFilesystem(t *testing.T) *Filesystem { diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 1efbcba..93053ca 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -17,6 +17,8 @@ import ( "io" ) +const dirPermissions = 0755 + var ( ErrNotFound = errors.New("artifact not found") ) diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go index c71df26..97800f0 100644 --- a/internal/storage/storage_test.go +++ b/internal/storage/storage_test.go @@ -1,6 +1,8 @@ package storage import ( + "bytes" + "context" "crypto/sha256" "encoding/hex" "io" @@ -56,3 +58,33 @@ func TestHashingReader(t *testing.T) { t.Errorf("got hash %s, want %s", r.Sum(), wantHash) } } + +// assertLargeFileRoundTrip stores a 1MB file in the given storage, verifies size and +// hash, then reads it back and confirms the content matches. +func assertLargeFileRoundTrip(t *testing.T, s Storage) { + t.Helper() + ctx := context.Background() + + data := bytes.Repeat([]byte("x"), 1024*1024) + + size, hash, err := s.Store(ctx, "large/file.bin", bytes.NewReader(data)) + if err != nil { + t.Fatalf("Store large file failed: %v", err) + } + if size != int64(len(data)) { + t.Errorf("size = %d, want %d", size, len(data)) + } + + h := sha256.Sum256(data) + wantHash := hex.EncodeToString(h[:]) + if hash != wantHash { + t.Errorf("hash mismatch for large file") + } + + r, _ := s.Open(ctx, "large/file.bin") + defer func() { _ = r.Close() }() + readBack, _ := io.ReadAll(r) + if !bytes.Equal(readBack, data) { + t.Error("large file content mismatch") + } +}