diff --git a/pkg/attestation/extractor.go b/pkg/attestation/extractor.go index dfe8824..199c661 100644 --- a/pkg/attestation/extractor.go +++ b/pkg/attestation/extractor.go @@ -1,5 +1,7 @@ package attestation +import "strings" + // Extractor is the interface that all attestation type extractors must implement type Extractor interface { // Name returns the name/type of attestation this extractor handles (e.g., "material", "command-run", "product") @@ -43,7 +45,7 @@ func (c *ExtractorChain) ExtractAll(attestations []TypedAttestation, typeFilter filterSet := make(map[string]struct{}) if len(typeFilter) > 0 { for _, t := range typeFilter { - filterSet[t] = struct{}{} + filterSet[strings.ToLower(strings.TrimSpace(t))] = struct{}{} } } diff --git a/pkg/attestation/extractor_test.go b/pkg/attestation/extractor_test.go new file mode 100644 index 0000000..3661f3e --- /dev/null +++ b/pkg/attestation/extractor_test.go @@ -0,0 +1,45 @@ +package attestation + +import ( + "testing" +) + +func TestExtractAll_TypeFilterNormalization(t *testing.T) { + attestations := []TypedAttestation{ + { + Type: "command-run", + Data: map[string]interface{}{ + "processes": []interface{}{ + map[string]interface{}{ + "openedfiles": map[string]interface{}{ + "/usr/lib/python3/site-packages/requests-2.28.0.dist-info/METADATA": map[string]interface{}{ + "sha256": "abc123", + }, + }, + }, + }, + }, + }, + } + + chain := NewExtractorChain() + + tests := []struct { + name string + typeFilter []string + }{ + {"exact lowercase", []string{"command-run"}}, + {"mixed case", []string{"Command-Run"}}, + {"uppercase", []string{"COMMAND-RUN"}}, + {"with whitespace", []string{" command-run "}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + files := chain.ExtractAll(attestations, tt.typeFilter) + if len(files) == 0 { + t.Errorf("ExtractAll with typeFilter %v returned no files; want at least 1", tt.typeFilter) + } + }) + } +} diff --git a/pkg/attestation/parser.go b/pkg/attestation/parser.go index 3756a58..74fc9cc 100644 --- a/pkg/attestation/parser.go +++ b/pkg/attestation/parser.go @@ -76,12 +76,20 @@ func decodeEnvelopePayload(rawPayload json.RawMessage) ([]byte, error) { return nil, fmt.Errorf("unsupported payload format") } +var supportedEncodings = []*base64.Encoding{ + base64.RawURLEncoding, + base64.URLEncoding, + base64.StdEncoding, + base64.RawStdEncoding, +} + func decodeBase64Any(s string) ([]byte, error) { - decoded, err := base64.StdEncoding.DecodeString(s) - if err != nil { - return nil, fmt.Errorf("base64 decoding failed: %w", err) + for _, enc := range supportedEncodings { + if b, err := enc.DecodeString(s); err == nil { + return b, nil + } } - return decoded, nil + return nil, fmt.Errorf("base64 decoding failed: not a recognized base64 variant") } func extractAttestations(predicate map[string]interface{}, typeFilter []string) ([]TypedAttestation, error) { diff --git a/pkg/attestation/parser_test.go b/pkg/attestation/parser_test.go new file mode 100644 index 0000000..5e93826 --- /dev/null +++ b/pkg/attestation/parser_test.go @@ -0,0 +1,49 @@ +package attestation + +import ( + "bytes" + "encoding/base64" + "testing" +) + +func TestDecodeBase64Any(t *testing.T) { + tests := []struct { + name string + encoding *base64.Encoding + }{ + {"RawURLEncoding (DSSE standard)", base64.RawURLEncoding}, + {"URLEncoding", base64.URLEncoding}, + {"StdEncoding", base64.StdEncoding}, + {"RawStdEncoding", base64.RawStdEncoding}, + } + + // 0xFF bytes encode to "////" in StdEncoding and "____" in RawURLEncoding,ensuring the variants are non-interchangeable + original := []byte{0xff, 0xff, 0xff} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encoded := tt.encoding.EncodeToString(original) + got, err := decodeBase64Any(encoded) + if err != nil { + t.Fatalf("decodeBase64Any(%q) failed: %v", encoded, err) + } + if !bytes.Equal(got, original) { + t.Errorf("decodeBase64Any(%q) = %v; want %v", encoded, got, original) + } + }) + } +} + +func TestDecodeBase64Any_NoPadding(t *testing.T) { + // DSSE payloads use base64url without padding; StdEncoding rejects these. + original := []byte(`{"a":1}`) + + encoded := base64.RawURLEncoding.EncodeToString(original) + got, err := decodeBase64Any(encoded) + if err != nil { + t.Fatalf("decodeBase64Any(%q) failed: %v", encoded, err) + } + if !bytes.Equal(got, original) { + t.Errorf("decodeBase64Any(%q) = %q; want %q", encoded, got, original) + } +}