Skip to content

Commit

Permalink
Introduce SplitIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Nov 5, 2024
1 parent d2a4cf9 commit 5adbe71
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 72 deletions.
11 changes: 5 additions & 6 deletions transport/split/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,19 @@ import (
// splitDialer is a [transport.StreamDialer] that implements the split strategy.
// Use [NewStreamDialer] to create new instances.
type splitDialer struct {
dialer transport.StreamDialer
splitPoint int64
options []Option
dialer transport.StreamDialer
nextSplit func() int64
}

var _ transport.StreamDialer = (*splitDialer)(nil)

// NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes
// using the split writer. You can specify multiple sequences with the [AddSplitSequence] option.
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) {
func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, nextSplit func() int64) (transport.StreamDialer, error) {
if dialer == nil {
return nil, errors.New("argument dialer must not be nil")
}
return &splitDialer{dialer: dialer, splitPoint: prefixBytes}, nil
return &splitDialer{dialer: dialer, nextSplit: nextSplit}, nil
}

// DialStream implements [transport.StreamDialer].DialStream.
Expand All @@ -46,5 +45,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp
if err != nil {
return nil, err
}
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.options...)), nil
return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.nextSplit)), nil
}
102 changes: 51 additions & 51 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,11 @@ import (
"io"
)

// repeatedSplit represents a split sequence of count blocks with bytes length.
type repeatedSplit struct {
count int
bytes int64
}

type splitWriter struct {
writer io.Writer
// Bytes until the next split. This must always be > 0, unless splits are done.
nextSplitBytes int64
// Remaining split sequences. All entries here must have count > 0 && bytes > 0.
remainingSplits []repeatedSplit
nextSplitBytes int64
nextSegmentLength func() int64
}

var _ io.Writer = (*splitWriter)(nil)
Expand All @@ -41,49 +34,62 @@ type splitWriterReaderFrom struct {

var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil)

