diff --git a/pkg/sshclient/bastion.go b/pkg/sshclient/bastion.go index c879bcc33..f10bddda0 100644 --- a/pkg/sshclient/bastion.go +++ b/pkg/sshclient/bastion.go @@ -84,13 +84,13 @@ func HostKey(host string) ssh.PublicKey { return nil } -func CreateBastion(_url *url.URL, passPhrase string, identity string, initial net.Conn, connect ConnectCallback) (Bastion, error) { +func CreateBastion(_url *url.URL, passPhrase string, identity string, initial net.Conn, connect ConnectCallback) (*Bastion, error) { var authMethods []ssh.AuthMethod if len(identity) > 0 { s, err := PublicKey(identity, []byte(passPhrase)) if err != nil { - return Bastion{}, errors.Wrapf(err, "failed to parse identity %q", identity) + return nil, errors.Wrapf(err, "failed to parse identity %q", identity) } authMethods = append(authMethods, ssh.PublicKeys(s)) } @@ -100,7 +100,7 @@ func CreateBastion(_url *url.URL, passPhrase string, identity string, initial ne } if len(authMethods) == 0 { - return Bastion{}, errors.New("No available auth methods") + return nil, errors.New("No available auth methods") } port := _url.Port() @@ -149,7 +149,7 @@ func CreateBastion(_url *url.URL, passPhrase string, identity string, initial ne } bastion := Bastion{nil, config, _url.Hostname(), port, _url.Path, connect} - return bastion, bastion.reconnect(context.Background(), initial) + return &bastion, bastion.reconnect(context.Background(), initial) } func (bastion *Bastion) Reconnect(ctx context.Context) error { diff --git a/pkg/sshclient/ssh_forwarder.go b/pkg/sshclient/ssh_forwarder.go index 2cd7ab235..75c43991d 100644 --- a/pkg/sshclient/ssh_forwarder.go +++ b/pkg/sshclient/ssh_forwarder.go @@ -171,14 +171,17 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity return &SSHForward{}, err } - bastion, err := CreateBastion(dest, passphrase, identity, conn, connectFunc) + createBastion := func() (*Bastion, error) { + return CreateBastion(dest, passphrase, identity, conn, connectFunc) + } + bastion, err := retry(ctx, createBastion, "Waiting for sshd") if err != nil { - return &SSHForward{}, err + return &SSHForward{}, fmt.Errorf("setupProxy failed: %w", err) } logrus.Debugf("Socket forward established: %s -> %s\n", socketURI.Path, dest.Path) - return &SSHForward{listener, &bastion, socketURI}, nil + return &SSHForward{listener, bastion, socketURI}, nil } const maxRetries = 60