Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into amir-tc-udp
Browse files Browse the repository at this point in the history
  • Loading branch information
amircybersec committed Nov 11, 2024
2 parents d5d398b + 8f91506 commit cd78174
Show file tree
Hide file tree
Showing 35 changed files with 1,131 additions and 637 deletions.
18 changes: 11 additions & 7 deletions transport/split/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@ import (
"github.com/Jigsaw-Code/outline-sdk/transport"
)

// splitDialer is a [transport.StreamDialer] that implements the split strategy.
// Use [NewStreamDialer] to create new instances.
type splitDialer struct {
dialer transport.StreamDialer
splitPoint int64
dialer transport.StreamDialer
nextSplit SplitIterator
}

var _ transport.StreamDialer = (*splitDialer)(nil)

// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes
// using [SplitWriter].
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64) (transport.StreamDialer, error) {
// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream according to nextSplit.
func NewStreamDialer(dialer transport.StreamDialer, nextSplit SplitIterator) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
return &splitDialer{dialer: dialer, splitPoint: prefixBytes}, nil
if nextSplit == nil {
return nil, errors.New("argument nextSplit must not be nil")
}
return &splitDialer{dialer: dialer, nextSplit: nextSplit}, nil
}

// DialStream implements [transport.StreamDialer].DialStream.
Expand All @@ -43,5 +47,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp
if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.nextSplit)), nil
}
108 changes: 92 additions & 16 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import (
)

type splitWriter struct {
writer io.Writer
prefixBytes int64
writer io.Writer
// Bytes until the next split. This must always be > 0, unless splits are done.
nextSplitBytes int64
nextSegmentLength func() int64
}

var _ io.Writer = (*splitWriter)(nil)
Expand All @@ -32,36 +34,110 @@ type splitWriterReaderFrom struct {

var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil)

// NewWriter creates a [io.Writer] that ensures the byte sequence is split at prefixBytes.
// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes.
// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789].
// If the input writer is a [io.ReaderFrom], the output writer will be too.
func NewWriter(writer io.Writer, prefixBytes int64) io.Writer {
sw := &splitWriter{writer, prefixBytes}
// SplitIterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do.
type SplitIterator func() int64

// NewFixedSplitIterator is a helper function that returns a [SplitIterator] that returns the input number once, followed by zero.
// This is helpful for when you want to split the stream once in a fixed position.
func NewFixedSplitIterator(n int64) SplitIterator {
return func() int64 {
next := n
n = 0
return next
}
}

// RepeatedSplit represents a split sequence of count segments with bytes length.
type RepeatedSplit struct {
Count int
Bytes int64
}

// NewRepeatedSplitIterator is a helper function that returns a [SplitIterator] that returns split points according to splits.
// The splits input represents pairs of (count, bytes), meaning a sequence of count splits with bytes length.
// This is helpful for when you want to split the stream repeatedly at different positions and lengths.
func NewRepeatedSplitIterator(splits ...RepeatedSplit) SplitIterator {
// Make sure we don't edit the original slice.
cleanSplits := make([]RepeatedSplit, 0, len(splits))
// Remove no-op splits.
for _, split := range splits {
if split.Count > 0 && split.Bytes > 0 {
cleanSplits = append(cleanSplits, split)
}
}
return func() int64 {
if len(cleanSplits) == 0 {
return 0
}
next := cleanSplits[0].Bytes
cleanSplits[0].Count -= 1
if cleanSplits[0].Count == 0 {
cleanSplits = cleanSplits[1:]
}
return next
}
}

// NewWriter creates a split Writer that calls the nextSegmentLength [SplitIterator] to determine the number bytes until the next split
// point until it returns zero.
func NewWriter(writer io.Writer, nextSegmentLength SplitIterator) io.Writer {
sw := &splitWriter{writer: writer, nextSegmentLength: nextSegmentLength}
sw.nextSplitBytes = nextSegmentLength()
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
return sw
}

// ReadFrom implements io.ReaderFrom.
func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
reader := io.MultiReader(io.LimitReader(source, w.prefixBytes), source)
written, err := w.rf.ReadFrom(reader)
w.prefixBytes -= written
var written int64
for w.nextSplitBytes > 0 {
expectedBytes := w.nextSplitBytes
n, err := w.rf.ReadFrom(io.LimitReader(source, expectedBytes))
written += n
w.advance(n)
if err != nil {
return written, err
}
if n < expectedBytes {
// Source is done before the split happened. Return.
return written, err
}
}
n, err := w.rf.ReadFrom(source)
written += n
w.advance(n)
return written, err
}

func (w *splitWriter) advance(n int64) {
if w.nextSplitBytes == 0 {
// Done with splits: return.
return
}
w.nextSplitBytes -= int64(n)
if w.nextSplitBytes > 0 {
return
}
// Split done, set up the next split.
w.nextSplitBytes = w.nextSegmentLength()
}

