From 4f27dcfaf5fc0a5b40f78b324333c6c1aed41bcd Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Mon, 30 Mar 2026 16:04:16 +0200 Subject: [PATCH 1/8] add EST provisioner Implementation of [RFC 7030] (https://datatracker.ietf.org/doc/html/rfc7030). Support TLS client certificate authentication and basic auth. Support webhook for authentication, notification and data. Not covered : * full CMC * server-side key generation --- .gitignore | 1 + api/api.go | 22 +- authority/authority.go | 82 +++++- authority/options.go | 12 + authority/provisioner/est.go | 195 ++++++++++++++ authority/provisioner/est_auth.go | 378 +++++++++++++++++++++++++++ authority/provisioner/provisioner.go | 6 + ca/bootstrap_test.go | 2 +- ca/ca.go | 19 +- ca/tls_test.go | 2 +- est/api/api.go | 304 +++++++++++++++++++++ est/api/api_test.go | 71 +++++ est/auth_context.go | 50 ++++ est/authority.go | 238 +++++++++++++++++ est/client_cert.go | 37 +++ est/options.go | 40 +++ est/provisioner.go | 46 ++++ 17 files changed, 1492 insertions(+), 13 deletions(-) create mode 100644 authority/provisioner/est.go create mode 100644 authority/provisioner/est_auth.go create mode 100644 est/api/api.go create mode 100644 est/api/api_test.go create mode 100644 est/auth_context.go create mode 100644 est/authority.go create mode 100644 est/client_cert.go create mode 100644 est/options.go create mode 100644 est/provisioner.go diff --git a/.gitignore b/.gitignore index c17ed53a2..313f51d5b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ go.work.sum # Output of the go coverage tool, specifically when used with LiteIDE *.out +.gocache # Others *.swp diff --git a/api/api.go b/api/api.go index 09d2c83fb..c3e128fe9 100644 --- a/api/api.go +++ b/api/api.go @@ -257,22 +257,32 @@ func scepFromProvisioner(p *provisioner.SCEP) *models.SCEP { } } +func estFromProvisioner(p *provisioner.EST) *provisioner.EST { + prov := *p + prov.ClientCertificateRoots = []byte(redacted) + prov.BasicAuthUsername = redacted + prov.BasicAuthPassword = redacted + return &prov +} + // MarshalJSON implements json.Marshaler. It marshals the ProvisionersResponse // into a byte slice. // // Special treatment is given to the SCEP provisioner, as it contains a // challenge secret that MUST NOT be leaked in (public) HTTP responses. The -// challenge value is thus redacted in HTTP responses. +// challenge value is thus redacted in HTTP responses. EST provisioners also +// contain a shared secret and are redacted in responses. func (p ProvisionersResponse) MarshalJSON() ([]byte, error) { var responseProvisioners provisioner.List for _, item := range p.Provisioners { - scepProv, ok := item.(*provisioner.SCEP) - if !ok { + switch prov := item.(type) { + case *provisioner.SCEP: + responseProvisioners = append(responseProvisioners, scepFromProvisioner(prov)) + case *provisioner.EST: + responseProvisioners = append(responseProvisioners, estFromProvisioner(prov)) + default: responseProvisioners = append(responseProvisioners, item) - continue } - - responseProvisioners = append(responseProvisioners, scepFromProvisioner(scepProv)) } var list = struct { diff --git a/authority/authority.go b/authority/authority.go index 98dd68968..d7e52bf21 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -33,6 +33,7 @@ import ( "github.com/smallstep/certificates/cas" casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/est" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/templates" @@ -70,6 +71,11 @@ type Authority struct { scepAuthority *scep.Authority scepKeyManager provisioner.SCEPKeyManager + // EST CA + estOptions *est.Options + validateEST bool + estAuthority *est.Authority + // SSH CA sshHostPassword []byte sshUserPassword []byte @@ -133,6 +139,7 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { config: cfg, certificates: new(sync.Map), validateSCEP: true, + validateEST: true, meter: noopMeter{}, wrapTransport: httptransport.NoopWrapper(), } @@ -170,6 +177,7 @@ func NewEmbedded(opts ...Option) (*Authority, error) { certificates: new(sync.Map), meter: noopMeter{}, wrapTransport: httptransport.NoopWrapper(), + validateEST: true, } // Apply options. @@ -314,6 +322,11 @@ func (a *Authority) ReloadAdminResources(ctx context.Context) error { // TODO(hs): don't remove the authority if we can't also // reload it. //a.scepAuthority = nil + case a.requiresEST() && a.GetEST() != nil: + a.estAuthority.UpdateProvisioners(a.getESTProvisionerNames()) + if err := a.estAuthority.Validate(); err != nil { + log.Printf("failed validating EST authority: %v\n", err) + } } return nil @@ -814,6 +827,51 @@ func (a *Authority) init() error { } } + // EST functionality is provided through an instance of est.Authority. + switch { + case a.requiresEST() && a.GetEST() == nil: + if a.estOptions == nil { + options := &est.Options{ + Roots: a.rootX509Certs, + Intermediates: a.intermediateX509Certs, + } + if len(a.intermediateX509Certs) > 0 { + options.SignerCert = a.intermediateX509Certs[0] + } + if a.config.IntermediateKey != "" { + if signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ + SigningKey: a.config.IntermediateKey, + Password: a.password, + }); err == nil { + options.Signer = signer + } + } + a.estOptions = options + } + + a.estOptions.ESTProvisionerNames = a.getESTProvisionerNames() + + estAuthority, err := est.New(a, *a.estOptions) + if err != nil { + return err + } + + if a.validateEST { + if err := estAuthority.Validate(); err != nil { + a.initLogf("failed validating EST authority: %v", err) + } + } + + a.estAuthority = estAuthority + case !a.requiresEST() && a.GetEST() != nil: + a.estAuthority = nil + case a.requiresEST() && a.GetEST() != nil: + a.estAuthority.UpdateProvisioners(a.getESTProvisionerNames()) + if err := a.estAuthority.Validate(); err != nil { + log.Printf("failed validating EST authority: %v\n", err) + } + } + // Load X509 constraints engine. // // This is currently only available in CA mode. @@ -1002,16 +1060,34 @@ func (a *Authority) GetSCEP() *scep.Authority { return a.scepAuthority } -// HasACMEProvisioner returns true if at least one ACME provisioner is configured. -func (a *Authority) HasACMEProvisioner() bool { +// requiresEST iterates over the configured provisioners +// and determines if at least one of them is an EST provisioner. +func (a *Authority) requiresEST() bool { for _, p := range a.config.AuthorityConfig.Provisioners { - if p.GetType() == provisioner.TypeACME { + if p.GetType() == provisioner.TypeEST { return true } } return false } +// getESTProvisionerNames returns the names of the EST provisioners +// that are currently available in the CA. +func (a *Authority) getESTProvisionerNames() (names []string) { + for _, p := range a.config.AuthorityConfig.Provisioners { + if p.GetType() == provisioner.TypeEST { + names = append(names, p.GetName()) + } + } + + return +} + +// GetEST returns the configured EST Authority +func (a *Authority) GetEST() *est.Authority { + return a.estAuthority +} + func (a *Authority) startCRLGenerator() error { if !a.config.CRL.IsEnabled() { return nil diff --git a/authority/options.go b/authority/options.go index d85a611dc..e840ae5b6 100644 --- a/authority/options.go +++ b/authority/options.go @@ -18,6 +18,7 @@ import ( casapi "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/httptransport" + "github.com/smallstep/certificates/est" "github.com/smallstep/certificates/scep" ) @@ -242,6 +243,17 @@ func WithFullSCEPOptions(options *scep.Options) Option { } } +// WithFullESTOptions defines the options used for EST support. +// +// This feature is EXPERIMENTAL and might change at any time. +func WithFullESTOptions(options *est.Options) Option { + return func(a *Authority) error { + a.estOptions = options + a.validateEST = false + return nil + } +} + // WithSCEPKeyManager defines the key manager used on SCEP provisioners. // // This feature is EXPERIMENTAL and might change at any time. diff --git a/authority/provisioner/est.go b/authority/provisioner/est.go new file mode 100644 index 000000000..43479d24e --- /dev/null +++ b/authority/provisioner/est.go @@ -0,0 +1,195 @@ +package provisioner + +import ( + "context" + "crypto" + "crypto/x509" + "fmt" + "time" + + "github.com/pkg/errors" + + "github.com/smallstep/certificates/internal/httptransport" + "github.com/smallstep/linkedca" +) + +// EST is the EST provisioner type, an entity that can authorize the EST flow. +type EST struct { + *base + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + EnableTLSClientCertificate *bool `json:"enableTLSClientCertificate,omitempty"` + EnableHTTPBasicAuth *bool `json:"enableHTTPBasicAuth,omitempty"` + BasicAuthUsername string `json:"basicAuthUsername,omitempty"` + BasicAuthPassword string `json:"basicAuthPassword,omitempty"` + ClientCertificateRoots []byte `json:"clientCertificateRoots,omitempty"` + ForceCN bool `json:"forceCN,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` + IncludeRoot bool `json:"includeRoot,omitempty"` + ExcludeIntermediate bool `json:"excludeIntermediate,omitempty"` + MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + CSRAttrs []byte `json:"csrAttrs,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + ctl *Controller + signer crypto.Signer + signerCertificate *x509.Certificate + challengeValidationController *challengeValidationController + notificationController *notificationController + clientCertificateRootPool *x509.CertPool +} + +// GetID returns the provisioner unique identifier. +func (s *EST) GetID() string { + if s.ID != "" { + return s.ID + } + return s.GetIDForToken() +} + +// GetIDForToken returns an identifier that will be used to load the provisioner from a token. +func (s *EST) GetIDForToken() string { + return "est/" + s.Name +} + +// GetName returns the name of the provisioner. +func (s *EST) GetName() string { + return s.Name +} + +// GetType returns the type of provisioner. +func (s *EST) GetType() Type { + return TypeEST +} + +// GetEncryptedKey returns the base provisioner encrypted key if it's defined. +func (s *EST) GetEncryptedKey() (string, string, bool) { + return "", "", false +} + +// GetTokenID returns the identifier of the token. This provisioner does not support tokens. +func (s *EST) GetTokenID(string) (string, error) { + return "", ErrTokenFlowNotSupported +} + +// GetOptions returns the configured provisioner options. +func (s *EST) GetOptions() *Options { + return s.Options +} + +// DefaultTLSCertDuration returns the default TLS cert duration enforced by the provisioner. +func (s *EST) DefaultTLSCertDuration() time.Duration { + return s.ctl.Claimer.DefaultTLSCertDuration() +} + +// newChallengeValidationController creates a new challengeValidationController +// that performs challenge validation through webhooks. +func newESTChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { + scepHooks := []*Webhook{} + for _, wh := range webhooks { + // if wh.Kind != linkedca.Webhook_ESTCHALLENGE.String() { + if wh.Kind != "ESTCHALLENGE" { + continue + } + if !isCertTypeOK(wh) { + continue + } + scepHooks = append(scepHooks, wh) + } + return &challengeValidationController{ + client: client, + wrapTransport: tw, + webhooks: scepHooks, + } +} + +// Init initializes and validates the fields of an EST type. +func (s *EST) Init(config Config) (err error) { + switch { + case s.Type == "": + return errors.New("provisioner type cannot be empty") + case s.Name == "": + return errors.New("provisioner name cannot be empty") + } + + if s.MinimumPublicKeyLength == 0 { + s.MinimumPublicKeyLength = 2048 + } + if s.MinimumPublicKeyLength%8 != 0 { + return errors.Errorf("%d bits is not exactly divisible by 8", s.MinimumPublicKeyLength) + } + + // Prepare the EST challenge validator + s.challengeValidationController = newESTChallengeValidationController( + config.WebhookClient, + config.WrapTransport, + s.GetOptions().GetWebhooks(), + ) + + // Prepare the EST notification controller + s.notificationController = newNotificationController( + config.WebhookClient, + config.WrapTransport, + s.GetOptions().GetWebhooks(), + ) + + if err := s.parseClientCertificateRoots(); err != nil { + return err + } + + if err := s.normalizeAuthConfig(); err != nil { + return err + } + + s.ctl, err = NewController(s, s.Claims, config, s.Options) + return err +} + +// AuthorizeSign does not do any verification; main validation is in the EST protocol. +func (s *EST) AuthorizeSign(context.Context, string) ([]SignOption, error) { + return []SignOption{ + s, + newProvisionerExtensionOption(TypeEST, s.Name, "").WithControllerOptions(s.ctl), + newForceCNOption(s.ForceCN), + profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), + newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), + newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), + newX509NamePolicyValidator(s.ctl.getPolicy().getX509()), + s.ctl.newWebhookController(nil, linkedca.Webhook_X509), + }, nil +} + +// ShouldIncludeRootInChain indicates if the CA should return its root in the chain. +func (s *EST) ShouldIncludeRootInChain() bool { + return s.IncludeRoot +} + +// ShouldIncludeIntermediateInChain indicates if the CA should include the intermediate CA certificate. +func (s *EST) ShouldIncludeIntermediateInChain() bool { + return !s.ExcludeIntermediate +} + +// GetSigner returns the provisioner specific signer, used to sign EST responses. +func (s *EST) GetSigner() (*x509.Certificate, crypto.Signer) { + return s.signerCertificate, s.signer +} + +// GetCSRAttributes returns the CSR attributes to signal to clients. +func (s *EST) GetCSRAttributes(context.Context) ([]byte, error) { + return s.CSRAttrs, nil +} + +func (s *EST) NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error { + if s.notificationController == nil { + return fmt.Errorf("provisioner %q wasn't initialized", s.Name) + } + return s.notificationController.Success(ctx, csr, cert, transactionID) +} + +func (s *EST) NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error { + if s.notificationController == nil { + return fmt.Errorf("provisioner %q wasn't initialized", s.Name) + } + return s.notificationController.Failure(ctx, csr, transactionID, errorCode, errorDescription) +} diff --git a/authority/provisioner/est_auth.go b/authority/provisioner/est_auth.go new file mode 100644 index 000000000..897279879 --- /dev/null +++ b/authority/provisioner/est_auth.go @@ -0,0 +1,378 @@ +package provisioner + +import ( + "context" + "crypto/subtle" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + + "github.com/smallstep/certificates/webhook" +) + +var ( + ErrESTAuthMethodDisabled = errors.New("est authentication method disabled") + ErrESTAuthDenied = errors.New("est authentication denied") +) + +// ESTAuthMethod identifies the EST authentication method used. +type ESTAuthMethod string + +const ( + ESTAuthMethodTLSClientCertificate ESTAuthMethod = "tls-client-certificate" + ESTAuthMethodTLSExternalClientCertificate ESTAuthMethod = "tls-external-client-certificate" + ESTAuthMethodHTTPBasicAuth ESTAuthMethod = "http-basic-auth" +) + +// ESTAuthRequest contains authentication material extracted from the request. +type ESTAuthRequest struct { + CSR *x509.CertificateRequest + ClientCertificate *x509.Certificate + ClientCertificateChain []*x509.Certificate + CARoots []*x509.Certificate + CAIntermediates []*x509.Certificate + BasicAuthUsername string + BasicAuthPassword string +} + +// AuthorizeRequest validates the request against configured EST auth methods. +func (s *EST) AuthorizeRequest(ctx context.Context, req ESTAuthRequest) (ESTAuthMethod, error) { + if s.hasAuthWebhooks() { + return s.authorizeRequestWithWebhook(ctx, req) + } + return s.authorizeRequestLocal(ctx, req) +} + +// AuthorizeTLSClientCertificate validates a CA-issued client certificate. +func (s *EST) AuthorizeTLSClientCertificate(ctx context.Context, cert *x509.Certificate, chain, roots, intermediates []*x509.Certificate) error { + method, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ + ClientCertificate: cert, + ClientCertificateChain: chain, + CARoots: roots, + CAIntermediates: intermediates, + }) + if err != nil { + return err + } + if method != ESTAuthMethodTLSClientCertificate { + return ErrESTAuthDenied + } + return nil +} + +// AuthorizeTLSExternalClientCertificate validates a client certificate against external roots. +func (s *EST) AuthorizeTLSExternalClientCertificate(ctx context.Context, cert *x509.Certificate, chain []*x509.Certificate) error { + method, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ + ClientCertificate: cert, + ClientCertificateChain: chain, + }) + if err != nil { + return err + } + if method != ESTAuthMethodTLSExternalClientCertificate { + return ErrESTAuthDenied + } + return nil +} + +// AuthorizeHTTPBasicAuth validates a username/password pair for EST. +func (s *EST) AuthorizeHTTPBasicAuth(ctx context.Context, csr *x509.CertificateRequest, username, password string) error { + method, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ + CSR: csr, + BasicAuthUsername: username, + BasicAuthPassword: password, + }) + if err != nil { + return err + } + if method != ESTAuthMethodHTTPBasicAuth { + return ErrESTAuthDenied + } + return nil +} + +// authorizeRequestWithWebhook delegates authentication to EST webhooks. +func (s *EST) authorizeRequestWithWebhook(ctx context.Context, req ESTAuthRequest) (ESTAuthMethod, error) { + if req.ClientCertificate != nil { + method, err := s.preferredCertAuthMethod() + if err != nil { + return "", err + } + if err := s.authorizeWithWebhook(ctx, req.ClientCertificate, req.CSR, ""); err != nil { + return "", err + } + return method, nil + } + + if req.hasBasicAuth() { + method, err := s.preferredBasicAuthMethod() + if err != nil { + return "", err + } + if method == ESTAuthMethodHTTPBasicAuth { + if req.BasicAuthPassword == "" { + return "", errors.New("missing basic auth credentials") + } + } + if req.CSR == nil { + return "", errors.New("missing CSR for basic auth validation") + } + opts := []webhook.RequestBodyOption{} + if req.BasicAuthUsername != "" { + opts = append(opts, webhook.WithAuthorizationPrincipal(req.BasicAuthUsername)) + } + if err := s.authorizeWithWebhook(ctx, nil, req.CSR, req.BasicAuthPassword, opts...); err != nil { + return "", err + } + return method, nil + } + + return "", errors.New("missing client certificate or basic auth") +} + +// authorizeRequestLocal validates the request using provisioner configuration. +func (s *EST) authorizeRequestLocal(ctx context.Context, req ESTAuthRequest) (ESTAuthMethod, error) { + if req.ClientCertificate != nil { + var lastErr error + if boolValue(s.EnableTLSClientCertificate, false) { + if err := verifyCertificate(req.ClientCertificate, req.ClientCertificateChain, req.CARoots, req.CAIntermediates); err == nil { + return ESTAuthMethodTLSClientCertificate, nil + } else { + lastErr = err + } + } + if s.hasClientCertificateRoots() { + if s.clientCertificateRootPool == nil { + lastErr = ErrESTAuthMethodDisabled + } else if err := verifyCertificateWithPool(req.ClientCertificate, req.ClientCertificateChain, s.clientCertificateRootPool, nil); err == nil { + return ESTAuthMethodTLSExternalClientCertificate, nil + } else { + lastErr = err + } + } + if lastErr != nil { + return "", lastErr + } + return "", ErrESTAuthMethodDisabled + } + + if req.hasBasicAuth() { + if boolValue(s.EnableHTTPBasicAuth, false) { + if req.BasicAuthPassword == "" { + return "", errors.New("missing basic auth credentials") + } + if s.BasicAuthUsername != "" && req.BasicAuthUsername != s.BasicAuthUsername { + return "", errors.New("invalid basic auth username") + } + if err := s.validateBasicAuthPassword(req.BasicAuthPassword); err != nil { + return "", err + } + return ESTAuthMethodHTTPBasicAuth, nil + } + return "", ErrESTAuthMethodDisabled + } + + return "", errors.New("missing client certificate or basic auth") +} + +// preferredCertAuthMethod selects the enabled certificate-based auth method. +func (s *EST) preferredCertAuthMethod() (ESTAuthMethod, error) { + switch { + case boolValue(s.EnableTLSClientCertificate, false): + return ESTAuthMethodTLSClientCertificate, nil + case s.hasClientCertificateRoots(): + return ESTAuthMethodTLSExternalClientCertificate, nil + default: + return "", ErrESTAuthMethodDisabled + } +} + +// preferredBasicAuthMethod selects the enabled basic-auth-based method. +func (s *EST) preferredBasicAuthMethod() (ESTAuthMethod, error) { + switch { + case boolValue(s.EnableHTTPBasicAuth, false): + return ESTAuthMethodHTTPBasicAuth, nil + default: + return "", ErrESTAuthMethodDisabled + } +} + +// validateBasicAuthPassword verifies the configured basic auth password. +func (s *EST) validateBasicAuthPassword(password string) error { + if s.BasicAuthPassword == "" { + return errors.New("basic auth password is not configured") + } + if subtleCompare(s.BasicAuthPassword, password) { + return nil + } + return errors.New("invalid basic auth password") +} + +// authorizeWithWebhook executes configured webhooks for auth decisions. +func (s *EST) authorizeWithWebhook(ctx context.Context, cert *x509.Certificate, csr *x509.CertificateRequest, secret string, opts ...webhook.RequestBodyOption) error { + if !s.hasAuthWebhooks() { + return nil + } + + var ( + req *webhook.RequestBody + err error + ) + switch { + case cert != nil: + req, err = webhook.NewRequestBody(append(opts, webhook.WithX509Certificate(nil, cert))...) + if err != nil { + return fmt.Errorf("failed creating webhook request: %w", err) + } + if req.X509Certificate != nil { + req.X509Certificate.Raw = cert.Raw + } + case csr != nil: + req, err = webhook.NewRequestBody(append(opts, webhook.WithX509CertificateRequest(csr))...) + if err != nil { + return fmt.Errorf("failed creating webhook request: %w", err) + } + default: + return errors.New("missing certificate or CSR for webhook validation") + } + + req.ProvisionerName = s.Name + if secret != "" { + // TODO: change this to add a dedicated field in the webhook request body (or rename it but can broken existing webhooks) + req.SCEPChallenge = secret + } + + for _, wh := range s.challengeValidationController.webhooks { + resp, err := wh.DoWithContext(ctx, s.challengeValidationController.client, s.challengeValidationController.wrapTransport, req, nil) + if err != nil { + return fmt.Errorf("failed executing webhook request: %w", err) + } + if resp.Allow { + return nil + } + } + + return ErrESTAuthDenied +} + +// hasBasicAuth reports whether any basic auth data is present. +func (r ESTAuthRequest) hasBasicAuth() bool { + return r.BasicAuthUsername != "" || r.BasicAuthPassword != "" +} + +// hasAuthWebhooks reports whether auth webhooks are configured. +func (s *EST) hasAuthWebhooks() bool { + return s.challengeValidationController != nil && len(s.challengeValidationController.webhooks) > 0 +} + +// normalizeAuthConfig applies defaults and validates auth configuration. +func (s *EST) normalizeAuthConfig() error { + if !s.authMethodsConfigured() { + enable := true + s.EnableHTTPBasicAuth = &enable + } + if s.EnableHTTPBasicAuth == nil && (s.BasicAuthUsername != "" || s.BasicAuthPassword != "") { + enable := true + s.EnableHTTPBasicAuth = &enable + } + if boolValue(s.EnableHTTPBasicAuth, false) && s.BasicAuthPassword == "" && !s.hasAuthWebhooks() { + return errors.New("basic auth password cannot be empty") + } + return nil +} + +// authMethodsConfigured reports whether any auth method is explicitly configured. +func (s *EST) authMethodsConfigured() bool { + return s.EnableTLSClientCertificate != nil || + s.hasClientCertificateRoots() || + s.EnableHTTPBasicAuth != nil +} + +// parseClientCertificateRoots loads external client certificate roots. +func (s *EST) parseClientCertificateRoots() error { + if len(s.ClientCertificateRoots) == 0 { + return nil + } + var ( + block *pem.Block + hasCert bool + rest = s.ClientCertificateRoots + ) + s.clientCertificateRootPool = x509.NewCertPool() + for rest != nil { + block, rest = pem.Decode(rest) + if block == nil { + break + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return errors.New("error parsing clientCertificateRoots: malformed certificate") + } + s.clientCertificateRootPool.AddCert(cert) + hasCert = true + } + if !hasCert { + return errors.New("error parsing clientCertificateRoots: no certificates found") + } + return nil +} + +func (s *EST) hasClientCertificateRoots() bool { + return len(s.ClientCertificateRoots) > 0 +} + +// verifyCertificate validates the client certificate against CA roots. +func verifyCertificate(cert *x509.Certificate, chain, roots, intermediates []*x509.Certificate) error { + rootPool := x509.NewCertPool() + for _, root := range roots { + if root != nil { + rootPool.AddCert(root) + } + } + intermediatePool := x509.NewCertPool() + for _, intermediate := range intermediates { + if intermediate != nil { + intermediatePool.AddCert(intermediate) + } + } + return verifyCertificateWithPool(cert, chain, rootPool, intermediatePool) +} + +// verifyCertificateWithPool validates the client certificate using explicit pools. +func verifyCertificateWithPool(cert *x509.Certificate, chain []*x509.Certificate, roots, intermediates *x509.CertPool) error { + if intermediates == nil { + intermediates = x509.NewCertPool() + } + for i, intermediate := range chain { + if i == 0 || intermediate == nil { + continue + } + intermediates.AddCert(intermediate) + } + _, err := cert.Verify(x509.VerifyOptions{ + Roots: roots, + Intermediates: intermediates, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }) + if err != nil { + return fmt.Errorf("invalid client certificate: %w", err) + } + return nil +} + +// boolValue returns the dereferenced value or a default. +func boolValue(value *bool, defaultValue bool) bool { + if value == nil { + return defaultValue + } + return *value +} + +// subtleCompare compares secrets in constant time. +func subtleCompare(expected, actual string) bool { + if len(expected) != len(actual) { + return false + } + return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1 +} diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 33d75fe9a..95c916436 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -208,6 +208,8 @@ const ( TypeSCEP Type = 10 // TypeNebula is used to indicate the Nebula provisioners TypeNebula Type = 11 + // TypeEST is used to indicate the EST provisioners + TypeEST Type = 12 ) // String returns the string representation of the type. @@ -235,6 +237,8 @@ func (t Type) String() string { return "SCEP" case TypeNebula: return "Nebula" + case TypeEST: + return "EST" default: return "" } @@ -328,6 +332,8 @@ func (l *List) UnmarshalJSON(data []byte) error { p = &SCEP{} case "nebula": p = &Nebula{} + case "est": + p = &EST{} default: // Skip unsupported provisioners. A client using this method may be // compiled with a version of smallstep/certificates that does not diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index da37eee58..056a44968 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -54,7 +54,7 @@ func startCABootstrapServer() *httptest.Server { if err != nil { panic(err) } - baseContext := buildContext(ca.auth, nil, nil, nil) + baseContext := buildContext(ca.auth, nil, nil, nil, nil) srv.Config.Handler = ca.srv.Handler srv.Config.BaseContext = func(net.Listener) context.Context { return baseContext diff --git a/ca/ca.go b/ca/ca.go index 3f0704a0a..ed62eeebf 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -39,6 +39,8 @@ import ( "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/monitoring" + "github.com/smallstep/certificates/est" + estAPI "github.com/smallstep/certificates/est/api" "github.com/smallstep/certificates/scep" scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" @@ -300,6 +302,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } var scepAuthority *scep.Authority + var estAuthority *est.Authority if ca.shouldServeSCEPEndpoints() { // get the SCEP authority configuration. Validation is // performed within the authority instantiation process. @@ -328,6 +331,15 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { //dumpRoutes(mux) //dumpRoutes(insecureMux) + // EST endpoints (HTTPS only) + estPrefix := ".well-known/est" + if estAuth := auth.GetEST(); estAuth != nil { + estAuthority = estAuth + mux.Route("/"+estPrefix, func(r chi.Router) { + estAPI.Route(r) + }) + } + // Add monitoring if configured if len(cfg.Monitoring) > 0 { m, err := monitoring.New(cfg.Monitoring) @@ -355,7 +367,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler) // Create context with all the necessary values. - baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) + baseContext := buildContext(auth, scepAuthority, estAuthority, acmeDB, acmeLinker) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -403,7 +415,7 @@ func (ca *CA) shouldServeInsecureServer() bool { } // buildContext builds the server base context. -func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, estAuthority *est.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -414,6 +426,9 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB if scepAuthority != nil { ctx = scep.NewContext(ctx, scepAuthority) } + if estAuthority != nil { + ctx = est.NewContext(ctx, estAuthority) + } if acmeDB != nil { ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) } diff --git a/ca/tls_test.go b/ca/tls_test.go index 465f1ede2..885b895ae 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -78,7 +78,7 @@ func startCATestServer(t *testing.T) *httptest.Server { ca, err := New(config) require.NoError(t, err) // Use a httptest.Server instead - baseContext := buildContext(ca.auth, nil, nil, nil) + baseContext := buildContext(ca.auth, nil, nil, nil, nil) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } diff --git a/est/api/api.go b/est/api/api.go new file mode 100644 index 000000000..805385f76 --- /dev/null +++ b/est/api/api.go @@ -0,0 +1,304 @@ +// Package api implements an EST HTTP server. +package api + +import ( + "context" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi/v5" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/est" +) + +const ( + maxPayloadSize = 2 << 20 +) + +// Route configures the EST routes under the provided router. +func Route(r api.Router) { + r.MethodFunc(http.MethodGet, "/{provisionerName}/cacerts", getCACerts) + r.MethodFunc(http.MethodGet, "/{provisionerName}/csrattrs", getCSRAttrs) + r.MethodFunc(http.MethodPost, "/{provisionerName}/simpleenroll", enroll) + r.MethodFunc(http.MethodPost, "/{provisionerName}/simplereenroll", enroll) +} + +func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "provisionerName") + if name == "" || name == "/" { + name = r.URL.Query().Get("provisioner") + } + if name == "" { + fail(w, r, errors.New("missing provisioner name")) + return + } + provisionerName, err := url.PathUnescape(name) + if err != nil { + fail(w, r, fmt.Errorf("error url unescaping provisioner name '%s'", name)) + return + } + + ctx := r.Context() + auth := authority.MustFromContext(ctx) + p, err := auth.LoadProvisionerByName(provisionerName) + if err != nil { + fail(w, r, err) + return + } + + prov, ok := p.(*provisioner.EST) + if !ok { + fail(w, r, errors.New("provisioner must be of type EST")) + return + } + + ctx = est.NewProvisionerContext(ctx, est.Provisioner(prov)) + next(w, r.WithContext(ctx)) + } +} + +func getCACerts(w http.ResponseWriter, r *http.Request) { + lookupProvisioner(getCACertsHandler)(w, r) +} + +func getCACertsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + auth := est.MustFromContext(ctx) + + certs, err := auth.GetCACertificates(ctx) + if err != nil { + fail(w, r, fmt.Errorf("failed to get CA certificates: %w", err)) + return + } + + data, err := auth.BuildResponse(ctx, certs) + if err != nil { + fail(w, r, fmt.Errorf("failed to encode CA certificates: %w", err)) + return + } + + writeResponse(w, r, data, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) +} + +func getCSRAttrs(w http.ResponseWriter, r *http.Request) { + lookupProvisioner(getCSRAttrsHandler)(w, r) +} + +func getCSRAttrsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + prov := est.ProvisionerFromContext(ctx) + + attrs, err := prov.GetCSRAttributes(ctx) + if err != nil { + fail(w, r, fmt.Errorf("failed to get CSR attributes: %w", err)) + return + } + if attrs == nil { + attrs = []byte{} + } + // Minimal implementation: allow provisioner to return nil/empty for "no attributes". + writeResponse(w, r, attrs, "application/csrattrs", http.StatusOK) +} + +func enroll(w http.ResponseWriter, r *http.Request) { + lookupProvisioner(enrollHandler)(w, r) +} + +func enrollHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx, err := authContextFromRequest(ctx, r) + if err != nil { + if errors.Is(err, errMissingClientCertificateOrBasicAuth) { + w.Header().Set("WWW-Authenticate", `Basic realm="EST"`) + } + failWithStatus(w, r, http.StatusUnauthorized, err) + return + } + + body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) + if err != nil { + failWithStatus(w, r, http.StatusBadRequest, fmt.Errorf("failed reading request body: %w", err)) + return + } + + if err := requireContentType(r, "application/pkcs10"); err != nil { + failWithStatus(w, r, http.StatusUnsupportedMediaType, err) + return + } + + der, err := decodeBase64Payload(body) + if err != nil { + failWithStatus(w, r, http.StatusBadRequest, err) + return + } + + csr, err := parseCSR(der) + if err != nil { + failWithStatus(w, r, http.StatusBadRequest, fmt.Errorf("failed parsing CSR: %w", err)) + return + } + if err := csr.CheckSignature(); err != nil { + failWithStatus(w, r, http.StatusBadRequest, fmt.Errorf("invalid CSR signature: %w", err)) + return + } + + ctx, err = authorizeEnrollRequest(ctx, csr) + if err != nil { + failWithStatus(w, r, http.StatusUnauthorized, err) + return + } + + r = r.WithContext(ctx) + auth := est.MustFromContext(ctx) + + issued, err := auth.SignCSR(ctx, csr) + if err != nil { + failWithStatus(w, r, http.StatusInternalServerError, fmt.Errorf("failed issuing certificate: %w", err)) + return + } + + signed, err := auth.BuildResponse(ctx, []*x509.Certificate{issued}) + if err != nil { + failWithStatus(w, r, http.StatusInternalServerError, fmt.Errorf("failed encoding issued certificate: %w", err)) + return + } + + writeResponse(w, r, signed, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) +} + +var errMissingClientCertificateOrBasicAuth = errors.New("missing client certificate or basic auth") + +// authContextFromRequest extracts auth material from the request into the context. +func authContextFromRequest(ctx context.Context, r *http.Request) (context.Context, error) { + if r.TLS == nil { + return ctx, errors.New("missing TLS connection") + } + + if len(r.TLS.PeerCertificates) > 0 { + ctx = est.NewClientCertificateContext(ctx, r.TLS.PeerCertificates[0]) + ctx = est.NewClientCertificateChainContext(ctx, r.TLS.PeerCertificates) + } + + if username, password, ok := r.BasicAuth(); ok { + ctx = est.NewBasicAuthContext(ctx, est.BasicAuth{ + Username: username, + Password: password, + }) + } + + if _, ok := est.ClientCertificateFromContext(ctx); !ok { + if _, ok := est.BasicAuthFromContext(ctx); !ok { + return ctx, errMissingClientCertificateOrBasicAuth + } + } + return ctx, nil +} + +// authorizeEnrollRequest validates the request against provisioner-configured auth methods. +func authorizeEnrollRequest(ctx context.Context, csr *x509.CertificateRequest) (context.Context, error) { + prov := est.ProvisionerFromContext(ctx) + ca := authority.MustFromContext(ctx) + + req := provisioner.ESTAuthRequest{ + CSR: csr, + CARoots: ca.GetRootCertificates(), + CAIntermediates: ca.GetIntermediateCertificates(), + } + if cert, ok := est.ClientCertificateFromContext(ctx); ok { + req.ClientCertificate = cert + req.ClientCertificateChain, _ = est.ClientCertificateChainFromContext(ctx) + } + if auth, ok := est.BasicAuthFromContext(ctx); ok { + req.BasicAuthUsername = auth.Username + req.BasicAuthPassword = auth.Password + } + + method, err := prov.AuthorizeRequest(ctx, req) + if err != nil { + return ctx, err + } + ctx = est.NewAuthMethodContext(ctx, est.AuthMethod(method)) + return ctx, nil +} + +func parseCSR(body []byte) (*x509.CertificateRequest, error) { + if len(body) == 0 { + return nil, errors.New("empty body") + } + + return x509.ParseCertificateRequest(body) +} + +func decodeBase64Payload(body []byte) ([]byte, error) { + if len(body) == 0 { + return nil, errors.New("empty body") + } + + trimmed := strings.Map(func(r rune) rune { + switch r { + case ' ', '\n', '\r', '\t': + return -1 + default: + return r + } + }, string(body)) + + if trimmed == "" { + return nil, errors.New("empty base64 payload") + } + + decoded := make([]byte, base64.StdEncoding.DecodedLen(len(trimmed))) + n, err := base64.StdEncoding.Decode(decoded, []byte(trimmed)) + if err != nil { + return nil, fmt.Errorf("invalid base64 payload: %w", err) + } + + return decoded[:n], nil +} + +func requireContentType(r *http.Request, want string) error { + ct := r.Header.Get("Content-Type") + if ct == "" { + return errors.New("missing Content-Type header") + } + mt, _, err := mime.ParseMediaType(ct) + if err != nil { + return fmt.Errorf("invalid Content-Type header: %w", err) + } + if mt != want { + return fmt.Errorf("unsupported Content-Type %q", mt) + } + return nil +} + +func writeResponse(w http.ResponseWriter, r *http.Request, data []byte, contentType string, status int) { + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Transfer-Encoding", "base64") + w.WriteHeader(status) + + encoder := base64.NewEncoder(base64.StdEncoding, w) + _, _ = encoder.Write(data) + _ = encoder.Close() +} + +func fail(w http.ResponseWriter, r *http.Request, err error) { + log.Error(w, r, err) + http.Error(w, err.Error(), http.StatusInternalServerError) +} + +func failWithStatus(w http.ResponseWriter, r *http.Request, status int, err error) { + log.Error(w, r, err) + http.Error(w, err.Error(), status) +} diff --git a/est/api/api_test.go b/est/api/api_test.go new file mode 100644 index 000000000..22cbc5f41 --- /dev/null +++ b/est/api/api_test.go @@ -0,0 +1,71 @@ +package api + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_writeResponse(t *testing.T) { + type args struct { + w http.ResponseWriter + r *http.Request + data []byte + contentType string + status int + } + tests := []struct { + name string + args args + wantBody string + wantHeaders map[string]string + }{ + { + name: "ok", + args: args{ + w: httptest.NewRecorder(), + r: httptest.NewRequest("GET", "/", nil), + data: []byte("hello world"), + contentType: "application/pkcs7-mime; smime-type=certs-only", + status: http.StatusOK, + }, + wantBody: base64.StdEncoding.EncodeToString([]byte("hello world")), + wantHeaders: map[string]string{ + "Content-Type": "application/pkcs7-mime; smime-type=certs-only", + "Content-Transfer-Encoding": "base64", + }, + }, + { + name: "ok/csrattrs", + args: args{ + w: httptest.NewRecorder(), + r: httptest.NewRequest("GET", "/", nil), + data: []byte("attribute data"), + contentType: "application/csrattrs", + status: http.StatusOK, + }, + wantBody: base64.StdEncoding.EncodeToString([]byte("attribute data")), + wantHeaders: map[string]string{ + "Content-Type": "application/csrattrs", + "Content-Transfer-Encoding": "base64", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writeResponse(tt.args.w, tt.args.r, tt.args.data, tt.args.contentType, tt.args.status) + resp := tt.args.w.(*httptest.ResponseRecorder) + + assert.Equal(t, tt.args.status, resp.Code) + assert.Equal(t, tt.wantBody, resp.Body.String()) + + for k, v := range tt.wantHeaders { + assert.Equal(t, v, resp.Header().Get(k)) + } + }) + } +} diff --git a/est/auth_context.go b/est/auth_context.go new file mode 100644 index 000000000..8f1f6be39 --- /dev/null +++ b/est/auth_context.go @@ -0,0 +1,50 @@ +package est + +import "context" + +// AuthMethod describes the authentication method used for an EST request. +type AuthMethod string + +const ( + AuthMethodTLSClientCertificate AuthMethod = "tls-client-certificate" + AuthMethodTLSExternalClientCertificate AuthMethod = "tls-external-client-certificate" + AuthMethodHTTPBasicAuth AuthMethod = "http-basic-auth" +) + +type authMethodKey struct{} + +// NewAuthMethodContext stores the EST authentication method in the context. +func NewAuthMethodContext(ctx context.Context, method AuthMethod) context.Context { + if method == "" { + return ctx + } + return context.WithValue(ctx, authMethodKey{}, method) +} + +// AuthMethodFromContext returns the EST authentication method stored in the context. +func AuthMethodFromContext(ctx context.Context) (AuthMethod, bool) { + method, ok := ctx.Value(authMethodKey{}).(AuthMethod) + return method, ok +} + +// BasicAuth holds the HTTP basic auth credentials for an EST request. +type BasicAuth struct { + Username string + Password string +} + +type basicAuthKey struct{} + +// NewBasicAuthContext stores the HTTP basic auth credentials in the context. +func NewBasicAuthContext(ctx context.Context, auth BasicAuth) context.Context { + if auth.Username == "" && auth.Password == "" { + return ctx + } + return context.WithValue(ctx, basicAuthKey{}, auth) +} + +// BasicAuthFromContext returns the HTTP basic auth credentials stored in the context. +func BasicAuthFromContext(ctx context.Context) (BasicAuth, bool) { + auth, ok := ctx.Value(basicAuthKey{}).(BasicAuth) + return auth, ok +} diff --git a/est/authority.go b/est/authority.go new file mode 100644 index 000000000..e7a6da41d --- /dev/null +++ b/est/authority.go @@ -0,0 +1,238 @@ +package est + +import ( + "bytes" + "context" + "crypto" + "crypto/x509" + "fmt" + "sync" + + "github.com/smallstep/pkcs7" + + "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/provisioner" +) + +// Authority handles EST interactions. +type Authority struct { + signAuth SignAuthority + roots []*x509.Certificate + intermediates []*x509.Certificate + defaultSigner crypto.Signer + signerCertificate *x509.Certificate + estProvisionerNames []string + provisionersMutex sync.RWMutex +} + +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + var ( + a *Authority + ok bool + ) + if a, ok = FromContext(ctx); !ok { + panic("est authority is not in the context") + } + return a +} + +// SignAuthority is the interface for a signing authority. +type SignAuthority interface { + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + LoadProvisionerByName(string) (provisioner.Interface, error) +} + +// New returns a new Authority that implements the EST interface. +func New(signAuth SignAuthority, opts Options) (*Authority, error) { + if err := opts.Validate(); err != nil { + return nil, err + } + + return &Authority{ + signAuth: signAuth, + roots: opts.Roots, + intermediates: opts.Intermediates, + defaultSigner: opts.Signer, + signerCertificate: opts.SignerCert, + estProvisionerNames: opts.ESTProvisionerNames, + }, nil +} + +// validates if the EST Authority has a valid configuration. +func (a *Authority) Validate() error { + if a == nil { + return nil + } + + a.provisionersMutex.RLock() + defer a.provisionersMutex.RUnlock() + + noDefaultSignerAvailable := a.defaultSigner == nil || a.signerCertificate == nil + for _, name := range a.estProvisionerNames { + p, err := a.LoadProvisionerByName(name) + if err != nil { + return fmt.Errorf("failed loading provisioner %q: %w", name, err) + } + if estProv, ok := p.(*provisioner.EST); ok { + cert, signer := estProv.GetSigner() + if cert == nil && noDefaultSignerAvailable { + return fmt.Errorf("EST provisioner %q does not have a signer certificate", name) + } + if signer == nil && noDefaultSignerAvailable { + return fmt.Errorf("EST provisioner %q does not have a signer", name) + } + } + } + + return nil +} + +// UpdateProvisioners updates the EST Authority with the new, and hopefully +// current EST provisioners configured. This allows the Authority to be +// validated with the latest data. +func (a *Authority) UpdateProvisioners(estProvisionerNames []string) { + if a == nil { + return + } + + a.provisionersMutex.Lock() + defer a.provisionersMutex.Unlock() + + a.estProvisionerNames = estProvisionerNames +} + +// LoadProvisionerByName calls out to the SignAuthority interface to load a +// provisioner by name. +func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + return a.signAuth.LoadProvisionerByName(name) +} + +// GetCACertificates returns the certificate chain for the CA. +func (a *Authority) GetCACertificates(ctx context.Context) (certs []*x509.Certificate, err error) { + p := provisionerFromContext(ctx) + + if signerCert, _ := p.GetSigner(); signerCert != nil { + certs = append(certs, signerCert) + } + + if p.ShouldIncludeIntermediateInChain() || len(certs) == 0 { + certs = append(certs, a.intermediates...) + } + + if p.ShouldIncludeRootInChain() { + certs = append(certs, a.roots...) + } + + return certs, nil +} + +// SignCSR signs the CSR using the provisioner and returns the issued chain. +func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, signCSROpts ...provisioner.SignCSROption) (*x509.Certificate, error) { + // TODO: intermediate storage of the request? In EST it's possible to request a csr/certificate + // to be signed, which can be performed asynchronously / out-of-band. In that case a client can + // poll for the status. It seems to be similar as what can happen in ACME and SCEP, so might want to model + // the implementation after the one in the ACME authority. Requires storage, etc. + // ref: https://datatracker.ietf.org/doc/html/rfc7030#section-4.2.3 + p := provisionerFromContext(ctx) + + // Template data + sans := []string{} + sans = append(sans, csr.DNSNames...) + sans = append(sans, csr.EmailAddresses...) + for _, v := range csr.IPAddresses { + sans = append(sans, v.String()) + } + for _, v := range csr.URIs { + sans = append(sans, v.String()) + } + if len(sans) == 0 { + sans = append(sans, csr.Subject.CommonName) + } + data := x509util.CreateTemplateData(csr.Subject.CommonName, sans) + data.SetCertificateRequest(csr) + data.SetSubject(x509util.Subject{ + Country: csr.Subject.Country, + Organization: csr.Subject.Organization, + OrganizationalUnit: csr.Subject.OrganizationalUnit, + Locality: csr.Subject.Locality, + Province: csr.Subject.Province, + StreetAddress: csr.Subject.StreetAddress, + PostalCode: csr.Subject.PostalCode, + SerialNumber: csr.Subject.SerialNumber, + CommonName: csr.Subject.CommonName, + }) + + for _, o := range signCSROpts { + if m, ok := o.(provisioner.TemplateDataModifier); ok { + m.Modify(data) + } + } + + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + signOps, err := p.AuthorizeSign(ctx, "") + if err != nil { + return nil, fmt.Errorf("error retrieving authorization options from EST provisioner: %w", err) + } + for _, signOp := range signOps { + if wc, ok := signOp.(*provisioner.WebhookController); ok { + wc.TemplateData = data + } + } + + opts := provisioner.SignOptions{} + templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) + if err != nil { + return nil, fmt.Errorf("error creating template options from EST provisioner: %w", err) + } + signOps = append(signOps, templateOptions) + + certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...) + if err != nil { + return nil, fmt.Errorf("error generating certificate: %w", err) + } + // return leaf certificate (only): https://datatracker.ietf.org/doc/html/rfc7030#section-4.2.3 + return certChain[0], nil +} + +// BuildResponse returns a certs-only PKCS7 SignedData for the given certs. +func (a *Authority) BuildResponse(ctx context.Context, certs []*x509.Certificate) ([]byte, error) { + if len(certs) == 0 { + return nil, fmt.Errorf("no certificates to encode") + } + // Build degenerate PKCS7: SignedData with no encapsulated content or signer infos. + var buf bytes.Buffer + for _, cert := range certs { + buf.Write(cert.Raw) + } + degenerate, err := pkcs7.DegenerateCertificate(buf.Bytes()) + if err != nil { + return nil, err + } + return degenerate, nil +} + +func (a *Authority) NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error { + p := provisionerFromContext(ctx) + return p.NotifySuccess(ctx, csr, cert, transactionID) +} + +func (a *Authority) NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error { + p := provisionerFromContext(ctx) + return p.NotifyFailure(ctx, csr, transactionID, errorCode, errorDescription) +} diff --git a/est/client_cert.go b/est/client_cert.go new file mode 100644 index 000000000..14c26f291 --- /dev/null +++ b/est/client_cert.go @@ -0,0 +1,37 @@ +package est + +import ( + "context" + "crypto/x509" +) + +type clientCertificateKey struct{} +type clientCertificateChainKey struct{} + +// NewClientCertificateContext stores the TLS client certificate in the context. +func NewClientCertificateContext(ctx context.Context, cert *x509.Certificate) context.Context { + if cert == nil { + return ctx + } + return context.WithValue(ctx, clientCertificateKey{}, cert) +} + +// ClientCertificateFromContext returns the TLS client certificate stored in the context. +func ClientCertificateFromContext(ctx context.Context) (*x509.Certificate, bool) { + cert, ok := ctx.Value(clientCertificateKey{}).(*x509.Certificate) + return cert, ok +} + +// NewClientCertificateChainContext stores the TLS client certificate chain in the context. +func NewClientCertificateChainContext(ctx context.Context, chain []*x509.Certificate) context.Context { + if len(chain) == 0 { + return ctx + } + return context.WithValue(ctx, clientCertificateChainKey{}, chain) +} + +// ClientCertificateChainFromContext returns the TLS client certificate chain stored in the context. +func ClientCertificateChainFromContext(ctx context.Context) ([]*x509.Certificate, bool) { + chain, ok := ctx.Value(clientCertificateChainKey{}).([]*x509.Certificate) + return chain, ok +} diff --git a/est/options.go b/est/options.go new file mode 100644 index 000000000..df2a48b26 --- /dev/null +++ b/est/options.go @@ -0,0 +1,40 @@ +package est + +import ( + "crypto" + "crypto/x509" + "errors" +) + +// Options configures the EST authority instance. +type Options struct { + Roots []*x509.Certificate `json:"-"` + Intermediates []*x509.Certificate `json:"-"` + SignerCert *x509.Certificate `json:"-"` + Signer crypto.Signer `json:"-"` + + ESTProvisionerNames []string +} + +type comparablePublicKey interface { + Equal(crypto.PublicKey) bool +} + +// Validate checks the fields in Options. +func (o *Options) Validate() error { + switch { + case len(o.Intermediates) == 0: + return errors.New("no intermediate certificate available for EST authority") + case o.SignerCert == nil: + return errors.New("no signer certificate available for EST authority") + } + + if o.Signer != nil { + signerPublicKey := o.Signer.Public().(comparablePublicKey) + if !signerPublicKey.Equal(o.SignerCert.PublicKey) { + return errors.New("mismatch between signer certificate and public key") + } + } + + return nil +} diff --git a/est/provisioner.go b/est/provisioner.go new file mode 100644 index 000000000..a1e9b6c3f --- /dev/null +++ b/est/provisioner.go @@ -0,0 +1,46 @@ +package est + +import ( + "context" + "crypto" + "crypto/x509" + + "github.com/smallstep/certificates/authority/provisioner" +) + +// Provisioner is an interface that embeds the generic provisioner.Interface and +// adds EST-specific helpers. +type Provisioner interface { + provisioner.Interface + GetOptions() *provisioner.Options + ShouldIncludeRootInChain() bool + ShouldIncludeIntermediateInChain() bool + GetSigner() (*x509.Certificate, crypto.Signer) + AuthorizeRequest(ctx context.Context, req provisioner.ESTAuthRequest) (provisioner.ESTAuthMethod, error) + NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error + NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error + GetCSRAttributes(ctx context.Context) ([]byte, error) +} + +// provisionerKey is the key type for storing and searching an EST provisioner in the context. +type provisionerKey struct{} + +// provisionerFromContext searches the context for an EST provisioner. +// Returns the provisioner or panics if no EST provisioner is found. +func provisionerFromContext(ctx context.Context) Provisioner { + p, ok := ctx.Value(provisionerKey{}).(Provisioner) + if !ok { + panic("EST provisioner expected in request context") + } + return p +} + +// NewProvisionerContext returns a new context with the EST provisioner set. +func NewProvisionerContext(ctx context.Context, p Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, p) +} + +// ProvisionerFromContext returns the EST provisioner stored in the context. +func ProvisionerFromContext(ctx context.Context) Provisioner { + return provisionerFromContext(ctx) +} From 093fbeb0d7e9d4aba7e9854a220c293c486b29e2 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Mon, 30 Mar 2026 16:06:23 +0200 Subject: [PATCH 2/8] add linkedca support for EST need for db migration and authority startup use replace in go.mod to use a linkedca version compatible (need for the build) --- authority/admin/db.go | 2 ++ authority/provisioner/est.go | 43 ++++++++++++++++----------------- authority/provisioners.go | 47 ++++++++++++++++++++++++++++++++++++ go.mod | 4 ++- go.sum | 8 +++--- 5 files changed, 77 insertions(+), 27 deletions(-) diff --git a/authority/admin/db.go b/authority/admin/db.go index 63940a8a3..3a5961c07 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -46,6 +46,8 @@ func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*l v.Data = new(linkedca.ProvisionerDetails_SCEP) case linkedca.Provisioner_NEBULA: v.Data = new(linkedca.ProvisionerDetails_Nebula) + case linkedca.Provisioner_EST: + v.Data = new(linkedca.ProvisionerDetails_EST) default: return nil, fmt.Errorf("unsupported provisioner type %s", typ) } diff --git a/authority/provisioner/est.go b/authority/provisioner/est.go index 43479d24e..cd3efaf1a 100644 --- a/authority/provisioner/est.go +++ b/authority/provisioner/est.go @@ -16,28 +16,27 @@ import ( // EST is the EST provisioner type, an entity that can authorize the EST flow. type EST struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - EnableTLSClientCertificate *bool `json:"enableTLSClientCertificate,omitempty"` - EnableHTTPBasicAuth *bool `json:"enableHTTPBasicAuth,omitempty"` - BasicAuthUsername string `json:"basicAuthUsername,omitempty"` - BasicAuthPassword string `json:"basicAuthPassword,omitempty"` - ClientCertificateRoots []byte `json:"clientCertificateRoots,omitempty"` - ForceCN bool `json:"forceCN,omitempty"` - Capabilities []string `json:"capabilities,omitempty"` - IncludeRoot bool `json:"includeRoot,omitempty"` - ExcludeIntermediate bool `json:"excludeIntermediate,omitempty"` - MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` - CSRAttrs []byte `json:"csrAttrs,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - ctl *Controller - signer crypto.Signer - signerCertificate *x509.Certificate - challengeValidationController *challengeValidationController - notificationController *notificationController - clientCertificateRootPool *x509.CertPool + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + EnableTLSClientCertificate *bool `json:"enableTLSClientCertificate,omitempty"` + EnableHTTPBasicAuth *bool `json:"enableHTTPBasicAuth,omitempty"` + BasicAuthUsername string `json:"basicAuthUsername,omitempty"` + BasicAuthPassword string `json:"basicAuthPassword,omitempty"` + ClientCertificateRoots []byte `json:"clientCertificateRoots,omitempty"` + ForceCN bool `json:"forceCN,omitempty"` + IncludeRoot bool `json:"includeRoot,omitempty"` + ExcludeIntermediate bool `json:"excludeIntermediate,omitempty"` + MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + CSRAttrs []byte `json:"csrAttrs,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` + ctl *Controller + signer crypto.Signer + signerCertificate *x509.Certificate + challengeValidationController *challengeValidationController + notificationController *notificationController + clientCertificateRootPool *x509.CertPool } // GetID returns the provisioner unique identifier. diff --git a/authority/provisioners.go b/authority/provisioners.go index d01048564..4bc42cd00 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1003,6 +1003,25 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, s.DecrypterKeyPassword = string(decrypter.KeyPassword) } return s, nil + case *linkedca.ProvisionerDetails_EST: + cfg := d.EST + enableTLSClientCertificate := cfg.EnableTlsClientCertificate + enableHTTPBasicAuth := cfg.EnableHttpBasicAuth + return &provisioner.EST{ + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + EnableTLSClientCertificate: &enableTLSClientCertificate, + EnableHTTPBasicAuth: &enableHTTPBasicAuth, + BasicAuthUsername: cfg.BasicAuthUsername, + BasicAuthPassword: cfg.BasicAuthPassword, + ClientCertificateRoots: provisionerPEMToCertificates(cfg.ClientCertificateRoots), + ForceCN: cfg.ForceCn, + IncludeRoot: cfg.IncludeRoot, + MinimumPublicKeyLength: int(cfg.MinimumPublicKeyLength), + Claims: claims, + Options: options, + }, nil case *linkedca.ProvisionerDetails_Nebula: var roots []byte for i, root := range d.Nebula.GetRoots() { @@ -1278,6 +1297,34 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro SshTemplate: sshTemplate, Webhooks: webhooks, }, nil + case *provisioner.EST: + x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) + if err != nil { + return nil, err + } + return &linkedca.Provisioner{ + Id: p.ID, + Type: linkedca.Provisioner_EST, + Name: p.GetName(), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_EST{ + EST: &linkedca.ESTProvisioner{ + ForceCn: p.ForceCN, + EnableTlsClientCertificate: p.EnableTLSClientCertificate != nil && *p.EnableTLSClientCertificate, + EnableHttpBasicAuth: p.EnableHTTPBasicAuth != nil && *p.EnableHTTPBasicAuth, + MinimumPublicKeyLength: cast.Int32(p.MinimumPublicKeyLength), + IncludeRoot: p.IncludeRoot, + BasicAuthUsername: p.BasicAuthUsername, + BasicAuthPassword: p.BasicAuthPassword, + ClientCertificateRoots: provisionerPEMToLinkedca(p.ClientCertificateRoots), + }, + }, + }, + Claims: claimsToLinkedca(p.Claims), + X509Template: x509Template, + SshTemplate: sshTemplate, + Webhooks: webhooks, + }, nil case *provisioner.Nebula: x509Template, sshTemplate, webhooks, err := provisionerOptionsToLinkedca(p.Options) if err != nil { diff --git a/go.mod b/go.mod index b7f8b1855..8455bff66 100644 --- a/go.mod +++ b/go.mod @@ -172,6 +172,8 @@ require ( google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260316180232-0b37fe3546d5 // indirect - google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect + google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.6.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/smallstep/linkedca => github.com/jbpin/linkedca v0.0.0-20251224103807-5e7deb3d4d49 diff --git a/go.sum b/go.sum index 200fd5604..3aafd6eb4 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jbpin/linkedca v0.0.0-20251224103807-5e7deb3d4d49 h1:rFx5O4hmVJ3fQ90n1PP8fjBWIQoSuBSqbdts6UfnaVk= +github.com/jbpin/linkedca v0.0.0-20251224103807-5e7deb3d4d49/go.mod h1:bUm0HgkgtOjbnwShtNKu3XPuQ6AA/o680sxTfy9vdYk= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= @@ -352,8 +354,6 @@ github.com/smallstep/cli-utils v0.12.2 h1:lGzM9PJrH/qawbzMC/s2SvgLdJPKDWKwKzx9do github.com/smallstep/cli-utils v0.12.2/go.mod h1:uCPqefO29goHLGqFnwk0i8W7XJu18X3WHQFRtOm/00Y= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca h1:VX8L0r8vybH0bPeaIxh4NQzafKQiqvlOn8pmOXbFLO4= github.com/smallstep/go-attestation v0.4.4-0.20241119153605-2306d5b464ca/go.mod h1:vNAduivU014fubg6ewygkAvQC0IQVXqdc8vaGl/0er4= -github.com/smallstep/linkedca v0.25.0 h1:txT9QHGbCsJq0MhAghBq7qhurGY727tQuqUi+n4BVBo= -github.com/smallstep/linkedca v0.25.0/go.mod h1:Q3jVAauFKNlF86W5/RFtgQeyDKz98GL/KN3KG4mJOvc= github.com/smallstep/nosql v0.8.0 h1:FBTCUfKPmWYbrozW+RBKu+fnvbn+zr5rVli/XB4Jp4A= github.com/smallstep/nosql v0.8.0/go.mod h1:5dUpNotHLHhOUapP0PLBVVfp3tG1DFC31VRccg+Cqwo= github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA= @@ -533,8 +533,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20260316180232-0b37fe3546d5 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 h1:F29+wU6Ee6qgu9TddPgooOdaqsxTMunOoj8KA5yuS5A= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1/go.mod h1:5KF+wpkbTSbGcR9zteSqZV6fqFOWBl4Yde8En8MryZA= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.6.0 h1:6Al3kEFFP9VJhRz3DID6quisgPnTeZVr4lep9kkxdPA= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.6.0/go.mod h1:QLvsjh0OIR0TYBeiu2bkWGTJBUNQ64st52iWj/yA93I= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 08f3322dc7d9b91fdc70ef2b7911db655ba5ad41 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Mon, 30 Mar 2026 16:07:00 +0200 Subject: [PATCH 3/8] refactor EST auth and add support for signOpts --- authority/provisioner/est.go | 9 +- authority/provisioner/est_auth.go | 217 ++++++++++-------------------- est/api/api.go | 13 +- est/provisioner.go | 2 +- 4 files changed, 78 insertions(+), 163 deletions(-) diff --git a/authority/provisioner/est.go b/authority/provisioner/est.go index cd3efaf1a..2dc1c4b81 100644 --- a/authority/provisioner/est.go +++ b/authority/provisioner/est.go @@ -85,21 +85,18 @@ func (s *EST) DefaultTLSCertDuration() time.Duration { // newChallengeValidationController creates a new challengeValidationController // that performs challenge validation through webhooks. func newESTChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { - scepHooks := []*Webhook{} + estHooks := []*Webhook{} for _, wh := range webhooks { // if wh.Kind != linkedca.Webhook_ESTCHALLENGE.String() { if wh.Kind != "ESTCHALLENGE" { continue } - if !isCertTypeOK(wh) { - continue - } - scepHooks = append(scepHooks, wh) + estHooks = append(estHooks, wh) } return &challengeValidationController{ client: client, wrapTransport: tw, - webhooks: scepHooks, + webhooks: estHooks, } } diff --git a/authority/provisioner/est_auth.go b/authority/provisioner/est_auth.go index 897279879..b7737e994 100644 --- a/authority/provisioner/est_auth.go +++ b/authority/provisioner/est_auth.go @@ -9,11 +9,14 @@ import ( "fmt" "github.com/smallstep/certificates/webhook" + "go.step.sm/crypto/x509util" ) var ( - ErrESTAuthMethodDisabled = errors.New("est authentication method disabled") - ErrESTAuthDenied = errors.New("est authentication denied") + ErrESTAuthMethodDisabled = errors.New("est authentication method disabled") + ErrESTAuthMethodNotFound = errors.New("no valid est authentication method found") + ErrESTAuthMethodMisconfigured = errors.New("est authentication method misconfigured") + ErrESTAuthDenied = errors.New("est authentication denied") ) // ESTAuthMethod identifies the EST authentication method used. @@ -37,223 +40,139 @@ type ESTAuthRequest struct { } // AuthorizeRequest validates the request against configured EST auth methods. -func (s *EST) AuthorizeRequest(ctx context.Context, req ESTAuthRequest) (ESTAuthMethod, error) { +func (s *EST) AuthorizeRequest(ctx context.Context, req ESTAuthRequest) ([]SignCSROption, error) { if s.hasAuthWebhooks() { - return s.authorizeRequestWithWebhook(ctx, req) + return s.authorizeWithWebhook(ctx, &req) } return s.authorizeRequestLocal(ctx, req) } // AuthorizeTLSClientCertificate validates a CA-issued client certificate. -func (s *EST) AuthorizeTLSClientCertificate(ctx context.Context, cert *x509.Certificate, chain, roots, intermediates []*x509.Certificate) error { - method, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ +func (s *EST) AuthorizeTLSClientCertificate(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, chain, roots, intermediates []*x509.Certificate) error { + _, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ + CSR: csr, ClientCertificate: cert, ClientCertificateChain: chain, CARoots: roots, CAIntermediates: intermediates, }) - if err != nil { - return err - } - if method != ESTAuthMethodTLSClientCertificate { - return ErrESTAuthDenied - } - return nil + return err } // AuthorizeTLSExternalClientCertificate validates a client certificate against external roots. -func (s *EST) AuthorizeTLSExternalClientCertificate(ctx context.Context, cert *x509.Certificate, chain []*x509.Certificate) error { - method, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ +func (s *EST) AuthorizeTLSExternalClientCertificate(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, chain []*x509.Certificate) error { + _, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ + CSR: csr, ClientCertificate: cert, ClientCertificateChain: chain, }) - if err != nil { - return err - } - if method != ESTAuthMethodTLSExternalClientCertificate { - return ErrESTAuthDenied - } - return nil + return err } // AuthorizeHTTPBasicAuth validates a username/password pair for EST. func (s *EST) AuthorizeHTTPBasicAuth(ctx context.Context, csr *x509.CertificateRequest, username, password string) error { - method, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ + _, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ CSR: csr, BasicAuthUsername: username, BasicAuthPassword: password, }) - if err != nil { - return err - } - if method != ESTAuthMethodHTTPBasicAuth { - return ErrESTAuthDenied - } - return nil -} - -// authorizeRequestWithWebhook delegates authentication to EST webhooks. -func (s *EST) authorizeRequestWithWebhook(ctx context.Context, req ESTAuthRequest) (ESTAuthMethod, error) { - if req.ClientCertificate != nil { - method, err := s.preferredCertAuthMethod() - if err != nil { - return "", err - } - if err := s.authorizeWithWebhook(ctx, req.ClientCertificate, req.CSR, ""); err != nil { - return "", err - } - return method, nil - } - - if req.hasBasicAuth() { - method, err := s.preferredBasicAuthMethod() - if err != nil { - return "", err - } - if method == ESTAuthMethodHTTPBasicAuth { - if req.BasicAuthPassword == "" { - return "", errors.New("missing basic auth credentials") - } - } - if req.CSR == nil { - return "", errors.New("missing CSR for basic auth validation") - } - opts := []webhook.RequestBodyOption{} - if req.BasicAuthUsername != "" { - opts = append(opts, webhook.WithAuthorizationPrincipal(req.BasicAuthUsername)) - } - if err := s.authorizeWithWebhook(ctx, nil, req.CSR, req.BasicAuthPassword, opts...); err != nil { - return "", err - } - return method, nil - } - - return "", errors.New("missing client certificate or basic auth") + return err } // authorizeRequestLocal validates the request using provisioner configuration. -func (s *EST) authorizeRequestLocal(ctx context.Context, req ESTAuthRequest) (ESTAuthMethod, error) { +func (s *EST) authorizeRequestLocal(ctx context.Context, req ESTAuthRequest) ([]SignCSROption, error) { + var lastErr error = ErrESTAuthMethodNotFound if req.ClientCertificate != nil { - var lastErr error if boolValue(s.EnableTLSClientCertificate, false) { - if err := verifyCertificate(req.ClientCertificate, req.ClientCertificateChain, req.CARoots, req.CAIntermediates); err == nil { - return ESTAuthMethodTLSClientCertificate, nil + if s.hasClientCertificateRoots() { + if err := verifyCertificateWithPool(req.ClientCertificate, req.ClientCertificateChain, s.clientCertificateRootPool, nil); err == nil { + return []SignCSROption{}, nil + } else { + lastErr = err + } } else { - lastErr = err - } - } - if s.hasClientCertificateRoots() { - if s.clientCertificateRootPool == nil { - lastErr = ErrESTAuthMethodDisabled - } else if err := verifyCertificateWithPool(req.ClientCertificate, req.ClientCertificateChain, s.clientCertificateRootPool, nil); err == nil { - return ESTAuthMethodTLSExternalClientCertificate, nil - } else { - lastErr = err + if err := verifyCertificate(req.ClientCertificate, req.ClientCertificateChain, req.CARoots, req.CAIntermediates); err == nil { + return []SignCSROption{}, nil + } else { + lastErr = err + } } + } else { + lastErr = ErrESTAuthMethodDisabled } - if lastErr != nil { - return "", lastErr - } - return "", ErrESTAuthMethodDisabled } if req.hasBasicAuth() { - if boolValue(s.EnableHTTPBasicAuth, false) { - if req.BasicAuthPassword == "" { - return "", errors.New("missing basic auth credentials") - } - if s.BasicAuthUsername != "" && req.BasicAuthUsername != s.BasicAuthUsername { - return "", errors.New("invalid basic auth username") - } - if err := s.validateBasicAuthPassword(req.BasicAuthPassword); err != nil { - return "", err + if boolValue(s.EnableHTTPBasicAuth, false) && s.BasicAuthPassword != "" { + if err := s.validateBasicAuthPassword(req.BasicAuthUsername, req.BasicAuthPassword); err == nil { + return []SignCSROption{}, nil + } else { + lastErr = err } - return ESTAuthMethodHTTPBasicAuth, nil + } else { + lastErr = ErrESTAuthMethodDisabled } - return "", ErrESTAuthMethodDisabled - } - - return "", errors.New("missing client certificate or basic auth") -} - -// preferredCertAuthMethod selects the enabled certificate-based auth method. -func (s *EST) preferredCertAuthMethod() (ESTAuthMethod, error) { - switch { - case boolValue(s.EnableTLSClientCertificate, false): - return ESTAuthMethodTLSClientCertificate, nil - case s.hasClientCertificateRoots(): - return ESTAuthMethodTLSExternalClientCertificate, nil - default: - return "", ErrESTAuthMethodDisabled } -} -// preferredBasicAuthMethod selects the enabled basic-auth-based method. -func (s *EST) preferredBasicAuthMethod() (ESTAuthMethod, error) { - switch { - case boolValue(s.EnableHTTPBasicAuth, false): - return ESTAuthMethodHTTPBasicAuth, nil - default: - return "", ErrESTAuthMethodDisabled - } + return nil, lastErr } // validateBasicAuthPassword verifies the configured basic auth password. -func (s *EST) validateBasicAuthPassword(password string) error { - if s.BasicAuthPassword == "" { - return errors.New("basic auth password is not configured") +func (s *EST) validateBasicAuthPassword(username, password string) error { + if s.BasicAuthUsername != "" && username != s.BasicAuthUsername { + return errors.New("invalid basic auth") } if subtleCompare(s.BasicAuthPassword, password) { return nil } - return errors.New("invalid basic auth password") + return errors.New("invalid basic auth") } // authorizeWithWebhook executes configured webhooks for auth decisions. -func (s *EST) authorizeWithWebhook(ctx context.Context, cert *x509.Certificate, csr *x509.CertificateRequest, secret string, opts ...webhook.RequestBodyOption) error { +func (s *EST) authorizeWithWebhook(ctx context.Context, req *ESTAuthRequest) ([]SignCSROption, error) { if !s.hasAuthWebhooks() { - return nil + return nil, ErrESTAuthMethodMisconfigured } var ( - req *webhook.RequestBody - err error + whreq *webhook.RequestBody + err error ) switch { - case cert != nil: - req, err = webhook.NewRequestBody(append(opts, webhook.WithX509Certificate(nil, cert))...) + case req.ClientCertificate != nil: + whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithX509Certificate(nil, req.ClientCertificate)) if err != nil { - return fmt.Errorf("failed creating webhook request: %w", err) + return nil, fmt.Errorf("failed creating webhook request: %w", err) } - if req.X509Certificate != nil { - req.X509Certificate.Raw = cert.Raw - } - case csr != nil: - req, err = webhook.NewRequestBody(append(opts, webhook.WithX509CertificateRequest(csr))...) + case req.hasBasicAuth(): + whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithAuthorizationPrincipal(req.BasicAuthUsername)) if err != nil { - return fmt.Errorf("failed creating webhook request: %w", err) + return nil, fmt.Errorf("failed creating webhook request: %w", err) } + whreq.SCEPChallenge = req.BasicAuthPassword default: - return errors.New("missing certificate or CSR for webhook validation") - } - - req.ProvisionerName = s.Name - if secret != "" { - // TODO: change this to add a dedicated field in the webhook request body (or rename it but can broken existing webhooks) - req.SCEPChallenge = secret + return nil, errors.New("missing certificate or basic auth for webhook validation") } + whreq.ProvisionerName = s.Name + var opts []SignCSROption for _, wh := range s.challengeValidationController.webhooks { - resp, err := wh.DoWithContext(ctx, s.challengeValidationController.client, s.challengeValidationController.wrapTransport, req, nil) + resp, err := wh.DoWithContext(ctx, s.challengeValidationController.client, s.challengeValidationController.wrapTransport, whreq, nil) if err != nil { - return fmt.Errorf("failed executing webhook request: %w", err) + return nil, fmt.Errorf("failed executing webhook request: %w", err) } if resp.Allow { - return nil + opts = append(opts, TemplateDataModifierFunc(func(data x509util.TemplateData) { + data.SetWebhook(wh.Name, resp.Data) + })) } } - return ErrESTAuthDenied + if len(opts) == 0 { + return nil, ErrESTAuthDenied + } + + return opts, nil } // hasBasicAuth reports whether any basic auth data is present. diff --git a/est/api/api.go b/est/api/api.go index 805385f76..27ccffed3 100644 --- a/est/api/api.go +++ b/est/api/api.go @@ -154,7 +154,7 @@ func enrollHandler(w http.ResponseWriter, r *http.Request) { return } - ctx, err = authorizeEnrollRequest(ctx, csr) + opts, err := authorizeEnrollRequest(ctx, csr) if err != nil { failWithStatus(w, r, http.StatusUnauthorized, err) return @@ -163,7 +163,7 @@ func enrollHandler(w http.ResponseWriter, r *http.Request) { r = r.WithContext(ctx) auth := est.MustFromContext(ctx) - issued, err := auth.SignCSR(ctx, csr) + issued, err := auth.SignCSR(ctx, csr, opts...) if err != nil { failWithStatus(w, r, http.StatusInternalServerError, fmt.Errorf("failed issuing certificate: %w", err)) return @@ -207,7 +207,7 @@ func authContextFromRequest(ctx context.Context, r *http.Request) (context.Conte } // authorizeEnrollRequest validates the request against provisioner-configured auth methods. -func authorizeEnrollRequest(ctx context.Context, csr *x509.CertificateRequest) (context.Context, error) { +func authorizeEnrollRequest(ctx context.Context, csr *x509.CertificateRequest) ([]provisioner.SignCSROption, error) { prov := est.ProvisionerFromContext(ctx) ca := authority.MustFromContext(ctx) @@ -225,12 +225,11 @@ func authorizeEnrollRequest(ctx context.Context, csr *x509.CertificateRequest) ( req.BasicAuthPassword = auth.Password } - method, err := prov.AuthorizeRequest(ctx, req) + opts, err := prov.AuthorizeRequest(ctx, req) if err != nil { - return ctx, err + return nil, err } - ctx = est.NewAuthMethodContext(ctx, est.AuthMethod(method)) - return ctx, nil + return opts, nil } func parseCSR(body []byte) (*x509.CertificateRequest, error) { diff --git a/est/provisioner.go b/est/provisioner.go index a1e9b6c3f..512cbf3f8 100644 --- a/est/provisioner.go +++ b/est/provisioner.go @@ -16,7 +16,7 @@ type Provisioner interface { ShouldIncludeRootInChain() bool ShouldIncludeIntermediateInChain() bool GetSigner() (*x509.Certificate, crypto.Signer) - AuthorizeRequest(ctx context.Context, req provisioner.ESTAuthRequest) (provisioner.ESTAuthMethod, error) + AuthorizeRequest(ctx context.Context, req provisioner.ESTAuthRequest) ([]provisioner.SignCSROption, error) NotifySuccess(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, transactionID string) error NotifyFailure(ctx context.Context, csr *x509.CertificateRequest, transactionID string, errorCode int, errorDescription string) error GetCSRAttributes(ctx context.Context) ([]byte, error) From 7db07a41ae079a0d1604cafda6d38866cd685d48 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Mon, 30 Mar 2026 16:07:00 +0200 Subject: [PATCH 4/8] Add support for Bearer token and Authentication header in webhook requests Clean up code implementation --- authority/provisioner/est_auth.go | 58 ++++++------------------------- est/api/api.go | 47 +++++++++++++++++-------- est/auth_context.go | 44 +++++++++++++---------- webhook/options.go | 13 +++++++ webhook/types.go | 3 ++ 5 files changed, 85 insertions(+), 80 deletions(-) diff --git a/authority/provisioner/est_auth.go b/authority/provisioner/est_auth.go index b7737e994..e1699b0aa 100644 --- a/authority/provisioner/est_auth.go +++ b/authority/provisioner/est_auth.go @@ -19,15 +19,6 @@ var ( ErrESTAuthDenied = errors.New("est authentication denied") ) -// ESTAuthMethod identifies the EST authentication method used. -type ESTAuthMethod string - -const ( - ESTAuthMethodTLSClientCertificate ESTAuthMethod = "tls-client-certificate" - ESTAuthMethodTLSExternalClientCertificate ESTAuthMethod = "tls-external-client-certificate" - ESTAuthMethodHTTPBasicAuth ESTAuthMethod = "http-basic-auth" -) - // ESTAuthRequest contains authentication material extracted from the request. type ESTAuthRequest struct { CSR *x509.CertificateRequest @@ -35,8 +26,10 @@ type ESTAuthRequest struct { ClientCertificateChain []*x509.Certificate CARoots []*x509.Certificate CAIntermediates []*x509.Certificate + AuthenticationHeader string BasicAuthUsername string BasicAuthPassword string + BearerToken string } // AuthorizeRequest validates the request against configured EST auth methods. @@ -44,43 +37,11 @@ func (s *EST) AuthorizeRequest(ctx context.Context, req ESTAuthRequest) ([]SignC if s.hasAuthWebhooks() { return s.authorizeWithWebhook(ctx, &req) } - return s.authorizeRequestLocal(ctx, req) -} - -// AuthorizeTLSClientCertificate validates a CA-issued client certificate. -func (s *EST) AuthorizeTLSClientCertificate(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, chain, roots, intermediates []*x509.Certificate) error { - _, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ - CSR: csr, - ClientCertificate: cert, - ClientCertificateChain: chain, - CARoots: roots, - CAIntermediates: intermediates, - }) - return err -} - -// AuthorizeTLSExternalClientCertificate validates a client certificate against external roots. -func (s *EST) AuthorizeTLSExternalClientCertificate(ctx context.Context, csr *x509.CertificateRequest, cert *x509.Certificate, chain []*x509.Certificate) error { - _, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ - CSR: csr, - ClientCertificate: cert, - ClientCertificateChain: chain, - }) - return err -} - -// AuthorizeHTTPBasicAuth validates a username/password pair for EST. -func (s *EST) AuthorizeHTTPBasicAuth(ctx context.Context, csr *x509.CertificateRequest, username, password string) error { - _, err := s.AuthorizeRequest(ctx, ESTAuthRequest{ - CSR: csr, - BasicAuthUsername: username, - BasicAuthPassword: password, - }) - return err + return s.authorizeRequestLocal(req) } // authorizeRequestLocal validates the request using provisioner configuration. -func (s *EST) authorizeRequestLocal(ctx context.Context, req ESTAuthRequest) ([]SignCSROption, error) { +func (s *EST) authorizeRequestLocal(req ESTAuthRequest) ([]SignCSROption, error) { var lastErr error = ErrESTAuthMethodNotFound if req.ClientCertificate != nil { if boolValue(s.EnableTLSClientCertificate, false) { @@ -144,12 +105,14 @@ func (s *EST) authorizeWithWebhook(ctx context.Context, req *ESTAuthRequest) ([] if err != nil { return nil, fmt.Errorf("failed creating webhook request: %w", err) } - case req.hasBasicAuth(): - whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithAuthorizationPrincipal(req.BasicAuthUsername)) + case req.AuthenticationHeader != "": + whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithAuthenticationHeader(req.AuthenticationHeader)) if err != nil { return nil, fmt.Errorf("failed creating webhook request: %w", err) } - whreq.SCEPChallenge = req.BasicAuthPassword + if req.BearerToken != "" { + whreq.BearerToken = req.BearerToken + } default: return nil, errors.New("missing certificate or basic auth for webhook validation") } @@ -187,12 +150,11 @@ func (s *EST) hasAuthWebhooks() bool { // normalizeAuthConfig applies defaults and validates auth configuration. func (s *EST) normalizeAuthConfig() error { + enable := true if !s.authMethodsConfigured() { - enable := true s.EnableHTTPBasicAuth = &enable } if s.EnableHTTPBasicAuth == nil && (s.BasicAuthUsername != "" || s.BasicAuthPassword != "") { - enable := true s.EnableHTTPBasicAuth = &enable } if boolValue(s.EnableHTTPBasicAuth, false) && s.BasicAuthPassword == "" && !s.hasAuthWebhooks() { diff --git a/est/api/api.go b/est/api/api.go index 27ccffed3..562100f7f 100644 --- a/est/api/api.go +++ b/est/api/api.go @@ -26,6 +26,17 @@ const ( maxPayloadSize = 2 << 20 ) +// Util to extract bearer token from request +func BearerToken(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + const prefix = "Bearer " + // Case insensitive prefix match. See Issue 22736. + if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "", false + } + return auth[len(prefix):], true +} + // Route configures the EST routes under the provided router. func Route(r api.Router) { r.MethodFunc(http.MethodGet, "/{provisionerName}/cacerts", getCACerts) @@ -120,9 +131,6 @@ func enrollHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx, err := authContextFromRequest(ctx, r) if err != nil { - if errors.Is(err, errMissingClientCertificateOrBasicAuth) { - w.Header().Set("WWW-Authenticate", `Basic realm="EST"`) - } failWithStatus(w, r, http.StatusUnauthorized, err) return } @@ -178,7 +186,7 @@ func enrollHandler(w http.ResponseWriter, r *http.Request) { writeResponse(w, r, signed, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) } -var errMissingClientCertificateOrBasicAuth = errors.New("missing client certificate or basic auth") +var errMissingAuth = errors.New("missing authentication material") // authContextFromRequest extracts auth material from the request into the context. func authContextFromRequest(ctx context.Context, r *http.Request) (context.Context, error) { @@ -191,16 +199,21 @@ func authContextFromRequest(ctx context.Context, r *http.Request) (context.Conte ctx = est.NewClientCertificateChainContext(ctx, r.TLS.PeerCertificates) } - if username, password, ok := r.BasicAuth(); ok { - ctx = est.NewBasicAuthContext(ctx, est.BasicAuth{ - Username: username, - Password: password, - }) + if authHeader := r.Header.Get("Authorization"); authHeader != "" { + ctx = est.NewAuthenticationHeaderContext(ctx, authHeader) + if token, ok := BearerToken(r); ok { + ctx = est.NewBearerTokenContext(ctx, token) + } else if username, password, ok := r.BasicAuth(); ok { + ctx = est.NewBasicAuthContext(ctx, est.BasicAuth{ + Username: username, + Password: password, + }) + } } if _, ok := est.ClientCertificateFromContext(ctx); !ok { - if _, ok := est.BasicAuthFromContext(ctx); !ok { - return ctx, errMissingClientCertificateOrBasicAuth + if _, ok := est.AuthenticationHeaderFromContext(ctx); !ok { + return ctx, errMissingAuth } } return ctx, nil @@ -220,9 +233,15 @@ func authorizeEnrollRequest(ctx context.Context, csr *x509.CertificateRequest) ( req.ClientCertificate = cert req.ClientCertificateChain, _ = est.ClientCertificateChainFromContext(ctx) } - if auth, ok := est.BasicAuthFromContext(ctx); ok { - req.BasicAuthUsername = auth.Username - req.BasicAuthPassword = auth.Password + if authHeader, ok := est.AuthenticationHeaderFromContext(ctx); ok { + req.AuthenticationHeader = authHeader + if auth, ok := est.BasicAuthFromContext(ctx); ok { + req.BasicAuthUsername = auth.Username + req.BasicAuthPassword = auth.Password + } + if token, ok := est.BearerTokenFromContext(ctx); ok { + req.BearerToken = token + } } opts, err := prov.AuthorizeRequest(ctx, req) diff --git a/est/auth_context.go b/est/auth_context.go index 8f1f6be39..687e17d1e 100644 --- a/est/auth_context.go +++ b/est/auth_context.go @@ -2,29 +2,21 @@ package est import "context" -// AuthMethod describes the authentication method used for an EST request. -type AuthMethod string +// AuthenticationHeaderKey is the context key used to store the EST authentication header. +type AuthenticationHeaderKey struct{} -const ( - AuthMethodTLSClientCertificate AuthMethod = "tls-client-certificate" - AuthMethodTLSExternalClientCertificate AuthMethod = "tls-external-client-certificate" - AuthMethodHTTPBasicAuth AuthMethod = "http-basic-auth" -) - -type authMethodKey struct{} - -// NewAuthMethodContext stores the EST authentication method in the context. -func NewAuthMethodContext(ctx context.Context, method AuthMethod) context.Context { - if method == "" { +// NewAuthenticationHeaderContext stores the EST authentication header in the context. +func NewAuthenticationHeaderContext(ctx context.Context, header string) context.Context { + if header == "" { return ctx } - return context.WithValue(ctx, authMethodKey{}, method) + return context.WithValue(ctx, AuthenticationHeaderKey{}, header) } -// AuthMethodFromContext returns the EST authentication method stored in the context. -func AuthMethodFromContext(ctx context.Context) (AuthMethod, bool) { - method, ok := ctx.Value(authMethodKey{}).(AuthMethod) - return method, ok +// AuthenticationHeaderFromContext returns the EST authentication header stored in the context. +func AuthenticationHeaderFromContext(ctx context.Context) (string, bool) { + header, ok := ctx.Value(AuthenticationHeaderKey{}).(string) + return header, ok } // BasicAuth holds the HTTP basic auth credentials for an EST request. @@ -48,3 +40,19 @@ func BasicAuthFromContext(ctx context.Context) (BasicAuth, bool) { auth, ok := ctx.Value(basicAuthKey{}).(BasicAuth) return auth, ok } + +type BearerTokenKey struct{} + +// NewBearerTokenContext stores the HTTP bearer token in the context. +func NewBearerTokenContext(ctx context.Context, token string) context.Context { + if token == "" { + return ctx + } + return context.WithValue(ctx, BearerTokenKey{}, token) +} + +// BearerTokenFromContext returns the HTTP bearer token stored in the context. +func BearerTokenFromContext(ctx context.Context) (string, bool) { + token, ok := ctx.Value(BearerTokenKey{}).(string) + return token, ok +} diff --git a/webhook/options.go b/webhook/options.go index 869237097..a9d6c970b 100644 --- a/webhook/options.go +++ b/webhook/options.go @@ -118,7 +118,20 @@ func WithX5CCertificate(leaf *x509.Certificate) RequestBodyOption { } rb.X5CCertificate.PublicKey = key } + return nil + } +} +func WithAuthenticationHeader(header string) RequestBodyOption { + return func(rb *RequestBody) error { + rb.AuthenticationHeader = header + return nil + } +} + +func WithBearerToken(token string) RequestBodyOption { + return func(rb *RequestBody) error { + rb.BearerToken = token return nil } } diff --git a/webhook/types.go b/webhook/types.go index c60de7099..e45ef54e3 100644 --- a/webhook/types.go +++ b/webhook/types.go @@ -102,4 +102,7 @@ type RequestBody struct { X5CCertificate *X5CCertificate `json:"x5cCertificate,omitempty"` // Set for X5C, AWS, GCP, and Azure provisioners AuthorizationPrincipal string `json:"authorizationPrincipal,omitempty"` + // Set for EST webhook requests + AuthenticationHeader string `json:"authenticationHeader,omitempty"` + BearerToken string `json:"bearerToken,omitempty"` } From 7a25141ed728fccf0850e805414911537d2dc2e2 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Mon, 30 Mar 2026 16:07:23 +0200 Subject: [PATCH 5/8] add support for tls client certificate forwarded by proxy --- authority/provisioner/est.go | 3 ++- authority/provisioner/est_auth.go | 16 ++++++++++- authority/provisioners.go | 44 ++++++++++++++++--------------- est/api/api.go | 36 +++++++++++++++++++------ est/api/api_test.go | 2 +- est/provisioner.go | 1 + go.mod | 2 +- go.sum | 4 +-- webhook/options.go | 24 +++++++++++++++++ webhook/types.go | 5 ++-- 10 files changed, 100 insertions(+), 37 deletions(-) diff --git a/authority/provisioner/est.go b/authority/provisioner/est.go index 2dc1c4b81..152a31b1c 100644 --- a/authority/provisioner/est.go +++ b/authority/provisioner/est.go @@ -19,7 +19,8 @@ type EST struct { ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` - EnableTLSClientCertificate *bool `json:"enableTLSClientCertificate,omitempty"` + EnableTLSClientCertificate *bool `json:"enableTlsClientCertificate,omitempty"` + ForwardedTLSClientCertHeader string `json:"forwardedTlsClientCertHeader,omitempty"` EnableHTTPBasicAuth *bool `json:"enableHTTPBasicAuth,omitempty"` BasicAuthUsername string `json:"basicAuthUsername,omitempty"` BasicAuthPassword string `json:"basicAuthPassword,omitempty"` diff --git a/authority/provisioner/est_auth.go b/authority/provisioner/est_auth.go index e1699b0aa..29904822a 100644 --- a/authority/provisioner/est_auth.go +++ b/authority/provisioner/est_auth.go @@ -19,6 +19,12 @@ var ( ErrESTAuthDenied = errors.New("est authentication denied") ) +// ClientCertificateConfig holds the EST client certificate authentication configuration. +type ClientCertificateConfig struct { + Enable bool + ForwardedTLSClientCertHeader string +} + // ESTAuthRequest contains authentication material extracted from the request. type ESTAuthRequest struct { CSR *x509.CertificateRequest @@ -32,6 +38,14 @@ type ESTAuthRequest struct { BearerToken string } +func (s *EST) GetClientCertificateConfig() *ClientCertificateConfig { + fmt.Printf("EST provisioner: %#v\n", s) + return &ClientCertificateConfig{ + Enable: boolValue(s.EnableTLSClientCertificate, false), + ForwardedTLSClientCertHeader: s.ForwardedTLSClientCertHeader, + } +} + // AuthorizeRequest validates the request against configured EST auth methods. func (s *EST) AuthorizeRequest(ctx context.Context, req ESTAuthRequest) ([]SignCSROption, error) { if s.hasAuthWebhooks() { @@ -101,7 +115,7 @@ func (s *EST) authorizeWithWebhook(ctx context.Context, req *ESTAuthRequest) ([] ) switch { case req.ClientCertificate != nil: - whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithX509Certificate(nil, req.ClientCertificate)) + whreq, err = webhook.NewRequestBody(webhook.WithX509CertificateRequest(req.CSR), webhook.WithClientCertificate(req.ClientCertificate)) if err != nil { return nil, fmt.Errorf("failed creating webhook request: %w", err) } diff --git a/authority/provisioners.go b/authority/provisioners.go index 4bc42cd00..59cc65752 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -1008,19 +1008,20 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, enableTLSClientCertificate := cfg.EnableTlsClientCertificate enableHTTPBasicAuth := cfg.EnableHttpBasicAuth return &provisioner.EST{ - ID: p.Id, - Type: p.Type.String(), - Name: p.Name, - EnableTLSClientCertificate: &enableTLSClientCertificate, - EnableHTTPBasicAuth: &enableHTTPBasicAuth, - BasicAuthUsername: cfg.BasicAuthUsername, - BasicAuthPassword: cfg.BasicAuthPassword, - ClientCertificateRoots: provisionerPEMToCertificates(cfg.ClientCertificateRoots), - ForceCN: cfg.ForceCn, - IncludeRoot: cfg.IncludeRoot, - MinimumPublicKeyLength: int(cfg.MinimumPublicKeyLength), - Claims: claims, - Options: options, + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + EnableTLSClientCertificate: &enableTLSClientCertificate, + EnableHTTPBasicAuth: &enableHTTPBasicAuth, + ForwardedTLSClientCertHeader: cfg.ForwardedTlsClientCertHeader, + BasicAuthUsername: cfg.BasicAuthUsername, + BasicAuthPassword: cfg.BasicAuthPassword, + ClientCertificateRoots: provisionerPEMToCertificates(cfg.ClientCertificateRoots), + ForceCN: cfg.ForceCn, + IncludeRoot: cfg.IncludeRoot, + MinimumPublicKeyLength: int(cfg.MinimumPublicKeyLength), + Claims: claims, + Options: options, }, nil case *linkedca.ProvisionerDetails_Nebula: var roots []byte @@ -1309,14 +1310,15 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_EST{ EST: &linkedca.ESTProvisioner{ - ForceCn: p.ForceCN, - EnableTlsClientCertificate: p.EnableTLSClientCertificate != nil && *p.EnableTLSClientCertificate, - EnableHttpBasicAuth: p.EnableHTTPBasicAuth != nil && *p.EnableHTTPBasicAuth, - MinimumPublicKeyLength: cast.Int32(p.MinimumPublicKeyLength), - IncludeRoot: p.IncludeRoot, - BasicAuthUsername: p.BasicAuthUsername, - BasicAuthPassword: p.BasicAuthPassword, - ClientCertificateRoots: provisionerPEMToLinkedca(p.ClientCertificateRoots), + ForceCn: p.ForceCN, + EnableTlsClientCertificate: p.EnableTLSClientCertificate != nil && *p.EnableTLSClientCertificate, + EnableHttpBasicAuth: p.EnableHTTPBasicAuth != nil && *p.EnableHTTPBasicAuth, + MinimumPublicKeyLength: cast.Int32(p.MinimumPublicKeyLength), + IncludeRoot: p.IncludeRoot, + BasicAuthUsername: p.BasicAuthUsername, + BasicAuthPassword: p.BasicAuthPassword, + ClientCertificateRoots: provisionerPEMToLinkedca(p.ClientCertificateRoots), + ForwardedTlsClientCertHeader: p.ForwardedTLSClientCertHeader, }, }, }, diff --git a/est/api/api.go b/est/api/api.go index 562100f7f..f36ed8d60 100644 --- a/est/api/api.go +++ b/est/api/api.go @@ -100,7 +100,7 @@ func getCACertsHandler(w http.ResponseWriter, r *http.Request) { return } - writeResponse(w, r, data, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) + writeResponse(w, data, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) } func getCSRAttrs(w http.ResponseWriter, r *http.Request) { @@ -120,7 +120,7 @@ func getCSRAttrsHandler(w http.ResponseWriter, r *http.Request) { attrs = []byte{} } // Minimal implementation: allow provisioner to return nil/empty for "no attributes". - writeResponse(w, r, attrs, "application/csrattrs", http.StatusOK) + writeResponse(w, attrs, "application/csrattrs", http.StatusOK) } func enroll(w http.ResponseWriter, r *http.Request) { @@ -183,7 +183,7 @@ func enrollHandler(w http.ResponseWriter, r *http.Request) { return } - writeResponse(w, r, signed, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) + writeResponse(w, signed, "application/pkcs7-mime; smime-type=certs-only", http.StatusOK) } var errMissingAuth = errors.New("missing authentication material") @@ -193,10 +193,30 @@ func authContextFromRequest(ctx context.Context, r *http.Request) (context.Conte if r.TLS == nil { return ctx, errors.New("missing TLS connection") } - - if len(r.TLS.PeerCertificates) > 0 { - ctx = est.NewClientCertificateContext(ctx, r.TLS.PeerCertificates[0]) - ctx = est.NewClientCertificateChainContext(ctx, r.TLS.PeerCertificates) + prov := est.ProvisionerFromContext(ctx) + cfg := prov.GetClientCertificateConfig() + + if cfg.Enable { + if cfg.ForwardedTLSClientCertHeader != "" { + if forwardedtlsClientCert := r.Header.Get(cfg.ForwardedTLSClientCertHeader); forwardedtlsClientCert != "" { + certPEM, err := base64.StdEncoding.DecodeString(forwardedtlsClientCert) + if err != nil { + // fmt.Printf("failed to decode client cert in forwarded header: %w", err) + } + certs, err := x509.ParseCertificates(certPEM) + if err != nil { + // return ctx, fmt.Errorf("failed to parse certificate from header: %w", err) + } + if len(certs) == 0 { + // return ctx, errors.New("no certificates found in header") + } + ctx = est.NewClientCertificateContext(ctx, certs[0]) + ctx = est.NewClientCertificateChainContext(ctx, certs) + } + } else if len(r.TLS.PeerCertificates) > 0 { + ctx = est.NewClientCertificateContext(ctx, r.TLS.PeerCertificates[0]) + ctx = est.NewClientCertificateChainContext(ctx, r.TLS.PeerCertificates) + } } if authHeader := r.Header.Get("Authorization"); authHeader != "" { @@ -301,7 +321,7 @@ func requireContentType(r *http.Request, want string) error { return nil } -func writeResponse(w http.ResponseWriter, r *http.Request, data []byte, contentType string, status int) { +func writeResponse(w http.ResponseWriter, data []byte, contentType string, status int) { w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Transfer-Encoding", "base64") w.WriteHeader(status) diff --git a/est/api/api_test.go b/est/api/api_test.go index 22cbc5f41..ecaebe564 100644 --- a/est/api/api_test.go +++ b/est/api/api_test.go @@ -57,7 +57,7 @@ func Test_writeResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - writeResponse(tt.args.w, tt.args.r, tt.args.data, tt.args.contentType, tt.args.status) + writeResponse(tt.args.w, tt.args.data, tt.args.contentType, tt.args.status) resp := tt.args.w.(*httptest.ResponseRecorder) assert.Equal(t, tt.args.status, resp.Code) diff --git a/est/provisioner.go b/est/provisioner.go index 512cbf3f8..69bce1713 100644 --- a/est/provisioner.go +++ b/est/provisioner.go @@ -13,6 +13,7 @@ import ( type Provisioner interface { provisioner.Interface GetOptions() *provisioner.Options + GetClientCertificateConfig() *provisioner.ClientCertificateConfig ShouldIncludeRootInChain() bool ShouldIncludeIntermediateInChain() bool GetSigner() (*x509.Certificate, crypto.Signer) diff --git a/go.mod b/go.mod index 8455bff66..30cc34d4e 100644 --- a/go.mod +++ b/go.mod @@ -176,4 +176,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/smallstep/linkedca => github.com/jbpin/linkedca v0.0.0-20251224103807-5e7deb3d4d49 +replace github.com/smallstep/linkedca => github.com/jbpin/linkedca v0.0.0-20260119192234-bf5917d1c897 diff --git a/go.sum b/go.sum index 3aafd6eb4..6ef544f7d 100644 --- a/go.sum +++ b/go.sum @@ -256,8 +256,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jbpin/linkedca v0.0.0-20251224103807-5e7deb3d4d49 h1:rFx5O4hmVJ3fQ90n1PP8fjBWIQoSuBSqbdts6UfnaVk= -github.com/jbpin/linkedca v0.0.0-20251224103807-5e7deb3d4d49/go.mod h1:bUm0HgkgtOjbnwShtNKu3XPuQ6AA/o680sxTfy9vdYk= +github.com/jbpin/linkedca v0.0.0-20260119192234-bf5917d1c897 h1:DTvDVMjQ4IGEn+l7g1MqCZpqKFDU5ZTlRzTDTIMFUZs= +github.com/jbpin/linkedca v0.0.0-20260119192234-bf5917d1c897/go.mod h1:bUm0HgkgtOjbnwShtNKu3XPuQ6AA/o680sxTfy9vdYk= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= diff --git a/webhook/options.go b/webhook/options.go index a9d6c970b..745e06592 100644 --- a/webhook/options.go +++ b/webhook/options.go @@ -122,6 +122,30 @@ func WithX5CCertificate(leaf *x509.Certificate) RequestBodyOption { } } +func WithClientCertificate(cert *x509.Certificate) RequestBodyOption { + return func(rb *RequestBody) error { + certificate, err := x509util.NewCertificateFromX509(cert) + if err != nil { + return err + } + rb.ClientCertificate = &X509Certificate{ + Raw: cert.Raw, + PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + Certificate: certificate, + } + if cert.PublicKey != nil { + key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) + if err != nil { + return err + } + rb.ClientCertificate.PublicKey = key + } + return nil + } +} + func WithAuthenticationHeader(header string) RequestBodyOption { return func(rb *RequestBody) error { rb.AuthenticationHeader = header diff --git a/webhook/types.go b/webhook/types.go index e45ef54e3..a78356c11 100644 --- a/webhook/types.go +++ b/webhook/types.go @@ -103,6 +103,7 @@ type RequestBody struct { // Set for X5C, AWS, GCP, and Azure provisioners AuthorizationPrincipal string `json:"authorizationPrincipal,omitempty"` // Set for EST webhook requests - AuthenticationHeader string `json:"authenticationHeader,omitempty"` - BearerToken string `json:"bearerToken,omitempty"` + AuthenticationHeader string `json:"authenticationHeader,omitempty"` + BearerToken string `json:"bearerToken,omitempty"` + ClientCertificate *X509Certificate `json:"clientCertificate,omitempty"` } From 074ee21b5266f77a697c0e3dd3a48633fb961af0 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Mon, 30 Mar 2026 16:07:23 +0200 Subject: [PATCH 6/8] remove log --- authority/provisioner/est_auth.go | 1 - 1 file changed, 1 deletion(-) diff --git a/authority/provisioner/est_auth.go b/authority/provisioner/est_auth.go index 29904822a..9543e7813 100644 --- a/authority/provisioner/est_auth.go +++ b/authority/provisioner/est_auth.go @@ -39,7 +39,6 @@ type ESTAuthRequest struct { } func (s *EST) GetClientCertificateConfig() *ClientCertificateConfig { - fmt.Printf("EST provisioner: %#v\n", s) return &ClientCertificateConfig{ Enable: boolValue(s.EnableTLSClientCertificate, false), ForwardedTLSClientCertHeader: s.ForwardedTLSClientCertHeader, From b27292dd1fac944eb15469970f9734644f6dde92 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Thu, 16 Apr 2026 11:23:08 +0200 Subject: [PATCH 7/8] fix: EST accept proxy certificate as valid authz --- est/api/api.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/est/api/api.go b/est/api/api.go index f36ed8d60..d2589f723 100644 --- a/est/api/api.go +++ b/est/api/api.go @@ -198,17 +198,20 @@ func authContextFromRequest(ctx context.Context, r *http.Request) (context.Conte if cfg.Enable { if cfg.ForwardedTLSClientCertHeader != "" { + // When a forwarded header is configured, only use it — never + // fall back to r.TLS.PeerCertificates, which would be the + // proxy's own certificate, not the actual client's. if forwardedtlsClientCert := r.Header.Get(cfg.ForwardedTLSClientCertHeader); forwardedtlsClientCert != "" { - certPEM, err := base64.StdEncoding.DecodeString(forwardedtlsClientCert) + certDER, err := base64.StdEncoding.DecodeString(forwardedtlsClientCert) if err != nil { - // fmt.Printf("failed to decode client cert in forwarded header: %w", err) + return ctx, fmt.Errorf("failed to decode client certificate from forwarded header: %w", err) } - certs, err := x509.ParseCertificates(certPEM) + certs, err := x509.ParseCertificates(certDER) if err != nil { - // return ctx, fmt.Errorf("failed to parse certificate from header: %w", err) + return ctx, fmt.Errorf("failed to parse client certificate from forwarded header: %w", err) } if len(certs) == 0 { - // return ctx, errors.New("no certificates found in header") + return ctx, errors.New("no certificates found in forwarded header") } ctx = est.NewClientCertificateContext(ctx, certs[0]) ctx = est.NewClientCertificateChainContext(ctx, certs) From 28048705cf5efa3f8b37d92148596b417f60b17b Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Pin Date: Thu, 16 Apr 2026 12:31:55 +0200 Subject: [PATCH 8/8] fix: build --- authority/authority.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/authority/authority.go b/authority/authority.go index d7e52bf21..aa2240ea2 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -1071,6 +1071,17 @@ func (a *Authority) requiresEST() bool { return false } +// HasACMEProvisioner iterates over the configured provisioners +// and determines if at least one of them is an ACME provisioner. +func (a *Authority) HasACMEProvisioner() bool { + for _, p := range a.config.AuthorityConfig.Provisioners { + if p.GetType() == provisioner.TypeACME { + return true + } + } + return false +} + // getESTProvisionerNames returns the names of the EST provisioners // that are currently available in the CA. func (a *Authority) getESTProvisionerNames() (names []string) {