diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ccc38d2..e6ec0ea 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,7 +1,7 @@ on: release: types: [created] - +name: Release jobs: releases-matrix: name: Release Go Binary diff --git a/diskrsync/main.go b/diskrsync/main.go index e595064..11129cc 100644 --- a/diskrsync/main.go +++ b/diskrsync/main.go @@ -39,13 +39,8 @@ type remoteProc struct { cmd *exec.Cmd } -type targetFile interface { - io.ReadWriteSeeker - io.Closer -} - func usage() { - fmt.Fprintf(os.Stderr, "Usage: %s [--ssh-flags=\"...\"] [--no-compress] [--verbose] \nsrc and dst is [[user@]host:]path\n", os.Args[0]) + _, _ = fmt.Fprintf(os.Stderr, "Usage: %s [--ssh-flags=\"...\"] [--no-compress] [--verbose] \nsrc and dst is [[user@]host:]path\n", os.Args[0]) os.Exit(2) } @@ -118,41 +113,51 @@ func (p *localProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan } func (p *localProc) run(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error) { + var err error if p.mode == modeSource { - errChan <- doSource(p.p, cmdReader, cmdWriter, p.opts) + err = doSource(p.p, cmdReader, cmdWriter, p.opts) } else { - errChan <- doTarget(p.p, cmdReader, cmdWriter, p.opts) + err = doTarget(p.p, cmdReader, cmdWriter, p.opts) + } + + cerr := cmdWriter.Close() + if err == nil { + err = cerr } + errChan <- err +} - cmdWriter.Close() +func (p *remoteProc) pipeCopy(dst io.WriteCloser, src io.Reader) { + _, err := io.Copy(dst, src) + if err != nil { + log.Printf("pipe copy failed: %v", err) + } + err = dst.Close() + if err != nil { + log.Printf("close failed after pipe copy: %v", err) + } } func (p *remoteProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error) error { - p.cmd.Stdout = cmdWriter p.cmd.Stderr = os.Stderr + p.cmd.Stdin = cmdReader - w, err := p.cmd.StdinPipe() + r, err := p.cmd.StdoutPipe() if err != nil { return err } - go func() { - io.Copy(w, cmdReader) - w.Close() - }() - err = p.cmd.Start() if err != nil { return err } - go p.run(cmdWriter, errChan) + go p.run(cmdWriter, r, errChan) return nil } -func (p *remoteProc) run(writer io.Closer, errChan chan error) { - err := p.cmd.Wait() - writer.Close() - errChan <- err +func (p *remoteProc) run(w io.WriteCloser, r io.Reader, errChan chan error) { + p.pipeCopy(w, r) + errChan <- p.cmd.Wait() } func doSource(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options) error { @@ -176,39 +181,42 @@ func doSource(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *opt src = sf } - size, err := src.Seek(0, os.SEEK_END) + size, err := src.Seek(0, io.SeekEnd) if err != nil { return err } - _, err = src.Seek(0, os.SEEK_SET) + _, err = src.Seek(0, io.SeekStart) if err != nil { return err } err = diskrsync.Source(src, size, cmdReader, cmdWriter, true, opts.verbose) - cmdWriter.Close() + cerr := cmdWriter.Close() + if err == nil { + err = cerr + } return err } -func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options) error { - var w targetFile - useBuffer := false +func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options) (err error) { + var w spgz.SparseFile + useReadBuffer := false f, err := os.OpenFile(p, os.O_RDWR|os.O_CREATE, 0666) if err != nil { - return err + return } info, err := f.Stat() if err != nil { - f.Close() - return err + _ = f.Close() + return } - if info.Mode() & (os.ModeDevice | os.ModeCharDevice) != 0 { - w = f - useBuffer = true + if info.Mode()&(os.ModeDevice|os.ModeCharDevice) != 0 { + w = spgz.NewSparseFileWithoutHolePunching(f) + useReadBuffer = true } else if !opts.noCompress { sf, err := spgz.NewFromFileSize(f, os.O_RDWR|os.O_CREATE, diskrsync.DefTargetBlockSize) if err != nil { @@ -216,43 +224,50 @@ func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *opt if err == spgz.ErrPunchHoleNotSupported { err = fmt.Errorf("target does not support compression. Try with -no-compress option (error was '%v')", err) } - f.Close() + _ = f.Close() return err } } else { - w = sf + w = &diskrsync.FixingSpgzFileWrapper{SpgzFile: sf} } } if w == nil { - w = spgz.NewSparseWriter(spgz.NewSparseFileWithFallback(f)) - useBuffer = true + w = spgz.NewSparseFileWithFallback(f) + useReadBuffer = true } - defer w.Close() + defer func() { + cerr := w.Close() + if err == nil { + err = cerr + } + }() - size, err := w.Seek(0, os.SEEK_END) + size, err := w.Seek(0, io.SeekEnd) if err != nil { return err } - _, err = w.Seek(0, os.SEEK_SET) + _, err = w.Seek(0, io.SeekStart) if err != nil { return err } - err = diskrsync.Target(w, size, cmdReader, cmdWriter, useBuffer, opts.verbose) - cmdWriter.Close() + err = diskrsync.Target(w, size, cmdReader, cmdWriter, useReadBuffer, opts.verbose) + cerr := cmdWriter.Close() + if err == nil { + err = cerr + } - return err + return } -func doCmd(opts *options) bool { +func doCmd(opts *options) (err error) { src, err := createProc(flag.Arg(0), modeSource, opts) if err != nil { - log.Printf("Could not create source: %v", err) - return false + return fmt.Errorf("could not create source: %w", err) } path := flag.Arg(1) @@ -262,8 +277,7 @@ func doCmd(opts *options) bool { dst, err := createProc(path, modeTarget, opts) if err != nil { - log.Printf("Could not create target: %v", err) - return false + return fmt.Errorf("could not create target: %w", err) } srcErrChan := make(chan error, 1) @@ -276,24 +290,42 @@ func doCmd(opts *options) bool { sw := &diskrsync.CountingWriteCloser{WriteCloser: srcWriter} if opts.verbose { - src.Start(sr, sw, srcErrChan) + err = src.Start(sr, sw, srcErrChan) } else { - src.Start(srcReader, srcWriter, srcErrChan) + err = src.Start(srcReader, srcWriter, srcErrChan) + } + + if err != nil { + return fmt.Errorf("could not start source: %w", err) } - dst.Start(dstReader, dstWriter, dstErrChan) - dstErr := <-dstErrChan - if dstErr != nil { - log.Printf("Target error: %v", dstErr) + err = dst.Start(dstReader, dstWriter, dstErrChan) + if err != nil { + return fmt.Errorf("could not start target: %w", err) } - srcErr := <-srcErrChan - if srcErr != nil { - log.Printf("Source error: %v", srcErr) + +L: + for srcErrChan != nil || dstErrChan != nil { + select { + case dstErr := <-dstErrChan: + if dstErr != nil { + err = fmt.Errorf("target error: %w", dstErr) + break L + } + dstErrChan = nil + case srcErr := <-srcErrChan: + if srcErr != nil { + err = fmt.Errorf("source error: %w", srcErr) + break L + } + srcErrChan = nil + } } + if opts.verbose { log.Printf("Read: %d, wrote: %d\n", sr.Count(), sw.Count()) } - return srcErr == nil && dstErr == nil + return } func main() { @@ -328,9 +360,9 @@ func main() { if flag.Arg(0) == "" || flag.Arg(1) == "" { usage() } - ok := doCmd(&opts) - if !ok { - os.Exit(1) + err := doCmd(&opts) + if err != nil { + log.Fatal(err) } } diff --git a/go.mod b/go.mod index 214f025..08c731e 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/dop251/diskrsync go 1.16 require ( - github.com/dop251/spgz v0.0.0-20180204132655-b86304a2b188 + github.com/dop251/spgz v1.1.0 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 ) diff --git a/go.sum b/go.sum index 21cc0d4..1545ada 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,16 @@ -github.com/dop251/spgz v0.0.0-20180204132655-b86304a2b188 h1:UYxfuh/hzW7AmbV9eAhr3mh8Qjr34z4fD5PlFbBaiUI= -github.com/dop251/spgz v0.0.0-20180204132655-b86304a2b188/go.mod h1:LNJUPCpuM80Fs0wTQ3+0oRp6h26KO5mAv4jOTLShJ8w= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dop251/buse v0.0.0-20170916130217-f7a5c857babd/go.mod h1:hQb8UeARubyuKpfzN+fkxK9TTytxdWRCalLjf/FQgwk= +github.com/dop251/nbd v0.0.0-20170916130042-b8933b281cb7/go.mod h1:/YqO/I24sucjxhCgQHgDrnffSwg5HzoYHQASayZnYl8= +github.com/dop251/spgz v1.1.0 h1:y49BXvoyhF+Y9No69DCJLqTCACleK27B73XWsXa2nFU= +github.com/dop251/spgz v1.1.0/go.mod h1:aXXbApWJzaK6jzPiIWWLFi3k47VmRFfunUp2ANdQFD8= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/sync.go b/sync.go index 2c219f5..aed770e 100644 --- a/sync.go +++ b/sync.go @@ -34,21 +34,21 @@ const ( ) var ( - ErrInvalidFormat = errors.New("Invalid data format") + ErrInvalidFormat = errors.New("invalid data format") ) type hashPool []hash.Hash type workCtx struct { - buf []byte - n *node + buf []byte + n *node hash hash.Hash avail, hashReady chan struct{} } type node struct { - buf [hashSize]byte + buf [hashSize]byte parent *node idx int @@ -57,7 +57,7 @@ type node struct { size int hash hash.Hash - sum []byte + sum []byte } type tree struct { @@ -81,7 +81,159 @@ type source struct { type target struct { base - writer io.ReadWriteSeeker + writer *batchingWriter +} + +// Accumulates successive writes into a large buffer so that the writes into the underlying spgz.SpgzFile +// cover compressed blocks completely, so they are not read and unpacked before writing. +type batchingWriter struct { + writer spgz.SparseFile + maxSize int + + offset int64 + holeSize int64 + buf []byte +} + +func (w *batchingWriter) Flush() error { + if w.holeSize > 0 { + err := w.writer.PunchHole(w.offset, w.holeSize) + if err == nil { + w.offset += w.holeSize + w.holeSize = 0 + } + return err + } + if len(w.buf) == 0 { + return nil + } + n, err := w.writer.WriteAt(w.buf, w.offset) + if err != nil { + return err + } + w.buf = w.buf[:0] + w.offset += int64(n) + return nil +} + +func (w *batchingWriter) prepareWrite() error { + if w.holeSize > 0 { + err := w.Flush() + if err != nil { + return err + } + } + if cap(w.buf) < w.maxSize { + buf := make([]byte, w.maxSize) + copy(buf, w.buf) + w.buf = buf[:len(w.buf)] + } + return nil +} + +func (w *batchingWriter) Write(p []byte) (int, error) { + if err := w.prepareWrite(); err != nil { + return 0, err + } + written := 0 + for len(p) > 0 { + if len(p) >= w.maxSize && len(w.buf) == 0 { + residue := len(p) % w.maxSize + n, err := w.writer.WriteAt(p[:len(p)-residue], w.offset) + written += n + w.offset += int64(n) + if err != nil { + return written, err + } + p = p[n:] + } else { + n := copy(w.buf[len(w.buf):w.maxSize], p) + w.buf = w.buf[:len(w.buf)+n] + if len(w.buf) == w.maxSize { + n1, err := w.writer.WriteAt(w.buf, w.offset) + w.offset += int64(n1) + n2 := n1 - (len(w.buf) - n) + w.buf = w.buf[:0] + if n2 < 0 { + n2 = 0 + } + written += n2 + if err != nil { + return written, err + } + } else { + written += n + } + p = p[n:] + } + } + + return written, nil +} + +func (w *batchingWriter) ReadFrom(src io.Reader) (int64, error) { + if err := w.prepareWrite(); err != nil { + return 0, err + } + + var read int64 + for { + n, err := src.Read(w.buf[len(w.buf):w.maxSize]) + read += int64(n) + w.buf = w.buf[:len(w.buf)+n] + if err == io.EOF { + return read, nil + } + if err != nil { + return read, err + } + if len(w.buf) == w.maxSize { + err = w.Flush() + if err != nil { + return read, err + } + } + } +} + +func (w *batchingWriter) WriteHole(size int64) error { + if w.holeSize == 0 { + err := w.Flush() + if err != nil { + return err + } + } + w.holeSize += size + return nil +} + +func (w *batchingWriter) Seek(offset int64, whence int) (int64, error) { + var o int64 + if w.holeSize > 0 { + o = w.offset + w.holeSize + } else { + o = w.offset + int64(len(w.buf)) + } + switch whence { + case io.SeekStart: + // no-op + case io.SeekCurrent: + offset = o + offset + case io.SeekEnd: + var err error + offset, err = w.writer.Seek(offset, whence) + if err != nil { + return offset, err + } + } + if offset != o { + err := w.Flush() + w.offset = offset + if err != nil { + return offset, err + } + } + return offset, nil } type counting struct { @@ -159,7 +311,7 @@ func (n *node) childReady(child *node, pool *hashPool, h hash.Hash) { } } n.hash.Write(child.sum) - if child.idx == len(n.children) - 1 { + if child.idx == len(n.children)-1 { n.sum = n.hash.Sum(n.buf[:0]) if n.parent != nil { n.parent.childReady(n, pool, n.hash) @@ -183,9 +335,10 @@ func (t *tree) build(offset, length int64, order, level int) *node { b := offset for i := 0; i < order; i++ { l := offset + (length * int64(i+1) / int64(order)) - b - n.children[i] = t.build(b, l, order, level) - n.children[i].parent = n - n.children[i].idx = i + child := t.build(b, l, order, level) + child.parent = n + child.idx = i + n.children[i] = child b += l } } else { @@ -242,7 +395,6 @@ func (t *tree) calc(verbose bool) error { order = 1 } - bs := int(float64(t.size) / math.Pow(float64(order), float64(levels-1))) if verbose { @@ -270,8 +422,8 @@ func (t *tree) calc(verbose bool) error { workItems := make([]*workCtx, 2) for i := range workItems { workItems[i] = &workCtx{ - buf: make([]byte, bs+1), - avail: make(chan struct{}, 1), + buf: make([]byte, bs+1), + avail: make(chan struct{}, 1), hashReady: make(chan struct{}, 1), } workItems[i].hash, _ = blake2b.New512(nil) @@ -282,7 +434,7 @@ func (t *tree) calc(verbose bool) error { idx := 0 for { wi := workItems[idx] - <- wi.hashReady + <-wi.hashReady if wi.n == nil { break } @@ -306,7 +458,7 @@ func (t *tree) calc(verbose bool) error { wi := workItems[workIdx] - <- wi.avail + <-wi.avail b := wi.buf[:n.size] r, err := io.ReadFull(reader, b) @@ -333,7 +485,7 @@ func (t *tree) calc(verbose bool) error { // wait until fully processed for i := range workItems { - <- workItems[i].avail + <-workItems[i].avail } // finish the goroutine @@ -359,17 +511,15 @@ func readHeader(reader io.Reader) (size int64, err error) { return } - br := bytes.NewBuffer(buf[len(hdrMagic):]) - err = binary.Read(br, binary.LittleEndian, &size) + size = int64(binary.LittleEndian.Uint64(buf[len(hdrMagic):])) return } func writeHeader(writer io.Writer, size int64) (err error) { - buf := make([]byte, 0, len(hdrMagic)+8) - bw := bytes.NewBuffer(buf) - bw.WriteString(hdrMagic) - binary.Write(bw, binary.LittleEndian, size) - _, err = writer.Write(bw.Bytes()) + buf := make([]byte, len(hdrMagic)+8) + copy(buf, hdrMagic) + binary.LittleEndian.PutUint64(buf[len(hdrMagic):], uint64(size)) + _, err = writer.Write(buf) return } @@ -378,7 +528,6 @@ func Source(reader io.ReadSeeker, size int64, cmdReader io.Reader, cmdWriter io. if err != nil { return } - var remoteSize int64 remoteSize, err = readHeader(cmdReader) if err != nil { @@ -563,10 +712,59 @@ func (s *source) subtree(root *node, offset, size int64) (err error) { return } -func Target(writer io.ReadWriteSeeker, size int64, cmdReader io.Reader, cmdWriter io.Writer, useBuffer bool, verbose bool) (err error) { +type TargetFile interface { + io.ReadWriteSeeker + io.WriterAt + io.Closer + spgz.Truncatable +} + +// FixingSpgzFileWrapper conceals read errors caused by compressed data corruption by re-writing the corrupt +// blocks with zeros. Such errors are usually caused by abrupt termination of the writing process. +// This wrapper is used as the sync target so the corrupt blocks will be updated during the sync process. +type FixingSpgzFileWrapper struct { + *spgz.SpgzFile +} + +func (rw *FixingSpgzFileWrapper) checkErr(err error) error { + var ce *spgz.ErrCorruptCompressedBlock + if errors.As(err, &ce) { + if ce.Size() == 0 { + return rw.SpgzFile.Truncate(ce.Offset()) + } + + buf := make([]byte, ce.Size()) + _, err = rw.SpgzFile.WriteAt(buf, ce.Offset()) + } + return err +} +func (rw *FixingSpgzFileWrapper) Read(p []byte) (n int, err error) { + for n == 0 && err == nil { // avoid returning (0, nil) after a fix + n, err = rw.SpgzFile.Read(p) + if err != nil { + err = rw.checkErr(err) + } + } + return +} + +func (rw *FixingSpgzFileWrapper) Seek(offset int64, whence int) (int64, error) { + o, err := rw.SpgzFile.Seek(offset, whence) + if err != nil { + err = rw.checkErr(err) + if err == nil { + o, err = rw.SpgzFile.Seek(offset, whence) + } + } + return o, err +} + +func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter io.Writer, useReadBuffer bool, verbose bool) (err error) { + + ch := make(chan error) go func() { - writeHeader(cmdWriter, size) + ch <- writeHeader(cmdWriter, size) }() var remoteSize int64 @@ -575,6 +773,11 @@ func Target(writer io.ReadWriteSeeker, size int64, cmdReader io.Reader, cmdWrite return } + err = <-ch + if err != nil { + return + } + commonSize := size if remoteSize < commonSize { commonSize = remoteSize @@ -590,14 +793,13 @@ func Target(writer io.ReadWriteSeeker, size int64, cmdReader io.Reader, cmdWrite t: tree{ reader: writer, size: commonSize, - useBuffer: useBuffer, + useBuffer: useReadBuffer, }, cmdReader: cmdReader, cmdWriter: cmdWriter, }, - writer: writer, + writer: &batchingWriter{writer: writer, maxSize: DefTargetBlockSize * 16}, } - err = t.t.calc(verbose) if err != nil { return @@ -607,6 +809,10 @@ func Target(writer io.ReadWriteSeeker, size int64, cmdReader io.Reader, cmdWrite if err != nil { return } + err = t.writer.Flush() + if err != nil { + return + } } if size < remoteSize { @@ -658,7 +864,7 @@ func Target(writer io.ReadWriteSeeker, size int64, cmdReader io.Reader, cmdWrite } hole = true } else { - return fmt.Errorf("Unexpected cmd: %d", cmd) + return fmt.Errorf("unexpected cmd: %d", cmd) } } } @@ -706,11 +912,7 @@ func (t *target) subtree(root *node, offset, size int64) (err error) { err = fmt.Errorf("while copying block data at %d: %w", offset, err) } } else { - buf := t.buffer(size) - for i := int64(0); i < size; i++ { - buf[i] = 0 - } - _, err = t.writer.Write(buf) + err = t.writer.WriteHole(size) } } else { b := offset diff --git a/sync_test.go b/sync_test.go index 4ed57c4..8df5b5e 100644 --- a/sync_test.go +++ b/sync_test.go @@ -8,17 +8,21 @@ import ( "io" "math/rand" "os" + "reflect" "testing" + "time" + + "github.com/dop251/spgz" "golang.org/x/crypto/blake2b" ) -type memFile struct { +type memSparseFile struct { data []byte offset int64 } -func (s *memFile) Read(buf []byte) (n int, err error) { +func (s *memSparseFile) Read(buf []byte) (n int, err error) { if s.offset >= int64(len(s.data)) { err = io.EOF return @@ -28,17 +32,24 @@ func (s *memFile) Read(buf []byte) (n int, err error) { return } -func (s *memFile) Write(buf []byte) (n int, err error) { - newSize := s.offset + int64(len(buf)) +func (s *memSparseFile) ensureSize(newSize int64) { if newSize > int64(len(s.data)) { if newSize <= int64(cap(s.data)) { + l := int64(len(s.data)) s.data = s.data[:newSize] + for i := l; i < s.offset; i++ { + s.data[i] = 0 + } } else { d := make([]byte, newSize) copy(d, s.data) s.data = d } } +} + +func (s *memSparseFile) Write(buf []byte) (n int, err error) { + s.ensureSize(s.offset + int64(len(buf))) n = copy(s.data[s.offset:], buf) if n < len(buf) { err = io.ErrShortWrite @@ -47,22 +58,22 @@ func (s *memFile) Write(buf []byte) (n int, err error) { return } -func (s *memFile) Seek(offset int64, whence int) (int64, error) { +func (s *memSparseFile) Seek(offset int64, whence int) (int64, error) { switch whence { - case os.SEEK_SET: + case io.SeekStart: s.offset = offset return s.offset, nil - case os.SEEK_CUR: + case io.SeekCurrent: s.offset += offset return s.offset, nil - case os.SEEK_END: + case io.SeekEnd: s.offset = int64(len(s.data)) + offset return s.offset, nil } - return s.offset, errors.New("Invalid whence") + return s.offset, errors.New("invalid whence") } -func (s *memFile) Truncate(size int64) error { +func (s *memSparseFile) Truncate(size int64) error { if size > int64(len(s.data)) { if size <= int64(cap(s.data)) { l := len(s.data) @@ -81,7 +92,44 @@ func (s *memFile) Truncate(size int64) error { return nil } -func (s *memFile) Bytes() []byte { +func (s *memSparseFile) PunchHole(offset, size int64) error { + if offset < int64(len(s.data)) { + d := offset + size - int64(len(s.data)) + if d > 0 { + size -= d + } + for i := offset; i < offset+size; i++ { + s.data[i] = 0 + } + } + return nil +} + +func (s *memSparseFile) ReadAt(p []byte, off int64) (n int, err error) { + if off < int64(len(s.data)) { + n = copy(p, s.data[off:]) + } + if n < len(p) { + err = io.EOF + } + return +} + +func (s *memSparseFile) WriteAt(p []byte, off int64) (n int, err error) { + s.ensureSize(off + int64(len(p))) + n = copy(s.data[off:], p) + return +} + +func (s *memSparseFile) Close() error { + return nil +} + +func (s *memSparseFile) Sync() error { + return nil +} + +func (s *memSparseFile) Bytes() []byte { return s.data } @@ -183,7 +231,6 @@ func TestNoChange(t *testing.T) { } } - func TestSmallFile(t *testing.T) { src := make([]byte, 128) dst := make([]byte, 128) @@ -195,30 +242,448 @@ func TestSmallFile(t *testing.T) { syncAndCheckEqual(src, dst, t) } +func TestCorruptCompressedBlock(t *testing.T) { + var f memSparseFile + sf, err := spgz.NewFromSparseFileSize(&f, os.O_RDWR|os.O_CREATE, 3*4096) + if err != nil { + t.Fatal(err) + } + src := make([]byte, 2*1024*1024) + + for i := 0; i < len(src); i++ { + src[i] = 'x' + } + + _, err = sf.WriteAt(src, 0) + if err != nil { + t.Fatal(err) + } + + err = sf.Close() + if err != nil { + t.Fatal(err) + } + + if f.data[4096] != 1 { + t.Fatalf("data: %d", f.data[4096]) + } + + f.data[4098] = ^f.data[4098] + + _, _ = f.Seek(0, io.SeekStart) + + sf, err = spgz.NewFromSparseFileSize(&f, os.O_RDWR, 4096) + if err != nil { + t.Fatal(err) + } + + syncAndCheckEqual1(&memSparseFile{data: src}, &FixingSpgzFileWrapper{SpgzFile: sf}, t) +} + +func TestCorruptLastCompressedBlock(t *testing.T) { + var f memSparseFile + sf, err := spgz.NewFromSparseFileSize(&f, os.O_RDWR|os.O_CREATE, 3*4096) + if err != nil { + t.Fatal(err) + } + src := make([]byte, 2*1024*1024) + + for i := 0; i < len(src); i++ { + src[i] = 'x' + } + + _, err = sf.WriteAt(src, 0) + if err != nil { + t.Fatal(err) + } + + err = sf.Close() + if err != nil { + t.Fatal(err) + } + + offset := 4096 + len(src)/(3*4096-1)*(3*4096) + + if f.data[offset] != 1 { + t.Fatalf("data: %d", f.data[offset]) + } + + f.data[offset+2] = ^f.data[offset+2] + + _, _ = f.Seek(0, io.SeekStart) + + sf, err = spgz.NewFromSparseFileSize(&f, os.O_RDWR, 4096) + if err != nil { + t.Fatal(err) + } + + syncAndCheckEqual1(&memSparseFile{data: src}, &FixingSpgzFileWrapper{SpgzFile: sf}, t) +} + +func TestRandomFiles(t *testing.T) { + var srcFile, dstFile memSparseFile + sf, err := spgz.NewFromSparseFile(&dstFile, os.O_RDWR|os.O_CREATE) + if err != nil { + t.Fatal(err) + } + rand.Seed(1234567890) + buf := make([]byte, 100*DefTargetBlockSize) + rand.Read(buf) + _, err = sf.WriteAt(buf, 0) + if err != nil { + t.Fatal(err) + } + + o, err := sf.Seek(0, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + if o != 0 { + t.Fatalf("o: %d", o) + } + + rand.Read(buf) + _, err = srcFile.WriteAt(buf, 0) + if err != nil { + t.Fatal(err) + } + + syncAndCheckEqual1(&srcFile, sf, t) +} + +type testOplogItem struct { + offset int64 + length int +} + +type testLoggingSparseFile struct { + spgz.SparseFile + wrlog []testOplogItem +} + +func (f *testLoggingSparseFile) WriteAt(buf []byte, offset int64) (int, error) { + f.wrlog = append(f.wrlog, testOplogItem{ + offset: offset, + length: len(buf), + }) + return f.SparseFile.WriteAt(buf, offset) +} + +func TestBatchingWriter(t *testing.T) { + var sf memSparseFile + lsf := &testLoggingSparseFile{ + SparseFile: &sf, + } + wr := &batchingWriter{ + writer: lsf, + maxSize: 100, + } + + reset := func(t *testing.T) { + _, err := wr.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + lsf.wrlog = lsf.wrlog[:0] + sf.data = nil + sf.offset = 0 + } + + t.Run("large_chunk", func(t *testing.T) { + buf := make([]byte, 502) + rand.Read(buf) + reset(t) + n, err := wr.Write(buf) + if err != nil { + t.Fatal(err) + } + if n != len(buf) { + t.Fatal(n) + } + err = wr.Flush() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sf.Bytes(), buf) { + t.Fatal("not equal") + } + if !reflect.DeepEqual(lsf.wrlog, []testOplogItem{ + {offset: 0, length: 500}, + {offset: 500, length: 2}, + }) { + t.Fatalf("Oplog: %#v", lsf.wrlog) + } + }) + + t.Run("exact", func(t *testing.T) { + buf := make([]byte, 100) + rand.Read(buf) + reset(t) + n, err := wr.Write(buf) + if err != nil { + t.Fatal(err) + } + if n != len(buf) { + t.Fatal(n) + } + err = wr.Flush() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sf.Bytes(), buf) { + t.Fatal("not equal") + } + if !reflect.DeepEqual(lsf.wrlog, []testOplogItem{ + {offset: 0, length: 100}, + }) { + t.Fatalf("Oplog: %#v", lsf.wrlog) + } + }) + + t.Run("two_small", func(t *testing.T) { + buf := make([]byte, 100) + rand.Read(buf) + reset(t) + n, err := wr.Write(buf[:50]) + if err != nil { + t.Fatal(err) + } + if n != 50 { + t.Fatal(n) + } + + n, err = wr.Write(buf[50:]) + if err != nil { + t.Fatal(err) + } + if n != 50 { + t.Fatal(n) + } + + err = wr.Flush() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sf.Bytes(), buf) { + t.Fatal("not equal") + } + if !reflect.DeepEqual(lsf.wrlog, []testOplogItem{ + {offset: 0, length: 100}, + }) { + t.Fatalf("Oplog: %#v", lsf.wrlog) + } + }) + + t.Run("seek", func(t *testing.T) { + buf := make([]byte, 100) + rand.Read(buf) + reset(t) + n, err := wr.Write(buf[:50]) + if err != nil { + t.Fatal(err) + } + if n != 50 { + t.Fatal(n) + } + + o, err := wr.Seek(0, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + if o != 50 { + t.Fatal(o) + } + + o, err = wr.Seek(55, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if o != 55 { + t.Fatal(o) + } + + n, err = wr.Write(buf[50:]) + if err != nil { + t.Fatal(err) + } + if n != 50 { + t.Fatal(n) + } + err = wr.Flush() + if err != nil { + t.Fatal(err) + } + + exp := make([]byte, 105) + copy(exp, buf[:50]) + copy(exp[55:], buf[50:]) + + if !bytes.Equal(sf.Bytes(), exp) { + t.Fatal("not equal") + } + if !reflect.DeepEqual(lsf.wrlog, []testOplogItem{ + {offset: 0, length: 50}, + {offset: 55, length: 50}, + }) { + t.Fatalf("Oplog: %#v", lsf.wrlog) + } + }) +} + +func TestFuzz(t *testing.T) { + const ( + fileSize = 30 * 1024 * 1024 + + numBlocks = 50 + numBlocksDelta = 128 + + blockSize = 64 * 1024 + blockSizeDelta = 32 * 1024 + ) + + if testing.Short() { + t.Skip() + } + seed := time.Now().UnixNano() + t.Logf("Seed: %d", seed) + rnd := rand.New(rand.NewSource(seed)) + roll := func(mean, delta int) int { + return mean + int(rnd.Int31n(int32(delta))) - delta/2 + } + + srcBuf := make([]byte, fileSize) + dstBuf := make([]byte, fileSize) + + blockBuf := make([]byte, 0, blockSize+blockSizeDelta/2) + zeroBlockBuf := make([]byte, 0, blockSize+blockSizeDelta/2) + + mutateBlock := func(buf []byte) { + size := roll(blockSize, blockSizeDelta) + offset := int(rnd.Int31n(int32(len(buf)))) + blk := blockBuf[:size] + typ := rnd.Int31n(16) + if typ >= 5 { + rnd.Read(blk) + } else if typ >= 3 { + for i := range blk { + blk[i] = 'x' + } + } else { + blk = zeroBlockBuf[:size] + } + copy(buf[offset:], blk) + } + + for i := 0; i < 50; i++ { + t.Logf("Running file %d", i) + dice := rnd.Int31n(16) + var srcSize, dstSize int + srcSize = int(rnd.Int31n(fileSize)) + if dice > 4 { + dstSize = int(rnd.Int31n(fileSize)) + } else { + dstSize = srcSize + } + srcBuf = srcBuf[:srcSize] + dstBuf = dstBuf[:dstSize] + rnd.Read(srcBuf) + nBlocks := roll(numBlocks, numBlocksDelta) + for i := 0; i < nBlocks; i++ { + mutateBlock(srcBuf) + } + + copy(dstBuf, srcBuf) + + nBlocks = roll(numBlocks, numBlocksDelta) + for i := 0; i < nBlocks; i++ { + mutateBlock(dstBuf) + } + + var mf memSparseFile + sf, err := spgz.NewFromSparseFile(&mf, os.O_RDWR|os.O_CREATE) + if err != nil { + t.Fatal(err) + } + _, err = sf.WriteAt(dstBuf, 0) + if err != nil { + t.Fatal(err) + } + sent, received := syncAndCheckEqual(srcBuf, dstBuf, t) + t.Logf("src size: %d, sent: %d, received: %d", len(srcBuf), sent, received) + sent1, received1 := syncAndCheckEqual1(&memSparseFile{data: srcBuf}, sf, t) + if sent != sent1 { + t.Fatalf("Sent counts did not match: %d, %d", sent, sent1) + } + if received != received1 { + t.Fatalf("Received counts did not match: %d, %d", received, received1) + } + } +} + func syncAndCheckEqual(src, dst []byte, t *testing.T) (sent, received int64) { + return syncAndCheckEqual1(&memSparseFile{data: src}, &memSparseFile{data: dst}, t) +} + +func getSize(s io.Seeker) (int64, error) { + o, err := s.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + size, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + _, err = s.Seek(o, io.SeekStart) + return size, err +} + +func getBytes(r io.ReadSeeker) ([]byte, error) { + if b, ok := r.(interface { + Bytes() []byte + }); ok { + return b.Bytes(), nil + } + + _, err := r.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + + return io.ReadAll(r) +} + +func syncAndCheckEqual1(src io.ReadSeeker, dst spgz.SparseFile, t *testing.T) (sent, received int64) { srcReader, dstWriter := io.Pipe() dstReader, srcWriter := io.Pipe() - srcR := &memFile{data: src} - dstW := &memFile{data: dst} - dstReaderC := &CountingReader{Reader: dstReader} dstWriterC := &CountingWriteCloser{WriteCloser: dstWriter} srcErrChan := make(chan error, 1) + srcSize, err := getSize(src) + if err != nil { + t.Fatal(err) + } + dstSize, err := getSize(dst) + if err != nil { + t.Fatal(err) + } + go func() { - err := Source(srcR, int64(len(src)), srcReader, srcWriter, false, false) - srcWriter.Close() - if err != nil { - srcErrChan <- err - return + err := Source(src, srcSize, srcReader, srcWriter, false, false) + cerr := srcWriter.Close() + if err == nil { + err = cerr } - srcErrChan <- nil + srcErrChan <- err }() - err := Target(dstW, int64(len(dst)), dstReaderC, dstWriterC, false, false) - dstWriter.Close() + err = Target(dst, dstSize, dstReaderC, dstWriterC, false, false) + cerr := dstWriter.Close() + if err == nil { + err = cerr + } if err != nil { t.Fatal(err) @@ -228,9 +693,27 @@ func syncAndCheckEqual(src, dst []byte, t *testing.T) (sent, received int64) { t.Fatal(err) } - if !bytes.Equal(srcR.Bytes(), dstW.Bytes()) { - t.Fatal("Not equal") + srcBytes, err := getBytes(src) + if err != nil { + t.Fatal(err) + } + + dstBytes, err := getBytes(dst) + if err != nil { + t.Fatal(err) + } + + if len(srcBytes) != len(dstBytes) { + t.Fatalf("Len not equal: %d, %d", len(srcBytes), len(dstBytes)) } + for i := 0; i < len(srcBytes); i++ { + if srcBytes[i] != dstBytes[i] { + t.Fatalf("Data mismatch at %d: %d, %d", i, srcBytes[i], dstBytes[i]) + } + } + /*if !bytes.Equal(srcBytes, dstBytes) { + t.Fatal("Not equal") + }*/ return dstReaderC.Count(), dstWriterC.Count() }