Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ The first setting in `connections` is the default connection.
| host | ssh host. Required. |
| port | ssh port. Required. |
| user | ssh user. Optional. |
| privateKey | private key path. Required. |
| privateKey | private key path. Required. Also supports ssh-agent via `agent://`, `agent://<name>` or `agent://<hash>`. |
| passPhrase | passPhrase. Optional. |

#### DSN (Data Source Name)
Expand Down
34 changes: 0 additions & 34 deletions internal/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package database

import (
"errors"
"fmt"
"os"

"github.com/sqls-server/sqls/dialect"
"golang.org/x/crypto/ssh"
)

type Proto string
Expand Down Expand Up @@ -163,34 +160,3 @@ func (s *SSHConfig) Validate() error {
}
return nil
}

func (s *SSHConfig) Endpoint() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}

func (s *SSHConfig) ClientConfig() (*ssh.ClientConfig, error) {
buffer, err := os.ReadFile(s.PrivateKey)
if err != nil {
return nil, fmt.Errorf("cannot read SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err)
}

var key ssh.Signer
if s.PassPhrase != "" {
key, err = ssh.ParsePrivateKeyWithPassphrase(buffer, []byte(s.PassPhrase))
if err != nil {
return nil, fmt.Errorf("cannot parse SSH private key file with passphrase, PrivateKey=%s, %w", s.PrivateKey, err)
}
} else {
key, err = ssh.ParsePrivateKey(buffer)
if err != nil {
return nil, fmt.Errorf("cannot parse SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err)
}
}

sshConfig := &ssh.ClientConfig{
User: s.User,
Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
return sshConfig, nil
}
19 changes: 2 additions & 17 deletions internal/database/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ import (
"fmt"
"log"
"net/url"
"os"
"strconv"

_ "github.com/denisenkom/go-mssqldb"
"github.com/jfcote87/sshdb"
"github.com/jfcote87/sshdb/mssql"
"github.com/sqls-server/sqls/dialect"
"golang.org/x/crypto/ssh"
)

func init() {
Expand All @@ -31,22 +29,9 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) {
}

if dbConnCfg.SSHCfg != nil {
key, err := os.ReadFile(dbConnCfg.SSHCfg.PrivateKey)
cfg, err := dbConnCfg.SSHCfg.ClientConfig()
if err != nil {
return nil, fmt.Errorf("unable to open private key")
}

signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(dbConnCfg.SSHCfg.PassPhrase))
if err != nil {
return nil, fmt.Errorf("unable to decrypt private key")
}

cfg := &ssh.ClientConfig{
User: dbConnCfg.SSHCfg.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
return nil, err
}

remoteAddr := fmt.Sprintf("%s:%d", dbConnCfg.SSHCfg.Host, dbConnCfg.SSHCfg.Port)
Expand Down
170 changes: 170 additions & 0 deletions internal/database/ssh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package database

import (
"errors"
"fmt"
"io"
"net"
"os"
"strings"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

func (s *SSHConfig) Endpoint() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}

func (s *SSHConfig) ClientConfig() (*ssh.ClientConfig, error) {
var (
auth ssh.AuthMethod
err error
)

if strings.HasPrefix(s.PrivateKey, "agent://") {
auth, err = s.sshAgentAuthMethod(strings.TrimPrefix(s.PrivateKey, "agent://"))
if err != nil {
return nil, err
}
} else {
buffer, err := os.ReadFile(s.PrivateKey)
if err != nil {
return nil, fmt.Errorf("cannot read SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err)
}

var key ssh.Signer
if s.PassPhrase != "" {
key, err = ssh.ParsePrivateKeyWithPassphrase(buffer, []byte(s.PassPhrase))
if err != nil {
return nil, fmt.Errorf("cannot parse SSH private key file with passphrase, PrivateKey=%s, %w", s.PrivateKey, err)
}
} else {
key, err = ssh.ParsePrivateKey(buffer)
if err != nil {
return nil, fmt.Errorf("cannot parse SSH private key file, PrivateKey=%s, %w", s.PrivateKey, err)
}
}
auth = ssh.PublicKeys(key)
}

sshConfig := &ssh.ClientConfig{
User: s.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
return sshConfig, nil
}

func (s *SSHConfig) sshAgentAuthMethod(selector string) (ssh.AuthMethod, error) {
if sock := os.Getenv("SSH_AUTH_SOCK"); sock == "" {
return nil, errors.New("SSH_AUTH_SOCK is not set (ssh-agent is not available)")
}

selector = strings.TrimSpace(selector)
selector = strings.TrimPrefix(selector, "/")

return ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
sock := os.Getenv("SSH_AUTH_SOCK")
if sock == "" {
return nil, errors.New("SSH_AUTH_SOCK is not set (ssh-agent is not available)")
}
conn, err := net.Dial("unix", sock)
if err != nil {
return nil, fmt.Errorf("cannot connect to SSH agent, SSH_AUTH_SOCK=%s, %w", sock, err)
}
ag := agent.NewClient(conn)
keys, err := ag.List()
_ = conn.Close()
if err != nil {
return nil, fmt.Errorf("cannot list SSH agent keys, %w", err)
}
if len(keys) == 0 {
return nil, errors.New("no keys available in SSH agent")
}

matchesHash := func(sel string, pk ssh.PublicKey) bool {
if sel == "" {
return true
}
sha := ssh.FingerprintSHA256(pk)
shaNoPrefix := strings.TrimPrefix(sha, "SHA256:")
md5 := ssh.FingerprintLegacyMD5(pk)
return sel == sha || sel == shaNoPrefix || sel == md5
}

matchesName := func(sel string, comment string) bool {
if sel == "" {
return true
}
if comment == sel {
return true
}
return comment != "" && strings.Contains(comment, sel)
}

matched := make([]ssh.Signer, 0, len(keys))
for _, k := range keys {
pk, err := ssh.ParsePublicKey(k.Blob)
if err != nil {
continue
}
comment := k.Comment

if selector == "" || matchesName(selector, comment) || matchesHash(selector, pk) {
matched = append(matched, &sshAgentSigner{pub: pk})
}
}

if len(matched) == 0 {
return nil, fmt.Errorf("no matching SSH agent key for selector %q", selector)
}
return matched, nil
}), nil
}

type sshAgentSigner struct {
pub ssh.PublicKey
}

var _ ssh.Signer = (*sshAgentSigner)(nil)
var _ ssh.AlgorithmSigner = (*sshAgentSigner)(nil)

func (s *sshAgentSigner) PublicKey() ssh.PublicKey { return s.pub }

func (s *sshAgentSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
c, conn, err := dialSSHAgent()
if err != nil {
return nil, err
}
defer conn.Close()
return c.Sign(s.pub, data)
}

func (s *sshAgentSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
c, conn, err := dialSSHAgent()
if err != nil {
return nil, err
}
defer conn.Close()
switch algorithm {
case ssh.KeyAlgoRSASHA256:
return c.SignWithFlags(s.pub, data, agent.SignatureFlagRsaSha256)
case ssh.KeyAlgoRSASHA512:
return c.SignWithFlags(s.pub, data, agent.SignatureFlagRsaSha512)
default:
return c.Sign(s.pub, data)
}
}

func dialSSHAgent() (agent.ExtendedAgent, net.Conn, error) {
sock := os.Getenv("SSH_AUTH_SOCK")
if sock == "" {
return nil, nil, errors.New("SSH_AUTH_SOCK is not set (ssh-agent is not available)")
}
conn, err := net.Dial("unix", sock)
if err != nil {
return nil, nil, fmt.Errorf("cannot connect to SSH agent, SSH_AUTH_SOCK=%s, %w", sock, err)
}
return agent.NewClient(conn), conn, nil
}
2 changes: 1 addition & 1 deletion schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"type": "string"
},
"privateKey": {
"description": "private key path. Required",
"description": "private key path. Required. Also supports ssh-agent via 'agent://', 'agent://<name>' or 'agent://<hash>'",
"type": "string"
},
"passPhrase": {
Expand Down