Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): implement TLS record fragmentation #114

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions transport/tls-record-frag/stream_dialer.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call the package tlssplit, so it can be more easily used. Or perhaps tlsfrag is better.
@jyyi1 thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like tlsfrag. And for the stream dialer's name, I prefer tlsClientHelloFragStreamDialer to indicate that it will only be applied to client hello packets.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package tlsrecordfrag
jyyi1 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
package tlsrecordfrag
package tlsfrag

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the directory match the package name.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tlssplit or tlsfrag?


import (
"context"
"errors"

"github.com/Jigsaw-Code/outline-sdk/transport"
)

type tlsRecordFragDialer struct {
dialer transport.StreamDialer
splitPoint uint32
}

var _ transport.StreamDialer = (*tlsRecordFragDialer)(nil)

// NewStreamDialer creates a [transport.StreamDialer] that splits the Client Hello Message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to clarify the behavior.

Suggested change
// NewStreamDialer creates a [transport.StreamDialer] that splits the Client Hello Message
// NewStreamDialer creates a [transport.StreamDialer] that splits the first TLS record if it's a handshake message. If the record has length N, the new records will have length prefixBytes and N - prefixBytes, plus their headers.

@jyyi1 let's make sure we agree on this behavior.

func NewStreamDialer(dialer transport.StreamDialer, prefixBytes uint32) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return nil, errors.New("argument dialer must not be nil")
return nil, errors.New("an inner dialer is required")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code is consistent with the split Dialer:

return nil, errors.New("argument dialer must not be nil")

I think we can keep it as is for now, and change them all later.

}
return &tlsRecordFragDialer{dialer: dialer, splitPoint: prefixBytes}, nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check prefixBytes as well?

Suggested change
return &tlsRecordFragDialer{dialer: dialer, splitPoint: prefixBytes}, nil
if prefixBytes <= 0 {
return nil, errors.New("prefixBytes must be positive")
}
return &tlsRecordFragDialer{dialer: dialer, splitPoint: prefixBytes}, nil

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code is consistent with split, and that's helpful for the user.
Also, returning an error makes the use harder. We can just assume <=0 means disabled (perhaps document that).

Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
}

