From 35086b69510ea853a39b4a71ff066dfd8f848183 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 17 Jun 2025 17:24:02 -0700 Subject: [PATCH 1/5] Add authorization header --- cmd/root/root.go | 8 +++++--- pkg/config/optnames.go | 1 + pkg/download/buffer.go | 3 +++ pkg/download/consistent_hashing.go | 4 +++- pkg/download/options.go | 5 +++++ 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/cmd/root/root.go b/cmd/root/root.go index 4853f91..bc98d84 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -179,6 +179,7 @@ func persistentFlags(cmd *cobra.Command) error { cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host") cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)") cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path") + cmd.PersistentFlags().String(config.OptHTTPAuthHeader, "", "HTTP Authorization header") if err := hideAndDeprecateFlags(cmd); err != nil { return err @@ -258,9 +259,10 @@ func rootExecute(ctx context.Context, urlString, dest string) error { } downloadOpts := download.Options{ - MaxConcurrency: viper.GetInt(config.OptConcurrency), - ChunkSize: int64(chunkSize), - Client: clientOpts, + MaxConcurrency: viper.GetInt(config.OptConcurrency), + ChunkSize: int64(chunkSize), + Client: clientOpts, + HTTPAuthorizationHeader: viper.GetString(config.OptHTTPAuthHeader), } consumer, err := config.GetConsumer() diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 565ebc1..9c1637d 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -30,4 +30,5 @@ const ( OptResolve = "resolve" OptRetries = "retries" OptVerbose = "verbose" + OptHTTPAuthHeader = "http-auth-header" ) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index ddcf94b..61b2925 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -200,6 +200,9 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st return nil, fmt.Errorf("failed to download %s: %w", trueURL, err) } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + if m.HTTPAuthorizationHeader != "" { + req.Header.Set("Authorization", m.HTTPAuthorizationHeader) + } resp, err := m.Client.Do(req) if err != nil { return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err) diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index 2115fe3..a42850b 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -301,7 +301,9 @@ func (m *ConsistentHashingMode) doRequestToCacheHost(req *http.Request, urlStrin return nil, cachePodIndex, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - + if m.HTTPAuthorizationHeader != "" { + req.Header.Set("Authorization", m.HTTPAuthorizationHeader) + } logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request") resp, err := m.Client.Do(req) diff --git a/pkg/download/options.go b/pkg/download/options.go index 84a8998..122f2dd 100644 --- a/pkg/download/options.go +++ b/pkg/download/options.go @@ -41,6 +41,11 @@ type Options struct { // pget requests to the first item in the CacheHosts list. This ignores // anything in the CacheableURIPrefixes and rewrites all requests. ForceCachePrefixRewrite bool + + // HTTPAuthorizationHeader sets the HTTP Authoriation header in requests to + // the upstream. Notably, following the HTTP protocol, this header will not + // persist on redirects. + HTTPAuthorizationHeader string } func (o *Options) maxConcurrency() int { From a8bc6680e0e2db1e85732ac470382c41da6907d2 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Fri, 20 Jun 2025 16:38:49 -0700 Subject: [PATCH 2/5] Should be behind env var --- cmd/root/root.go | 8 +++----- pkg/config/optnames.go | 2 +- pkg/download/buffer.go | 7 +++++-- pkg/download/consistent_hashing.go | 6 ++++-- pkg/download/options.go | 5 ----- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/cmd/root/root.go b/cmd/root/root.go index bc98d84..4853f91 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -179,7 +179,6 @@ func persistentFlags(cmd *cobra.Command) error { cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host") cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)") cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path") - cmd.PersistentFlags().String(config.OptHTTPAuthHeader, "", "HTTP Authorization header") if err := hideAndDeprecateFlags(cmd); err != nil { return err @@ -259,10 +258,9 @@ func rootExecute(ctx context.Context, urlString, dest string) error { } downloadOpts := download.Options{ - MaxConcurrency: viper.GetInt(config.OptConcurrency), - ChunkSize: int64(chunkSize), - Client: clientOpts, - HTTPAuthorizationHeader: viper.GetString(config.OptHTTPAuthHeader), + MaxConcurrency: viper.GetInt(config.OptConcurrency), + ChunkSize: int64(chunkSize), + Client: clientOpts, } consumer, err := config.GetConsumer() diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 9c1637d..831cfbf 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -12,6 +12,7 @@ const ( OptHostIP = "host-ip" OptMetricsEndpoint = "metrics-endpoint" OptHeaders = "headers" + OptProxyAuthHeader = "proxy-auth-header" // Normal options with CLI arguments OptConcurrency = "concurrency" @@ -30,5 +31,4 @@ const ( OptResolve = "resolve" OptRetries = "retries" OptVerbose = "verbose" - OptHTTPAuthHeader = "http-auth-header" ) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 13ea9c3..7411469 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -10,8 +10,10 @@ import ( "strings" "github.com/rs/zerolog" + "github.com/spf13/viper" "github.com/replicate/pget/pkg/client" + "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/logging" ) @@ -200,8 +202,9 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st return nil, fmt.Errorf("failed to download %s: %w", trueURL, err) } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - if m.HTTPAuthorizationHeader != "" { - req.Header.Set("Authorization", m.HTTPAuthorizationHeader) + proxyAuthHeader := viper.GetString(config.OptProxyAuthHeader) + if proxyAuthHeader != "" { + req.Header.Set("Authorization", proxyAuthHeader) } resp, err := m.Client.Do(req) if err != nil { diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index a42850b..e67430e 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -14,6 +14,7 @@ import ( "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/consistent" "github.com/replicate/pget/pkg/logging" + "github.com/spf13/viper" ) type ConsistentHashingMode struct { @@ -301,8 +302,9 @@ func (m *ConsistentHashingMode) doRequestToCacheHost(req *http.Request, urlStrin return nil, cachePodIndex, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - if m.HTTPAuthorizationHeader != "" { - req.Header.Set("Authorization", m.HTTPAuthorizationHeader) + proxyAuthHeader := viper.GetString(config.OptProxyAuthHeader) + if proxyAuthHeader != "" { + req.Header.Set("Authorization", proxyAuthHeader) } logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request") diff --git a/pkg/download/options.go b/pkg/download/options.go index 122f2dd..84a8998 100644 --- a/pkg/download/options.go +++ b/pkg/download/options.go @@ -41,11 +41,6 @@ type Options struct { // pget requests to the first item in the CacheHosts list. This ignores // anything in the CacheableURIPrefixes and rewrites all requests. ForceCachePrefixRewrite bool - - // HTTPAuthorizationHeader sets the HTTP Authoriation header in requests to - // the upstream. Notably, following the HTTP protocol, this header will not - // persist on redirects. - HTTPAuthorizationHeader string } func (o *Options) maxConcurrency() int { From 2bd6b120634f539745209c1de7205cef444a533b Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 24 Jun 2025 16:23:52 -0700 Subject: [PATCH 3/5] Linting --- pkg/download/consistent_hashing.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index e67430e..4ac934c 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -10,11 +10,12 @@ import ( "strconv" "strings" + "github.com/spf13/viper" + "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/consistent" "github.com/replicate/pget/pkg/logging" - "github.com/spf13/viper" ) type ConsistentHashingMode struct { From 1f64f6fc08daeb1bac8eb52b0521d296182a3cd6 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 24 Jun 2025 18:20:19 -0700 Subject: [PATCH 4/5] Eat auth header on redirect --- pkg/download/buffer.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 7411469..817cf25 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -21,14 +21,16 @@ type BufferMode struct { Client client.HTTPClient Options - queue *priorityWorkQueue + queue *priorityWorkQueue + redirected bool } func GetBufferMode(opts Options) *BufferMode { client := client.NewHTTPClient(opts.Client) m := &BufferMode{ - Client: client, - Options: opts, + Client: client, + Options: opts, + redirected: false, } m.queue = newWorkQueue(opts.maxConcurrency(), m.chunkSize()) m.queue.start() @@ -102,6 +104,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e trueURL := firstChunkResp.Request.URL.String() if trueURL != url { logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") + m.redirected = true } fileSize, err := m.getFileSizeFromResponse(firstChunkResp) @@ -203,7 +206,7 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) proxyAuthHeader := viper.GetString(config.OptProxyAuthHeader) - if proxyAuthHeader != "" { + if proxyAuthHeader != "" && !m.redirected { req.Header.Set("Authorization", proxyAuthHeader) } resp, err := m.Client.Do(req) From 7f49bd70b7095481c7c118ad7dda03deb437cfdf Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Wed, 25 Jun 2025 10:42:12 -0700 Subject: [PATCH 5/5] Ignore consistent hashing for now --- pkg/download/consistent_hashing.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index 4ac934c..2115fe3 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -10,8 +10,6 @@ import ( "strconv" "strings" - "github.com/spf13/viper" - "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/consistent" @@ -303,10 +301,7 @@ func (m *ConsistentHashingMode) doRequestToCacheHost(req *http.Request, urlStrin return nil, cachePodIndex, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - proxyAuthHeader := viper.GetString(config.OptProxyAuthHeader) - if proxyAuthHeader != "" { - req.Header.Set("Authorization", proxyAuthHeader) - } + logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request") resp, err := m.Client.Do(req)