-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,14 @@ | ||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= | ||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300 h1:XQdibLKagjdevRB6vAjVY4qbSr8rQ610YzTkWcxzxSI= | ||
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300/go.mod h1:FNa/dfN95vAYCNFrIKRrlRo+MBLbwmR9Asa5f2ljmBI= | ||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= | ||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= | ||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= | ||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
package libscan | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"io/fs" | ||
"log" | ||
"math/rand" | ||
"path" | ||
|
||
"golang.org/x/time/rate" | ||
) | ||
|
||
// Scanner searches for files in the filesystem matching the glob pattern. | ||
// Then it streams the content of the files to the io.Writer. | ||
type Scanner struct { | ||
FS fs.FS | ||
Dir string | ||
Glob string | ||
BufSize int | ||
Shuffle bool | ||
Limiter *rate.Limiter | ||
|
||
OnNext func(ctx context.Context, filename string) error | ||
} | ||
|
||
// Stream copies content of the files matching the glob pattern to the dst. | ||
// It loops infinitely until the context is canceled or an error occurs. | ||
func (scanner *Scanner) Stream(ctx context.Context, dst io.Writer) error { | ||
files, errFiles := scanner.findAll(ctx) | ||
if errFiles != nil { | ||
return fmt.Errorf("finding files: %w", errFiles) | ||
} | ||
|
||
bufSize := scanner.BufSize | ||
if bufSize <= 0 { | ||
bufSize = 16 * 1024 | ||
} | ||
buf := make([]byte, bufSize) | ||
|
||
copyFile := func(filename string) error { | ||
file, errFile := scanner.FS.Open(filename) | ||
if errFile != nil { | ||
return fmt.Errorf("opening file: %w", errFile) | ||
} | ||
defer file.Close() | ||
|
||
errCopy := scanner.copy(ctx, dst, file, buf) | ||
if errCopy != nil { | ||
return fmt.Errorf("copying file: %w", errCopy) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
n := len(files) | ||
if n == 0 { | ||
return nil | ||
} | ||
|
||
for i := 0; ctx.Err() == nil; i = (i + 1) % n { | ||
if i == 0 && scanner.Shuffle { | ||
shuffle(files) | ||
} | ||
|
||
filename := files[i] | ||
|
||
if scanner.OnNext != nil { | ||
errNext := scanner.OnNext(ctx, filename) | ||
if errNext != nil { | ||
return fmt.Errorf("next callback: %w", errNext) | ||
} | ||
} | ||
|
||
errCopy := copyFile(filename) | ||
|
||
switch { | ||
case errors.Is(errCopy, fs.ErrNotExist): | ||
continue | ||
case errors.Is(errCopy, io.EOF): | ||
continue | ||
case errCopy != nil: | ||
return fmt.Errorf("file %q: %w", filename, errCopy) | ||
} | ||
} | ||
|
||
return ctx.Err() | ||
} | ||
|
||
func (scanner *Scanner) findAll(ctx context.Context) ([]string, error) { | ||
if _, err := path.Match("", scanner.Glob); err != nil { | ||
return nil, fmt.Errorf("glob: %w", err) | ||
} | ||
|
||
dir := scanner.Dir | ||
if dir == "" { | ||
dir = "." | ||
} | ||
|
||
var files []string | ||
|
||
errWalk := fs.WalkDir(scanner.FS, dir, func(fpath string, d fs.DirEntry, err error) error { | ||
if err != nil || d.IsDir() { | ||
return err | ||
} | ||
|
||
if ctx.Err() != nil { | ||
return ctx.Err() | ||
} | ||
|
||
ok, _ := path.Match(scanner.Glob, d.Name()) | ||
if ok { | ||
files = append(files, fpath) | ||
} | ||
|
||
return nil | ||
}) | ||
|
||
if errWalk != nil { | ||
return nil, fmt.Errorf("walking dir %q: %w", dir, errWalk) | ||
} | ||
|
||
return files, nil | ||
} | ||
|
||
func (scanner *Scanner) copy(ctx context.Context, dst io.Writer, src io.Reader, buf []byte) error { | ||
if scanner.Limiter == nil { | ||
_, err := io.CopyBuffer(dst, src, buf) | ||
return err | ||
} | ||
|
||
streamChunk := func(buf []byte) error { | ||
n, errRead := src.Read(buf) | ||
|
||
log.Printf("read %d bytes, err=%v", n, errRead) | ||
|
||
switch { | ||
case errors.Is(errRead, io.EOF): | ||
_, errWrite := dst.Write(buf[:n]) | ||
if errWrite != nil { | ||
return fmt.Errorf("writing: %w", errWrite) | ||
} | ||
return nil | ||
case errRead != nil: | ||
return fmt.Errorf("reading: %w", errRead) | ||
} | ||
|
||
if err := scanner.Limiter.Wait(ctx); err != nil { | ||
return fmt.Errorf("rate limiter: %w", err) | ||
} | ||
|
||
_, errWrite := dst.Write(buf[:n]) | ||
if errWrite != nil { | ||
return fmt.Errorf("writing: %w", errWrite) | ||
} | ||
|
||
log.Printf("wrote %d bytes", n) | ||
|
||
return nil | ||
} | ||
|
||
for { | ||
if err := streamChunk(buf); err != nil { | ||
return fmt.Errorf("streaming chunk: %w", err) | ||
} | ||
} | ||
} | ||
|
||
func shuffle[E any](slice []E) { | ||
rand.Shuffle(len(slice), func(i, j int) { | ||
slice[i], slice[j] = slice[j], slice[i] | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package libscan_test | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"testing" | ||
"testing/fstest" | ||
|
||
"github.com/ninedraft/substream/internal/libscan" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestScanner_Stream(t *testing.T) { | ||
t.Parallel() | ||
|
||
fsys := fstest.MapFS{ | ||
"foo.txt": {Data: []byte("foo\n")}, | ||
"sub/bar.txt": {Data: []byte("bar\n")}, | ||
"baz.dat": {Data: []byte("baz\n")}, | ||
} | ||
|
||
scanner := &libscan.Scanner{ | ||
FS: fsys, | ||
Glob: "*.txt", | ||
Shuffle: false, | ||
} | ||
|
||
t.Run("context stops looping", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
got, input := io.Pipe() | ||
|
||
buf := make([]byte, 1) | ||
go func() { | ||
defer cancel() | ||
for { | ||
_, err := got.Read(buf) | ||
cancel() | ||
if err != nil { | ||
return | ||
} | ||
} | ||
}() | ||
|
||
err := scanner.Stream(ctx, input) | ||
require.ErrorIs(t, err, context.Canceled) | ||
}) | ||
|
||
t.Run("file contents are copied to dst", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
gotStream, input := io.Pipe() | ||
|
||
go func() { | ||
defer input.Close() | ||
_ = scanner.Stream(context.Background(), input) | ||
}() | ||
|
||
got := make([]byte, 1024) | ||
_, errGot := io.ReadFull(gotStream, got) | ||
|
||
require.NoError(t, errGot) | ||
require.Containsf(t, string(got), "foo", "got %q", got) | ||
require.Containsf(t, string(got), "bar", "got %q", got) | ||
}) | ||
} |