diff --git a/transport/tls-record-frag/stream_dialer.go b/transport/tls-record-frag/stream_dialer.go new file mode 100644 index 00000000..a3c6f28b --- /dev/null +++ b/transport/tls-record-frag/stream_dialer.go @@ -0,0 +1,32 @@ +package tlsrecordfrag + +import ( + "context" + "errors" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +type tlsRecordFragDialer struct { + dialer transport.StreamDialer + splitPoint uint32 +} + +var _ transport.StreamDialer = (*tlsRecordFragDialer)(nil) + +// NewStreamDialer creates a [transport.StreamDialer] that splits the Client Hello Message +func NewStreamDialer(dialer transport.StreamDialer, prefixBytes uint32) (transport.StreamDialer, error) { + if dialer == nil { + return nil, errors.New("argument dialer must not be nil") + } + return &tlsRecordFragDialer{dialer: dialer, splitPoint: prefixBytes}, nil +} + +// Dial implements [transport.StreamDialer].Dial. +func (d *tlsRecordFragDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := d.dialer.Dial(ctx, remoteAddr) + if err != nil { + return nil, err + } + return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil +} \ No newline at end of file diff --git a/transport/tls-record-frag/writer.go b/transport/tls-record-frag/writer.go new file mode 100644 index 00000000..60ab2116 --- /dev/null +++ b/transport/tls-record-frag/writer.go @@ -0,0 +1,105 @@ +package tlsrecordfrag + +import ( + "errors" + "io" +) + +type tlsRecordFragWriter struct { + writer io.Writer + prefixBytes uint32 +} + +const maxRecordLength = 16384 + +func NewWriter(writer io.Writer, prefixBytes uint32) *tlsRecordFragWriter { + return &tlsRecordFragWriter{writer, prefixBytes} +} + +func (w *tlsRecordFragWriter) dontFrag(first []byte, source io.Reader) (written int64, err error) { + tmp, err := w.writer.Write(first) + written = int64(tmp) + w.prefixBytes = 0 + if err != nil { + return written, err + } + n, err := io.Copy(w.writer, source) + written += n + return written, err +} + +func (w *tlsRecordFragWriter) ReadFrom(source io.Reader) (written int64, err error) { + if 0 < w.prefixBytes { + var first [5]byte + _, err := io.ReadFull(source, first[:]) + if err != nil { + return 0, err + } + recordLength := uint32(first[3]) << 8 | uint32(first[4]) + if w.prefixBytes >= recordLength { + return w.dontFrag(first[:], source) + } + if recordLength > maxRecordLength { + return 0, errors.New("Broken handshake message") + } + buf := make([]byte, recordLength+10) + n2, err := io.ReadFull(source, buf[5:5+w.prefixBytes]) + if err != nil { + w.prefixBytes = 0 + return 0, err + } + n3, err := io.ReadFull(source, buf[10+w.prefixBytes:]) + if err != nil { + w.prefixBytes = 0 + return 0, err + } + + header := first[:3] + + copy(buf, header) + buf[3] = byte(uint32(n2) >> 8) + buf[4] = byte(uint32(n2) & 0xff) + + copy(buf[5+n2:], header) + buf[5+n2+3] = byte(uint32(n3) >> 8) + buf[5+n2+4] = byte(uint32(n3) & 0xff) + + tmp, err := w.writer.Write(buf) + w.prefixBytes = 0 + written = int64(tmp) + if err != nil { + return written, err + } + } + n, err := io.Copy(w.writer, source) + written += n + return written, err +} + +func (w *tlsRecordFragWriter) Write(data []byte) (written int, err error) { + length := len(data) + if length > 5+maxRecordLength { + return 0, errors.New("Broken handshake message") + } + if 0 < w.prefixBytes && w.prefixBytes < uint32(length -5) { + buf := make([]byte, length+5) + header := data[:3] + record1 := data[5 : 5+w.prefixBytes] + record2 := data[5+w.prefixBytes:] + + copy(buf, header) + buf[3] = byte(w.prefixBytes >> 8) + buf[4] = byte(w.prefixBytes & 0xff) + copy(buf[5:], record1) + + copy(buf[5+w.prefixBytes:], header) + buf[5+3+w.prefixBytes] = byte(len(record2) >> 8) + buf[5+4+w.prefixBytes] = byte(len(record2) & 0xff) + copy(buf[5+5+w.prefixBytes:], record2) + + w.prefixBytes = 0 + return w.writer.Write(buf) + } + w.prefixBytes = 0 + return w.writer.Write(data) +} diff --git a/transport/tls-record-frag/writer_test.go b/transport/tls-record-frag/writer_test.go new file mode 100644 index 00000000..24e276a7 --- /dev/null +++ b/transport/tls-record-frag/writer_test.go @@ -0,0 +1,42 @@ +package tlsrecordfrag + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +type collectWrites struct { + writes [][]byte +} + +var _ io.Writer = (*collectWrites)(nil) + +func (w *collectWrites) Write(data []byte) (int, error) { + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + w.writes = append(w.writes, dataCopy) + return len(data), nil +} + +func TestWrite(t *testing.T) { + data := []byte{0x16, 0x03, 0x01, 0, 10, 0x01, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4} + var innerWriter collectWrites + trfWriter := NewWriter(&innerWriter, 1) + n, err := trfWriter.Write(data) + require.NoError(t, err) + require.Equal(t, n, len(data)+5) + require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 1, 0x1, 0x16, 0x03, 0x01, 0, 9, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4}}, innerWriter.writes) +} + +func TestReadFrom(t *testing.T) { + data := []byte{0x16, 0x03, 0x01, 0, 10, 0x01, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4, 0xff} + var innerWriter collectWrites + trfWriter := NewWriter(&innerWriter, 2) + n, err := trfWriter.ReadFrom(bytes.NewReader(data)) + require.NoError(t, err) + require.Equal(t, n, int64(len(data))+5) + require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 2, 0x1, 0, 0x16, 0x03, 0x01, 0, 8, 0, 6, 0x03, 0x03, 1, 2, 3, 4}, []byte{0xff}}, innerWriter.writes) +}