From b6e887a86731814a9c5bbec652caa17bc51bee71 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Mon, 4 Nov 2024 19:00:06 -0500 Subject: [PATCH] Advanced split --- transport/split/stream_dialer.go | 9 +++-- transport/split/writer.go | 69 ++++++++++++++++++++++++++------ transport/split/writer_test.go | 24 +++++++++++ 3 files changed, 86 insertions(+), 16 deletions(-) diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index c6b59d1a..2c123060 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -21,16 +21,19 @@ import ( "github.com/Jigsaw-Code/outline-sdk/transport" ) +// splitDialer is a [transport.StreamDialer] that implements the split strategy. +// Use [NewStreamDialer] to create new instances. type splitDialer struct { dialer transport.StreamDialer splitPoint int64 + options []Option } var _ transport.StreamDialer = (*splitDialer)(nil) // NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes -// using [SplitWriter]. -func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64) (transport.StreamDialer, error) { +// using [splitWriter]. If "repeatsNumber" is not 0, will split that many times, skipping "skipBytes" in between packets. +func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) { if dialer == nil { return nil, errors.New("argument dialer must not be nil") } @@ -43,5 +46,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp if err != nil { return nil, err } - return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil + return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.options...)), nil } diff --git a/transport/split/writer.go b/transport/split/writer.go index b1b3e140..bbe4ee81 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -18,9 +18,15 @@ import ( "io" ) +type repeatedSplit struct { + count int + bytes int64 +} + type splitWriter struct { - writer io.Writer - prefixBytes int64 + writer io.Writer + nextSplitBytes int64 + remainingSplits []repeatedSplit } var _ io.Writer = (*splitWriter)(nil) @@ -36,32 +42,69 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) // A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes. // For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789]. // If the input writer is a [io.ReaderFrom], the output writer will be too. -func NewWriter(writer io.Writer, prefixBytes int64) io.Writer { - sw := &splitWriter{writer, prefixBytes} - if rf, ok := writer.(io.ReaderFrom); ok { - return &splitWriterReaderFrom{sw, rf} +// It's possible to enable multiple splits with the [EnableRepeatSplit] option. +// In that cases, splits will happen at positions prefixBytes + i * skipBytes, for 0 <= i < count. +// This means that after the initial split, count splits will happen every skipBytes bytes. +// Example: +// prefixBytes = 1 +// count = 2 +// skipBytes = 6 +// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become +// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] +func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { + sw := &splitWriter{writer: writer, nextSplitBytes: prefixBytes, remainingSplits: []repeatedSplit{}} + for _, option := range options { + option(sw) + } + if len(sw.remainingSplits) == 0 { + // TODO(fortuna): Support ReaderFrom for repeat split. + if rf, ok := writer.(io.ReaderFrom); ok { + return &splitWriterReaderFrom{sw, rf} + } } return sw } +type Option func(w *splitWriter) + +// AddSplitSequence will add count splits, each of skipBytes length. +func AddSplitSequence(count int, skipBytes int64) Option { + return func(w *splitWriter) { + if count > 0 { + w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) + } + } +} + func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { - reader := io.MultiReader(io.LimitReader(source, w.prefixBytes), source) + reader := io.MultiReader(io.LimitReader(source, w.nextSplitBytes), source) written, err := w.rf.ReadFrom(reader) - w.prefixBytes -= written + w.nextSplitBytes -= written return written, err } func (w *splitWriter) Write(data []byte) (written int, err error) { - if 0 < w.prefixBytes && w.prefixBytes < int64(len(data)) { - written, err = w.writer.Write(data[:w.prefixBytes]) - w.prefixBytes -= int64(written) + for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { + dataToSend := data[:w.nextSplitBytes] + n, err := w.writer.Write(dataToSend) + written += n + w.nextSplitBytes -= int64(n) if err != nil { return written, err } - data = data[written:] + data = data[n:] + + // Split done. Update nextSplitBytes. + if len(w.remainingSplits) > 0 { + w.nextSplitBytes = w.remainingSplits[0].bytes + w.remainingSplits[0].count -= 1 + if w.remainingSplits[0].count == 0 { + w.remainingSplits = w.remainingSplits[1:] + } + } } n, err := w.writer.Write(data) written += n - w.prefixBytes -= int64(n) + w.nextSplitBytes -= int64(n) return written, err } diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index 025402fa..276fe7bf 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -84,6 +84,30 @@ func TestWrite_Compound(t *testing.T) { require.Equal(t, [][]byte{[]byte("R"), []byte("equ"), []byte("est")}, innerWriter.writes) } +func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 5)) + n, err := splitWriter.Write([]byte("RequestRequestRequest.")) + require.NoError(t, err) + require.Equal(t, 7*3+1, n) + require.Equal(t, [][]byte{ + []byte("R"), // prefix + []byte("eques"), // split 1 + []byte("tRequ"), // split 2 + []byte("estRe"), // split 3 + []byte("quest."), // tail + }, innerWriter.writes) +} + +func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(0, 3)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("R"), []byte("equest")}, innerWriter.writes) +} + // collectReader is a [io.Reader] that appends each Read from the Reader to the reads slice. type collectReader struct { io.Reader