diff --git a/transport/split/writer.go b/transport/split/writer.go index 38f991c8..671b9763 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -18,14 +18,17 @@ import ( "io" ) +// repeatedSplit represents a split sequence of count blocks with bytes length. type repeatedSplit struct { count int bytes int64 } type splitWriter struct { - writer io.Writer - nextSplitBytes int64 + writer io.Writer + // Bytes until the next split. This must always be > 0, unless splits are done. + nextSplitBytes int64 + // Remaining split sequences. All entries here must have count > 0 && bytes > 0. remainingSplits []repeatedSplit } @@ -48,7 +51,8 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil) // Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become // [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...] func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer { - sw := &splitWriter{writer: writer, nextSplitBytes: prefixBytes, remainingSplits: []repeatedSplit{}} + sw := &splitWriter{writer: writer, remainingSplits: []repeatedSplit{}} + sw.addSplitSequence(1, prefixBytes) for _, option := range options { option(sw) } @@ -61,14 +65,25 @@ func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer return sw } +func (w *splitWriter) addSplitSequence(count int, skipBytes int64) { + if count == 0 || skipBytes == 0 { + return + } + if w.nextSplitBytes == 0 { + w.nextSplitBytes = skipBytes + count-- + } + if count > 0 { + w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) + } +} + type Option func(w *splitWriter) // AddSplitSequence will add count splits, each of skipBytes length. func AddSplitSequence(count int, skipBytes int64) Option { return func(w *splitWriter) { - if count > 0 { - w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes}) - } + w.addSplitSequence(count, skipBytes) } } @@ -79,28 +94,39 @@ func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { return written, err } +func (w *splitWriter) advance(n int) { + if w.nextSplitBytes == 0 { + // Done with splits: return. + return + } + w.nextSplitBytes -= int64(n) + if w.nextSplitBytes > 0 { + return + } + // Split done, set next split. + if len(w.remainingSplits) == 0 { + return + } + w.nextSplitBytes = w.remainingSplits[0].bytes + w.remainingSplits[0].count -= 1 + if w.remainingSplits[0].count == 0 { + w.remainingSplits = w.remainingSplits[1:] + } +} + func (w *splitWriter) Write(data []byte) (written int, err error) { for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) { dataToSend := data[:w.nextSplitBytes] n, err := w.writer.Write(dataToSend) written += n - w.nextSplitBytes -= int64(n) + w.advance(n) if err != nil { return written, err } data = data[n:] - - // Split done. Update nextSplitBytes. - if len(w.remainingSplits) > 0 { - w.nextSplitBytes = w.remainingSplits[0].bytes - w.remainingSplits[0].count -= 1 - if w.remainingSplits[0].count == 0 { - w.remainingSplits = w.remainingSplits[1:] - } - } } n, err := w.writer.Write(data) written += n - w.nextSplitBytes -= int64(n) + w.advance(n) return written, err } diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index 276fe7bf..b20f8222 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -45,6 +45,42 @@ func TestWrite_Split(t *testing.T) { require.Equal(t, [][]byte{[]byte("Req"), []byte("uest")}, innerWriter.writes) } +func TestWrite_SplitZero(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(0, 1), AddSplitSequence(10, 0), AddSplitSequence(0, 2)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes) +} + +func TestWrite_SplitZeroLong(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(1_000_000_000_000_000_000, 0)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes) +} + +func TestWrite_SplitZeroPrefix(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(3, 2)) + n, err := splitWriter.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + require.Equal(t, [][]byte{[]byte("Re"), []byte("qu"), []byte("es"), []byte("t")}, innerWriter.writes) +} + +func TestWrite_SplitMulti(t *testing.T) { + var innerWriter collectWrites + splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + n, err := splitWriter.Write([]byte("RequestRequestRequest")) + require.NoError(t, err) + require.Equal(t, 21, n) + require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, innerWriter.writes) +} + func TestWrite_ShortWrite(t *testing.T) { var innerWriter collectWrites splitWriter := NewWriter(&innerWriter, 10)