// NewWriter creates a [io.Writer] that ensures the byte sequence is split at prefixBytes.
// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes.
// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789].
// If the input writer is a [io.ReaderFrom], the output writer will be too.
// It's possible to enable multiple splits with the [AddSplitSequence] option, which adds count splits every skipBytes bytes.
// Example:
// prefixBytes = 1, AddSplitSequence(count=2, bytes=6)
// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become
// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...]
func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer {
sw := &splitWriter{writer: writer, remainingSplits: []repeatedSplit{}}
sw.addSplitSequence(1, prefixBytes)
for _, option := range options {
option(sw)
}
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
// Split Iterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do.
type SplitIterator func() int64

// NewFixedSplitIterator is a helper function that returns a [SplitIterator] that returns the input number once, followed by zero.
// This is helpful for when you want to split the stream once in a fixed position.
func NewFixedSplitIterator(n int64) SplitIterator {
return func() int64 {
next := n
n = 0
return next
}
return sw
}

func (w *splitWriter) addSplitSequence(count int, skipBytes int64) {
if count == 0 || skipBytes == 0 {
return
}
if w.nextSplitBytes == 0 {
w.nextSplitBytes = skipBytes
count--
// RepeatedSplit represents a split sequence of count segments with bytes length.
type RepeatedSplit struct {
Count int
Bytes int64
}

// NewRepeatedSplitIterator is a helper function that returns a [SplitIterator] that returns split points according to splits.
// The splits input represents pairs of (count, bytes), meaning a sequence of count splits with bytes length.
// This is helpful for when you want to split the stream repeatedly at different positions and lengths.
func NewRepeatedSplitIterator(splits ...RepeatedSplit) SplitIterator {
// Make sure we don't edit the original slice.
cleanSplits := make([]RepeatedSplit, 0, len(splits))
// Remove no-op splits.
for _, split := range splits {
if split.Count > 0 && split.Bytes > 0 {
cleanSplits = append(cleanSplits, split)
}
}
if count > 0 {
w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes})
return func() int64 {
if len(cleanSplits) == 0 {
return 0
}
next := cleanSplits[0].Bytes
cleanSplits[0].Count -= 1
if cleanSplits[0].Count == 0 {
cleanSplits = cleanSplits[1:]
}
return next
}
}

type Option func(w *splitWriter)

// AddSplitSequence will add count splits, each of skipBytes length.
func AddSplitSequence(count int, skipBytes int64) Option {
return func(w *splitWriter) {
w.addSplitSequence(count, skipBytes)
// NewWriter creates a split Writer that calls the nextSegmentLength [SplitIterator] to determine the number bytes until the next split
// point until it returns zero.
func NewWriter(writer io.Writer, nextSegmentLength SplitIterator) io.Writer {
sw := &splitWriter{writer: writer, nextSegmentLength: nextSegmentLength}
sw.nextSplitBytes = nextSegmentLength()
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
return sw
}

// ReadFrom implements io.ReaderFrom.
func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
var written int64
for w.nextSplitBytes > 0 {
Expand Down Expand Up @@ -114,17 +120,11 @@ func (w *splitWriter) advance(n int64) {
if w.nextSplitBytes > 0 {
return
}
// Split done, set next split.
if len(w.remainingSplits) == 0 {
return
}
w.nextSplitBytes = w.remainingSplits[0].bytes
w.remainingSplits[0].count -= 1
if w.remainingSplits[0].count == 0 {
w.remainingSplits = w.remainingSplits[1:]
}
// Split done, set up the next split.
w.nextSplitBytes = w.nextSegmentLength()
}

// Write implements io.Writer.
func (w *splitWriter) Write(data []byte) (written int, err error) {
for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) {
dataToSend := data[:w.nextSplitBytes]
Expand Down
30 changes: 15 additions & 15 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (w *collectWrites) Write(data []byte) (int, error) {

func TestWrite_Split(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 3)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(3))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -47,7 +47,7 @@ func TestWrite_Split(t *testing.T) {

func TestWrite_SplitZero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(0, 1), AddSplitSequence(10, 0), AddSplitSequence(0, 2))
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{0, 1}, RepeatedSplit{10, 0}, RepeatedSplit{0, 2}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -56,7 +56,7 @@ func TestWrite_SplitZero(t *testing.T) {

func TestWrite_SplitZeroLong(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(1_000_000_000_000_000_000, 0))
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{1_000_000_000_000_000_000, 0}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -65,7 +65,7 @@ func TestWrite_SplitZeroLong(t *testing.T) {

func TestWrite_SplitZeroPrefix(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(3, 2))
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{3, 2}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -74,7 +74,7 @@ func TestWrite_SplitZeroPrefix(t *testing.T) {

func TestWrite_SplitMulti(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3))
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3}))
n, err := splitWriter.Write([]byte("RequestRequestRequest"))
require.NoError(t, err)
require.Equal(t, 21, n)
Expand All @@ -83,7 +83,7 @@ func TestWrite_SplitMulti(t *testing.T) {

func TestWrite_ShortWrite(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 10)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(10))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -92,7 +92,7 @@ func TestWrite_ShortWrite(t *testing.T) {

func TestWrite_Zero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(0))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -101,7 +101,7 @@ func TestWrite_Zero(t *testing.T) {

func TestWrite_NeedsTwoWrites(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 5)
splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(5))
n, err := splitWriter.Write([]byte("Re"))
require.NoError(t, err)
require.Equal(t, 2, n)
Expand All @@ -113,7 +113,7 @@ func TestWrite_NeedsTwoWrites(t *testing.T) {

func TestWrite_Compound(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(NewWriter(&innerWriter, 4), 1)
splitWriter := NewWriter(NewWriter(&innerWriter, NewFixedSplitIterator(4)), NewFixedSplitIterator(1))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -122,7 +122,7 @@ func TestWrite_Compound(t *testing.T) {

func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 5))
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 5}))
n, err := splitWriter.Write([]byte("RequestRequestRequest."))
require.NoError(t, err)
require.Equal(t, 7*3+1, n)
Expand All @@ -137,7 +137,7 @@ func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) {

func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(0, 3))
splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{0, 3}))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
Expand All @@ -161,7 +161,7 @@ func (r *collectReader) Read(buf []byte) (int, error) {
}

func TestReadFrom(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 3)
splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(3))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

Expand All @@ -179,7 +179,7 @@ func TestReadFrom(t *testing.T) {
}

func TestReadFrom_Multi(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3))
splitWriter := NewWriter(&bytes.Buffer{}, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3}))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

Expand All @@ -191,7 +191,7 @@ func TestReadFrom_Multi(t *testing.T) {
}

func TestReadFrom_ShortRead(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 10)
splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(10))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)
cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))}
Expand All @@ -210,7 +210,7 @@ func TestReadFrom_ShortRead(t *testing.T) {
func BenchmarkReadFrom(b *testing.B) {
for n := 0; n < b.N; n++ {
reader := bytes.NewReader(make([]byte, n))
writer := NewWriter(io.Discard, 10)
writer := NewWriter(io.Discard, NewFixedSplitIterator(10))
rf, ok := writer.(io.ReaderFrom)
require.True(b, ok)
_, err := rf.ReadFrom(reader)
Expand Down

0 comments on commit 5adbe71

Please sign in to comment.