diff --git a/transport/split/writer.go b/transport/split/writer.go index 671b9763..a4db713c 100644 --- a/transport/split/writer.go +++ b/transport/split/writer.go @@ -56,11 +56,8 @@ func NewWriter(writer io.Writer, prefixBytes int64, options ...Option) io.Writer for _, option := range options { option(sw) } - if len(sw.remainingSplits) == 0 { - // TODO(fortuna): Support ReaderFrom for repeat split. - if rf, ok := writer.(io.ReaderFrom); ok { - return &splitWriterReaderFrom{sw, rf} - } + if rf, ok := writer.(io.ReaderFrom); ok { + return &splitWriterReaderFrom{sw, rf} } return sw } @@ -88,13 +85,27 @@ func AddSplitSequence(count int, skipBytes int64) Option { } func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) { - reader := io.MultiReader(io.LimitReader(source, w.nextSplitBytes), source) - written, err := w.rf.ReadFrom(reader) - w.nextSplitBytes -= written + var written int64 + for w.nextSplitBytes > 0 { + expectedBytes := w.nextSplitBytes + n, err := w.rf.ReadFrom(io.LimitReader(source, expectedBytes)) + written += n + w.advance(n) + if err != nil { + return written, err + } + if n < expectedBytes { + // Source is done before the split happened. Return. + return written, err + } + } + n, err := w.rf.ReadFrom(source) + written += n + w.advance(n) return written, err } -func (w *splitWriter) advance(n int) { +func (w *splitWriter) advance(n int64) { if w.nextSplitBytes == 0 { // Done with splits: return. return @@ -119,7 +130,7 @@ func (w *splitWriter) Write(data []byte) (written int, err error) { dataToSend := data[:w.nextSplitBytes] n, err := w.writer.Write(dataToSend) written += n - w.advance(n) + w.advance(int64(n)) if err != nil { return written, err } @@ -127,6 +138,6 @@ func (w *splitWriter) Write(data []byte) (written int, err error) { } n, err := w.writer.Write(data) written += n - w.advance(n) + w.advance(int64(n)) return written, err } diff --git a/transport/split/writer_test.go b/transport/split/writer_test.go index b20f8222..d995417b 100644 --- a/transport/split/writer_test.go +++ b/transport/split/writer_test.go @@ -178,6 +178,18 @@ func TestReadFrom(t *testing.T) { require.Equal(t, [][]byte{[]byte("Request2")}, cr.reads) } +func TestReadFrom_Multi(t *testing.T) { + splitWriter := NewWriter(&bytes.Buffer{}, 1, AddSplitSequence(3, 2), AddSplitSequence(2, 3)) + rf, ok := splitWriter.(io.ReaderFrom) + require.True(t, ok) + + cr := &collectReader{Reader: bytes.NewReader([]byte("RequestRequestRequest"))} + n, err := rf.ReadFrom(cr) + require.NoError(t, err) + require.Equal(t, int64(21), n) + require.Equal(t, [][]byte{[]byte("R"), []byte("eq"), []byte("ue"), []byte("st"), []byte("Req"), []byte("ues"), []byte("tRequest")}, cr.reads) +} + func TestReadFrom_ShortRead(t *testing.T) { splitWriter := NewWriter(&bytes.Buffer{}, 10) rf, ok := splitWriter.(io.ReaderFrom)