Skip to content

Commit

Permalink
Advanced split
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Nov 5, 2024
1 parent 82b33bc commit b6e887a
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 16 deletions.
9 changes: 6 additions & 3 deletions transport/split/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@ import (
"github.com/Jigsaw-Code/outline-sdk/transport"
)

// splitDialer is a [transport.StreamDialer] that implements the split strategy.
// Use [NewStreamDialer] to create new instances.
type splitDialer struct {
dialer transport.StreamDialer
splitPoint int64
options []Option
}

var _ transport.StreamDialer = (*splitDialer)(nil)

// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes
// using [SplitWriter].
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64) (transport.StreamDialer, error) {
// using [splitWriter]. If "repeatsNumber" is not 0, will split that many times, skipping "skipBytes" in between packets.
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
Expand All @@ -43,5 +46,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp
if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint)), nil
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.options...)), nil
}
69 changes: 56 additions & 13 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ import (
"io"
)

type repeatedSplit struct {
count int
bytes int64
}

type splitWriter struct {
writer io.Writer
prefixBytes int64
writer io.Writer
nextSplitBytes int64
remainingSplits []repeatedSplit
}

var _ io.Writer = (*splitWriter)(nil)
Expand All @@ -36,32 +42,69 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil)
// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes.
// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789].
// If the input writer is a [io.ReaderFrom], the output writer will be too.
func NewWriter(writer io.Writer, prefixBytes int64) io.Writer {
sw := &splitWriter{writer, prefixBytes}
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
// It's possible to enable multiple splits with the [EnableRepeatSplit] option.
// In that cases, splits will happen at positions prefixBytes + i * skipBytes, for 0 <= i < count.
// This means that after the initial split, count splits will happen every skipBytes bytes.
// Example:
// prefixBytes = 1
// count = 2
// skipBytes = 6
// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become
// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...]
func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer {
sw := &splitWriter{writer: writer, nextSplitBytes: prefixBytes, remainingSplits: []repeatedSplit{}}
for _, option := range options {
option(sw)
}
if len(sw.remainingSplits) == 0 {
// TODO(fortuna): Support ReaderFrom for repeat split.
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
}
return sw
}

type Option func(w *splitWriter)

// AddSplitSequence will add count splits, each of skipBytes length.
func AddSplitSequence(count int, skipBytes int64) Option {
return func(w *splitWriter) {
if count > 0 {
w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes})
}
}
}

func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
reader := io.MultiReader(io.LimitReader(source, w.prefixBytes), source)
reader := io.MultiReader(io.LimitReader(source, w.nextSplitBytes), source)
written, err := w.rf.ReadFrom(reader)
w.prefixBytes -= written
w.nextSplitBytes -= written
return written, err
}

func (w *splitWriter) Write(data []byte) (written int, err error) {
if 0 < w.prefixBytes && w.prefixBytes < int64(len(data)) {
written, err = w.writer.Write(data[:w.prefixBytes])
w.prefixBytes -= int64(written)
for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) {
dataToSend := data[:w.nextSplitBytes]
n, err := w.writer.Write(dataToSend)
written += n
w.nextSplitBytes -= int64(n)
if err != nil {
return written, err
}
data = data[written:]
data = data[n:]

// Split done. Update nextSplitBytes.
if len(w.remainingSplits) > 0 {
w.nextSplitBytes = w.remainingSplits[0].bytes
w.remainingSplits[0].count -= 1
if w.remainingSplits[0].count == 0 {
w.remainingSplits = w.remainingSplits[1:]
}
}
}
n, err := w.writer.Write(data)
written += n
w.prefixBytes -= int64(n)
w.nextSplitBytes -= int64(n)
return written, err
}
24 changes: 24 additions & 0 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ func TestWrite_Compound(t *testing.T) {
require.Equal(t, [][]byte{[]byte("R"), []byte("equ"), []byte("est")}, innerWriter.writes)
}

func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 5))
n, err := splitWriter.Write([]byte("RequestRequestRequest."))
require.NoError(t, err)
require.Equal(t, 7*3+1, n)
require.Equal(t, [][]byte{
[]byte("R"), // prefix
[]byte("eques"), // split 1
[]byte("tRequ"), // split 2
[]byte("estRe"), // split 3
[]byte("quest."), // tail
}, innerWriter.writes)
}

func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(0, 3))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("equest")}, innerWriter.writes)
}

// collectReader is a [io.Reader] that appends each Read from the Reader to the reads slice.
type collectReader struct {
io.Reader
Expand Down

0 comments on commit b6e887a

Please sign in to comment.