From 5a687a0d72019ee9eb630d097a1b4c545179b240 Mon Sep 17 00:00:00 2001 From: jyyi1 Date: Thu, 16 Nov 2023 18:39:21 -0500 Subject: [PATCH] add test cases for stream dialer --- network/dnstruncate/packet_proxy_test.go | 18 +-- transport/tlsfrag/stream_dialer_test.go | 141 +++++++++++++++++++++++ 2 files changed, 150 insertions(+), 9 deletions(-) diff --git a/network/dnstruncate/packet_proxy_test.go b/network/dnstruncate/packet_proxy_test.go index 44d53b5d..cf2112b9 100644 --- a/network/dnstruncate/packet_proxy_test.go +++ b/network/dnstruncate/packet_proxy_test.go @@ -164,15 +164,15 @@ func constructDNSQuestionsFromDomainNames(questions []string) []layers.DNSQuesti // constructDNSRequestOrResponse creates the following DNS request/response: // -// [ `id` ]: 2 bytes -// [ Standard-Query/Response + Recursive ]: 0x01/0x81 -// [ Reserved/Response-No-Err ]: 0x00 -// [ Questions-Count ]: 2 bytes (= len(questions)) -// [ Answers Count ]: 2 bytes (= 0x00 0x00 / len(questions)) -// [ Authorities Count ]: 0x00 0x00 -// [ Resources Count ]: 0x00 0x01 -// [ `questions` ]: ? bytes -// [ Additional Resources ]: ? bytes (= OPT(payload_size=4096)) +// [ `id` ]: 2 bytes +// [ Standard-Query/Response + Recursive ]: 0x01/0x81 +// [ Reserved/Response-No-Err ]: 0x00 +// [ Questions-Count ]: 2 bytes (= len(questions)) +// [ Answers Count ]: 2 bytes (= 0x00 0x00 / len(questions)) +// [ Authorities Count ]: 0x00 0x00 +// [ Resources Count ]: 0x00 0x01 +// [ `questions` ]: ? bytes +// [ Additional Resources ]: ? bytes (= OPT(payload_size=4096)) // // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 // diff --git a/transport/tlsfrag/stream_dialer_test.go b/transport/tlsfrag/stream_dialer_test.go index 8065fe06..887f8f5c 100644 --- a/transport/tlsfrag/stream_dialer_test.go +++ b/transport/tlsfrag/stream_dialer_test.go @@ -13,3 +13,144 @@ // limitations under the License. package tlsfrag + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" +) + +// Make sure only the first Client Hello is splitted. +func TestStreamDialerFuncSplitsClientHello(t *testing.T) { + hello := constructTLSRecord(t, layers.TLSHandshake, 0x0301, []byte{0x01, 0x00, 0x00, 0x03, 0xaa, 0xbb, 0xcc}) + cipher := constructTLSRecord(t, layers.TLSChangeCipherSpec, 0x0303, []byte{0x01}) + req1 := constructTLSRecord(t, layers.TLSApplicationData, 0x0303, []byte{0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88}) + + inner := &collectStreamDialer{} + conn := assertCanDialFragFunc(t, inner, "ipinfo.io:443", func(_ []byte) int { return 2 }) + defer conn.Close() + + assertCanWriteAll(t, conn, net.Buffers{hello, cipher, req1, hello, cipher, req1}) + + frag1 := constructTLSRecord(t, layers.TLSHandshake, 0x0301, []byte{0x01, 0x00}) + frag2 := constructTLSRecord(t, layers.TLSHandshake, 0x0301, []byte{0x00, 0x03, 0xaa, 0xbb, 0xcc}) + expected := net.Buffers{ + append(frag1, frag2...), // fragment 1 and fragment 2 will be merged in one single Write + cipher, req1, hello, cipher, req1, // unchanged + } + require.Equal(t, expected, inner.bufs) +} + +// Make sure we don't split if the first packet is not a Client Hello. +func TestStreamDialerFuncDontSplitNonClientHello(t *testing.T) { + cases := []struct { + msg string + pkt []byte + }{ + { + msg: "application data", + pkt: constructTLSRecord(t, layers.TLSApplicationData, 0x0303, []byte{0x01, 0x00, 0x00, 0x03, 0xdd, 0xee, 0xff}), + }, + { + msg: "cipher", + pkt: constructTLSRecord(t, layers.TLSChangeCipherSpec, 0x0303, []byte{0xff}), + }, + { + msg: "invalid version", + pkt: constructTLSRecord(t, layers.TLSHandshake, 0x0305, []byte{0x01, 0x00, 0x00, 0x03, 0xdd, 0xee, 0xff}), + }, + { + msg: "invalid length", + pkt: constructTLSRecord(t, layers.TLSHandshake, 0x0305, []byte{}), + }, + } + + cipher := constructTLSRecord(t, layers.TLSChangeCipherSpec, 0x0303, []byte{0x01}) + req := constructTLSRecord(t, layers.TLSApplicationData, 0x0303, []byte{0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88}) + + for _, tc := range cases { + inner := &collectStreamDialer{} + conn := assertCanDialFragFunc(t, inner, "ipinfo.io:443", func(_ []byte) int { return 2 }) + defer conn.Close() + + assertCanWriteAll(t, conn, net.Buffers{tc.pkt, cipher, req}) + expected := net.Buffers{tc.pkt, cipher, req} + if len(tc.pkt) > 5 { + // header and content of the first pkt might be issued by two Writes, but they are not fragmented + expected = net.Buffers{tc.pkt[:5], tc.pkt[5:], cipher, req} + } + require.Equal(t, expected, inner.bufs, tc.msg) + } +} + +// test assertions + +func assertCanDialFragFunc(t *testing.T, inner transport.StreamDialer, raddr string, frag FragFunc) transport.StreamConn { + d, err := NewStreamDialerFunc(inner, frag) + require.NoError(t, err) + require.NotNil(t, d) + conn, err := d.Dial(context.Background(), raddr) + require.NoError(t, err) + require.NotNil(t, conn) + return conn +} + +func assertCanWriteAll(t *testing.T, w io.Writer, buf net.Buffers) { + for _, p := range buf { + n, err := w.Write(p) + require.NoError(t, err) + require.Equal(t, len(p), n) + } +} + +// private test helpers + +func constructTLSRecord(t *testing.T, typ layers.TLSType, ver layers.TLSVersion, payload []byte) []byte { + pkt := layers.TLS{ + AppData: []layers.TLSAppDataRecord{{ + TLSRecordHeader: layers.TLSRecordHeader{ + ContentType: typ, + Version: ver, + Length: uint16(len(payload)), + }, + Payload: payload, + }}, + } + + buf := gopacket.NewSerializeBuffer() + err := pkt.SerializeTo(buf, gopacket.SerializeOptions{}) + require.NoError(t, err) + return buf.Bytes() +} + +// collectStreamDialer collects all writes to this stream dialer and append it to bufs +type collectStreamDialer struct { + bufs net.Buffers +} + +func (d *collectStreamDialer) Dial(ctx context.Context, raddr string) (transport.StreamConn, error) { + return d, nil +} + +func (c *collectStreamDialer) Write(p []byte) (int, error) { + c.bufs = append(c.bufs, append([]byte{}, p...)) // copy p rather than retaining it according to the principle of Write + return len(p), nil +} + +func (c *collectStreamDialer) Read(p []byte) (int, error) { return 0, errors.New("not supported") } +func (c *collectStreamDialer) Close() error { return nil } +func (c *collectStreamDialer) CloseRead() error { return nil } +func (c *collectStreamDialer) CloseWrite() error { return nil } +func (c *collectStreamDialer) LocalAddr() net.Addr { return nil } +func (c *collectStreamDialer) RemoteAddr() net.Addr { return nil } +func (c *collectStreamDialer) SetDeadline(t time.Time) error { return errors.New("not supported") } +func (c *collectStreamDialer) SetReadDeadline(t time.Time) error { return errors.New("not supported") } +func (c *collectStreamDialer) SetWriteDeadline(t time.Time) error { return errors.New("not supported") }