Skip to content

Commit

Permalink
Several fixes in the tlsrecordfrag package
Browse files Browse the repository at this point in the history
tlsRecordFragWriter:
 check if input is a handshake record
 return the number of input bytes consumed
 add an RFC link for reference
tlsRecordFragDialer:
 rename field
  • Loading branch information
Lanius-collaris committed Oct 25, 2023
1 parent 12e8acf commit 12e1c28
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
8 changes: 4 additions & 4 deletions transport/tls-record-frag/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import (
)

type tlsRecordFragDialer struct {
dialer transport.StreamDialer
splitPoint int32
dialer transport.StreamDialer
prefixBytes int32
}

var _ transport.StreamDialer = (*tlsRecordFragDialer)(nil)
Expand All @@ -33,7 +33,7 @@ func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int32) (transpor
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
return &tlsRecordFragDialer{dialer: dialer, splitPoint: prefixBytes}, nil
return &tlsRecordFragDialer{dialer: dialer, prefixBytes: prefixBytes}, nil
}

// Dial implements [transport.StreamDialer].Dial.
Expand All @@ -42,5 +42,5 @@ func (d *tlsRecordFragDialer) Dial(ctx context.Context, remoteAddr string) (tran
if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.prefixBytes)), nil
}
17 changes: 13 additions & 4 deletions transport/tls-record-frag/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ type tlsRecordFragWriter struct {
prefixBytes int32
}

const maxRecordLength = 16384 //For the fragments, not for the reassembled record
// Record Layer: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1
const maxRecordLength = 1 << 14 //For the fragments, not for the reassembled record
const typeHandshake = 22

func NewWriter(writer io.Writer, prefixBytes int32) *tlsRecordFragWriter {
return &tlsRecordFragWriter{writer, prefixBytes}
Expand All @@ -50,7 +52,7 @@ func (w *tlsRecordFragWriter) ReadFrom(source io.Reader) (written int64, err err
return 0, err
}
recordLength := int32(binary.BigEndian.Uint16(recordHeader[3:]))
if w.prefixBytes >= recordLength {
if recordHeader[0] != typeHandshake || w.prefixBytes >= recordLength {
return w.dontFrag(recordHeader[:], source)
}
if recordLength > maxRecordLength {
Expand Down Expand Up @@ -78,6 +80,9 @@ func (w *tlsRecordFragWriter) ReadFrom(source io.Reader) (written int64, err err

tmp, err := w.writer.Write(buf)
w.prefixBytes = 0
if tmp >= 5 { //subtract bytes of added header
tmp -= 5
}
written = int64(tmp)
if err != nil {
return written, err
Expand All @@ -103,7 +108,7 @@ func (w *tlsRecordFragWriter) Write(data []byte) (written int, err error) {
hasMultipleRecords := recordLength < remainderLength
isRecordOverflow := recordLength > maxRecordLength

if hasPartialRecord || w.prefixBytes == recordLength || isRecordOverflow {
if data[0] != typeHandshake || hasPartialRecord || w.prefixBytes == recordLength || isRecordOverflow {
w.prefixBytes = 0
return w.writer.Write(data)
}
Expand All @@ -126,7 +131,11 @@ func (w *tlsRecordFragWriter) Write(data []byte) (written int, err error) {
}

w.prefixBytes = 0
return w.writer.Write(buf)
written, err = w.writer.Write(buf)
if written >= 5 {
written -= 5
}
return written, err
}
return w.writer.Write(data)
}
4 changes: 2 additions & 2 deletions transport/tls-record-frag/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestWrite(t *testing.T) {
trfWriter := NewWriter(&innerWriter, 1)
n, err := trfWriter.Write(data)
require.NoError(t, err)
require.Equal(t, n, len(data)+5)
require.Equal(t, n, len(data))
require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 1, 0x1, 0x16, 0x03, 0x01, 0, 9, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4}}, innerWriter.writes)
}

Expand All @@ -54,7 +54,7 @@ func TestReadFrom(t *testing.T) {
trfWriter := NewWriter(&innerWriter, 2)
n, err := trfWriter.ReadFrom(bytes.NewReader(data))
require.NoError(t, err)
require.Equal(t, n, int64(len(data))+5)
require.Equal(t, n, int64(len(data)))
require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 2, 0x1, 0, 0x16, 0x03, 0x01, 0, 8, 0, 6, 0x03, 0x03, 1, 2, 3, 4}, []byte{0xff}}, innerWriter.writes)
}

Expand Down

0 comments on commit 12e1c28

Please sign in to comment.