diff --git a/src/mod/sshprox/sshprox.go b/src/mod/sshprox/sshprox.go index ed1b92c..113cb45 100644 --- a/src/mod/sshprox/sshprox.go +++ b/src/mod/sshprox/sshprox.go @@ -50,21 +50,6 @@ func NewSSHProxyManager() *Manager { } } -// Get the next free port in the list -func (m *Manager) GetNextPort() int { - nextPort := m.StartingPort - occupiedPort := make(map[int]bool) - for _, instance := range m.Instances { - occupiedPort[instance.AssignedPort] = true - } - for { - if !occupiedPort[nextPort] { - return nextPort - } - nextPort++ - } -} - func (m *Manager) HandleHttpByInstanceId(instanceId string, w http.ResponseWriter, r *http.Request) { targetInstance, err := m.GetInstanceById(instanceId) if err != nil { @@ -168,6 +153,17 @@ func (i *Instance) CreateNewConnection(listenPort int, username string, remoteIp if username != "" { connAddr = username + "@" + remoteIpAddr } + + //Trim the space in the username and remote address + username = strings.TrimSpace(username) + remoteIpAddr = strings.TrimSpace(remoteIpAddr) + + //Validate the username and remote address + err := ValidateUsernameAndRemoteAddr(username, remoteIpAddr) + if err != nil { + return err + } + configPath := filepath.Join(filepath.Dir(i.ExecPath), ".gotty") title := username + "@" + remoteIpAddr if remotePort != 22 { diff --git a/src/mod/sshprox/sshprox_test.go b/src/mod/sshprox/sshprox_test.go new file mode 100644 index 0000000..36a9ab5 --- /dev/null +++ b/src/mod/sshprox/sshprox_test.go @@ -0,0 +1,66 @@ +package sshprox + +import ( + "testing" +) + +func TestInstance_Destroy(t *testing.T) { + manager := NewSSHProxyManager() + instance, err := manager.NewSSHProxy("/tmp") + if err != nil { + t.Fatalf("Failed to create new SSH proxy: %v", err) + } + + instance.Destroy() + + if len(manager.Instances) != 0 { + t.Errorf("Expected Instances to be empty, got %d", len(manager.Instances)) + } +} + +func TestInstance_ValidateUsernameAndRemoteAddr(t *testing.T) { + tests := []struct { + username string + remoteAddr string + expectError bool + }{ + {"validuser", "127.0.0.1", false}, + {"valid.user", "example.com", false}, + {"; bash ;", "example.com", true}, + {"valid-user", "example.com", false}, + {"invalid user", "127.0.0.1", true}, + {"validuser", "invalid address", true}, + {"invalid@user", "127.0.0.1", true}, + {"validuser", "invalid@address", true}, + {"injection; rm -rf /", "127.0.0.1", true}, + {"validuser", "127.0.0.1; rm -rf /", true}, + {"$(reboot)", "127.0.0.1", true}, + {"validuser", "$(reboot)", true}, + {"validuser", "127.0.0.1; $(reboot)", true}, + {"validuser", "127.0.0.1 | ls", true}, + {"validuser", "127.0.0.1 & ls", true}, + {"validuser", "127.0.0.1 && ls", true}, + {"validuser", "127.0.0.1 |& ls", true}, + {"validuser", "127.0.0.1 ; ls", true}, + {"validuser", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", false}, + {"validuser", "2001:db8::ff00:42:8329", false}, + {"validuser", "2001:db8:0:1234:0:567:8:1", false}, + {"validuser", "2001:db8::1234:0:567:8:1", false}, + {"validuser", "2001:db8:0:0:0:0:2:1", false}, + {"validuser", "2001:db8::2:1", false}, + {"validuser", "2001:db8:0:0:8:800:200c:417a", false}, + {"validuser", "2001:db8::8:800:200c:417a", false}, + {"validuser", "2001:db8:0:0:8:800:200c:417a; rm -rf /", true}, + {"validuser", "2001:db8::8:800:200c:417a; rm -rf /", true}, + } + + for _, test := range tests { + err := ValidateUsernameAndRemoteAddr(test.username, test.remoteAddr) + if test.expectError && err == nil { + t.Errorf("Expected error for username %s and remoteAddr %s, but got none", test.username, test.remoteAddr) + } + if !test.expectError && err != nil { + t.Errorf("Did not expect error for username %s and remoteAddr %s, but got %v", test.username, test.remoteAddr, err) + } + } +} diff --git a/src/mod/sshprox/utils.go b/src/mod/sshprox/utils.go index 082c8d9..0e4d271 100644 --- a/src/mod/sshprox/utils.go +++ b/src/mod/sshprox/utils.go @@ -1,9 +1,11 @@ package sshprox import ( + "errors" "fmt" "net" "net/url" + "regexp" "runtime" "strings" "time" @@ -34,6 +36,21 @@ func IsWebSSHSupported() bool { return true } +// Get the next free port in the list +func (m *Manager) GetNextPort() int { + nextPort := m.StartingPort + occupiedPort := make(map[int]bool) + for _, instance := range m.Instances { + occupiedPort[instance.AssignedPort] = true + } + for { + if !occupiedPort[nextPort] { + return nextPort + } + nextPort++ + } +} + // Check if a given domain and port is a valid ssh server func IsSSHConnectable(ipOrDomain string, port int) bool { timeout := time.Second * 3 @@ -60,13 +77,47 @@ func IsSSHConnectable(ipOrDomain string, port int) bool { return string(buf[:7]) == "SSH-2.0" } -// Check if the port is used by other process or application -func isPortInUse(port int) bool { - address := fmt.Sprintf(":%d", port) - listener, err := net.Listen("tcp", address) - if err != nil { +// Validate the username and remote address to prevent injection +func ValidateUsernameAndRemoteAddr(username string, remoteIpAddr string) error { + // Validate and sanitize the username to prevent ssh injection + validUsername := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + if !validUsername.MatchString(username) { + return errors.New("invalid username, only alphanumeric characters, dots, underscores and dashes are allowed") + } + + //Check if the remoteIpAddr is a valid ipv4 or ipv6 address + if net.ParseIP(remoteIpAddr) != nil { + //A valid IP address do not need further validation + return nil + } + + // Validate and sanitize the remote domain to prevent injection + validRemoteAddr := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + if !validRemoteAddr.MatchString(remoteIpAddr) { + return errors.New("invalid remote address, only alphanumeric characters, dots, underscores and dashes are allowed") + } + + return nil +} + +// Check if the given ip or domain is a loopback address +// or resolves to a loopback address +func IsLoopbackIPOrDomain(ipOrDomain string) bool { + if strings.EqualFold(strings.TrimSpace(ipOrDomain), "localhost") || strings.TrimSpace(ipOrDomain) == "127.0.0.1" { return true } - listener.Close() + + //Check if the ipOrDomain resolves to a loopback address + ips, err := net.LookupIP(ipOrDomain) + if err != nil { + return false + } + + for _, ip := range ips { + if ip.IsLoopback() { + return true + } + } + return false } diff --git a/src/webssh.go b/src/webssh.go index 0f8631f..8d870b9 100644 --- a/src/webssh.go +++ b/src/webssh.go @@ -42,7 +42,7 @@ func HandleCreateProxySession(w http.ResponseWriter, r *http.Request) { if !*allowSshLoopback { //Not allow loopback connections - if strings.EqualFold(strings.TrimSpace(ipaddr), "localhost") || strings.TrimSpace(ipaddr) == "127.0.0.1" { + if sshprox.IsLoopbackIPOrDomain(ipaddr) { //Request target is loopback utils.SendErrorResponse(w, "loopback web ssh connection is not enabled on this host") return @@ -74,7 +74,7 @@ func HandleCreateProxySession(w http.ResponseWriter, r *http.Request) { utils.SendJSONResponse(w, string(js)) } -//Check if the host support ssh, or if the target domain (and port, optional) support ssh +// Check if the host support ssh, or if the target domain (and port, optional) support ssh func HandleWebSshSupportCheck(w http.ResponseWriter, r *http.Request) { domain, err := utils.PostPara(r, "domain") if err != nil {