diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 565ebc1..831cfbf 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -12,6 +12,7 @@ const ( OptHostIP = "host-ip" OptMetricsEndpoint = "metrics-endpoint" OptHeaders = "headers" + OptProxyAuthHeader = "proxy-auth-header" // Normal options with CLI arguments OptConcurrency = "concurrency" diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 9344066..817cf25 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -10,8 +10,10 @@ import ( "strings" "github.com/rs/zerolog" + "github.com/spf13/viper" "github.com/replicate/pget/pkg/client" + "github.com/replicate/pget/pkg/config" "github.com/replicate/pget/pkg/logging" ) @@ -19,14 +21,16 @@ type BufferMode struct { Client client.HTTPClient Options - queue *priorityWorkQueue + queue *priorityWorkQueue + redirected bool } func GetBufferMode(opts Options) *BufferMode { client := client.NewHTTPClient(opts.Client) m := &BufferMode{ - Client: client, - Options: opts, + Client: client, + Options: opts, + redirected: false, } m.queue = newWorkQueue(opts.maxConcurrency(), m.chunkSize()) m.queue.start() @@ -100,6 +104,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e trueURL := firstChunkResp.Request.URL.String() if trueURL != url { logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") + m.redirected = true } fileSize, err := m.getFileSizeFromResponse(firstChunkResp) @@ -200,6 +205,10 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st return nil, fmt.Errorf("failed to download %s: %w", trueURL, err) } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + proxyAuthHeader := viper.GetString(config.OptProxyAuthHeader) + if proxyAuthHeader != "" && !m.redirected { + req.Header.Set("Authorization", proxyAuthHeader) + } resp, err := m.Client.Do(req) if err != nil { return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err)