Skip to content

Commit

Permalink
simplify stream_dialer by extracting out TLS dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
jyyi1 committed Nov 16, 2023
1 parent 653b855 commit 2ed8304
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 331 deletions.
125 changes: 5 additions & 120 deletions transport/tlsfrag/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ package tlsfrag
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"

"github.com/Jigsaw-Code/outline-sdk/transport"
)
Expand All @@ -30,7 +26,6 @@ import (
type tlsFragDialer struct {
dialer transport.StreamDialer
frag FragFunc
config *DialerConfiguration
}

// Compilation guard against interface implementation
Expand All @@ -42,98 +37,34 @@ var _ transport.StreamDialer = (*tlsFragDialer)(nil)
// If the returned index is either ≤ 0 or ≥ len(record), no fragmentation will occur.
//
// [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1
type FragFunc func(record []byte) int

// DialerConfiguration is an internal type used to configure the [transport.StreamDialer] created by
// [NewStreamDialerFunc]. You don't need to work with it directly. Instead, use the provided configuration functions
// like [WithTLSHostPortList].
type DialerConfiguration struct {
addrs []*tlsAddrEntry
}

// DialerConfigurer updates the settings in the internal DialerConfiguration object. You can use the configuration
// functions such as [WithTLSHostPortList] to create configurers and then pass them to NewStreamDialerFunc to create a
// [transport.StreamDialer] with your desired configuration.
type DialerConfigurer func(*DialerConfiguration) error
type FragFunc func(record []byte) (n int)

// NewStreamDialerFunc creates a [transport.StreamDialer] that intercepts the initial [TLS Client Hello]
// [handshake record] and splits it into two separate records before sending them. The split point is determined by the
// callback function frag. The dialer then adds appropriate headers to each record and transmits them sequentially
// using the base dialer. Following the fragmented Client Hello, all subsequent data is passed through directly without
// modification.
//
// NewStreamDialerFunc allows specifying additional options to customize its behavior. By default, if no options are
// specified, the fragmentation only affects TLS Client Hello messages targeting port 443. All other network traffic,
// including non-TLS or non-Client Hello messages, or those targeting other ports, are passed through without any
// modification.
//
// [TLS Client Hello]: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.2
// [handshake record]: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1
func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc, options ...DialerConfigurer) (transport.StreamDialer, error) {
func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc) (transport.StreamDialer, error) {
if base == nil {
return nil, errors.New("base dialer must not be nil")
}
if frag == nil {
return nil, errors.New("frag function must not be nil")
}
config := &DialerConfiguration{
addrs: []*tlsAddrEntry{{"", 443}},
}
for _, opt := range options {
if opt != nil {
if err := opt(config); err != nil {
return nil, err
}
}
}
return &tlsFragDialer{base, frag, config}, nil
}

// WithTLSHostPortList tells the [transport.StreamDialer] which connections to treat as TLS. Only connections matching
// entries in the tlsAddrs list will be treated as TLS traffic and fragmented accordingly.
//
// Each entry in the tlsAddrs list should be in the format "host:port", where "host" can be an IP address or a domain
// name, and "port" must be a valid port number. You can use empty string "" as the "host" to only match based on the
// port, and "0" as the "port" to match any port.
//
// The default list only includes ":443", meaning all traffic on port 443 is treated as TLS. This function overrides
// the entire list. So if you want to add entries, you need to include ":443" along with your additional entries.
//
// Matching for "host" is case-insensitive and strict. For example, "google.com:123" will only match "google.com" and
// not "www.google.com". Subdomain wildcards are not supported.
func WithTLSHostPortList(tlsAddrs []string) DialerConfigurer {
return func(c *DialerConfiguration) error {
addrs := make([]*tlsAddrEntry, 0, len(tlsAddrs))
for _, hostport := range tlsAddrs {
addr, err := parseTLSAddrEntry(hostport)
if err != nil {
return err
}
addrs = append(addrs, addr)
}
c.addrs = addrs
return nil
}
return &tlsFragDialer{base, frag}, nil
}

