From 0c0e04dad00a4d17bc23aff36377d9f3ed666ea9 Mon Sep 17 00:00:00 2001 From: Leoyoungxh Date: Thu, 10 Aug 2023 16:12:46 +0800 Subject: [PATCH] Remove go_reuseport module dependency --- admin/admin.go | 2 +- go.mod | 9 +- go.sum | 10 +- http/restful_server_transport.go | 2 +- http/transport.go | 2 +- internal/reuseport/reuseport.go | 45 +++++ internal/reuseport/reuseport_bsd.go | 44 +++++ internal/reuseport/reuseport_linux.go | 51 +++++ internal/reuseport/reuseport_windows.go | 33 ++++ internal/reuseport/tcp.go | 167 ++++++++++++++++ internal/reuseport/tcp_linux_test.go | 57 ++++++ internal/reuseport/tcp_test.go | 226 ++++++++++++++++++++++ internal/reuseport/testdata/EmptyLine.txt | 1 + internal/reuseport/testdata/NoEof.txt | 1 + internal/reuseport/testdata/NumMax.txt | 1 + internal/reuseport/testdata/NumZero.txt | 1 + internal/reuseport/udp.go | 158 +++++++++++++++ internal/reuseport/udp_test.go | 157 +++++++++++++++ transport/server_transport.go | 2 +- transport/tnet/server_transport_tcp.go | 2 +- 20 files changed, 953 insertions(+), 18 deletions(-) create mode 100644 internal/reuseport/reuseport.go create mode 100644 internal/reuseport/reuseport_bsd.go create mode 100644 internal/reuseport/reuseport_linux.go create mode 100644 internal/reuseport/reuseport_windows.go create mode 100644 internal/reuseport/tcp.go create mode 100644 internal/reuseport/tcp_linux_test.go create mode 100644 internal/reuseport/tcp_test.go create mode 100644 internal/reuseport/testdata/EmptyLine.txt create mode 100644 internal/reuseport/testdata/NoEof.txt create mode 100644 internal/reuseport/testdata/NumMax.txt create mode 100644 internal/reuseport/testdata/NumZero.txt create mode 100644 internal/reuseport/udp.go create mode 100644 internal/reuseport/udp_test.go diff --git a/admin/admin.go b/admin/admin.go index 5d2b52d..1fc3d2a 100644 --- a/admin/admin.go +++ b/admin/admin.go @@ -15,7 +15,7 @@ import ( "strings" "sync" - reuseport "trpc.group/trpc-go/go_reuseport" + "trpc.group/trpc-go/trpc-go/internal/reuseport" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "trpc.group/trpc-go/trpc-go/config" diff --git a/go.mod b/go.mod index dbacfdb..0e72e3c 100644 --- a/go.mod +++ b/go.mod @@ -28,9 +28,8 @@ require ( golang.org/x/sys v0.4.0 google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v3 v3.0.1 - trpc.group/trpc-go/go_reuseport v1.7.1-0.20230423021710-f5eeff5d87a3 - trpc.group/trpc-go/tnet v0.0.12-0.20230423031524-5eb1cc42f225 - trpc.group/trpc/trpc-protocol/pb/go/trpc v0.1.2-0.20230530025122-c44533fe44bd + trpc.group/trpc-go/tnet v0.0.0-20230810071536-9d05338021cf + trpc.group/trpc/trpc-protocol/pb/go/trpc v0.0.0-20230803031059-de4168eb5952 ) require ( @@ -46,7 +45,3 @@ require ( go.uber.org/multierr v1.6.0 // indirect golang.org/x/text v0.6.0 // indirect ) - -// The hash of current code of v0.11.0 does not match with -// the hash stored in sumdb. -retract v0.11.0 diff --git a/go.sum b/go.sum index 76ceda6..b5aadc8 100644 --- a/go.sum +++ b/go.sum @@ -137,9 +137,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/go_reuseport v1.7.1-0.20230423021710-f5eeff5d87a3 h1:FN4gLJ2xxueQ/LYPxfrAwq6hmdHJaAftFZm8qvqI63o= -trpc.group/trpc-go/go_reuseport v1.7.1-0.20230423021710-f5eeff5d87a3/go.mod h1:9yQwBnMw+3ceJd5MPqP18t6y/EAi7fEh59MiFcqeS78= -trpc.group/trpc-go/tnet v0.0.12-0.20230423031524-5eb1cc42f225 h1:sExJeNiDhhQ7dP6P0afYZ+HU1MnnMYhlnDApSmCLu9U= -trpc.group/trpc-go/tnet v0.0.12-0.20230423031524-5eb1cc42f225/go.mod h1:YUGAw+mutXH1ILO6UzoZYeT7G/9l5xmEDXmp/CvaP/E= -trpc.group/trpc/trpc-protocol/pb/go/trpc v0.1.2-0.20230530025122-c44533fe44bd h1:mFEoX5jBSSy/4a8HSGWreuGVaoj8fxc7eDgROTrMc7A= -trpc.group/trpc/trpc-protocol/pb/go/trpc v0.1.2-0.20230530025122-c44533fe44bd/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= +trpc.group/trpc-go/tnet v0.0.0-20230810071536-9d05338021cf h1:Qo0p6ZJV60Qd5XajiIDidVgx1NDM9UHL7DzDKc2gqns= +trpc.group/trpc-go/tnet v0.0.0-20230810071536-9d05338021cf/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc/trpc-protocol/pb/go/trpc v0.0.0-20230803031059-de4168eb5952 h1:AhjP72IKa1YKnSIayk1X5xSzKrem0EanjZ7oMc2HYOw= +trpc.group/trpc/trpc-protocol/pb/go/trpc v0.0.0-20230803031059-de4168eb5952/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/http/restful_server_transport.go b/http/restful_server_transport.go index 96daf10..09bcb96 100644 --- a/http/restful_server_transport.go +++ b/http/restful_server_transport.go @@ -13,7 +13,7 @@ import ( "time" "github.com/valyala/fasthttp" - reuseport "trpc.group/trpc-go/go_reuseport" + "trpc.group/trpc-go/trpc-go/internal/reuseport" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "trpc.group/trpc-go/trpc-go/codec" diff --git a/http/transport.go b/http/transport.go index b736d99..95ac9f4 100644 --- a/http/transport.go +++ b/http/transport.go @@ -22,7 +22,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" - reuseport "trpc.group/trpc-go/go_reuseport" + "trpc.group/trpc-go/trpc-go/internal/reuseport" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "trpc.group/trpc-go/trpc-go/codec" diff --git a/internal/reuseport/reuseport.go b/internal/reuseport/reuseport.go new file mode 100644 index 0000000..23b6e4d --- /dev/null +++ b/internal/reuseport/reuseport.go @@ -0,0 +1,45 @@ +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd +// +build linux darwin dragonfly freebsd netbsd openbsd + +// Package reuseport provides a function that returns a net.Listener powered +// by a net.FileListener with a SO_REUSEPORT option set to the socket. +package reuseport + +import ( + "errors" + "fmt" + "net" + "os" + "syscall" +) + +const fileNameTemplate = "reuseport.%d.%s.%s" + +var errUnsupportedProtocol = errors.New("only tcp, tcp4, tcp6, udp, udp4, udp6 are supported") + +// getSockaddr parses protocol and address and returns implementor +// of syscall.Sockaddr: syscall.SockaddrInet4 or syscall.SockaddrInet6. +func getSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) { + switch proto { + case "tcp", "tcp4", "tcp6": + return getTCPSockaddr(proto, addr) + case "udp", "udp4", "udp6": + return getUDPSockaddr(proto, addr) + default: + return nil, -1, errUnsupportedProtocol + } +} + +func getSocketFileName(proto, addr string) string { + return fmt.Sprintf(fileNameTemplate, os.Getpid(), proto, addr) +} + +// Listen function is an alias for NewReusablePortListener. +func Listen(proto, addr string) (l net.Listener, err error) { + return NewReusablePortListener(proto, addr) +} + +// ListenPacket is an alias for NewReusablePortPacketConn. +func ListenPacket(proto, addr string) (l net.PacketConn, err error) { + return NewReusablePortPacketConn(proto, addr) +} diff --git a/internal/reuseport/reuseport_bsd.go b/internal/reuseport/reuseport_bsd.go new file mode 100644 index 0000000..86dc450 --- /dev/null +++ b/internal/reuseport/reuseport_bsd.go @@ -0,0 +1,44 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd +// +build darwin dragonfly freebsd netbsd openbsd + +package reuseport + +import ( + "runtime" + "syscall" +) + +var reusePort = syscall.SO_REUSEPORT + +func maxListenerBacklog() int { + var ( + n uint32 + err error + ) + + switch runtime.GOOS { + case "darwin", "freebsd": + n, err = syscall.SysctlUint32("kern.ipc.somaxconn") + case "netbsd": + // NOTE: NetBSD has no somaxconn-like kernel state so far + case "openbsd": + n, err = syscall.SysctlUint32("kern.somaxconn") + default: + } + + return defaultBacklog(n, err) +} + +func defaultBacklog(n uint32, err error) int { + if n == 0 || err != nil { + return syscall.SOMAXCONN + } + + // FreeBSD stores the backlog in a uint16, as does Linux. + // Assume the other BSDs do too. Truncate number to avoid wrapping. + // See issue 5030. + if n > 1<<16-1 { + n = 1<<16 - 1 + } + return int(n) +} diff --git a/internal/reuseport/reuseport_linux.go b/internal/reuseport/reuseport_linux.go new file mode 100644 index 0000000..165c66d --- /dev/null +++ b/internal/reuseport/reuseport_linux.go @@ -0,0 +1,51 @@ +//go:build linux +// +build linux + +package reuseport + +import ( + "bufio" + "os" + "strconv" + "strings" + "syscall" +) + +var reusePort = 0x0F +var maxConnFileName = "/proc/sys/net/core/somaxconn" + +func maxListenerBacklog() int { + fd, err := os.Open(maxConnFileName) + if err != nil { + return syscall.SOMAXCONN + } + defer fd.Close() + + rd := bufio.NewReader(fd) + line, err := rd.ReadString('\n') + if err != nil { + return syscall.SOMAXCONN + } + + f := strings.Fields(line) + if len(f) < 1 { + return syscall.SOMAXCONN + } + + n, err := strconv.Atoi(f[0]) + return defaultBacklog(uint32(n), err) +} + +func defaultBacklog(n uint32, err error) int { + if n == 0 || err != nil { + return syscall.SOMAXCONN + } + + // Linux stores the backlog in a uint16. + // Truncate number to avoid wrapping. + // See issue 5030. + if n > 1<<16-1 { + n = 1<<16 - 1 + } + return int(n) +} diff --git a/internal/reuseport/reuseport_windows.go b/internal/reuseport/reuseport_windows.go new file mode 100644 index 0000000..94d0cfd --- /dev/null +++ b/internal/reuseport/reuseport_windows.go @@ -0,0 +1,33 @@ +//go:build windows +// +build windows + +package reuseport + +import ( + "net" + "syscall" +) + +var ListenerBacklogMaxSize = maxListenerBacklog() + +func maxListenerBacklog() int { + return syscall.SOMAXCONN +} + +func NewReusablePortListener(proto, addr string) (net.Listener, error) { + return net.Listen(proto, addr) +} + +func NewReusablePortPacketConn(proto, addr string) (net.PacketConn, error) { + return net.ListenPacket(proto, addr) +} + +// Listen function is an alias for NewReusablePortListener. +func Listen(proto, addr string) (l net.Listener, err error) { + return NewReusablePortListener(proto, addr) +} + +// ListenPacket is an alias for NewReusablePortPacketConn. +func ListenPacket(proto, addr string) (l net.PacketConn, err error) { + return NewReusablePortPacketConn(proto, addr) +} diff --git a/internal/reuseport/tcp.go b/internal/reuseport/tcp.go new file mode 100644 index 0000000..cb57410 --- /dev/null +++ b/internal/reuseport/tcp.go @@ -0,0 +1,167 @@ +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd +// +build linux darwin dragonfly freebsd netbsd openbsd + +package reuseport + +import ( + "errors" + "net" + "os" + "syscall" +) + +var ( + // ListenerBacklogMaxSize setting backlog size + ListenerBacklogMaxSize = maxListenerBacklog() + errUnsupportedTCPProtocol = errors.New("only tcp, tcp4, tcp6 are supported") +) + +func getTCP4Sockaddr(tcp *net.TCPAddr) (syscall.Sockaddr, int, error) { + sa := &syscall.SockaddrInet4{Port: tcp.Port} + + if tcp.IP != nil { + if len(tcp.IP) == 16 { + copy(sa.Addr[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array + } else { + copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array + } + } + + return sa, syscall.AF_INET, nil +} + +func getTCP6Sockaddr(tcp *net.TCPAddr) (syscall.Sockaddr, int, error) { + sa := &syscall.SockaddrInet6{Port: tcp.Port} + + if tcp.IP != nil { + copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array + } + + if tcp.Zone != "" { + iface, err := net.InterfaceByName(tcp.Zone) + if err != nil { + return nil, -1, err + } + + sa.ZoneId = uint32(iface.Index) + } + + return sa, syscall.AF_INET6, nil +} + +func getTCPAddr(proto, addr string) (*net.TCPAddr, string, error) { + var tcp *net.TCPAddr + + // fix bugs https://github.com/kavu/go_reuseport/pull/33 + tcp, err := net.ResolveTCPAddr(proto, addr) + if err != nil { + return nil, "", err + } + + tcpVersion, err := determineTCPProto(proto, tcp) + if err != nil { + return nil, "", err + } + return tcp, tcpVersion, nil +} + +func getTCPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) { + tcp, tcpVersion, err := getTCPAddr(proto, addr) + if err != nil { + return nil, -1, err + } + switch tcpVersion { + case "tcp": + return &syscall.SockaddrInet4{Port: tcp.Port}, syscall.AF_INET, nil + case "tcp4": + return getTCP4Sockaddr(tcp) + default: + // must be "tcp6" + return getTCP6Sockaddr(tcp) + } +} + +func determineTCPProto(proto string, ip *net.TCPAddr) (string, error) { + // If the protocol is set to "tcp", we try to determine the actual protocol + // version from the size of the resolved IP address. Otherwise, we simple use + // the protcol given to us by the caller. + + if ip.IP.To4() != nil { + return "tcp4", nil + } + + if ip.IP.To16() != nil { + return "tcp6", nil + } + + switch proto { + case "tcp", "tcp4", "tcp6": + return proto, nil + default: + return "", errUnsupportedTCPProtocol + } +} + +// NewReusablePortListener returns net.FileListener that created from +// a file discriptor for a socket with SO_REUSEPORT option. +func NewReusablePortListener(proto, addr string) (l net.Listener, err error) { + var ( + soType, fd int + sockaddr syscall.Sockaddr + ) + if sockaddr, soType, err = getSockaddr(proto, addr); err != nil { + return nil, err + } + + syscall.ForkLock.RLock() + if fd, err = syscall.Socket(soType, syscall.SOCK_STREAM, syscall.IPPROTO_TCP); err != nil { + syscall.ForkLock.RUnlock() + return nil, err + } + syscall.ForkLock.RUnlock() + + if err = createReusableFd(fd, sockaddr); err != nil { + return nil, err + } + return createReusableListener(fd, proto, addr) +} + +func createReusableListener(fd int, proto, addr string) (l net.Listener, err error) { + file := os.NewFile(uintptr(fd), getSocketFileName(proto, addr)) + if l, err = net.FileListener(file); err != nil { + file.Close() + return nil, err + } + + if err = file.Close(); err != nil { + return nil, err + } + return l, err +} + +func createReusableFd(fd int, sockaddr syscall.Sockaddr) (err error) { + defer func() { + if err != nil { + syscall.Close(fd) + } + }() + + if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + return err + } + + if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, reusePort, 1); err != nil { + return err + } + + if err = syscall.Bind(fd, sockaddr); err != nil { + return err + } + + // Set backlog size to the maximum + if err = syscall.Listen(fd, ListenerBacklogMaxSize); err != nil { + return err + } + + return nil +} diff --git a/internal/reuseport/tcp_linux_test.go b/internal/reuseport/tcp_linux_test.go new file mode 100644 index 0000000..6621f4b --- /dev/null +++ b/internal/reuseport/tcp_linux_test.go @@ -0,0 +1,57 @@ +//go:build linux +// +build linux + +package reuseport + +import ( + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMaxListenerBackLog(t *testing.T) { + oldMaxConnFileName := maxConnFileName + defer func() { + maxConnFileName = oldMaxConnFileName + }() + + tests := []struct { + name string + fileName string + want int + }{ + { + name: "file not exist", + fileName: "./testdata/NotExistFile.txt", + want: syscall.SOMAXCONN, + }, + { + name: "file content invalid, no eof", + fileName: "./testdata/NoEof.txt", + want: syscall.SOMAXCONN, + }, + { + name: "empty line", + fileName: "./testdata/EmptyLine.txt", + want: syscall.SOMAXCONN, + }, + { + name: "num zero", + fileName: "./testdata/NumZero.txt", + want: syscall.SOMAXCONN, + }, + { + name: "num 65536", + fileName: "./testdata/NumMax.txt", + want: 65535, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + maxConnFileName = tt.fileName + assert.Equalf(t, tt.want, maxListenerBacklog(), "maxListenerBacklog()") + }) + } +} diff --git a/internal/reuseport/tcp_test.go b/internal/reuseport/tcp_test.go new file mode 100644 index 0000000..5a98b6d --- /dev/null +++ b/internal/reuseport/tcp_test.go @@ -0,0 +1,226 @@ +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd +// +build linux darwin dragonfly freebsd netbsd openbsd + +package reuseport + +import ( + "fmt" + "html" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "os" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + httpServerOneResponse = "1" + httpServerTwoResponse = "2" +) + +var ( + httpServerOne = NewHTTPServer(httpServerOneResponse) + httpServerTwo = NewHTTPServer(httpServerTwoResponse) +) + +func NewHTTPServer(resp string) *httptest.Server { + return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, resp) + })) +} + +func TestNewReusablePortListener(t *testing.T) { + listenerOne, err := NewReusablePortListener("tcp4", "localhost:10081") + assert.Nil(t, err) + defer listenerOne.Close() + + listenerTwo, err := NewReusablePortListener("tcp", "127.0.0.1:10081") + assert.Nil(t, err) + defer listenerTwo.Close() + + // devcloud ipv6地址无效 + _, err = NewReusablePortListener("tcp6", "[::x]:10081") + if err == nil { + t.Errorf("expect err, err[%v]", err) + } + + listenerFour, err := NewReusablePortListener("tcp6", ":10081") + assert.Nil(t, err) + defer listenerFour.Close() + + listenerFive, err := NewReusablePortListener("tcp4", ":10081") + assert.Nil(t, err) + defer listenerFive.Close() + + listenerSix, err := NewReusablePortListener("tcp", ":10081") + assert.Nil(t, err) + defer listenerSix.Close() + + // proto invalid 非法协议 + _, err = NewReusablePortListener("xxx", "") + if err == nil { + t.Errorf("expect err") + } +} + +func TestListen(t *testing.T) { + listenerOne, err := Listen("tcp4", "localhost:10081") + assert.Nil(t, err) + defer listenerOne.Close() + + listenerTwo, err := Listen("tcp", "127.0.0.1:10081") + assert.Nil(t, err) + defer listenerTwo.Close() + + listenerThree, err := Listen("tcp6", ":10081") + assert.Nil(t, err) + defer listenerThree.Close() + + listenerFour, err := Listen("tcp6", ":10081") + assert.Nil(t, err) + defer listenerFour.Close() + + listenerFive, err := Listen("tcp4", ":10081") + assert.Nil(t, err) + defer listenerFive.Close() + + listenerSix, err := Listen("tcp", ":10081") + assert.Nil(t, err) + defer listenerSix.Close() +} + +func TestNewReusablePortServers(t *testing.T) { + listenerOne, err := NewReusablePortListener("tcp4", "localhost:10081") + assert.Nil(t, err) + defer listenerOne.Close() + + //listenerTwo, err := NewReusablePortListener("tcp6", ":10081") + listenerTwo, err := NewReusablePortListener("tcp", "localhost:10081") + assert.Nil(t, err) + defer listenerTwo.Close() + + httpServerOne.Listener = listenerOne + httpServerTwo.Listener = listenerTwo + + httpServerOne.Start() + httpServerTwo.Start() + + // Server One — First Response + httpGet(httpServerOne.URL, httpServerOneResponse, httpServerTwoResponse, t) + + // Server Two — First Response + httpGet(httpServerTwo.URL, httpServerOneResponse, httpServerTwoResponse, t) + httpServerTwo.Close() + + // Server One — Second Response + httpGet(httpServerOne.URL, httpServerOneResponse, "", t) + + // Server One — Third Response + httpGet(httpServerOne.URL, httpServerOneResponse, "", t) + httpServerOne.Close() +} + +func httpGet(url string, expected1 string, expected2 string, t *testing.T) { + resp, err := http.Get(url) + assert.Nil(t, err) + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + assert.Nil(t, err) + if string(body) != expected1 && string(body) != expected2 { + t.Errorf("Expected %#v or %#v, got %#v.", expected1, expected2, string(body)) + } +} + +func BenchmarkNewReusablePortListener(b *testing.B) { + for i := 0; i < b.N; i++ { + listener, err := NewReusablePortListener("tcp", ":10081") + + if err != nil { + b.Error(err) + } else { + listener.Close() + } + } +} + +func ExampleNewReusablePortListener() { + listener, err := NewReusablePortListener("tcp", ":8881") + if err != nil { + panic(err) + } + defer listener.Close() + + server := &http.Server{} + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Println(os.Getgid()) + fmt.Fprintf(w, "Hello, %q\n", html.EscapeString(r.URL.Path)) + }) + + panic(server.Serve(listener)) +} + +// TestBoundaryCase 一些边界条件覆盖 +func TestBoundaryCase(t *testing.T) { + proto, err := determineTCPProto("tcp", &net.TCPAddr{}) + if proto != "tcp" { + t.Errorf("proto not tcp") + } + assert.Nil(t, err) + _, err = determineTCPProto("udp", &net.TCPAddr{}) + if err == nil { + t.Errorf("expect error") + } + + //getTCPAddr 边界 + if _, _, err := getTCPAddr("udp", "localhost:8001"); err == nil { + t.Error("expect error") + } + + //ipv6 zone id,不存在的网卡 + addr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Zone: "ethx", + } + _, _, err = getTCP6Sockaddr(addr) + assert.NotNil(t, err) + + //udp ipv6 + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Zone: "ethx", + } + _, _, err = getUDP6Sockaddr(udpAddr) + assert.NotNil(t, err) + + //ResolveUDPAddr failed + _, _, err = getUDPSockaddr("xxx", ":10086") + assert.NotNil(t, err) +} + +func TestCreateReusableFd(t *testing.T) { + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + assert.Nil(t, err) + assert.NotZero(t, fd) + + // set opt failed, bad fd: -1 + sa := &syscall.SockaddrInet4{} + err = createReusableFd(-1, sa) + assert.NotNil(t, err) + + // set opt failed + oldReusePort := reusePort + defer func() { + reusePort = oldReusePort + }() + reusePort = 0 + err = createReusableFd(fd, sa) + assert.NotNil(t, err) + + // file descriptor invalid + _, err = createReusableListener(10081, "tcp", "localhost:8001") + assert.NotNil(t, err) +} diff --git a/internal/reuseport/testdata/EmptyLine.txt b/internal/reuseport/testdata/EmptyLine.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/internal/reuseport/testdata/EmptyLine.txt @@ -0,0 +1 @@ + diff --git a/internal/reuseport/testdata/NoEof.txt b/internal/reuseport/testdata/NoEof.txt new file mode 100644 index 0000000..f3e53ee --- /dev/null +++ b/internal/reuseport/testdata/NoEof.txt @@ -0,0 +1 @@ +2048 \ No newline at end of file diff --git a/internal/reuseport/testdata/NumMax.txt b/internal/reuseport/testdata/NumMax.txt new file mode 100644 index 0000000..e2ed8f4 --- /dev/null +++ b/internal/reuseport/testdata/NumMax.txt @@ -0,0 +1 @@ +65536 diff --git a/internal/reuseport/testdata/NumZero.txt b/internal/reuseport/testdata/NumZero.txt new file mode 100644 index 0000000..573541a --- /dev/null +++ b/internal/reuseport/testdata/NumZero.txt @@ -0,0 +1 @@ +0 diff --git a/internal/reuseport/udp.go b/internal/reuseport/udp.go new file mode 100644 index 0000000..526c0d6 --- /dev/null +++ b/internal/reuseport/udp.go @@ -0,0 +1,158 @@ +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd +// +build linux darwin dragonfly freebsd netbsd openbsd + +package reuseport + +import ( + "errors" + "net" + "os" + "syscall" +) + +var errUnsupportedUDPProtocol = errors.New("only udp, udp4, udp6 are supported") + +func getUDP4Sockaddr(udp *net.UDPAddr) (syscall.Sockaddr, int, error) { + sa := &syscall.SockaddrInet4{Port: udp.Port} + + if udp.IP != nil { + if len(udp.IP) == 16 { + copy(sa.Addr[:], udp.IP[12:16]) // copy last 4 bytes of slice to array + } else { + copy(sa.Addr[:], udp.IP) // copy all bytes of slice to array + } + } + + return sa, syscall.AF_INET, nil +} + +func getUDP6Sockaddr(udp *net.UDPAddr) (syscall.Sockaddr, int, error) { + sa := &syscall.SockaddrInet6{Port: udp.Port} + + if udp.IP != nil { + copy(sa.Addr[:], udp.IP) // copy all bytes of slice to array + } + + if udp.Zone != "" { + iface, err := net.InterfaceByName(udp.Zone) + if err != nil { + return nil, -1, err + } + + sa.ZoneId = uint32(iface.Index) + } + + return sa, syscall.AF_INET6, nil +} + +func getUDPAddr(proto, addr string) (*net.UDPAddr, string, error) { + + var udp *net.UDPAddr + + udp, err := net.ResolveUDPAddr(proto, addr) + if err != nil { + return nil, "", err + } + + udpVersion, err := determineUDPProto(proto, udp) + if err != nil { + return nil, "", err + } + + return udp, udpVersion, nil +} + +func getUDPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) { + udp, udpVersion, err := getUDPAddr(proto, addr) + if err != nil { + return nil, -1, err + } + + switch udpVersion { + case "udp": + return &syscall.SockaddrInet4{Port: udp.Port}, syscall.AF_INET, nil + case "udp4": + return getUDP4Sockaddr(udp) + default: + // must be "udp6" + return getUDP6Sockaddr(udp) + } +} + +func determineUDPProto(proto string, ip *net.UDPAddr) (string, error) { + // If the protocol is set to "udp", we try to determine the actual protocol + // version from the size of the resolved IP address. Otherwise, we simple use + // the protcol given to us by the caller. + + if ip.IP.To4() != nil { + return "udp4", nil + } + + if ip.IP.To16() != nil { + return "udp6", nil + } + + switch proto { + case "udp", "udp4", "udp6": + return proto, nil + default: + return "", errUnsupportedUDPProtocol + } +} + +// NewReusablePortPacketConn returns net.FilePacketConn that created from +// a file discriptor for a socket with SO_REUSEPORT option. +func NewReusablePortPacketConn(proto, addr string) (net.PacketConn, error) { + sockaddr, soType, err := getSockaddr(proto, addr) + if err != nil { + return nil, err + } + + syscall.ForkLock.RLock() + fd, err := syscall.Socket(soType, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + if err == nil { + syscall.CloseOnExec(fd) + } + syscall.ForkLock.RUnlock() + if err != nil { + syscall.Close(fd) + return nil, err + } + return createPacketConn(fd, sockaddr, getSocketFileName(proto, addr)) +} + +func createPacketConn(fd int, sockaddr syscall.Sockaddr, fdName string) (net.PacketConn, error) { + if err := setPacketConnSockOpt(fd, sockaddr); err != nil { + syscall.Close(fd) + return nil, err + } + + file := os.NewFile(uintptr(fd), fdName) + l, err := net.FilePacketConn(file) + if err != nil { + syscall.Close(fd) + return nil, err + } + + if err = file.Close(); err != nil { + syscall.Close(fd) + return nil, err + } + return l, err +} + +func setPacketConnSockOpt(fd int, sockaddr syscall.Sockaddr) error { + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { + return err + } + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, reusePort, 1); err != nil { + return err + } + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1); err != nil { + return err + } + + return syscall.Bind(fd, sockaddr) +} diff --git a/internal/reuseport/udp_test.go b/internal/reuseport/udp_test.go new file mode 100644 index 0000000..1248441 --- /dev/null +++ b/internal/reuseport/udp_test.go @@ -0,0 +1,157 @@ +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd +// +build linux darwin dragonfly freebsd netbsd openbsd + +package reuseport + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func moreCaseNewReusablePortPacketConn(t *testing.T) { + listenerFour, err := NewReusablePortListener("udp6", ":10081") + assert.Nil(t, err) + defer listenerFour.Close() + + listenerFive, err := NewReusablePortListener("udp4", ":10081") + assert.Nil(t, err) + defer listenerFive.Close() + + listenerSix, err := NewReusablePortListener("udp", ":10081") + assert.Nil(t, err) + defer listenerSix.Close() +} + +func TestNewReusablePortPacketConn(t *testing.T) { + listenerOne, err := NewReusablePortPacketConn("udp4", "localhost:10082") + assert.Nil(t, err) + defer listenerOne.Close() + + listenerTwo, err := NewReusablePortPacketConn("udp", "127.0.0.1:10082") + assert.Nil(t, err) + defer listenerTwo.Close() + + listenerThree, err := NewReusablePortPacketConn("udp6", ":10082") + assert.Nil(t, err) + defer listenerThree.Close() + + moreCaseNewReusablePortPacketConn(t) +} + +func TestListenPacket(t *testing.T) { + type args struct { + proto string + addr string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "case1", + args: args{ + proto: "udp4", + addr: "localhost:10082", + }, + wantErr: false, + }, + { + name: "case2", + args: args{ + proto: "udp", + addr: "localhost:10082", + }, + wantErr: false, + }, + { + name: "case3", + args: args{ + proto: "udp6", + addr: ":10082", + }, + wantErr: false, + }, + { + name: "case4", + args: args{ + proto: "udp4", + addr: ":10081", + }, + wantErr: false, + }, + { + name: "case5", + args: args{ + proto: "udp6", + addr: ":10081", + }, + wantErr: false, + }, + { + name: "case6", + args: args{ + proto: "udp", + addr: ":10081", + }, + wantErr: false, + }, + { + name: "case7", + args: args{ + proto: "udp6_no_ipv_device", + addr: "[::1]:10081", + }, + wantErr: true, + }, + { + name: "case8_not_support_proto", + args: args{ + proto: "xxx", + addr: "[::1]:10081", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotL, err := ListenPacket(tt.args.proto, tt.args.addr) + if (err != nil) != tt.wantErr { + t.Errorf("ListenPacket() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotL != nil { + _ = gotL.Close() + } + }) + } +} + +func BenchmarkNewReusableUDPPortListener(b *testing.B) { + for i := 0; i < b.N; i++ { + listener, err := NewReusablePortPacketConn("udp4", "localhost:10082") + + if err != nil { + b.Error(err) + } else { + listener.Close() + } + } +} + +// TestNewReusablePortPacketConn2 一些边界覆盖 +func TestNewReusablePortPacketConn2(t *testing.T) { + // new socket fd failed, unsupported protocol + _, err := NewReusablePortPacketConn("udp4xx", "localhost:10082") + assert.NotNil(t, err) + + // reusePort failed + oldReusePort := reusePort + defer func() { + reusePort = oldReusePort + }() + reusePort = 0 + _, err = NewReusablePortPacketConn("udp4", "localhost:10082") + assert.NotNil(t, err) +} diff --git a/transport/server_transport.go b/transport/server_transport.go index 9e8abfa..9cb6618 100644 --- a/transport/server_transport.go +++ b/transport/server_transport.go @@ -16,7 +16,7 @@ import ( "time" "github.com/panjf2000/ants/v2" - reuseport "trpc.group/trpc-go/go_reuseport" + "trpc.group/trpc-go/trpc-go/internal/reuseport" itls "trpc.group/trpc-go/trpc-go/internal/tls" "trpc.group/trpc-go/trpc-go/log" diff --git a/transport/tnet/server_transport_tcp.go b/transport/tnet/server_transport_tcp.go index 89e439b..30ee2ad 100644 --- a/transport/tnet/server_transport_tcp.go +++ b/transport/tnet/server_transport_tcp.go @@ -15,9 +15,9 @@ import ( "time" "github.com/panjf2000/ants/v2" - reuseport "trpc.group/trpc-go/go_reuseport" "trpc.group/trpc-go/tnet" "trpc.group/trpc-go/tnet/tls" + "trpc.group/trpc-go/trpc-go/internal/reuseport" "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs"