diff --git a/transport/split/stream_dialer.go b/transport/split/stream_dialer.go index a8f9ada6..59326a15 100644 --- a/transport/split/stream_dialer.go +++ b/transport/split/stream_dialer.go @@ -24,20 +24,19 @@ import ( // splitDialer is a [transport.StreamDialer] that implements the split strategy. // Use [NewStreamDialer] to create new instances. type splitDialer struct { - dialer transport.StreamDialer - splitPoint int64 - options []Option + dialer transport.StreamDialer + nextSplit func() int64 } var _ transport.StreamDialer = (*splitDialer)(nil) // NewStreamDialer creates a [transport.StreamDialer] that splits the outgoing stream after writing "prefixBytes" bytes // using the split writer. You can specify multiple sequences with the [AddSplitSequence] option. -func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, options ...Option) (transport.StreamDialer, error) { +func NewStreamDialer(dialer transport.StreamDialer, prefixBytes int64, nextSplit func() int64) (transport.StreamDialer, error) { if dialer == nil { return nil, errors.New("argument dialer must not be nil") } - return &splitDialer{dialer: dialer, splitPoint: prefixBytes}, nil + return &splitDialer{dialer: dialer, nextSplit: nextSplit}, nil } // DialStream implements [transport.StreamDialer].DialStream. @@ -46,5 +45,5 @@ func (d *splitDialer) DialStream(ctx context.Context, remoteAddr string) (transp if err != nil { return nil, err } - return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.splitPoint, d.options...)), nil + return transport.WrapConn(innerConn, innerConn, NewWriter(innerConn, d.nextSplit)), nil } diff --git a/transport/split/writer.go b/transport/split/writer.go index a4db713c..21d2bf66 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -18,18 +18,11 @@ import ( "io" ) -// repeatedSplit represents a split sequence of count blocks with bytes length. -type repeatedSplit struct { - count int - bytes int64 -} - type splitWriter struct { writer io.Writer // Bytes until the next split. This must always be > 0, unless splits are done. - nextSplitBytes int64 - // Remaining split sequences. All entries here must have count > 0 && bytes > 0. - remainingSplits []repeatedSplit + nextSplitBytes int64 + nextSegmentLength func() int64 } var _ io.Writer = (*splitWriter)(nil) @@ -41,49 +34,62 @@ type splitWriterReaderFrom struct { var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) -// NewWriter creates a [io.Writer] that ensures the byte sequence is split at prefixBytes. -// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes. -// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789]. -// If the input writer is a [io.ReaderFrom], the output writer will be too. -// It's possible to enable multiple splits with the [AddSplitSequence] option, which adds count splits every skipBytes bytes. -// Example: -// prefixBytes = 1, AddSplitSequence(count=2, bytes=6) -// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become -// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] -func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { - sw := &splitWriter{writer: writer, remainingSplits: []repeatedSplit{}} - sw.addSplitSequence(1, prefixBytes) - for _, option := range options { - option(sw) - } - if rf, ok := writer.(io.ReaderFrom); ok { - return &splitWriterReaderFrom{sw, rf} +// Split Iterator is a function that returns how many bytes until the next split point, or zero if there are no more splits to do. +type SplitIterator func() int64 + +// NewFixedSplitIterator is a helper function that returns a [SplitIterator] that returns the input number once, followed by zero. +// This is helpful for when you want to split the stream once in a fixed position. +func NewFixedSplitIterator(n int64) SplitIterator { + return func() int64 { + next := n + n = 0 + return next } - return sw } -func (w *splitWriter) addSplitSequence(count int, skipBytes int64) { - if count == 0 || skipBytes == 0 { - return - } - if w.nextSplitBytes == 0 { - w.nextSplitBytes = skipBytes - count-- +// RepeatedSplit represents a split sequence of count segments with bytes length. +type RepeatedSplit struct { + Count int + Bytes int64 +} + +// NewRepeatedSplitIterator is a helper function that returns a [SplitIterator] that returns split points according to splits. +// The splits input represents pairs of (count, bytes), meaning a sequence of count splits with bytes length. +// This is helpful for when you want to split the stream repeatedly at different positions and lengths. +func NewRepeatedSplitIterator(splits ...RepeatedSplit) SplitIterator { + // Make sure we don't edit the original slice. + cleanSplits := make([]RepeatedSplit, 0, len(splits)) + // Remove no-op splits. + for _, split := range splits { + if split.Count > 0 && split.Bytes > 0 { + cleanSplits = append(cleanSplits, split) + } } - if count > 0 { - w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) + return func() int64 { + if len(cleanSplits) == 0 { + return 0 + } + next := cleanSplits[0].Bytes + cleanSplits[0].Count -= 1 + if cleanSplits[0].Count == 0 { + cleanSplits = cleanSplits[1:] + } + return next } } -type Option func(w *splitWriter) - -// AddSplitSequence will add count splits, each of skipBytes length. -func AddSplitSequence(count int, skipBytes int64) Option { - return func(w *splitWriter) { - w.addSplitSequence(count, skipBytes) +// NewWriter creates a split Writer that calls the nextSegmentLength [SplitIterator] to determine the number bytes until the next split +// point until it returns zero. +func NewWriter(writer io.Writer, nextSegmentLength SplitIterator) io.Writer { + sw := &splitWriter{writer: writer, nextSegmentLength: nextSegmentLength} + sw.nextSplitBytes = nextSegmentLength() + if rf, ok := writer.(io.ReaderFrom); ok { + return &splitWriterReaderFrom{sw, rf} } + return sw } +// ReadFrom implements io.ReaderFrom. func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { var written int64 for w.nextSplitBytes > 0 { @@ -114,17 +120,11 @@ func (w *splitWriter) advance(n int64) { if w.nextSplitBytes > 0 { return } - // Split done, set next split. - if len(w.remainingSplits) == 0 { - return - } - w.nextSplitBytes = w.remainingSplits[0].bytes - w.remainingSplits[0].count -= 1 - if w.remainingSplits[0].count == 0 { - w.remainingSplits = w.remainingSplits[1:] - } + // Split done, set up the next split. + w.nextSplitBytes = w.nextSegmentLength() } +// Write implements io.Writer. func (w *splitWriter) Write(data []byte) (written int, err error) { for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { dataToSend := data[:w.nextSplitBytes] diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index d995417b..fc936891 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -38,7 +38,7 @@ func (w *collectWrites) Write(data []byte) (int, error) { func TestWrite_Split(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 3) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(3)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -47,7 +47,7 @@ func TestWrite_Split(t *testing.T) { func TestWrite_SplitZero(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(0, 1), AddSplitSequence(10, 0), AddSplitSequence(0, 2)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{0, 1}, RepeatedSplit{10, 0}, RepeatedSplit{0, 2})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -56,7 +56,7 @@ func TestWrite_SplitZero(t *testing.T) { func TestWrite_SplitZeroLong(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(1_000_000_000_000_000_000, 0)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{1_000_000_000_000_000_000, 0})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -65,7 +65,7 @@ func TestWrite_SplitZeroLong(t *testing.T) { func TestWrite_SplitZeroPrefix(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(3, 2)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 0}, RepeatedSplit{3, 2})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -74,7 +74,7 @@ func TestWrite_SplitZeroPrefix(t *testing.T) { func TestWrite_SplitMulti(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3})) n, err := splitWriter.Write([]byte("RequestRequestRequest")) require.NoError(t, err) require.Equal(t, 21, n) @@ -83,7 +83,7 @@ func TestWrite_SplitMulti(t *testing.T) { func TestWrite_ShortWrite(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 10) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(10)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -92,7 +92,7 @@ func TestWrite_ShortWrite(t *testing.T) { func TestWrite_Zero(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 0) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(0)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -101,7 +101,7 @@ func TestWrite_Zero(t *testing.T) { func TestWrite_NeedsTwoWrites(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 5) + splitWriter := NewWriter(&innerWriter, NewFixedSplitIterator(5)) n, err := splitWriter.Write([]byte("Re")) require.NoError(t, err) require.Equal(t, 2, n) @@ -113,7 +113,7 @@ func TestWrite_NeedsTwoWrites(t *testing.T) { func TestWrite_Compound(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(NewWriter(&innerWriter, 4), 1) + splitWriter := NewWriter(NewWriter(&innerWriter, NewFixedSplitIterator(4)), NewFixedSplitIterator(1)) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -122,7 +122,7 @@ func TestWrite_Compound(t *testing.T) { func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 5)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 5})) n, err := splitWriter.Write([]byte("RequestRequestRequest.")) require.NoError(t, err) require.Equal(t, 7*3+1, n) @@ -137,7 +137,7 @@ func TestWrite_RepeatNumber3_SkipBytes5(t *testing.T) { func TestWrite_RepeatNumber3_SkipBytes0(t *testing.T) { var innerWriter collectWrites - splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(0, 3)) + splitWriter := NewWriter(&innerWriter, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{0, 3})) n, err := splitWriter.Write([]byte("Request")) require.NoError(t, err) require.Equal(t, 7, n) @@ -161,7 +161,7 @@ func (r *collectReader) Read(buf []byte) (int, error) { } func TestReadFrom(t *testing.T) { - splitWriter := NewWriter(&bytes.Buffer{}, 3) + splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(3)) rf, ok := splitWriter.(io.ReaderFrom) require.True(t, ok) @@ -179,7 +179,7 @@ func TestReadFrom(t *testing.T) { } func TestReadFrom_Multi(t *testing.T) { - splitWriter := NewWriter(&bytes.Buffer{}, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + splitWriter := NewWriter(&bytes.Buffer{}, NewRepeatedSplitIterator(RepeatedSplit{1, 1}, RepeatedSplit{3, 2}, RepeatedSplit{2, 3})) rf, ok := splitWriter.(io.ReaderFrom) require.True(t, ok) @@ -191,7 +191,7 @@ func TestReadFrom_Multi(t *testing.T) { } func TestReadFrom_ShortRead(t *testing.T) { - splitWriter := NewWriter(&bytes.Buffer{}, 10) + splitWriter := NewWriter(&bytes.Buffer{}, NewFixedSplitIterator(10)) rf, ok := splitWriter.(io.ReaderFrom) require.True(t, ok) cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))} @@ -210,7 +210,7 @@ func TestReadFrom_ShortRead(t *testing.T) { func BenchmarkReadFrom(b *testing.B) { for n := 0; n < b.N; n++ { reader := bytes.NewReader(make([]byte, n)) - writer := NewWriter(io.Discard, 10) + writer := NewWriter(io.Discard, NewFixedSplitIterator(10)) rf, ok := writer.(io.ReaderFrom) require.True(b, ok) _, err := rf.ReadFrom(reader)