Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mikew committed Oct 22, 2024
1 parent 04a828d commit 12493b8
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 37 deletions.
6 changes: 3 additions & 3 deletions src/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import (
"github.com/urfave/cli/v2"

"nvrh/src/context"
"nvrh/src/go_ssh_ext"
"nvrh/src/logger"
"nvrh/src/nvim_helpers"
"nvrh/src/nvrh_base_ssh"
"nvrh/src/nvrh_binary_ssh"
"nvrh/src/nvrh_internal_ssh"
"nvrh/src/nvrh_ssh"
"nvrh/src/ssh_endpoint"
"nvrh/src/ssh_tunnel_info"
)
Expand Down Expand Up @@ -115,12 +115,12 @@ var CliClientOpenCommand = cli.Command{

BrowserScriptPath: fmt.Sprintf("/tmp/nvrh-browser-%s", sessionId),

SshPath: c.String("ssh-path"),
SshPath: sshPath,
Debug: isDebug,
}

if sshPath == "internal" {
sshClient, err := nvrh_ssh.GetSshClientForEndpoint(endpoint)
sshClient, err := go_ssh_ext.GetSshClientForEndpoint(endpoint)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion src/nvrh_ssh/internal_ssh.go → src/go_ssh_ext/main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nvrh_ssh
package go_ssh_ext

import (
"fmt"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nvrh_ssh
package go_ssh_ext

import (
"log/slog"
Expand Down
31 changes: 29 additions & 2 deletions src/nvrh_internal_ssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (c *NvrhInternalSshClient) TunnelSocket(tunnelInfo *ssh_tunnel_info.SshTunn
}

// Listen on the local Unix socket
localListener, err := tunnelInfo.LocalListener(tunnelInfo.Public)
localListener, err := LocalListenerFromTunnelInfo(tunnelInfo)
if err != nil {
slog.Error("Failed to listen on local socket", "err", err)
return
Expand All @@ -87,7 +87,7 @@ func (c *NvrhInternalSshClient) TunnelSocket(tunnelInfo *ssh_tunnel_info.SshTunn
}

// Establish a connection to the remote socket via SSH
remoteConn, err := tunnelInfo.RemoteListener(c.SshClient)
remoteConn, err := RemoteListenerFromTunnelInfo(tunnelInfo, c.SshClient)
if err != nil {
slog.Error("Failed to dial remote socket", "err", err)
localConn.Close()
Expand All @@ -110,3 +110,30 @@ func handleConnection(localConn net.Conn, remoteConn net.Conn) {
// Copy data from remote to local
io.Copy(localConn, remoteConn)
}

func LocalListenerFromTunnelInfo(ti *ssh_tunnel_info.SshTunnelInfo) (net.Listener, error) {
switch ti.Mode {
case "unix":
return net.Listen("unix", ti.LocalSocket)
case "port":
ip := "localhost"
if ti.Public {
ip = "0.0.0.0"
}

return net.Listen("tcp", fmt.Sprintf("%s:%s", ip, ti.LocalSocket))
}

return nil, fmt.Errorf("Invalid mode: %s", ti.Mode)
}

func RemoteListenerFromTunnelInfo(ti *ssh_tunnel_info.SshTunnelInfo, sshClient *ssh.Client) (net.Conn, error) {
switch ti.Mode {
case "unix":
return sshClient.Dial("unix", ti.RemoteSocket)
case "port":
return sshClient.Dial("tcp", fmt.Sprintf("localhost:%s", ti.RemoteSocket))
}

return nil, fmt.Errorf("Invalid mode: %s", ti.Mode)
}
30 changes: 0 additions & 30 deletions src/ssh_tunnel_info/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package ssh_tunnel_info

import (
"fmt"
"net"

"golang.org/x/crypto/ssh"
)

type SshTunnelInfo struct {
Expand All @@ -14,33 +11,6 @@ type SshTunnelInfo struct {
Public bool
}

func (ti *SshTunnelInfo) LocalListener(public bool) (net.Listener, error) {
switch ti.Mode {
case "unix":
return net.Listen("unix", ti.LocalSocket)
case "port":
ip := "localhost"
if public {
ip = "0.0.0.0"
}

return net.Listen("tcp", fmt.Sprintf("%s:%s", ip, ti.LocalSocket))
}

return nil, fmt.Errorf("Invalid mode: %s", ti.Mode)
}

func (ti *SshTunnelInfo) RemoteListener(sshClient *ssh.Client) (net.Conn, error) {
switch ti.Mode {
case "unix":
return sshClient.Dial("unix", ti.RemoteSocket)
case "port":
return sshClient.Dial("tcp", fmt.Sprintf("localhost:%s", ti.RemoteSocket))
}

return nil, fmt.Errorf("Invalid mode: %s", ti.Mode)
}

func (ti *SshTunnelInfo) BoundToIp() string {
if ti.Mode == "unix" {
return fmt.Sprintf("%s:%s", ti.LocalSocket, ti.RemoteSocket)
Expand Down

0 comments on commit 12493b8

Please sign in to comment.