Skip to content

Commit

Permalink
Support ReaderFrom
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Nov 5, 2024
1 parent e283a16 commit d2a4cf9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
33 changes: 22 additions & 11 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,8 @@ func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer
for _, option := range options {
option(sw)
}
if len(sw.remainingSplits) == 0 {
// TODO(fortuna): Support ReaderFrom for repeat split.
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
return sw
}
Expand Down Expand Up @@ -88,13 +85,27 @@ func AddSplitSequence(count int, skipBytes int64) Option {
}

func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
reader := io.MultiReader(io.LimitReader(source, w.nextSplitBytes), source)
written, err := w.rf.ReadFrom(reader)
w.nextSplitBytes -= written
var written int64
for w.nextSplitBytes > 0 {
expectedBytes := w.nextSplitBytes
n, err := w.rf.ReadFrom(io.LimitReader(source, expectedBytes))
written += n
w.advance(n)
if err != nil {
return written, err
}
if n < expectedBytes {
// Source is done before the split happened. Return.
return written, err
}
}
n, err := w.rf.ReadFrom(source)
written += n
w.advance(n)
return written, err
}

func (w *splitWriter) advance(n int) {
func (w *splitWriter) advance(n int64) {
if w.nextSplitBytes == 0 {
// Done with splits: return.
return
Expand All @@ -119,14 +130,14 @@ func (w *splitWriter) Write(data []byte) (written int, err error) {
dataToSend := data[:w.nextSplitBytes]
n, err := w.writer.Write(dataToSend)
written += n
w.advance(n)
w.advance(int64(n))
if err != nil {
return written, err
}
data = data[n:]
}
n, err := w.writer.Write(data)
written += n
w.advance(n)
w.advance(int64(n))
return written, err
}
12 changes: 12 additions & 0 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ func TestReadFrom(t *testing.T) {
require.Equal(t, [][]byte{[]byte("Request2")}, cr.reads)
}

func TestReadFrom_Multi(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3))
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

cr := &collectReader{Reader: bytes.NewReader([]byte("RequestRequestRequest"))}
n, err := rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(21), n)
require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, cr.reads)
}

func TestReadFrom_ShortRead(t *testing.T) {
splitWriter := NewWriter(&bytes.Buffer{}, 10)
rf, ok := splitWriter.(io.ReaderFrom)
Expand Down

0 comments on commit d2a4cf9

Please sign in to comment.