Skip to content

Commit

Permalink
feat: implement TLS record fragmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Lanius-collaris committed Oct 22, 2023
1 parent 7bc0cae commit 751117a
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 0 deletions.
32 changes: 32 additions & 0 deletions transport/tls-record-frag/stream_dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package tlsrecordfrag

import (
"context"
"errors"

"github.com/Jigsaw-Code/outline-sdk/transport"
)

type tlsRecordFragDialer struct {
dialer transport.StreamDialer
splitPoint uint32
}

var _ transport.StreamDialer = (*tlsRecordFragDialer)(nil)

// NewStreamDialer creates a [transport.StreamDialer] that splits the Client Hello Message
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes uint32) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
return &tlsRecordFragDialer{dialer: dialer, splitPoint: prefixBytes}, nil
}

// Dial implements [transport.StreamDialer].Dial.
func (d *tlsRecordFragDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) {
innerConn, err := d.dialer.Dial(ctx, remoteAddr)
if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil
}
105 changes: 105 additions & 0 deletions transport/tls-record-frag/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package tlsrecordfrag

import (
"errors"
"io"
)

type tlsRecordFragWriter struct {
writer io.Writer
prefixBytes uint32
}

const maxRecordLength = 16384

func NewWriter(writer io.Writer, prefixBytes uint32) *tlsRecordFragWriter {
return &tlsRecordFragWriter{writer, prefixBytes}
}

func (w *tlsRecordFragWriter) dontFrag(first []byte, source io.Reader) (written int64, err error) {
tmp, err := w.writer.Write(first)
written = int64(tmp)
w.prefixBytes = 0
if err != nil {
return written, err
}
n, err := io.Copy(w.writer, source)
written += n
return written, err
}

func (w *tlsRecordFragWriter) ReadFrom(source io.Reader) (written int64, err error) {
if 0 < w.prefixBytes {
var first [5]byte
_, err := io.ReadFull(source, first[:])
if err != nil {
return 0, err
}
recordLength := uint32(first[3]) << 8 | uint32(first[4])
if w.prefixBytes >= recordLength {
return w.dontFrag(first[:], source)
}
if recordLength > maxRecordLength {
return 0, errors.New("Broken handshake message")
}
buf := make([]byte, recordLength+10)
n2, err := io.ReadFull(source, buf[5:5+w.prefixBytes])
if err != nil {
w.prefixBytes = 0
return 0, err
}
n3, err := io.ReadFull(source, buf[10+w.prefixBytes:])
if err != nil {
w.prefixBytes = 0
return 0, err
}

header := first[:3]

copy(buf, header)
buf[3] = byte(uint32(n2) >> 8)
buf[4] = byte(uint32(n2) & 0xff)

copy(buf[5+n2:], header)
buf[5+n2+3] = byte(uint32(n3) >> 8)
buf[5+n2+4] = byte(uint32(n3) & 0xff)

tmp, err := w.writer.Write(buf)
w.prefixBytes = 0
written = int64(tmp)
if err != nil {
return written, err
}
}
n, err := io.Copy(w.writer, source)
written += n
return written, err
}

func (w *tlsRecordFragWriter) Write(data []byte) (written int, err error) {
length := len(data)
if length > 5+maxRecordLength {
return 0, errors.New("Broken handshake message")
}
if 0 < w.prefixBytes && w.prefixBytes < uint32(length -5) {
buf := make([]byte, length+5)
header := data[:3]
record1 := data[5 : 5+w.prefixBytes]
record2 := data[5+w.prefixBytes:]

copy(buf, header)
buf[3] = byte(w.prefixBytes >> 8)
buf[4] = byte(w.prefixBytes & 0xff)
copy(buf[5:], record1)

copy(buf[5+w.prefixBytes:], header)
buf[5+3+w.prefixBytes] = byte(len(record2) >> 8)
buf[5+4+w.prefixBytes] = byte(len(record2) & 0xff)
copy(buf[5+5+w.prefixBytes:], record2)

w.prefixBytes = 0
return w.writer.Write(buf)
}
w.prefixBytes = 0
return w.writer.Write(data)
}
42 changes: 42 additions & 0 deletions transport/tls-record-frag/writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package tlsrecordfrag

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/require"
)

type collectWrites struct {
writes [][]byte
}

var _ io.Writer = (*collectWrites)(nil)

func (w *collectWrites) Write(data []byte) (int, error) {
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
w.writes = append(w.writes, dataCopy)
return len(data), nil
}

func TestWrite(t *testing.T) {
data := []byte{0x16, 0x03, 0x01, 0, 10, 0x01, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4}
var innerWriter collectWrites
trfWriter := NewWriter(&innerWriter, 1)
n, err := trfWriter.Write(data)
require.NoError(t, err)
require.Equal(t, n, len(data)+5)
require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 1, 0x1, 0x16, 0x03, 0x01, 0, 9, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4}}, innerWriter.writes)
}

func TestReadFrom(t *testing.T) {
data := []byte{0x16, 0x03, 0x01, 0, 10, 0x01, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4, 0xff}
var innerWriter collectWrites
trfWriter := NewWriter(&innerWriter, 2)
n, err := trfWriter.ReadFrom(bytes.NewReader(data))
require.NoError(t, err)
require.Equal(t, n, int64(len(data))+5)
require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 2, 0x1, 0, 0x16, 0x03, 0x01, 0, 8, 0, 6, 0x03, 0x03, 1, 2, 3, 4}, []byte{0xff}}, innerWriter.writes)
}

0 comments on commit 751117a

Please sign in to comment.