diff --git a/README.md b/README.md index c6979c1..080a037 100644 --- a/README.md +++ b/README.md @@ -256,7 +256,7 @@ The first setting in `connections` is the default connection. | host | ssh host. Required. | | port | ssh port. Required. | | user | ssh user. Optional. | -| privateKey | private key path. Required. | +| privateKey | private key path. Required. Also supports ssh-agent via `agent://`, `agent://` or `agent://`. | | passPhrase | passPhrase. Optional. | #### DSN (Data Source Name) diff --git a/internal/database/config.go b/internal/database/config.go index 7e7084e..5e27c7e 100644 --- a/internal/database/config.go +++ b/internal/database/config.go @@ -2,11 +2,8 @@ package database import ( "errors" - "fmt" - "os" "github.com/sqls-server/sqls/dialect" - "golang.org/x/crypto/ssh" ) type Proto string @@ -163,34 +160,3 @@ func (s *SSHConfig) Validate() error { } return nil } - -func (s *SSHConfig) Endpoint() string { - return fmt.Sprintf("%s:%d", s.Host, s.Port) -} - -func (s *SSHConfig) ClientConfig() (*ssh.ClientConfig, error) { - buffer, err := os.ReadFile(s.PrivateKey) - if err != nil { - return nil, fmt.Errorf("cannot read SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err) - } - - var key ssh.Signer - if s.PassPhrase != "" { - key, err = ssh.ParsePrivateKeyWithPassphrase(buffer, []byte(s.PassPhrase)) - if err != nil { - return nil, fmt.Errorf("cannot parse SSH private key file with passphrase, PrivateKey=%s, %w", s.PrivateKey, err) - } - } else { - key, err = ssh.ParsePrivateKey(buffer) - if err != nil { - return nil, fmt.Errorf("cannot parse SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err) - } - } - - sshConfig := &ssh.ClientConfig{ - User: s.User, - Auth: []ssh.AuthMethod{ssh.PublicKeys(key)}, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - return sshConfig, nil -} diff --git a/internal/database/mssql.go b/internal/database/mssql.go index 65654ad..6cf2785 100644 --- a/internal/database/mssql.go +++ b/internal/database/mssql.go @@ -6,14 +6,12 @@ import ( "fmt" "log" "net/url" - "os" "strconv" _ "github.com/denisenkom/go-mssqldb" "github.com/jfcote87/sshdb" "github.com/jfcote87/sshdb/mssql" "github.com/sqls-server/sqls/dialect" - "golang.org/x/crypto/ssh" ) func init() { @@ -31,22 +29,9 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { } if dbConnCfg.SSHCfg != nil { - key, err := os.ReadFile(dbConnCfg.SSHCfg.PrivateKey) + cfg, err := dbConnCfg.SSHCfg.ClientConfig() if err != nil { - return nil, fmt.Errorf("unable to open private key") - } - - signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(dbConnCfg.SSHCfg.PassPhrase)) - if err != nil { - return nil, fmt.Errorf("unable to decrypt private key") - } - - cfg := &ssh.ClientConfig{ - User: dbConnCfg.SSHCfg.User, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + return nil, err } remoteAddr := fmt.Sprintf("%s:%d", dbConnCfg.SSHCfg.Host, dbConnCfg.SSHCfg.Port) diff --git a/internal/database/ssh.go b/internal/database/ssh.go new file mode 100644 index 0000000..43274de --- /dev/null +++ b/internal/database/ssh.go @@ -0,0 +1,170 @@ +package database + +import ( + "errors" + "fmt" + "io" + "net" + "os" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func (s *SSHConfig) Endpoint() string { + return fmt.Sprintf("%s:%d", s.Host, s.Port) +} + +func (s *SSHConfig) ClientConfig() (*ssh.ClientConfig, error) { + var ( + auth ssh.AuthMethod + err error + ) + + if strings.HasPrefix(s.PrivateKey, "agent://") { + auth, err = s.sshAgentAuthMethod(strings.TrimPrefix(s.PrivateKey, "agent://")) + if err != nil { + return nil, err + } + } else { + buffer, err := os.ReadFile(s.PrivateKey) + if err != nil { + return nil, fmt.Errorf("cannot read SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err) + } + + var key ssh.Signer + if s.PassPhrase != "" { + key, err = ssh.ParsePrivateKeyWithPassphrase(buffer, []byte(s.PassPhrase)) + if err != nil { + return nil, fmt.Errorf("cannot parse SSH private key file with passphrase, PrivateKey=%s, %w", s.PrivateKey, err) + } + } else { + key, err = ssh.ParsePrivateKey(buffer) + if err != nil { + return nil, fmt.Errorf("cannot parse SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err) + } + } + auth = ssh.PublicKeys(key) + } + + sshConfig := &ssh.ClientConfig{ + User: s.User, + Auth: []ssh.AuthMethod{auth}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + return sshConfig, nil +} + +func (s *SSHConfig) sshAgentAuthMethod(selector string) (ssh.AuthMethod, error) { + if sock := os.Getenv("SSH_AUTH_SOCK"); sock == "" { + return nil, errors.New("SSH_AUTH_SOCK is not set (ssh-agent is not available)") + } + + selector = strings.TrimSpace(selector) + selector = strings.TrimPrefix(selector, "/") + + return ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { + sock := os.Getenv("SSH_AUTH_SOCK") + if sock == "" { + return nil, errors.New("SSH_AUTH_SOCK is not set (ssh-agent is not available)") + } + conn, err := net.Dial("unix", sock) + if err != nil { + return nil, fmt.Errorf("cannot connect to SSH agent, SSH_AUTH_SOCK=%s, %w", sock, err) + } + ag := agent.NewClient(conn) + keys, err := ag.List() + _ = conn.Close() + if err != nil { + return nil, fmt.Errorf("cannot list SSH agent keys, %w", err) + } + if len(keys) == 0 { + return nil, errors.New("no keys available in SSH agent") + } + + matchesHash := func(sel string, pk ssh.PublicKey) bool { + if sel == "" { + return true + } + sha := ssh.FingerprintSHA256(pk) + shaNoPrefix := strings.TrimPrefix(sha, "SHA256:") + md5 := ssh.FingerprintLegacyMD5(pk) + return sel == sha || sel == shaNoPrefix || sel == md5 + } + + matchesName := func(sel string, comment string) bool { + if sel == "" { + return true + } + if comment == sel { + return true + } + return comment != "" && strings.Contains(comment, sel) + } + + matched := make([]ssh.Signer, 0, len(keys)) + for _, k := range keys { + pk, err := ssh.ParsePublicKey(k.Blob) + if err != nil { + continue + } + comment := k.Comment + + if selector == "" || matchesName(selector, comment) || matchesHash(selector, pk) { + matched = append(matched, &sshAgentSigner{pub: pk}) + } + } + + if len(matched) == 0 { + return nil, fmt.Errorf("no matching SSH agent key for selector %q", selector) + } + return matched, nil + }), nil +} + +type sshAgentSigner struct { + pub ssh.PublicKey +} + +var _ ssh.Signer = (*sshAgentSigner)(nil) +var _ ssh.AlgorithmSigner = (*sshAgentSigner)(nil) + +func (s *sshAgentSigner) PublicKey() ssh.PublicKey { return s.pub } + +func (s *sshAgentSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + c, conn, err := dialSSHAgent() + if err != nil { + return nil, err + } + defer conn.Close() + return c.Sign(s.pub, data) +} + +func (s *sshAgentSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { + c, conn, err := dialSSHAgent() + if err != nil { + return nil, err + } + defer conn.Close() + switch algorithm { + case ssh.KeyAlgoRSASHA256: + return c.SignWithFlags(s.pub, data, agent.SignatureFlagRsaSha256) + case ssh.KeyAlgoRSASHA512: + return c.SignWithFlags(s.pub, data, agent.SignatureFlagRsaSha512) + default: + return c.Sign(s.pub, data) + } +} + +func dialSSHAgent() (agent.ExtendedAgent, net.Conn, error) { + sock := os.Getenv("SSH_AUTH_SOCK") + if sock == "" { + return nil, nil, errors.New("SSH_AUTH_SOCK is not set (ssh-agent is not available)") + } + conn, err := net.Dial("unix", sock) + if err != nil { + return nil, nil, fmt.Errorf("cannot connect to SSH agent, SSH_AUTH_SOCK=%s, %w", sock, err) + } + return agent.NewClient(conn), conn, nil +} diff --git a/schema.json b/schema.json index 12d621c..99155a2 100644 --- a/schema.json +++ b/schema.json @@ -83,7 +83,7 @@ "type": "string" }, "privateKey": { - "description": "private key path. Required", + "description": "private key path. Required. Also supports ssh-agent via 'agent://', 'agent://' or 'agent://'", "type": "string" }, "passPhrase": {