diff --git a/dns/resolver.go b/dns/resolver.go index 1c23629a..92b83a97 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -263,7 +263,7 @@ func ensurePort(address string, defaultPort string) string { func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver { resolverAddr = ensurePort(resolverAddr, "53") return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { - conn, err := pd.Dial(ctx, resolverAddr) + conn, err := pd.DialPacket(ctx, resolverAddr) if err != nil { return nil, &nestedError{ErrDial, err} } @@ -301,7 +301,7 @@ func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver { resolverAddr = ensurePort(resolverAddr, "53") return &streamResolver{ NewConn: func(ctx context.Context) (transport.StreamConn, error) { - return sd.Dial(ctx, resolverAddr) + return sd.DialStream(ctx, resolverAddr) }, } } @@ -315,7 +315,7 @@ func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName resolverAddr = ensurePort(resolverAddr, "853") return &streamResolver{ NewConn: func(ctx context.Context) (transport.StreamConn, error) { - baseConn, err := sd.Dial(ctx, resolverAddr) + baseConn, err := sd.DialStream(ctx, resolverAddr) if err != nil { return nil, err } @@ -336,7 +336,7 @@ func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string // TODO: Support UDP for QUIC. return nil, fmt.Errorf("protocol not supported: %v", network) } - conn, err := sd.Dial(ctx, resolverAddr) + conn, err := sd.DialStream(ctx, resolverAddr) if err != nil { return nil, &nestedError{ErrDial, err} } diff --git a/dns/resolver_net_test.go b/dns/resolver_net_test.go index 5ef23ed2..bc452aef 100644 --- a/dns/resolver_net_test.go +++ b/dns/resolver_net_test.go @@ -37,7 +37,7 @@ func newTestContext(t *testing.T) context.Context { func TestNewUDPResolver(t *testing.T) { ctx := newTestContext(t) - resolver := NewUDPResolver(&transport.UDPPacketDialer{}, "8.8.8.8") + resolver := NewUDPResolver(&transport.UDPDialer{}, "8.8.8.8") q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) require.NoError(t, err) resp, err := resolver.Query(ctx, *q) @@ -47,7 +47,7 @@ func TestNewUDPResolver(t *testing.T) { func TestNewTCPResolver(t *testing.T) { ctx := newTestContext(t) - resolver := NewTCPResolver(&transport.TCPStreamDialer{}, "8.8.8.8") + resolver := NewTCPResolver(&transport.TCPDialer{}, "8.8.8.8") q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) require.NoError(t, err) resp, err := resolver.Query(ctx, *q) @@ -57,7 +57,7 @@ func TestNewTCPResolver(t *testing.T) { func TestNewTLSResolver(t *testing.T) { ctx := newTestContext(t) - resolver := NewTLSResolver(&transport.TCPStreamDialer{}, "8.8.8.8", "8.8.8.8") + resolver := NewTLSResolver(&transport.TCPDialer{}, "8.8.8.8", "8.8.8.8") q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) require.NoError(t, err) resp, err := resolver.Query(ctx, *q) @@ -67,7 +67,7 @@ func TestNewTLSResolver(t *testing.T) { func TestNewHTTPSResolver(t *testing.T) { ctx := newTestContext(t) - resolver := NewHTTPSResolver(&transport.TCPStreamDialer{}, "8.8.8.8", "https://8.8.8.8/dns-query") + resolver := NewHTTPSResolver(&transport.TCPDialer{}, "8.8.8.8", "https://8.8.8.8/dns-query") q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) require.NoError(t, err) resp, err := resolver.Query(ctx, *q) diff --git a/go.mod b/go.mod index 558fad92..5ff0f143 100644 --- a/go.mod +++ b/go.mod @@ -8,12 +8,13 @@ require ( github.com/shadowsocks/go-shadowsocks2 v0.1.5 github.com/stretchr/testify v1.8.2 golang.org/x/crypto v0.17.0 - golang.org/x/net v0.17.0 + golang.org/x/net v0.19.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.1.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/go.sum b/go.sum index 7c9c5b13..f1f3f92a 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -8,8 +9,9 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= @@ -34,8 +36,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191021144547-ec77196f6094/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/network/lwip2transport/device_test.go b/network/lwip2transport/device_test.go index 39a73fa3..155ed8f2 100644 --- a/network/lwip2transport/device_test.go +++ b/network/lwip2transport/device_test.go @@ -54,7 +54,7 @@ type errTcpUdpHandler struct { err error } -func (h *errTcpUdpHandler) Dial(context.Context, string) (transport.StreamConn, error) { +func (h *errTcpUdpHandler) DialStream(context.Context, string) (transport.StreamConn, error) { return nil, h.err } diff --git a/network/lwip2transport/tcp.go b/network/lwip2transport/tcp.go index 0eb30043..ee856c04 100644 --- a/network/lwip2transport/tcp.go +++ b/network/lwip2transport/tcp.go @@ -36,7 +36,7 @@ func newTCPHandler(client transport.StreamDialer) *tcpHandler { } func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error { - proxyConn, err := h.dialer.Dial(context.Background(), target.String()) + proxyConn, err := h.dialer.DialStream(context.Background(), target.String()) if err != nil { return err } diff --git a/network/packet_listener_proxy_test.go b/network/packet_listener_proxy_test.go index cd7e0d78..b27e8f5e 100644 --- a/network/packet_listener_proxy_test.go +++ b/network/packet_listener_proxy_test.go @@ -23,7 +23,7 @@ import ( ) func TestWithWriteTimeoutOptionWorks(t *testing.T) { - pl := &transport.UDPPacketListener{} + pl := &transport.UDPListener{} defProxy, err := NewPacketProxyFromPacketListener(pl) require.NoError(t, err) diff --git a/transport/doc.go b/transport/doc.go new file mode 100644 index 00000000..f6dc74a8 --- /dev/null +++ b/transport/doc.go @@ -0,0 +1,47 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Package transport has the core types to work with transport layer connections. + +# Connections + +Connections enable communication between two endpoints over an abstract transport. There are two types of connections: + + - Stream connections, like TCP and the SOCK_STREAM Posix socket type. They are represented by [StreamConn] objects. + - Datagram connections, like UDP and the SOCK_DGRAM Posix socket type. They are represented by [net.Conn] objects. + +We use "Packet" instead of "Datagram" in the method and type names related to datagrams because that is the convention in the Go standard library. + +Each write and read on datagram connections represent a single datagram, while reads and writes on stream connections operate on byte sequences +that may be independent of how those bytes are packaged. + +Stream connections offer CloseRead and CloseWrite methods, which allows for a half-closed state (like TCP). +In general, you communicate end of data ("EOF") to the other side of the connection by calling CloseWrite (TCP will send a FIN). +CloseRead doesn't generate packets, but it allows for releasing resources (e.g. a read loop) and to signal errors to the peer +if more data does arrive (TCP will usually send a RST). + +Connections can be wrapped to create nested connections over a new transport. For example, a StreamConn could be over TCP, +over TLS over TCP, over HTTP over TLS over TCP, over QUIC, among other options. + +# Dialers + +Dialers enable the creation of connections given a host:port address while encapsulating the underlying transport or proxy protocol. +The [StreamDialer] and [PacketDialer] types create stream ([StreamConn]) and datagram ([net.Conn]) connections, respectively, given an address. + +Dialers can also be nested. For example, a TLS Stream Dialer can use a TCP dialer to create a StreamConn backed by a TCP connection, +then create a TLS StreamConn backed by the TCP StreamConn. A SOCKS5-over-TLS Dialer could use the TLS Dialer to create the TLS StreamConn +to the proxy before doing the SOCKS5 connection to the target address. +*/ +package transport diff --git a/transport/packet.go b/transport/packet.go index d4c59ae2..bac37b3d 100644 --- a/transport/packet.go +++ b/transport/packet.go @@ -22,8 +22,8 @@ import ( // PacketEndpoint represents an endpoint that can be used to establish packet connections (like UDP) to a fixed destination. type PacketEndpoint interface { - // Connect creates a connection bound to an endpoint, returning the connection. - Connect(ctx context.Context) (net.Conn, error) + // ConnectPacket creates a connection bound to an endpoint, returning the connection. + ConnectPacket(ctx context.Context) (net.Conn, error) } // UDPEndpoint is a [PacketEndpoint] that connects to the specified address using UDP. @@ -37,8 +37,8 @@ type UDPEndpoint struct { var _ PacketEndpoint = (*UDPEndpoint)(nil) -// Connect implements [PacketEndpoint].Connect. -func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) { +// ConnectPacket implements [PacketEndpoint].ConnectPacket. +func (e UDPEndpoint) ConnectPacket(ctx context.Context) (net.Conn, error) { return e.Dialer.DialContext(ctx, "udp", e.Address) } @@ -47,8 +47,8 @@ type FuncPacketEndpoint func(ctx context.Context) (net.Conn, error) var _ PacketEndpoint = (*FuncPacketEndpoint)(nil) -// Connect implements the [PacketEndpoint] interface. -func (f FuncPacketEndpoint) Connect(ctx context.Context) (net.Conn, error) { +// ConnectPacket implements the [PacketEndpoint] interface. +func (f FuncPacketEndpoint) ConnectPacket(ctx context.Context) (net.Conn, error) { return f(ctx) } @@ -60,28 +60,28 @@ type PacketDialerEndpoint struct { var _ PacketEndpoint = (*PacketDialerEndpoint)(nil) -// Connect implements [PacketEndpoint].Connect. -func (e *PacketDialerEndpoint) Connect(ctx context.Context) (net.Conn, error) { - return e.Dialer.Dial(ctx, e.Address) +// ConnectPacket implements [PacketEndpoint].ConnectPacket. +func (e *PacketDialerEndpoint) ConnectPacket(ctx context.Context) (net.Conn, error) { + return e.Dialer.DialPacket(ctx, e.Address) } // PacketDialer provides a way to dial a destination and establish datagram connections. type PacketDialer interface { - // Dial connects to `addr`. + // DialPacket connects to `addr`. // `addr` has the form "host:port", where "host" can be a domain name or IP address. - Dial(ctx context.Context, addr string) (net.Conn, error) + DialPacket(ctx context.Context, addr string) (net.Conn, error) } -// UDPPacketDialer is a [PacketDialer] that uses the standard [net.Dialer] to dial. +// UDPDialer is a [PacketDialer] that uses the standard [net.Dialer] to dial. // It provides a convenient way to use a [net.Dialer] when you need a [PacketDialer]. -type UDPPacketDialer struct { +type UDPDialer struct { Dialer net.Dialer } -var _ PacketDialer = (*UDPPacketDialer)(nil) +var _ PacketDialer = (*UDPDialer)(nil) -// Dial implements [PacketDialer].Dial. -func (d *UDPPacketDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { +// DialPacket implements [PacketDialer].DialPacket. +func (d *UDPDialer) DialPacket(ctx context.Context, addr string) (net.Conn, error) { return d.Dialer.DialContext(ctx, "udp", addr) } @@ -100,12 +100,12 @@ type boundPacketConn struct { var _ net.Conn = (*boundPacketConn)(nil) -// Dial implements [PacketDialer].Dial. +// DialPacket implements [PacketDialer].DialPacket. // The address is in "host:port" format and the host must be either a full IP address (not "[::]") or a domain. // The address must be supported by the WriteTo call of the [net.PacketConn] returned by the [PacketListener]. // For example, a [net.UDPConn] only supports IP addresses, not domain names. // If the host is a domain name, consider pre-resolving it to avoid resolution calls. -func (e PacketListenerDialer) Dial(ctx context.Context, address string) (net.Conn, error) { +func (e PacketListenerDialer) DialPacket(ctx context.Context, address string) (net.Conn, error) { packetConn, err := e.Listener.ListenPacket(ctx) if err != nil { return nil, fmt.Errorf("could not create PacketConn: %w", err) @@ -152,17 +152,17 @@ type PacketListener interface { ListenPacket(ctx context.Context) (net.PacketConn, error) } -// UDPPacketListener is a [PacketListener] that uses the standard [net.ListenConfig].ListenPacket to listen. -type UDPPacketListener struct { +// UDPListener is a [PacketListener] that uses the standard [net.ListenConfig].ListenPacket to listen. +type UDPListener struct { net.ListenConfig // The local address to bind to, as specified in [net.ListenPacket]. Address string } -var _ PacketListener = (*UDPPacketListener)(nil) +var _ PacketListener = (*UDPListener)(nil) // ListenPacket implements [PacketListener].ListenPacket -func (l UDPPacketListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { +func (l UDPListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { return l.ListenConfig.ListenPacket(ctx, "udp", l.Address) } @@ -171,7 +171,7 @@ type FuncPacketDialer func(ctx context.Context, addr string) (net.Conn, error) var _ PacketDialer = (*FuncPacketDialer)(nil) -// Dial implements the [PacketDialer] interface. -func (f FuncPacketDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { +// DialPacket implements the [PacketDialer] interface. +func (f FuncPacketDialer) DialPacket(ctx context.Context, addr string) (net.Conn, error) { return f(ctx, addr) } diff --git a/transport/packet_test.go b/transport/packet_test.go index 1c813077..aeb65a4e 100644 --- a/transport/packet_test.go +++ b/transport/packet_test.go @@ -36,7 +36,7 @@ func TestUDPEndpointIPv4(t *testing.T) { require.Equal(t, serverAddr, address) return nil } - conn, err := ep.Connect(context.Background()) + conn, err := ep.ConnectPacket(context.Background()) require.NoError(t, err) assert.Equal(t, "udp", conn.RemoteAddr().Network()) assert.Equal(t, serverAddr, conn.RemoteAddr().String()) @@ -50,7 +50,7 @@ func TestUDPEndpointIPv6(t *testing.T) { require.Equal(t, serverAddr, address) return nil } - conn, err := ep.Connect(context.Background()) + conn, err := ep.ConnectPacket(context.Background()) require.NoError(t, err) assert.Equal(t, "udp", conn.RemoteAddr().Network()) assert.Equal(t, serverAddr, conn.RemoteAddr().String()) @@ -64,7 +64,7 @@ func TestUDPEndpointDomain(t *testing.T) { resolvedAddr = address return nil } - conn, err := ep.Connect(context.Background()) + conn, err := ep.ConnectPacket(context.Background()) require.NoError(t, err) assert.Equal(t, "udp", conn.RemoteAddr().Network()) assert.Equal(t, resolvedAddr, conn.RemoteAddr().String()) @@ -76,7 +76,7 @@ func TestFuncPacketEndpoint(t *testing.T) { endpoint := FuncPacketEndpoint(func(ctx context.Context) (net.Conn, error) { return expectedConn, expectedErr }) - conn, err := endpoint.Connect(context.Background()) + conn, err := endpoint.ConnectPacket(context.Background()) require.Equal(t, expectedConn, conn) require.Equal(t, expectedErr, err) } @@ -88,7 +88,7 @@ func TestFuncPacketDialer(t *testing.T) { require.Equal(t, "unused", addr) return expectedConn, expectedErr }) - conn, err := dialer.Dial(context.Background(), "unused") + conn, err := dialer.DialPacket(context.Background(), "unused") require.Equal(t, expectedConn, conn) require.Equal(t, expectedErr, err) } @@ -96,7 +96,7 @@ func TestFuncPacketDialer(t *testing.T) { // UDPPacketListener func TestUDPPacketListenerLocalIPv4Addr(t *testing.T) { - listener := &UDPPacketListener{Address: "127.0.0.1:0"} + listener := &UDPListener{Address: "127.0.0.1:0"} pc, err := listener.ListenPacket(context.Background()) require.NoError(t, err) require.Equal(t, "udp", pc.LocalAddr().Network()) @@ -106,7 +106,7 @@ func TestUDPPacketListenerLocalIPv4Addr(t *testing.T) { } func TestUDPPacketListenerLocalIPv6Addr(t *testing.T) { - listener := &UDPPacketListener{Address: "[::1]:0"} + listener := &UDPListener{Address: "[::1]:0"} pc, err := listener.ListenPacket(context.Background()) require.NoError(t, err) require.Equal(t, "udp", pc.LocalAddr().Network()) @@ -116,7 +116,7 @@ func TestUDPPacketListenerLocalIPv6Addr(t *testing.T) { } func TestUDPPacketListenerLocalhost(t *testing.T) { - listener := &UDPPacketListener{Address: "localhost:0"} + listener := &UDPListener{Address: "localhost:0"} pc, err := listener.ListenPacket(context.Background()) require.NoError(t, err) require.Equal(t, "udp", pc.LocalAddr().Network()) @@ -126,7 +126,7 @@ func TestUDPPacketListenerLocalhost(t *testing.T) { } func TestUDPPacketListenerDefaulAddr(t *testing.T) { - listener := &UDPPacketListener{} + listener := &UDPListener{} pc, err := listener.ListenPacket(context.Background()) require.Equal(t, "udp", pc.LocalAddr().Network()) require.NoError(t, err) @@ -142,8 +142,8 @@ func TestUDPPacketDialer(t *testing.T) { require.NoError(t, err) require.Equal(t, "udp", server.LocalAddr().Network()) - dialer := &UDPPacketDialer{} - conn, err := dialer.Dial(context.Background(), server.LocalAddr().String()) + dialer := &UDPDialer{} + conn, err := dialer.DialPacket(context.Background(), server.LocalAddr().String()) require.NoError(t, err) request := []byte("PING") @@ -169,7 +169,7 @@ func TestPacketListenerDialer(t *testing.T) { request := []byte("Request") response := []byte("Response") - serverListener := UDPPacketListener{Address: "127.0.0.1:0"} + serverListener := UDPListener{Address: "127.0.0.1:0"} serverPacketConn, err := serverListener.ListenPacket(context.Background()) require.NoError(t, err, "Failed to create UDP listener: %v", err) t.Logf("Listening on %v", serverPacketConn.LocalAddr()) @@ -202,9 +202,9 @@ func TestPacketListenerDialer(t *testing.T) { }() serverEndpoint := &PacketListenerDialer{ - Listener: UDPPacketListener{Address: "127.0.0.1:0"}, + Listener: UDPListener{Address: "127.0.0.1:0"}, } - conn, err := serverEndpoint.Dial(context.Background(), serverPacketConn.LocalAddr().String()) + conn, err := serverEndpoint.DialPacket(context.Background(), serverPacketConn.LocalAddr().String()) require.NoError(t, err) t.Logf("Connected to %v from %v", conn.RemoteAddr(), conn.LocalAddr()) defer func() { diff --git a/transport/shadowsocks/packet_listener.go b/transport/shadowsocks/packet_listener.go index aa0c9efd..548138dc 100644 --- a/transport/shadowsocks/packet_listener.go +++ b/transport/shadowsocks/packet_listener.go @@ -50,7 +50,7 @@ func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (t } func (c *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { - proxyConn, err := c.endpoint.Connect(ctx) + proxyConn, err := c.endpoint.ConnectPacket(ctx) if err != nil { return nil, fmt.Errorf("could not connect to endpoint: %w", err) } diff --git a/transport/shadowsocks/stream_dialer.go b/transport/shadowsocks/stream_dialer.go index 8749e73f..e1fc8bb3 100644 --- a/transport/shadowsocks/stream_dialer.go +++ b/transport/shadowsocks/stream_dialer.go @@ -63,7 +63,7 @@ type StreamDialer struct { var _ transport.StreamDialer = (*StreamDialer)(nil) -// Dial implements StreamDialer.Dial using a Shadowsocks server. +// DialStream implements StreamDialer.DialStream using a Shadowsocks server. // // The Shadowsocks StreamDialer returns a connection after the connection to the proxy is established, // but before the connection to the target is established. That means we cannot signal "connection refused" @@ -78,12 +78,12 @@ var _ transport.StreamDialer = (*StreamDialer)(nil) // initial data from the application in order to send the Shadowsocks salt, SOCKS address and initial data // all in one packet. This makes the size of the initial packet hard to predict, avoiding packet size // fingerprinting. We can only get the application initial data if we return a connection first. -func (c *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { +func (c *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { socksTargetAddr := socks.ParseAddr(remoteAddr) if socksTargetAddr == nil { return nil, errors.New("failed to parse target address") } - proxyConn, err := c.endpoint.Connect(ctx) + proxyConn, err := c.endpoint.ConnectStream(ctx) if err != nil { return nil, err } diff --git a/transport/shadowsocks/stream_dialer_test.go b/transport/shadowsocks/stream_dialer_test.go index 8c223138..e361fcaa 100644 --- a/transport/shadowsocks/stream_dialer_test.go +++ b/transport/shadowsocks/stream_dialer_test.go @@ -34,7 +34,7 @@ func TestStreamDialer_Dial(t *testing.T) { if err != nil { t.Fatalf("Failed to create StreamDialer: %v", err) } - conn, err := d.Dial(context.Background(), testTargetAddr) + conn, err := d.DialStream(context.Background(), testTargetAddr) if err != nil { t.Fatalf("StreamDialer.Dial failed: %v", err) } @@ -56,7 +56,7 @@ func TestStreamDialer_DialNoPayload(t *testing.T) { // Extend the wait to be safer. d.ClientDataWait = 0 * time.Millisecond - conn, err := d.Dial(context.Background(), testTargetAddr) + conn, err := d.DialStream(context.Background(), testTargetAddr) if err != nil { t.Fatalf("StreamDialer.Dial failed: %v", err) } @@ -102,7 +102,7 @@ func TestStreamDialer_DialFastClose(t *testing.T) { // Extend the wait to be safer. d.ClientDataWait = 100 * time.Millisecond - conn, err := d.Dial(context.Background(), testTargetAddr) + conn, err := d.DialStream(context.Background(), testTargetAddr) require.NoError(t, err, "StreamDialer.Dial failed: %v", err) // Wait for less than 100 milliseconds to ensure that the target @@ -151,7 +151,7 @@ func TestStreamDialer_TCPPrefix(t *testing.T) { t.Fatalf("Failed to create StreamDialer: %v", err) } d.SaltGenerator = NewPrefixSaltGenerator(prefix) - conn, err := d.Dial(context.Background(), testTargetAddr) + conn, err := d.DialStream(context.Background(), testTargetAddr) if err != nil { t.Fatalf("StreamDialer.Dial failed: %v", err) } @@ -170,7 +170,7 @@ func BenchmarkStreamDialer_Dial(b *testing.B) { if err != nil { b.Fatalf("Failed to create StreamDialer: %v", err) } - conn, err := d.Dial(context.Background(), testTargetAddr) + conn, err := d.DialStream(context.Background(), testTargetAddr) if err != nil { b.Fatalf("StreamDialer.Dial failed: %v", err) } diff --git a/transport/socks5/stream_dialer.go b/transport/socks5/stream_dialer.go index 14743d34..9d4febef 100644 --- a/transport/socks5/stream_dialer.go +++ b/transport/socks5/stream_dialer.go @@ -38,12 +38,12 @@ type streamDialer struct { var _ transport.StreamDialer = (*streamDialer)(nil) -// Dial implements [transport.StreamDialer].Dial using SOCKS5. +// DialStream implements [transport.StreamDialer].DialStream using SOCKS5. // It will send the method and the connect requests in one packet, to avoid an unnecessary roundtrip. // The returned [error] will be of type [ReplyCode] if the server sends a SOCKS error reply code, which // you can check against the error constants in this package using [errors.Is]. -func (c *streamDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { - proxyConn, err := c.proxyEndpoint.Connect(ctx) +func (c *streamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + proxyConn, err := c.proxyEndpoint.ConnectStream(ctx) if err != nil { return nil, fmt.Errorf("could not connect to SOCKS5 proxy: %w", err) } diff --git a/transport/socks5/stream_dialer_test.go b/transport/socks5/stream_dialer_test.go index cdf997e0..e79fdcac 100644 --- a/transport/socks5/stream_dialer_test.go +++ b/transport/socks5/stream_dialer_test.go @@ -39,7 +39,7 @@ func TestSOCKS5Dialer_BadConnection(t *testing.T) { dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: "127.0.0.0:0"}) require.NotNil(t, dialer) require.NoError(t, err) - _, err = dialer.Dial(context.Background(), "example.com:443") + _, err = dialer.DialStream(context.Background(), "example.com:443") require.Error(t, err) } @@ -52,7 +52,7 @@ func TestSOCKS5Dialer_BadAddress(t *testing.T) { require.NotNil(t, dialer) require.NoError(t, err) - _, err = dialer.Dial(context.Background(), "noport") + _, err = dialer.DialStream(context.Background(), "noport") require.Error(t, err) } @@ -97,7 +97,7 @@ func testExchange(tb testing.TB, listener *net.TCPListener, destAddr string, req defer running.Done() dialer, err := NewStreamDialer(&transport.TCPEndpoint{Address: listener.Addr().String()}) require.NoError(tb, err) - serverConn, err := dialer.Dial(context.Background(), destAddr) + serverConn, err := dialer.DialStream(context.Background(), destAddr) if replyCode != 0 { require.ErrorIs(tb, err, replyCode) var extractedReplyCode ReplyCode diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index 4a1275a6..fac0a6d9 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -37,9 +37,9 @@ func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64) (transpor return &splitDialer{dialer: dialer, splitPoint: prefixBytes}, nil } -// Dial implements [transport.StreamDialer].Dial. -func (d *splitDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { - innerConn, err := d.dialer.Dial(ctx, remoteAddr) +// DialStream implements [transport.StreamDialer].DialStream. +func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.DialStream(ctx, remoteAddr) if err != nil { return nil, err } diff --git a/transport/stream.go b/transport/stream.go index 9dbdcdac..60241354 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -72,8 +72,8 @@ func WrapConn(c StreamConn, r io.Reader, w io.Writer) StreamConn { // StreamEndpoint represents an endpoint that can be used to establish stream connections (like TCP) to a fixed // destination. type StreamEndpoint interface { - // Connect establishes a connection with the endpoint, returning the connection. - Connect(ctx context.Context) (StreamConn, error) + // ConnectStream establishes a connection with the endpoint, returning the connection. + ConnectStream(ctx context.Context) (StreamConn, error) } // TCPEndpoint is a [StreamEndpoint] that connects to the specified address using the specified [StreamDialer]. @@ -87,8 +87,8 @@ type TCPEndpoint struct { var _ StreamEndpoint = (*TCPEndpoint)(nil) -// Connect implements [StreamEndpoint].Connect. -func (e *TCPEndpoint) Connect(ctx context.Context) (StreamConn, error) { +// ConnectStream implements [StreamEndpoint].ConnectStream. +func (e *TCPEndpoint) ConnectStream(ctx context.Context) (StreamConn, error) { conn, err := e.Dialer.DialContext(ctx, "tcp", e.Address) if err != nil { return nil, err @@ -101,8 +101,8 @@ type FuncStreamEndpoint func(ctx context.Context) (StreamConn, error) var _ StreamEndpoint = (*FuncStreamEndpoint)(nil) -// Connect implements the [StreamEndpoint] interface. -func (f FuncStreamEndpoint) Connect(ctx context.Context) (StreamConn, error) { +// ConnectStream implements the [StreamEndpoint] interface. +func (f FuncStreamEndpoint) ConnectStream(ctx context.Context) (StreamConn, error) { return f(ctx) } @@ -115,27 +115,27 @@ type StreamDialerEndpoint struct { var _ StreamEndpoint = (*StreamDialerEndpoint)(nil) -// Connect implements [StreamEndpoint].Connect. -func (e *StreamDialerEndpoint) Connect(ctx context.Context) (StreamConn, error) { - return e.Dialer.Dial(ctx, e.Address) +// ConnectStream implements [StreamEndpoint].ConnectStream. +func (e *StreamDialerEndpoint) ConnectStream(ctx context.Context) (StreamConn, error) { + return e.Dialer.DialStream(ctx, e.Address) } // StreamDialer provides a way to dial a destination and establish stream connections. type StreamDialer interface { - // Dial connects to `raddr`. + // DialStream connects to `raddr`. // `raddr` has the form "host:port", where "host" can be a domain name or IP address. - Dial(ctx context.Context, raddr string) (StreamConn, error) + DialStream(ctx context.Context, raddr string) (StreamConn, error) } -// TCPStreamDialer is a [StreamDialer] that uses the standard [net.Dialer] to dial. +// TCPDialer is a [StreamDialer] that uses the standard [net.Dialer] to dial. // It provides a convenient way to use a [net.Dialer] when you need a [StreamDialer]. -type TCPStreamDialer struct { +type TCPDialer struct { Dialer net.Dialer } -var _ StreamDialer = (*TCPStreamDialer)(nil) +var _ StreamDialer = (*TCPDialer)(nil) -func (d *TCPStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, error) { +func (d *TCPDialer) DialStream(ctx context.Context, addr string) (StreamConn, error) { conn, err := d.Dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err @@ -148,7 +148,7 @@ type FuncStreamDialer func(ctx context.Context, addr string) (StreamConn, error) var _ StreamDialer = (*FuncStreamDialer)(nil) -// Dial implements the [StreamDialer] interface. -func (f FuncStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, error) { +// DialStream implements the [StreamDialer] interface. +func (f FuncStreamDialer) DialStream(ctx context.Context, addr string) (StreamConn, error) { return f(ctx, addr) } diff --git a/transport/stream_test.go b/transport/stream_test.go index 98de5d02..27f51c7e 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -37,7 +37,7 @@ func TestFuncStreamEndpoint(t *testing.T) { endpoint := FuncStreamEndpoint(func(ctx context.Context) (StreamConn, error) { return expectedConn, expectedErr }) - conn, err := endpoint.Connect(context.Background()) + conn, err := endpoint.ConnectStream(context.Background()) require.Equal(t, expectedConn, conn) require.Equal(t, expectedErr, err) } @@ -49,7 +49,7 @@ func TestFuncStreamDialer(t *testing.T) { require.Equal(t, "unused", addr) return expectedConn, expectedErr }) - conn, err := dialer.Dial(context.Background(), "unused") + conn, err := dialer.DialStream(context.Background(), "unused") require.Equal(t, expectedConn, conn) require.Equal(t, expectedErr, err) } @@ -92,13 +92,13 @@ func TestNewTCPStreamDialerIPv4(t *testing.T) { // Client go func() { defer running.Done() - dialer := &TCPStreamDialer{} + dialer := &TCPDialer{} dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { require.Equal(t, "tcp4", network) require.Equal(t, listener.Addr().String(), address) return nil } - serverConn, err := dialer.Dial(context.Background(), listener.Addr().String()) + serverConn, err := dialer.DialStream(context.Background(), listener.Addr().String()) require.NoError(t, err, "Dial failed") require.Equal(t, listener.Addr().String(), serverConn.RemoteAddr().String()) defer serverConn.Close() @@ -120,14 +120,14 @@ func TestNewTCPStreamDialerIPv4(t *testing.T) { func TestNewTCPStreamDialerAddress(t *testing.T) { errCancel := errors.New("cancelled") - dialer := &TCPStreamDialer{} + dialer := &TCPDialer{} dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { require.Equal(t, "tcp4", network) require.Equal(t, "8.8.8.8:53", address) return errCancel } - _, err := dialer.Dial(context.Background(), "8.8.8.8:53") + _, err := dialer.DialStream(context.Background(), "8.8.8.8:53") require.ErrorIs(t, err, errCancel) dialer.Dialer.Control = func(network, address string, c syscall.RawConn) error { @@ -135,7 +135,7 @@ func TestNewTCPStreamDialerAddress(t *testing.T) { require.Equal(t, "[2001:4860:4860::8888]:53", address) return errCancel } - _, err = dialer.Dial(context.Background(), "[2001:4860:4860::8888]:53") + _, err = dialer.DialStream(context.Background(), "[2001:4860:4860::8888]:53") require.ErrorIs(t, err, errCancel) } @@ -150,7 +150,7 @@ func TestDialStreamEndpointAddr(t *testing.T) { require.Equal(t, listener.Addr().String(), address) return nil } - conn, err := endpoint.Connect(context.Background()) + conn, err := endpoint.ConnectStream(context.Background()) require.NoError(t, err) require.Equal(t, listener.Addr().String(), conn.RemoteAddr().String()) require.Nil(t, conn.Close()) diff --git a/transport/tls/stream_dialer.go b/transport/tls/stream_dialer.go index 383341f2..7e6bc125 100644 --- a/transport/tls/stream_dialer.go +++ b/transport/tls/stream_dialer.go @@ -62,9 +62,9 @@ func (c streamConn) CloseRead() error { return c.innerConn.CloseRead() } -// Dial implements [transport.StreamDialer].Dial. -func (d *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { - innerConn, err := d.dialer.Dial(ctx, remoteAddr) +// DialStream implements [transport.StreamDialer].DialStream. +func (d *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.DialStream(ctx, remoteAddr) if err != nil { return nil, err } diff --git a/transport/tls/stream_dialer_test.go b/transport/tls/stream_dialer_test.go index 33fc1708..7927c099 100644 --- a/transport/tls/stream_dialer_test.go +++ b/transport/tls/stream_dialer_test.go @@ -24,9 +24,9 @@ import ( ) func TestDomain(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}) + sd, err := NewStreamDialer(&transport.TCPDialer{}) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "dns.google:443") + conn, err := sd.DialStream(context.Background(), "dns.google:443") require.NoError(t, err) tlsConn, ok := conn.(streamConn) require.True(t, ok) @@ -37,76 +37,76 @@ func TestDomain(t *testing.T) { } func TestUntrustedRoot(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}) + sd, err := NewStreamDialer(&transport.TCPDialer{}) require.NoError(t, err) - _, err = sd.Dial(context.Background(), "untrusted-root.badssl.com:443") + _, err = sd.DialStream(context.Background(), "untrusted-root.badssl.com:443") var certErr x509.UnknownAuthorityError require.ErrorAs(t, err, &certErr) } func TestRevoked(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}) + sd, err := NewStreamDialer(&transport.TCPDialer{}) require.NoError(t, err) - _, err = sd.Dial(context.Background(), "revoked.badssl.com:443") + _, err = sd.DialStream(context.Background(), "revoked.badssl.com:443") var certErr x509.CertificateInvalidError require.ErrorAs(t, err, &certErr) require.Equal(t, x509.Expired, certErr.Reason) } func TestIP(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}) + sd, err := NewStreamDialer(&transport.TCPDialer{}) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "8.8.8.8:443") + conn, err := sd.DialStream(context.Background(), "8.8.8.8:443") require.NoError(t, err) conn.Close() } func TestIPOverride(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, WithCertificateName("8.8.8.8")) + sd, err := NewStreamDialer(&transport.TCPDialer{}, WithCertificateName("8.8.8.8")) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "dns.google:443") + conn, err := sd.DialStream(context.Background(), "dns.google:443") require.NoError(t, err) conn.Close() } func TestFakeSNI(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, WithSNI("decoy.example.com")) + sd, err := NewStreamDialer(&transport.TCPDialer{}, WithSNI("decoy.example.com")) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "www.youtube.com:443") + conn, err := sd.DialStream(context.Background(), "www.youtube.com:443") require.NoError(t, err) conn.Close() } func TestNoSNI(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, WithSNI("")) + sd, err := NewStreamDialer(&transport.TCPDialer{}, WithSNI("")) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "dns.google:443") + conn, err := sd.DialStream(context.Background(), "dns.google:443") require.NoError(t, err) conn.Close() } func TestAllCustom(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, WithSNI("decoy.android.com"), WithCertificateName("www.youtube.com")) + sd, err := NewStreamDialer(&transport.TCPDialer{}, WithSNI("decoy.android.com"), WithCertificateName("www.youtube.com")) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "www.google.com:443") + conn, err := sd.DialStream(context.Background(), "www.google.com:443") require.NoError(t, err) conn.Close() } func TestHostSelector(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, + sd, err := NewStreamDialer(&transport.TCPDialer{}, IfHost("dns.google", WithSNI("decoy.example.com")), IfHost("www.youtube.com", WithSNI("notyoutube.com")), ) require.NoError(t, err) - conn, err := sd.Dial(context.Background(), "dns.google:443") + conn, err := sd.DialStream(context.Background(), "dns.google:443") require.NoError(t, err) tlsConn := conn.(streamConn) require.Equal(t, "decoy.example.com", tlsConn.ConnectionState().ServerName) conn.Close() - conn, err = sd.Dial(context.Background(), "www.youtube.com:443") + conn, err = sd.DialStream(context.Background(), "www.youtube.com:443") require.NoError(t, err) tlsConn = conn.(streamConn) require.Equal(t, "notyoutube.com", tlsConn.ConnectionState().ServerName) diff --git a/transport/tlsfrag/stream_dialer.go b/transport/tlsfrag/stream_dialer.go index 550ae131..1cf38bec 100644 --- a/transport/tlsfrag/stream_dialer.go +++ b/transport/tlsfrag/stream_dialer.go @@ -57,10 +57,10 @@ func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc) (transport. return &tlsFragDialer{base, frag}, nil } -// Dial implements [transport.StreamConn].Dial. It establishes a connection to raddr in the format "host-or-ip:port". +// DialStream implements [transport.StreamConn].DialStream. It establishes a connection to raddr in the format "host-or-ip:port". // The initial TLS Client Hello record sent through the connection will be fragmented. -func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (conn transport.StreamConn, err error) { - conn, err = d.dialer.Dial(ctx, raddr) +func (d *tlsFragDialer) DialStream(ctx context.Context, raddr string) (conn transport.StreamConn, err error) { + conn, err = d.dialer.DialStream(ctx, raddr) if err != nil { return } diff --git a/transport/tlsfrag/stream_dialer_test.go b/transport/tlsfrag/stream_dialer_test.go index c88754a8..3d127395 100644 --- a/transport/tlsfrag/stream_dialer_test.go +++ b/transport/tlsfrag/stream_dialer_test.go @@ -182,7 +182,7 @@ func assertCanDialFragFunc(t *testing.T, inner transport.StreamDialer, raddr str d, err := NewStreamDialerFunc(inner, frag) require.NoError(t, err) require.NotNil(t, d) - conn, err := d.Dial(context.Background(), raddr) + conn, err := d.DialStream(context.Background(), raddr) require.NoError(t, err) require.NotNil(t, conn) return conn @@ -192,7 +192,7 @@ func assertCanDialFixedLenFrag(t *testing.T, inner transport.StreamDialer, raddr d, err := NewFixedLenStreamDialer(inner, splitLen) require.NoError(t, err) require.NotNil(t, d) - conn, err := d.Dial(context.Background(), raddr) + conn, err := d.DialStream(context.Background(), raddr) require.NoError(t, err) require.NotNil(t, conn) return conn @@ -231,7 +231,7 @@ type collectStreamDialer struct { bufs net.Buffers } -func (d *collectStreamDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { +func (d *collectStreamDialer) DialStream(ctx context.Context, raddr string) (transport.StreamConn, error) { return d, nil } diff --git a/x/config/config.go b/x/config/config.go index ac73f286..57f04186 100644 --- a/x/config/config.go +++ b/x/config/config.go @@ -45,7 +45,7 @@ func parseConfigPart(oneDialerConfig string) (*url.URL, error) { // NewStreamDialer creates a new [transport.StreamDialer] according to the given config. func NewStreamDialer(transportConfig string) (transport.StreamDialer, error) { - return WrapStreamDialer(&transport.TCPStreamDialer{}, transportConfig) + return WrapStreamDialer(&transport.TCPDialer{}, transportConfig) } // WrapStreamDialer created a [transport.StreamDialer] according to transportConfig, using dialer as the @@ -112,7 +112,7 @@ func newStreamDialerFromPart(innerDialer transport.StreamDialer, oneDialerConfig // NewPacketDialer creates a new [transport.PacketDialer] according to the given config. func NewPacketDialer(transportConfig string) (dialer transport.PacketDialer, err error) { - dialer = &transport.UDPPacketDialer{} + dialer = &transport.UDPDialer{} transportConfig = strings.TrimSpace(transportConfig) if transportConfig == "" { return dialer, nil diff --git a/x/config/override.go b/x/config/override.go index d47a2d7f..885bc87e 100644 --- a/x/config/override.go +++ b/x/config/override.go @@ -76,7 +76,7 @@ func newOverrideStreamDialerFromURL(innerDialer transport.StreamDialer, configUR if err != nil { return nil, err } - return innerDialer.Dial(ctx, addr) + return innerDialer.DialStream(ctx, addr) }), nil } @@ -90,6 +90,6 @@ func newOverridePacketDialerFromURL(innerDialer transport.PacketDialer, configUR if err != nil { return nil, err } - return innerDialer.Dial(ctx, addr) + return innerDialer.DialPacket(ctx, addr) }), nil } diff --git a/x/config/tls_test.go b/x/config/tls_test.go index 1e77313c..47e20ddc 100644 --- a/x/config/tls_test.go +++ b/x/config/tls_test.go @@ -26,7 +26,7 @@ import ( func TestTLS(t *testing.T) { tlsURL, err := url.Parse("tls") require.NoError(t, err) - _, err = newTlsStreamDialerFromURL(&transport.TCPStreamDialer{}, tlsURL) + _, err = newTlsStreamDialerFromURL(&transport.TCPDialer{}, tlsURL) require.NoError(t, err) } diff --git a/x/connectivity/connectivity_test.go b/x/connectivity/connectivity_test.go index 920950a0..ee607633 100644 --- a/x/connectivity/connectivity_test.go +++ b/x/connectivity/connectivity_test.go @@ -36,7 +36,7 @@ import ( // StreamDialer Tests func TestTestResolverStreamConnectivityOk(t *testing.T) { // TODO(fortuna): Run a local resolver and make test not depend on an external server. - resolver := dns.NewTCPResolver(&transport.TCPStreamDialer{}, "8.8.8.8:53") + resolver := dns.NewTCPResolver(&transport.TCPDialer{}, "8.8.8.8:53") result, err := TestConnectivityWithResolver(context.Background(), resolver, "example.com") require.NoError(t, err) require.Nil(t, result) @@ -71,7 +71,7 @@ func TestTestResolverStreamConnectivityRefused(t *testing.T) { // Close right away to ensure the port is closed. The OS will likely not reuse it soon enough. require.Nil(t, listener.Close()) - resolver := dns.NewTCPResolver(&transport.TCPStreamDialer{}, listener.Addr().String()) + resolver := dns.NewTCPResolver(&transport.TCPDialer{}, listener.Addr().String()) result, err := TestConnectivityWithResolver(context.Background(), resolver, "anything") require.NoError(t, err) require.NotNil(t, result) @@ -107,7 +107,7 @@ func TestTestResolverStreamConnectivityReset(t *testing.T) { }, &running) defer listener.Close() - resolver := dns.NewTCPResolver(&transport.TCPStreamDialer{}, listener.Addr().String()) + resolver := dns.NewTCPResolver(&transport.TCPDialer{}, listener.Addr().String()) result, err := TestConnectivityWithResolver(context.Background(), resolver, "anything") require.NoError(t, err) require.NotNil(t, result) @@ -139,7 +139,7 @@ func TestTestStreamDialerEarlyClose(t *testing.T) { }, &running) defer listener.Close() - resolver := dns.NewTCPResolver(&transport.TCPStreamDialer{}, listener.Addr().String()) + resolver := dns.NewTCPResolver(&transport.TCPDialer{}, listener.Addr().String()) result, err := TestConnectivityWithResolver(context.Background(), resolver, "anything") require.NoError(t, err) require.NotNil(t, result) @@ -164,7 +164,7 @@ func TestTestResolverStreamConnectivityTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - resolver := dns.NewTCPResolver(&transport.TCPStreamDialer{}, listener.Addr().String()) + resolver := dns.NewTCPResolver(&transport.TCPDialer{}, listener.Addr().String()) result, err := TestConnectivityWithResolver(ctx, resolver, "anything") require.NoError(t, err) require.NotNil(t, result) @@ -204,7 +204,7 @@ func TestTestPacketPacketConnectivityOk(t *testing.T) { require.NoError(t, err) }() - resolver := dns.NewUDPResolver(&transport.UDPPacketDialer{}, server.LocalAddr().String()) + resolver := dns.NewUDPResolver(&transport.UDPDialer{}, server.LocalAddr().String()) result, err := TestConnectivityWithResolver(context.Background(), resolver, "anything") require.NoError(t, err) require.Nil(t, result) diff --git a/x/examples/fetch/main.go b/x/examples/fetch/main.go index 83436995..7c1e08e6 100644 --- a/x/examples/fetch/main.go +++ b/x/examples/fetch/main.go @@ -86,7 +86,7 @@ func main() { if !strings.HasPrefix(network, "tcp") { return nil, fmt.Errorf("protocol not supported: %v", network) } - return dialer.Dial(ctx, net.JoinHostPort(host, port)) + return dialer.DialStream(ctx, net.JoinHostPort(host, port)) } httpClient := &http.Client{Transport: &http.Transport{DialContext: dialContext}, Timeout: 5 * time.Second} diff --git a/x/examples/fyne-proxy/go.mod b/x/examples/fyne-proxy/go.mod index 37618103..72cf65f1 100644 --- a/x/examples/fyne-proxy/go.mod +++ b/x/examples/fyne-proxy/go.mod @@ -4,13 +4,13 @@ go 1.20 require ( fyne.io/fyne/v2 v2.4.3 - github.com/Jigsaw-Code/outline-sdk/x v0.0.0-20240112224558-7294484cf816 + github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212231-233d1898e1db + github.com/Jigsaw-Code/outline-sdk/x v0.0.0-20240117212231-233d1898e1db ) require ( fyne.io/systray v1.10.1-0.20231115130155-104f5ef7839e // indirect github.com/BurntSushi/toml v1.3.2 // indirect - github.com/Jigsaw-Code/outline-sdk v0.0.11 // indirect github.com/akavel/rsrc v0.10.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/x/examples/fyne-proxy/go.sum b/x/examples/fyne-proxy/go.sum index da49719d..7074a9e3 100644 --- a/x/examples/fyne-proxy/go.sum +++ b/x/examples/fyne-proxy/go.sum @@ -45,10 +45,10 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/Jigsaw-Code/outline-sdk v0.0.11 h1:dJ2QMQJJmQ1J4/XvJ9lWpdhg40SKNBmmKffV5wyL00I= -github.com/Jigsaw-Code/outline-sdk v0.0.11/go.mod h1:m+KaWzU05VOIdAC0MVnV0HwyzdzH4WIJ8w7eaMrPp70= -github.com/Jigsaw-Code/outline-sdk/x v0.0.0-20240112224558-7294484cf816 h1:NOoDy92LN2qGLu5s3imiAbYCxH/k6ZlE4N8eXqfZbgw= -github.com/Jigsaw-Code/outline-sdk/x v0.0.0-20240112224558-7294484cf816/go.mod h1:MeTl41RMo9izytmQUWkvulEOMEsZVbuMutAoLwAlshE= +github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212231-233d1898e1db h1:I1guGrgFXY/w+YWt7QNcb3nrTNts5opuPO44XlRF0xI= +github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212231-233d1898e1db/go.mod h1:FtzQwsbvAT55lpc4kmOaHyvfX8MFW8y7yOHL81wHOVQ= +github.com/Jigsaw-Code/outline-sdk/x v0.0.0-20240117212231-233d1898e1db h1:pVN7fcEihOgzIUVBn92y84HsXrZlaHRkUgnm7Rn7lUo= +github.com/Jigsaw-Code/outline-sdk/x v0.0.0-20240117212231-233d1898e1db/go.mod h1:VVVqAev7l5HwkQVZfs79UrXWtmU3rq76+PLAhkSvFRs= github.com/akavel/rsrc v0.10.2 h1:Zxm8V5eI1hW4gGaYsJQUhxpjkENuG91ki8B4zCrvEsw= github.com/akavel/rsrc v0.10.2/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= diff --git a/x/examples/fyne-proxy/main.go b/x/examples/fyne-proxy/main.go index ea15a1df..9de6c345 100644 --- a/x/examples/fyne-proxy/main.go +++ b/x/examples/fyne-proxy/main.go @@ -62,7 +62,7 @@ func newFilteredStreamDialer() transport.StreamDialer { } return nil } - return &transport.TCPStreamDialer{Dialer: dialer} + return &transport.TCPDialer{Dialer: dialer} } func runServer(address, transport string) (*runningProxy, error) { diff --git a/x/go.mod b/x/go.mod index b70fd5dc..e21d972f 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.20 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.11 + github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212550-6cd87709dc1e github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b github.com/stretchr/testify v1.8.2 github.com/vishvananda/netlink v1.1.0 @@ -15,7 +15,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/eycorsican/go-tun2socks v1.16.11 // indirect - github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/shadowsocks/go-shadowsocks2 v0.1.5 // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect diff --git a/x/go.sum b/x/go.sum index 592528fa..0806aa8a 100644 --- a/x/go.sum +++ b/x/go.sum @@ -1,6 +1,5 @@ -github.com/Jigsaw-Code/outline-sdk v0.0.11 h1:dJ2QMQJJmQ1J4/XvJ9lWpdhg40SKNBmmKffV5wyL00I= -github.com/Jigsaw-Code/outline-sdk v0.0.11/go.mod h1:m+KaWzU05VOIdAC0MVnV0HwyzdzH4WIJ8w7eaMrPp70= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212550-6cd87709dc1e h1:56ZI48e68EYYb3m2slu3YJ6C+gWqh8v9bIWk+Bl9dfY= +github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212550-6cd87709dc1e/go.mod h1:9cEaF6sWWMzY8orcUI9pV5D0oFp2FZArTSyJiYtMQQs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -9,7 +8,6 @@ github.com/eycorsican/go-tun2socks v1.16.11/go.mod h1:wgB2BFT8ZaPKyKOQ/5dljMG/YI github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= diff --git a/x/httpproxy/connect_handler.go b/x/httpproxy/connect_handler.go index 16ddba97..c88bc56d 100644 --- a/x/httpproxy/connect_handler.go +++ b/x/httpproxy/connect_handler.go @@ -39,8 +39,8 @@ func isCancelledError(err error) bool { return errors.Is(err, context.Canceled) || strings.HasSuffix(err.Error(), "operation was canceled") } -func (d *sanitizeErrorDialer) Dial(ctx context.Context, addr string) (transport.StreamConn, error) { - conn, err := d.StreamDialer.Dial(ctx, addr) +func (d *sanitizeErrorDialer) DialStream(ctx context.Context, addr string) (transport.StreamConn, error) { + conn, err := d.StreamDialer.DialStream(ctx, addr) if isCancelledError(err) { return nil, context.Canceled } @@ -82,7 +82,7 @@ func (h *connectHandler) ServeHTTP(proxyResp http.ResponseWriter, proxyReq *http http.Error(proxyResp, fmt.Sprintf("Invalid config in Transport header: %v", err), http.StatusBadRequest) return } - targetConn, err := dialer.Dial(proxyReq.Context(), proxyReq.Host) + targetConn, err := dialer.DialStream(proxyReq.Context(), proxyReq.Host) if err != nil { http.Error(proxyResp, fmt.Sprintf("Failed to connect to %v: %v", proxyReq.Host, err), http.StatusServiceUnavailable) return diff --git a/x/httpproxy/forward_handler.go b/x/httpproxy/forward_handler.go index c9b964dc..114a5b38 100644 --- a/x/httpproxy/forward_handler.go +++ b/x/httpproxy/forward_handler.go @@ -71,7 +71,7 @@ func NewForwardHandler(dialer transport.StreamDialer) http.Handler { if !strings.HasPrefix(network, "tcp") { return nil, fmt.Errorf("protocol not supported: %v", network) } - return dialer.Dial(ctx, addr) + return dialer.DialStream(ctx, addr) } return &forwardHandler{http.Client{Transport: &http.Transport{DialContext: dialContext}}} }