diff --git a/transport/socks5/packet_listener.go b/transport/socks5/packet_listener.go new file mode 100644 index 00000000..fbbbf4ed --- /dev/null +++ b/transport/socks5/packet_listener.go @@ -0,0 +1,183 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socks5 + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/netip" + "time" + + "github.com/Jigsaw-Code/outline-sdk/internal/slicepool" + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// clientUDPBufferSize is the maximum supported UDP packet size in bytes. +const clientUDPBufferSize = 16 * 1024 + +// udpPool stores the byte slices used for storing packets. +var udpPool = slicepool.MakePool(clientUDPBufferSize) + +type packetConn struct { + pc net.Conn + sc io.Closer +} + +var _ net.PacketConn = (*packetConn)(nil) + +func (p *packetConn) LocalAddr() net.Addr { + return p.pc.LocalAddr() +} + +func (p *packetConn) SetDeadline(t time.Time) error { + return p.pc.SetDeadline(t) +} + +func (p *packetConn) SetReadDeadline(t time.Time) error { + return p.pc.SetReadDeadline(t) +} + +func (c *packetConn) SetWriteDeadline(t time.Time) error { + return c.pc.SetWriteDeadline(t) +} + +// ReadFrom reads the packet from the SOCKS5 server and extract the payload +// The packet format is specified in https://datatracker.ietf.org/doc/html/rfc1928#section-7 +func (p *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { + lazySlice := udpPool.LazySlice() + buffer := lazySlice.Acquire() + defer lazySlice.Release() + + n, err := p.pc.Read(buffer) + if err != nil { + return 0, nil, err + } + // Minimum packet size + if n < 10 { + return 0, nil, errors.New("invalid SOCKS5 UDP packet: too short") + } + + // Using bytes.Buffer to handle data + buf := bytes.NewBuffer(buffer[:n]) + + // Read and check reserved bytes + rsv := make([]byte, 2) + if _, err := buf.Read(rsv); err != nil { + return 0, nil, err + } + if rsv[0] != 0x00 || rsv[1] != 0x00 { + return 0, nil, fmt.Errorf("invalid reserved bytes: expected 0x0000, got %#x%#x", rsv[0], rsv[1]) + } + + // Read fragment byte + frag, err := buf.ReadByte() + if err != nil { + return 0, nil, err + } + if frag != 0 { + return 0, nil, errors.New("fragmentation is not supported") + } + + // Read address using socks.ReadAddr which must now accept a bytes.Buffer directly + address, err := readAddr(buf) + if err != nil { + return 0, nil, fmt.Errorf("failed to read address: %w", err) + } + + // Convert the address to a net.Addr + addr, err := transport.MakeNetAddr("udp", addrToString(address)) + if err != nil { + return 0, nil, fmt.Errorf("failed to convert address: %w", err) + } + + // Payload handling: remaining bytes in the buffer are the payload + payload := buf.Bytes() + payloadLength := len(payload) + if payloadLength > len(b) { + return 0, nil, io.ErrShortBuffer + } + copy(b, payload) + + return payloadLength, addr, nil +} + +// WriteTo encapsulates the payload in a SOCKS5 UDP packet as specified in +// https://datatracker.ietf.org/doc/html/rfc1928#section-7 +// and write it to the SOCKS5 server via the underlying connection. +func (p *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { + + // The minimum preallocated header size (10 bytes) + lazySlice := udpPool.LazySlice() + buffer := lazySlice.Acquire() + defer lazySlice.Release() + buffer = append(buffer[:0], + 0x00, 0x00, // Reserved + 0x00, // Fragment number + // To be appended below: + // ATYP, IPv4, IPv6, Domain Name, Port + ) + buffer, err := appendSOCKS5Address(buffer, addr.String()) + if err != nil { + return 0, fmt.Errorf("failed to append SOCKS5 address: %w", err) + } + // Combine the header and the payload + return p.pc.Write(append(buffer, b...)) +} + +// Close closes both the underlying stream and packet connections. +func (p *packetConn) Close() error { + return errors.Join(p.sc.Close(), p.pc.Close()) +} + +// ListenPacket creates a [net.PacketConn] for dialing to SOCKS5 server. +func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { + // Connect to the SOCKS5 server and perform UDP association + // Since local address is not known in advance, we use unspecified address + // which means the server is going to accept incoming packets from any address + // on the bind port on the server. The bind address is determined and returned by + // the server. + // https://datatracker.ietf.org/doc/html/rfc1928#section-6 + // Whoile binding address to specific client address has its advantages, it also creates some + // challenges such as NAT traveral if client is behind NAT. + sc, bindAddr, err := c.connectAndRequest(ctx, CmdUDPAssociate, "0.0.0.0:0") + if err != nil { + return nil, err + } + + // If the returned bind IP address is unspecified (i.e. "0.0.0.0" or "::"), + // then use the IP address of the SOCKS5 server + if ipAddr := bindAddr.IP; ipAddr.IsValid() && ipAddr.IsUnspecified() { + schost, _, err := net.SplitHostPort(sc.RemoteAddr().String()) + if err != nil { + return nil, fmt.Errorf("failed to parse tcp address: %w", err) + } + + bindAddr.IP, err = netip.ParseAddr(schost) + if err != nil { + return nil, fmt.Errorf("failed to parse bind address: %w", err) + } + } + + proxyConn, err := c.pd.DialPacket(ctx, addrToString(bindAddr)) + if err != nil { + sc.Close() + return nil, fmt.Errorf("could not connect to packet endpoint: %w", err) + } + return &packetConn{pc: proxyConn, sc: sc}, nil +} diff --git a/transport/socks5/packet_listener_test.go b/transport/socks5/packet_listener_test.go new file mode 100644 index 00000000..0d2ac053 --- /dev/null +++ b/transport/socks5/packet_listener_test.go @@ -0,0 +1,114 @@ +package socks5 + +import ( + "bytes" + "context" + "errors" + "net" + "testing" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/things-go/go-socks5" +) + +func TestSOCKS5Associate(t *testing.T) { + // Create a local listener. + // This creates a UDP server that responded to "ping" + // message with "pong" response. + locIP := net.ParseIP("127.0.0.1") + // Create a local listener + echoServerAddr := &net.UDPAddr{IP: locIP, Port: 0} + echoServer := setupUDPEchoServer(t, echoServerAddr) + defer echoServer.Close() + + // Create a socks server to proxy "ping" message. + cator := socks5.UserPassAuthenticator{Credentials: socks5.StaticCredentials{ + "testusername": "testpassword", + }} + proxySrv := socks5.NewServer( + socks5.WithAuthMethods([]socks5.Authenticator{cator}), + ) + + // Create SOCKS5 proxy on localhost with a random port. + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + proxyServerAddress := listener.Addr().String() + + go func() { + err := proxySrv.Serve(listener) + if !errors.Is(err, net.ErrClosed) && err != nil { + require.NoError(t, err) // Assert no error if it's not the expected close error + } + }() + + // Connect to local proxy, auth and start the PacketConn. + client, err := NewClient(&transport.TCPEndpoint{Address: proxyServerAddress}) + require.NotNil(t, client) + require.NoError(t, err) + err = client.SetCredentials([]byte("testusername"), []byte("testpassword")) + require.NoError(t, err) + client.EnablePacket(&transport.UDPDialer{}) + conn, err := client.ListenPacket(context.Background()) + require.NoError(t, err) + defer conn.Close() + + // Send "ping" message. + _, err = conn.WriteTo([]byte("ping"), echoServer.LocalAddr()) + require.NoError(t, err) + // Max wait time for response. + err = conn.SetDeadline(time.Now().Add(time.Second)) + require.NoError(t, err) + response := make([]byte, 1024) + n, addr, err := conn.ReadFrom(response) + require.Equal(t, echoServer.LocalAddr().String(), addr.String()) + require.NoError(t, err) + require.Equal(t, []byte("pong"), response[:n]) +} + +func TestUDPLoopBack(t *testing.T) { + // Create a local listener. + locIP := net.ParseIP("127.0.0.1") + echoServerAddr := &net.UDPAddr{IP: locIP, Port: 0} + echoServer := setupUDPEchoServer(t, echoServerAddr) + defer echoServer.Close() + + packDialer := transport.UDPDialer{} + conn, err := packDialer.DialPacket(context.Background(), echoServer.LocalAddr().String()) + require.NoError(t, err) + _, err = conn.Write([]byte("ping")) + require.NoError(t, err) + response := make([]byte, 1024) + n, err := conn.Read(response) + require.NoError(t, err) + assert.Equal(t, []byte("pong"), response[:n]) +} + +func setupUDPEchoServer(t *testing.T, serverAddr *net.UDPAddr) *net.UDPConn { + server, err := net.ListenUDP("udp", serverAddr) + require.NoError(t, err) + go func() { + buf := make([]byte, 2048) + for { + n, remote, err := server.ReadFrom(buf) + if err != nil { + return + } + if bytes.Equal(buf[:n], []byte("ping")) { + _, err := server.WriteTo([]byte("pong"), remote) + if err != nil { + return + } + } + } + }() + + t.Cleanup(func() { + server.Close() + }) + + return server +} diff --git a/transport/socks5/socks5.go b/transport/socks5/socks5.go index 59e2bf06..ebebcf71 100644 --- a/transport/socks5/socks5.go +++ b/transport/socks5/socks5.go @@ -18,7 +18,9 @@ import ( "encoding/binary" "errors" "fmt" + "io" "net" + "net/netip" "strconv" ) @@ -37,6 +39,13 @@ const ( ErrAddressTypeNotSupported = ReplyCode(0x08) ) +// SOCKS5 commands, from https://datatracker.ietf.org/doc/html/rfc1928#section-4. +const ( + CmdConnect = byte(1) + CmdBind = byte(2) + CmdUDPAssociate = byte(3) +) + // SOCKS5 authentication methods, as specified in https://datatracker.ietf.org/doc/html/rfc1928#section-3 const ( authMethodNoAuth = 0x00 @@ -79,6 +88,27 @@ const ( addrTypeIPv6 = 0x04 ) +// address is a SOCKS-specific address. +// Either Name or IP is used exclusively. +type address struct { + Name string // fully-qualified domain name + IP netip.Addr + Port uint16 +} + +// Address returns a string suitable to dial; prefer returning IP-based +// address, fallback to Name +func addrToString(a *address) string { + if a == nil { + return "" + } + port := strconv.Itoa(int(a.Port)) + if a.IP.IsValid() { + return net.JoinHostPort(a.IP.String(), port) + } + return net.JoinHostPort(a.Name, port) +} + // appendSOCKS5Address adds the address to buffer b in SOCKS5 format, // as specified in https://datatracker.ietf.org/doc/html/rfc1928#section-4 func appendSOCKS5Address(b []byte, address string) ([]byte, error) { @@ -119,3 +149,47 @@ func appendSOCKS5Address(b []byte, address string) ([]byte, error) { b = binary.BigEndian.AppendUint16(b, uint16(portNum)) return b, nil } + +func readAddr(r io.Reader) (*address, error) { + address := &address{} + + var addrType [1]byte + if _, err := r.Read(addrType[:]); err != nil { + return nil, err + } + + switch addrType[0] { + case addrTypeIPv4: + var addr [4]byte + if _, err := io.ReadFull(r, addr[:]); err != nil { + return nil, err + } + address.IP = netip.AddrFrom4(addr) + case addrTypeIPv6: + var addr [16]byte + if _, err := io.ReadFull(r, addr[:]); err != nil { + return nil, err + } + address.IP = netip.AddrFrom16(addr) + case addrTypeDomainName: + if _, err := r.Read(addrType[:]); err != nil { + return nil, err + } + addrLen := addrType[0] + // addrLen btye type maximum value is 255 which + // prevents passing larger then 255 values for domain names. + fqdn := make([]byte, addrLen) + if _, err := io.ReadFull(r, fqdn); err != nil { + return nil, err + } + address.Name = string(fqdn) + default: + return nil, errors.New("unrecognized address type") + } + var port [2]byte + if _, err := io.ReadFull(r, port[:]); err != nil { + return nil, err + } + address.Port = binary.BigEndian.Uint16(port[:]) + return address, nil +} diff --git a/transport/socks5/socks5_test.go b/transport/socks5/socks5_test.go index 8dbfd26f..f5696251 100644 --- a/transport/socks5/socks5_test.go +++ b/transport/socks5/socks5_test.go @@ -15,12 +15,161 @@ package socks5 import ( + "bytes" + "io" + "net/netip" "strings" "testing" "github.com/stretchr/testify/require" ) +func TestReadAddr(t *testing.T) { + tests := []struct { + name string + input []byte + want *address + wantErr bool + }{ + + { + name: "IPv4 Example", + input: []byte{addrTypeIPv4, 192, 168, 1, 1, 0x01, 0xF4}, + want: &address{IP: netip.MustParseAddr("192.168.1.1"), Port: 500}, + wantErr: false, + }, + { + name: "IPv6 Full", + input: []byte{ + addrTypeIPv6, + 0x20, 0x01, 0x0d, 0xb8, // first 4 bytes of the IPv6 address + 0x00, 0x00, 0x00, 0x00, // middle zeroes are often omitted in shorthand notation + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, // last segment with the "1" + 0x04, 0xD2, // port number 1234 + }, + want: &address{IP: netip.MustParseAddr("2001:db8::1"), Port: 1234}, + wantErr: false, + }, + { + name: "IPv6 Compressed", + input: []byte{ + addrTypeIPv6, + 0xfe, 0x80, 0x00, 0x00, // first 4 bytes with "fe80", and then three zeroed segments + 0x00, 0x00, 0x00, 0x00, + 0x02, 0x04, 0x61, 0xff, // "0204:61ff" + 0xfe, 0x9d, 0xf1, 0x56, // "fe9d:f156" + 0x00, 0x50, // port number 80 in hexadecimal + }, + want: &address{IP: netip.MustParseAddr("fe80::204:61ff:fe9d:f156"), Port: 80}, + wantErr: false, + }, + { + name: "IPv6 Loopback", + input: []byte{ + addrTypeIPv6, + 0x00, 0x00, 0x00, 0x00, // eight zeroed-out segments + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, // last segment is "0001" + 0x1F, 0x90, // port number 8080 in hexadecimal + }, + want: &address{IP: netip.IPv6Loopback(), Port: 8080}, + wantErr: false, + }, + { + name: "Domain Short", + input: []byte{ + addrTypeDomainName, // Address type for domain name + 0x0b, // Length of the domain name "example.com" which is 11 characters + 'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm', // The domain name "example.com" + 0x23, 0x28, // Port number 9000 in hexadecimal + }, + want: &address{Name: "example.com", Port: 9000}, + wantErr: false, + }, + { + name: "Domain Long", + input: append([]byte{addrTypeDomainName, 0x3B}, append([]byte("very-long-domain-name-used-for-testing-purposes.example.com"), 0x00, 0x50)...), + want: &address{Name: "very-long-domain-name-used-for-testing-purposes.example.com", Port: 80}, + wantErr: false, + }, + { + name: "Unrecognized Address Type", + input: []byte{0x00}, + want: nil, + wantErr: true, + }, + { + name: "Short Input", + input: []byte{addrTypeIPv4}, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.input) + got, err := readAddr(r) + if tt.wantErr { + require.Error(t, err, "Expected an error but got none") + } else { + require.NoError(t, err, "Did not expect an error but got one") + } + + if !tt.wantErr && !compareAddresses(got, tt.want) { + t.Errorf("readAddr() got = %v, want %v", got, tt.want) + } + }) + } +} + +func BenchmarkReadAddr(b *testing.B) { + tests := []struct { + name string + input []byte + }{ + { + name: "IPv4", + input: append([]byte{addrTypeIPv4}, append(netip.AddrFrom4([4]byte{192, 168, 1, 1}).AsSlice(), []byte{0x00, 0x50}...)...), + }, + { + name: "IPv6", + input: append([]byte{addrTypeIPv6}, append(netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}).AsSlice(), []byte{0x1F, 0x90}...)...), + }, + { + name: "Domain", + input: append([]byte{addrTypeDomainName, 0x0b}, append([]byte("example.com"), []byte{0x23, 0x28}...)...), + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + reader := bytes.NewReader(tt.input) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := reader.Seek(0, io.SeekStart); err != nil { + b.Error("Seek failed:", err) + } + if _, err := readAddr(reader); err != nil { + b.Error("readAddr failed:", err) + } + } + }) + } +} + +func compareAddresses(a1, a2 *address) bool { + if a1 == nil || a2 == nil { + return a1 == a2 + } + if (a1.IP != netip.Addr{}) && a1.IP != a2.IP || a1.Name != a2.Name || a1.Port != a2.Port { + return false + } + return true +} + func TestAppendSOCKS5Address_IPv4(t *testing.T) { b := []byte{} b, err := appendSOCKS5Address(b, "8.8.8.8:853") diff --git a/transport/socks5/stream_dialer.go b/transport/socks5/stream_dialer.go index b1839a7c..63fc68d3 100644 --- a/transport/socks5/stream_dialer.go +++ b/transport/socks5/stream_dialer.go @@ -30,23 +30,25 @@ type credentials struct { password []byte } -// NewStreamDialer creates a [transport.StreamDialer] that routes connections to a SOCKS5 +// NewClient creates a SOCKS5 client that routes connections to a SOCKS5 // proxy listening at the given [transport.StreamEndpoint]. -func NewStreamDialer(endpoint transport.StreamEndpoint) (*StreamDialer, error) { - if endpoint == nil { +func NewClient(streamEndpoint transport.StreamEndpoint) (*Client, error) { + if streamEndpoint == nil { return nil, errors.New("argument endpoint must not be nil") } - return &StreamDialer{proxyEndpoint: endpoint, cred: nil}, nil + return &Client{se: streamEndpoint, cred: nil}, nil } -type StreamDialer struct { - proxyEndpoint transport.StreamEndpoint - cred *credentials +type Client struct { + se transport.StreamEndpoint + pd transport.PacketDialer + cred *credentials } -var _ transport.StreamDialer = (*StreamDialer)(nil) +var _ transport.StreamDialer = (*Client)(nil) +var _ transport.PacketListener = (*Client)(nil) -func (c *StreamDialer) SetCredentials(username, password []byte) error { +func (c *Client) SetCredentials(username, password []byte) error { if len(username) > 255 { return errors.New("username exceeds 255 bytes") } @@ -65,23 +67,14 @@ func (c *StreamDialer) SetCredentials(username, password []byte) error { return nil } -// DialStream implements [transport.StreamDialer].DialStream using SOCKS5. -// It will send the auth method, auth credentials (if auth is chosen), and -// the connect requests in one packet, to avoid an additional roundtrip. -// The returned [error] will be of type [ReplyCode] if the server sends a SOCKS error reply code, which -// you can check against the error constants in this package using [errors.Is]. -func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { - proxyConn, err := c.proxyEndpoint.ConnectStream(ctx) - if err != nil { - return nil, fmt.Errorf("could not connect to SOCKS5 proxy: %w", err) - } - dialSuccess := false - defer func() { - if !dialSuccess { - proxyConn.Close() - } - }() +// EnablePacket enables the use of the [Client] as a [transport.PacketListener]. It takes the [transport.PacketDialer] used to connect to the SOCKS5 packet endpoint. +func (c *Client) EnablePacket(packetDialer transport.PacketDialer) { + c.pd = packetDialer +} +// request sends a SOCKS5 request to the server to perform a command (e.g., connect, udp associate), +// performs authentication (if provided), returns the bound address. +func (c *Client) request(conn io.ReadWriter, cmd byte, dstAddr string) (*address, error) { // For protocol details, see https://datatracker.ietf.org/doc/html/rfc1928#section-3 // Creating a single buffer for method selection, authentication, and connection request // Buffer large enough for method, auth, and connect requests with a domain name address. @@ -118,24 +111,24 @@ func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (trans b = append(b, c.cred.password...) } - // Connect request: - // VER = 5, CMD = 1 (connect), RSV = 0, DST.ADDR, DST.PORT + // CMD Request: + // VER = 5, CMD = cmd, RSV = 0, DST.ADDR, DST.PORT // +----+-----+-------+------+----------+----------+ // |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | // +----+-----+-------+------+----------+----------+ // | 1 | 1 | X'00' | 1 | Variable | 2 | // +----+-----+-------+------+----------+----------+ - b = append(b, 5, 1, 0) + b = append(b, 5, cmd, 0) // TODO: Probably more memory efficient if remoteAddr is added to the buffer directly. - b, err = appendSOCKS5Address(b, remoteAddr) + b, err := appendSOCKS5Address(b, dstAddr) if err != nil { return nil, fmt.Errorf("failed to create SOCKS5 address: %w", err) } - // We merge the method and connect requests and only perform one write + // We merge the method and CMD requests and only perform one write // because we send a single authentication method, so there's no point // in waiting for the response. This eliminates a roundtrip. - _, err = proxyConn.Write(b) + _, err = conn.Write(b) if err != nil { return nil, fmt.Errorf("failed to write combined SOCKS5 request: %w", err) } @@ -149,7 +142,7 @@ func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (trans // +----+--------+ // buffer[0]: VER, buffer[1]: METHOD // Reuse buffer for better performance. - if _, err = io.ReadFull(proxyConn, buffer[:2]); err != nil { + if _, err = io.ReadFull(conn, buffer[:2]); err != nil { return nil, fmt.Errorf("failed to read method server response: %w", err) } if buffer[0] != 5 { @@ -169,7 +162,7 @@ func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (trans // +----+--------+ // VER = 1 means the server should be expecting username/password authentication. // buffer[2]: VER, buffer[3]: STATUS - if _, err = io.ReadFull(proxyConn, buffer[2:4]); err != nil { + if _, err = io.ReadFull(conn, buffer[2:4]); err != nil { return nil, fmt.Errorf("failed to read authentication version and status: %w", err) } if buffer[2] != 1 { @@ -193,7 +186,7 @@ func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (trans // buffer[1]: REP // buffer[2]: RSV // buffer[3]: ATYP - if _, err = io.ReadFull(proxyConn, buffer[:4]); err != nil { + if _, err = io.ReadFull(conn, buffer[:3]); err != nil { return nil, fmt.Errorf("failed to read connect server response: %w", err) } @@ -206,32 +199,40 @@ func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (trans return nil, ReplyCode(buffer[1]) } - // 4. Read address and length - var bndAddrLen int - switch buffer[3] { - case addrTypeIPv4: - bndAddrLen = 4 - case addrTypeIPv6: - bndAddrLen = 16 - case addrTypeDomainName: - // buffer[8]: length of the domain name - _, err := io.ReadFull(proxyConn, buffer[:1]) - if err != nil { - return nil, fmt.Errorf("failed to read address length in connect response: %w", err) - } - bndAddrLen = int(buffer[0]) - default: - return nil, fmt.Errorf("invalid address type %v", buffer[3]) - } - // 5. Reads the bound address and port, but we currently ignore them. - // TODO(fortuna): Should we expose the remote bound address as the net.Conn.LocalAddr()? - if _, err := io.ReadFull(proxyConn, buffer[:bndAddrLen]); err != nil { + // 4. Read BND.ADDR. + bindAddr, err := readAddr(conn) + if err != nil { return nil, fmt.Errorf("failed to read bound address: %w", err) } - // We read but ignore the remote bound port number: BND.PORT - if _, err = io.ReadFull(proxyConn, buffer[:2]); err != nil { - return nil, fmt.Errorf("failed to read bound port: %w", err) + + return bindAddr, nil +} + +// connectAndRequest manages the connection lifecycle and delegates the SOCKS5 communication to the request function. +func (c *Client) connectAndRequest(ctx context.Context, cmd byte, dstAddr string) (transport.StreamConn, *address, error) { + proxyConn, err := c.se.ConnectStream(ctx) + if err != nil { + return nil, nil, fmt.Errorf("could not connect to SOCKS5 proxy: %w", err) + } + + bindAddr, err := c.request(proxyConn, cmd, dstAddr) + if err != nil { + proxyConn.Close() + return nil, nil, err + } + + return proxyConn, bindAddr, nil +} + +// DialStream implements [transport.StreamDialer].DialStream using SOCKS5. +// It will send the auth method, auth credentials (if auth is chosen), and +// the connect requests in one packet, to avoid an additional roundtrip. +// The returned [error] will be of type [ReplyCode] if the server sends a SOCKS error reply code, which +// you can check against the error constants in this package using [errors.Is]. +func (c *Client) DialStream(ctx context.Context, dstAddr string) (transport.StreamConn, error) { + proxyConn, _, err := c.connectAndRequest(ctx, CmdConnect, dstAddr) + if err != nil { + return nil, err } - dialSuccess = true return proxyConn, nil } diff --git a/transport/socks5/stream_dialer_test.go b/transport/socks5/stream_dialer_test.go index 6352bd3b..6ce0b597 100644 --- a/transport/socks5/stream_dialer_test.go +++ b/transport/socks5/stream_dialer_test.go @@ -23,7 +23,6 @@ import ( "sync" "testing" "testing/iotest" - "time" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/stretchr/testify/assert" @@ -32,16 +31,16 @@ import ( ) func TestSOCKS5Dialer_NewStreamDialerNil(t *testing.T) { - dialer, err := NewStreamDialer(nil) - require.Nil(t, dialer) + client, err := NewClient(nil) + require.Nil(t, client) require.Error(t, err) } func TestSOCKS5Dialer_BadConnection(t *testing.T) { - dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: "127.0.0.0:0"}) - require.NotNil(t, dialer) + client, err := NewClient(&transport.TCPEndpoint{Address: "127.0.0.0:0"}) + require.NotNil(t, client) require.NoError(t, err) - _, err = dialer.DialStream(context.Background(), "example.com:443") + _, err = client.DialStream(context.Background(), "example.com:443") require.Error(t, err) } @@ -50,7 +49,7 @@ func TestSOCKS5Dialer_BadAddress(t *testing.T) { require.NoError(t, err, "Failed to create TCP listener: %v", err) defer listener.Close() - dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}) + dialer, err := NewClient(&transport.TCPEndpoint{Address: listener.Addr().String()}) require.NotNil(t, dialer) require.NoError(t, err) @@ -97,9 +96,9 @@ func testExchange(tb testing.TB, listener *net.TCPListener, destAddr string, req // Client go func() { defer running.Done() - dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}) + client, err := NewClient(&transport.TCPEndpoint{Address: listener.Addr().String()}) require.NoError(tb, err) - serverConn, err := dialer.DialStream(context.Background(), destAddr) + serverConn, err := client.DialStream(context.Background(), destAddr) if replyCode != 0 { require.ErrorIs(tb, err, replyCode) var extractedReplyCode ReplyCode @@ -168,31 +167,30 @@ func testExchange(tb testing.TB, listener *net.TCPListener, destAddr string, req } func TestConnectWithoutAuth(t *testing.T) { - // Create a SOCKS5 server + // Create a SOCKS5 server. server := socks5.NewServer() - // Create SOCKS5 proxy on localhost with a random port + // Create SOCKS5 proxy on localhost with a random port. listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + defer listener.Close() go func() { err := server.Serve(listener) - defer listener.Close() t.Log("server is listening...") - require.NoError(t, err) + if !errors.Is(err, net.ErrClosed) && err != nil { + require.NoError(t, err) // Assert no error if it's not the expected close error + } }() - // wait for server to start - time.Sleep(10 * time.Millisecond) - address := listener.Addr().String() - // Create a SOCKS5 client - dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: address}) - require.NotNil(t, dialer) + // Create a SOCKS5 client. + client, err := NewClient(&transport.TCPEndpoint{Address: address}) + require.NotNil(t, client) require.NoError(t, err) - _, err = dialer.DialStream(context.Background(), address) + _, err = client.DialStream(context.Background(), address) require.NoError(t, err) } @@ -207,21 +205,21 @@ func TestConnectWithAuth(t *testing.T) { socks5.WithAuthMethods([]socks5.Authenticator{cator}), ) - // Create SOCKS5 proxy on localhost with a random port + // Create SOCKS5 proxy on localhost with a random port. listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + defer listener.Close() address := listener.Addr().String() // Create SOCKS5 proxy on localhost port 8001 go func() { err := server.Serve(listener) - defer listener.Close() - require.NoError(t, err) + if !errors.Is(err, net.ErrClosed) && err != nil { + require.NoError(t, err) // Assert no error if it's not the expected close error + } }() - // wait for server to start - time.Sleep(10 * time.Millisecond) - dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: address}) + dialer, err := NewClient(&transport.TCPEndpoint{Address: address}) require.NotNil(t, dialer) require.NoError(t, err) err = dialer.SetCredentials([]byte("testusername"), []byte("testpassword")) @@ -229,7 +227,7 @@ func TestConnectWithAuth(t *testing.T) { _, err = dialer.DialStream(context.Background(), address) require.NoError(t, err) - // Try to connect with incorrect credentials + // Try to connect with incorrect credentials. err = dialer.SetCredentials([]byte("testusername"), []byte("wrongpassword")) require.NoError(t, err) _, err = dialer.DialStream(context.Background(), address) diff --git a/x/config/config.go b/x/config/config.go index aa6d7888..a5a73e24 100644 --- a/x/config/config.go +++ b/x/config/config.go @@ -57,6 +57,7 @@ func NewDefaultConfigToDialer() *ConfigToDialer { p.RegisterPacketDialerType("override", wrapPacketDialerWithOverride) p.RegisterStreamDialerType("socks5", wrapStreamDialerWithSOCKS5) + p.RegisterPacketDialerType("socks5", wrapPacketDialerWithSOCKS5) p.RegisterStreamDialerType("split", wrapStreamDialerWithSplit) diff --git a/x/config/config_test.go b/x/config/config_test.go index 21725596..a849ac34 100644 --- a/x/config/config_test.go +++ b/x/config/config_test.go @@ -26,11 +26,11 @@ func TestSanitizeConfig(t *testing.T) { require.NoError(t, err) // Test that a invalid cypher is rejected. - sanitizedConfig, err := SanitizeConfig("split:5|ss://jhvdsjkfhvkhsadvf@example.com:1234?prefix=HTTP%2F1.1%20") + _, err = SanitizeConfig("split:5|ss://jhvdsjkfhvkhsadvf@example.com:1234?prefix=HTTP%2F1.1%20") require.Error(t, err) // Test that a valid config is accepted and user info is redacted. - sanitizedConfig, err = SanitizeConfig("split:5|ss://Y2hhY2hhMjAtaWV0Zi1wb2x5MTMwNTpLeTUyN2duU3FEVFB3R0JpQ1RxUnlT@example.com:1234?prefix=HTTP%2F1.1%20") + sanitizedConfig, err := SanitizeConfig("split:5|ss://Y2hhY2hhMjAtaWV0Zi1wb2x5MTMwNTpLeTUyN2duU3FEVFB3R0JpQ1RxUnlT@example.com:1234?prefix=HTTP%2F1.1%20") require.NoError(t, err) require.Equal(t, "split:5|ss://REDACTED@example.com:1234?prefix=HTTP%2F1.1+", sanitizedConfig) diff --git a/x/config/doc.go b/x/config/doc.go index 539fcf29..a34e3fa6 100644 --- a/x/config/doc.go +++ b/x/config/doc.go @@ -40,7 +40,7 @@ Shadowsocks proxy (compatible with Outline's access keys, package [github.com/Ji ss://[USERINFO]@[HOST]:[PORT]?prefix=[PREFIX] -SOCKS5 proxy (currently streams only, package [github.com/Jigsaw-Code/outline-sdk/transport/socks5]) +SOCKS5 proxy (works with both stream and packet dialers, package [github.com/Jigsaw-Code/outline-sdk/transport/socks5]) socks5://[USERINFO]@[HOST]:[PORT] diff --git a/x/config/socks5.go b/x/config/socks5.go index 9820a148..105f6817 100644 --- a/x/config/socks5.go +++ b/x/config/socks5.go @@ -27,7 +27,7 @@ func wrapStreamDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), return nil, err } endpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: configURL.Host} - dialer, err := socks5.NewStreamDialer(&endpoint) + client, err := socks5.NewClient(&endpoint) if err != nil { return nil, err } @@ -35,10 +35,40 @@ func wrapStreamDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), if userInfo != nil { username := userInfo.Username() password, _ := userInfo.Password() - err := dialer.SetCredentials([]byte(username), []byte(password)) + err := client.SetCredentials([]byte(username), []byte(password)) if err != nil { return nil, err } } - return dialer, nil + + return client, nil +} + +func wrapPacketDialerWithSOCKS5(innerSD func() (transport.StreamDialer, error), innerPD func() (transport.PacketDialer, error), configURL *url.URL) (transport.PacketDialer, error) { + sd, err := innerSD() + if err != nil { + return nil, err + } + streamEndpoint := transport.StreamDialerEndpoint{Dialer: sd, Address: configURL.Host} + client, err := socks5.NewClient(&streamEndpoint) + if err != nil { + return nil, err + } + userInfo := configURL.User + if userInfo != nil { + username := userInfo.Username() + password, _ := userInfo.Password() + err := client.SetCredentials([]byte(username), []byte(password)) + if err != nil { + return nil, err + } + } + + pd, err := innerPD() + if err != nil { + return nil, err + } + client.EnablePacket(pd) + packetDialer := transport.PacketListenerDialer{Listener: client} + return packetDialer, nil } diff --git a/x/go.mod b/x/go.mod index e9c14043..0b322d15 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.21 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.16 + github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57 // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 diff --git a/x/go.sum b/x/go.sum index a8dcfab0..7c6f5f80 100644 --- a/x/go.sum +++ b/x/go.sum @@ -6,8 +6,8 @@ github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57 h1:CVuXDbdzPW github.com/AndreasBriese/bbloom v0.0.0-20170702084017-28f7e881ca57/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Jigsaw-Code/outline-sdk v0.0.16 h1:WbHmv80FKDIpzEmR3GehTbq5CibYTLvcxIIpMMILiEs= -github.com/Jigsaw-Code/outline-sdk v0.0.16/go.mod h1:e1oQZbSdLJBBuHgfeQsgEkvkuyIePPwstUeZRGq0KO8= +github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57 h1:XNSV0dGW48J8DmdmCnk/txGHf9glAPqa6Xme/rFWn7c= +github.com/Jigsaw-Code/outline-sdk v0.0.17-0.20240726212635-470a9290ec57/go.mod h1:e1oQZbSdLJBBuHgfeQsgEkvkuyIePPwstUeZRGq0KO8= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e h1:NPfqIbzmijrl0VclX2t8eO5EPBhqe47LLGKpRrcVjXk= github.com/Psiphon-Inc/rotate-safe-writer v0.0.0-20210303140923-464a7a37606e/go.mod h1:ZdY5pBfat/WVzw3eXbIf7N1nZN0XD5H5+X8ZMDWbCs4= github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7 h1:Hx/NCZTnvoKZuIBwSmxE58KKoNLXIGG6hBJYN7pj9Ag=