Skip to content

Commit

Permalink
add libscan
Browse files Browse the repository at this point in the history
  • Loading branch information
ninedraft committed Nov 8, 2023
1 parent 1675303 commit cc239d9
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 0 deletions.
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ module github.com/ninedraft/substream
go 1.21.3

require (
github.com/stretchr/testify v1.8.4
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300
golang.org/x/time v0.3.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 h1:XQdibLKagjdevRB6vAjVY4qbSr8rQ610YzTkWcxzxSI=
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300/go.mod h1:FNa/dfN95vAYCNFrIKRrlRo+MBLbwmR9Asa5f2ljmBI=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
175 changes: 175 additions & 0 deletions internal/libscan/libscan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package libscan

import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math/rand"
"path"

"golang.org/x/time/rate"
)

// Scanner searches for files in the filesystem matching the glob pattern.
// Then it streams the content of the files to the io.Writer.
type Scanner struct {
FS fs.FS
Dir string
Glob string
BufSize int
Shuffle bool
Limiter *rate.Limiter

OnNext func(ctx context.Context, filename string) error
}

// Stream copies content of the files matching the glob pattern to the dst.
// It loops infinitely until the context is canceled or an error occurs.
func (scanner *Scanner) Stream(ctx context.Context, dst io.Writer) error {
files, errFiles := scanner.findAll(ctx)
if errFiles != nil {
return fmt.Errorf("finding files: %w", errFiles)
}

bufSize := scanner.BufSize
if bufSize <= 0 {
bufSize = 16 * 1024
}
buf := make([]byte, bufSize)

copyFile := func(filename string) error {
file, errFile := scanner.FS.Open(filename)
if errFile != nil {
return fmt.Errorf("opening file: %w", errFile)
}
defer file.Close()

errCopy := scanner.copy(ctx, dst, file, buf)
if errCopy != nil {
return fmt.Errorf("copying file: %w", errCopy)
}

return nil
}

n := len(files)
if n == 0 {
return nil
}

for i := 0; ctx.Err() == nil; i = (i + 1) % n {
if i == 0 && scanner.Shuffle {
shuffle(files)
}

filename := files[i]

if scanner.OnNext != nil {
errNext := scanner.OnNext(ctx, filename)
if errNext != nil {
return fmt.Errorf("next callback: %w", errNext)
}
}

errCopy := copyFile(filename)

switch {
case errors.Is(errCopy, fs.ErrNotExist):
continue
case errors.Is(errCopy, io.EOF):
continue
case errCopy != nil:
return fmt.Errorf("file %q: %w", filename, errCopy)
}
}

return ctx.Err()
}

func (scanner *Scanner) findAll(ctx context.Context) ([]string, error) {
if _, err := path.Match("", scanner.Glob); err != nil {
return nil, fmt.Errorf("glob: %w", err)
}

dir := scanner.Dir
if dir == "" {
dir = "."
}

var files []string

errWalk := fs.WalkDir(scanner.FS, dir, func(fpath string, d fs.DirEntry, err error) error {
if err != nil || d.IsDir() {
return err
}

if ctx.Err() != nil {
return ctx.Err()
}

ok, _ := path.Match(scanner.Glob, d.Name())
if ok {
files = append(files, fpath)
}

return nil
})

if errWalk != nil {
return nil, fmt.Errorf("walking dir %q: %w", dir, errWalk)
}

return files, nil
}

func (scanner *Scanner) copy(ctx context.Context, dst io.Writer, src io.Reader, buf []byte) error {
if scanner.Limiter == nil {
_, err := io.CopyBuffer(dst, src, buf)
return err
}

streamChunk := func(buf []byte) error {
n, errRead := src.Read(buf)

log.Printf("read %d bytes, err=%v", n, errRead)

switch {
case errors.Is(errRead, io.EOF):
_, errWrite := dst.Write(buf[:n])
if errWrite != nil {
return fmt.Errorf("writing: %w", errWrite)
}
return nil
case errRead != nil:
return fmt.Errorf("reading: %w", errRead)
}

if err := scanner.Limiter.Wait(ctx); err != nil {
return fmt.Errorf("rate limiter: %w", err)
}

_, errWrite := dst.Write(buf[:n])
if errWrite != nil {
return fmt.Errorf("writing: %w", errWrite)
}

log.Printf("wrote %d bytes", n)

return nil
}

for {
if err := streamChunk(buf); err != nil {
return fmt.Errorf("streaming chunk: %w", err)
}
}
}

func shuffle[E any](slice []E) {
rand.Shuffle(len(slice), func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
})
}
69 changes: 69 additions & 0 deletions internal/libscan/libscan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package libscan_test

import (
"context"
"io"
"testing"
"testing/fstest"

"github.com/ninedraft/substream/internal/libscan"
"github.com/stretchr/testify/require"
)

func TestScanner_Stream(t *testing.T) {
t.Parallel()

fsys := fstest.MapFS{
"foo.txt": {Data: []byte("foo\n")},
"sub/bar.txt": {Data: []byte("bar\n")},
"baz.dat": {Data: []byte("baz\n")},
}

scanner := &libscan.Scanner{
FS: fsys,
Glob: "*.txt",
Shuffle: false,
}

t.Run("context stops looping", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

got, input := io.Pipe()

buf := make([]byte, 1)
go func() {
defer cancel()
for {
_, err := got.Read(buf)
cancel()
if err != nil {
return
}
}
}()

err := scanner.Stream(ctx, input)
require.ErrorIs(t, err, context.Canceled)
})

t.Run("file contents are copied to dst", func(t *testing.T) {
t.Parallel()

gotStream, input := io.Pipe()

go func() {
defer input.Close()
_ = scanner.Stream(context.Background(), input)
}()

got := make([]byte, 1024)
_, errGot := io.ReadFull(gotStream, got)

require.NoError(t, errGot)
require.Containsf(t, string(got), "foo", "got %q", got)
require.Containsf(t, string(got), "bar", "got %q", got)
})
}

0 comments on commit cc239d9

Please sign in to comment.