diff --git a/handler.go b/handler.go index a9ec1d5..4d8f28a 100644 --- a/handler.go +++ b/handler.go @@ -63,7 +63,7 @@ func (h listenerConnectionHandler) handleUnconnected(b []byte, addr net.Addr) er // an error. return nil } - return fmt.Errorf("unknown packet received (len=%v): %x", len(b), b) + return fmt.Errorf("unknown packet received (id=%x, len=%v)", b[0], len(b)) } // handleUnconnectedPing handles an unconnected ping packet stored in buffer b, diff --git a/listener.go b/listener.go index 0d573c6..9b0825d 100644 --- a/listener.go +++ b/listener.go @@ -3,11 +3,13 @@ package raknet import ( "fmt" "log/slog" + "maps" "math" "math/rand/v2" "net" "sync" "sync/atomic" + "time" ) // UpstreamPacketListener allows for a custom PacketListener implementation. @@ -24,21 +26,31 @@ type ListenConfig struct { // UpstreamPacketListener adds an abstraction for net.ListenPacket. UpstreamPacketListener UpstreamPacketListener + + // DisableCookies specifies if cookies should be generated and verified for + // new incoming connections. This is a security measure against IP spoofing, + // but some server providers (OVH in particular) have existing protection + // systems that interfere with this. In this case, DisableCookies should be + // set to true. + DisableCookies bool + // BlockDuration specifies how long IP addresses should be blocked if an + // error is encountered during the handling of packets from an address. + // BlockDuration defaults to 10s. If set to a negative value, IP addresses + // are never blocked on errors. + BlockDuration time.Duration } // Listener implements a RakNet connection listener. It follows the same // methods as those implemented by the TCPListener in the net package. Listener // implements the net.Listener interface. type Listener struct { - h *listenerConnectionHandler + conf ListenConfig + handler *listenerConnectionHandler + sec *security once sync.Once closed chan struct{} - // log is a logger that errors from packet decoding are logged to. It may be - // set to a logger that simply discards the messages. - log *slog.Logger - conn net.PacketConn // incoming is a channel of incoming connections. Connections that end up in // here will also end up in the connections map. @@ -65,32 +77,37 @@ var listenerID = rand.Int64() // follows the same rules as those defined in the net.TCPListen() function. // Specific features of the listener may be modified once it is returned, such // as the used log and/or the accepted protocol. -func (l ListenConfig) Listen(address string) (*Listener, error) { +func (conf ListenConfig) Listen(address string) (*Listener, error) { + if conf.ErrorLog == nil { + conf.ErrorLog = slog.Default() + } + if conf.BlockDuration == 0 { + conf.BlockDuration = time.Second * 10 + } var conn net.PacketConn var err error - if l.UpstreamPacketListener == nil { + if conf.UpstreamPacketListener == nil { conn, err = net.ListenPacket("udp", address) } else { - conn, err = l.UpstreamPacketListener.ListenPacket("udp", address) + conn, err = conf.UpstreamPacketListener.ListenPacket("udp", address) } if err != nil { return nil, &net.OpError{Op: "listen", Net: "raknet", Source: nil, Addr: nil, Err: err} } listener := &Listener{ + conf: conf, conn: conn, incoming: make(chan *Conn), closed: make(chan struct{}), - log: l.ErrorLog, id: atomic.AddInt64(&listenerID, 1), + sec: newSecurity(conf), } - listener.h = &listenerConnectionHandler{l: listener, cookieSalt: rand.Uint32()} + listener.handler = &listenerConnectionHandler{l: listener, cookieSalt: rand.Uint32()} listener.pongData.Store(new([]byte)) - if l.ErrorLog == nil { - listener.log = slog.Default() - } go listener.listen() + go listener.sec.gc(listener.closed) return listener, nil } @@ -161,11 +178,12 @@ func (listener *Listener) listen() { if err != nil { close(listener.incoming) return - } else if n == 0 { + } else if n == 0 || listener.sec.blocked(addr) { continue } - if err := listener.handle(b[:n], addr); err != nil { - listener.log.Error("listener: handle packet: "+err.Error(), "address", addr.String()) + if err = listener.handle(b[:n], addr); err != nil { + listener.conf.ErrorLog.Error("listener: handle packet: "+err.Error(), "address", addr.String(), "block-duration", max(0, listener.conf.BlockDuration)) + listener.sec.block(addr) } } } @@ -175,7 +193,7 @@ func (listener *Listener) listen() { func (listener *Listener) handle(b []byte, addr net.Addr) error { value, found := listener.connections.Load(resolve(addr)) if !found { - return listener.h.handleUnconnected(b, addr) + return listener.handler.handleUnconnected(b, addr) } conn := value.(*Conn) select { @@ -190,3 +208,78 @@ func (listener *Listener) handle(b []byte, addr net.Addr) error { return nil } } + +// security implements security measurements against DoS attacks against a +// Listener. +type security struct { + conf ListenConfig + + blockCount atomic.Uint32 + + mu sync.Mutex + blocks map[[16]byte]time.Time +} + +// newSecurity uses settings from a ListenConfig to create a security. +func newSecurity(conf ListenConfig) *security { + return &security{conf: conf, blocks: make(map[[16]byte]time.Time)} +} + +// gc clears garbage from the security layer every second until the stop channel +// passed is closed. +func (s *security) gc(stop <-chan struct{}) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.gcBlocks() + case <-stop: + return + } + } +} + +// block stops the handling of packets originating from the IP of a net.Addr. +func (s *security) block(addr net.Addr) { + if s.conf.BlockDuration < 0 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + + s.blockCount.Add(1) + s.blocks[[16]byte(addr.(*net.UDPAddr).IP.To16())] = time.Now() +} + +// blocked checks if the IP of a net.Addr is currently blocked from any packet +// handling. +func (s *security) blocked(addr net.Addr) bool { + if s.conf.BlockDuration < 0 || s.blockCount.Load() == 0 { + // Fast path optimisation: Prevents (relatively costly) map lookups. + return false + } + s.mu.Lock() + defer s.mu.Unlock() + + _, blocked := s.blocks[[16]byte(addr.(*net.UDPAddr).IP.To16())] + return blocked +} + +// gcBlocks removes blocks from the map that are no longer active. gcBlocks only +// attempts to clear outdated blocks if there are two times more blocks active +// than there were after the previous call to gcBlocks. +func (s *security) gcBlocks() { + if s.blockCount.Load() == 0 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + maps.DeleteFunc(s.blocks, func(ip [16]byte, t time.Time) bool { + return now.Sub(t) > s.conf.BlockDuration + }) + s.blockCount.Store(uint32(len(s.blocks))) +}