diff --git a/internal/bootstraptoken/bootstraptoken.go b/internal/bootstraptoken/bootstraptoken.go index 1dfefbfd..abf4c5d1 100644 --- a/internal/bootstraptoken/bootstraptoken.go +++ b/internal/bootstraptoken/bootstraptoken.go @@ -88,11 +88,18 @@ func NewFromString(rawBootstrapToken string) (*BootstrapToken, error) { return nil, fmt.Errorf("%w: failed to decode service account name: %v", ErrInvalidBootstrapTokenFormat, err) } + if len(serviceAccountName) == 0 { + return nil, fmt.Errorf("%w: empty service account name", ErrInvalidBootstrapTokenFormat) + } + serviceAccountToken, err := encoding.DecodeString(splits[2]) if err != nil { return nil, fmt.Errorf("%w: failed to decode service account token: %v", ErrInvalidBootstrapTokenFormat, err) } + if len(serviceAccountToken) == 0 { + return nil, fmt.Errorf("%w: empty service account token", ErrInvalidBootstrapTokenFormat) + } // Optionally parse the certificate var certificate *x509.Certificate @@ -106,11 +113,15 @@ func NewFromString(rawBootstrapToken string) (*BootstrapToken, error) { } block, _ := pem.Decode(rawCertificate) + if block == nil { + return nil, fmt.Errorf("%w: failed to parse certificate: expected a PEM format", + ErrInvalidBootstrapTokenFormat) + } certificate, err = x509.ParseCertificate(block.Bytes) if err != nil { return nil, fmt.Errorf("%w: failed to parse certificate: %v", - ErrFailedToCreateBootstrapToken, err) + ErrInvalidBootstrapTokenFormat, err) } } diff --git a/internal/bootstraptoken/bootstraptoken_test.go b/internal/bootstraptoken/bootstraptoken_test.go index 9458e3e4..bdc747b5 100644 --- a/internal/bootstraptoken/bootstraptoken_test.go +++ b/internal/bootstraptoken/bootstraptoken_test.go @@ -1,12 +1,14 @@ package bootstraptoken_test import ( + "encoding/base64" "encoding/pem" + "testing" + "github.com/cirruslabs/orchard/internal/bootstraptoken" controllercmd "github.com/cirruslabs/orchard/internal/command/controller" "github.com/google/uuid" "github.com/stretchr/testify/require" - "testing" ) func TestBootstrapTokenTwoWay(t *testing.T) { @@ -28,6 +30,7 @@ func TestBootstrapTokenTwoWay(t *testing.T) { require.Equal(t, bootstrapTokenOld.ServiceAccountName(), bootstrapTokenNew.ServiceAccountName()) require.Equal(t, bootstrapTokenOld.ServiceAccountToken(), bootstrapTokenNew.ServiceAccountToken()) require.Equal(t, bootstrapTokenOld.Certificate(), bootstrapTokenNew.Certificate()) + require.Equal(t, bootstrapTokenOld.String(), bootstrapTokenNew.String()) } func TestBootstrapTokenTwoWayEmptyCertificate(t *testing.T) { @@ -41,3 +44,49 @@ func TestBootstrapTokenTwoWayEmptyCertificate(t *testing.T) { require.Equal(t, bootstrapTokenOld.ServiceAccountToken(), bootstrapTokenNew.ServiceAccountToken()) require.Equal(t, bootstrapTokenOld.Certificate(), bootstrapTokenNew.Certificate()) } + +func TestNewFromStringNonPEMCertificate(t *testing.T) { + rawBootstrapToken := "orchard-bootstrap-token-v0." + + encodeTokenPart("name") + "." + + encodeTokenPart("token") + "." + + encodeTokenPart("not pem") + + _, err := bootstraptoken.NewFromString(rawBootstrapToken) + + require.ErrorIs(t, err, bootstraptoken.ErrInvalidBootstrapTokenFormat) +} + +func TestNewFromStringInvalidCertificate(t *testing.T) { + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: []byte("not der"), + } + rawBootstrapToken := "orchard-bootstrap-token-v0." + + encodeTokenPart("name") + "." + + encodeTokenPart("token") + "." + + encodeTokenPart(string(pem.EncodeToMemory(block))) + + _, err := bootstraptoken.NewFromString(rawBootstrapToken) + + require.ErrorIs(t, err, bootstraptoken.ErrInvalidBootstrapTokenFormat) +} + +func TestNewFromStringEmptyServiceAccountName(t *testing.T) { + rawBootstrapToken := "orchard-bootstrap-token-v0.." + encodeTokenPart("token") + + _, err := bootstraptoken.NewFromString(rawBootstrapToken) + + require.ErrorIs(t, err, bootstraptoken.ErrInvalidBootstrapTokenFormat) +} + +func TestNewFromStringEmptyServiceAccountToken(t *testing.T) { + rawBootstrapToken := "orchard-bootstrap-token-v0." + encodeTokenPart("name") + "." + + _, err := bootstraptoken.NewFromString(rawBootstrapToken) + + require.ErrorIs(t, err, bootstraptoken.ErrInvalidBootstrapTokenFormat) +} + +func encodeTokenPart(s string) string { + return base64.RawURLEncoding.EncodeToString([]byte(s)) +}