Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/attestation/extractor.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package attestation

import "strings"

// Extractor is the interface that all attestation type extractors must implement
type Extractor interface {
// Name returns the name/type of attestation this extractor handles (e.g., "material", "command-run", "product")
Expand Down Expand Up @@ -43,7 +45,7 @@ func (c *ExtractorChain) ExtractAll(attestations []TypedAttestation, typeFilter
filterSet := make(map[string]struct{})
if len(typeFilter) > 0 {
for _, t := range typeFilter {
filterSet[t] = struct{}{}
filterSet[strings.ToLower(strings.TrimSpace(t))] = struct{}{}
}
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/attestation/extractor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package attestation

import (
"testing"
)

func TestExtractAll_TypeFilterNormalization(t *testing.T) {
attestations := []TypedAttestation{
{
Type: "command-run",
Data: map[string]interface{}{
"processes": []interface{}{
map[string]interface{}{
"openedfiles": map[string]interface{}{
"/usr/lib/python3/site-packages/requests-2.28.0.dist-info/METADATA": map[string]interface{}{
"sha256": "abc123",
},
},
},
},
},
},
}

chain := NewExtractorChain()

tests := []struct {
name string
typeFilter []string
}{
{"exact lowercase", []string{"command-run"}},
{"mixed case", []string{"Command-Run"}},
{"uppercase", []string{"COMMAND-RUN"}},
{"with whitespace", []string{" command-run "}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
files := chain.ExtractAll(attestations, tt.typeFilter)
if len(files) == 0 {
t.Errorf("ExtractAll with typeFilter %v returned no files; want at least 1", tt.typeFilter)
}
})
}
}
16 changes: 12 additions & 4 deletions pkg/attestation/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,20 @@ func decodeEnvelopePayload(rawPayload json.RawMessage) ([]byte, error) {
return nil, fmt.Errorf("unsupported payload format")
}

var supportedEncodings = []*base64.Encoding{
base64.RawURLEncoding,
base64.URLEncoding,
base64.StdEncoding,
base64.RawStdEncoding,
}

func decodeBase64Any(s string) ([]byte, error) {
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("base64 decoding failed: %w", err)
for _, enc := range supportedEncodings {
if b, err := enc.DecodeString(s); err == nil {
return b, nil
}
}
return decoded, nil
return nil, fmt.Errorf("base64 decoding failed: not a recognized base64 variant")
}

func extractAttestations(predicate map[string]interface{}, typeFilter []string) ([]TypedAttestation, error) {
Expand Down
49 changes: 49 additions & 0 deletions pkg/attestation/parser_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package attestation

import (
"bytes"
"encoding/base64"
"testing"
)

func TestDecodeBase64Any(t *testing.T) {
tests := []struct {
name string
encoding *base64.Encoding
}{
{"RawURLEncoding (DSSE standard)", base64.RawURLEncoding},
{"URLEncoding", base64.URLEncoding},
{"StdEncoding", base64.StdEncoding},
{"RawStdEncoding", base64.RawStdEncoding},
}

// 0xFF bytes encode to "////" in StdEncoding and "____" in RawURLEncoding,ensuring the variants are non-interchangeable
original := []byte{0xff, 0xff, 0xff}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encoded := tt.encoding.EncodeToString(original)
got, err := decodeBase64Any(encoded)
if err != nil {
t.Fatalf("decodeBase64Any(%q) failed: %v", encoded, err)
}
if !bytes.Equal(got, original) {
t.Errorf("decodeBase64Any(%q) = %v; want %v", encoded, got, original)
}
})
}
}

func TestDecodeBase64Any_NoPadding(t *testing.T) {
// DSSE payloads use base64url without padding; StdEncoding rejects these.
original := []byte(`{"a":1}`)

encoded := base64.RawURLEncoding.EncodeToString(original)
got, err := decodeBase64Any(encoded)
if err != nil {
t.Fatalf("decodeBase64Any(%q) failed: %v", encoded, err)
}
if !bytes.Equal(got, original) {
t.Errorf("decodeBase64Any(%q) = %q; want %q", encoded, got, original)
}
}