Skip to content

Commit

Permalink
Fix heartbeat read buffer bug (#256)
Browse files Browse the repository at this point in the history
Fix heartbeat read buffer bug, by checking if an insufficient buffer is used in hbConn.Read().
Wrap heartbeat before wrapping SCTPConn so that it also benefits from SCTPConn enhancements
  • Loading branch information
mingyech authored and jmwample committed Oct 23, 2023
1 parent 701913d commit c3aac85
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 67 deletions.
78 changes: 45 additions & 33 deletions pkg/dtls/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@ package dtls

import (
"bytes"
"net"
"errors"
"sync/atomic"
"time"
)

var maxMessageSize = 65535
var ErrInsufficientBuffer = errors.New("buffer too small to hold the received data")

const recvChBufSize = 64

type hbConn struct {
conn net.Conn
recvCh chan errBytes
waiting uint32
hb []byte
timeout time.Duration
stream msgStream

recvCh chan errBytes
waiting uint32
hb []byte
timeout time.Duration
maxMessageSize int
}

type errBytes struct {
Expand All @@ -23,13 +27,14 @@ type errBytes struct {
}

// heartbeatServer listens for heartbeat over conn with config
func heartbeatServer(conn net.Conn, config *heartbeatConfig) (net.Conn, error) {
func heartbeatServer(stream msgStream, config *heartbeatConfig, maxMessageSize int) (*hbConn, error) {
conf := validate(config)

c := &hbConn{conn: conn,
recvCh: make(chan errBytes),
timeout: conf.Interval,
hb: conf.Heartbeat,
c := &hbConn{stream: stream,
recvCh: make(chan errBytes, recvChBufSize),
timeout: conf.Interval,
hb: conf.Heartbeat,
maxMessageSize: maxMessageSize,
}

atomic.StoreUint32(&c.waiting, 2)
Expand All @@ -43,7 +48,7 @@ func heartbeatServer(conn net.Conn, config *heartbeatConfig) (net.Conn, error) {
func (c *hbConn) hbLoop() {
for {
if atomic.LoadUint32(&c.waiting) == 0 {
c.conn.Close()
c.stream.Close()
return
}

Expand All @@ -55,58 +60,65 @@ func (c *hbConn) hbLoop() {

func (c *hbConn) recvLoop() {
for {
// create a buffer to hold your data
buffer := make([]byte, maxMessageSize)
buffer := make([]byte, c.maxMessageSize)

n, err := c.conn.Read(buffer)
n, err := c.stream.Read(buffer)

if bytes.Equal(c.hb, buffer[:n]) {
atomic.AddUint32(&c.waiting, 1)
continue
}

if err != nil {
c.recvCh <- errBytes{nil, err}
}

c.recvCh <- errBytes{buffer[:n], err}
}

}

func (c *hbConn) Close() error {
return c.conn.Close()
return c.stream.Close()
}

func (c *hbConn) Write(b []byte) (n int, err error) {
return c.conn.Write(b)
return c.stream.Write(b)
}

func (c *hbConn) Read(b []byte) (n int, err error) {
func (c *hbConn) Read(b []byte) (int, error) {
readBytes := <-c.recvCh
copy(b, readBytes.b)
if readBytes.err != nil {
return 0, readBytes.err
}

return len(readBytes.b), readBytes.err
}
if len(b) < len(readBytes.b) {
return 0, ErrInsufficientBuffer
}

n := copy(b, readBytes.b)

func (c *hbConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
return n, nil
}

func (c *hbConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
func (c *hbConn) BufferedAmount() uint64 {
return c.stream.BufferedAmount()
}

func (c *hbConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
func (c *hbConn) SetReadDeadline(deadline time.Time) error {
return c.stream.SetReadDeadline(deadline)
}

func (c *hbConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
func (c *hbConn) SetBufferedAmountLowThreshold(th uint64) {
c.stream.SetBufferedAmountLowThreshold(th)
}

func (c *hbConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
func (c *hbConn) OnBufferedAmountLow(f func()) {
c.stream.OnBufferedAmountLow(f)
}

// heartbeatClient sends heartbeats over conn with config
func heartbeatClient(conn net.Conn, config *heartbeatConfig) error {
func heartbeatClient(conn msgStream, config *heartbeatConfig) error {
conf := validate(config)
go func() {
for {
Expand Down
Loading

0 comments on commit c3aac85

Please sign in to comment.