// Dial implements [transport.StreamConn].Dial. It establishes a connection to raddr in the format "host-or-ip:port".
//
// If raddr matches an entry in the valid TLS address list (which can be configured using [WithTLSHostPortList]), the
// initial TLS Client Hello record sent through the connection will be fragmented.
//
// If raddr is not listed in the valid TLS address list, the function simply utilizes the underlying base dialer's Dial
// function to establish the connection without any fragmentation.
// The initial TLS Client Hello record sent through the connection will be fragmented.
func (d *tlsFragDialer) Dial(ctx context.Context, raddr string) (conn transport.StreamConn, err error) {
conn, err = d.dialer.Dial(ctx, raddr)
if err != nil {
return
}
for _, addr := range d.config.addrs {
if addr.matches(raddr) {
return WrapConnFunc(conn, d.frag)
}
}
return
return WrapConnFunc(conn, d.frag)
}

// WrapConnFunc wraps the base [transport.StreamConn] and splits the first TLS Client Hello packet into two records
Expand All @@ -146,49 +77,3 @@ func WrapConnFunc(base transport.StreamConn, frag FragFunc) (transport.StreamCon
}
return transport.WrapConn(base, base, w), nil
}

// tlsAddrEntry reprsents an entry of the TLS traffic list. See [WithTLSHostPortList].
type tlsAddrEntry struct {
host string
port int
}

// parseTLSAddrEntry parses hostport in format "host:port" and returns the corresponding tlsAddrEntry.
func parseTLSAddrEntry(hostport string) (*tlsAddrEntry, error) {
host, portStr, err := net.SplitHostPort(hostport)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
}
if port < 0 || port > 65535 {
return nil, fmt.Errorf("port must be within 0-65535: %w", strconv.ErrRange)
}
return &tlsAddrEntry{host, port}, nil
}

// matches returns whether raddr matches this entry.
func (e *tlsAddrEntry) matches(raddr string) bool {
if len(e.host) == 0 && e.port == 0 {
return true
}
host, portStr, err := net.SplitHostPort(raddr)
if err != nil {
return false
}
if len(e.host) > 0 && !strings.EqualFold(e.host, host) {
return false
}
if e.port > 0 {
port, err := strconv.Atoi(portStr)
if err != nil {
return false
}
if e.port != port {
return false
}
}
return true
}
187 changes: 0 additions & 187 deletions transport/tlsfrag/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,190 +13,3 @@
// limitations under the License.

package tlsfrag

import (
"context"
"strconv"
"testing"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/stretchr/testify/require"
)

// this is the local conn that can be shared across tests
var theLocalConn = &localConn{}

// Make sure NewStreamDialer returns error on invalid WithTLSHostPortList calls.
func TestNewStreamDialerWithInvalidTLSAddr(t *testing.T) {
cases := []struct {
addr string
errType error // nil indicates general error
}{
{"1.2.3.4", nil},
{":::::", nil},
{"1.2.3.4:654-321", strconv.ErrSyntax},
{"1.2.3.4:--8080", strconv.ErrSyntax},
{"[::]:10000000000", strconv.ErrRange},
{"1.2.3.4:-1234", strconv.ErrRange},
{":654321", strconv.ErrRange},
}
for _, tc := range cases {
d, err := NewStreamDialerFunc(localConnDialer{}, func([]byte) int { return 0 }, WithTLSHostPortList([]string{tc.addr}))
require.Error(t, err, tc.addr)
if tc.errType != nil {
require.ErrorIs(t, err, tc.errType, tc.addr)
}
require.Nil(t, d)
}
}

