diff --git a/acknowledge.go b/acknowledge.go index e264704..fd6ca25 100644 --- a/acknowledge.go +++ b/acknowledge.go @@ -97,9 +97,8 @@ func (ack *acknowledgement) read(b []byte) error { if len(b) < 2 { return io.ErrUnexpectedEOF } - n := binary.BigEndian.Uint16(b) offset := 2 - for i := uint16(0); i < n; i++ { + for range binary.BigEndian.Uint16(b) { if len(b)-offset < 4 { return io.ErrUnexpectedEOF } diff --git a/conn.go b/conn.go index a43322b..82b21d9 100644 --- a/conn.go +++ b/conn.go @@ -162,7 +162,6 @@ func (conn *Conn) startTicking() { conn.mu.Lock() acksLeft = len(conn.retransmission.unacknowledged) conn.mu.Unlock() - if before != 0 && acksLeft == 0 { conn.closeImmediately() } @@ -383,13 +382,8 @@ func (conn *Conn) send(pk encoding.BinaryMarshaler) error { return err } -// packetPool is a sync.Pool used to pool packets that encapsulate their -// content. -var packetPool = sync.Pool{ - New: func() interface{} { - return &packet{reliability: reliabilityReliableOrdered} - }, -} +// packetPool is used to pool packets that encapsulate their content. +var packetPool = sync.Pool{New: func() any { return &packet{reliability: reliabilityReliableOrdered} }} // receive receives a packet from the connection, handling it as appropriate. // If not successful, an error is returned. @@ -422,13 +416,13 @@ func (conn *Conn) receiveDatagram(b []byte) error { conn.ackSlice = append(conn.ackSlice, seq) conn.ackMu.Unlock() - if !conn.win.new(seq) { - // Datagram was already received, this might happen if a packet took a long time to arrive, and we already sent - // a NACK for it. This is expected to happen sometimes under normal circumstances, so no reason to return an - // error. + if !conn.win.add(seq) { + // Datagram was already received, this might happen if a packet took a + // long time to arrive, and we already sent a NACK for it. This is + // expected to happen sometimes under normal circumstances, so no reason + // to return an error. return nil } - conn.win.add(seq) if conn.win.shift() == 0 { // Datagram window couldn't be shifted up, so we're still missing // packets. diff --git a/datagram_window.go b/datagram_window.go index 4c95b9c..ec1b1c8 100644 --- a/datagram_window.go +++ b/datagram_window.go @@ -15,21 +15,23 @@ func newDatagramWindow() *datagramWindow { return &datagramWindow{queue: make(map[uint24]time.Time)} } -// new checks if the index passed is new to the datagramWindow. -func (win *datagramWindow) new(index uint24) bool { - if index < win.lowest { - return true +// add puts an index in the window. +func (win *datagramWindow) add(index uint24) bool { + if win.seen(index) { + return false } - _, ok := win.queue[index] - return !ok + win.highest = max(win.highest, index+1) + win.queue[index] = time.Now() + return true } -// add puts an index in the window. -func (win *datagramWindow) add(index uint24) { - if index >= win.highest { - win.highest = index + 1 +// seen checks if the index passed is known to the datagramWindow. +func (win *datagramWindow) seen(index uint24) bool { + if index < win.lowest { + return true } - win.queue[index] = time.Now() + _, ok := win.queue[index] + return ok } // shift attempts to delete as many indices from the queue as possible, @@ -51,9 +53,7 @@ func (win *datagramWindow) shift() (n int) { // set using add while within the window of lowest and highest index. The queue // is shifted after this call. func (win *datagramWindow) missing(since time.Duration) (indices []uint24) { - var ( - missing = false - ) + missing := false for index := int(win.highest) - 1; index >= int(win.lowest); index-- { i := uint24(index) t, ok := win.queue[i] diff --git a/dial.go b/dial.go index 404383e..ea18c9e 100644 --- a/dial.go +++ b/dial.go @@ -230,7 +230,7 @@ func (dialer Dialer) DialContext(ctx context.Context, address string) (*Conn, er } func (dialer Dialer) connect(ctx context.Context, state *connState) (*Conn, error) { - conn := newConn(internal.ConnToPacketConn(state.conn), state.raddr, state.mtu, &dialerConnectionHandler{}) + conn := newConn(internal.ConnToPacketConn(state.conn), state.raddr, state.mtu, dialerConnectionHandler{}) if err := conn.send((&message.ConnectionRequest{ClientGUID: state.id, RequestTimestamp: timestamp()})); err != nil { return nil, dialer.error("dial", fmt.Errorf("send connection request: %w", err)) } @@ -313,8 +313,7 @@ func (state *connState) discoverMTU(ctx context.Context) error { if err := response.UnmarshalBinary(b[1:n]); err != nil { return fmt.Errorf("read open connection reply 1: %w", err) } - state.serverSecurity = response.Secure - state.cookie = response.Cookie + state.serverSecurity, state.cookie = response.Secure, response.Cookie if response.ServerGUID == 0 || response.ServerPreferredMTUSize < 400 || response.ServerPreferredMTUSize > 1500 { // This is an awful hack we cooked up to deal with OVH 'DDoS' // protection. For some reason they send a broken MTU size @@ -342,7 +341,7 @@ func (state *connState) request1(ctx context.Context, sizes []uint16) { defer ticker.Stop() for _, size := range sizes { - for attempt := 0; attempt < 3; attempt++ { + for range 3 { state.openConnectionRequest1(size) select { case <-ticker.C: diff --git a/go.mod b/go.mod index ea3aa2d..123ea9e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/sandertv/go-raknet -go 1.21.0 +go 1.22 diff --git a/handler.go b/handler.go index cdc0658..4500281 100644 --- a/handler.go +++ b/handler.go @@ -21,14 +21,22 @@ var ( errUnexpectedAdditionalNIC = errors.New("unexpected additional NEW_INCOMING_CONNECTION packet") ) -func (h *listenerConnectionHandler) handleUnconnected(b []byte, addr net.Addr) error { +func (h listenerConnectionHandler) limitsEnabled() bool { + return true +} + +func (h listenerConnectionHandler) close(conn *Conn) { + h.l.connections.Delete(resolve(conn.raddr)) +} + +func (h listenerConnectionHandler) handleUnconnected(b []byte, addr net.Addr) error { switch b[0] { case message.IDUnconnectedPing, message.IDUnconnectedPingOpenConnections: - return handleUnconnectedPing(h.l, b[1:], addr) + return h.handleUnconnectedPing(b[1:], addr) case message.IDOpenConnectionRequest1: - return handleOpenConnectionRequest1(h.l, b[1:], addr) + return h.handleOpenConnectionRequest1(b[1:], addr) case message.IDOpenConnectionRequest2: - return handleOpenConnectionRequest2(h.l, b[1:], addr) + return h.handleOpenConnectionRequest2(b[1:], addr) } if b[0]&bitFlagDatagram != 0 { // In some cases, the client will keep trying to send datagrams @@ -39,14 +47,14 @@ func (h *listenerConnectionHandler) handleUnconnected(b []byte, addr net.Addr) e return fmt.Errorf("unknown packet received (len=%v): %x", len(b), b) } -func (h *listenerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, err error) { +func (h listenerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, err error) { switch b[0] { case message.IDConnectionRequest: - return true, handleConnectionRequest(conn, b[1:]) + return true, h.handleConnectionRequest(conn, b[1:]) case message.IDConnectionRequestAccepted: return true, errUnexpectedCRA case message.IDNewIncomingConnection: - return true, handleNewIncomingConnection(conn) + return true, h.handleNewIncomingConnection(conn) case message.IDConnectedPing: return true, handleConnectedPing(conn, b[1:]) case message.IDConnectedPong: @@ -62,29 +70,21 @@ func (h *listenerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, } } -func (h *listenerConnectionHandler) limitsEnabled() bool { - return true -} - -func (h *listenerConnectionHandler) close(conn *Conn) { - h.l.connections.Delete(resolve(conn.raddr)) -} - // handleUnconnectedPing handles an unconnected ping packet stored in buffer b, // coming from an address. -func handleUnconnectedPing(listener *Listener, b []byte, addr net.Addr) error { +func (h listenerConnectionHandler) handleUnconnectedPing(b []byte, addr net.Addr) error { pk := &message.UnconnectedPing{} if err := pk.UnmarshalBinary(b); err != nil { return fmt.Errorf("read UNCONNECTED_PING: %w", err) } - data, _ := (&message.UnconnectedPong{ServerGUID: listener.id, SendTimestamp: pk.SendTimestamp, Data: *listener.pongData.Load()}).MarshalBinary() - _, err := listener.conn.WriteTo(data, addr) + data, _ := (&message.UnconnectedPong{ServerGUID: h.l.id, SendTimestamp: pk.SendTimestamp, Data: *h.l.pongData.Load()}).MarshalBinary() + _, err := h.l.conn.WriteTo(data, addr) return err } // handleOpenConnectionRequest1 handles an open connection request 1 packet // stored in buffer b, coming from an address. -func handleOpenConnectionRequest1(listener *Listener, b []byte, addr net.Addr) error { +func (h listenerConnectionHandler) handleOpenConnectionRequest1(b []byte, addr net.Addr) error { pk := &message.OpenConnectionRequest1{} if err := pk.UnmarshalBinary(b); err != nil { return fmt.Errorf("read OPEN_CONNECTION_REQUEST_1: %w", err) @@ -92,32 +92,32 @@ func handleOpenConnectionRequest1(listener *Listener, b []byte, addr net.Addr) e mtuSize := min(pk.MaximumSizeNotDropped, maxMTUSize) if pk.Protocol != protocolVersion { - data, _ := (&message.IncompatibleProtocolVersion{ServerGUID: listener.id, ServerProtocol: protocolVersion}).MarshalBinary() - _, _ = listener.conn.WriteTo(data, addr) + data, _ := (&message.IncompatibleProtocolVersion{ServerGUID: h.l.id, ServerProtocol: protocolVersion}).MarshalBinary() + _, _ = h.l.conn.WriteTo(data, addr) return fmt.Errorf("handle OPEN_CONNECTION_REQUEST_1: incompatible protocol version %v (listener protocol = %v)", pk.Protocol, protocolVersion) } - data, _ := (&message.OpenConnectionReply1{ServerGUID: listener.id, Secure: false, ServerPreferredMTUSize: mtuSize}).MarshalBinary() - _, err := listener.conn.WriteTo(data, addr) + data, _ := (&message.OpenConnectionReply1{ServerGUID: h.l.id, Secure: false, ServerPreferredMTUSize: mtuSize}).MarshalBinary() + _, err := h.l.conn.WriteTo(data, addr) return err } // handleOpenConnectionRequest2 handles an open connection request 2 packet // stored in buffer b, coming from an address. -func handleOpenConnectionRequest2(listener *Listener, b []byte, addr net.Addr) error { +func (h listenerConnectionHandler) handleOpenConnectionRequest2(b []byte, addr net.Addr) error { pk := &message.OpenConnectionRequest2{} if err := pk.UnmarshalBinary(b); err != nil { return fmt.Errorf("read OPEN_CONNECTION_REQUEST_2: %w", err) } mtuSize := min(pk.ClientPreferredMTUSize, maxMTUSize) - data, _ := (&message.OpenConnectionReply2{ServerGUID: listener.id, ClientAddress: resolve(addr), MTUSize: mtuSize}).MarshalBinary() - if _, err := listener.conn.WriteTo(data, addr); err != nil { + data, _ := (&message.OpenConnectionReply2{ServerGUID: h.l.id, ClientAddress: resolve(addr), MTUSize: mtuSize}).MarshalBinary() + if _, err := h.l.conn.WriteTo(data, addr); err != nil { return fmt.Errorf("send OPEN_CONNECTION_REPLY_2: %w", err) } - conn := newConn(listener.conn, addr, mtuSize, listener.h) - listener.connections.Store(resolve(addr), conn) + conn := newConn(h.l.conn, addr, mtuSize, h) + h.l.connections.Store(resolve(addr), conn) go func() { t := time.NewTimer(time.Second * 10) @@ -126,8 +126,8 @@ func handleOpenConnectionRequest2(listener *Listener, b []byte, addr net.Addr) e case <-conn.connected: // Add the connection to the incoming channel so that a caller of // Accept() can receive it. - listener.incoming <- conn - case <-listener.closed: + h.l.incoming <- conn + case <-h.l.closed: _ = conn.Close() case <-t.C: // It took too long to complete this connection. We closed it and go @@ -139,6 +139,28 @@ func handleOpenConnectionRequest2(listener *Listener, b []byte, addr net.Addr) e return nil } +// handleConnectionRequest handles a connection request packet inside of buffer +// b. An error is returned if the packet was invalid. +func (h listenerConnectionHandler) handleConnectionRequest(conn *Conn, b []byte) error { + pk := &message.ConnectionRequest{} + if err := pk.UnmarshalBinary(b); err != nil { + return fmt.Errorf("read CONNECTION_REQUEST: %w", err) + } + return conn.send(&message.ConnectionRequestAccepted{ClientAddress: resolve(conn.raddr), RequestTimestamp: pk.RequestTimestamp, AcceptedTimestamp: timestamp()}) +} + +// handleNewIncomingConnection handles an incoming connection packet from the +// client, finalising the Conn. +func (h listenerConnectionHandler) handleNewIncomingConnection(conn *Conn) error { + select { + case <-conn.connected: + return errUnexpectedAdditionalNIC + default: + close(conn.connected) + } + return nil +} + type dialerConnectionHandler struct{} var ( @@ -147,12 +169,20 @@ var ( errUnexpectedNIC = errors.New("unexpected NEW_INCOMING_CONNECTION packet") ) -func (h *dialerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, err error) { +func (h dialerConnectionHandler) close(conn *Conn) { + _ = conn.conn.Close() +} + +func (h dialerConnectionHandler) limitsEnabled() bool { + return false +} + +func (h dialerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, err error) { switch b[0] { case message.IDConnectionRequest: return true, errUnexpectedCR case message.IDConnectionRequestAccepted: - return true, handleConnectionRequestAccepted(conn, b[1:]) + return true, h.handleConnectionRequestAccepted(conn, b[1:]) case message.IDNewIncomingConnection: return true, errUnexpectedNIC case message.IDConnectedPing: @@ -170,12 +200,22 @@ func (h *dialerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, er } } -func (h *dialerConnectionHandler) close(conn *Conn) { - _ = conn.conn.Close() -} - -func (h *dialerConnectionHandler) limitsEnabled() bool { - return false +// handleConnectionRequestAccepted handles a serialised connection request +// accepted packet in b, and returns an error if not successful. +func (h dialerConnectionHandler) handleConnectionRequestAccepted(conn *Conn, b []byte) error { + pk := &message.ConnectionRequestAccepted{} + if err := pk.UnmarshalBinary(b); err != nil { + return fmt.Errorf("read CONNECTION_REQUEST_ACCEPTED: %w", err) + } + select { + case <-conn.connected: + return errUnexpectedAdditionalCRA + default: + // Make sure to send NewIncomingConnection before closing conn.connected. + err := conn.send(&message.NewIncomingConnection{ServerAddress: resolve(conn.raddr), RequestTimestamp: pk.AcceptedTimestamp, AcceptedTimestamp: timestamp()}) + close(conn.connected) + return err + } } // handleConnectedPing handles a connected ping packet inside of buffer b. An @@ -204,43 +244,3 @@ func handleConnectedPong(b []byte) error { // unreliable and doesn't give a good idea of the connection quality. return nil } - -// handleConnectionRequest handles a connection request packet inside of buffer -// b. An error is returned if the packet was invalid. -func handleConnectionRequest(conn *Conn, b []byte) error { - pk := &message.ConnectionRequest{} - if err := pk.UnmarshalBinary(b); err != nil { - return fmt.Errorf("read CONNECTION_REQUEST: %w", err) - } - return conn.send(&message.ConnectionRequestAccepted{ClientAddress: resolve(conn.raddr), RequestTimestamp: pk.RequestTimestamp, AcceptedTimestamp: timestamp()}) -} - -// handleConnectionRequestAccepted handles a serialised connection request -// accepted packet in b, and returns an error if not successful. -func handleConnectionRequestAccepted(conn *Conn, b []byte) error { - pk := &message.ConnectionRequestAccepted{} - if err := pk.UnmarshalBinary(b); err != nil { - return fmt.Errorf("read CONNECTION_REQUEST_ACCEPTED: %w", err) - } - select { - case <-conn.connected: - return errUnexpectedAdditionalCRA - default: - // Make sure to send NewIncomingConnection before closing conn.connected. - err := conn.send(&message.NewIncomingConnection{ServerAddress: resolve(conn.raddr), RequestTimestamp: pk.AcceptedTimestamp, AcceptedTimestamp: timestamp()}) - close(conn.connected) - return err - } -} - -// handleNewIncomingConnection handles an incoming connection packet from the -// client, finalising the Conn. -func handleNewIncomingConnection(conn *Conn) error { - select { - case <-conn.connected: - return errUnexpectedAdditionalNIC - default: - close(conn.connected) - } - return nil -} diff --git a/internal/message/connection_request_accepted.go b/internal/message/connection_request_accepted.go index 7b76b63..c7c6c65 100644 --- a/internal/message/connection_request_accepted.go +++ b/internal/message/connection_request_accepted.go @@ -20,7 +20,7 @@ func (pk *ConnectionRequestAccepted) UnmarshalBinary(data []byte) error { var offset int pk.ClientAddress, offset = addr(data) offset += 2 // Zero int16. - for i := 0; i < 20; i++ { + for i := range 20 { if len(data) < addrSize(data[offset:]) { return io.ErrUnexpectedEOF } diff --git a/internal/message/new_incoming_connection.go b/internal/message/new_incoming_connection.go index ba1b81f..1a62bce 100644 --- a/internal/message/new_incoming_connection.go +++ b/internal/message/new_incoming_connection.go @@ -19,7 +19,7 @@ func (pk *NewIncomingConnection) UnmarshalBinary(data []byte) error { } var offset int pk.ServerAddress, offset = addr(data) - for i := 0; i < 20; i++ { + for i := range 20 { if len(data) < addrSize(data[offset:]) { return io.ErrUnexpectedEOF } diff --git a/listener.go b/listener.go index 80247f9..66fb539 100644 --- a/listener.go +++ b/listener.go @@ -30,7 +30,7 @@ type ListenConfig struct { // methods as those implemented by the TCPListener in the net package. Listener // implements the net.Listener interface. type Listener struct { - h *listenerConnectionHandler + h listenerConnectionHandler once sync.Once closed chan struct{} @@ -84,7 +84,7 @@ func (l ListenConfig) Listen(address string) (*Listener, error) { log: l.ErrorLog, id: atomic.AddInt64(&listenerID, 1), } - listener.h = &listenerConnectionHandler{l: listener} + listener.h.l = listener if l.ErrorLog == nil { listener.log = slog.Default() } diff --git a/packet.go b/packet.go index 42f5f99..250a3e0 100644 --- a/packet.go +++ b/packet.go @@ -156,8 +156,9 @@ func (pk *packet) reliable() bool { reliabilityReliableOrdered, reliabilityReliableSequenced: return true + default: + return false } - return false } func (pk *packet) sequencedOrOrdered() bool { @@ -166,8 +167,9 @@ func (pk *packet) sequencedOrOrdered() bool { reliabilityReliableOrdered, reliabilityReliableSequenced: return true + default: + return false } - return false } func (pk *packet) sequenced() bool { @@ -175,8 +177,9 @@ func (pk *packet) sequenced() bool { case reliabilityUnreliableSequenced, reliabilityReliableSequenced: return true + default: + return false } - return false } const ( @@ -210,7 +213,7 @@ func split(b []byte, mtu uint16) [][]byte { // to reserve another fragment for the last bit of the packet. fragmentCount := n/maxSize + min(n%maxSize, 1) fragments := make([][]byte, fragmentCount) - for i := 0; i < fragmentCount-1; i++ { + for i := range fragmentCount - 1 { fragments[i] = b[:maxSize] b = b[maxSize:] }