diff --git a/server.go b/server.go index e67593c..e1adf09 100644 --- a/server.go +++ b/server.go @@ -1,14 +1,14 @@ package main import ( - "crypto/x509" "encoding/pem" "fmt" "io/ioutil" "log" "net" "os" - "regexp" + + "github.com/ScaleFT/sshkeys" ) func server(config Config) { @@ -49,7 +49,6 @@ func server(config Config) { } func readKeyData(config *Config) *[]byte { - // var keyData []byte var err error keyData := []byte(os.Getenv(KEY_DATA_ENV_VAR)) if len(keyData) == 0 { @@ -59,23 +58,25 @@ func readKeyData(config *Config) *[]byte { log.Fatalf("ERROR reading keyfile %s: %s!\n", config.KeyPath, err) } } - pemBlock, _ := pem.Decode(keyData) - if pemBlock != nil { - if x509.IsEncryptedPEMBlock(pemBlock) { - fmt.Println("Decrypting private key with passphrase...") - decoded, err := x509.DecryptPEMBlock(pemBlock, []byte(config.Pwd)) - if err == nil { - header := `PRIVATE KEY` // default key type in header - matcher := regexp.MustCompile("-----BEGIN (.*)-----") - if matches := matcher.FindSubmatch(keyData); len(matches) > 1 { - header = string(matches[1]) - } - keyData = pem.EncodeToMemory( - &pem.Block{Type: header, Bytes: decoded}) - } else { - fmt.Printf("Error decrypting PEM-encoded secret: %s\n", err) - } - } + + passphrase := []byte(config.Pwd) + var privateKey interface{} + fmt.Println("Decrypting private key with passphrase...") + privateKey, err = sshkeys.ParseEncryptedRawPrivateKey(keyData, passphrase) + if err != nil { + log.Fatalf("ERROR parsing encrypted key %s!\n", err) } + + fmt.Println("Converting decrypted key to RSA key...") + opts := sshkeys.MarshalOptions{Format: sshkeys.FormatClassicPEM} + var privateKeyAsPem []byte + privateKeyAsPem, err = sshkeys.Marshal(privateKey, &opts) + if err != nil { + log.Fatalf("ERROR converting private key to unencrypted PEM format %s!\n", err) + } + + pemBlock, _ := pem.Decode(privateKeyAsPem) + keyData = pem.EncodeToMemory( + &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pemBlock.Bytes}) return &keyData }