Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tools/labctl/cmd/labctl/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package main

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -37,12 +39,18 @@ func setupTestScript(env *testscript.Env) error {
if err != nil {
return err
}
talosArchiveSHA256 := sha256.Sum256(talosArchive)
talosArchiveDigest := hex.EncodeToString(talosArchiveSHA256[:])
talosSidecar := []byte(talosArchiveDigest + " nocloud-amd64.raw.xz\n")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/repos/GilmanLab/secrets/contents/network/keycloak.sops.yaml":
handleSecretFixture(env, w, r)
case "/image/" + defaultTalosSchematicID + "/v1.13.0/nocloud-amd64.raw.xz":
_, _ = w.Write(talosArchive)
case "/image/" + defaultTalosSchematicID + "/v1.13.0/nocloud-amd64.raw.xz.sha256":
_, _ = w.Write(talosSidecar)
default:
http.Error(w, fmt.Sprintf("unexpected path %s", r.URL.Path), http.StatusNotFound)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ exists .state/downloads/talos/376567988ad370138ad8b2698212367b8edcb69b5fd68c80be
exec labctl bootstrap talos image build --json talos-valid.yaml
stdout '"name":"talos-test"'
stdout '"bootArtifactPath":'
stdout '"bootArtifactSHA256":"[0-9a-f]{64}"'
stdout '"configArtifactPath":'
stdout '"configArtifactSHA256":"[0-9a-f]{64}"'
stdout '"sourceVersion":"v1.13.0"'
stdout '"sourceURL":'
stdout 'nocloud-amd64.raw.xz'
Expand Down
2 changes: 1 addition & 1 deletion tools/labctl/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.41.7
github.com/aws/aws-sdk-go-v2/config v1.32.17
github.com/aws/aws-sdk-go-v2/service/lambda v1.90.1
github.com/cenkalti/backoff/v4 v4.3.0
github.com/diskfs/go-diskfs v1.9.1
github.com/getsops/sops/v3 v3.12.2
github.com/gilmanlab/platform/schemas/lab v0.2.1
Expand Down Expand Up @@ -113,7 +114,6 @@ require (
github.com/butuzov/mirror v1.3.0 // indirect
github.com/catenacyber/perfsprint v0.10.1 // indirect
github.com/ccojocar/zxcvbn-go v1.0.4 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charithe/durationcheck v0.0.11 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
Expand Down
161 changes: 149 additions & 12 deletions tools/labctl/internal/adapters/httpupstream/client.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,63 @@
package httpupstream

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/cenkalti/backoff/v4"

"github.com/gilmanlab/platform/tools/labctl/internal/app/incusosimage"
)

// Client fetches IncusOS upstream metadata and artifacts over HTTP.
const (
// DefaultResponseHeaderTimeout bounds how long the client will wait for
// upstream response headers before failing the connection.
DefaultResponseHeaderTimeout = 30 * time.Second
// DefaultIdleConnTimeout bounds how long an idle keep-alive connection
// will be held open in the transport's idle pool.
DefaultIdleConnTimeout = 90 * time.Second

retryMaxAttempts = 5
retryInitialInterval = 500 * time.Millisecond
retryMaxInterval = 16 * time.Second
retryRandomization = 0.2

sha256HexLen = 64
)

// NewHTTPClient constructs a [*http.Client] with explicit transport timeouts
// suited to upstream artifact downloads.
//
// The returned client has no top-level Timeout — long-running downloads use
// context cancellation. ResponseHeaderTimeout prevents stalled connections
// from hanging forever waiting for response headers.
func NewHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
ResponseHeaderTimeout: DefaultResponseHeaderTimeout,
IdleConnTimeout: DefaultIdleConnTimeout,
},
}
}

// Client fetches IncusOS and Talos upstream metadata and artifacts over HTTP.
type Client struct {
httpClient *http.Client
}