// Make sure no fragmentation connection is created if raddr is not in the allowed list.
func TestDialFragmentOnTLSAddrOnly(t *testing.T) {
tlsAddrs := []string{
":443", // default entry
":990", // additional FTPS port
":853", // additional DNS-over-TLS port
"pop.gmail.com:995", // Gmail pop3
}
cases := []struct {
msg string
raddrs []string
shouldFrag bool
shouldFragWithList bool
}{
{
msg: "*:443 should be fragmented, raddr = %s",
raddrs: []string{"example.com:443", "66.77.88.99:443", "[2001:db8::1]:443"},
shouldFrag: true,
shouldFragWithList: true,
},
{
msg: "*:990 should be fragmented by allowlist, raddr = %s",
raddrs: []string{"my-test.org:990", "192.168.1.10:990", "[2001:db8:3333:4444:5555:6666:7777:8888]:990"},
shouldFrag: false,
shouldFragWithList: true,
},
{
msg: "*:8080 should not be fragmented, raddr = %s",
raddrs: []string{"google.com:8080", "64.233.191.255:8080", "[2001:db8:3333:4444:5555:6666:7777:8888]:8080"},
shouldFrag: false,
shouldFragWithList: false,
},
{
msg: "DNS ports should not be fragmented, raddr = %s",
raddrs: []string{"8.8.8.8:53", "8.8.4.4:53", "2001:4860:4860::8888", "2001:4860:4860::8844"},
shouldFrag: false,
shouldFragWithList: false,
},
{
msg: "DNS over TLS ports should be fragmented by allowlist, raddr = %s",
raddrs: []string{"9.9.9.9:853", "8.8.4.4:853", "[2001:4860:4860::8844]:853", "[2620:fe::fe]:853"},
shouldFrag: false,
shouldFragWithList: true,
},
{
msg: "only gmail POP3 should be fragmented by allowlist, raddr = %s",
raddrs: []string{"pop.GMail.com:995"},
shouldFrag: false,
shouldFragWithList: true,
},
{
msg: "non-gmail POP3 should not be fragmented, raddr = %s",
raddrs: []string{"8.8.8.8:995", "outlook.office365.com:995", "outlook.office365.com:993", "pop.gmail.com:993"},
shouldFrag: false,
shouldFragWithList: false,
},
}

base := localConnDialer{}
assertShouldFrag := func(conn transport.StreamConn, msg, addr string) {
prevWrCnt := theLocalConn.writeCount
// this Write should not be pushed to theLocalConn yet because it's a valid TLS handshake
conn.Write([]byte{22})

nonFragConn, ok := conn.(*localConn)
require.False(t, ok, msg, addr)
require.Nil(t, nonFragConn, msg)
require.Equal(t, prevWrCnt, theLocalConn.writeCount, msg, addr)
}
assertShouldNotFrag := func(conn transport.StreamConn, msg, addr string) {
prevWrCnt := theLocalConn.writeCount
// this Write should be pushed to theLocalConn because it's a direct Write call
conn.Write([]byte{22})

nonFragConn, ok := conn.(*localConn)
require.True(t, ok, msg, addr)
require.NotNil(t, nonFragConn, msg, addr)
require.Equal(t, theLocalConn, nonFragConn)
require.Equal(t, prevWrCnt+1, theLocalConn.writeCount, msg, addr)
}

// default dialer
d1, err := NewStreamDialerFunc(base, func([]byte) int { return 0 })
require.NoError(t, err)
require.NotNil(t, d1)

// with additional tls addrs
d2, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList(tlsAddrs))
require.NoError(t, err)
require.NotNil(t, d2)

// with no tls addrs
d3, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList([]string{}))
require.NoError(t, err)
require.NotNil(t, d3)

// all traffic
d4, err := NewStreamDialerFunc(base, func([]byte) int { return 0 }, WithTLSHostPortList([]string{":0"}))
require.NoError(t, err)
require.NotNil(t, d4)

for _, tc := range cases {
for _, addr := range tc.raddrs {
conn, err := d1.Dial(context.Background(), addr)
require.NoError(t, err, tc.msg, addr)
require.NotNil(t, conn, tc.msg, addr)
if tc.shouldFrag {
assertShouldFrag(conn, tc.msg, addr)
} else {
assertShouldNotFrag(conn, tc.msg, addr)
}

conn, err = d2.Dial(context.Background(), addr)
require.NoError(t, err, tc.msg, addr)
require.NotNil(t, conn, tc.msg, addr)
if tc.shouldFragWithList {
assertShouldFrag(conn, tc.msg, addr)
} else {
assertShouldNotFrag(conn, tc.msg, addr)
}

conn, err = d3.Dial(context.Background(), addr)
require.NoError(t, err, tc.msg, addr)
require.NotNil(t, conn, tc.msg, addr)
assertShouldNotFrag(conn, tc.msg, addr)

conn, err = d4.Dial(context.Background(), addr)
require.NoError(t, err, tc.msg, addr)
require.NotNil(t, conn, tc.msg, addr)
assertShouldFrag(conn, tc.msg, addr)
}
}
}

// testing utilitites

type localConnDialer struct{}
type localConn struct {
transport.StreamConn
writeCount int
}

func (localConnDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) {
return theLocalConn, nil
}

func (lc *localConn) Write(b []byte) (n int, err error) {
lc.writeCount++
return len(b), nil
}
Loading

0 comments on commit 2ed8304

Please sign in to comment.