diff --git a/cli/cli.go b/cli/cli.go index a8b79ef0..135cf72c 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -1,15 +1,18 @@ package cli import ( + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" + "net/url" "os" "strings" "github.com/square/certigo/cli/terminal" "github.com/square/certigo/lib" + "github.com/square/certigo/proxy" "github.com/square/certigo/starttls" "gopkg.in/alecthomas/kingpin.v2" ) @@ -49,6 +52,15 @@ var ( verifyName = verify.Flag("name", "Server name to verify certificate against.").Short('n').Required().String() verifyCaPath = verify.Flag("ca", "Path to CA bundle (system default if unspecified).").ExistingFile() verifyJSON = verify.Flag("json", "Write output as machine-readable JSON format.").Short('j').Bool() + + tlsProxy = app.Command("proxy", "Proxy mTLS to a server and print its certificate(s).") + proxyPort = tlsProxy.Arg("port", "Port to listen on.").Required().Int() + proxyTo = tlsProxy.Arg("target", "URL to proxy to.").Required().String() + proxyName = tlsProxy.Flag("name", "Override the server name used for Server Name Inidication (SNI).").String() + proxyCaPath = tlsProxy.Flag("ca", "Path to CA bundle (system default if unspecified).").ExistingFile() + proxyCert = tlsProxy.Flag("cert", "Client certificate chain for connecting to server (PEM).").ExistingFile() + proxyKey = tlsProxy.Flag("key", "Private key for client certificate, if not in same file (PEM).").ExistingFile() + proxyVerifyExpectedName = tlsProxy.Flag("expected-name", "Name expected in the server TLS certificate. Defaults to name from SNI or, if SNI not overidden, the hostname to connect to.").String() ) const ( @@ -237,6 +249,45 @@ func Run(args []string, tty terminal.Terminal) int { if verifyResult.Error != "" { return 1 } + case tlsProxy.FullCommand(): + printer := func(verification *lib.SimpleVerification, conn *tls.ConnectionState) { + fmt.Fprintln(stdout, lib.EncodeTLSInfoToText(conn, nil)) + for i, cert := range conn.PeerCertificates { + fmt.Fprintf(stdout, "** CERTIFICATE %d **\n", i+1) + fmt.Fprintf(stdout, "%s\n\n", lib.EncodeX509ToText(cert, terminalWidth, *verbose)) + } + lib.PrintVerifyResult(stdout, *verification) + } + + proxyOptions := &proxy.Options{ + Inspect: printer, + Port: *proxyPort, + Target: *proxyTo, + ServerName: *proxyName, + CAPath: *proxyCaPath, + CertPath: *proxyCert, + KeyPath: *proxyKey, + } + + switch { + case *proxyVerifyExpectedName != "": + // Use the explicitly provided name + proxyOptions.ExpectedName = *proxyVerifyExpectedName + case *proxyName != "": + // Use the provided SNI + proxyOptions.ExpectedName = *proxyName + default: + // Use the hostname/IP from the target URL + url, err := url.Parse(*proxyTo) + if err != nil { + return printErr("error parsing URL %q: %v\n", *proxyTo, err) + } + proxyOptions.ExpectedName = strings.Split(url.Host, ":")[0] + } + + if err := proxy.ListenAndServe(proxyOptions); err != nil { + return printErr("%s\n", err) + } } return 0 } diff --git a/lib/verify.go b/lib/verify.go index 549f7284..15d413a5 100644 --- a/lib/verify.go +++ b/lib/verify.go @@ -111,7 +111,7 @@ func caBundle(caPath string) (*x509.CertPool, error) { return bundle, nil } -func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caPath string) SimpleVerification { +func VerifyChainWithPool(certs []*x509.Certificate, ocspStaple []byte, expectedName string, roots *x509.CertPool) SimpleVerification { result := SimpleVerification{ Chains: [][]simpleVerifyCert{}, OCSPWasStapled: ocspStaple != nil, @@ -127,11 +127,6 @@ func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caP intermediates.AddCert(certs[i]) } - roots, err := caBundle(caPath) - if err != nil { - result.Error = fmt.Sprintf("%s", err) - return result - } // expectedName could be a hostname or could be a SPIFFE ID (spiffe://...) // x509 package doesn't support verifying SPIFFE IDs. When we're expecting a SPIFFE ID, we tell // Certificate.Verify below to skip name matching, and then we perform our own matching later @@ -188,6 +183,14 @@ func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caP return result } +func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caPath string) SimpleVerification { + roots, err := caBundle(caPath) + if err != nil { + return SimpleVerification{Error: err.Error()} + } + return VerifyChainWithPool(certs, ocspStaple, expectedName, roots) +} + func fmtCert(cert simpleVerifyCert) string { name := cert.Name if cert.IsSelfSigned { diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 00000000..6d0091e9 --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,103 @@ +package proxy + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "os" + + "github.com/square/certigo/lib" +) + +type Options struct { + // Called after each TLS connection is verified (in a separate goroutine). Required. + Inspect func(verification *lib.SimpleVerification, conn *tls.ConnectionState) + // Port to listen on. Required. + Port int + // URL to proxy to (should be https://...) + Target string + // SNI override. Optional. + ServerName string + // Roots to verify server certificate against. If empty, the system cert store is used. + CAPath string + // Client certificate. Optional. + CertPath string + // Client key. Optional. + KeyPath string + // Expected name in the server certificate. + ExpectedName string +} + +func ListenAndServe(opts *Options) error { + // Load the TLS roots and client cert. + var roots *x509.CertPool + if opts.CAPath != "" { + rootPEM, err := os.ReadFile(opts.CAPath) + if err != nil { + return err + } + roots = x509.NewCertPool() + roots.AppendCertsFromPEM(rootPEM) + } + + var clientCert []tls.Certificate + if opts.CertPath != "" { + keyPath := opts.KeyPath + if keyPath == "" { + keyPath = opts.CertPath + } + cert, err := tls.LoadX509KeyPair(opts.CertPath, opts.KeyPath) + if err != nil { + return err + } + clientCert = append(clientCert, cert) + } + + // Start a goroutine to print verification results. + type result struct { + verification *lib.SimpleVerification + state *tls.ConnectionState + } + results := make(chan result) + go func() { + for result := range results { + opts.Inspect(result.verification, result.state) + } + }() + + verify := func(conn tls.ConnectionState) error { + verification := lib.VerifyChainWithPool(conn.PeerCertificates, conn.OCSPResponse, opts.ExpectedName, roots) + results <- result{&verification, &conn} + if verification.Error != "" { + return errors.New(verification.Error) + } + return nil + } + + // Create a reverse proxy to the target. + url, err := url.Parse(opts.Target) + if err != nil { + return err + } + proxy := httputil.NewSingleHostReverseProxy(url) + director := proxy.Director + proxy.Director = func(r *http.Request) { + // NewSingleHostReverseProxy doesn't overwrite Host. Do so now and + // then forward to the original director... + r.Host = "" + director(r) + } + proxy.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: clientCert, + ServerName: opts.ServerName, + InsecureSkipVerify: true, + VerifyConnection: verify, + }, + } + return http.ListenAndServe(fmt.Sprintf(":%d", opts.Port), proxy) +}