// New constructs an HTTP upstream adapter.
// New constructs an HTTP upstream adapter. A nil httpClient selects the
// timeout-defaulted client returned by NewHTTPClient.
func New(httpClient *http.Client) Client {
if httpClient == nil {
httpClient = http.DefaultClient
httpClient = NewHTTPClient()
}

return Client{
Expand All @@ -43,6 +82,11 @@ func (c Client) FetchIndex(ctx context.Context, url string) (incusosimage.Index,
}

// Download opens a response body for an upstream artifact.
//
// The returned ReadCloser must be closed by the caller. Body reads are not
// retried; transient mid-stream failures are surfaced as errors and the
// caller is expected to verify integrity (for example, by streaming through
// a SHA256 hasher and comparing against a sidecar fetched via FetchSHA256).
func (c Client) Download(ctx context.Context, url string) (io.ReadCloser, error) {
response, err := c.get(ctx, url)
if err != nil {
Expand All @@ -52,21 +96,114 @@ func (c Client) Download(ctx context.Context, url string) (io.ReadCloser, error)
return response.Body, nil
}

func (c Client) get(ctx context.Context, url string) (*http.Response, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
// FetchSHA256 fetches a checksum sidecar and returns the lowercase hex digest.
//
// The response body is expected to be in `sha256sum -c` format: a single
// line of "<hex> <filename>". Only the first whitespace-delimited token is
// parsed; the filename and any trailing lines are ignored.
func (c Client) FetchSHA256(ctx context.Context, url string) (string, error) {
response, err := c.get(ctx, url)
if err != nil {
return nil, fmt.Errorf("create request %q: %w", url, err)
return "", err
}
defer response.Body.Close()

scanner := bufio.NewScanner(response.Body)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("read sha256 sidecar %q: %w", url, err)
}

return "", fmt.Errorf("sha256 sidecar %q is empty", url)
}

response, err := c.httpClient.Do(request)
if err != nil {
return nil, fmt.Errorf("GET %q: %w", url, err)
fields := strings.Fields(scanner.Text())
if len(fields) == 0 {
return "", fmt.Errorf("sha256 sidecar %q has no digest", url)
}
if response.StatusCode < http.StatusOK || response.StatusCode >= http.StatusMultipleChoices {
defer response.Body.Close()

return nil, fmt.Errorf("GET %q: unexpected HTTP status %s", url, response.Status)
digest := strings.ToLower(fields[0])
if !isHexSHA256(digest) {
return "", fmt.Errorf("sha256 sidecar %q digest %q is not a 64-character hex string", url, digest)
}

return digest, nil
}

func (c Client) get(ctx context.Context, url string) (*http.Response, error) {
var response *http.Response

operation := func() error {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return backoff.Permanent(fmt.Errorf("create request %q: %w", url, err))
}

resp, err := c.httpClient.Do(request)
if err != nil {
return err
}

if isRetryableStatus(resp.StatusCode) {
status := resp.Status
_ = resp.Body.Close()

return fmt.Errorf("GET %q: retryable HTTP status %s", url, status)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
status := resp.Status
_ = resp.Body.Close()

return backoff.Permanent(fmt.Errorf("GET %q: unexpected HTTP status %s", url, status))
}

response = resp

return nil
}

policy := backoff.WithContext(newRetryBackoff(), ctx)
if err := backoff.Retry(operation, policy); err != nil {
if permanent, ok := errors.AsType[*backoff.PermanentError](err); ok {
return nil, permanent.Err
}

return nil, err
}

return response, nil
}

func newRetryBackoff() backoff.BackOff {
expo := backoff.NewExponentialBackOff()
expo.InitialInterval = retryInitialInterval
expo.MaxInterval = retryMaxInterval
expo.RandomizationFactor = retryRandomization
expo.MaxElapsedTime = 0

return backoff.WithMaxRetries(expo, retryMaxAttempts-1)
}

func isRetryableStatus(code int) bool {
if code == http.StatusTooManyRequests {
return true
}

return code >= http.StatusInternalServerError && code <= 599
}

func isHexSHA256(s string) bool {
if len(s) != sha256HexLen {
return false
}
for _, r := range s {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'f':
default:
return false
}
}

return true
}
133 changes: 133 additions & 0 deletions tools/labctl/internal/adapters/httpupstream/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package httpupstream_test

import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gilmanlab/platform/tools/labctl/internal/adapters/httpupstream"
)

func TestClientFetchSHA256ParsesChecksumLine(t *testing.T) {
const digest = "5fa3a23e3f12cf6f33b66e2eb1cd0f8df57f53efb15c1ab8c8f6bb3fa1e02b9d"

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, digest+" nocloud-amd64.raw.xz\n")
}))
t.Cleanup(server.Close)

client := httpupstream.New(server.Client())
got, err := client.FetchSHA256(context.Background(), server.URL+"/x.sha256")

require.NoError(t, err)
assert.Equal(t, digest, got)
}

func TestClientFetchSHA256RejectsMalformedDigest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "not-a-digest\n")
}))
t.Cleanup(server.Close)

client := httpupstream.New(server.Client())
_, err := client.FetchSHA256(context.Background(), server.URL+"/x.sha256")

require.Error(t, err)
assert.Contains(t, err.Error(), "not a 64-character hex string")
}

func TestClientFetchSHA256RejectsEmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(server.Close)

client := httpupstream.New(server.Client())
_, err := client.FetchSHA256(context.Background(), server.URL+"/x.sha256")

require.Error(t, err)
assert.Contains(t, err.Error(), "empty")
}

func TestClientRetriesTransientFailures(t *testing.T) {
var calls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if calls.Add(1) <= 2 {
http.Error(w, "boom", http.StatusBadGateway)

return
}
_, _ = io.WriteString(w, "ok")
}))
t.Cleanup(server.Close)

client := httpupstream.New(server.Client())
body, err := client.Download(context.Background(), server.URL+"/artifact")

require.NoError(t, err)
t.Cleanup(func() { _ = body.Close() })

data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, "ok", string(data))
assert.GreaterOrEqual(t, calls.Load(), int32(3), "expected at least three attempts")
}

func TestClientDoesNotRetryClientErrors(t *testing.T) {
var calls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
calls.Add(1)
http.Error(w, "nope", http.StatusNotFound)
}))
t.Cleanup(server.Close)

client := httpupstream.New(server.Client())
_, err := client.Download(context.Background(), server.URL+"/artifact")

require.Error(t, err)
assert.Contains(t, err.Error(), "404")
assert.Equal(t, int32(1), calls.Load(), "4xx must not be retried")
}

func TestClientStopsRetryingOnContextCancel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "boom", http.StatusServiceUnavailable)
}))
t.Cleanup(server.Close)

client := httpupstream.New(server.Client())

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
t.Cleanup(cancel)

_, err := client.Download(ctx, server.URL+"/artifact")

require.Error(t, err)
assert.True(
t,
errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) || isRetryableErr(err),
"expected ctx error or last retryable error, got %v", err,
)
}

func isRetryableErr(err error) bool {
return err != nil && (assertContains(err.Error(), "503") || assertContains(err.Error(), "retryable"))
}

func assertContains(s, sub string) bool {
for i := range len(s) - len(sub) + 1 {
if s[i:i+len(sub)] == sub {
return true
}
}

return false
}
Loading
Loading