Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
Fixed web ssh security bug
  • Loading branch information
tobychui authored Nov 10, 2024
2 parents e79a70b + ed178d8 commit 2e9bc77
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 23 deletions.
26 changes: 11 additions & 15 deletions src/mod/sshprox/sshprox.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,6 @@ func NewSSHProxyManager() *Manager {
}
}

// Get the next free port in the list
func (m *Manager) GetNextPort() int {
nextPort := m.StartingPort
occupiedPort := make(map[int]bool)
for _, instance := range m.Instances {
occupiedPort[instance.AssignedPort] = true
}
for {
if !occupiedPort[nextPort] {
return nextPort
}
nextPort++
}
}

func (m *Manager) HandleHttpByInstanceId(instanceId string, w http.ResponseWriter, r *http.Request) {
targetInstance, err := m.GetInstanceById(instanceId)
if err != nil {
Expand Down Expand Up @@ -168,6 +153,17 @@ func (i *Instance) CreateNewConnection(listenPort int, username string, remoteIp
if username != "" {
connAddr = username + "@" + remoteIpAddr
}

//Trim the space in the username and remote address
username = strings.TrimSpace(username)
remoteIpAddr = strings.TrimSpace(remoteIpAddr)

//Validate the username and remote address
err := ValidateUsernameAndRemoteAddr(username, remoteIpAddr)
if err != nil {
return err
}

configPath := filepath.Join(filepath.Dir(i.ExecPath), ".gotty")
title := username + "@" + remoteIpAddr
if remotePort != 22 {
Expand Down
66 changes: 66 additions & 0 deletions src/mod/sshprox/sshprox_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package sshprox

import (
"testing"
)

func TestInstance_Destroy(t *testing.T) {
manager := NewSSHProxyManager()
instance, err := manager.NewSSHProxy("/tmp")
if err != nil {
t.Fatalf("Failed to create new SSH proxy: %v", err)
}

instance.Destroy()

if len(manager.Instances) != 0 {
t.Errorf("Expected Instances to be empty, got %d", len(manager.Instances))
}
}

func TestInstance_ValidateUsernameAndRemoteAddr(t *testing.T) {
tests := []struct {
username string
remoteAddr string
expectError bool
}{
{"validuser", "127.0.0.1", false},
{"valid.user", "example.com", false},
{"; bash ;", "example.com", true},
{"valid-user", "example.com", false},
{"invalid user", "127.0.0.1", true},
{"validuser", "invalid address", true},
{"invalid@user", "127.0.0.1", true},
{"validuser", "invalid@address", true},
{"injection; rm -rf /", "127.0.0.1", true},
{"validuser", "127.0.0.1; rm -rf /", true},
{"$(reboot)", "127.0.0.1", true},
{"validuser", "$(reboot)", true},
{"validuser", "127.0.0.1; $(reboot)", true},
{"validuser", "127.0.0.1 | ls", true},
{"validuser", "127.0.0.1 & ls", true},
{"validuser", "127.0.0.1 && ls", true},
{"validuser", "127.0.0.1 |& ls", true},
{"validuser", "127.0.0.1 ; ls", true},
{"validuser", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", false},
{"validuser", "2001:db8::ff00:42:8329", false},
{"validuser", "2001:db8:0:1234:0:567:8:1", false},
{"validuser", "2001:db8::1234:0:567:8:1", false},
{"validuser", "2001:db8:0:0:0:0:2:1", false},
{"validuser", "2001:db8::2:1", false},
{"validuser", "2001:db8:0:0:8:800:200c:417a", false},
{"validuser", "2001:db8::8:800:200c:417a", false},
{"validuser", "2001:db8:0:0:8:800:200c:417a; rm -rf /", true},
{"validuser", "2001:db8::8:800:200c:417a; rm -rf /", true},
}

for _, test := range tests {
err := ValidateUsernameAndRemoteAddr(test.username, test.remoteAddr)
if test.expectError && err == nil {
t.Errorf("Expected error for username %s and remoteAddr %s, but got none", test.username, test.remoteAddr)
}
if !test.expectError && err != nil {
t.Errorf("Did not expect error for username %s and remoteAddr %s, but got %v", test.username, test.remoteAddr, err)
}
}
}
63 changes: 57 additions & 6 deletions src/mod/sshprox/utils.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package sshprox

import (
"errors"
"fmt"
"net"
"net/url"
"regexp"
"runtime"
"strings"
"time"
Expand Down Expand Up @@ -34,6 +36,21 @@ func IsWebSSHSupported() bool {
return true
}

// Get the next free port in the list
func (m *Manager) GetNextPort() int {
nextPort := m.StartingPort
occupiedPort := make(map[int]bool)
for _, instance := range m.Instances {
occupiedPort[instance.AssignedPort] = true
}
for {
if !occupiedPort[nextPort] {
return nextPort
}
nextPort++
}
}

// Check if a given domain and port is a valid ssh server
func IsSSHConnectable(ipOrDomain string, port int) bool {
timeout := time.Second * 3
Expand All @@ -60,13 +77,47 @@ func IsSSHConnectable(ipOrDomain string, port int) bool {
return string(buf[:7]) == "SSH-2.0"
}

// Check if the port is used by other process or application
func isPortInUse(port int) bool {
address := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", address)
if err != nil {
// Validate the username and remote address to prevent injection
func ValidateUsernameAndRemoteAddr(username string, remoteIpAddr string) error {
// Validate and sanitize the username to prevent ssh injection
validUsername := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
if !validUsername.MatchString(username) {
return errors.New("invalid username, only alphanumeric characters, dots, underscores and dashes are allowed")
}

//Check if the remoteIpAddr is a valid ipv4 or ipv6 address
if net.ParseIP(remoteIpAddr) != nil {
//A valid IP address do not need further validation
return nil
}

// Validate and sanitize the remote domain to prevent injection
validRemoteAddr := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
if !validRemoteAddr.MatchString(remoteIpAddr) {
return errors.New("invalid remote address, only alphanumeric characters, dots, underscores and dashes are allowed")
}

return nil
}

// Check if the given ip or domain is a loopback address
// or resolves to a loopback address
func IsLoopbackIPOrDomain(ipOrDomain string) bool {
if strings.EqualFold(strings.TrimSpace(ipOrDomain), "localhost") || strings.TrimSpace(ipOrDomain) == "127.0.0.1" {
return true
}
listener.Close()

//Check if the ipOrDomain resolves to a loopback address
ips, err := net.LookupIP(ipOrDomain)
if err != nil {
return false
}

for _, ip := range ips {
if ip.IsLoopback() {
return true
}
}

return false
}
4 changes: 2 additions & 2 deletions src/webssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func HandleCreateProxySession(w http.ResponseWriter, r *http.Request) {

if !*allowSshLoopback {
//Not allow loopback connections
if strings.EqualFold(strings.TrimSpace(ipaddr), "localhost") || strings.TrimSpace(ipaddr) == "127.0.0.1" {
if sshprox.IsLoopbackIPOrDomain(ipaddr) {
//Request target is loopback
utils.SendErrorResponse(w, "loopback web ssh connection is not enabled on this host")
return
Expand Down Expand Up @@ -74,7 +74,7 @@ func HandleCreateProxySession(w http.ResponseWriter, r *http.Request) {
utils.SendJSONResponse(w, string(js))
}

//Check if the host support ssh, or if the target domain (and port, optional) support ssh
// Check if the host support ssh, or if the target domain (and port, optional) support ssh
func HandleWebSshSupportCheck(w http.ResponseWriter, r *http.Request) {
domain, err := utils.PostPara(r, "domain")
if err != nil {
Expand Down

0 comments on commit 2e9bc77

Please sign in to comment.