From 6a25451eaa64f771bbec20ce8910a6aa95d4d986 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Wed, 4 Jun 2025 16:49:01 -0700 Subject: [PATCH 1/3] Always rewrite prefix --- pkg/download/buffer.go | 62 ++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index ce9542e..28d9f63 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" "strconv" - "strings" "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/logging" @@ -229,48 +228,39 @@ func (m *BufferMode) rewriteUrlForCache(urlString string) string { Msg("Cache URL Rewrite") return urlString } - if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { - for _, pfx := range prefixes { - if pfx.Path == "/" || strings.HasPrefix(parsed.Path, pfx.Path) { - newUrl := m.CacheHosts[0] - if m.CacheUsePathProxy { - newUrl, err = url.JoinPath(newUrl, pfx.Host) - if err != nil { - break - } - logger.Debug(). - Bool("path_based_proxy", true). - Str("host_prefix", pfx.Host). - Str("intermediate_target_url", newUrl). - Str("url", urlString). - Msg("Cache URL Rewrite") - } - newUrl, err = url.JoinPath(newUrl, parsed.Path) - if err != nil { - break - } - logger.Info(). - Str("url", urlString). - Str("target_url", newUrl). - Bool("enabled", true). - Msg("Cache URL Rewrite") - return newUrl - } + newUrl := m.CacheHosts[0] + if m.CacheUsePathProxy { + newUrl, err = url.JoinPath(newUrl, parsed.Host) + if err != nil { + logger.Error(). + Err(err). + Str("url", urlString). + Bool("enabled", false). + Str("disabled_reason", "failed to join cache URL to host"). + Msg("Cache URL Rewrite") + return urlString } + logger.Debug(). + Bool("path_based_proxy", true). + Str("host_prefix", parsed.Host). + Str("intermediate_target_url", newUrl). + Str("url", urlString). + Msg("Cache URL Rewrite") } + newUrl, err = url.JoinPath(newUrl, parsed.Path) if err != nil { logger.Error(). Err(err). Str("url", urlString). Bool("enabled", false). - Str("disabled_reason", "failed to generate target url"). - Msg("Cache URL Rewrite") - } else { - logger.Debug(). - Str("url", urlString). - Bool("enabled", false). - Str("disabled_reason", "no matching prefix"). + Str("disabled_reason", "failed to join host URL to path"). Msg("Cache URL Rewrite") + return urlString } - return urlString + logger.Info(). + Str("url", urlString). + Str("target_url", newUrl). + Bool("enabled", true). + Msg("Cache URL Rewrite") + return newUrl } From 82414110b01a42151dd27c614b4ed3eb10d39374 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Wed, 4 Jun 2025 17:15:21 -0700 Subject: [PATCH 2/3] Nvm put it behind a flag --- cmd/multifile/multifile.go | 4 ++-- cmd/root/root.go | 2 ++ pkg/config/optnames.go | 1 + pkg/download/buffer.go | 31 ++++++++++++++++++++++++++++++- pkg/download/options.go | 5 +++++ 5 files changed, 40 insertions(+), 3 deletions(-) diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 98f3320..3305707 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -141,7 +141,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { if srvName := config.GetCacheSRV(); srvName != "" { downloadOpts.SliceSize = 500 * humanize.MiByte downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() - downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) if downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName); err != nil { return err } @@ -152,7 +152,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { } else if cacheHostname := config.CacheServiceHostname(); cacheHostname != "" { downloadOpts.CacheHosts = []string{cacheHostname} downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() - downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) } if getter.Downloader == nil { diff --git a/cmd/root/root.go b/cmd/root/root.go index 7c63977..4853f91 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -282,6 +282,7 @@ func rootExecute(ctx context.Context, urlString, dest string) error { downloadOpts.SliceSize = 500 * humanize.MiByte downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) if downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName); err != nil { return err } @@ -293,6 +294,7 @@ func rootExecute(ctx context.Context, urlString, dest string) error { downloadOpts.CacheHosts = []string{cacheHostname} downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) } if getter.Downloader == nil { diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 826a1b1..565ebc1 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -8,6 +8,7 @@ const ( OptCacheServiceHostname = "cache-service-hostname" OptCacheURIPrefixes = "cache-uri-prefixes" OptCacheUsePathProxy = "cache-use-path-proxy" + OptForceCachePrefixRewrite = "force-cache-prefix-rewrite" OptHostIP = "host-ip" OptMetricsEndpoint = "metrics-endpoint" OptHeaders = "headers" diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 28d9f63..5591662 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -7,9 +7,11 @@ import ( "net/http" "net/url" "strconv" + "strings" "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/logging" + "github.com/rs/zerolog" ) type BufferMode struct { @@ -228,7 +230,34 @@ func (m *BufferMode) rewriteUrlForCache(urlString string) string { Msg("Cache URL Rewrite") return urlString } - newUrl := m.CacheHosts[0] + if m.ForceCachePrefixRewrite { + // Forcefully rewrite the URL prefix + return m.rewritePrefix(m.CacheHosts[0], urlString, parsed, logger) + } else { + if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { + for _, pfx := range prefixes { + if pfx.Path == "/" || strings.HasPrefix(parsed.Path, pfx.Path) { + // Found a matching prefix, rewrite the URL prefix + return m.rewritePrefix(m.CacheHosts[0], urlString, parsed, logger) + } + } + } + } + + // If we got here, we weren't forcefully rewriting the cache prefix and we didn't + // find any matching prefixes, so we just return the original URL + logger.Debug(). + Str("url", urlString). + Bool("enabled", false). + Str("disabled_reason", "no matching prefix"). + Str("disabled_reason", "failed to join host URL to path"). + Msg("Cache URL Rewrite") + return urlString +} + +func (m *BufferMode) rewritePrefix(cacheHost, urlString string, parsed *url.URL, logger zerolog.Logger) string { + newUrl := cacheHost + var err error if m.CacheUsePathProxy { newUrl, err = url.JoinPath(newUrl, parsed.Host) if err != nil { diff --git a/pkg/download/options.go b/pkg/download/options.go index fe45651..84a8998 100644 --- a/pkg/download/options.go +++ b/pkg/download/options.go @@ -36,6 +36,11 @@ type Options struct { // hashing algorithm. The slice may contain empty entries which // correspond to a cache host which is currently unavailable. CacheHosts []string + + // ForceCachePrefixRewrite will forcefully rewrite the prefix for all + // pget requests to the first item in the CacheHosts list. This ignores + // anything in the CacheableURIPrefixes and rewrites all requests. + ForceCachePrefixRewrite bool } func (o *Options) maxConcurrency() int { From 38036509bc805a98a12c1589d469a31dc2377416 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Wed, 4 Jun 2025 17:20:36 -0700 Subject: [PATCH 3/3] Lint --- pkg/download/buffer.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 5591662..874f9b3 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -9,9 +9,10 @@ import ( "strconv" "strings" + "github.com/rs/zerolog" + "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/logging" - "github.com/rs/zerolog" ) type BufferMode struct {