Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Nov 5, 2024
1 parent 81f3f2a commit e283a16
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
60 changes: 43 additions & 17 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ import (
"io"
)

// repeatedSplit represents a split sequence of count blocks with bytes length.
type repeatedSplit struct {
count int
bytes int64
}

type splitWriter struct {
writer io.Writer
nextSplitBytes int64
writer io.Writer
// Bytes until the next split. This must always be > 0, unless splits are done.
nextSplitBytes int64
// Remaining split sequences. All entries here must have count > 0 && bytes > 0.
remainingSplits []repeatedSplit
}

Expand All @@ -48,7 +51,8 @@ var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil)
// Array of [0 1 3 2 4 5 6 7 8 9 10 11 12 13 14 15 16 ...] will become
// [0] [1 2 3 4 5 6] [7 8 9 10 11 12] [13 14 15 16 ...]
func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer {
sw := &splitWriter{writer: writer, nextSplitBytes: prefixBytes, remainingSplits: []repeatedSplit{}}
sw := &splitWriter{writer: writer, remainingSplits: []repeatedSplit{}}
sw.addSplitSequence(1, prefixBytes)
for _, option := range options {
option(sw)
}
Expand All @@ -61,14 +65,25 @@ func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer
return sw
}

func (w *splitWriter) addSplitSequence(count int, skipBytes int64) {
if count == 0 || skipBytes == 0 {
return
}
if w.nextSplitBytes == 0 {
w.nextSplitBytes = skipBytes
count--
}
if count > 0 {
w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes})
}
}

type Option func(w *splitWriter)

// AddSplitSequence will add count splits, each of skipBytes length.
func AddSplitSequence(count int, skipBytes int64) Option {
return func(w *splitWriter) {
if count > 0 {
w.remainingSplits = append(w.remainingSplits, repeatedSplit{count: count, bytes: skipBytes})
}
w.addSplitSequence(count, skipBytes)
}
}

Expand All @@ -79,28 +94,39 @@ func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
return written, err
}

func (w *splitWriter) advance(n int) {
if w.nextSplitBytes == 0 {
// Done with splits: return.
return
}
w.nextSplitBytes -= int64(n)
if w.nextSplitBytes > 0 {
return
}
// Split done, set next split.
if len(w.remainingSplits) == 0 {
return
}
w.nextSplitBytes = w.remainingSplits[0].bytes
w.remainingSplits[0].count -= 1
if w.remainingSplits[0].count == 0 {
w.remainingSplits = w.remainingSplits[1:]
}
}

func (w *splitWriter) Write(data []byte) (written int, err error) {
for 0 < w.nextSplitBytes && w.nextSplitBytes < int64(len(data)) {
dataToSend := data[:w.nextSplitBytes]
n, err := w.writer.Write(dataToSend)
written += n
w.nextSplitBytes -= int64(n)
w.advance(n)
if err != nil {
return written, err
}
data = data[n:]

// Split done. Update nextSplitBytes.
if len(w.remainingSplits) > 0 {
w.nextSplitBytes = w.remainingSplits[0].bytes
w.remainingSplits[0].count -= 1
if w.remainingSplits[0].count == 0 {
w.remainingSplits = w.remainingSplits[1:]
}
}
}
n, err := w.writer.Write(data)
written += n
w.nextSplitBytes -= int64(n)
w.advance(n)
return written, err
}
36 changes: 36 additions & 0 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,42 @@ func TestWrite_Split(t *testing.T) {
require.Equal(t, [][]byte{[]byte("Req"), []byte("uest")}, innerWriter.writes)
}

func TestWrite_SplitZero(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(0, 1), AddSplitSequence(10, 0), AddSplitSequence(0, 2))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
}

func TestWrite_SplitZeroLong(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(1_000_000_000_000_000_000, 0))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
}

func TestWrite_SplitZeroPrefix(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 0, AddSplitSequence(3, 2))
n, err := splitWriter.Write([]byte("Request"))
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, [][]byte{[]byte("Re"), []byte("qu"), []byte("es"), []byte("t")}, innerWriter.writes)
}

func TestWrite_SplitMulti(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3))
n, err := splitWriter.Write([]byte("RequestRequestRequest"))
require.NoError(t, err)
require.Equal(t, 21, n)
require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, innerWriter.writes)
}

func TestWrite_ShortWrite(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 10)
Expand Down

0 comments on commit e283a16

Please sign in to comment.