From b063ab0b4b2c56fadcde994fc63027cce2498e77 Mon Sep 17 00:00:00 2001 From: Johan Walles Date: Sun, 12 Jan 2025 15:57:20 +0100 Subject: [PATCH] Fix handling of short streams Before this change, on streams shorter than 6 bytes, we would present up to 6 bytes of garbage at the start of the buffer. With this change in place, we now present exactly what we read. Introduced in e7ecaa20153579260881dfb46ac52a29a4e969aa. Fixes #263. --- m/zopen.go | 3 ++- m/zopen_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 m/zopen_test.go diff --git a/m/zopen.go b/m/zopen.go index f16fc38..8a4a413 100644 --- a/m/zopen.go +++ b/m/zopen.go @@ -103,7 +103,7 @@ func ZOpen(filename string) (io.ReadCloser, string, error) { func ZReader(input io.Reader) (io.Reader, error) { // Read the first 6 bytes to determine the compression type firstBytes := make([]byte, 6) - _, err := input.Read(firstBytes) + count, err := input.Read(firstBytes) if err != nil { if err == io.EOF { // Stream was empty @@ -111,6 +111,7 @@ func ZReader(input io.Reader) (io.Reader, error) { } return nil, fmt.Errorf("failed to read stream: %w", err) } + firstBytes = firstBytes[:count] // Reset input reader to start of stream input = io.MultiReader(bytes.NewReader(firstBytes), input) diff --git a/m/zopen_test.go b/m/zopen_test.go new file mode 100644 index 0000000..1a5dc83 --- /dev/null +++ b/m/zopen_test.go @@ -0,0 +1,36 @@ +package m + +import ( + "bytes" + "io" + "testing" + + "gotest.tools/v3/assert" +) + +// Test that ZReader works with an empty stream +func TestZReaderEmpty(t *testing.T) { + bytesReader := bytes.NewReader([]byte{}) + + zReader, err := ZReader(bytesReader) + assert.NilError(t, err) + + all, err := io.ReadAll(zReader) + assert.NilError(t, err) + + assert.Equal(t, 0, len(all)) +} + +// Test that ZReader works with a one-byte stream +func TestZReaderOneByte(t *testing.T) { + bytesReader := bytes.NewReader([]byte{42}) + + zReader, err := ZReader(bytesReader) + assert.NilError(t, err) + + all, err := io.ReadAll(zReader) + assert.NilError(t, err) + + assert.Equal(t, 1, len(all)) + assert.Equal(t, byte(42), all[0]) +}