From 676feda7894e0245fc30af223c92c47323baf99b Mon Sep 17 00:00:00 2001 From: Joshua Gilman Date: Mon, 4 May 2026 19:53:05 -0700 Subject: [PATCH] feat(cache): add content-addressed disk store Adds a provider-neutral cache service interface with an XDG-compatible disk-backed implementation for verified blob downloads. Covers cache hits, unknown digests, corrupt entries, validation errors, cancellation, mismatch handling, read-only blobs, and unsafe cache paths. --- internal/cache/cache.go | 469 +++++++++++++++++++++++++++++++ internal/cache/cache_test.go | 522 +++++++++++++++++++++++++++++++++++ 2 files changed, 991 insertions(+) create mode 100644 internal/cache/cache.go create mode 100644 internal/cache/cache_test.go diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a7e97c0 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,469 @@ +// Package cache provides content-addressed local blob caching for providers. +package cache + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" +) + +const ( + cacheDirName = "imgcli" + + bytesPerKiB int64 = 1024 + bytesPerMiB = bytesPerKiB * bytesPerKiB + bytesPerGiB = bytesPerMiB * bytesPerKiB + defaultMaxUnknownSizeGiB = 64 + defaultMaxUnknownSizeBytes = defaultMaxUnknownSizeGiB * bytesPerGiB + + blobPerm = 0o400 + hexCharsPerByte = 2 + digestShardLength = 2 + dirPerm = 0o750 + sha256HexLength = sha256.Size * hexCharsPerByte + tmpPerm = 0o700 + writePermMask = 0o222 +) + +// Service fetches remote blobs into a verified local cache. +type Service interface { + // Fetch returns a verified local cache path for the requested blob. + Fetch(ctx context.Context, req FetchRequest) (Blob, error) +} + +// FetchRequest describes a remote blob to fetch into the cache. +type FetchRequest struct { + // URL is the remote HTTP URL to download. + URL string + + // ExpectedSHA256 is the optional expected SHA-256 digest in lowercase or uppercase hex. + ExpectedSHA256 string + + // ExpectedSize is the optional expected blob size in bytes. Zero means unknown. + ExpectedSize int64 +} + +// Blob describes a verified blob stored in the cache. +type Blob struct { + // Path is the local path to the immutable cached blob. + Path string + + // SHA256 is the blob SHA-256 digest in lowercase hex. + SHA256 string + + // Size is the blob size in bytes. + Size int64 +} + +// Option configures a DiskStore. +type Option func(*DiskStore) + +// DiskStore is a filesystem-backed content-addressed cache. +type DiskStore struct { + root string + httpClient *http.Client + maxUnknownSizeBytes int64 +} + +// WithRoot configures the cache root directory. +func WithRoot(root string) Option { + return func(store *DiskStore) { + store.root = root + } +} + +// WithMaxUnknownSizeBytes configures the download cap when FetchRequest.ExpectedSize is unknown. +// +// Zero makes unknown-size downloads explicitly unbounded. +func WithMaxUnknownSizeBytes(size int64) Option { + return func(store *DiskStore) { + store.maxUnknownSizeBytes = size + } +} + +// NewDiskStore constructs a disk-backed cache store. +func NewDiskStore(options ...Option) (*DiskStore, error) { + root, err := defaultRoot() + if err != nil { + return nil, err + } + + store := &DiskStore{ + root: root, + httpClient: http.DefaultClient, + maxUnknownSizeBytes: defaultMaxUnknownSizeBytes, + } + for _, option := range options { + option(store) + } + + if strings.TrimSpace(store.root) == "" { + return nil, errors.New("cache root is required") + } + if store.maxUnknownSizeBytes < 0 { + return nil, errors.New("cache max unknown-size bytes must be non-negative") + } + + return store, nil +} + +// Fetch returns a verified local cache path for the requested blob. +func (s *DiskStore) Fetch(ctx context.Context, req FetchRequest) (Blob, error) { + normalized, err := normalizeFetchRequest(req) + if err != nil { + return Blob{}, err + } + + if normalized.ExpectedSHA256 != "" { + path := s.blobPath(normalized.ExpectedSHA256) + if pathErr := s.ensureExpectedDigestPath(path); pathErr != nil { + return Blob{}, pathErr + } + blob, ok, verifyErr := verifyExistingBlob(path, normalized.ExpectedSHA256, normalized.ExpectedSize) + if verifyErr != nil { + return Blob{}, verifyErr + } + if ok { + return blob, nil + } + if removeErr := os.Remove(path); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) { + return Blob{}, fmt.Errorf("remove corrupt cached blob: %w", removeErr) + } + } + + tmpPath, downloaded, err := s.download(ctx, normalized) + if tmpPath != "" { + defer os.Remove(tmpPath) + } + if err != nil { + return Blob{}, err + } + + blob, err := s.publish(tmpPath, downloaded) + if err != nil { + return Blob{}, err + } + + return blob, nil +} + +func (s *DiskStore) ensureExpectedDigestPath(path string) error { + if err := s.ensureDirs(); err != nil { + return err + } + if err := ensureCacheDir(filepath.Dir(path), dirPerm); err != nil { + return fmt.Errorf("validate cache blob shard directory: %w", err) + } + + return nil +} + +func defaultRoot() (string, error) { + userCacheDir, err := os.UserCacheDir() + if err != nil { + return "", fmt.Errorf("resolve user cache directory: %w", err) + } + + return filepath.Join(userCacheDir, cacheDirName), nil +} + +func normalizeFetchRequest(req FetchRequest) (FetchRequest, error) { + req.URL = strings.TrimSpace(req.URL) + req.ExpectedSHA256 = strings.ToLower(strings.TrimSpace(req.ExpectedSHA256)) + + if req.URL == "" { + return FetchRequest{}, errors.New("cache fetch URL is required") + } + if req.ExpectedSize < 0 { + return FetchRequest{}, errors.New("cache expected size must be non-negative") + } + if req.ExpectedSHA256 != "" { + if len(req.ExpectedSHA256) != sha256HexLength { + return FetchRequest{}, fmt.Errorf("cache expected SHA-256 must be %d hex characters", sha256HexLength) + } + if _, err := hex.DecodeString(req.ExpectedSHA256); err != nil { + return FetchRequest{}, fmt.Errorf("cache expected SHA-256 must be hex: %w", err) + } + } + + return req, nil +} + +func (s *DiskStore) download(ctx context.Context, req FetchRequest) (string, Blob, error) { + if err := s.ensureDirs(); err != nil { + return "", Blob{}, err + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, req.URL, nil) + if err != nil { + return "", Blob{}, fmt.Errorf("create cache fetch request: %w", err) + } + + resp, err := s.httpClientOrDefault().Do(httpReq) + if err != nil { + return "", Blob{}, fmt.Errorf("fetch cache blob: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", Blob{}, fmt.Errorf("fetch cache blob: unexpected HTTP status %s", resp.Status) + } + + tmp, err := os.CreateTemp(s.tmpDir(), "fetch-*.tmp") + if err != nil { + return "", Blob{}, fmt.Errorf("create cache temp file: %w", err) + } + tmpPath := tmp.Name() + closeTmp := true + defer func() { + if closeTmp { + _ = tmp.Close() + } + }() + + hasher := sha256.New() + size, err := io.Copy(io.MultiWriter(tmp, hasher), limitedBody(resp.Body, s.downloadLimit(req))) + if err != nil { + return tmpPath, Blob{}, fmt.Errorf("write cache temp file: %w", err) + } + if err := tmp.Close(); err != nil { + return tmpPath, Blob{}, fmt.Errorf("close cache temp file: %w", err) + } + closeTmp = false + + digest := hex.EncodeToString(hasher.Sum(nil)) + if req.ExpectedSize != 0 && size != req.ExpectedSize { + return tmpPath, Blob{}, fmt.Errorf( + "verify cache blob size: got %d bytes, want %d bytes", + size, + req.ExpectedSize, + ) + } + if req.ExpectedSize == 0 && s.maxUnknownSizeBytes != 0 && size > s.maxUnknownSizeBytes { + return tmpPath, Blob{}, fmt.Errorf( + "verify cache blob size: got more than %d bytes for unknown-size fetch", + s.maxUnknownSizeBytes, + ) + } + if req.ExpectedSHA256 != "" && digest != req.ExpectedSHA256 { + return tmpPath, Blob{}, fmt.Errorf("verify cache blob SHA-256: got %s, want %s", digest, req.ExpectedSHA256) + } + if err := os.Chmod(tmpPath, blobPerm); err != nil { + return tmpPath, Blob{}, fmt.Errorf("make cache temp file read-only: %w", err) + } + + return tmpPath, Blob{ + Path: s.blobPath(digest), + SHA256: digest, + Size: size, + }, nil +} + +func (s *DiskStore) downloadLimit(req FetchRequest) int64 { + if req.ExpectedSize != 0 { + return req.ExpectedSize + } + + return s.maxUnknownSizeBytes +} + +func limitedBody(reader io.Reader, limit int64) io.Reader { + if limit == 0 { + return reader + } + + return io.LimitReader(reader, limit+1) +} + +func (s *DiskStore) publish(tmpPath string, downloaded Blob) (Blob, error) { + finalPath := s.blobPath(downloaded.SHA256) + if err := ensureCacheDir(filepath.Dir(finalPath), dirPerm); err != nil { + return Blob{}, fmt.Errorf("create cache blob shard directory: %w", err) + } + + if err := os.Link(tmpPath, finalPath); err == nil { + return Blob{ + Path: finalPath, + SHA256: downloaded.SHA256, + Size: downloaded.Size, + }, nil + } else if !errors.Is(err, os.ErrExist) { + return Blob{}, fmt.Errorf("publish cache blob: %w", err) + } + + blob, ok, err := verifyExistingBlob(finalPath, downloaded.SHA256, downloaded.Size) + if err != nil { + return Blob{}, err + } + if ok { + return blob, nil + } + + if err := os.Remove(finalPath); err != nil && !errors.Is(err, os.ErrNotExist) { + return Blob{}, fmt.Errorf("remove corrupt cached blob: %w", err) + } + if err := os.Link(tmpPath, finalPath); err != nil { + return handlePublishConflict(err, finalPath, downloaded) + } + + return Blob{ + Path: finalPath, + SHA256: downloaded.SHA256, + Size: downloaded.Size, + }, nil +} + +func handlePublishConflict(err error, finalPath string, downloaded Blob) (Blob, error) { + if !errors.Is(err, os.ErrExist) { + return Blob{}, fmt.Errorf("publish cache blob: %w", err) + } + + blob, ok, verifyErr := verifyExistingBlob(finalPath, downloaded.SHA256, downloaded.Size) + if verifyErr != nil { + return Blob{}, verifyErr + } + if ok { + return blob, nil + } + + return Blob{}, fmt.Errorf("publish cache blob: existing cache blob %q is corrupt", finalPath) +} + +func (s *DiskStore) ensureDirs() error { + for _, dir := range []struct { + path string + perm os.FileMode + }{ + {path: s.root, perm: dirPerm}, + {path: filepath.Join(s.root, "blobs"), perm: dirPerm}, + {path: filepath.Join(s.root, "blobs", "sha256"), perm: dirPerm}, + {path: s.tmpDir(), perm: tmpPerm}, + } { + if err := ensureCacheDir(dir.path, dir.perm); err != nil { + return err + } + } + + return nil +} + +func ensureCacheDir(path string, perm os.FileMode) error { + if err := os.MkdirAll(path, perm); err != nil { + return fmt.Errorf("create cache directory %q: %w", path, err) + } + if err := validateCacheDir(path, perm); err != nil { + return err + } + + return nil +} + +func validateCacheDir(path string, perm os.FileMode) error { + info, err := os.Lstat(path) + if err != nil { + return fmt.Errorf("inspect cache directory %q: %w", path, err) + } + if !info.Mode().IsDir() { + return fmt.Errorf("cache directory %q is not a directory", path) + } + if info.Mode().Perm() != perm { + if err := os.Chmod(path, perm); err != nil { + return fmt.Errorf("repair cache directory permissions %q: %w", path, err) + } + } + + return nil +} + +func verifyExistingBlob(path string, expectedSHA256 string, expectedSize int64) (Blob, bool, error) { + info, err := os.Lstat(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return Blob{}, false, nil + } + return Blob{}, false, fmt.Errorf("inspect cached blob: %w", err) + } + if !info.Mode().IsRegular() { + return Blob{}, false, nil + } + digest, size, err := hashFile(path, info) + if err != nil { + return Blob{}, false, fmt.Errorf("verify cached blob: %w", err) + } + if digest != expectedSHA256 { + return Blob{}, false, nil + } + if expectedSize != 0 && size != expectedSize { + return Blob{}, false, fmt.Errorf( + "cached blob size conflicts with request: got %d bytes, want %d bytes", + size, + expectedSize, + ) + } + if info.Mode().Perm()&writePermMask != 0 { + if err := os.Chmod(path, blobPerm); err != nil { + return Blob{}, false, fmt.Errorf("repair cached blob permissions: %w", err) + } + } + + return Blob{ + Path: path, + SHA256: digest, + Size: size, + }, true, nil +} + +func (s *DiskStore) blobPath(sha256Digest string) string { + return filepath.Join(s.root, "blobs", "sha256", sha256Digest[:digestShardLength], sha256Digest) +} + +func (s *DiskStore) tmpDir() string { + return filepath.Join(s.root, "tmp") +} + +func (s *DiskStore) httpClientOrDefault() *http.Client { + if s.httpClient == nil { + return http.DefaultClient + } + + return s.httpClient +} + +func hashFile(path string, expected os.FileInfo) (string, int64, error) { + file, err := os.Open(path) + if err != nil { + return "", 0, err + } + defer file.Close() + + current, err := os.Lstat(path) + if err != nil { + return "", 0, err + } + if !current.Mode().IsRegular() || !os.SameFile(expected, current) { + return "", 0, fmt.Errorf("cached blob changed while opening %q", path) + } + + info, err := file.Stat() + if err != nil { + return "", 0, err + } + if !os.SameFile(expected, info) { + return "", 0, fmt.Errorf("cached blob changed while opening %q", path) + } + + hasher := sha256.New() + size, err := io.Copy(hasher, file) + if err != nil { + return "", 0, err + } + + return hex.EncodeToString(hasher.Sum(nil)), size, nil +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..428c676 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,522 @@ +package cache_test + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/meigma/imgcli/internal/cache" +) + +var _ cache.Service = (*cache.DiskStore)(nil) + +func TestDiskStoreFetchKnownDigest(t *testing.T) { + body := []byte("provider image bytes") + server, requests := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)), + }) + + require.NoError(t, err) + assert.Equal(t, cache.Blob{ + Path: cachedBlobPath(root, digest), + SHA256: digest, + Size: int64(len(body)), + }, blob) + assert.Equal(t, int64(1), requests.Load()) + assertFileContent(t, blob.Path, body) + assertReadOnlyFile(t, blob.Path) +} + +func TestDiskStoreFetchUsesVerifiedCacheHit(t *testing.T) { + body := []byte("already cached bytes") + server, requests := newBlobServer(t, http.StatusInternalServerError, []byte("should not be requested")) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + writeCachedBlob(t, root, digest, body) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)), + }) + + require.NoError(t, err) + assert.Equal(t, cachedBlobPath(root, digest), blob.Path) + assert.Equal(t, digest, blob.SHA256) + assert.Equal(t, int64(len(body)), blob.Size) + assert.Equal(t, int64(0), requests.Load()) + assertReadOnlyFile(t, blob.Path) +} + +func TestDiskStoreFetchRejectsCacheHitWithConflictingSize(t *testing.T) { + body := []byte("already cached bytes") + server, requests := newBlobServer(t, http.StatusInternalServerError, []byte("should not be requested")) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + writeCachedBlob(t, root, digest, body) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)) + 1, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "cached blob size conflicts with request") + assert.Empty(t, blob) + assert.Equal(t, int64(0), requests.Load()) + assertFileContent(t, cachedBlobPath(root, digest), body) +} + +func TestDiskStoreFetchUnknownDigest(t *testing.T) { + body := []byte("unknown digest bytes") + server, requests := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + }) + + require.NoError(t, err) + assert.Equal(t, cache.Blob{ + Path: cachedBlobPath(root, digest), + SHA256: digest, + Size: int64(len(body)), + }, blob) + assert.Equal(t, int64(1), requests.Load()) + assertFileContent(t, blob.Path, body) + assertReadOnlyFile(t, blob.Path) +} + +func TestDiskStoreFetchReplacesCorruptCacheEntry(t *testing.T) { + body := []byte("correct cached bytes") + server, requests := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + writeCachedBlob(t, root, digest, []byte("corrupt")) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)), + }) + + require.NoError(t, err) + assert.Equal(t, cachedBlobPath(root, digest), blob.Path) + assert.Equal(t, digest, blob.SHA256) + assert.Equal(t, int64(1), requests.Load()) + assertFileContent(t, blob.Path, body) + assertReadOnlyFile(t, blob.Path) +} + +func TestDiskStoreFetchReplacesSymlinkCacheEntry(t *testing.T) { + body := []byte("correct cached bytes") + server, requests := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + path := cachedBlobPath(root, digest) + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o750)) + + externalPath := filepath.Join(t.TempDir(), "external") + require.NoError(t, os.WriteFile(externalPath, body, 0o600)) + if err := os.Symlink(externalPath, path); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)), + }) + + require.NoError(t, err) + assert.Equal(t, cachedBlobPath(root, digest), blob.Path) + assert.Equal(t, digest, blob.SHA256) + assert.Equal(t, int64(1), requests.Load()) + info, err := os.Lstat(blob.Path) + require.NoError(t, err) + assert.True(t, info.Mode().IsRegular()) + assertFileContent(t, blob.Path, body) + assertReadOnlyFile(t, blob.Path) +} + +func TestDiskStoreFetchRejectsSymlinkShardDirectory(t *testing.T) { + body := []byte("correct cached bytes") + server, requests := newBlobServer(t, http.StatusInternalServerError, []byte("should not be requested")) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + shardPath := filepath.Join(root, "blobs", "sha256", digest[:2]) + require.NoError(t, os.MkdirAll(filepath.Dir(shardPath), 0o750)) + + externalDir := t.TempDir() + externalBlob := filepath.Join(externalDir, digest) + require.NoError(t, os.WriteFile(externalBlob, body, 0o600)) + if err := os.Symlink(externalDir, shardPath); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)), + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "validate cache blob shard directory") + assert.Empty(t, blob) + assert.Equal(t, int64(0), requests.Load()) + assertFileContent(t, externalBlob, body) +} + +func TestDiskStoreFetchValidatesInputs(t *testing.T) { + tests := []struct { + name string + req cache.FetchRequest + wantErr string + }{ + { + name: "empty URL", + req: cache.FetchRequest{ExpectedSHA256: sha256Hex([]byte("bytes"))}, + wantErr: "cache fetch URL is required", + }, + { + name: "invalid digest", + req: cache.FetchRequest{ + URL: "https://example.invalid/blob", + ExpectedSHA256: "not-a-sha256", + }, + wantErr: "cache expected SHA-256 must be", + }, + { + name: "negative size", + req: cache.FetchRequest{ + URL: "https://example.invalid/blob", + ExpectedSize: -1, + }, + wantErr: "cache expected size must be non-negative", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := newDiskStore(t, t.TempDir()) + + blob, err := store.Fetch(context.Background(), tt.req) + + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Empty(t, blob) + }) + } +} + +func TestDiskStoreFetchHandlesHTTPFailure(t *testing.T) { + body := []byte("unavailable") + server, _ := newBlobServer(t, http.StatusServiceUnavailable, body) + root := t.TempDir() + store := newDiskStore(t, root) + digest := sha256Hex(body) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + ExpectedSHA256: digest, + ExpectedSize: int64(len(body)), + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected HTTP status 503 Service Unavailable") + assert.Empty(t, blob) + assert.NoFileExists(t, cachedBlobPath(root, digest)) + assertTmpEmpty(t, root) +} + +func TestDiskStoreFetchHonorsContext(t *testing.T) { + server, _ := newBlobServer(t, http.StatusOK, []byte("bytes")) + root := t.TempDir() + store := newDiskStore(t, root) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + blob, err := store.Fetch(ctx, cache.FetchRequest{ + URL: server.URL + "/blob", + }) + + require.ErrorIs(t, err, context.Canceled) + assert.Empty(t, blob) + assertTmpEmpty(t, root) +} + +func TestDiskStoreFetchVerifiesDownloadedContent(t *testing.T) { + tests := []struct { + name string + req func(url string, body []byte) cache.FetchRequest + wantErr string + }{ + { + name: "digest mismatch", + req: func(url string, body []byte) cache.FetchRequest { + return cache.FetchRequest{ + URL: url, + ExpectedSHA256: sha256Hex([]byte("different")), + ExpectedSize: int64(len(body)), + } + }, + wantErr: "verify cache blob SHA-256", + }, + { + name: "size mismatch", + req: func(url string, body []byte) cache.FetchRequest { + return cache.FetchRequest{ + URL: url, + ExpectedSHA256: sha256Hex(body), + ExpectedSize: int64(len(body)) + 1, + } + }, + wantErr: "verify cache blob size", + }, + { + name: "oversized download", + req: func(url string, body []byte) cache.FetchRequest { + return cache.FetchRequest{ + URL: url, + ExpectedSHA256: sha256Hex(body), + ExpectedSize: int64(len(body) - 1), + } + }, + wantErr: "verify cache blob size", + }, + { + name: "unknown-size download exceeds cache cap", + req: func(url string, _ []byte) cache.FetchRequest { + return cache.FetchRequest{ + URL: url, + } + }, + wantErr: "unknown-size fetch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := []byte("downloaded bytes") + server, _ := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + store := newDiskStoreWithOptions( + t, + cache.WithRoot(root), + cache.WithMaxUnknownSizeBytes(int64(len(body)-1)), + ) + if tt.name != "unknown-size download exceeds cache cap" { + store = newDiskStore(t, root) + } + + blob, err := store.Fetch(context.Background(), tt.req(server.URL+"/blob", body)) + + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Empty(t, blob) + assertTmpEmpty(t, root) + assertNoCachedBlobs(t, root) + }) + } +} + +func TestDiskStoreRepairsUnsafeCacheDirectoryPermissions(t *testing.T) { + body := []byte("cached bytes") + server, _ := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + for _, dir := range []string{ + root, + filepath.Join(root, "blobs"), + filepath.Join(root, "blobs", "sha256"), + filepath.Join(root, "tmp"), + } { + require.NoError(t, os.MkdirAll(dir, 0o750)) + require.NoError(t, os.Chmod(dir, 0o777)) + } + store := newDiskStore(t, root) + + _, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + }) + + require.NoError(t, err) + assertDirPerm(t, root, 0o750) + assertDirPerm(t, filepath.Join(root, "blobs"), 0o750) + assertDirPerm(t, filepath.Join(root, "blobs", "sha256"), 0o750) + assertDirPerm(t, filepath.Join(root, "tmp"), 0o700) +} + +func TestDiskStoreRejectsSymlinkCacheDirectory(t *testing.T) { + body := []byte("cached bytes") + server, requests := newBlobServer(t, http.StatusOK, body) + root := t.TempDir() + target := t.TempDir() + require.NoError(t, os.Symlink(target, filepath.Join(root, "tmp"))) + store := newDiskStore(t, root) + + blob, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "cache directory") + assert.Empty(t, blob) + assert.Equal(t, int64(0), requests.Load()) +} + +func TestDiskStoreCreatesRestrictiveCacheDirectories(t *testing.T) { + body := []byte("cached bytes") + server, _ := newBlobServer(t, http.StatusOK, body) + parent := t.TempDir() + root := filepath.Join(parent, "imgcli-cache") + store := newDiskStore(t, root) + + _, err := store.Fetch(context.Background(), cache.FetchRequest{ + URL: server.URL + "/blob", + }) + + require.NoError(t, err) + assertDirPerm(t, root, 0o750) + assertDirPerm(t, filepath.Join(root, "blobs"), 0o750) + assertDirPerm(t, filepath.Join(root, "blobs", "sha256"), 0o750) + assertDirPerm(t, filepath.Join(root, "tmp"), 0o700) +} + +func TestNewDiskStoreValidatesOptions(t *testing.T) { + store, err := cache.NewDiskStore( + cache.WithRoot(t.TempDir()), + cache.WithMaxUnknownSizeBytes(-1), + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "cache max unknown-size bytes must be non-negative") + assert.Nil(t, store) +} + +func newDiskStore(t *testing.T, root string) *cache.DiskStore { + t.Helper() + + return newDiskStoreWithOptions(t, cache.WithRoot(root)) +} + +func newDiskStoreWithOptions(t *testing.T, options ...cache.Option) *cache.DiskStore { + t.Helper() + + store, err := cache.NewDiskStore(options...) + require.NoError(t, err) + return store +} + +func newBlobServer(t *testing.T, status int, body []byte) (*httptest.Server, *atomic.Int64) { + t.Helper() + + requests := &atomic.Int64{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests.Add(1) + assert.Equal(t, "/blob", r.URL.Path) + w.WriteHeader(status) + _, err := w.Write(body) + assert.NoError(t, err) + })) + t.Cleanup(server.Close) + + return server, requests +} + +func writeCachedBlob(t *testing.T, root string, digest string, body []byte) { + t.Helper() + + path := cachedBlobPath(root, digest) + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o750)) + require.NoError(t, os.WriteFile(path, body, 0o600)) +} + +func cachedBlobPath(root string, digest string) string { + return filepath.Join(root, "blobs", "sha256", digest[:2], digest) +} + +func sha256Hex(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func assertFileContent(t *testing.T, path string, want []byte) { + t.Helper() + + got, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, want, got) +} + +func assertTmpEmpty(t *testing.T, root string) { + t.Helper() + + tmpDir := filepath.Join(root, "tmp") + entries, err := os.ReadDir(tmpDir) + if os.IsNotExist(err) { + return + } + require.NoError(t, err) + assert.Empty(t, entries) +} + +func assertNoCachedBlobs(t *testing.T, root string) { + t.Helper() + + shaDir := filepath.Join(root, "blobs", "sha256") + err := filepath.WalkDir(shaDir, func(path string, entry os.DirEntry, err error) error { + if err != nil { + return err + } + if !entry.IsDir() { + t.Fatalf("unexpected cached blob %s", path) + } + return nil + }) + if os.IsNotExist(err) { + return + } + require.NoError(t, err) +} + +func assertDirPerm(t *testing.T, path string, want os.FileMode) { + t.Helper() + + info, err := os.Stat(path) + require.NoError(t, err) + assert.True(t, info.IsDir()) + assert.Equal(t, want, info.Mode().Perm()) +} + +func assertReadOnlyFile(t *testing.T, path string) { + t.Helper() + + info, err := os.Lstat(path) + require.NoError(t, err) + assert.True(t, info.Mode().IsRegular()) + assert.Equal(t, os.FileMode(0o400), info.Mode().Perm()) +}