diff --git a/diskrsync/main.go b/diskrsync/main.go index 11129cc..8e70114 100644 --- a/diskrsync/main.go +++ b/diskrsync/main.go @@ -1,6 +1,9 @@ package main import ( + "bufio" + "context" + "errors" "flag" "fmt" "io" @@ -12,6 +15,8 @@ import ( "github.com/dop251/diskrsync" "github.com/dop251/spgz" + + "github.com/vbauerster/mpb/v7" ) const ( @@ -26,7 +31,8 @@ type options struct { } type proc interface { - Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error) error + Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error, calcPl, syncPl diskrsync.ProgressListener) error + IsLocal() bool } type localProc struct { @@ -36,7 +42,16 @@ type localProc struct { } type remoteProc struct { - cmd *exec.Cmd + p string + mode int + opts *options + host string + cmd *exec.Cmd +} + +// used to prevent output to stderr while the progress bars are active +var bufStderr = bufferedOut{ + w: os.Stderr, } func usage() { @@ -58,7 +73,7 @@ func split(arg string) (host, path string) { return } -func createProc(arg string, mode int, opts *options) (proc, error) { +func createProc(arg string, mode int, opts *options) proc { host, path := split(arg) if host != "" { return createRemoteProc(host, path, mode, opts) @@ -66,58 +81,38 @@ func createProc(arg string, mode int, opts *options) (proc, error) { return createLocalProc(path, mode, opts) } -func createRemoteProc(host, path string, mode int, opts *options) (proc, error) { - var m string - if mode == modeSource { - m = "--source" - } else { - m = "--target" - if opts.noCompress { - m += " --no-compress" - } - } - if opts.verbose { - m += " --verbose" - } - - args := make([]string, 1, 8) - args[0] = "ssh" - - if opts.sshFlags != "" { - flags := strings.Split(opts.sshFlags, " ") - args = append(args, flags...) - } - - args = append(args, host, os.Args[0], m, path) - cmd := exec.Command("ssh") - cmd.Args = args - +func createRemoteProc(host, path string, mode int, opts *options) proc { return &remoteProc{ - cmd: cmd, - }, nil + host: host, + p: path, + mode: mode, + opts: opts, + } } -func createLocalProc(p string, mode int, opts *options) (proc, error) { - pr := &localProc{ +func createLocalProc(p string, mode int, opts *options) proc { + return &localProc{ p: p, mode: mode, opts: opts, } - - return pr, nil } -func (p *localProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error) error { - go p.run(cmdReader, cmdWriter, errChan) +func (p *localProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error, calcPl, syncPl diskrsync.ProgressListener) error { + go p.run(cmdReader, cmdWriter, errChan, calcPl, syncPl) return nil } -func (p *localProc) run(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error) { +func (p *localProc) IsLocal() bool { + return true +} + +func (p *localProc) run(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error, calcPl, syncPl diskrsync.ProgressListener) { var err error if p.mode == modeSource { - err = doSource(p.p, cmdReader, cmdWriter, p.opts) + err = doSource(p.p, cmdReader, cmdWriter, p.opts, calcPl, syncPl) } else { - err = doTarget(p.p, cmdReader, cmdWriter, p.opts) + err = doTarget(p.p, cmdReader, cmdWriter, p.opts, calcPl, syncPl) } cerr := cmdWriter.Close() @@ -138,16 +133,102 @@ func (p *remoteProc) pipeCopy(dst io.WriteCloser, src io.Reader) { } } -func (p *remoteProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error) error { - p.cmd.Stderr = os.Stderr - p.cmd.Stdin = cmdReader +func (p *remoteProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errChan chan error, calcPl, syncPl diskrsync.ProgressListener) error { + cmd := exec.Command("ssh") + p.cmd = cmd + args := cmd.Args + + if p.opts.sshFlags != "" { + flags := strings.Split(p.opts.sshFlags, " ") + args = append(args, flags...) + } + + args = append(args, p.host, os.Args[0]) + + if p.mode == modeSource { + args = append(args, "--source") + } else { + args = append(args, "--target") + if p.opts.noCompress { + args = append(args, " --no-compress") + } + } + if p.opts.verbose && calcPl == nil { + args = append(args, " --verbose") + } + if calcPl == nil { + cmd.Stderr = os.Stderr + } else { + stderr, err := cmd.StderrPipe() + if err != nil { + return err + } + args = append(args, "--calc-progress") + if syncPl != nil { + args = append(args, "--sync-progress") + } + + go func() { + r := bufio.NewReader(stderr) + readStart := func() (string, error) { + for { + line, err := r.ReadString('\n') + if name := strings.TrimPrefix(line, "[Start "); name != line && len(name) > 1 { + return name[:len(name)-2], nil + } + if len(line) > 0 { + _, werr := bufStderr.Write([]byte(line)) + if werr != nil { + return "", werr + } + } + if err != nil { + return "", err + } + } + } + pr := &progressReader{ + r: r, + w: &bufStderr, + pl: calcPl, + } + name, err := readStart() + if err != nil { + return + } + if name == "calc" { + err := pr.read() + if err != nil { + return + } + if syncPl != nil { + name, err = readStart() + if err != nil { + return + } + } + } + if syncPl != nil && name == "sync" { + pr.pl = syncPl + err = pr.read() + if err != nil { + return + } + } + _, _ = io.Copy(os.Stderr, r) + }() + } + + args = append(args, p.p) + cmd.Args = args + cmd.Stdin = cmdReader - r, err := p.cmd.StdoutPipe() + r, err := cmd.StdoutPipe() if err != nil { return err } - err = p.cmd.Start() + err = cmd.Start() if err != nil { return err } @@ -155,12 +236,16 @@ func (p *remoteProc) Start(cmdReader io.Reader, cmdWriter io.WriteCloser, errCha return nil } +func (p *remoteProc) IsLocal() bool { + return false +} + func (p *remoteProc) run(w io.WriteCloser, r io.Reader, errChan chan error) { p.pipeCopy(w, r) errChan <- p.cmd.Wait() } -func doSource(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options) error { +func doSource(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options, calcPl, syncPl diskrsync.ProgressListener) error { f, err := os.Open(p) if err != nil { return err @@ -191,7 +276,7 @@ func doSource(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *opt return err } - err = diskrsync.Source(src, size, cmdReader, cmdWriter, true, opts.verbose) + err = diskrsync.Source(src, size, cmdReader, cmdWriter, true, opts.verbose, calcPl, syncPl) cerr := cmdWriter.Close() if err == nil { err = cerr @@ -199,7 +284,7 @@ func doSource(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *opt return err } -func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options) (err error) { +func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *options, calcPl, syncPl diskrsync.ProgressListener) (err error) { var w spgz.SparseFile useReadBuffer := false @@ -255,7 +340,7 @@ func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *opt return err } - err = diskrsync.Target(w, size, cmdReader, cmdWriter, useReadBuffer, opts.verbose) + err = diskrsync.Target(w, size, cmdReader, cmdWriter, useReadBuffer, opts.verbose, calcPl, syncPl) cerr := cmdWriter.Close() if err == nil { err = cerr @@ -265,20 +350,14 @@ func doTarget(p string, cmdReader io.Reader, cmdWriter io.WriteCloser, opts *opt } func doCmd(opts *options) (err error) { - src, err := createProc(flag.Arg(0), modeSource, opts) - if err != nil { - return fmt.Errorf("could not create source: %w", err) - } + src := createProc(flag.Arg(0), modeSource, opts) path := flag.Arg(1) if _, p := split(path); strings.HasSuffix(p, "/") { path += filepath.Base(flag.Arg(0)) } - dst, err := createProc(path, modeTarget, opts) - if err != nil { - return fmt.Errorf("could not create target: %w", err) - } + dst := createProc(path, modeTarget, opts) srcErrChan := make(chan error, 1) dstErrChan := make(chan error, 1) @@ -289,17 +368,44 @@ func doCmd(opts *options) (err error) { sr := &diskrsync.CountingReader{Reader: srcReader} sw := &diskrsync.CountingWriteCloser{WriteCloser: srcWriter} + var ( + p *mpb.Progress + cancel func() + srcCalcPl, syncPl diskrsync.ProgressListener + ) + if opts.verbose { - err = src.Start(sr, sw, srcErrChan) - } else { - err = src.Start(srcReader, srcWriter, srcErrChan) + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + defer func() { + if cancel != nil { + cancel() + } + bufStderr.Release() + }() + p = mpb.NewWithContext(ctx) + if src.IsLocal() && !dst.IsLocal() { + syncPl = newSyncProgressBarListener(p) + } + srcCalcPl = newCalcProgressBarListener(p, "Source Checksums") + log.SetOutput(&bufStderr) } + err = src.Start(sr, sw, srcErrChan, srcCalcPl, syncPl) if err != nil { return fmt.Errorf("could not start source: %w", err) } - err = dst.Start(dstReader, dstWriter, dstErrChan) + var dstCalcPl, dstSyncPl diskrsync.ProgressListener + + if opts.verbose { + dstCalcPl = newCalcProgressBarListener(p, "Target Checksums") + if syncPl == nil { + dstSyncPl = newSyncProgressBarListener(p) + } + } + err = dst.Start(dstReader, dstWriter, dstErrChan, dstCalcPl, dstSyncPl) + if err != nil { return fmt.Errorf("could not start target: %w", err) } @@ -309,19 +415,31 @@ L: select { case dstErr := <-dstErrChan: if dstErr != nil { - err = fmt.Errorf("target error: %w", dstErr) - break L + if !errors.Is(dstErr, io.EOF) { + err = fmt.Errorf("target error: %w", dstErr) + break L + } } dstErrChan = nil case srcErr := <-srcErrChan: if srcErr != nil { - err = fmt.Errorf("source error: %w", srcErr) - break L + if !errors.Is(srcErr, io.EOF) { + err = fmt.Errorf("source error: %w", srcErr) + break L + } } srcErrChan = nil } } + if cancel != nil { + if err == nil { + p.Wait() + } + cancel() + cancel = nil + } + if opts.verbose { log.Printf("Read: %d, wrote: %d\n", sr.Count(), sw.Count()) } @@ -329,32 +447,47 @@ L: } func main() { + // These flags are for the remote command mode, not to be used directly. var sourceMode = flag.Bool("source", false, "Source mode") var targetMode = flag.Bool("target", false, "Target mode") + var calcProgress = flag.Bool("calc-progress", false, "Write calc progress") + var syncProgress = flag.Bool("sync-progress", false, "Write sync progress") var opts options flag.StringVar(&opts.sshFlags, "ssh-flags", "", "SSH flags") flag.BoolVar(&opts.noCompress, "no-compress", false, "Store target as a raw file") - flag.BoolVar(&opts.verbose, "verbose", false, "Print statistics and some debug info") + flag.BoolVar(&opts.verbose, "verbose", false, "Print statistics, progress, and some debug info") flag.Parse() - if *sourceMode { - if opts.verbose { - log.Println("Running source") - } - err := doSource(flag.Arg(0), os.Stdin, os.Stdout, &opts) - if err != nil { - log.Fatalf("Source failed: %s", err.Error()) + if *sourceMode || *targetMode { + var calcPl, syncPl diskrsync.ProgressListener + + if *calcProgress { + calcPl = &progressWriter{ + name: "calc", + w: os.Stderr, + } } - } else if *targetMode { - if opts.verbose { - log.Println("Running target") + + if *syncProgress { + syncPl = &progressWriter{ + name: "sync", + w: os.Stderr, + } } - err := doTarget(flag.Arg(0), os.Stdin, os.Stdout, &opts) - if err != nil { - log.Fatalf("Target failed: %s", err.Error()) + + if *sourceMode { + err := doSource(flag.Arg(0), os.Stdin, os.Stdout, &opts, calcPl, syncPl) + if err != nil { + log.Fatalf("Source failed: %s", err.Error()) + } + } else { + err := doTarget(flag.Arg(0), os.Stdin, os.Stdout, &opts, calcPl, syncPl) + if err != nil { + log.Fatalf("Target failed: %s", err.Error()) + } } } else { if flag.Arg(0) == "" || flag.Arg(1) == "" { diff --git a/diskrsync/progress.go b/diskrsync/progress.go new file mode 100644 index 0000000..187703c --- /dev/null +++ b/diskrsync/progress.go @@ -0,0 +1,202 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strconv" + "strings" + "sync" + "time" + + "github.com/dop251/diskrsync" + "github.com/vbauerster/mpb/v7" + "github.com/vbauerster/mpb/v7/decor" +) + +type bufferedOut struct { + sync.Mutex + buf bytes.Buffer + w io.Writer + + blocked bool +} + +func (b *bufferedOut) Write(p []byte) (int, error) { + b.Lock() + defer b.Unlock() + if !b.blocked { + return b.w.Write(p) + } + return b.buf.Write(p) +} + +func (b *bufferedOut) Block() { + b.Lock() + b.blocked = true + b.Unlock() +} + +func (b *bufferedOut) Release() { + b.Lock() + defer b.Unlock() + if b.blocked { + _, _ = b.w.Write(b.buf.Bytes()) + b.buf.Reset() + b.blocked = false + } +} + +type calcProgressBar struct { + p *mpb.Progress + bar *mpb.Bar + name string + lastUpdate time.Time +} + +func newCalcProgressBarListener(p *mpb.Progress, name string) *calcProgressBar { + return &calcProgressBar{ + p: p, + name: name, + } +} + +func (pb *calcProgressBar) Start(size int64) { + bufStderr.Block() + pb.bar = pb.p.New(size, + mpb.BarStyle().Rbound("|").Padding(" "), + mpb.BarPriority(1), + mpb.PrependDecorators( + decor.Name(pb.name, decor.WC{C: decor.DidentRight | decor.DextraSpace | decor.DSyncWidth}), + decor.CountersKibiByte("% .2f / % .2f", decor.WCSyncSpace), + ), + mpb.AppendDecorators( + decor.OnComplete( + decor.EwmaETA(decor.ET_STYLE_GO, 60, decor.WC{W: 4}), "done", + ), + decor.Name(" ] "), + decor.EwmaSpeed(decor.UnitKiB, "% .2f", 60, decor.WCSyncSpace), + ), + ) + pb.lastUpdate = time.Now() +} + +func (pb *calcProgressBar) Update(pos int64) { + pb.bar.SetCurrent(pos) + now := time.Now() + pb.bar.DecoratorEwmaUpdate(now.Sub(pb.lastUpdate)) + pb.lastUpdate = now +} + +type syncProgressBar struct { + p *mpb.Progress + bar *mpb.Bar +} + +func newSyncProgressBarListener(p *mpb.Progress) *syncProgressBar { + return &syncProgressBar{ + p: p, + } +} + +func (pb *syncProgressBar) Start(size int64) { + const name = "Sync" + bufStderr.Block() + pb.bar = pb.p.New(size, + mpb.BarStyle().Padding(" "), + mpb.BarPriority(2), + mpb.PrependDecorators( + decor.Name(name, decor.WC{C: decor.DidentRight | decor.DextraSpace | decor.DSyncWidth}), + decor.CountersKibiByte("% .2f / % .2f", decor.WCSyncSpace), + ), + mpb.AppendDecorators( + decor.Percentage(decor.WCSyncSpace), + ), + ) +} + +func (pb *syncProgressBar) Update(pos int64) { + pb.bar.SetCurrent(pos) +} + +type progressWriter struct { + name string + w io.Writer + + size int64 + lastWrite time.Time +} + +func (pw *progressWriter) Start(size int64) { + pw.size = size + pw.lastWrite = time.Now() + fmt.Fprintf(pw.w, "[Start %s]\nSize: %d\n", pw.name, size) +} + +func (pw *progressWriter) Update(pos int64) { + if pw.size <= 0 { + return + } + now := time.Now() + if pos >= pw.size || now.Sub(pw.lastWrite) >= 250*time.Millisecond { + fmt.Fprintf(pw.w, "Update: %d\n", pos) + pw.lastWrite = now + } +} + +type progressReader struct { + r *bufio.Reader + w io.Writer + pl diskrsync.ProgressListener +} + +func (pr *progressReader) read() error { + var size int64 + for { + str, err := pr.r.ReadString('\n') + if s := strings.TrimPrefix(str, "Size: "); s != str && len(s) > 0 { + sz, err := strconv.ParseInt(s[:len(s)-1], 10, 64) + if err == nil { + pr.pl.Start(sz) + size = sz + break + } + } + if len(str) > 0 { + _, werr := pr.w.Write([]byte(str)) + if werr != nil { + return werr + } + } + if err != nil { + return err + } + } + if size <= 0 { + return nil + } + for { + str, err := pr.r.ReadString('\n') + if s := strings.TrimPrefix(str, "Update: "); s != str && len(s) > 0 { + pos, err := strconv.ParseInt(s[:len(s)-1], 10, 64) + if err == nil { + pr.pl.Update(pos) + if pos >= size { + break + } + continue + } + } + if len(str) > 0 { + _, werr := pr.w.Write([]byte(str)) + if werr != nil { + return werr + } + } + if err != nil { + return err + } + } + return nil +} diff --git a/go.mod b/go.mod index 51859f4..b9dab7f 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,7 @@ go 1.16 require ( github.com/dop251/spgz v1.2.0 - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 + github.com/vbauerster/mpb/v7 v7.4.1 + golang.org/x/crypto v0.0.0-20220312131142-6068a2e6cfdc + golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect ) diff --git a/go.sum b/go.sum index 8472d2d..7841c48 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,32 @@ +github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= +github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= +github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= +github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dop251/buse v1.1.0/go.mod h1:MGzyYwutwcAUZa2KlHLEhUaYBr6JpZf4sqyh1v1lEZs= github.com/dop251/nbd v0.0.0-20170916130042-b8933b281cb7/go.mod h1:/YqO/I24sucjxhCgQHgDrnffSwg5HzoYHQASayZnYl8= github.com/dop251/spgz v1.2.0 h1:/VXInlcNmrhdehE228zLnTK9jTdpnNxtxG/t6XlFn14= github.com/dop251/spgz v1.2.0/go.mod h1:TvZEdiTP+5fkWTBiO9Po3zlegP9MXzwVKw9O97IJijQ= +github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= +github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +github.com/vbauerster/mpb/v7 v7.4.1 h1:NhLMWQ3gNg2KJR8oeA9lO8Xvq+eNPmixDmB6JEQOUdA= +github.com/vbauerster/mpb/v7 v7.4.1/go.mod h1:Ygg2mV9Vj9sQBWqsK2m2pidcf9H3s6bNKtqd3/M4gBo= +golang.org/x/crypto v0.0.0-20220312131142-6068a2e6cfdc h1:i6Z9eOQAdM7lvsbkT3fwFNtSAAC+A59TYilFj53HW+E= +golang.org/x/crypto v0.0.0-20220312131142-6068a2e6cfdc/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/sync.go b/sync.go index aed770e..8afd3fc 100644 --- a/sync.go +++ b/sync.go @@ -12,7 +12,6 @@ import ( "math" "github.com/dop251/spgz" - "golang.org/x/crypto/blake2b" ) @@ -37,6 +36,11 @@ var ( ErrInvalidFormat = errors.New("invalid data format") ) +type ProgressListener interface { + Start(size int64) + Update(position int64) +} + type hashPool []hash.Hash type workCtx struct { @@ -72,6 +76,8 @@ type base struct { buf []byte cmdReader io.Reader cmdWriter io.Writer + + syncProgressListener ProgressListener } type source struct { @@ -354,7 +360,7 @@ func (t *tree) first(n *node) *node { return n } -func (t *tree) calc(verbose bool) error { +func (t *tree) calc(verbose bool, progressListener ProgressListener) error { var targetBlockSize int64 = DefTargetBlockSize @@ -451,6 +457,10 @@ func (t *tree) calc(verbose bool) error { workIdx := 0 + if progressListener != nil { + progressListener.Start(t.size) + } + for n := t.first(t.root); n != nil; n = n.next() { if n.size == 0 { panic("Leaf node size is zero") @@ -466,6 +476,9 @@ func (t *tree) calc(verbose bool) error { return fmt.Errorf("in calc at %d (expected %d, read %d): %w", rr, len(b), r, err) } rr += int64(r) + if progressListener != nil { + progressListener.Update(rr) + } wi.n = n @@ -523,7 +536,7 @@ func writeHeader(writer io.Writer, size int64) (err error) { return } -func Source(reader io.ReadSeeker, size int64, cmdReader io.Reader, cmdWriter io.Writer, useBuffer bool, verbose bool) (err error) { +func Source(reader io.ReadSeeker, size int64, cmdReader io.Reader, cmdWriter io.Writer, useBuffer bool, verbose bool, calcPl, syncPl ProgressListener) (err error) { err = writeHeader(cmdWriter, size) if err != nil { return @@ -556,22 +569,28 @@ func Source(reader io.ReadSeeker, size int64, cmdReader io.Reader, cmdWriter io. reader: reader, } - err = s.t.calc(verbose) + err = s.t.calc(verbose, calcPl) if err != nil { return } + if syncPl != nil { + s.syncProgressListener = syncPl + syncPl.Start(size) + } + err = s.subtree(s.t.root, 0, commonSize) if err != nil { return } + } else { + if syncPl != nil { + syncPl.Start(size) + } } if size > commonSize { // Write the tail - if verbose { - log.Print("Writing tail...") - } _, err = reader.Seek(commonSize, io.SeekStart) if err != nil { return @@ -629,6 +648,9 @@ func Source(reader io.ReadSeeker, size int64, cmdReader io.Reader, cmdWriter io. return } curPos += int64(r) + if syncPl != nil { + syncPl.Update(curPos) + } if stop { break } @@ -660,6 +682,9 @@ func (s *source) subtree(root *node, offset, size int64) (err error) { if bytes.Equal(root.sum, remoteHash) { err = binary.Write(s.cmdWriter, binary.LittleEndian, cmdEqual) + if s.syncProgressListener != nil { + s.syncProgressListener.Update(offset + size) + } return } @@ -691,6 +716,9 @@ func (s *source) subtree(root *node, offset, size int64) (err error) { _, err = s.cmdWriter.Write(buf) } + if s.syncProgressListener != nil { + s.syncProgressListener.Update(offset + size) + } } else { err = binary.Write(s.cmdWriter, binary.LittleEndian, cmdNotEqual) if err != nil { @@ -760,7 +788,7 @@ func (rw *FixingSpgzFileWrapper) Seek(offset int64, whence int) (int64, error) { return o, err } -func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter io.Writer, useReadBuffer bool, verbose bool) (err error) { +func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter io.Writer, useReadBuffer bool, verbose bool, calcPl, syncPl ProgressListener) (err error) { ch := make(chan error) go func() { @@ -783,10 +811,6 @@ func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter i commonSize = remoteSize } - if verbose { - log.Printf("Local size: %d, remote size: %d", size, remoteSize) - } - if commonSize > 0 { t := target{ base: base{ @@ -800,11 +824,16 @@ func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter i }, writer: &batchingWriter{writer: writer, maxSize: DefTargetBlockSize * 16}, } - err = t.t.calc(verbose) + err = t.t.calc(verbose, calcPl) if err != nil { return } + if syncPl != nil { + t.syncProgressListener = syncPl + syncPl.Start(remoteSize) + } + err = t.subtree(t.t.root, 0, commonSize) if err != nil { return @@ -813,14 +842,19 @@ func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter i if err != nil { return } + if syncPl != nil { + syncPl.Update(commonSize) + } + } else { + if syncPl != nil { + syncPl.Start(remoteSize) + } } if size < remoteSize { // Read the tail - if verbose { - log.Printf("Reading tail (%d bytes)...", remoteSize-size) - } - _, err = writer.Seek(commonSize, io.SeekStart) + pos := commonSize + _, err = writer.Seek(pos, io.SeekStart) if err != nil { return } @@ -840,12 +874,17 @@ func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter i } if cmd == cmdBlock { - _, err = io.CopyN(writer, rd, DefTargetBlockSize) + var n int64 + n, err = io.CopyN(writer, rd, DefTargetBlockSize) + pos += n hole = false if err != nil { if err == io.EOF { err = nil + if syncPl != nil { + syncPl.Update(pos) + } break } else { return fmt.Errorf("target: while copying block: %w", err) @@ -863,10 +902,14 @@ func Target(writer spgz.SparseFile, size int64, cmdReader io.Reader, cmdWriter i return } hole = true + pos += holeSize } else { return fmt.Errorf("unexpected cmd: %d", cmd) } } + if syncPl != nil { + syncPl.Update(pos) + } } if hole { diff --git a/sync_test.go b/sync_test.go index 8df5b5e..8a8b8e9 100644 --- a/sync_test.go +++ b/sync_test.go @@ -671,7 +671,7 @@ func syncAndCheckEqual1(src io.ReadSeeker, dst spgz.SparseFile, t *testing.T) (s } go func() { - err := Source(src, srcSize, srcReader, srcWriter, false, false) + err := Source(src, srcSize, srcReader, srcWriter, false, false, nil, nil) cerr := srcWriter.Close() if err == nil { err = cerr @@ -679,7 +679,7 @@ func syncAndCheckEqual1(src io.ReadSeeker, dst spgz.SparseFile, t *testing.T) (s srcErrChan <- err }() - err = Target(dst, dstSize, dstReaderC, dstWriterC, false, false) + err = Target(dst, dstSize, dstReaderC, dstWriterC, false, false, nil, nil) cerr := dstWriter.Close() if err == nil { err = cerr