diff --git a/modules/connections.go b/modules/connections.go index f2c45ce..16c661b 100644 --- a/modules/connections.go +++ b/modules/connections.go @@ -1,19 +1,23 @@ package modules import ( + "errors" "fmt" "net" + "sync" "time" "golang.org/x/net/proxy" ) type ConnectionManager struct { - Socks5 string - Timeout time.Duration - Iface string - Dialer proxy.Dialer - DialFunc func(network, address string) (net.Conn, error) + Socks5 string + Timeout time.Duration + Iface string + Dialer proxy.Dialer + DialFunc func(network, address string) (net.Conn, error) + ConnPool map[string]chan net.Conn + PoolMutex sync.Mutex } func NewConnectionManager(socks5 string, timeout time.Duration, iface ...string) (*ConnectionManager, error) { @@ -30,9 +34,10 @@ func NewConnectionManager(socks5 string, timeout time.Duration, iface ...string) } cm := &ConnectionManager{ - Socks5: socks5, - Timeout: timeout, - Iface: ifaceName, + Socks5: socks5, + Timeout: timeout, + Iface: ifaceName, + ConnPool: make(map[string]chan net.Conn), } ipAddr, err := GetIPv4Address(ifaceName) @@ -66,7 +71,45 @@ func NewConnectionManager(socks5 string, timeout time.Duration, iface ...string) } func (cm *ConnectionManager) Dial(network, address string) (net.Conn, error) { - return cm.DialFunc(network, address) + key := fmt.Sprintf("%s:%s", network, address) + + cm.PoolMutex.Lock() + if _, ok := cm.ConnPool[key]; !ok { + cm.ConnPool[key] = make(chan net.Conn, 10) + } + cm.PoolMutex.Unlock() + + select { + case conn := <-cm.ConnPool[key]: + return conn, nil + default: + conn, err := cm.DialFunc(network, address) + if err != nil { + return nil, err + } + return conn, nil + } +} + +func (cm *ConnectionManager) Release(conn net.Conn) { + if conn == nil { + return + } + + cm.PoolMutex.Lock() + defer cm.PoolMutex.Unlock() + + key := fmt.Sprintf("%s:%s", conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + + if _, ok := cm.ConnPool[key]; !ok { + cm.ConnPool[key] = make(chan net.Conn, 10) + } + + select { + case cm.ConnPool[key] <- conn: + default: + conn.Close() + } } func (cm *ConnectionManager) DialUDP(network, address string) (*net.UDPConn, error) {