Skip to content

Commit

Permalink
feat: no-buffer split (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Nov 2, 2023
1 parent 85ede5b commit 242ca55
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 30 deletions.
41 changes: 21 additions & 20 deletions transport/split/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,43 @@
package split

import (
"errors"
"io"
)

type SplitWriter struct {
type splitWriter struct {
writer io.Writer
prefixBytes int64
}

var _ io.Writer = (*SplitWriter)(nil)
var _ io.ReaderFrom = (*SplitWriter)(nil)
var _ io.Writer = (*splitWriter)(nil)

type splitWriterReaderFrom struct {
*splitWriter
rf io.ReaderFrom
}

var _ io.ReaderFrom = (*splitWriterReaderFrom)(nil)

// NewWriter creates a [io.Writer] that ensures the byte sequence is split at prefixBytes.
// A write will end right after byte index prefixBytes - 1, before a write starting at byte index prefixBytes.
// For example, if you have a write of [0123456789] and prefixBytes = 3, you will get writes [012] and [3456789].
func NewWriter(writer io.Writer, prefixBytes int64) *SplitWriter {
return &SplitWriter{writer, prefixBytes}
// If the input writer is a [io.ReaderFrom], the output writer will be too.
func NewWriter(writer io.Writer, prefixBytes int64) io.Writer {
sw := &splitWriter{writer, prefixBytes}
if rf, ok := writer.(io.ReaderFrom); ok {
return &splitWriterReaderFrom{sw, rf}
}
return sw
}

func (w *SplitWriter) ReadFrom(source io.Reader) (written int64, err error) {
if w.prefixBytes > 0 {
written, err = io.CopyN(w.writer, source, w.prefixBytes)
w.prefixBytes -= written
if errors.Is(err, io.EOF) {
return written, nil
}
if err != nil {
return written, err
}
}
n, err := io.Copy(w.writer, source)
written += n
func (w *splitWriterReaderFrom) ReadFrom(source io.Reader) (int64, error) {
reader := io.MultiReader(io.LimitReader(source, w.prefixBytes), source)
written, err := w.rf.ReadFrom(reader)
w.prefixBytes -= written
return written, err
}

func (w *SplitWriter) Write(data []byte) (written int, err error) {
func (w *splitWriter) Write(data []byte) (written int, err error) {
if 0 < w.prefixBytes && w.prefixBytes < int64(len(data)) {
written, err = w.writer.Write(data[:w.prefixBytes])
w.prefixBytes -= int64(written)
Expand Down
65 changes: 55 additions & 10 deletions transport/split/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/stretchr/testify/require"
)

// collectWrites is a [io.Writer] that appends each write to the writes slice.
type collectWrites struct {
writes [][]byte
}
Expand Down Expand Up @@ -83,20 +84,64 @@ func TestWrite_Compound(t *testing.T) {
require.Equal(t, [][]byte{[]byte("R"), []byte("equ"), []byte("est")}, innerWriter.writes)
}

// collectReader is a [io.Reader] that appends each Read from the Reader to the reads slice.
type collectReader struct {
io.Reader
reads [][]byte
}

func (r *collectReader) Read(buf []byte) (int, error) {
n, err := r.Reader.Read(buf)
if n > 0 {
read := make([]byte, n)
copy(read, buf[:n])
r.reads = append(r.reads, read)
}
return n, err
}

func TestReadFrom(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 3)
n, err := splitWriter.ReadFrom(bytes.NewReader([]byte("Request")))
splitWriter := NewWriter(&bytes.Buffer{}, 3)
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)

cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))}
n, err := rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(7), n)
require.Equal(t, [][]byte{[]byte("Req"), []byte("uest")}, innerWriter.writes)
require.Equal(t, int64(8), n)
require.Equal(t, [][]byte{[]byte("Req"), []byte("uest1")}, cr.reads)

cr = &collectReader{Reader: bytes.NewReader([]byte("Request2"))}
n, err = rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(8), n)
require.Equal(t, [][]byte{[]byte("Request2")}, cr.reads)
}

func TestReadFrom_ShortRead(t *testing.T) {
var innerWriter collectWrites
splitWriter := NewWriter(&innerWriter, 10)
n, err := splitWriter.ReadFrom(bytes.NewReader([]byte("Request")))
splitWriter := NewWriter(&bytes.Buffer{}, 10)
rf, ok := splitWriter.(io.ReaderFrom)
require.True(t, ok)
cr := &collectReader{Reader: bytes.NewReader([]byte("Request1"))}
n, err := rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(7), n)
require.Equal(t, [][]byte{[]byte("Request")}, innerWriter.writes)
require.Equal(t, int64(8), n)
require.Equal(t, [][]byte{[]byte("Request1")}, cr.reads)

cr = &collectReader{Reader: bytes.NewReader([]byte("Request2"))}
n, err = rf.ReadFrom(cr)
require.NoError(t, err)
require.Equal(t, int64(8), n)
require.Equal(t, [][]byte{[]byte("Re"), []byte("quest2")}, cr.reads)
}

func BenchmarkReadFrom(b *testing.B) {
for n := 0; n < b.N; n++ {
reader := bytes.NewReader(make([]byte, n))
writer := NewWriter(io.Discard, 10)
rf, ok := writer.(io.ReaderFrom)
require.True(b, ok)
_, err := rf.ReadFrom(reader)
require.NoError(b, err)
}
}

0 comments on commit 242ca55

Please sign in to comment.