Skip to content

Commit

Permalink
Rename Dial to DialStream|Packet
Browse files Browse the repository at this point in the history
Rename Connect to Stream|Packet

Rename direct Dialers

Rename UDPPacketListener

Update x

Fix mod
  • Loading branch information
fortuna committed Jan 17, 2024
1 parent 51171d8 commit f09c820
Show file tree
Hide file tree
Showing 34 changed files with 184 additions and 153 deletions.
8 changes: 4 additions & 4 deletions dns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func ensurePort(address string, defaultPort string) string {
func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver {
resolverAddr = ensurePort(resolverAddr, "53")
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
conn, err := pd.Dial(ctx, resolverAddr)
conn, err := pd.DialPacket(ctx, resolverAddr)
if err != nil {
return nil, &nestedError{ErrDial, err}
}
Expand Down Expand Up @@ -301,7 +301,7 @@ func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver {
resolverAddr = ensurePort(resolverAddr, "53")
return &streamResolver{
NewConn: func(ctx context.Context) (transport.StreamConn, error) {
return sd.Dial(ctx, resolverAddr)
return sd.DialStream(ctx, resolverAddr)
},
}
}
Expand All @@ -315,7 +315,7 @@ func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName
resolverAddr = ensurePort(resolverAddr, "853")
return &streamResolver{
NewConn: func(ctx context.Context) (transport.StreamConn, error) {
baseConn, err := sd.Dial(ctx, resolverAddr)
baseConn, err := sd.DialStream(ctx, resolverAddr)
if err != nil {
return nil, err
}
Expand All @@ -336,7 +336,7 @@ func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string
// TODO: Support UDP for QUIC.
return nil, fmt.Errorf("protocol not supported: %v", network)
}
conn, err := sd.Dial(ctx, resolverAddr)
conn, err := sd.DialStream(ctx, resolverAddr)
if err != nil {
return nil, &nestedError{ErrDial, err}
}
Expand Down
8 changes: 4 additions & 4 deletions dns/resolver_net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func newTestContext(t *testing.T) context.Context {

func TestNewUDPResolver(t *testing.T) {
ctx := newTestContext(t)
resolver := NewUDPResolver(&transport.UDPPacketDialer{}, "8.8.8.8")
resolver := NewUDPResolver(&transport.UDPDialer{}, "8.8.8.8")
q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA)
require.NoError(t, err)
resp, err := resolver.Query(ctx, *q)
Expand All @@ -47,7 +47,7 @@ func TestNewUDPResolver(t *testing.T) {

func TestNewTCPResolver(t *testing.T) {
ctx := newTestContext(t)
resolver := NewTCPResolver(&transport.TCPStreamDialer{}, "8.8.8.8")
resolver := NewTCPResolver(&transport.TCPDialer{}, "8.8.8.8")
q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA)
require.NoError(t, err)
resp, err := resolver.Query(ctx, *q)
Expand All @@ -57,7 +57,7 @@ func TestNewTCPResolver(t *testing.T) {

func TestNewTLSResolver(t *testing.T) {
ctx := newTestContext(t)
resolver := NewTLSResolver(&transport.TCPStreamDialer{}, "8.8.8.8", "8.8.8.8")
resolver := NewTLSResolver(&transport.TCPDialer{}, "8.8.8.8", "8.8.8.8")
q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA)
require.NoError(t, err)
resp, err := resolver.Query(ctx, *q)
Expand All @@ -67,7 +67,7 @@ func TestNewTLSResolver(t *testing.T) {

func TestNewHTTPSResolver(t *testing.T) {
ctx := newTestContext(t)
resolver := NewHTTPSResolver(&transport.TCPStreamDialer{}, "8.8.8.8", "https://8.8.8.8/dns-query")
resolver := NewHTTPSResolver(&transport.TCPDialer{}, "8.8.8.8", "https://8.8.8.8/dns-query")
q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA)
require.NoError(t, err)
resp, err := resolver.Query(ctx, *q)
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ require (
github.com/shadowsocks/go-shadowsocks2 v0.1.5
github.com/stretchr/testify v1.8.2
golang.org/x/crypto v0.17.0
golang.org/x/net v0.17.0
golang.org/x/net v0.19.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
golang.org/x/sys v0.15.0 // indirect
Expand Down
8 changes: 5 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand All @@ -8,8 +9,9 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=
Expand All @@ -34,8 +36,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191021144547-ec77196f6094/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
2 changes: 1 addition & 1 deletion network/lwip2transport/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type errTcpUdpHandler struct {
err error
}

func (h *errTcpUdpHandler) Dial(context.Context, string) (transport.StreamConn, error) {
func (h *errTcpUdpHandler) DialStream(context.Context, string) (transport.StreamConn, error) {
return nil, h.err
}

Expand Down
2 changes: 1 addition & 1 deletion network/lwip2transport/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func newTCPHandler(client transport.StreamDialer) *tcpHandler {
}

func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error {
proxyConn, err := h.dialer.Dial(context.Background(), target.String())
proxyConn, err := h.dialer.DialStream(context.Background(), target.String())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion network/packet_listener_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

func TestWithWriteTimeoutOptionWorks(t *testing.T) {
pl := &transport.UDPPacketListener{}
pl := &transport.UDPListener{}

defProxy, err := NewPacketProxyFromPacketListener(pl)
require.NoError(t, err)
Expand Down
48 changes: 24 additions & 24 deletions transport/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import (

// PacketEndpoint represents an endpoint that can be used to establish packet connections (like UDP) to a fixed destination.
type PacketEndpoint interface {
// Connect creates a connection bound to an endpoint, returning the connection.
Connect(ctx context.Context) (net.Conn, error)
// ConnectPacket creates a connection bound to an endpoint, returning the connection.
ConnectPacket(ctx context.Context) (net.Conn, error)
}

// UDPEndpoint is a [PacketEndpoint] that connects to the specified address using UDP.
Expand All @@ -37,8 +37,8 @@ type UDPEndpoint struct {

var _ PacketEndpoint = (*UDPEndpoint)(nil)

// Connect implements [PacketEndpoint].Connect.
func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) {
// ConnectPacket implements [PacketEndpoint].ConnectPacket.
func (e UDPEndpoint) ConnectPacket(ctx context.Context) (net.Conn, error) {
return e.Dialer.DialContext(ctx, "udp", e.Address)
}

Expand All @@ -47,8 +47,8 @@ type FuncPacketEndpoint func(ctx context.Context) (net.Conn, error)

var _ PacketEndpoint = (*FuncPacketEndpoint)(nil)

// Connect implements the [PacketEndpoint] interface.
func (f FuncPacketEndpoint) Connect(ctx context.Context) (net.Conn, error) {
// ConnectPacket implements the [PacketEndpoint] interface.
func (f FuncPacketEndpoint) ConnectPacket(ctx context.Context) (net.Conn, error) {
return f(ctx)
}

Expand All @@ -60,28 +60,28 @@ type PacketDialerEndpoint struct {

var _ PacketEndpoint = (*PacketDialerEndpoint)(nil)

// Connect implements [PacketEndpoint].Connect.
func (e *PacketDialerEndpoint) Connect(ctx context.Context) (net.Conn, error) {
return e.Dialer.Dial(ctx, e.Address)
// ConnectPacket implements [PacketEndpoint].ConnectPacket.
func (e *PacketDialerEndpoint) ConnectPacket(ctx context.Context) (net.Conn, error) {
return e.Dialer.DialPacket(ctx, e.Address)
}

// PacketDialer provides a way to dial a destination and establish datagram connections.
type PacketDialer interface {
// Dial connects to `addr`.
// DialPacket connects to `addr`.
// `addr` has the form "host:port", where "host" can be a domain name or IP address.
Dial(ctx context.Context, addr string) (net.Conn, error)
DialPacket(ctx context.Context, addr string) (net.Conn, error)
}

// UDPPacketDialer is a [PacketDialer] that uses the standard [net.Dialer] to dial.
// UDPDialer is a [PacketDialer] that uses the standard [net.Dialer] to dial.
// It provides a convenient way to use a [net.Dialer] when you need a [PacketDialer].
type UDPPacketDialer struct {
type UDPDialer struct {
Dialer net.Dialer
}

var _ PacketDialer = (*UDPPacketDialer)(nil)
var _ PacketDialer = (*UDPDialer)(nil)

// Dial implements [PacketDialer].Dial.
func (d *UDPPacketDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
// DialPacket implements [PacketDialer].DialPacket.
func (d *UDPDialer) DialPacket(ctx context.Context, addr string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, "udp", addr)
}

Expand All @@ -100,12 +100,12 @@ type boundPacketConn struct {

var _ net.Conn = (*boundPacketConn)(nil)

// Dial implements [PacketDialer].Dial.
// DialPacket implements [PacketDialer].DialPacket.
// The address is in "host:port" format and the host must be either a full IP address (not "[::]") or a domain.
// The address must be supported by the WriteTo call of the [net.PacketConn] returned by the [PacketListener].
// For example, a [net.UDPConn] only supports IP addresses, not domain names.
// If the host is a domain name, consider pre-resolving it to avoid resolution calls.
func (e PacketListenerDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (e PacketListenerDialer) DialPacket(ctx context.Context, address string) (net.Conn, error) {
packetConn, err := e.Listener.ListenPacket(ctx)
if err != nil {
return nil, fmt.Errorf("could not create PacketConn: %w", err)
Expand Down Expand Up @@ -152,17 +152,17 @@ type PacketListener interface {
ListenPacket(ctx context.Context) (net.PacketConn, error)
}

// UDPPacketListener is a [PacketListener] that uses the standard [net.ListenConfig].ListenPacket to listen.
type UDPPacketListener struct {
// UDPListener is a [PacketListener] that uses the standard [net.ListenConfig].ListenPacket to listen.
type UDPListener struct {
net.ListenConfig
// The local address to bind to, as specified in [net.ListenPacket].
Address string
}

var _ PacketListener = (*UDPPacketListener)(nil)
var _ PacketListener = (*UDPListener)(nil)

// ListenPacket implements [PacketListener].ListenPacket
func (l UDPPacketListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
func (l UDPListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
return l.ListenConfig.ListenPacket(ctx, "udp", l.Address)
}

Expand All @@ -171,7 +171,7 @@ type FuncPacketDialer func(ctx context.Context, addr string) (net.Conn, error)

var _ PacketDialer = (*FuncPacketDialer)(nil)

// Dial implements the [PacketDialer] interface.
func (f FuncPacketDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
// DialPacket implements the [PacketDialer] interface.
func (f FuncPacketDialer) DialPacket(ctx context.Context, addr string) (net.Conn, error) {
return f(ctx, addr)
}
28 changes: 14 additions & 14 deletions transport/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestUDPEndpointIPv4(t *testing.T) {
require.Equal(t, serverAddr, address)
return nil
}
conn, err := ep.Connect(context.Background())
conn, err := ep.ConnectPacket(context.Background())
require.NoError(t, err)
assert.Equal(t, "udp", conn.RemoteAddr().Network())
assert.Equal(t, serverAddr, conn.RemoteAddr().String())
Expand All @@ -50,7 +50,7 @@ func TestUDPEndpointIPv6(t *testing.T) {
require.Equal(t, serverAddr, address)
return nil
}
conn, err := ep.Connect(context.Background())
conn, err := ep.ConnectPacket(context.Background())
require.NoError(t, err)
assert.Equal(t, "udp", conn.RemoteAddr().Network())
assert.Equal(t, serverAddr, conn.RemoteAddr().String())
Expand All @@ -64,7 +64,7 @@ func TestUDPEndpointDomain(t *testing.T) {
resolvedAddr = address
return nil
}
conn, err := ep.Connect(context.Background())
conn, err := ep.ConnectPacket(context.Background())
require.NoError(t, err)
assert.Equal(t, "udp", conn.RemoteAddr().Network())
assert.Equal(t, resolvedAddr, conn.RemoteAddr().String())
Expand All @@ -76,7 +76,7 @@ func TestFuncPacketEndpoint(t *testing.T) {
endpoint := FuncPacketEndpoint(func(ctx context.Context) (net.Conn, error) {
return expectedConn, expectedErr
})
conn, err := endpoint.Connect(context.Background())
conn, err := endpoint.ConnectPacket(context.Background())
require.Equal(t, expectedConn, conn)
require.Equal(t, expectedErr, err)
}
Expand All @@ -88,15 +88,15 @@ func TestFuncPacketDialer(t *testing.T) {
require.Equal(t, "unused", addr)
return expectedConn, expectedErr
})
conn, err := dialer.Dial(context.Background(), "unused")
conn, err := dialer.DialPacket(context.Background(), "unused")
require.Equal(t, expectedConn, conn)
require.Equal(t, expectedErr, err)
}

// UDPPacketListener

func TestUDPPacketListenerLocalIPv4Addr(t *testing.T) {
listener := &UDPPacketListener{Address: "127.0.0.1:0"}
listener := &UDPListener{Address: "127.0.0.1:0"}
pc, err := listener.ListenPacket(context.Background())
require.NoError(t, err)
require.Equal(t, "udp", pc.LocalAddr().Network())
Expand All @@ -106,7 +106,7 @@ func TestUDPPacketListenerLocalIPv4Addr(t *testing.T) {
}

func TestUDPPacketListenerLocalIPv6Addr(t *testing.T) {
listener := &UDPPacketListener{Address: "[::1]:0"}
listener := &UDPListener{Address: "[::1]:0"}
pc, err := listener.ListenPacket(context.Background())
require.NoError(t, err)
require.Equal(t, "udp", pc.LocalAddr().Network())
Expand All @@ -116,7 +116,7 @@ func TestUDPPacketListenerLocalIPv6Addr(t *testing.T) {
}

func TestUDPPacketListenerLocalhost(t *testing.T) {
listener := &UDPPacketListener{Address: "localhost:0"}
listener := &UDPListener{Address: "localhost:0"}
pc, err := listener.ListenPacket(context.Background())
require.NoError(t, err)
require.Equal(t, "udp", pc.LocalAddr().Network())
Expand All @@ -126,7 +126,7 @@ func TestUDPPacketListenerLocalhost(t *testing.T) {
}

func TestUDPPacketListenerDefaulAddr(t *testing.T) {
listener := &UDPPacketListener{}
listener := &UDPListener{}
pc, err := listener.ListenPacket(context.Background())
require.Equal(t, "udp", pc.LocalAddr().Network())
require.NoError(t, err)
Expand All @@ -142,8 +142,8 @@ func TestUDPPacketDialer(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "udp", server.LocalAddr().Network())

dialer := &UDPPacketDialer{}
conn, err := dialer.Dial(context.Background(), server.LocalAddr().String())
dialer := &UDPDialer{}
conn, err := dialer.DialPacket(context.Background(), server.LocalAddr().String())
require.NoError(t, err)

request := []byte("PING")
Expand All @@ -169,7 +169,7 @@ func TestPacketListenerDialer(t *testing.T) {
request := []byte("Request")
response := []byte("Response")

serverListener := UDPPacketListener{Address: "127.0.0.1:0"}
serverListener := UDPListener{Address: "127.0.0.1:0"}
serverPacketConn, err := serverListener.ListenPacket(context.Background())
require.NoError(t, err, "Failed to create UDP listener: %v", err)
t.Logf("Listening on %v", serverPacketConn.LocalAddr())
Expand Down Expand Up @@ -202,9 +202,9 @@ func TestPacketListenerDialer(t *testing.T) {
}()

serverEndpoint := &PacketListenerDialer{
Listener: UDPPacketListener{Address: "127.0.0.1:0"},
Listener: UDPListener{Address: "127.0.0.1:0"},
}
conn, err := serverEndpoint.Dial(context.Background(), serverPacketConn.LocalAddr().String())
conn, err := serverEndpoint.DialPacket(context.Background(), serverPacketConn.LocalAddr().String())
require.NoError(t, err)
t.Logf("Connected to %v from %v", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
Expand Down
2 changes: 1 addition & 1 deletion transport/shadowsocks/packet_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (t
}

func (c *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
proxyConn, err := c.endpoint.Connect(ctx)
proxyConn, err := c.endpoint.ConnectPacket(ctx)
if err != nil {
return nil, fmt.Errorf("could not connect to endpoint: %w", err)
}
Expand Down
Loading

0 comments on commit f09c820

Please sign in to comment.