diff --git a/.gitignore b/.gitignore index c17ed53a2..313f51d5b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ go.work.sum # Output of the go coverage tool, specifically when used with LiteIDE *.out +.gocache # Others *.swp diff --git a/api/api.go b/api/api.go index 09d2c83fb..c3e128fe9 100644 --- a/api/api.go +++ b/api/api.go @@ -257,22 +257,32 @@ func scepFromProvisioner(p *provisioner.SCEP) *models.SCEP { } } +func estFromProvisioner(p *provisioner.EST) *provisioner.EST { + prov := *p + prov.ClientCertificateRoots = []byte(redacted) + prov.BasicAuthUsername = redacted + prov.BasicAuthPassword = redacted + return &prov +} + // MarshalJSON implements json.Marshaler. It marshals the ProvisionersResponse // into a byte slice. // // Special treatment is given to the SCEP provisioner, as it contains a // challenge secret that MUST NOT be leaked in (public) HTTP responses. The -// challenge value is thus redacted in HTTP responses. +// challenge value is thus redacted in HTTP responses. EST provisioners also +// contain a shared secret and are redacted in responses. func (p ProvisionersResponse) MarshalJSON() ([]byte, error) { var responseProvisioners provisioner.List for _, item := range p.Provisioners { - scepProv, ok := item.(*provisioner.SCEP) - if !ok { + switch prov := item.(type) { + case *provisioner.SCEP: + responseProvisioners = append(responseProvisioners, scepFromProvisioner(prov)) + case *provisioner.EST: + responseProvisioners = append(responseProvisioners, estFromProvisioner(prov)) + default: responseProvisioners = append(responseProvisioners, item) - continue } - - responseProvisioners = append(responseProvisioners, scepFromProvisioner(scepProv)) } var list = struct { diff --git a/authority/admin/db.go b/authority/admin/db.go index 63940a8a3..3a5961c07 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -46,6 +46,8 @@ func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*l v.Data = new(linkedca.ProvisionerDetails_SCEP) case linkedca.Provisioner_NEBULA: v.Data = new(linkedca.ProvisionerDetails_Nebula) + case linkedca.Provisioner_EST: + v.Data = new(linkedca.ProvisionerDetails_EST) default: return nil, fmt.Errorf("unsupported provisioner type %s", typ) } diff --git a/authority/authority.go b/authority/authority.go index 98dd68968..aa2240ea2 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -33,6 +33,7 @@ import ( "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/est" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" @@ -70,6 +71,11 @@ type Authority struct { scepAuthority *scep.Authority scepKeyManager provisioner.SCEPKeyManager + // EST CA + estOptions *est.Options + validateEST bool + estAuthority *est.Authority + // SSH CA sshHostPassword []byte sshUserPassword []byte @@ -133,6 +139,7 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { config: cfg, certificates: new(sync.Map), validateSCEP: true, + validateEST: true, meter: noopMeter{}, wrapTransport: httptransport.NoopWrapper(), } @@ -170,6 +177,7 @@ func NewEmbedded(opts ...Option) (*Authority, error) { certificates: new(sync.Map), meter: noopMeter{}, wrapTransport: httptransport.NoopWrapper(), + validateEST: true, } // Apply options. @@ -314,6 +322,11 @@ func (a *Authority) ReloadAdminResources(ctx context.Context) error { // TODO(hs): don't remove the authority if we can't also // reload it. //a.scepAuthority = nil + case a.requiresEST() && a.GetEST() != nil: + a.estAuthority.UpdateProvisioners(a.getESTProvisionerNames()) + if err := a.estAuthority.Validate(); err != nil { + log.Printf("failed validating EST authority: %v\n", err) + } } return nil @@ -814,6 +827,51 @@ func (a *Authority) init() error { } } + // EST functionality is provided through an instance of est.Authority. + switch { + case a.requiresEST() && a.GetEST() == nil: + if a.estOptions == nil { + options := &est.Options{ + Roots: a.rootX509Certs, + Intermediates: a.intermediateX509Certs, + } + if len(a.intermediateX509Certs) > 0 { + options.SignerCert = a.intermediateX509Certs[0] + } + if a.config.IntermediateKey != "" { + if signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ + SigningKey: a.config.IntermediateKey, + Password: a.password, + }); err == nil { + options.Signer = signer + } + } + a.estOptions = options + } + + a.estOptions.ESTProvisionerNames = a.getESTProvisionerNames() + + estAuthority, err := est.New(a, *a.estOptions) + if err != nil { + return err + } + + if a.validateEST { + if err := estAuthority.Validate(); err != nil { + a.initLogf("failed validating EST authority: %v", err) + } + } + + a.estAuthority = estAuthority + case !a.requiresEST() && a.GetEST() != nil: + a.estAuthority = nil + case a.requiresEST() && a.GetEST() != nil: + a.estAuthority.UpdateProvisioners(a.getESTProvisionerNames()) + if err := a.estAuthority.Validate(); err != nil { + log.Printf("failed validating EST authority: %v\n", err) + } + } + // Load X509 constraints engine. // // This is currently only available in CA mode. @@ -1002,7 +1060,19 @@ func (a *Authority) GetSCEP() *scep.Authority { return a.scepAuthority } -// HasACMEProvisioner returns true if at least one ACME provisioner is configured. +// requiresEST iterates over the configured provisioners +// and determines if at least one of them is an EST provisioner. +func (a *Authority) requiresEST() bool { + for _, p := range a.config.AuthorityConfig.Provisioners { + if p.GetType() == provisioner.TypeEST { + return true + } + } + return false +} + +// HasACMEProvisioner iterates over the configured provisioners +// and determines if at least one of them is an ACME provisioner. func (a *Authority) HasACMEProvisioner() bool { for _, p := range a.config.AuthorityConfig.Provisioners { if p.GetType() == provisioner.TypeACME { @@ -1012,6 +1082,23 @@ func (a *Authority) HasACMEProvisioner() bool { return false } +// getESTProvisionerNames returns the names of the EST provisioners +// that are currently available in the CA. +func (a *Authority) getESTProvisionerNames() (names []string) { + for _, p := range a.config.AuthorityConfig.Provisioners { + if p.GetType() == provisioner.TypeEST { + names = append(names, p.GetName()) + } + } + + return +} + +// GetEST returns the configured EST Authority +func (a *Authority) GetEST() *est.Authority { + return a.estAuthority +} + func (a *Authority) startCRLGenerator() error { if !a.config.CRL.IsEnabled() { return nil diff --git a/authority/options.go b/authority/options.go index d85a611dc..e840ae5b6 100644 --- a/authority/options.go +++ b/authority/options.go @@ -18,6 +18,7 @@ import ( casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/httptransport" + "github.com/smallstep/certificates/est" "github.com/smallstep/certificates/scep" ) @@ -242,6 +243,17 @@ func WithFullSCEPOptions(options *scep.Options) Option { } } +// WithFullESTOptions defines the options used for EST support. +// +// This feature is EXPERIMENTAL and might change at any time. +func WithFullESTOptions(options *est.Options) Option { + return func(a *Authority) error { + a.estOptions = options + a.validateEST = false + return nil + } +} + // WithSCEPKeyManager defines the key manager used on SCEP provisioners. // // This feature is EXPERIMENTAL and might change at any time. diff --git a/authority/provisioner/est.go b/authority/provisioner/est.go new file mode 100644 index 000000000..152a31b1c --- /dev/null +++ b/authority/provisioner/est.go @@ -0,0 +1,192 @@ +package provisioner + +import ( + "context" + "crypto" + "crypto/x509" + "fmt" + "time" + + "github.com/pkg/errors" + + "github.com/smallstep/certificates/internal/httptransport" + "github.com/smallstep/linkedca" +) + +// EST is the EST provisioner type, an entity that can authorize the EST flow. +type EST struct { + *base + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + EnableTLSClientCertificate *bool `json:"enableTlsClientCertificate,omitempty"` + ForwardedTLSClientCertHeader string `json:"forwardedTlsClientCertHeader,omitempty"` + EnableHTTPBasicAuth *bool `json:"enableHTTPBasicAuth,omitempty"` + BasicAuthUsername string `json:"basicAuthUsername,omitempty"` + BasicAuthPassword string `json:"basicAuthPassword,omitempty"` + ClientCertificateRoots []byte `json:"clientCertificateRoots,omitempty"` + ForceCN bool `json:"forceCN,omitempty"` + IncludeRoot bool `json:"includeRoot,omitempty"` + ExcludeIntermediate bool `json:"excludeIntermediate,omitempty"` + MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + CSRAttrs []byte `json:"csrAttrs,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + ctl *Controller + signer crypto.Signer + signerCertificate *x509.Certificate + challengeValidationController *challengeValidationController + notificationController *notificationController + clientCertificateRootPool *x509.CertPool +} + +// GetID returns the provisioner unique identifier. +func (s *EST) GetID() string { + if s.ID != "" { + return s.ID + } + return s.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner from a token. +func (s *EST) GetIDForToken() string { + return "est/" + s.Name +} + +// GetName returns the name of the provisioner. +func (s *EST) GetName() string { + return s.Name +} + +// GetType returns the type of provisioner. +func (s *EST) GetType() Type { + return TypeEST +} + +// GetEncryptedKey returns the base provisioner encrypted key if it's defined. +func (s *EST) GetEncryptedKey() (string, string, bool) { + return "", "", false +} + +// GetTokenID returns the identifier of the token. This provisioner does not support tokens. +func (s *EST) GetTokenID(string) (string, error) { + return "", ErrTokenFlowNotSupported +} + +// GetOptions returns the configured provisioner options. +func (s *EST) GetOptions() *Options { + return s.Options +} + +// DefaultTLSCertDuration returns the default TLS cert duration enforced by the provisioner. +func (s *EST) DefaultTLSCertDuration() time.Duration { + return s.ctl.Claimer.DefaultTLSCertDuration() +} + +// newChallengeValidationController creates a new challengeValidationController +// that performs challenge validation through webhooks. +func newESTChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { + estHooks := []*Webhook{} + for _, wh := range webhooks { + // if wh.Kind != linkedca.Webhook_ESTCHALLENGE.String() { + if wh.Kind != "ESTCHALLENGE" { + continue + } + estHooks = append(estHooks, wh) + } + return &challengeValidationController{ + client: client, + wrapTransport: tw, + webhooks: estHooks, + } +} + +// Init initializes and validates the fields of an EST type. +func (s *EST) Init(config Config) (err error) { + switch { + case s.Type == "": + return errors.New("provisioner type cannot be empty") + case s.Name == "": + return errors.New("provisioner name cannot be empty") + } + + if s.MinimumPublicKeyLength == 0 { + s.MinimumPublicKeyLength = 2048 + } + if s.MinimumPublicKeyLength%8 != 0 { + return errors.Errorf("%d bits is not exactly divisible by 8", s.MinimumPublicKeyLength) + } + + // Prepare the EST challenge validator + s.challengeValidationController = newESTChallengeValidationController( + config.WebhookClient, + config.WrapTransport, + s.GetOptions().GetWebhooks(), + ) + + // Prepare the EST notification controller + s.notificationController = newNotificationController( + config.WebhookClient, + config.WrapTransport, + s.GetOptions().GetWebhooks(), + ) + + if err := s.parseClientCertificateRoots(); err != nil { + return err + } + + if err := s.normalizeAuthConfig(); err != nil { + return err + } + + s.ctl, err = NewController(s, s.Claims, config, s.Options) + return err +} + +// AuthorizeSign does not do any verification; main validation is in the EST protocol. +func (s *EST) AuthorizeSign(context.Context, string) ([]SignOption, error) { + return []SignOption{ + s, + newProvisionerExtensionOption(TypeEST, s.Name, "").WithControllerOptions(s.ctl), + newForceCNOption(s.ForceCN), + profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), + newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), + newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(s.ctl.getPolicy().getX509()), + s.ctl.newWebhookController(nil, linkedca.Webhook_X509), + }, nil +} + +// ShouldIncludeRootInChain indicates if the CA should return its root in the chain. +func (s *EST) ShouldIncludeRootInChain() bool { + return s.IncludeRoot +} + +// ShouldIncludeIntermediateInChain indicates if the CA should include the intermediate CA certificate. +func (s *EST) ShouldIncludeIntermediateInChain() bool { + return !s.ExcludeIntermediate +} + +// GetSigner returns the provisioner specific signer, used to sign EST responses. +func (s *EST) GetSigner() (*x509.Certificate, crypto.Signer) { + return s.signerCertificate, s.signer +} + +// GetCSRAttributes returns the CSR attributes to signal to clients. +func (s *EST) GetCSRAttributes(context.Context) ([]byte, error) { + return s.CSRAttrs, nil +} + +func (s *EST) NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error { + if s.notificationController == nil { + return fmt.Errorf("provisioner %q wasn't initialized", s.Name) + } + return s.notificationController.Success(ctx, csr, cert, transactionID) +} + +func (s *EST) NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error { + if s.notificationController == nil { + return fmt.Errorf("provisioner %q wasn't initialized", s.Name) + } + return s.notificationController.Failure(ctx, csr, transactionID, errorCode, errorDescription) +} diff --git a/authority/provisioner/est_auth.go b/authority/provisioner/est_auth.go new file mode 100644 index 000000000..9543e7813 --- /dev/null +++ b/authority/provisioner/est_auth.go @@ -0,0 +1,272 @@ +package provisioner + +import ( + "context" + "crypto/subtle" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + + "github.com/smallstep/certificates/webhook" + "go.step.sm/crypto/x509util" +) + +var ( + ErrESTAuthMethodDisabled = errors.New("est authentication method disabled") + ErrESTAuthMethodNotFound = errors.New("no valid est authentication method found") + ErrESTAuthMethodMisconfigured = errors.New("est authentication method misconfigured") + ErrESTAuthDenied = errors.New("est authentication denied") +) + +// ClientCertificateConfig holds the EST client certificate authentication configuration. +type ClientCertificateConfig struct { + Enable bool + ForwardedTLSClientCertHeader string +} + +// ESTAuthRequest contains authentication material extracted from the request. +type ESTAuthRequest struct { + CSR *x509.CertificateRequest + ClientCertificate *x509.Certificate + ClientCertificateChain []*x509.Certificate + CARoots []*x509.Certificate + CAIntermediates []*x509.Certificate + AuthenticationHeader string + BasicAuthUsername string + BasicAuthPassword string + BearerToken string +} + +func (s *EST) GetClientCertificateConfig() *ClientCertificateConfig { + return &ClientCertificateConfig{ + Enable: boolValue(s.EnableTLSClientCertificate, false), + ForwardedTLSClientCertHeader: s.ForwardedTLSClientCertHeader, + } +} + +// AuthorizeRequest validates the request against configured EST auth methods. +func (s *EST) AuthorizeRequest(ctx context.Context, req ESTAuthRequest) ([]SignCSROption, error) { + if s.hasAuthWebhooks() { + return s.authorizeWithWebhook(ctx, &req) + } + return s.authorizeRequestLocal(req) +} + +// authorizeRequestLocal validates the request using provisioner configuration. +func (s *EST) authorizeRequestLocal(req ESTAuthRequest) ([]SignCSROption, error) { + var lastErr error = ErrESTAuthMethodNotFound + if req.ClientCertificate != nil { + if boolValue(s.EnableTLSClientCertificate, false) { + if s.hasClientCertificateRoots() { + if err := verifyCertificateWithPool(req.ClientCertificate, req.ClientCertificateChain, s.clientCertificateRootPool, nil); err == nil { + return []SignCSROption{}, nil + } else { + lastErr = err + } + } else { + if err := verifyCertificate(req.ClientCertificate, req.ClientCertificateChain, req.CARoots, req.CAIntermediates); err == nil { + return []SignCSROption{}, nil + } else { + lastErr = err + } + } + } else { + lastErr = ErrESTAuthMethodDisabled + } + } + + if req.hasBasicAuth() { + if boolValue(s.EnableHTTPBasicAuth, false) && s.BasicAuthPassword != "" { + if err := s.validateBasicAuthPassword(req.BasicAuthUsername, req.BasicAuthPassword); err == nil { + return []SignCSROption{}, nil + } else { + lastErr = err + } + } else { + lastErr = ErrESTAuthMethodDisabled + } + } + + return nil, lastErr +} + +// validateBasicAuthPassword verifies the configured basic auth password. +func (s *EST) validateBasicAuthPassword(username, password string) error { + if s.BasicAuthUsername != "" && username != s.BasicAuthUsername { + return errors.New("invalid basic auth") + } + if subtleCompare(s.BasicAuthPassword, password) { + return nil + } + return errors.New("invalid basic auth") +} + +// authorizeWithWebhook executes configured webhooks for auth decisions. +func (s *EST) authorizeWithWebhook(ctx context.Context, req *ESTAuthRequest) ([]SignCSROption, error) { + if !s.hasAuthWebhooks() { + return nil, ErrESTAuthMethodMisconfigured + } + + var ( + whreq *webhook.RequestBody + err error + ) + switch { + case req.ClientCertificate != nil: + whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithClientCertificate(req.ClientCertificate)) + if err != nil { + return nil, fmt.Errorf("failed creating webhook request: %w", err) + } + case req.AuthenticationHeader != "": + whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithAuthenticationHeader(req.AuthenticationHeader)) + if err != nil { + return nil, fmt.Errorf("failed creating webhook request: %w", err) + } + if req.BearerToken != "" { + whreq.BearerToken = req.BearerToken + } + default: + return nil, errors.New("missing certificate or basic auth for webhook validation") + } + whreq.ProvisionerName = s.Name + var opts []SignCSROption + + for _, wh := range s.challengeValidationController.webhooks { + resp, err := wh.DoWithContext(ctx, s.challengeValidationController.client, s.challengeValidationController.wrapTransport, whreq, nil) + if err != nil { + return nil, fmt.Errorf("failed executing webhook request: %w", err) + } + if resp.Allow { + opts = append(opts, TemplateDataModifierFunc(func(data x509util.TemplateData) { + data.SetWebhook(wh.Name, resp.Data) + })) + } + } + + if len(opts) == 0 { + return nil, ErrESTAuthDenied + } + + return opts, nil +} + +// hasBasicAuth reports whether any basic auth data is present. +func (r ESTAuthRequest) hasBasicAuth() bool { + return r.BasicAuthUsername != "" || r.BasicAuthPassword != "" +} + +// hasAuthWebhooks reports whether auth webhooks are configured. +func (s *EST) hasAuthWebhooks() bool { + return s.challengeValidationController != nil && len(s.challengeValidationController.webhooks) > 0 +} + +// normalizeAuthConfig applies defaults and validates auth configuration. +func (s *EST) normalizeAuthConfig() error { + enable := true + if !s.authMethodsConfigured() { + s.EnableHTTPBasicAuth = &enable + } + if s.EnableHTTPBasicAuth == nil && (s.BasicAuthUsername != "" || s.BasicAuthPassword != "") { + s.EnableHTTPBasicAuth = &enable + } + if boolValue(s.EnableHTTPBasicAuth, false) && s.BasicAuthPassword == "" && !s.hasAuthWebhooks() { + return errors.New("basic auth password cannot be empty") + } + return nil +} + +// authMethodsConfigured reports whether any auth method is explicitly configured. +func (s *EST) authMethodsConfigured() bool { + return s.EnableTLSClientCertificate != nil || + s.hasClientCertificateRoots() || + s.EnableHTTPBasicAuth != nil +} + +// parseClientCertificateRoots loads external client certificate roots. +func (s *EST) parseClientCertificateRoots() error { + if len(s.ClientCertificateRoots) == 0 { + return nil + } + var ( + block *pem.Block + hasCert bool + rest = s.ClientCertificateRoots + ) + s.clientCertificateRootPool = x509.NewCertPool() + for rest != nil { + block, rest = pem.Decode(rest) + if block == nil { + break + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return errors.New("error parsing clientCertificateRoots: malformed certificate") + } + s.clientCertificateRootPool.AddCert(cert) + hasCert = true + } + if !hasCert { + return errors.New("error parsing clientCertificateRoots: no certificates found") + } + return nil +} + +func (s *EST) hasClientCertificateRoots() bool { + return len(s.ClientCertificateRoots) > 0 +} + +// verifyCertificate validates the client certificate against CA roots. +func verifyCertificate(cert *x509.Certificate, chain, roots, intermediates []*x509.Certificate) error { + rootPool := x509.NewCertPool() + for _, root := range roots { + if root != nil { + rootPool.AddCert(root) + } + } + intermediatePool := x509.NewCertPool() + for _, intermediate := range intermediates { + if intermediate != nil { + intermediatePool.AddCert(intermediate) + } + } + return verifyCertificateWithPool(cert, chain, rootPool, intermediatePool) +} + +// verifyCertificateWithPool validates the client certificate using explicit pools. +func verifyCertificateWithPool(cert *x509.Certificate, chain []*x509.Certificate, roots, intermediates *x509.CertPool) error { + if intermediates == nil { + intermediates = x509.NewCertPool() + } + for i, intermediate := range chain { + if i == 0 || intermediate == nil { + continue + } + intermediates.AddCert(intermediate) + } + _, err := cert.Verify(x509.VerifyOptions{ + Roots: roots, + Intermediates: intermediates, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + if err != nil { + return fmt.Errorf("invalid client certificate: %w", err) + } + return nil +} + +// boolValue returns the dereferenced value or a default. +func boolValue(value *bool, defaultValue bool) bool { + if value == nil { + return defaultValue + } + return *value +} + +// subtleCompare compares secrets in constant time. +func subtleCompare(expected, actual string) bool { + if len(expected) != len(actual) { + return false + } + return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1 +} diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 33d75fe9a..95c916436 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -208,6 +208,8 @@ const ( TypeSCEP Type = 10 // TypeNebula is used to indicate the Nebula provisioners TypeNebula Type = 11 + // TypeEST is used to indicate the EST provisioners + TypeEST Type = 12 ) // String returns the string representation of the type. @@ -235,6 +237,8 @@ func (t Type) String() string { return "SCEP" case TypeNebula: return "Nebula" + case TypeEST: + return "EST" default: return "" } @@ -328,6 +332,8 @@ func (l *List) UnmarshalJSON(data []byte) error { p = &SCEP{} case "nebula": p = &Nebula{} + case "est": + p = &EST{} default: // Skip unsupported provisioners. A client using this method may be // compiled with a version of smallstep/certificates that does not diff --git a/authority/provisioners.go b/authority/provisioners.go index d01048564..59cc65752 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1003,6 +1003,26 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, s.DecrypterKeyPassword = string(decrypter.KeyPassword) } return s, nil + case *linkedca.ProvisionerDetails_EST: + cfg := d.EST + enableTLSClientCertificate := cfg.EnableTlsClientCertificate + enableHTTPBasicAuth := cfg.EnableHttpBasicAuth + return &provisioner.EST{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + EnableTLSClientCertificate: &enableTLSClientCertificate, + EnableHTTPBasicAuth: &enableHTTPBasicAuth, + ForwardedTLSClientCertHeader: cfg.ForwardedTlsClientCertHeader, + BasicAuthUsername: cfg.BasicAuthUsername, + BasicAuthPassword: cfg.BasicAuthPassword, + ClientCertificateRoots: provisionerPEMToCertificates(cfg.ClientCertificateRoots), + ForceCN: cfg.ForceCn, + IncludeRoot: cfg.IncludeRoot, + MinimumPublicKeyLength: int(cfg.MinimumPublicKeyLength), + Claims: claims, + Options: options, + }, nil case *linkedca.ProvisionerDetails_Nebula: var roots []byte for i, root := range d.Nebula.GetRoots() { @@ -1278,6 +1298,35 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro SshTemplate: sshTemplate, Webhooks: webhooks, }, nil + case *provisioner.EST: + x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_EST, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_EST{ + EST: &linkedca.ESTProvisioner{ + ForceCn: p.ForceCN, + EnableTlsClientCertificate: p.EnableTLSClientCertificate != nil && *p.EnableTLSClientCertificate, + EnableHttpBasicAuth: p.EnableHTTPBasicAuth != nil && *p.EnableHTTPBasicAuth, + MinimumPublicKeyLength: cast.Int32(p.MinimumPublicKeyLength), + IncludeRoot: p.IncludeRoot, + BasicAuthUsername: p.BasicAuthUsername, + BasicAuthPassword: p.BasicAuthPassword, + ClientCertificateRoots: provisionerPEMToLinkedca(p.ClientCertificateRoots), + ForwardedTlsClientCertHeader: p.ForwardedTLSClientCertHeader, + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + Webhooks: webhooks, + }, nil case *provisioner.Nebula: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index da37eee58..056a44968 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server { if err != nil { panic(err) } - baseContext := buildContext(ca.auth, nil, nil, nil) + baseContext := buildContext(ca.auth, nil, nil, nil, nil) srv.Config.Handler = ca.srv.Handler srv.Config.BaseContext = func(net.Listener) context.Context { return baseContext diff --git a/ca/ca.go b/ca/ca.go index 3f0704a0a..ed62eeebf 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -39,6 +39,8 @@ import ( "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/monitoring" + "github.com/smallstep/certificates/est" + estAPI "github.com/smallstep/certificates/est/api" "github.com/smallstep/certificates/scep" scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" @@ -300,6 +302,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } var scepAuthority *scep.Authority + var estAuthority *est.Authority if ca.shouldServeSCEPEndpoints() { // get the SCEP authority configuration. Validation is // performed within the authority instantiation process. @@ -328,6 +331,15 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { //dumpRoutes(mux) //dumpRoutes(insecureMux) + // EST endpoints (HTTPS only) + estPrefix := ".well-known/est" + if estAuth := auth.GetEST(); estAuth != nil { + estAuthority = estAuth + mux.Route("/"+estPrefix, func(r chi.Router) { + estAPI.Route(r) + }) + } + // Add monitoring if configured if len(cfg.Monitoring) > 0 { m, err := monitoring.New(cfg.Monitoring) @@ -355,7 +367,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler) // Create context with all the necessary values. - baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) + baseContext := buildContext(auth, scepAuthority, estAuthority, acmeDB, acmeLinker) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -403,7 +415,7 @@ func (ca *CA) shouldServeInsecureServer() bool { } // buildContext builds the server base context. -func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, estAuthority *est.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -414,6 +426,9 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB if scepAuthority != nil { ctx = scep.NewContext(ctx, scepAuthority) } + if estAuthority != nil { + ctx = est.NewContext(ctx, estAuthority) + } if acmeDB != nil { ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) } diff --git a/ca/tls_test.go b/ca/tls_test.go index 465f1ede2..885b895ae 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -78,7 +78,7 @@ func startCATestServer(t *testing.T) *httptest.Server { ca, err := New(config) require.NoError(t, err) // Use a httptest.Server instead - baseContext := buildContext(ca.auth, nil, nil, nil) + baseContext := buildContext(ca.auth, nil, nil, nil, nil) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } diff --git a/est/api/api.go b/est/api/api.go new file mode 100644 index 000000000..d2589f723 --- /dev/null +++ b/est/api/api.go @@ -0,0 +1,345 @@ +// Package api implements an EST HTTP server. +package api + +import ( + "context" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi/v5" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/est" +) + +const ( + maxPayloadSize = 2 << 20 +) + +// Util to extract bearer token from request +func BearerToken(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + const prefix = "Bearer " + // Case insensitive prefix match. See Issue 22736. + if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "", false + } + return auth[len(prefix):], true +} + +// Route configures the EST routes under the provided router. +func Route(r api.Router) { + r.MethodFunc(http.MethodGet, "/{provisionerName}/cacerts", getCACerts) + r.MethodFunc(http.MethodGet, "/{provisionerName}/csrattrs", getCSRAttrs) + r.MethodFunc(http.MethodPost, "/{provisionerName}/simpleenroll", enroll) + r.MethodFunc(http.MethodPost, "/{provisionerName}/simplereenroll", enroll) +} + +func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "provisionerName") + if name == "" || name == "/" { + name = r.URL.Query().Get("provisioner") + } + if name == "" { + fail(w, r, errors.New("missing provisioner name")) + return + } + provisionerName, err := url.PathUnescape(name) + if err != nil { + fail(w, r, fmt.Errorf("error url unescaping provisioner name '%s'", name)) + return + } + + ctx := r.Context() + auth := authority.MustFromContext(ctx) + p, err := auth.LoadProvisionerByName(provisionerName) + if err != nil { + fail(w, r, err) + return + } + + prov, ok := p.(*provisioner.EST) + if !ok { + fail(w, r, errors.New("provisioner must be of type EST")) + return + } + + ctx = est.NewProvisionerContext(ctx, est.Provisioner(prov)) + next(w, r.WithContext(ctx)) + } +} + +func getCACerts(w http.ResponseWriter, r *http.Request) { + lookupProvisioner(getCACertsHandler)(w, r) +} + +func getCACertsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + auth := est.MustFromContext(ctx) + + certs, err := auth.GetCACertificates(ctx) + if err != nil { + fail(w, r, fmt.Errorf("failed to get CA certificates: %w", err)) + return + } + + data, err := auth.BuildResponse(ctx, certs) + if err != nil { + fail(w, r, fmt.Errorf("failed to encode CA certificates: %w", err)) + return + } + + writeResponse(w, data, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) +} + +func getCSRAttrs(w http.ResponseWriter, r *http.Request) { + lookupProvisioner(getCSRAttrsHandler)(w, r) +} + +func getCSRAttrsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prov := est.ProvisionerFromContext(ctx) + + attrs, err := prov.GetCSRAttributes(ctx) + if err != nil { + fail(w, r, fmt.Errorf("failed to get CSR attributes: %w", err)) + return + } + if attrs == nil { + attrs = []byte{} + } + // Minimal implementation: allow provisioner to return nil/empty for "no attributes". + writeResponse(w, attrs, "application/csrattrs", http.StatusOK) +} + +func enroll(w http.ResponseWriter, r *http.Request) { + lookupProvisioner(enrollHandler)(w, r) +} + +func enrollHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx, err := authContextFromRequest(ctx, r) + if err != nil { + failWithStatus(w, r, http.StatusUnauthorized, err) + return + } + + body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) + if err != nil { + failWithStatus(w, r, http.StatusBadRequest, fmt.Errorf("failed reading request body: %w", err)) + return + } + + if err := requireContentType(r, "application/pkcs10"); err != nil { + failWithStatus(w, r, http.StatusUnsupportedMediaType, err) + return + } + + der, err := decodeBase64Payload(body) + if err != nil { + failWithStatus(w, r, http.StatusBadRequest, err) + return + } + + csr, err := parseCSR(der) + if err != nil { + failWithStatus(w, r, http.StatusBadRequest, fmt.Errorf("failed parsing CSR: %w", err)) + return + } + if err := csr.CheckSignature(); err != nil { + failWithStatus(w, r, http.StatusBadRequest, fmt.Errorf("invalid CSR signature: %w", err)) + return + } + + opts, err := authorizeEnrollRequest(ctx, csr) + if err != nil { + failWithStatus(w, r, http.StatusUnauthorized, err) + return + } + + r = r.WithContext(ctx) + auth := est.MustFromContext(ctx) + + issued, err := auth.SignCSR(ctx, csr, opts...) + if err != nil { + failWithStatus(w, r, http.StatusInternalServerError, fmt.Errorf("failed issuing certificate: %w", err)) + return + } + + signed, err := auth.BuildResponse(ctx, []*x509.Certificate{issued}) + if err != nil { + failWithStatus(w, r, http.StatusInternalServerError, fmt.Errorf("failed encoding issued certificate: %w", err)) + return + } + + writeResponse(w, signed, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) +} + +var errMissingAuth = errors.New("missing authentication material") + +// authContextFromRequest extracts auth material from the request into the context. +func authContextFromRequest(ctx context.Context, r *http.Request) (context.Context, error) { + if r.TLS == nil { + return ctx, errors.New("missing TLS connection") + } + prov := est.ProvisionerFromContext(ctx) + cfg := prov.GetClientCertificateConfig() + + if cfg.Enable { + if cfg.ForwardedTLSClientCertHeader != "" { + // When a forwarded header is configured, only use it — never + // fall back to r.TLS.PeerCertificates, which would be the + // proxy's own certificate, not the actual client's. + if forwardedtlsClientCert := r.Header.Get(cfg.ForwardedTLSClientCertHeader); forwardedtlsClientCert != "" { + certDER, err := base64.StdEncoding.DecodeString(forwardedtlsClientCert) + if err != nil { + return ctx, fmt.Errorf("failed to decode client certificate from forwarded header: %w", err) + } + certs, err := x509.ParseCertificates(certDER) + if err != nil { + return ctx, fmt.Errorf("failed to parse client certificate from forwarded header: %w", err) + } + if len(certs) == 0 { + return ctx, errors.New("no certificates found in forwarded header") + } + ctx = est.NewClientCertificateContext(ctx, certs[0]) + ctx = est.NewClientCertificateChainContext(ctx, certs) + } + } else if len(r.TLS.PeerCertificates) > 0 { + ctx = est.NewClientCertificateContext(ctx, r.TLS.PeerCertificates[0]) + ctx = est.NewClientCertificateChainContext(ctx, r.TLS.PeerCertificates) + } + } + + if authHeader := r.Header.Get("Authorization"); authHeader != "" { + ctx = est.NewAuthenticationHeaderContext(ctx, authHeader) + if token, ok := BearerToken(r); ok { + ctx = est.NewBearerTokenContext(ctx, token) + } else if username, password, ok := r.BasicAuth(); ok { + ctx = est.NewBasicAuthContext(ctx, est.BasicAuth{ + Username: username, + Password: password, + }) + } + } + + if _, ok := est.ClientCertificateFromContext(ctx); !ok { + if _, ok := est.AuthenticationHeaderFromContext(ctx); !ok { + return ctx, errMissingAuth + } + } + return ctx, nil +} + +// authorizeEnrollRequest validates the request against provisioner-configured auth methods. +func authorizeEnrollRequest(ctx context.Context, csr *x509.CertificateRequest) ([]provisioner.SignCSROption, error) { + prov := est.ProvisionerFromContext(ctx) + ca := authority.MustFromContext(ctx) + + req := provisioner.ESTAuthRequest{ + CSR: csr, + CARoots: ca.GetRootCertificates(), + CAIntermediates: ca.GetIntermediateCertificates(), + } + if cert, ok := est.ClientCertificateFromContext(ctx); ok { + req.ClientCertificate = cert + req.ClientCertificateChain, _ = est.ClientCertificateChainFromContext(ctx) + } + if authHeader, ok := est.AuthenticationHeaderFromContext(ctx); ok { + req.AuthenticationHeader = authHeader + if auth, ok := est.BasicAuthFromContext(ctx); ok { + req.BasicAuthUsername = auth.Username + req.BasicAuthPassword = auth.Password + } + if token, ok := est.BearerTokenFromContext(ctx); ok { + req.BearerToken = token + } + } + + opts, err := prov.AuthorizeRequest(ctx, req) + if err != nil { + return nil, err + } + return opts, nil +} + +func parseCSR(body []byte) (*x509.CertificateRequest, error) { + if len(body) == 0 { + return nil, errors.New("empty body") + } + + return x509.ParseCertificateRequest(body) +} + +func decodeBase64Payload(body []byte) ([]byte, error) { + if len(body) == 0 { + return nil, errors.New("empty body") + } + + trimmed := strings.Map(func(r rune) rune { + switch r { + case ' ', '\n', '\r', '\t': + return -1 + default: + return r + } + }, string(body)) + + if trimmed == "" { + return nil, errors.New("empty base64 payload") + } + + decoded := make([]byte, base64.StdEncoding.DecodedLen(len(trimmed))) + n, err := base64.StdEncoding.Decode(decoded, []byte(trimmed)) + if err != nil { + return nil, fmt.Errorf("invalid base64 payload: %w", err) + } + + return decoded[:n], nil +} + +func requireContentType(r *http.Request, want string) error { + ct := r.Header.Get("Content-Type") + if ct == "" { + return errors.New("missing Content-Type header") + } + mt, _, err := mime.ParseMediaType(ct) + if err != nil { + return fmt.Errorf("invalid Content-Type header: %w", err) + } + if mt != want { + return fmt.Errorf("unsupported Content-Type %q", mt) + } + return nil +} + +func writeResponse(w http.ResponseWriter, data []byte, contentType string, status int) { + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Transfer-Encoding", "base64") + w.WriteHeader(status) + + encoder := base64.NewEncoder(base64.StdEncoding, w) + _, _ = encoder.Write(data) + _ = encoder.Close() +} + +func fail(w http.ResponseWriter, r *http.Request, err error) { + log.Error(w, r, err) + http.Error(w, err.Error(), http.StatusInternalServerError) +} + +func failWithStatus(w http.ResponseWriter, r *http.Request, status int, err error) { + log.Error(w, r, err) + http.Error(w, err.Error(), status) +} diff --git a/est/api/api_test.go b/est/api/api_test.go new file mode 100644 index 000000000..ecaebe564 --- /dev/null +++ b/est/api/api_test.go @@ -0,0 +1,71 @@ +package api + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_writeResponse(t *testing.T) { + type args struct { + w http.ResponseWriter + r *http.Request + data []byte + contentType string + status int + } + tests := []struct { + name string + args args + wantBody string + wantHeaders map[string]string + }{ + { + name: "ok", + args: args{ + w: httptest.NewRecorder(), + r: httptest.NewRequest("GET", "/", nil), + data: []byte("hello world"), + contentType: "application/pkcs7-mime; smime-type=certs-only", + status: http.StatusOK, + }, + wantBody: base64.StdEncoding.EncodeToString([]byte("hello world")), + wantHeaders: map[string]string{ + "Content-Type": "application/pkcs7-mime; smime-type=certs-only", + "Content-Transfer-Encoding": "base64", + }, + }, + { + name: "ok/csrattrs", + args: args{ + w: httptest.NewRecorder(), + r: httptest.NewRequest("GET", "/", nil), + data: []byte("attribute data"), + contentType: "application/csrattrs", + status: http.StatusOK, + }, + wantBody: base64.StdEncoding.EncodeToString([]byte("attribute data")), + wantHeaders: map[string]string{ + "Content-Type": "application/csrattrs", + "Content-Transfer-Encoding": "base64", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writeResponse(tt.args.w, tt.args.data, tt.args.contentType, tt.args.status) + resp := tt.args.w.(*httptest.ResponseRecorder) + + assert.Equal(t, tt.args.status, resp.Code) + assert.Equal(t, tt.wantBody, resp.Body.String()) + + for k, v := range tt.wantHeaders { + assert.Equal(t, v, resp.Header().Get(k)) + } + }) + } +} diff --git a/est/auth_context.go b/est/auth_context.go new file mode 100644 index 000000000..687e17d1e --- /dev/null +++ b/est/auth_context.go @@ -0,0 +1,58 @@ +package est + +import "context" + +// AuthenticationHeaderKey is the context key used to store the EST authentication header. +type AuthenticationHeaderKey struct{} + +// NewAuthenticationHeaderContext stores the EST authentication header in the context. +func NewAuthenticationHeaderContext(ctx context.Context, header string) context.Context { + if header == "" { + return ctx + } + return context.WithValue(ctx, AuthenticationHeaderKey{}, header) +} + +// AuthenticationHeaderFromContext returns the EST authentication header stored in the context. +func AuthenticationHeaderFromContext(ctx context.Context) (string, bool) { + header, ok := ctx.Value(AuthenticationHeaderKey{}).(string) + return header, ok +} + +// BasicAuth holds the HTTP basic auth credentials for an EST request. +type BasicAuth struct { + Username string + Password string +} + +type basicAuthKey struct{} + +// NewBasicAuthContext stores the HTTP basic auth credentials in the context. +func NewBasicAuthContext(ctx context.Context, auth BasicAuth) context.Context { + if auth.Username == "" && auth.Password == "" { + return ctx + } + return context.WithValue(ctx, basicAuthKey{}, auth) +} + +// BasicAuthFromContext returns the HTTP basic auth credentials stored in the context. +func BasicAuthFromContext(ctx context.Context) (BasicAuth, bool) { + auth, ok := ctx.Value(basicAuthKey{}).(BasicAuth) + return auth, ok +} + +type BearerTokenKey struct{} + +// NewBearerTokenContext stores the HTTP bearer token in the context. +func NewBearerTokenContext(ctx context.Context, token string) context.Context { + if token == "" { + return ctx + } + return context.WithValue(ctx, BearerTokenKey{}, token) +} + +// BearerTokenFromContext returns the HTTP bearer token stored in the context. +func BearerTokenFromContext(ctx context.Context) (string, bool) { + token, ok := ctx.Value(BearerTokenKey{}).(string) + return token, ok +} diff --git a/est/authority.go b/est/authority.go new file mode 100644 index 000000000..e7a6da41d --- /dev/null +++ b/est/authority.go @@ -0,0 +1,238 @@ +package est + +import ( + "bytes" + "context" + "crypto" + "crypto/x509" + "fmt" + "sync" + + "github.com/smallstep/pkcs7" + + "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/provisioner" +) + +// Authority handles EST interactions. +type Authority struct { + signAuth SignAuthority + roots []*x509.Certificate + intermediates []*x509.Certificate + defaultSigner crypto.Signer + signerCertificate *x509.Certificate + estProvisionerNames []string + provisionersMutex sync.RWMutex +} + +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + var ( + a *Authority + ok bool + ) + if a, ok = FromContext(ctx); !ok { + panic("est authority is not in the context") + } + return a +} + +// SignAuthority is the interface for a signing authority. +type SignAuthority interface { + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + LoadProvisionerByName(string) (provisioner.Interface, error) +} + +// New returns a new Authority that implements the EST interface. +func New(signAuth SignAuthority, opts Options) (*Authority, error) { + if err := opts.Validate(); err != nil { + return nil, err + } + + return &Authority{ + signAuth: signAuth, + roots: opts.Roots, + intermediates: opts.Intermediates, + defaultSigner: opts.Signer, + signerCertificate: opts.SignerCert, + estProvisionerNames: opts.ESTProvisionerNames, + }, nil +} + +// validates if the EST Authority has a valid configuration. +func (a *Authority) Validate() error { + if a == nil { + return nil + } + + a.provisionersMutex.RLock() + defer a.provisionersMutex.RUnlock() + + noDefaultSignerAvailable := a.defaultSigner == nil || a.signerCertificate == nil + for _, name := range a.estProvisionerNames { + p, err := a.LoadProvisionerByName(name) + if err != nil { + return fmt.Errorf("failed loading provisioner %q: %w", name, err) + } + if estProv, ok := p.(*provisioner.EST); ok { + cert, signer := estProv.GetSigner() + if cert == nil && noDefaultSignerAvailable { + return fmt.Errorf("EST provisioner %q does not have a signer certificate", name) + } + if signer == nil && noDefaultSignerAvailable { + return fmt.Errorf("EST provisioner %q does not have a signer", name) + } + } + } + + return nil +} + +// UpdateProvisioners updates the EST Authority with the new, and hopefully +// current EST provisioners configured. This allows the Authority to be +// validated with the latest data. +func (a *Authority) UpdateProvisioners(estProvisionerNames []string) { + if a == nil { + return + } + + a.provisionersMutex.Lock() + defer a.provisionersMutex.Unlock() + + a.estProvisionerNames = estProvisionerNames +} + +// LoadProvisionerByName calls out to the SignAuthority interface to load a +// provisioner by name. +func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + return a.signAuth.LoadProvisionerByName(name) +} + +// GetCACertificates returns the certificate chain for the CA. +func (a *Authority) GetCACertificates(ctx context.Context) (certs []*x509.Certificate, err error) { + p := provisionerFromContext(ctx) + + if signerCert, _ := p.GetSigner(); signerCert != nil { + certs = append(certs, signerCert) + } + + if p.ShouldIncludeIntermediateInChain() || len(certs) == 0 { + certs = append(certs, a.intermediates...) + } + + if p.ShouldIncludeRootInChain() { + certs = append(certs, a.roots...) + } + + return certs, nil +} + +// SignCSR signs the CSR using the provisioner and returns the issued chain. +func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, signCSROpts ...provisioner.SignCSROption) (*x509.Certificate, error) { + // TODO: intermediate storage of the request? In EST it's possible to request a csr/certificate + // to be signed, which can be performed asynchronously / out-of-band. In that case a client can + // poll for the status. It seems to be similar as what can happen in ACME and SCEP, so might want to model + // the implementation after the one in the ACME authority. Requires storage, etc. + // ref: https://datatracker.ietf.org/doc/html/rfc7030#section-4.2.3 + p := provisionerFromContext(ctx) + + // Template data + sans := []string{} + sans = append(sans, csr.DNSNames...) + sans = append(sans, csr.EmailAddresses...) + for _, v := range csr.IPAddresses { + sans = append(sans, v.String()) + } + for _, v := range csr.URIs { + sans = append(sans, v.String()) + } + if len(sans) == 0 { + sans = append(sans, csr.Subject.CommonName) + } + data := x509util.CreateTemplateData(csr.Subject.CommonName, sans) + data.SetCertificateRequest(csr) + data.SetSubject(x509util.Subject{ + Country: csr.Subject.Country, + Organization: csr.Subject.Organization, + OrganizationalUnit: csr.Subject.OrganizationalUnit, + Locality: csr.Subject.Locality, + Province: csr.Subject.Province, + StreetAddress: csr.Subject.StreetAddress, + PostalCode: csr.Subject.PostalCode, + SerialNumber: csr.Subject.SerialNumber, + CommonName: csr.Subject.CommonName, + }) + + for _, o := range signCSROpts { + if m, ok := o.(provisioner.TemplateDataModifier); ok { + m.Modify(data) + } + } + + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + signOps, err := p.AuthorizeSign(ctx, "") + if err != nil { + return nil, fmt.Errorf("error retrieving authorization options from EST provisioner: %w", err) + } + for _, signOp := range signOps { + if wc, ok := signOp.(*provisioner.WebhookController); ok { + wc.TemplateData = data + } + } + + opts := provisioner.SignOptions{} + templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) + if err != nil { + return nil, fmt.Errorf("error creating template options from EST provisioner: %w", err) + } + signOps = append(signOps, templateOptions) + + certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...) + if err != nil { + return nil, fmt.Errorf("error generating certificate: %w", err) + } + // return leaf certificate (only): https://datatracker.ietf.org/doc/html/rfc7030#section-4.2.3 + return certChain[0], nil +} + +// BuildResponse returns a certs-only PKCS7 SignedData for the given certs. +func (a *Authority) BuildResponse(ctx context.Context, certs []*x509.Certificate) ([]byte, error) { + if len(certs) == 0 { + return nil, fmt.Errorf("no certificates to encode") + } + // Build degenerate PKCS7: SignedData with no encapsulated content or signer infos. + var buf bytes.Buffer + for _, cert := range certs { + buf.Write(cert.Raw) + } + degenerate, err := pkcs7.DegenerateCertificate(buf.Bytes()) + if err != nil { + return nil, err + } + return degenerate, nil +} + +func (a *Authority) NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error { + p := provisionerFromContext(ctx) + return p.NotifySuccess(ctx, csr, cert, transactionID) +} + +func (a *Authority) NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error { + p := provisionerFromContext(ctx) + return p.NotifyFailure(ctx, csr, transactionID, errorCode, errorDescription) +} diff --git a/est/client_cert.go b/est/client_cert.go new file mode 100644 index 000000000..14c26f291 --- /dev/null +++ b/est/client_cert.go @@ -0,0 +1,37 @@ +package est + +import ( + "context" + "crypto/x509" +) + +type clientCertificateKey struct{} +type clientCertificateChainKey struct{} + +// NewClientCertificateContext stores the TLS client certificate in the context. +func NewClientCertificateContext(ctx context.Context, cert *x509.Certificate) context.Context { + if cert == nil { + return ctx + } + return context.WithValue(ctx, clientCertificateKey{}, cert) +} + +// ClientCertificateFromContext returns the TLS client certificate stored in the context. +func ClientCertificateFromContext(ctx context.Context) (*x509.Certificate, bool) { + cert, ok := ctx.Value(clientCertificateKey{}).(*x509.Certificate) + return cert, ok +} + +// NewClientCertificateChainContext stores the TLS client certificate chain in the context. +func NewClientCertificateChainContext(ctx context.Context, chain []*x509.Certificate) context.Context { + if len(chain) == 0 { + return ctx + } + return context.WithValue(ctx, clientCertificateChainKey{}, chain) +} + +// ClientCertificateChainFromContext returns the TLS client certificate chain stored in the context. +func ClientCertificateChainFromContext(ctx context.Context) ([]*x509.Certificate, bool) { + chain, ok := ctx.Value(clientCertificateChainKey{}).([]*x509.Certificate) + return chain, ok +} diff --git a/est/options.go b/est/options.go new file mode 100644 index 000000000..df2a48b26 --- /dev/null +++ b/est/options.go @@ -0,0 +1,40 @@ +package est + +import ( + "crypto" + "crypto/x509" + "errors" +) + +// Options configures the EST authority instance. +type Options struct { + Roots []*x509.Certificate `json:"-"` + Intermediates []*x509.Certificate `json:"-"` + SignerCert *x509.Certificate `json:"-"` + Signer crypto.Signer `json:"-"` + + ESTProvisionerNames []string +} + +type comparablePublicKey interface { + Equal(crypto.PublicKey) bool +} + +// Validate checks the fields in Options. +func (o *Options) Validate() error { + switch { + case len(o.Intermediates) == 0: + return errors.New("no intermediate certificate available for EST authority") + case o.SignerCert == nil: + return errors.New("no signer certificate available for EST authority") + } + + if o.Signer != nil { + signerPublicKey := o.Signer.Public().(comparablePublicKey) + if !signerPublicKey.Equal(o.SignerCert.PublicKey) { + return errors.New("mismatch between signer certificate and public key") + } + } + + return nil +} diff --git a/est/provisioner.go b/est/provisioner.go new file mode 100644 index 000000000..69bce1713 --- /dev/null +++ b/est/provisioner.go @@ -0,0 +1,47 @@ +package est + +import ( + "context" + "crypto" + "crypto/x509" + + "github.com/smallstep/certificates/authority/provisioner" +) + +// Provisioner is an interface that embeds the generic provisioner.Interface and +// adds EST-specific helpers. +type Provisioner interface { + provisioner.Interface + GetOptions() *provisioner.Options + GetClientCertificateConfig() *provisioner.ClientCertificateConfig + ShouldIncludeRootInChain() bool + ShouldIncludeIntermediateInChain() bool + GetSigner() (*x509.Certificate, crypto.Signer) + AuthorizeRequest(ctx context.Context, req provisioner.ESTAuthRequest) ([]provisioner.SignCSROption, error) + NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error + NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error + GetCSRAttributes(ctx context.Context) ([]byte, error) +} + +// provisionerKey is the key type for storing and searching an EST provisioner in the context. +type provisionerKey struct{} + +// provisionerFromContext searches the context for an EST provisioner. +// Returns the provisioner or panics if no EST provisioner is found. +func provisionerFromContext(ctx context.Context) Provisioner { + p, ok := ctx.Value(provisionerKey{}).(Provisioner) + if !ok { + panic("EST provisioner expected in request context") + } + return p +} + +// NewProvisionerContext returns a new context with the EST provisioner set. +func NewProvisionerContext(ctx context.Context, p Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, p) +} + +// ProvisionerFromContext returns the EST provisioner stored in the context. +func ProvisionerFromContext(ctx context.Context) Provisioner { + return provisionerFromContext(ctx) +} diff --git a/go.mod b/go.mod index 5c74edf1a..ce250f704 100644 --- a/go.mod +++ b/go.mod @@ -172,6 +172,8 @@ require ( google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect - google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect + google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.6.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/smallstep/linkedca => github.com/jbpin/linkedca v0.0.0-20260119192234-bf5917d1c897 diff --git a/go.sum b/go.sum index 376b3d16f..fff8de384 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jbpin/linkedca v0.0.0-20260119192234-bf5917d1c897 h1:DTvDVMjQ4IGEn+l7g1MqCZpqKFDU5ZTlRzTDTIMFUZs= +github.com/jbpin/linkedca v0.0.0-20260119192234-bf5917d1c897/go.mod h1:bUm0HgkgtOjbnwShtNKu3XPuQ6AA/o680sxTfy9vdYk= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= @@ -352,8 +354,6 @@ github.com/smallstep/cli-utils v0.12.2 h1:lGzM9PJrH/qawbzMC/s2SvgLdJPKDWKwKzx9do github.com/smallstep/cli-utils v0.12.2/go.mod h1:uCPqefO29goHLGqFnwk0i8W7XJu18X3WHQFRtOm/00Y= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca h1:VX8L0r8vybH0bPeaIxh4NQzafKQiqvlOn8pmOXbFLO4= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca/go.mod h1:vNAduivU014fubg6ewygkAvQC0IQVXqdc8vaGl/0er4= -github.com/smallstep/linkedca v0.25.0 h1:txT9QHGbCsJq0MhAghBq7qhurGY727tQuqUi+n4BVBo= -github.com/smallstep/linkedca v0.25.0/go.mod h1:Q3jVAauFKNlF86W5/RFtgQeyDKz98GL/KN3KG4mJOvc= github.com/smallstep/nosql v0.8.0 h1:FBTCUfKPmWYbrozW+RBKu+fnvbn+zr5rVli/XB4Jp4A= github.com/smallstep/nosql v0.8.0/go.mod h1:5dUpNotHLHhOUapP0PLBVVfp3tG1DFC31VRccg+Cqwo= github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA= @@ -533,8 +533,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 h1:F29+wU6Ee6qgu9TddPgooOdaqsxTMunOoj8KA5yuS5A= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1/go.mod h1:5KF+wpkbTSbGcR9zteSqZV6fqFOWBl4Yde8En8MryZA= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.6.0 h1:6Al3kEFFP9VJhRz3DID6quisgPnTeZVr4lep9kkxdPA= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.6.0/go.mod h1:QLvsjh0OIR0TYBeiu2bkWGTJBUNQ64st52iWj/yA93I= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/webhook/options.go b/webhook/options.go index 62f0170ae..388b97c9b 100644 --- a/webhook/options.go +++ b/webhook/options.go @@ -134,7 +134,44 @@ func WithX5CCertificate(leaf *x509.Certificate) RequestBodyOption { } rb.X5CCertificate.PublicKey = key } + return nil + } +} + +func WithClientCertificate(cert *x509.Certificate) RequestBodyOption { + return func(rb *RequestBody) error { + certificate, err := x509util.NewCertificateFromX509(cert) + if err != nil { + return err + } + rb.ClientCertificate = &X509Certificate{ + Raw: cert.Raw, + PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + Certificate: certificate, + } + if cert.PublicKey != nil { + key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) + if err != nil { + return err + } + rb.ClientCertificate.PublicKey = key + } + return nil + } +} + +func WithAuthenticationHeader(header string) RequestBodyOption { + return func(rb *RequestBody) error { + rb.AuthenticationHeader = header + return nil + } +} +func WithBearerToken(token string) RequestBodyOption { + return func(rb *RequestBody) error { + rb.BearerToken = token return nil } } diff --git a/webhook/types.go b/webhook/types.go index c60de7099..a78356c11 100644 --- a/webhook/types.go +++ b/webhook/types.go @@ -102,4 +102,8 @@ type RequestBody struct { X5CCertificate *X5CCertificate `json:"x5cCertificate,omitempty"` // Set for X5C, AWS, GCP, and Azure provisioners AuthorizationPrincipal string `json:"authorizationPrincipal,omitempty"` + // Set for EST webhook requests + AuthenticationHeader string `json:"authenticationHeader,omitempty"` + BearerToken string `json:"bearerToken,omitempty"` + ClientCertificate *X509Certificate `json:"clientCertificate,omitempty"` }