diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 98f3320..3305707 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -141,7 +141,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { if srvName := config.GetCacheSRV(); srvName != "" { downloadOpts.SliceSize = 500 * humanize.MiByte downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() - downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) if downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName); err != nil { return err } @@ -152,7 +152,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { } else if cacheHostname := config.CacheServiceHostname(); cacheHostname != "" { downloadOpts.CacheHosts = []string{cacheHostname} downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() - downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) } if getter.Downloader == nil { diff --git a/cmd/root/root.go b/cmd/root/root.go index 7c63977..4853f91 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -282,6 +282,7 @@ func rootExecute(ctx context.Context, urlString, dest string) error { downloadOpts.SliceSize = 500 * humanize.MiByte downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) if downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName); err != nil { return err } @@ -293,6 +294,7 @@ func rootExecute(ctx context.Context, urlString, dest string) error { downloadOpts.CacheHosts = []string{cacheHostname} downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) + downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) } if getter.Downloader == nil { diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 826a1b1..565ebc1 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -8,6 +8,7 @@ const ( OptCacheServiceHostname = "cache-service-hostname" OptCacheURIPrefixes = "cache-uri-prefixes" OptCacheUsePathProxy = "cache-use-path-proxy" + OptForceCachePrefixRewrite = "force-cache-prefix-rewrite" OptHostIP = "host-ip" OptMetricsEndpoint = "metrics-endpoint" OptHeaders = "headers" diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index ce9542e..874f9b3 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -9,6 +9,8 @@ import ( "strconv" "strings" + "github.com/rs/zerolog" + "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/logging" ) @@ -229,48 +231,66 @@ func (m *BufferMode) rewriteUrlForCache(urlString string) string { Msg("Cache URL Rewrite") return urlString } - if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { - for _, pfx := range prefixes { - if pfx.Path == "/" || strings.HasPrefix(parsed.Path, pfx.Path) { - newUrl := m.CacheHosts[0] - if m.CacheUsePathProxy { - newUrl, err = url.JoinPath(newUrl, pfx.Host) - if err != nil { - break - } - logger.Debug(). - Bool("path_based_proxy", true). - Str("host_prefix", pfx.Host). - Str("intermediate_target_url", newUrl). - Str("url", urlString). - Msg("Cache URL Rewrite") - } - newUrl, err = url.JoinPath(newUrl, parsed.Path) - if err != nil { - break + if m.ForceCachePrefixRewrite { + // Forcefully rewrite the URL prefix + return m.rewritePrefix(m.CacheHosts[0], urlString, parsed, logger) + } else { + if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { + for _, pfx := range prefixes { + if pfx.Path == "/" || strings.HasPrefix(parsed.Path, pfx.Path) { + // Found a matching prefix, rewrite the URL prefix + return m.rewritePrefix(m.CacheHosts[0], urlString, parsed, logger) } - logger.Info(). - Str("url", urlString). - Str("target_url", newUrl). - Bool("enabled", true). - Msg("Cache URL Rewrite") - return newUrl } } } + + // If we got here, we weren't forcefully rewriting the cache prefix and we didn't + // find any matching prefixes, so we just return the original URL + logger.Debug(). + Str("url", urlString). + Bool("enabled", false). + Str("disabled_reason", "no matching prefix"). + Str("disabled_reason", "failed to join host URL to path"). + Msg("Cache URL Rewrite") + return urlString +} + +func (m *BufferMode) rewritePrefix(cacheHost, urlString string, parsed *url.URL, logger zerolog.Logger) string { + newUrl := cacheHost + var err error + if m.CacheUsePathProxy { + newUrl, err = url.JoinPath(newUrl, parsed.Host) + if err != nil { + logger.Error(). + Err(err). + Str("url", urlString). + Bool("enabled", false). + Str("disabled_reason", "failed to join cache URL to host"). + Msg("Cache URL Rewrite") + return urlString + } + logger.Debug(). + Bool("path_based_proxy", true). + Str("host_prefix", parsed.Host). + Str("intermediate_target_url", newUrl). + Str("url", urlString). + Msg("Cache URL Rewrite") + } + newUrl, err = url.JoinPath(newUrl, parsed.Path) if err != nil { logger.Error(). Err(err). Str("url", urlString). Bool("enabled", false). - Str("disabled_reason", "failed to generate target url"). - Msg("Cache URL Rewrite") - } else { - logger.Debug(). - Str("url", urlString). - Bool("enabled", false). - Str("disabled_reason", "no matching prefix"). + Str("disabled_reason", "failed to join host URL to path"). Msg("Cache URL Rewrite") + return urlString } - return urlString + logger.Info(). + Str("url", urlString). + Str("target_url", newUrl). + Bool("enabled", true). + Msg("Cache URL Rewrite") + return newUrl } diff --git a/pkg/download/options.go b/pkg/download/options.go index fe45651..84a8998 100644 --- a/pkg/download/options.go +++ b/pkg/download/options.go @@ -36,6 +36,11 @@ type Options struct { // hashing algorithm. The slice may contain empty entries which // correspond to a cache host which is currently unavailable. CacheHosts []string + + // ForceCachePrefixRewrite will forcefully rewrite the prefix for all + // pget requests to the first item in the CacheHosts list. This ignores + // anything in the CacheableURIPrefixes and rewrites all requests. + ForceCachePrefixRewrite bool } func (o *Options) maxConcurrency() int {