From 9a288f7478c6957676be7a24f20b80949da3f2fe Mon Sep 17 00:00:00 2001 From: Petr Zhizhin Date: Sun, 10 Nov 2024 01:23:18 +0100 Subject: [PATCH] Allow for a connection to check if it's sending bytes --- x/sockopt/is_sending_bytes_linux.go | 42 +++++++++++++++ x/sockopt/is_sending_bytes_not_implemented.go | 17 ++++++ x/sockopt/sockopt.go | 53 ++++++++++++++++++- 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 x/sockopt/is_sending_bytes_linux.go create mode 100644 x/sockopt/is_sending_bytes_not_implemented.go diff --git a/x/sockopt/is_sending_bytes_linux.go b/x/sockopt/is_sending_bytes_linux.go new file mode 100644 index 00000000..af0b2eb4 --- /dev/null +++ b/x/sockopt/is_sending_bytes_linux.go @@ -0,0 +1,42 @@ +//go:build linux + +package sockopt + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func isSocketFdSendingBytes(fd int) (bool, error) { + tcpInfo, err := unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + if err != nil { + return false, err + } + + // 1 == TCP_ESTABLISHED, but for some reason not available in the package + if tcpInfo.State != 1 { + // If the connection is not established, the socket is not sending bytes + return false, nil + } + + return tcpInfo.Notsent_bytes != 0, nil +} + +func isConnectionSendingBytesImplemented() bool { + return true +} + +func isConnectionSendingBytes(conn *net.TCPConn) (bool, error) { + syscallConn, err := conn.SyscallConn() + if err != nil { + return false, err + } + var result bool + syscallConn.Control(func(fd uintptr) { + innerResult, innerErr := isSocketFdSendingBytes(int(fd)) + result = innerResult + err = innerErr + }) + return result, err +} diff --git a/x/sockopt/is_sending_bytes_not_implemented.go b/x/sockopt/is_sending_bytes_not_implemented.go new file mode 100644 index 00000000..b76024d4 --- /dev/null +++ b/x/sockopt/is_sending_bytes_not_implemented.go @@ -0,0 +1,17 @@ +//go:build !linux + +package sockopt + +import ( + "errors" + "fmt" + "net" +) + +func isConnectionSendingBytesImplemented() bool { + return false +} + +func isConnectionSendingBytes(_ *net.TCPConn) (bool, error) { + return false, fmt.Errorf("%w: checking if socket is sending bytes is not implemented on this platform", errors.ErrUnsupported) +} diff --git a/x/sockopt/sockopt.go b/x/sockopt/sockopt.go index 44cb575e..e22744d6 100644 --- a/x/sockopt/sockopt.go +++ b/x/sockopt/sockopt.go @@ -19,11 +19,21 @@ import ( "fmt" "net" "net/netip" + "time" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) +type HasWaitUntilBytesAreSent interface { + // Wait until all bytes are sent to the socket. + // Returns ErrUnsupported if the platform doesn't support it. + // May return a different error. + WaitUntilBytesAreSent() error + // Checks if the OS supports waiting until the bytes are sent + OsSupportsWaitingUntilBytesAreSent() bool +} + // HasHopLimit enables manipulation of the hop limit option. type HasHopLimit interface { // HopLimit returns the hop limit field value for outgoing packets. @@ -50,15 +60,51 @@ var _ HasHopLimit = (*hopLimitOption)(nil) // TCPOptions represents options for TCP connections. type TCPOptions interface { + HasWaitUntilBytesAreSent HasHopLimit } type tcpOptions struct { hopLimitOption + + conn *net.TCPConn + + // Timeout after which we return an error + waitingTimeout time.Duration + // Delay between checking the socket + waitingDelay time.Duration } var _ TCPOptions = (*tcpOptions)(nil) +func (o *tcpOptions) SetWaitingTimeout(timeout time.Duration) { + o.waitingTimeout = timeout +} + +func (o *tcpOptions) SetWaitingDelay(delay time.Duration) { + o.waitingDelay = delay +} + +func (o *tcpOptions) OsSupportsWaitingUntilBytesAreSent() bool { + return isConnectionSendingBytesImplemented() +} + +func (o *tcpOptions) WaitUntilBytesAreSent() error { + startTime := time.Now() + for time.Since(startTime) < o.waitingTimeout { + isSendingBytes, err := isConnectionSendingBytes(o.conn) + if err != nil { + return err + } + if !isSendingBytes { + return nil + } + + time.Sleep(o.waitingDelay) + } + return fmt.Errorf("waiting for socket to send all bytes: timeout exceeded") +} + // newHopLimit creates a hopLimitOption from a [net.Conn]. Works for both TCP or UDP. func newHopLimit(conn net.Conn) (*hopLimitOption, error) { addr, err := netip.ParseAddrPort(conn.LocalAddr().String()) @@ -87,5 +133,10 @@ func NewTCPOptions(conn *net.TCPConn) (TCPOptions, error) { if err != nil { return nil, err } - return &tcpOptions{hopLimitOption: *hopLimit}, nil + return &tcpOptions{ + hopLimitOption: *hopLimit, + conn: conn, + waitingTimeout: 10 * time.Millisecond, + waitingDelay: 100 * time.Microsecond, + }, nil }