diff --git a/conn.go b/conn.go index c45592c..e1d32c2 100644 --- a/conn.go +++ b/conn.go @@ -101,7 +101,7 @@ func newConn(conn net.PacketConn, raddr net.Addr, mtu uint16, h connectionHandle pk: new(packet), closed: make(chan struct{}), connected: make(chan struct{}), - packets: internal.Chan[[]byte](4), + packets: internal.Chan[[]byte](4, 4096), splits: make(map[uint16][][]byte), win: newDatagramWindow(), packetQueue: newPacketQueue(), diff --git a/internal/chan.go b/internal/chan.go index df4be2a..62fa35e 100644 --- a/internal/chan.go +++ b/internal/chan.go @@ -5,16 +5,20 @@ import ( "sync/atomic" ) -// ElasticChan is a channel that grows if its capacity is reached. +// ElasticChan is a channel that grows if its capacity is reached. ElasticChan +// is safe for concurrent use with multiple readers and 1 sender. Calling Send +// from multiple goroutines simultaneously is unsafe. type ElasticChan[T any] struct { mu sync.RWMutex len atomic.Int64 ch chan T + lim int64 } // Chan creates an ElasticChan of a size. -func Chan[T any](size int) *ElasticChan[T] { +func Chan[T any](size, max int) *ElasticChan[T] { c := new(ElasticChan[T]) + c.lim = int64(max) c.grow(size) return c } @@ -39,7 +43,7 @@ func (c *ElasticChan[T]) Recv(cancel <-chan struct{}) (val T, ok bool) { // Send sends a value to the channel. Send never blocks, because if the maximum // capacity of the underlying channel is reached, a larger one is created. func (c *ElasticChan[T]) Send(val T) { - if c.len.Load()+1 >= int64(cap(c.ch)) { + if ccap := int64(cap(c.ch)); c.len.Add(1) >= ccap && ccap < c.lim { // This check happens outside a lock, meaning in the meantime, a call to // Recv could cause the length to decrease, technically meaning growing // is then unnecessary. That isn't a major issue though, as in most @@ -47,7 +51,6 @@ func (c *ElasticChan[T]) Send(val T) { c.growSend(val) return } - c.len.Add(1) c.ch <- val } @@ -57,8 +60,7 @@ func (c *ElasticChan[T]) growSend(val T) { c.mu.Lock() defer c.mu.Unlock() - c.grow(cap(c.ch) * 2) - c.len.Add(1) + c.grow(max(cap(c.ch)*2, int(c.lim))) c.ch <- val }