Skip to content

Commit

Permalink
Fix socks5 packet conn
Browse files Browse the repository at this point in the history
  • Loading branch information
dyhkwong authored and nekohasekai committed May 17, 2024
1 parent 8fb1634 commit dd0be0d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 40 deletions.
8 changes: 8 additions & 0 deletions common/bufio/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
}

func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}

func (c *unidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}
Expand Down Expand Up @@ -136,6 +140,10 @@ func (c *bidirectionalNATPacketConn) Upstream() any {
return c.NetPacketConn
}

func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}

func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
destination.Port = 0
return destination
Expand Down
3 changes: 1 addition & 2 deletions protocol/socks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"strings"

"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
Expand Down Expand Up @@ -148,7 +147,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
tcpConn.Close()
return nil, err
}
return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil
return NewAssociatePacketConn(udpConn, address, tcpConn), nil
}
return nil, os.ErrInvalid
}
Expand Down
67 changes: 29 additions & 38 deletions protocol/socks/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,41 @@ import (
var ErrInvalidPacket = E.New("socks5: invalid packet")

type AssociatePacketConn struct {
N.NetPacketConn
N.AbstractConn
conn N.ExtendedConn
remoteAddr M.Socksaddr
underlying net.Conn
}

func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
func NewAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
return &AssociatePacketConn{
NetPacketConn: bufio.NewPacketConn(conn),
remoteAddr: remoteAddr,
underlying: underlying,
AbstractConn: conn,
conn: bufio.NewExtendedConn(conn),
remoteAddr: remoteAddr,
underlying: underlying,
}
}

// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead.
func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
return &AssociatePacketConn{
NetPacketConn: bufio.NewUnbindPacketConn(conn),
remoteAddr: remoteAddr,
underlying: underlying,
}
}

func (c *AssociatePacketConn) RemoteAddr() net.Addr {
return c.remoteAddr.UDPAddr()
}

//warn:unsafe
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.NetPacketConn.ReadFrom(p)
n, err = c.conn.Read(p)
if err != nil {
return
}
if n < 3 {
return 0, nil, ErrInvalidPacket
}
c.remoteAddr = M.SocksaddrFromNet(addr)
reader := bytes.NewReader(p[3:n])
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return
}
c.remoteAddr = destination
addr = destination.UDPAddr()
index := 3 + int(reader.Size()) - reader.Len()
n = copy(p, p[index:n])
return
}

//warn:unsafe
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
destination := M.SocksaddrFromNet(addr)
buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
Expand All @@ -82,32 +69,23 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error
if err != nil {
return
}
return bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)
}

func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}

func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr)
return c.conn.Write(buffer.Bytes())
}

func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
destination, err = c.NetPacketConn.ReadPacket(buffer)
err = c.conn.ReadBuffer(buffer)
if err != nil {
return M.Socksaddr{}, err
return
}
if buffer.Len() < 3 {
return M.Socksaddr{}, ErrInvalidPacket
}
c.remoteAddr = destination
buffer.Advance(3)
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
if err != nil {
return
}
c.remoteAddr = destination
return destination.Unwrap(), nil
}

Expand All @@ -118,11 +96,24 @@ func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Sock
if err != nil {
return err
}
return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr))
return c.conn.WriteBuffer(buffer)
}

func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}

func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr)
}

func (c *AssociatePacketConn) RemoteAddr() net.Addr {
return c.remoteAddr.UDPAddr()
}

func (c *AssociatePacketConn) Upstream() any {
return c.NetPacketConn
return c.conn
}

func (c *AssociatePacketConn) FrontHeadroom() int {
Expand All @@ -131,7 +122,7 @@ func (c *AssociatePacketConn) FrontHeadroom() int {

func (c *AssociatePacketConn) Close() error {
return common.Close(
c.NetPacketConn,
c.conn,
c.underlying,
)
}

0 comments on commit dd0be0d

Please sign in to comment.