From d83d08cd93860d548b233d7933e45652144f024e Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 15 Jun 2026 09:34:57 +0300 Subject: [PATCH] oci: extract artifact-agnostic primitives into oci/artifact MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 0 of plugin support (THV-0077). Move the artifact-agnostic OCI primitives out of oci/skills into a new oci/artifact package so a future oci/plugins can reuse them: - tar (CreateTar/ExtractTar*/FileEntry/TarOptions, MaxTarFileSize) - gzip (Compress/Decompress*/CompressTar/DecompressTar, MaxDecompressedSize) - platform helpers (PlatformString/ParsePlatform/DefaultPlatforms, OS/Arch) - pull-hardening, exported as ValidatingTarget/NewValidatingTarget plus the manifest/blob size caps and manifest-count limits oci/skills re-exports every moved symbol via type aliases, var-forwarding, and const re-declaration (oci/skills/artifact_aliases.go), so the package's public surface is unchanged and downstream consumers (toolhive) are unaffected — no exported signature changed. Behavior-preserving move: function bodies are identical and the oci/skills determinism tests still assert byte-stable artifact digests. Co-Authored-By: Claude Opus 4.8 --- CLAUDE.md | 1 + oci/artifact/doc.go | 36 ++++ oci/{skills => artifact}/gzip.go | 2 +- oci/artifact/gzip_test.go | 190 +++++++++++++++++++ oci/artifact/platform.go | 59 ++++++ oci/artifact/platform_test.go | 128 +++++++++++++ oci/{skills => artifact}/tar.go | 2 +- oci/artifact/tar_test.go | 280 ++++++++++++++++++++++++++++ oci/artifact/testconsts_test.go | 9 + oci/artifact/validate.go | 155 +++++++++++++++ oci/artifact/validate_test.go | 210 +++++++++++++++++++++ oci/skills/artifact_aliases.go | 86 +++++++++ oci/skills/artifact_aliases_test.go | 31 +++ oci/skills/mediatypes.go | 49 ----- oci/skills/packager.go | 24 +-- oci/skills/registry.go | 144 +------------- 16 files changed, 1205 insertions(+), 201 deletions(-) create mode 100644 oci/artifact/doc.go rename oci/{skills => artifact}/gzip.go (99%) create mode 100644 oci/artifact/gzip_test.go create mode 100644 oci/artifact/platform.go create mode 100644 oci/artifact/platform_test.go rename oci/{skills => artifact}/tar.go (99%) create mode 100644 oci/artifact/tar_test.go create mode 100644 oci/artifact/testconsts_test.go create mode 100644 oci/artifact/validate.go create mode 100644 oci/artifact/validate_test.go create mode 100644 oci/skills/artifact_aliases.go create mode 100644 oci/skills/artifact_aliases_test.go diff --git a/CLAUDE.md b/CLAUDE.md index d0deedf..7552334 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -72,6 +72,7 @@ task license-fix # Add missing license headers | `env` | Environment variable abstraction with `Reader` interface for testable code | | `httperr` | Wrap errors with HTTP status codes; use `WithCode()`, `Code()`, `New()` | | `logging` | Pre-configured `*slog.Logger` factory with consistent ToolHive defaults (Alpha) | +| `oci/artifact` | Artifact-agnostic OCI tar/gzip/extraction/platform primitives shared by oci/skills and oci/plugins (Alpha) | | `oci/skills` | OCI artifact types, media types, and registry operations for ToolHive skills (Alpha) | | `postgres` | PostgreSQL connection pool with optional AWS RDS IAM dynamic auth (Alpha) | | `recovery` | HTTP panic recovery middleware (Beta) | diff --git a/oci/artifact/doc.go b/oci/artifact/doc.go new file mode 100644 index 0000000..949578b --- /dev/null +++ b/oci/artifact/doc.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +/* +Package artifact provides artifact-agnostic OCI primitives shared by the +ToolHive ecosystem: reproducible tar archive creation and extraction, +reproducible gzip compression, OCI platform helpers, and pull-hardening +(size/count/digest validation) for registry operations. + +These primitives are independent of any particular artifact type (skills, +plugins, etc.). Artifact-specific media types, labels, and annotations live in +the packages that define those artifacts (for example oci/skills). + +# Reproducible Archives + +CreateTar and Compress produce byte-stable output for identical input, which is +what makes artifact digests deterministic: + + data, err := artifact.CompressTar(files, artifact.DefaultTarOptions(), artifact.DefaultGzipOptions()) + +# Platform Helpers + +PlatformString and ParsePlatform convert between OCI platform values and their +"os/arch" or "os/arch/variant" string form. + +# Pull Hardening + +ValidatingTarget wraps an oras.Target and enforces size and structure limits on +pushed content, defending against OOM and resource exhaustion from malicious +registries during pull operations. + +# Stability + +This package is Alpha. Breaking changes are possible between minor versions. +*/ +package artifact diff --git a/oci/skills/gzip.go b/oci/artifact/gzip.go similarity index 99% rename from oci/skills/gzip.go rename to oci/artifact/gzip.go index 1959158..3d45d4c 100644 --- a/oci/skills/gzip.go +++ b/oci/artifact/gzip.go @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -package skills +package artifact import ( "bytes" diff --git a/oci/artifact/gzip_test.go b/oci/artifact/gzip_test.go new file mode 100644 index 0000000..9e69f15 --- /dev/null +++ b/oci/artifact/gzip_test.go @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +import ( + "bytes" + "compress/gzip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompress_Reproducible(t *testing.T) { + t.Parallel() + + data := []byte("test data for compression") + opts := DefaultGzipOptions() + + gz1, err := Compress(data, opts) + require.NoError(t, err) + + gz2, err := Compress(data, opts) + require.NoError(t, err) + + assert.Equal(t, gz1, gz2, "Compress should produce identical output for same input") +} + +func TestCompress_HeaderFieldsForReproducibility(t *testing.T) { + t.Parallel() + + data := []byte("test data") + epoch := time.Unix(1234567890, 0).UTC() + opts := GzipOptions{ + Level: gzip.BestCompression, + Epoch: epoch, + } + + compressed, err := Compress(data, opts) + require.NoError(t, err) + + gr, err := gzip.NewReader(bytes.NewReader(compressed)) + require.NoError(t, err) + defer gr.Close() + + assert.True(t, gr.ModTime.Equal(epoch), "ModTime should match epoch") + assert.Empty(t, gr.Name, "Name should be empty") + assert.Empty(t, gr.Comment, "Comment should be empty") + assert.Equal(t, byte(gzipOSUnknown), gr.OS, "OS should be 255 (unknown)") +} + +func TestCompress_DifferentEpochs(t *testing.T) { + t.Parallel() + + data := []byte("test data") + + tests := []struct { + name string + epoch1 time.Time + epoch2 time.Time + wantEqual bool + }{ + { + name: "same epoch produces same output", + epoch1: time.Unix(1609459200, 0).UTC(), + epoch2: time.Unix(1609459200, 0).UTC(), + wantEqual: true, + }, + { + name: "different epochs produce different output", + epoch1: time.Unix(0, 0).UTC(), + epoch2: time.Unix(1000000, 0).UTC(), + wantEqual: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + opts1 := GzipOptions{Level: gzip.BestCompression, Epoch: tt.epoch1} + opts2 := GzipOptions{Level: gzip.BestCompression, Epoch: tt.epoch2} + + gz1, err := Compress(data, opts1) + require.NoError(t, err) + + gz2, err := Compress(data, opts2) + require.NoError(t, err) + + if tt.wantEqual { + assert.Equal(t, gz1, gz2) + } else { + assert.NotEqual(t, gz1, gz2) + } + }) + } +} + +func TestCompress_SameEpochAlwaysReproducible(t *testing.T) { + t.Parallel() + + data := []byte("test data for reproducibility check") + epoch := time.Unix(1609459200, 0).UTC() + opts := GzipOptions{Level: gzip.BestCompression, Epoch: epoch} + + results := make([][]byte, 5) + for i := range results { + var err error + results[i], err = Compress(data, opts) + require.NoError(t, err) + } + + for i := 1; i < len(results); i++ { + assert.Equal(t, results[0], results[i], "iteration %d should match", i) + } +} + +func TestCompressDecompress_RoundTrip(t *testing.T) { + t.Parallel() + + original := []byte("test data for round trip") + opts := DefaultGzipOptions() + + compressed, err := Compress(original, opts) + require.NoError(t, err) + + decompressed, err := Decompress(compressed) + require.NoError(t, err) + + assert.Equal(t, original, decompressed) +} + +func TestDecompressWithLimit_RejectsOversized(t *testing.T) { + t.Parallel() + + // Create compressed data that exceeds the limit when decompressed + data := bytes.Repeat([]byte("x"), 1024) + compressed, err := Compress(data, DefaultGzipOptions()) + require.NoError(t, err) + + _, err = DecompressWithLimit(compressed, 100) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum size") +} + +func TestCompressTar_Reproducible(t *testing.T) { + t.Parallel() + + files := []FileEntry{ + {Path: testFileB, Content: []byte("content b")}, + {Path: testFileA, Content: []byte("content a")}, + } + + tarOpts := DefaultTarOptions() + gzipOpts := DefaultGzipOptions() + + gz1, err := CompressTar(files, tarOpts, gzipOpts) + require.NoError(t, err) + + gz2, err := CompressTar(files, tarOpts, gzipOpts) + require.NoError(t, err) + + assert.Equal(t, gz1, gz2, "CompressTar should produce identical output") +} + +func TestCompressTar_RoundTrip(t *testing.T) { + t.Parallel() + + originalFiles := []FileEntry{ + {Path: testFileA, Content: []byte("content a")}, + {Path: "dir/" + testFileB, Content: []byte("content b")}, + } + + tarOpts := DefaultTarOptions() + gzipOpts := DefaultGzipOptions() + + compressed, err := CompressTar(originalFiles, tarOpts, gzipOpts) + require.NoError(t, err) + + extractedFiles, err := DecompressTar(compressed) + require.NoError(t, err) + + require.Len(t, extractedFiles, len(originalFiles)) + for i, f := range extractedFiles { + assert.Equal(t, originalFiles[i].Path, f.Path) + assert.Equal(t, originalFiles[i].Content, f.Content) + } +} diff --git a/oci/artifact/platform.go b/oci/artifact/platform.go new file mode 100644 index 0000000..6e1cf68 --- /dev/null +++ b/oci/artifact/platform.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +import ( + "fmt" + "strings" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// PlatformString returns the platform in "os/arch" or "os/arch/variant" format. +func PlatformString(p ocispec.Platform) string { + s := p.OS + "/" + p.Architecture + if p.Variant != "" { + s += "/" + p.Variant + } + return s +} + +// ParsePlatform parses a platform string in "os/arch" or "os/arch/variant" format. +func ParsePlatform(s string) (ocispec.Platform, error) { + parts := strings.Split(s, "/") + if len(parts) < 2 || len(parts) > 3 { + return ocispec.Platform{}, fmt.Errorf("invalid platform format: %q (expected os/arch or os/arch/variant)", s) + } + osName := strings.TrimSpace(parts[0]) + arch := strings.TrimSpace(parts[1]) + if osName == "" || arch == "" { + return ocispec.Platform{}, fmt.Errorf("invalid platform format: %q (os and arch cannot be empty)", s) + } + p := ocispec.Platform{OS: osName, Architecture: arch} + if len(parts) == 3 { + variant := strings.TrimSpace(parts[2]) + if variant == "" { + return ocispec.Platform{}, fmt.Errorf("invalid platform format: %q (variant cannot be empty)", s) + } + p.Variant = variant + } + return p, nil +} + +// OS and architecture constants for OCI platform specifications. +const ( + // OSLinux is the Linux OS identifier used in OCI platform specs. + OSLinux = "linux" + // ArchAMD64 is the x86-64 architecture identifier used in OCI platform specs. + ArchAMD64 = "amd64" + // ArchARM64 is the 64-bit ARM architecture identifier used in OCI platform specs. + ArchARM64 = "arm64" +) + +// DefaultPlatforms are the default platforms for artifacts. +// These cover most Kubernetes clusters. +var DefaultPlatforms = []ocispec.Platform{ + {OS: OSLinux, Architecture: ArchAMD64}, + {OS: OSLinux, Architecture: ArchARM64}, +} diff --git a/oci/artifact/platform_test.go b/oci/artifact/platform_test.go new file mode 100644 index 0000000..398f92e --- /dev/null +++ b/oci/artifact/platform_test.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +import ( + "testing" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testArchARM is the 32-bit ARM architecture identifier used in test platform specs. +const testArchARM = "arm" + +func TestParsePlatform(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want ocispec.Platform + wantErr bool + }{ + { + name: "os/arch", + input: "linux/amd64", + want: ocispec.Platform{OS: OSLinux, Architecture: ArchAMD64}, + }, + { + name: "os/arch/variant", + input: "linux/arm/v7", + want: ocispec.Platform{OS: OSLinux, Architecture: testArchARM, Variant: "v7"}, + }, + { + name: "fewer than 2 parts (no slash)", + input: "linuxamd64", + wantErr: true, + }, + { + name: "more than 3 parts", + input: "linux/amd64/v8/extra", + wantErr: true, + }, + { + name: "empty os", + input: "/amd64", + wantErr: true, + }, + { + name: "empty arch", + input: "linux/", + wantErr: true, + }, + { + name: "empty variant", + input: "linux/arm/", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := ParsePlatform(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestPlatformString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + platform ocispec.Platform + want string + }{ + { + name: "os/arch", + platform: ocispec.Platform{OS: OSLinux, Architecture: ArchAMD64}, + want: "linux/amd64", + }, + { + name: "os/arch/variant", + platform: ocispec.Platform{OS: OSLinux, Architecture: testArchARM, Variant: "v7"}, + want: "linux/arm/v7", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, PlatformString(tt.platform)) + }) + } +} + +func TestParsePlatform_PlatformString_Roundtrip(t *testing.T) { + t.Parallel() + + platforms := []ocispec.Platform{ + {OS: OSLinux, Architecture: ArchAMD64}, + {OS: OSLinux, Architecture: ArchARM64}, + {OS: OSLinux, Architecture: testArchARM, Variant: "v7"}, + } + + for _, p := range platforms { + parsed, err := ParsePlatform(PlatformString(p)) + require.NoError(t, err) + assert.Equal(t, p, parsed) + } +} + +func TestDefaultPlatforms(t *testing.T) { + t.Parallel() + + require.Len(t, DefaultPlatforms, 2) + assert.Equal(t, ocispec.Platform{OS: OSLinux, Architecture: ArchAMD64}, DefaultPlatforms[0]) + assert.Equal(t, ocispec.Platform{OS: OSLinux, Architecture: ArchARM64}, DefaultPlatforms[1]) +} diff --git a/oci/skills/tar.go b/oci/artifact/tar.go similarity index 99% rename from oci/skills/tar.go rename to oci/artifact/tar.go index 1533374..77c0fec 100644 --- a/oci/skills/tar.go +++ b/oci/artifact/tar.go @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -package skills +package artifact import ( "archive/tar" diff --git a/oci/artifact/tar_test.go b/oci/artifact/tar_test.go new file mode 100644 index 0000000..6c42d9f --- /dev/null +++ b/oci/artifact/tar_test.go @@ -0,0 +1,280 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +import ( + "archive/tar" + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateTar_Reproducible(t *testing.T) { + t.Parallel() + + files := []FileEntry{ + {Path: testFileB, Content: []byte("content b")}, + {Path: testFileA, Content: []byte("content a")}, + {Path: "c/d.txt", Content: []byte("content d")}, + } + + opts := DefaultTarOptions() + + tar1, err := CreateTar(files, opts) + require.NoError(t, err) + + tar2, err := CreateTar(files, opts) + require.NoError(t, err) + + assert.Equal(t, tar1, tar2, "CreateTar should produce identical output for same input") +} + +func TestCreateTar_DifferentOrder(t *testing.T) { + t.Parallel() + + files1 := []FileEntry{ + {Path: testFileB, Content: []byte("b")}, + {Path: testFileA, Content: []byte("a")}, + } + + files2 := []FileEntry{ + {Path: testFileA, Content: []byte("a")}, + {Path: testFileB, Content: []byte("b")}, + } + + opts := DefaultTarOptions() + + tar1, err := CreateTar(files1, opts) + require.NoError(t, err) + + tar2, err := CreateTar(files2, opts) + require.NoError(t, err) + + assert.Equal(t, tar1, tar2, "CreateTar should sort files internally") +} + +func TestCreateTar_DifferentTimestamps(t *testing.T) { + t.Parallel() + + files := []FileEntry{ + {Path: "test.txt", Content: []byte("test")}, + } + + tests := []struct { + name string + opts1 TarOptions + opts2 TarOptions + wantEqual bool + }{ + { + name: "same epoch produces same output", + opts1: TarOptions{Epoch: time.Unix(0, 0).UTC()}, + opts2: TarOptions{Epoch: time.Unix(0, 0).UTC()}, + wantEqual: true, + }, + { + name: "different epochs produce different output", + opts1: TarOptions{Epoch: time.Unix(0, 0).UTC()}, + opts2: TarOptions{Epoch: time.Unix(1000000, 0).UTC()}, + wantEqual: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + tar1, err := CreateTar(files, tt.opts1) + require.NoError(t, err) + + tar2, err := CreateTar(files, tt.opts2) + require.NoError(t, err) + + if tt.wantEqual { + assert.Equal(t, tar1, tar2) + } else { + assert.NotEqual(t, tar1, tar2) + } + }) + } +} + +func TestExtractTar_RoundTrip(t *testing.T) { + t.Parallel() + + originalFiles := []FileEntry{ + {Path: testFileA, Content: []byte("content a")}, + {Path: "b/c.txt", Content: []byte("content c")}, + } + + tarData, err := CreateTar(originalFiles, DefaultTarOptions()) + require.NoError(t, err) + + extractedFiles, err := ExtractTar(tarData) + require.NoError(t, err) + + require.Len(t, extractedFiles, len(originalFiles)) + + for i, f := range extractedFiles { + assert.Equal(t, originalFiles[i].Path, f.Path) + assert.Equal(t, originalFiles[i].Content, f.Content) + } +} + +func TestCreateTar_EmptyFiles(t *testing.T) { + t.Parallel() + + tarData, err := CreateTar(nil, DefaultTarOptions()) + require.NoError(t, err) + + extractedFiles, err := ExtractTar(tarData) + require.NoError(t, err) + + assert.Empty(t, extractedFiles) +} + +func TestExtractTar_RejectsSymlinks(t *testing.T) { + t.Parallel() + + // Create a tar with a symlink entry + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + hdr := &tar.Header{ + Name: "malicious_link", + Typeflag: tar.TypeSymlink, + Linkname: "/etc/passwd", + } + require.NoError(t, tw.WriteHeader(hdr)) + require.NoError(t, tw.Close()) + + _, err := ExtractTar(buf.Bytes()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "disallowed link type") +} + +func TestExtractTar_RejectsHardlinks(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + hdr := &tar.Header{ + Name: "malicious_link", + Typeflag: tar.TypeLink, + Linkname: "other_file", + } + require.NoError(t, tw.WriteHeader(hdr)) + require.NoError(t, tw.Close()) + + _, err := ExtractTar(buf.Bytes()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "disallowed link type") +} + +func TestExtractTar_RejectsDeviceEntries(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + hdr := &tar.Header{ + Name: "malicious_device", + Typeflag: tar.TypeChar, + Mode: 0666, + } + require.NoError(t, tw.WriteHeader(hdr)) + require.NoError(t, tw.Close()) + + _, err := ExtractTar(buf.Bytes()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "disallowed entry type") +} + +func TestExtractTar_RejectsPathTraversal(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + }{ + {name: "dotdot prefix", path: "../etc/passwd"}, + {name: "dotdot in middle", path: "foo/../../etc/passwd"}, + {name: "absolute path", path: "/etc/passwd"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + hdr := &tar.Header{ + Name: tt.path, + Size: 4, + Typeflag: tar.TypeReg, + Mode: 0644, + } + require.NoError(t, tw.WriteHeader(hdr)) + _, err := tw.Write([]byte("test")) + require.NoError(t, err) + require.NoError(t, tw.Close()) + + _, err = ExtractTar(buf.Bytes()) + assert.Error(t, err) + }) + } +} + +func TestExtractTarWithLimit_RejectsOversized(t *testing.T) { + t.Parallel() + + files := []FileEntry{ + {Path: "big.txt", Content: bytes.Repeat([]byte("x"), 1024)}, + } + + tarData, err := CreateTar(files, DefaultTarOptions()) + require.NoError(t, err) + + _, err = ExtractTarWithLimit(tarData, 100) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum size") +} + +func TestExtractTar_SkipsDirectories(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + + // Write a directory entry + require.NoError(t, tw.WriteHeader(&tar.Header{ + Name: "mydir/", + Typeflag: tar.TypeDir, + Mode: 0755, + })) + + // Write a file inside it + content := []byte("hello") + require.NoError(t, tw.WriteHeader(&tar.Header{ + Name: "mydir/file.txt", + Size: int64(len(content)), + Typeflag: tar.TypeReg, + Mode: 0644, + })) + _, err := tw.Write(content) + require.NoError(t, err) + require.NoError(t, tw.Close()) + + files, err := ExtractTar(buf.Bytes()) + require.NoError(t, err) + + require.Len(t, files, 1) + assert.Equal(t, "mydir/file.txt", files[0].Path) + assert.Equal(t, content, files[0].Content) +} diff --git a/oci/artifact/testconsts_test.go b/oci/artifact/testconsts_test.go new file mode 100644 index 0000000..0903334 --- /dev/null +++ b/oci/artifact/testconsts_test.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +const ( + testFileA = "a.txt" + testFileB = "b.txt" +) diff --git a/oci/artifact/validate.go b/oci/artifact/validate.go new file mode 100644 index 0000000..926c12c --- /dev/null +++ b/oci/artifact/validate.go @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2" +) + +// MaxManifestSize is the maximum size of a manifest (1MB). +const MaxManifestSize int64 = 1 * 1024 * 1024 + +// MaxBlobSize is the maximum size of a blob (100MB). +const MaxBlobSize int64 = 100 * 1024 * 1024 + +// maxIndexManifests is the maximum number of manifests in an image index. +const maxIndexManifests = 32 + +// maxManifestLayers is the maximum number of layers in a manifest. +const maxManifestLayers = 64 + +// Compile-time interface check. +var _ oras.Target = (*ValidatingTarget)(nil) + +// ValidatingTarget wraps an oras.Target to enforce size and count limits +// on pushed content. This prevents OOM and resource exhaustion from +// malicious registries during pull operations. +type ValidatingTarget struct { + inner oras.Target +} + +// NewValidatingTarget wraps an oras.Target with size and structure validation +// applied on every Push. +func NewValidatingTarget(inner oras.Target) *ValidatingTarget { + return &ValidatingTarget{inner: inner} +} + +// Fetch delegates to the inner target. +func (v *ValidatingTarget) Fetch(ctx context.Context, target ocispec.Descriptor) (io.ReadCloser, error) { + return v.inner.Fetch(ctx, target) +} + +// Exists delegates to the inner target. +func (v *ValidatingTarget) Exists(ctx context.Context, target ocispec.Descriptor) (bool, error) { + return v.inner.Exists(ctx, target) +} + +// Resolve delegates to the inner target. +func (v *ValidatingTarget) Resolve(ctx context.Context, reference string) (ocispec.Descriptor, error) { + return v.inner.Resolve(ctx, reference) +} + +// Tag delegates to the inner target. +func (v *ValidatingTarget) Tag(ctx context.Context, desc ocispec.Descriptor, reference string) error { + return v.inner.Tag(ctx, desc, reference) +} + +// Push validates size and structure limits before delegating to the inner target. +func (v *ValidatingTarget) Push(ctx context.Context, desc ocispec.Descriptor, content io.Reader) error { + maxSize := MaxBlobSize + if IsManifestMediaType(desc.MediaType) { + maxSize = MaxManifestSize + } + + if desc.Size < 0 { + return fmt.Errorf("invalid negative content size %d", desc.Size) + } + if desc.Size > maxSize { + return fmt.Errorf( + "content size %d exceeds maximum allowed size %d for media type %q", + desc.Size, maxSize, desc.MediaType, + ) + } + + // Read with a limit to defend against lying descriptors + limitedReader := io.LimitReader(content, maxSize+1) + data, err := io.ReadAll(limitedReader) + if err != nil { + return fmt.Errorf("reading content: %w", err) + } + + if int64(len(data)) > maxSize { + return fmt.Errorf( + "actual content size exceeds maximum allowed size %d for media type %q", + maxSize, desc.MediaType, + ) + } + + // Verify digest integrity — defense-in-depth against a lying registry + actual := digest.FromBytes(data) + if actual != desc.Digest { + return fmt.Errorf("digest mismatch: expected %s, got %s", desc.Digest, actual) + } + + // Validate manifest/index structure limits + if err := ValidateManifestCounts(desc.MediaType, data); err != nil { + return err + } + + return v.inner.Push(ctx, desc, bytes.NewReader(data)) +} + +// ValidateManifestCounts checks layer/manifest counts for resource exhaustion prevention. +// +// It only inspects media types it recognizes (image index and image manifest). +// For any other media type it returns nil without examining the data. A nil +// return therefore means "no count violation was detected", not "this manifest +// is safe" — callers must not treat a nil return as a general safety guarantee. +func ValidateManifestCounts(mediaType string, data []byte) error { + switch mediaType { + case ocispec.MediaTypeImageIndex: + var index ocispec.Index + if err := json.Unmarshal(data, &index); err != nil { + return fmt.Errorf("parsing index: %w", err) + } + if len(index.Manifests) > maxIndexManifests { + return fmt.Errorf( + "index has %d manifests, exceeds maximum of %d", + len(index.Manifests), maxIndexManifests, + ) + } + case ocispec.MediaTypeImageManifest: + var manifest ocispec.Manifest + if err := json.Unmarshal(data, &manifest); err != nil { + return fmt.Errorf("parsing manifest: %w", err) + } + if len(manifest.Layers) > maxManifestLayers { + return fmt.Errorf( + "manifest has %d layers, exceeds maximum of %d", + len(manifest.Layers), maxManifestLayers, + ) + } + } + return nil +} + +// IsManifestMediaType returns true if the media type is a manifest or index type. +func IsManifestMediaType(mediaType string) bool { + switch mediaType { + case ocispec.MediaTypeImageManifest, ocispec.MediaTypeImageIndex, + "application/vnd.docker.distribution.manifest.v2+json", + "application/vnd.docker.distribution.manifest.list.v2+json": + return true + default: + return false + } +} diff --git a/oci/artifact/validate_test.go b/oci/artifact/validate_test.go new file mode 100644 index 0000000..f989a8b --- /dev/null +++ b/oci/artifact/validate_test.go @@ -0,0 +1,210 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package artifact + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "oras.land/oras-go/v2/content/memory" +) + +func TestIsManifestMediaType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mediaType string + want bool + }{ + {name: "oci image manifest", mediaType: ocispec.MediaTypeImageManifest, want: true}, + {name: "oci image index", mediaType: ocispec.MediaTypeImageIndex, want: true}, + {name: "docker manifest v2", mediaType: "application/vnd.docker.distribution.manifest.v2+json", want: true}, + {name: "docker manifest list v2", mediaType: "application/vnd.docker.distribution.manifest.list.v2+json", want: true}, + {name: "oci image layer", mediaType: ocispec.MediaTypeImageLayerGzip, want: false}, + {name: "oci image config", mediaType: ocispec.MediaTypeImageConfig, want: false}, + {name: "empty", mediaType: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, IsManifestMediaType(tt.mediaType)) + }) + } +} + +func TestValidatingTarget_Push(t *testing.T) { + t.Parallel() + + validManifest := []byte(`{"schemaVersion": 2}`) + oversized := make([]byte, MaxManifestSize+1) + + tests := []struct { + name string + desc ocispec.Descriptor + content []byte + wantErr bool + errSubstr string + }{ + { + name: "accepts valid content", + desc: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(validManifest), + Size: int64(len(validManifest)), + }, + content: validManifest, + wantErr: false, + }, + { + name: "rejects oversized declared size", + desc: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(oversized), + Size: int64(len(oversized)), + }, + content: oversized, + wantErr: true, + errSubstr: "exceeds maximum allowed size", + }, + { + name: "rejects lying (too-small) descriptor size", + desc: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromBytes(oversized), + Size: 10, // lying + }, + content: oversized, + wantErr: true, + errSubstr: "exceeds maximum allowed size", + }, + { + name: "rejects negative size", + desc: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromString("test"), + Size: -1, + }, + content: []byte("test"), + wantErr: true, + errSubstr: "invalid negative content size", + }, + { + name: "rejects digest mismatch", + desc: ocispec.Descriptor{ + MediaType: ocispec.MediaTypeImageManifest, + Digest: digest.FromString("something-else"), + Size: int64(len(validManifest)), + }, + content: validManifest, + wantErr: true, + errSubstr: "digest mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + inner := memory.New() + vt := NewValidatingTarget(inner) + + err := vt.Push(ctx, tt.desc, bytes.NewReader(tt.content)) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + return + } + + require.NoError(t, err) + exists, err := inner.Exists(ctx, tt.desc) + require.NoError(t, err) + assert.True(t, exists) + }) + } +} + +func TestValidateManifestCounts(t *testing.T) { + t.Parallel() + + tooManyManifests := func() []byte { + index := ocispec.Index{MediaType: ocispec.MediaTypeImageIndex} + index.SchemaVersion = 2 + index.Manifests = make([]ocispec.Descriptor, maxIndexManifests+1) + data, err := json.Marshal(index) + require.NoError(t, err) + return data + }() + + tooManyLayers := func() []byte { + m := ocispec.Manifest{MediaType: ocispec.MediaTypeImageManifest} + m.Layers = make([]ocispec.Descriptor, maxManifestLayers+1) + data, err := json.Marshal(m) + require.NoError(t, err) + return data + }() + + validManifest := func() []byte { + m := ocispec.Manifest{MediaType: ocispec.MediaTypeImageManifest} + m.Layers = make([]ocispec.Descriptor, 2) + data, err := json.Marshal(m) + require.NoError(t, err) + return data + }() + + tests := []struct { + name string + mediaType string + data []byte + wantErr bool + errSubstr string + }{ + { + name: "too many manifests in index", + mediaType: ocispec.MediaTypeImageIndex, + data: tooManyManifests, + wantErr: true, + errSubstr: "exceeds maximum", + }, + { + name: "too many layers in manifest", + mediaType: ocispec.MediaTypeImageManifest, + data: tooManyLayers, + wantErr: true, + errSubstr: "exceeds maximum", + }, + { + name: "valid counts", + mediaType: ocispec.MediaTypeImageManifest, + data: validManifest, + wantErr: false, + }, + { + name: "non-manifest media type is ignored", + mediaType: ocispec.MediaTypeImageLayerGzip, + data: []byte("not json"), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateManifestCounts(tt.mediaType, tt.data) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + return + } + require.NoError(t, err) + }) + } +} diff --git a/oci/skills/artifact_aliases.go b/oci/skills/artifact_aliases.go new file mode 100644 index 0000000..1fe5caf --- /dev/null +++ b/oci/skills/artifact_aliases.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package skills + +import ( + artifact "github.com/stacklok/toolhive-core/oci/artifact" +) + +// This file re-exports the artifact-agnostic OCI primitives that were extracted +// into the oci/artifact package. The aliases preserve the public surface of +// oci/skills so existing importers keep working unchanged. +// +// These are backward-compatibility re-exports of github.com/stacklok/toolhive-core/oci/artifact. +// NEW code should import github.com/stacklok/toolhive-core/oci/artifact directly +// rather than depending on these aliases. + +// Type aliases for tar/gzip primitives. +type ( + // FileEntry represents a file to include in a tar archive. + FileEntry = artifact.FileEntry + // TarOptions configures reproducible tar archive creation. + TarOptions = artifact.TarOptions + // GzipOptions configures reproducible gzip compression. + GzipOptions = artifact.GzipOptions +) + +// Function forwarding for tar primitives. +var ( + // DefaultTarOptions returns default options for reproducible tar archives. + DefaultTarOptions = artifact.DefaultTarOptions + // CreateTar creates a reproducible tar archive from the given files. + CreateTar = artifact.CreateTar + // ExtractTar extracts files from a tar archive. + ExtractTar = artifact.ExtractTar + // ExtractTarWithLimit extracts files from a tar archive with a per-file size limit. + ExtractTarWithLimit = artifact.ExtractTarWithLimit +) + +// Function forwarding for gzip primitives. +var ( + // DefaultGzipOptions returns default options for reproducible gzip compression. + DefaultGzipOptions = artifact.DefaultGzipOptions + // Compress creates a reproducible gzip compressed byte slice. + Compress = artifact.Compress + // Decompress decompresses gzip data. + Decompress = artifact.Decompress + // DecompressWithLimit decompresses gzip data with a size limit. + DecompressWithLimit = artifact.DecompressWithLimit + // CompressTar creates a reproducible .tar.gz from the given files. + CompressTar = artifact.CompressTar + // DecompressTar extracts files from a .tar.gz archive. + DecompressTar = artifact.DecompressTar +) + +// Function forwarding for platform helpers. +var ( + // PlatformString returns the platform in "os/arch" or "os/arch/variant" format. + PlatformString = artifact.PlatformString + // ParsePlatform parses a platform string in "os/arch" or "os/arch/variant" format. + ParsePlatform = artifact.ParsePlatform + // DefaultPlatforms are the default platforms for artifacts. + DefaultPlatforms = artifact.DefaultPlatforms +) + +// Size limit constants re-exported from the artifact package. +const ( + // MaxTarFileSize is the maximum size of a single file in a tar archive (100MB). + MaxTarFileSize = artifact.MaxTarFileSize + // MaxDecompressedSize is the maximum size of decompressed data (100MB). + MaxDecompressedSize = artifact.MaxDecompressedSize + // MaxManifestSize is the maximum size of a manifest (1MB). + MaxManifestSize = artifact.MaxManifestSize + // MaxBlobSize is the maximum size of a blob (100MB). + MaxBlobSize = artifact.MaxBlobSize +) + +// OS and architecture constants for OCI platform specifications. +const ( + // OSLinux is the Linux OS identifier used in OCI platform specs. + OSLinux = artifact.OSLinux + // ArchAMD64 is the x86-64 architecture identifier used in OCI platform specs. + ArchAMD64 = artifact.ArchAMD64 + // ArchARM64 is the 64-bit ARM architecture identifier used in OCI platform specs. + ArchARM64 = artifact.ArchARM64 +) diff --git a/oci/skills/artifact_aliases_test.go b/oci/skills/artifact_aliases_test.go new file mode 100644 index 0000000..0971783 --- /dev/null +++ b/oci/skills/artifact_aliases_test.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package skills + +import ( + artifact "github.com/stacklok/toolhive-core/oci/artifact" +) + +// Test-only forwarders for unexported artifact primitives that the existing +// oci/skills test suite (registry_test.go, gzip_test.go) references directly. +// Keeping these as test aliases lets the regression-gate tests stay unchanged +// after the primitives were moved into the oci/artifact package. + +// gzipOSUnknown mirrors the gzip "unknown" OS header value used in tests. +const gzipOSUnknown = 255 + +// maxIndexManifests mirrors the artifact cap on manifests in an image index. +const maxIndexManifests = 32 + +// maxManifestLayers mirrors the artifact cap on layers in a manifest. +const maxManifestLayers = 64 + +// newValidatingTarget forwards to artifact.NewValidatingTarget. +var newValidatingTarget = artifact.NewValidatingTarget + +// validateManifestCounts forwards to artifact.ValidateManifestCounts. +var validateManifestCounts = artifact.ValidateManifestCounts + +// isManifestMediaType forwards to artifact.IsManifestMediaType. +var isManifestMediaType = artifact.IsManifestMediaType diff --git a/oci/skills/mediatypes.go b/oci/skills/mediatypes.go index e71ccc9..8282690 100644 --- a/oci/skills/mediatypes.go +++ b/oci/skills/mediatypes.go @@ -6,7 +6,6 @@ package skills import ( "encoding/json" "fmt" - "strings" ocispec "github.com/opencontainers/image-spec/specs-go/v1" ) @@ -103,54 +102,6 @@ func SkillConfigFromImageConfig(imgConfig *ocispec.Image) (*SkillConfig, error) return config, nil } -// PlatformString returns the platform in "os/arch" or "os/arch/variant" format. -func PlatformString(p ocispec.Platform) string { - s := p.OS + "/" + p.Architecture - if p.Variant != "" { - s += "/" + p.Variant - } - return s -} - -// ParsePlatform parses a platform string in "os/arch" or "os/arch/variant" format. -func ParsePlatform(s string) (ocispec.Platform, error) { - parts := strings.Split(s, "/") - if len(parts) < 2 || len(parts) > 3 { - return ocispec.Platform{}, fmt.Errorf("invalid platform format: %q (expected os/arch or os/arch/variant)", s) - } - osName := strings.TrimSpace(parts[0]) - arch := strings.TrimSpace(parts[1]) - if osName == "" || arch == "" { - return ocispec.Platform{}, fmt.Errorf("invalid platform format: %q (os and arch cannot be empty)", s) - } - p := ocispec.Platform{OS: osName, Architecture: arch} - if len(parts) == 3 { - variant := strings.TrimSpace(parts[2]) - if variant == "" { - return ocispec.Platform{}, fmt.Errorf("invalid platform format: %q (variant cannot be empty)", s) - } - p.Variant = variant - } - return p, nil -} - -// OS and architecture constants for OCI platform specifications. -const ( - // OSLinux is the Linux OS identifier used in OCI platform specs. - OSLinux = "linux" - // ArchAMD64 is the x86-64 architecture identifier used in OCI platform specs. - ArchAMD64 = "amd64" - // ArchARM64 is the 64-bit ARM architecture identifier used in OCI platform specs. - ArchARM64 = "arm64" -) - -// DefaultPlatforms are the default platforms for skill artifacts. -// These cover most Kubernetes clusters. -var DefaultPlatforms = []ocispec.Platform{ - {OS: OSLinux, Architecture: ArchAMD64}, - {OS: OSLinux, Architecture: ArchARM64}, -} - // ParseRequiresAnnotation parses skill dependency references from manifest annotations. // Returns nil if the annotation is missing or invalid. func ParseRequiresAnnotation(annotations map[string]string) []string { diff --git a/oci/skills/packager.go b/oci/skills/packager.go index 642f316..baeb1b8 100644 --- a/oci/skills/packager.go +++ b/oci/skills/packager.go @@ -20,6 +20,8 @@ import ( specs "github.com/opencontainers/image-spec/specs-go" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "gopkg.in/yaml.v3" + + "github.com/stacklok/toolhive-core/oci/artifact" ) // Packager creates reproducible OCI artifacts from skill directories. @@ -133,14 +135,14 @@ func DefaultPackageOptions() PackageOptions { return PackageOptions{ Epoch: epoch, - Platforms: DefaultPlatforms, + Platforms: artifact.DefaultPlatforms, } } // Package packages a skill directory into an OCI artifact in the local store. func (p *Packager) Package(ctx context.Context, skillDir string, opts PackageOptions) (*PackageResult, error) { if len(opts.Platforms) == 0 { - opts.Platforms = DefaultPlatforms + opts.Platforms = artifact.DefaultPlatforms } // Read and validate skill directory @@ -167,7 +169,7 @@ func (p *Packager) Package(ctx context.Context, skillDir string, opts PackageOpt var manifestAnnotations map[string]string for i, platform := range opts.Platforms { - platformStr := PlatformString(platform) + platformStr := artifact.PlatformString(platform) ociConfig, cfg := createOCIConfig(content, uncompressedTar, platform, opts) configBytes, err := json.Marshal(ociConfig) @@ -396,10 +398,10 @@ func parseFrontmatter(content []byte) (*frontmatter, error) { // createContentLayer creates a reproducible tar.gz of the skill content. // Returns both compressed and uncompressed bytes (uncompressed needed for diff_id). func createContentLayer(content *skillDirContent, opts PackageOptions) (compressed, uncompressed []byte, err error) { - var files []FileEntry + var files []artifact.FileEntry // Add SKILL.md first - files = append(files, FileEntry{ + files = append(files, artifact.FileEntry{ Path: SkillFileName, Content: content.skillMD, }) @@ -412,21 +414,21 @@ func createContentLayer(content *skillDirContent, opts PackageOptions) (compress slices.Sort(sortedPaths) for _, p := range sortedPaths { - files = append(files, FileEntry{ + files = append(files, artifact.FileEntry{ Path: p, Content: content.files[p], }) } - tarOpts := TarOptions{Epoch: opts.Epoch} - gzipOpts := DefaultGzipOptions() + tarOpts := artifact.TarOptions{Epoch: opts.Epoch} + gzipOpts := artifact.DefaultGzipOptions() - uncompressed, err = CreateTar(files, tarOpts) + uncompressed, err = artifact.CreateTar(files, tarOpts) if err != nil { return nil, nil, fmt.Errorf("creating tar: %w", err) } - compressed, err = Compress(uncompressed, gzipOpts) + compressed, err = artifact.Compress(uncompressed, gzipOpts) if err != nil { return nil, nil, fmt.Errorf("compressing tar: %w", err) } @@ -558,7 +560,7 @@ func (p *Packager) createIndex( ) (digest.Digest, error) { manifests := make([]ocispec.Descriptor, 0, len(opts.Platforms)) for _, platform := range opts.Platforms { - platformStr := PlatformString(platform) + platformStr := artifact.PlatformString(platform) info, ok := platformManifests[platformStr] if !ok { return "", fmt.Errorf("missing manifest for platform %s", platformStr) diff --git a/oci/skills/registry.go b/oci/skills/registry.go index d4bb1c4..f9864e9 100644 --- a/oci/skills/registry.go +++ b/oci/skills/registry.go @@ -4,39 +4,22 @@ package skills import ( - "bytes" "context" - "encoding/json" "fmt" - "io" "github.com/opencontainers/go-digest" - ocispec "github.com/opencontainers/image-spec/specs-go/v1" "oras.land/oras-go/v2" "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote" "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras-go/v2/registry/remote/credentials" -) - -// MaxManifestSize is the maximum size of a manifest (1MB). -const MaxManifestSize int64 = 1 * 1024 * 1024 - -// MaxBlobSize is the maximum size of a blob (100MB). -const MaxBlobSize int64 = 100 * 1024 * 1024 -// maxIndexManifests is the maximum number of manifests in an image index. -const maxIndexManifests = 32 - -// maxManifestLayers is the maximum number of layers in a manifest. -const maxManifestLayers = 64 - -// Compile-time interface checks. -var ( - _ RegistryClient = (*Registry)(nil) - _ oras.Target = (*validatingTarget)(nil) + "github.com/stacklok/toolhive-core/oci/artifact" ) +// Compile-time interface check. +var _ RegistryClient = (*Registry)(nil) + // Registry provides operations for pushing and pulling skills from OCI registries. type Registry struct { credStore credentials.Store @@ -135,7 +118,7 @@ func (r *Registry) Pull(ctx context.Context, store *Store, ref string) (digest.D return "", fmt.Errorf("getting repository: %w", err) } - validated := newValidatingTarget(store.Target()) + validated := artifact.NewValidatingTarget(store.Target()) // Copy from remote to the validated local store desc, err := oras.Copy( @@ -154,123 +137,6 @@ func (r *Registry) Pull(ctx context.Context, store *Store, ref string) (digest.D return desc.Digest, nil } -// validatingTarget wraps an oras.Target to enforce size and count limits -// on pushed content. This prevents OOM and resource exhaustion from -// malicious registries during pull operations. -type validatingTarget struct { - inner oras.Target -} - -func newValidatingTarget(inner oras.Target) *validatingTarget { - return &validatingTarget{inner: inner} -} - -// Fetch delegates to the inner target. -func (v *validatingTarget) Fetch(ctx context.Context, target ocispec.Descriptor) (io.ReadCloser, error) { - return v.inner.Fetch(ctx, target) -} - -// Exists delegates to the inner target. -func (v *validatingTarget) Exists(ctx context.Context, target ocispec.Descriptor) (bool, error) { - return v.inner.Exists(ctx, target) -} - -// Resolve delegates to the inner target. -func (v *validatingTarget) Resolve(ctx context.Context, reference string) (ocispec.Descriptor, error) { - return v.inner.Resolve(ctx, reference) -} - -// Tag delegates to the inner target. -func (v *validatingTarget) Tag(ctx context.Context, desc ocispec.Descriptor, reference string) error { - return v.inner.Tag(ctx, desc, reference) -} - -// Push validates size and structure limits before delegating to the inner target. -func (v *validatingTarget) Push(ctx context.Context, desc ocispec.Descriptor, content io.Reader) error { - maxSize := MaxBlobSize - if isManifestMediaType(desc.MediaType) { - maxSize = MaxManifestSize - } - - if desc.Size < 0 { - return fmt.Errorf("invalid negative content size %d", desc.Size) - } - if desc.Size > maxSize { - return fmt.Errorf( - "content size %d exceeds maximum allowed size %d for media type %q", - desc.Size, maxSize, desc.MediaType, - ) - } - - // Read with a limit to defend against lying descriptors - limitedReader := io.LimitReader(content, maxSize+1) - data, err := io.ReadAll(limitedReader) - if err != nil { - return fmt.Errorf("reading content: %w", err) - } - - if int64(len(data)) > maxSize { - return fmt.Errorf( - "actual content size exceeds maximum allowed size %d for media type %q", - maxSize, desc.MediaType, - ) - } - - // Verify digest integrity — defense-in-depth against a lying registry - actual := digest.FromBytes(data) - if actual != desc.Digest { - return fmt.Errorf("digest mismatch: expected %s, got %s", desc.Digest, actual) - } - - // Validate manifest/index structure limits - if err := validateManifestCounts(desc.MediaType, data); err != nil { - return err - } - - return v.inner.Push(ctx, desc, bytes.NewReader(data)) -} - -// validateManifestCounts checks layer/manifest counts for resource exhaustion prevention. -func validateManifestCounts(mediaType string, data []byte) error { - switch mediaType { - case ocispec.MediaTypeImageIndex: - var index ocispec.Index - if err := json.Unmarshal(data, &index); err != nil { - return fmt.Errorf("parsing index: %w", err) - } - if len(index.Manifests) > maxIndexManifests { - return fmt.Errorf( - "index has %d manifests, exceeds maximum of %d", - len(index.Manifests), maxIndexManifests, - ) - } - case ocispec.MediaTypeImageManifest: - var manifest ocispec.Manifest - if err := json.Unmarshal(data, &manifest); err != nil { - return fmt.Errorf("parsing manifest: %w", err) - } - if len(manifest.Layers) > maxManifestLayers { - return fmt.Errorf( - "manifest has %d layers, exceeds maximum of %d", - len(manifest.Layers), maxManifestLayers, - ) - } - } - return nil -} - -// isManifestMediaType returns true if the media type is a manifest or index type. -func isManifestMediaType(mediaType string) bool { - switch mediaType { - case ocispec.MediaTypeImageManifest, ocispec.MediaTypeImageIndex, - "application/vnd.docker.distribution.manifest.v2+json", - "application/vnd.docker.distribution.manifest.list.v2+json": - return true - default: - return false - } -} - // parseReference parses an OCI reference and validates it has a tag or digest. func parseReference(ref string) (registry.Reference, error) { parsedRef, err := registry.ParseReference(ref)