diff --git a/go.mod b/go.mod index 23b2ec0..7c9ed6f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/godror/godror v0.41.0 github.com/google/go-cmp v0.5.9 github.com/jackc/pgx/v4 v4.18.1 + github.com/jfcote87/sshdb v0.5.3 github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect github.com/sourcegraph/jsonrpc2 v0.2.0 github.com/urfave/cli/v2 v2.27.0 diff --git a/go.sum b/go.sum index 3836588..180aebc 100644 --- a/go.sum +++ b/go.sum @@ -136,6 +136,8 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jfcote87/sshdb v0.5.3 h1:c0I3+ScEbT0mjvpoY8qbVNfR4Y9Q5JWh52WnmjfsuV0= +github.com/jfcote87/sshdb v0.5.3/go.mod h1:YIGPRF3vtRG1Cvpwa1LaQvmrsIEPKC9WqF+ZU5rInUw= github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901 h1:rp+c0RAYOWj8l6qbCUTSiRLG/iKnW3K3/QfPPuSsBt4= github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901/go.mod h1:Z86h9688Y0wesXCyonoVr47MasHilkuLMqGhRZ4Hpak= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM= diff --git a/internal/database/mssql.go b/internal/database/mssql.go index 60611f2..8667603 100644 --- a/internal/database/mssql.go +++ b/internal/database/mssql.go @@ -1,6 +1,7 @@ package database import ( + "os" "context" "database/sql" "fmt" @@ -10,6 +11,8 @@ import ( _ "github.com/denisenkom/go-mssqldb" "github.com/sqls-server/sqls/dialect" + "github.com/jfcote87/sshdb" + "github.com/jfcote87/sshdb/mssql" "golang.org/x/crypto/ssh" ) @@ -21,7 +24,6 @@ func init() { func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { var ( conn *sql.DB - sshConn *ssh.Client ) dsn, err := genMssqlConfig(dbConnCfg) if err != nil { @@ -29,13 +31,44 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { } if dbConnCfg.SSHCfg != nil { - return nil, fmt.Errorf("connect via SSH is not supported") - } - dbConn, err := sql.Open("sqlserver", dsn) - if err != nil { - return nil, err + key, err := os.ReadFile(dbConnCfg.SSHCfg.PrivateKey) + if err != nil { + return nil, fmt.Errorf("unable to open private key") + } + + signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(dbConnCfg.SSHCfg.PassPhrase)) + if err != nil { + return nil, fmt.Errorf("unable to decrypt private key") + } + + cfg := &ssh.ClientConfig { + User: dbConnCfg.SSHCfg.User, + Auth: []ssh.AuthMethod { + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + remoteAddr := fmt.Sprintf("%s:%d", dbConnCfg.SSHCfg.Host, dbConnCfg.SSHCfg.Port) + + tunnel, err := sshdb.New(cfg, remoteAddr) + if err != nil { + return nil, fmt.Errorf("%v", err) + } + + connector, err := tunnel.OpenConnector(mssql.TunnelDriver, dsn) + if err != nil { + return nil, err + } + + conn = sql.OpenDB(connector) + } else { + conn, err = sql.Open("mssql", dsn) + if err != nil { + return nil, err + } } - conn = dbConn + if err = conn.Ping(); err != nil { return nil, err } @@ -45,7 +78,6 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { return &DBConnection{ Conn: conn, - SSHConn: sshConn, }, nil }