// Write implements io.Writer.
func (w *splitWriter) Write(data []byte) (written int, err error) {
if 0 < w.prefixBytes && w.prefixBytes < int64(len(data)) {
written, err = w.writer.Write(data[:w.prefixBytes])
w.prefixBytes -= int64(written)
for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) {
dataToSend := data[:w.nextSplitBytes]
n, err := w.writer.Write(dataToSend)
written += n
w.advance(int64(n))
if err != nil {
return written, err
}
data = data[written:]
data = data[n:]
}
n, err := w.writer.Write(data)
written += n
w.prefixBytes -= int64(n)
w.advance(int64(n))
return written, err
}
88 changes: 80 additions & 8 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,52 @@ func (w *collectWrites) Write(data []byte) (int, error) {

func TestWrite_Split(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 3)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(3))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Req"), []byte("uest")}, innerWriter.writes)
}

func TestWrite_SplitZero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{0, 1}, RepeatedSplit{10, 0}, RepeatedSplit{0, 2}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
}

func TestWrite_SplitZeroLong(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{1_000_000_000_000_000_000, 0}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
}

func TestWrite_SplitZeroPrefix(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{3, 2}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Re"), []byte("qu"), []byte("es"), []byte("t")}, innerWriter.writes)
}

func TestWrite_SplitMulti(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3}))
n, err := splitWriter.Write([]byte("RequestRequestRequest"))
require.NoError(t, err)
require.Equal(t, 21, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, innerWriter.writes)
}

func TestWrite_ShortWrite(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 10)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(10))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -56,7 +92,7 @@ func TestWrite_ShortWrite(t *testing.T) {

func TestWrite_Zero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(0))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -65,7 +101,7 @@ func TestWrite_Zero(t *testing.T) {

func TestWrite_NeedsTwoWrites(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 5)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(5))
n, err := splitWriter.Write([]byte("Re"))
require.NoError(t, err)
require.Equal(t, 2, n)
Expand All @@ -77,13 +113,37 @@ func TestWrite_NeedsTwoWrites(t *testing.T) {

func TestWrite_Compound(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(NewWriter(&innerWriter, 4), 1)
splitWriter := NewWriter(NewWriter(&innerWriter, NewFixedSplitIterator(4)), NewFixedSplitIterator(1))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("equ"), []byte("est")}, innerWriter.writes)
}

func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 5}))
n, err := splitWriter.Write([]byte("RequestRequestRequest."))
require.NoError(t, err)
require.Equal(t, 7*3+1, n)
require.Equal(t, [][]byte{
[]byte("R"), // prefix
[]byte("eques"), // split 1
[]byte("tRequ"), // split 2
[]byte("estRe"), // split 3
[]byte("quest."), // tail
}, innerWriter.writes)
}

func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{0, 3}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("equest")}, innerWriter.writes)
}

// collectReader is a [io.Reader] that appends each Read from the Reader to the reads slice.
type collectReader struct {
io.Reader
Expand All @@ -101,7 +161,7 @@ func (r *collectReader) Read(buf []byte) (int, error) {
}

func TestReadFrom(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 3)
splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(3))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

Expand All @@ -118,8 +178,20 @@ func TestReadFrom(t *testing.T) {
require.Equal(t, [][]byte{[]byte("Request2")}, cr.reads)
}

func TestReadFrom_Multi(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3}))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

cr := &collectReader{Reader: bytes.NewReader([]byte("RequestRequestRequest"))}
n, err := rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(21), n)
require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, cr.reads)
}

func TestReadFrom_ShortRead(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 10)
splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(10))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)
cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))}
Expand All @@ -138,7 +210,7 @@ func TestReadFrom_ShortRead(t *testing.T) {
func BenchmarkReadFrom(b *testing.B) {
for n := 0; n < b.N; n++ {
reader := bytes.NewReader(make([]byte, n))
writer := NewWriter(io.Discard, 10)
writer := NewWriter(io.Discard, NewFixedSplitIterator(10))
rf, ok := writer.(io.ReaderFrom)
require.True(b, ok)
_, err := rf.ReadFrom(reader)
Expand Down
8 changes: 4 additions & 4 deletions transport/tls/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package tls
import (
"context"
"crypto/x509"
"runtime"
"testing"

"github.com/Jigsaw-Code/outline-sdk/transport"
Expand Down Expand Up @@ -55,9 +54,10 @@ func TestExpired(t *testing.T) {
}

func TestRevoked(t *testing.T) {
if runtime.GOOS == "linux" || runtime.GOOS == "windows" {
t.Skip("Certificate revocation list is not up-to-date in Linux and Windows")
}
t.Skip("Certificate revocation list is not working")

// TODO(fortuna): implement proper revocation test.
// See https://www.cossacklabs.com/blog/tls-validation-implementing-ocsp-and-crl-in-go/

sd, err := NewStreamDialer(&transport.TCPDialer{})
require.NoError(t, err)
Expand Down
Loading

0 comments on commit cd78174

Please sign in to comment.