Skip to content

Commit

Permalink
refactor buffer to get rid of b.len and use cap instead
Browse files Browse the repository at this point in the history
  • Loading branch information
jyyi1 committed Nov 21, 2023
1 parent edb97bb commit 87c3592
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 118 deletions.
62 changes: 23 additions & 39 deletions transport/tlsfrag/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package tlsfrag

import (
"bytes"
"errors"
"fmt"
"io"
Expand All @@ -32,10 +33,10 @@ var (

// clientHelloBuffer is a byte buffer used to receive and buffer a TLS Client Hello packet.
type clientHelloBuffer struct {
data []byte // the buffer that hosts both header and content, len: 5 -> 5+len(content)
len int // the actual bytes that have been read into data
valid bool // indicate whether the content in data is a valid TLS Client Hello record
toRead int // the number of bytes to read next, e.g. 5 -> len(content)
data []byte // The buffer that hosts both header and content, cap: 5 -> 5+len(content)+padding
padding int // The unused additional padding allocated at the end of data, 0 -> 5
valid bool // Indicates whether the content in data is a valid TLS Client Hello record
bufrd *bytes.Reader // A reader used to read from the slice passed to Write
}

var _ io.Writer = (*clientHelloBuffer)(nil)
Expand All @@ -45,9 +46,10 @@ var _ io.ReaderFrom = (*clientHelloBuffer)(nil)
func newClientHelloBuffer() *clientHelloBuffer {
// Allocate the 5 bytes header first, and then reallocate it to contain the entire packet later
return &clientHelloBuffer{
data: make([]byte, 0, recordHeaderLen),
valid: true,
toRead: recordHeaderLen,
data: make([]byte, 0, recordHeaderLen),
padding: 0,
valid: true,
bufrd: bytes.NewReader(nil), // It will be Reset in Write
}
}

Expand All @@ -61,31 +63,13 @@ func (b *clientHelloBuffer) Bytes() []byte {
// If an invalid TLS Client Hello message is detected, it returns the error errInvalidTLSClientHello.
// If all bytes in p have been used and the buffer still requires more data to build a complete TLS Client Hello
// message, it returns (len(p), nil).
func (b *clientHelloBuffer) Write(p []byte) (n int, err error) {
if !b.valid {
return 0, errInvalidTLSClientHello
func (b *clientHelloBuffer) Write(p []byte) (int, error) {
b.bufrd.Reset(p)
n, err := b.ReadFrom(b.bufrd)
if err == nil && int(n) != len(p) {
err = io.ErrShortWrite
}

for b.len < len(b.data) && len(p) > 0 {
m := copy(b.data[b.len:], p)
n += m
b.len += m
p = p[m:]

if b.len == recordHeaderLen {
if err = b.validateTLSClientHello(); err != nil {
return
}
buf := make([]byte, recordHeaderLen+getMsgLen(b.data))
copy(buf, b.data)
b.data = buf
}
}

if b.len == len(b.data) {
err = errTLSClientHelloFullyReceived
}
return
return int(n), err
}

// ReadFrom reads all the data from r and appends it to this buffer until a complete Client Hello packet has been
Expand All @@ -102,30 +86,30 @@ func (b *clientHelloBuffer) ReadFrom(r io.Reader) (n int64, err error) {
return 0, errInvalidTLSClientHello
}

for b.len < len(b.data) && err == nil {
m, e := r.Read(b.data[b.len:])
for len(b.data) < cap(b.data)-b.padding && err == nil {
m, e := r.Read(b.data[len(b.data) : cap(b.data)-b.padding])
b.data = b.data[:len(b.data)+m]
n += int64(m)
b.len += m
err = e

if b.len == recordHeaderLen {
if len(b.data) == recordHeaderLen {
if e := b.validateTLSClientHello(); e != nil {
if err == io.EOF {
err = nil
}
err = errors.Join(err, e)
return
}
buf := make([]byte, recordHeaderLen+getMsgLen(b.data))
copy(buf, b.data)
b.data = buf
buf := make([]byte, 0, recordHeaderLen*2+getMsgLen(b.data))
b.data = append(buf, b.data...)
b.padding = recordHeaderLen
}
}

if err == io.EOF {
err = nil
}
if b.len == len(b.data) {
if len(b.data) == cap(b.data)-b.padding {
err = errors.Join(err, errTLSClientHelloFullyReceived)
}
return
Expand Down
6 changes: 2 additions & 4 deletions transport/tlsfrag/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ func TestWriteValidClientHello(t *testing.T) {
require.Equal(t, tc.expectRemaining[k], pkt[n:], tc.msg+": pkt-%d", k)

totalExpectedBytes = append(totalExpectedBytes, pkt[:n]...)
require.Equal(t, len(totalExpectedBytes), buf.Len(), tc.msg+": pkt-%d", k)
require.Equal(t, totalExpectedBytes, buf.Bytes(), tc.msg+": pkt-%d", k)
}
require.Equal(t, len(tc.expectTotalPkt), buf.Len(), tc.msg)
require.Equal(t, tc.expectTotalPkt, buf.Bytes(), tc.msg)
require.Equal(t, len(tc.expectTotalPkt)+5, cap(buf.Bytes()), tc.msg)
}
}

Expand All @@ -67,13 +66,12 @@ func TestReadFromValidClientHello(t *testing.T) {
require.Equal(t, tc.expectRemaining[k], pkt[n:], tc.msg+": pkt-%d", k)

totalExpectedBytes = append(totalExpectedBytes, pkt[:n]...)
require.Equal(t, len(totalExpectedBytes), buf.Len(), tc.msg+": pkt-%d", k)
require.Equal(t, totalExpectedBytes, buf.Bytes(), tc.msg+": pkt-%d", k)
require.Equal(t, len(tc.expectRemaining[k]), r.Len(), tc.msg+": pkt-%d", k)
require.Equal(t, tc.expectRemaining[k], r.Bytes(), tc.msg+": pkt-%d", k)
}
require.Equal(t, len(tc.expectTotalPkt), buf.Len(), tc.msg)
require.Equal(t, tc.expectTotalPkt, buf.Bytes(), tc.msg)
require.Equal(t, len(tc.expectTotalPkt)+5, cap(buf.Bytes()), tc.msg)
}
}

Expand Down
148 changes: 73 additions & 75 deletions transport/tlsfrag/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import (
// are not modified and are directly transmitted through the base [io.Writer].
type clientHelloFragWriter struct {
base io.Writer
done bool // indicates all splitted rcds have been already written to base
done bool // Indicates all splitted rcds have been already written to base
frag FragFunc

buf *clientHelloBuffer // the buffer containing and parsing a TLS Client Hello record
rcds *bytes.Buffer // the buffer containing splitted records what will be written to base
helloBuf *clientHelloBuffer // The buffer containing and parsing a TLS Client Hello record
record *bytes.Buffer // The buffer containing splitted records what will be written to base
}

// clientHelloFragReaderFrom serves as an optimized version of clientHelloFragWriter when the base [io.Writer] also
Expand Down Expand Up @@ -58,9 +58,9 @@ func newClientHelloFragWriter(base io.Writer, frag FragFunc) (io.Writer, error)
return nil, errors.New("frag callback function must not be nil")
}
fw := &clientHelloFragWriter{
base: base,
frag: frag,
buf: newClientHelloBuffer(),
base: base,
frag: frag,
helloBuf: newClientHelloBuffer(),
}
if rf, ok := base.(io.ReaderFrom); ok {
return &clientHelloFragReaderFrom{fw, rf}, nil
Expand All @@ -71,37 +71,30 @@ func newClientHelloFragWriter(base io.Writer, frag FragFunc) (io.Writer, error)
// Write implements io.Writer.Write. It attempts to split the data received in the first one or more Write call(s)
// into two TLS records if the data corresponds to a TLS Client Hello record.
func (w *clientHelloFragWriter) Write(p []byte) (n int, err error) {
if w.done {
return w.base.Write(p)
}
if w.rcds != nil {
if _, err = w.flushRecords(); err != nil {
return
}
return w.base.Write(p)
}

if n, err = w.buf.Write(p); err != nil {
if errors.Is(err, errTLSClientHelloFullyReceived) {
w.splitBufToRecords()
} else {
w.copyBufToRecords()
if !w.done {
// not yet splitted, append to the buffer
if w.record == nil {
if n, err = w.helloBuf.Write(p); err == nil {
// all written, but Client Hello is not fully received yet
return
}
p = p[n:]
if errors.Is(err, errTLSClientHelloFullyReceived) {
w.splitHelloBufToRecord()
} else {
w.copyHelloBufToRecord()
}
}
// We did not call w.Write(p[n:]) here because p[n:] might be empty, and we don't want to
// Write an empty buffer to w.base if it's not initiated by the upstream caller.
if _, err = w.flushRecords(); err != nil {
// already splitted (but previous Writes might fail), try to flush all remaining w.record to w.base
if _, err = w.flushRecord(); err != nil {
return
}
if p = p[n:]; len(p) > 0 {
m, e := w.base.Write(p)
n += m
err = e
}
return
}

if n < len(p) {
return n, io.ErrShortWrite
if len(p) > 0 {
m, e := w.base.Write(p)
n += m
err = e
}
return
}
Expand All @@ -114,65 +107,70 @@ func (w *clientHelloFragWriter) Write(p []byte) (n int, err error) {
//
// It returns the number of bytes read. Any error except EOF encountered during the read is also returned.
func (w *clientHelloFragReaderFrom) ReadFrom(r io.Reader) (n int64, err error) {
if w.done {
return w.baseRF.ReadFrom(r)
}
if w.rcds != nil {
if _, err = w.flushRecords(); err != nil {
if !w.done {
// not yet splitted, append to the buffer
if w.record == nil {
if n, err = w.helloBuf.ReadFrom(r); err == nil {
// EOF, but Client Hello is not fully received yet
return
}
if errors.Is(err, errTLSClientHelloFullyReceived) {
w.splitHelloBufToRecord()
} else {
w.copyHelloBufToRecord()
}
}
// already splitted (but previous Writes might fail), try to flush all remaining w.record to w.base
if _, err = w.flushRecord(); err != nil {
return
}
return w.baseRF.ReadFrom(r)
}

if n, err = w.buf.ReadFrom(r); err != nil {
if errors.Is(err, errTLSClientHelloFullyReceived) {
w.splitBufToRecords()
} else {
w.copyBufToRecords()
}
// recursively flush w.rcds and read the remaining content from r
m, e := w.ReadFrom(r)
return n + m, e
}
m, e := w.baseRF.ReadFrom(r)
n += m
err = e
return
}

// copyBuf copies w.buf into w.rcds.
func (w *clientHelloFragWriter) copyBufToRecords() {
w.rcds = bytes.NewBuffer(w.buf.Bytes())
w.buf = nil // allows the GC to recycle the memory
// copyHelloBufToRecord copies w.helloBuf into w.record without allocations.
func (w *clientHelloFragWriter) copyHelloBufToRecord() {
w.record = bytes.NewBuffer(w.helloBuf.Bytes())
w.helloBuf = nil // allows the GC to recycle the memory
}

// splitBuf splits w.buf into two records and put them into w.rcds.
func (w *clientHelloFragWriter) splitBufToRecords() {
content := w.buf.Bytes()[recordHeaderLen:]
// splitHelloBufToRecord splits w.helloBuf into two records and put them into w.record without allocations.
func (w *clientHelloFragWriter) splitHelloBufToRecord() {
received := w.helloBuf.Bytes()
content := received[recordHeaderLen:]
split := w.frag(content)
if split <= 0 || split >= len(content) {
w.copyBufToRecords()
w.copyHelloBufToRecord()
return
}

header := make([]byte, recordHeaderLen)
copy(header, w.buf.Bytes())

w.rcds = bytes.NewBuffer(make([]byte, 0, w.buf.Len()+recordHeaderLen))

putMsgLen(header, uint16(split))
w.rcds.Write(header)
w.rcds.Write(content[:split])

putMsgLen(header, uint16(len(content)-split))
w.rcds.Write(header)
w.rcds.Write(content[split:])

w.buf = nil // allows the GC to recycle the memory
// received: | <== header (5) ==> | <== split ==> | <== len(content)-split ==> | ... cap with padding (5) ... |
// \ \
// +-----------------+ +-----------------+
// \ \
// splitted: | <== header (5) ==> | <== split ==> | <== header2 (5) ==> | <== len(content)-split ==> |
splitted := received[:len(received)+recordHeaderLen]
hdr1 := splitted[:recordHeaderLen]
hdr2 := splitted[recordHeaderLen+split : recordHeaderLen*2+split]
recvContent2 := splitted[recordHeaderLen+split : len(received)]
content2 := splitted[recordHeaderLen*2+split:]
copy(content2, recvContent2)
copy(hdr2, hdr1)
putMsgLen(hdr1, uint16(split))
putMsgLen(hdr2, uint16(len(content)-split))
w.record = bytes.NewBuffer(splitted)
w.helloBuf = nil // allows the GC to recycle the memory
}

// flushRecords writes all bytes from w.rcds to base.
func (w *clientHelloFragWriter) flushRecords() (int, error) {
n, err := io.Copy(w.base, w.rcds)
if w.rcds.Len() == 0 {
w.rcds = nil // allows the GC to recycle the memory
// flushRecord writes all bytes from w.record to base.
func (w *clientHelloFragWriter) flushRecord() (int, error) {
n, err := io.Copy(w.base, w.record)
if w.record.Len() == 0 {
w.record = nil // allows the GC to recycle the memory
w.done = true
}
return int(n), err
Expand Down

0 comments on commit 87c3592

Please sign in to comment.