Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions cli/cli.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package cli

import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/url"
"os"
"strings"

"github.com/square/certigo/cli/terminal"
"github.com/square/certigo/lib"
"github.com/square/certigo/proxy"
"github.com/square/certigo/starttls"
"gopkg.in/alecthomas/kingpin.v2"
)
Expand Down Expand Up @@ -49,6 +52,15 @@ var (
verifyName = verify.Flag("name", "Server name to verify certificate against.").Short('n').Required().String()
verifyCaPath = verify.Flag("ca", "Path to CA bundle (system default if unspecified).").ExistingFile()
verifyJSON = verify.Flag("json", "Write output as machine-readable JSON format.").Short('j').Bool()

tlsProxy = app.Command("proxy", "Proxy mTLS to a server and print its certificate(s).")
proxyPort = tlsProxy.Arg("port", "Port to listen on.").Required().Int()
proxyTo = tlsProxy.Arg("target", "URL to proxy to.").Required().String()
proxyName = tlsProxy.Flag("name", "Override the server name used for Server Name Inidication (SNI).").String()
proxyCaPath = tlsProxy.Flag("ca", "Path to CA bundle (system default if unspecified).").ExistingFile()
proxyCert = tlsProxy.Flag("cert", "Client certificate chain for connecting to server (PEM).").ExistingFile()
proxyKey = tlsProxy.Flag("key", "Private key for client certificate, if not in same file (PEM).").ExistingFile()
proxyVerifyExpectedName = tlsProxy.Flag("expected-name", "Name expected in the server TLS certificate. Defaults to name from SNI or, if SNI not overidden, the hostname to connect to.").String()
)

const (
Expand Down Expand Up @@ -237,6 +249,45 @@ func Run(args []string, tty terminal.Terminal) int {
if verifyResult.Error != "" {
return 1
}
case tlsProxy.FullCommand():
printer := func(verification *lib.SimpleVerification, conn *tls.ConnectionState) {
fmt.Fprintln(stdout, lib.EncodeTLSInfoToText(conn, nil))
for i, cert := range conn.PeerCertificates {
fmt.Fprintf(stdout, "** CERTIFICATE %d **\n", i+1)
fmt.Fprintf(stdout, "%s\n\n", lib.EncodeX509ToText(cert, terminalWidth, *verbose))
}
lib.PrintVerifyResult(stdout, *verification)
}

proxyOptions := &proxy.Options{
Inspect: printer,
Port: *proxyPort,
Target: *proxyTo,
ServerName: *proxyName,
CAPath: *proxyCaPath,
CertPath: *proxyCert,
KeyPath: *proxyKey,
}

switch {
case *proxyVerifyExpectedName != "":
// Use the explicitly provided name
proxyOptions.ExpectedName = *proxyVerifyExpectedName
case *proxyName != "":
// Use the provided SNI
proxyOptions.ExpectedName = *proxyName
default:
// Use the hostname/IP from the target URL
url, err := url.Parse(*proxyTo)
if err != nil {
return printErr("error parsing URL %q: %v\n", *proxyTo, err)
}
proxyOptions.ExpectedName = strings.Split(url.Host, ":")[0]
}

if err := proxy.ListenAndServe(proxyOptions); err != nil {
return printErr("%s\n", err)
}
}
return 0
}
Expand Down
15 changes: 9 additions & 6 deletions lib/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func caBundle(caPath string) (*x509.CertPool, error) {
return bundle, nil
}

func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caPath string) SimpleVerification {
func VerifyChainWithPool(certs []*x509.Certificate, ocspStaple []byte, expectedName string, roots *x509.CertPool) SimpleVerification {
result := SimpleVerification{
Chains: [][]simpleVerifyCert{},
OCSPWasStapled: ocspStaple != nil,
Expand All @@ -127,11 +127,6 @@ func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caP
intermediates.AddCert(certs[i])
}

roots, err := caBundle(caPath)
if err != nil {
result.Error = fmt.Sprintf("%s", err)
return result
}
// expectedName could be a hostname or could be a SPIFFE ID (spiffe://...)
// x509 package doesn't support verifying SPIFFE IDs. When we're expecting a SPIFFE ID, we tell
// Certificate.Verify below to skip name matching, and then we perform our own matching later
Expand Down Expand Up @@ -188,6 +183,14 @@ func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caP
return result
}

func VerifyChain(certs []*x509.Certificate, ocspStaple []byte, expectedName, caPath string) SimpleVerification {
roots, err := caBundle(caPath)
if err != nil {
return SimpleVerification{Error: err.Error()}
}
return VerifyChainWithPool(certs, ocspStaple, expectedName, roots)
}

func fmtCert(cert simpleVerifyCert) string {
name := cert.Name
if cert.IsSelfSigned {
Expand Down
103 changes: 103 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package proxy

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"os"

"github.com/square/certigo/lib"
)

type Options struct {
// Called after each TLS connection is verified (in a separate goroutine). Required.
Inspect func(verification *lib.SimpleVerification, conn *tls.ConnectionState)
// Port to listen on. Required.
Port int
// URL to proxy to (should be https://...)
Target string
// SNI override. Optional.
ServerName string
// Roots to verify server certificate against. If empty, the system cert store is used.
CAPath string
// Client certificate. Optional.
CertPath string
// Client key. Optional.
KeyPath string
// Expected name in the server certificate.
ExpectedName string
}

func ListenAndServe(opts *Options) error {
// Load the TLS roots and client cert.
var roots *x509.CertPool
if opts.CAPath != "" {
rootPEM, err := os.ReadFile(opts.CAPath)
if err != nil {
return err
}
roots = x509.NewCertPool()
roots.AppendCertsFromPEM(rootPEM)
}

var clientCert []tls.Certificate
if opts.CertPath != "" {
keyPath := opts.KeyPath
if keyPath == "" {
keyPath = opts.CertPath
}
cert, err := tls.LoadX509KeyPair(opts.CertPath, opts.KeyPath)
if err != nil {
return err
}
clientCert = append(clientCert, cert)
}

// Start a goroutine to print verification results.
type result struct {
verification *lib.SimpleVerification
state *tls.ConnectionState
}
results := make(chan result)
go func() {
for result := range results {
opts.Inspect(result.verification, result.state)
}
}()

verify := func(conn tls.ConnectionState) error {
verification := lib.VerifyChainWithPool(conn.PeerCertificates, conn.OCSPResponse, opts.ExpectedName, roots)
results <- result{&verification, &conn}
if verification.Error != "" {
return errors.New(verification.Error)
}
return nil
}

// Create a reverse proxy to the target.
url, err := url.Parse(opts.Target)
if err != nil {
return err
}
proxy := httputil.NewSingleHostReverseProxy(url)
director := proxy.Director
proxy.Director = func(r *http.Request) {
// NewSingleHostReverseProxy doesn't overwrite Host. Do so now and
// then forward to the original director...
r.Host = ""
director(r)
}
proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: clientCert,
ServerName: opts.ServerName,
InsecureSkipVerify: true,
VerifyConnection: verify,
},
}
return http.ListenAndServe(fmt.Sprintf(":%d", opts.Port), proxy)
}