// Dial implements [transport.StreamDialer].Dial.
func (d *tlsRecordFragDialer) Dial(ctx context.Context, remoteAddr string) (transport.StreamConn, error) {
innerConn, err := d.dialer.Dial(ctx, remoteAddr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to inspect remoteAddr, and make sure that we are only splitting TLS traffic (not regular TCP traffic). Here are some common ports serving TLS traffic (feel free to add more):

  • 443, HTTPS
  • 990, FTPS
  • 563, NNTPS
  • 636, LDAPS
  • 993, IMAPS
  • 995, POP3S

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. We should be able to restrict to specific ports.
But I would like the user to be able to specify that list.

@jyyi1 what are your thoughts on the questions below?

How to specify the ports?
Perhaps we pass a function that takes a host and port, and outputs a binary decision.
It could be passed in the New function with variadic options, like we do here:
https://github.com/Jigsaw-Code/outline-sdk/blob/375a66df04ef48659c73f0a7ed1156b04b2008a4/network/packet_listener_proxy.go#L58C68-L58C75

Or as a setter.

What's the default behavior?
I think the default should be either all or nothing. But perhaps 443, since that's the most common scenario.

Or perhaps we can force the user to make a decision by making the function mandatory in the constructor. That's probably my favorite. We can provide constant functions for "all" or "443 only". Also a EnableForPorts(portList) that returns a function that enables for that list. That way someone can do:

dialer, err := tlsfrag.NewStreamDialer(inner, 10, tlsfrag.EnablePorts(int[]{443}))

or

dialer, err := tlsfrag.NewStreamDialer(inner, 10, tlsfrag.Enable443)

Reference
Please link to https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml?search=TLS

if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil
}
105 changes: 105 additions & 0 deletions transport/tls-record-frag/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package tlsrecordfrag

import (
"errors"
"io"
)

type tlsRecordFragWriter struct {
writer io.Writer
prefixBytes uint32
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
}

const maxRecordLength = 16384

func NewWriter(writer io.Writer, prefixBytes uint32) *tlsRecordFragWriter {
return &tlsRecordFragWriter{writer, prefixBytes}
}

func (w *tlsRecordFragWriter) dontFrag(first []byte, source io.Reader) (written int64, err error) {
tmp, err := w.writer.Write(first)
written = int64(tmp)
w.prefixBytes = 0
if err != nil {
return written, err
}
n, err := io.Copy(w.writer, source)
written += n
return written, err
}

func (w *tlsRecordFragWriter) ReadFrom(source io.Reader) (written int64, err error) {
if 0 < w.prefixBytes {
var first [5]byte
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
_, err := io.ReadFull(source, first[:])
if err != nil {
return 0, err
}
recordLength := uint32(first[3]) << 8 | uint32(first[4])
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
if w.prefixBytes >= recordLength {
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
return w.dontFrag(first[:], source)
}
if recordLength > maxRecordLength {
return 0, errors.New("Broken handshake message")
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
}
buf := make([]byte, recordLength+10)
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
n2, err := io.ReadFull(source, buf[5:5+w.prefixBytes])
if err != nil {
w.prefixBytes = 0
return 0, err
}
n3, err := io.ReadFull(source, buf[10+w.prefixBytes:])
if err != nil {
w.prefixBytes = 0
return 0, err
}

header := first[:3]

copy(buf, header)
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
buf[3] = byte(uint32(n2) >> 8)
buf[4] = byte(uint32(n2) & 0xff)

copy(buf[5+n2:], header)
buf[5+n2+3] = byte(uint32(n3) >> 8)
buf[5+n2+4] = byte(uint32(n3) & 0xff)

tmp, err := w.writer.Write(buf)
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
w.prefixBytes = 0
written = int64(tmp)
if err != nil {
return written, err
}
}
n, err := io.Copy(w.writer, source)
written += n
return written, err
}

func (w *tlsRecordFragWriter) Write(data []byte) (written int, err error) {
length := len(data)
if length > 5+maxRecordLength {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the correct behavior, even if the connections is guaranteed TLS. A given write may have a single record, a partial record or multiple records. We need to fix this.

  • remove the check
  • read the record length from the data
  • fix the split calculation

We should assume the first write has the record length.

return 0, errors.New("Broken handshake message")
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
}
if 0 < w.prefixBytes && w.prefixBytes < uint32(length -5) {
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
buf := make([]byte, length+5)
header := data[:3]
record1 := data[5 : 5+w.prefixBytes]
record2 := data[5+w.prefixBytes:]

copy(buf, header)
buf[3] = byte(w.prefixBytes >> 8)
buf[4] = byte(w.prefixBytes & 0xff)
copy(buf[5:], record1)

copy(buf[5+w.prefixBytes:], header)
buf[5+3+w.prefixBytes] = byte(len(record2) >> 8)
buf[5+4+w.prefixBytes] = byte(len(record2) & 0xff)
copy(buf[5+5+w.prefixBytes:], record2)

w.prefixBytes = 0
return w.writer.Write(buf)
Lanius-collaris marked this conversation as resolved.
Show resolved Hide resolved
}
w.prefixBytes = 0
return w.writer.Write(data)
}
42 changes: 42 additions & 0 deletions transport/tls-record-frag/writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package tlsrecordfrag

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/require"
)

type collectWrites struct {
writes [][]byte
}

var _ io.Writer = (*collectWrites)(nil)

func (w *collectWrites) Write(data []byte) (int, error) {
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
w.writes = append(w.writes, dataCopy)
return len(data), nil
}

func TestWrite(t *testing.T) {
data := []byte{0x16, 0x03, 0x01, 0, 10, 0x01, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4}
var innerWriter collectWrites
trfWriter := NewWriter(&innerWriter, 1)
n, err := trfWriter.Write(data)
require.NoError(t, err)
require.Equal(t, n, len(data)+5)
require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 1, 0x1, 0x16, 0x03, 0x01, 0, 9, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4}}, innerWriter.writes)
}

func TestReadFrom(t *testing.T) {
data := []byte{0x16, 0x03, 0x01, 0, 10, 0x01, 0, 0, 6, 0x03, 0x03, 1, 2, 3, 4, 0xff}
var innerWriter collectWrites
trfWriter := NewWriter(&innerWriter, 2)
n, err := trfWriter.ReadFrom(bytes.NewReader(data))
require.NoError(t, err)
require.Equal(t, n, int64(len(data))+5)
require.Equal(t, [][]byte{[]byte{0x16, 0x03, 0x01, 0, 2, 0x1, 0, 0x16, 0x03, 0x01, 0, 8, 0, 6, 0x03, 0x03, 1, 2, 3, 4}, []byte{0xff}}, innerWriter.writes)
}