Skip to content

Commit

Permalink
Consolidated the progress-bar's format into common/step_download.go. …
Browse files Browse the repository at this point in the history
…Removed DownloadClient's PercentProgress callback since cheggaaa's progress-bar already does that.
  • Loading branch information
arizvisa committed Dec 21, 2017
1 parent 0b02a49 commit 745d62f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 66 deletions.
124 changes: 68 additions & 56 deletions common/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ func HashForType(t string) hash.Hash {

// NewDownloadClient returns a new DownloadClient for the given
// configuration.
func NewDownloadClient(c *DownloadConfig) *DownloadClient {
func NewDownloadClient(c *DownloadConfig, bar *pb.ProgressBar) *DownloadClient {
const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */

if c.DownloaderMap == nil {
// Create downloader map
c.DownloaderMap = map[string]Downloader{
"file": &FileDownloader{bufferSize: nil},
"ftp": &FTPDownloader{userInfo: url.UserPassword("anonymous", "anonymous@"), mtu: mtu},
"http": &HTTPDownloader{userAgent: c.UserAgent},
"https": &HTTPDownloader{userAgent: c.UserAgent},
"smb": &SMBDownloader{bufferSize: nil},
"file": &FileDownloader{progress: bar, bufferSize: nil},
"ftp": &FTPDownloader{progress: bar, userInfo: url.UserPassword("anonymous", "anonymous@"), mtu: mtu},
"http": &HTTPDownloader{progress: bar, userAgent: c.UserAgent},
"https": &HTTPDownloader{progress: bar, userAgent: c.UserAgent},
"smb": &SMBDownloader{progress: bar, bufferSize: nil},
}
}

return &DownloadClient{config: c}
}

Expand Down Expand Up @@ -209,14 +209,6 @@ func (d *DownloadClient) Get() (string, error) {
return finalPath, err
}

// PercentProgress returns the download progress as a percentage.
func (d *DownloadClient) PercentProgress() int {
if d.downloader == nil {
return -1
}
return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
}

// VerifyChecksum tests that the path matches the checksum for the
// download.
func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
Expand All @@ -239,9 +231,11 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) {
// HTTPDownloader is an implementation of Downloader that downloads
// files over HTTP.
type HTTPDownloader struct {
progress uint64
current uint64
total uint64
userAgent string

progress *pb.ProgressBar
}

func (d *HTTPDownloader) Cancel() {
Expand All @@ -261,7 +255,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
}

// Reset our progress
d.progress = 0
d.current = 0

// Make the request. We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't
Expand Down Expand Up @@ -290,7 +284,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
if _, err = dst.Seek(0, os.SEEK_END); err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size()))

d.progress = uint64(fi.Size())
d.current = uint64(fi.Size())
}
}
}
Expand All @@ -304,9 +298,11 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err
}

d.total = d.progress + uint64(resp.ContentLength)
progressBar := pb.New64(int64(d.Total())).Start()
progressBar.Set64(int64(d.Progress()))
d.total = d.current + uint64(resp.ContentLength)

d.progress.Total = int64(d.total)
progressBar := d.progress.Start()
progressBar.Set64(int64(d.current))

var buffer [4096]byte
for {
Expand All @@ -315,8 +311,8 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
return err
}

d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
d.current += uint64(n)
progressBar.Set64(int64(d.current))

if _, werr := dst.Write(buffer[:n]); werr != nil {
return werr
Expand All @@ -326,12 +322,12 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error {
break
}
}

progressBar.Finish()
return nil
}

func (d *HTTPDownloader) Progress() uint64 {
return d.progress
return d.current
}

func (d *HTTPDownloader) Total() uint64 {
Expand All @@ -344,13 +340,15 @@ type FTPDownloader struct {
userInfo *url.Userinfo
mtu uint

active bool
progress uint64
total uint64
active bool
current uint64
total uint64

progress *pb.ProgressBar
}

func (d *FTPDownloader) Progress() uint64 {
return d.progress
return d.current
}

func (d *FTPDownloader) Total() uint64 {
Expand Down Expand Up @@ -438,13 +436,15 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
}
log.Printf("Found file : %s : %v bytes\n", entry.Name, entry.Size)

d.progress = 0
d.current = 0
d.total = entry.Size
progressBar := pb.New64(int64(d.Total())).Start()

d.progress.Total = int64(d.total)
progressBar := d.progress.Start()

// download specified file
d.active = true
reader, err := cli.RetrFrom(uri.Path, d.progress)
reader, err := cli.RetrFrom(uri.Path, d.current)
if err != nil {
return nil
}
Expand All @@ -458,19 +458,21 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
break
}

d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
d.current += uint64(n)
progressBar.Set64(int64(d.current))
}
d.active = false
e <- err
}(d, reader, dst, errch)

// spin until it's done
err = <-errch

progressBar.Finish()
reader.Close()

if err == nil && d.progress != d.total {
err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.progress, d.total)
if err == nil && d.current != d.total {
err = fmt.Errorf("FTP total transfer size was %d when %d was expected", d.current, d.total)
}

// log out
Expand All @@ -483,13 +485,15 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
type FileDownloader struct {
bufferSize *uint

active bool
progress uint64
total uint64
active bool
current uint64
total uint64

progress *pb.ProgressBar
}

func (d *FileDownloader) Progress() uint64 {
return d.progress
return d.current
}

func (d *FileDownloader) Total() uint64 {
Expand Down Expand Up @@ -549,7 +553,7 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
}

/* download the file using the operating system's facilities */
d.progress = 0
d.current = 0
d.active = true

f, err := os.Open(realpath)
Expand All @@ -564,16 +568,18 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
return err
}
d.total = uint64(fi.Size())
progressBar := pb.New64(int64(d.Total())).Start()

d.progress.Total = int64(d.total)
progressBar := d.progress.Start()

// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n, err = io.Copy(dst, f)
d.active = false

d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
d.current += uint64(n)
progressBar.Set64(int64(d.current))

// use a goro in case someone else wants to enable cancel/resume
} else {
Expand All @@ -585,8 +591,8 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
break
}

d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
d.current += uint64(n)
progressBar.Set64(int64(d.current))
}
d.active = false
e <- err
Expand All @@ -595,6 +601,7 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
// ...and we spin until it's done
err = <-errch
}
progressBar.Finish()
f.Close()
return err
}
Expand All @@ -604,13 +611,15 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
type SMBDownloader struct {
bufferSize *uint

active bool
progress uint64
total uint64
active bool
current uint64
total uint64

progress *pb.ProgressBar
}

func (d *SMBDownloader) Progress() uint64 {
return d.progress
return d.current
}

func (d *SMBDownloader) Total() uint64 {
Expand Down Expand Up @@ -663,7 +672,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
}

/* Open up the "\\"-prefixed path using the Windows filesystem */
d.progress = 0
d.current = 0
d.active = true

f, err := os.Open(realpath)
Expand All @@ -678,16 +687,18 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
return err
}
d.total = uint64(fi.Size())
progressBar := pb.New64(int64(d.Total())).Start()

d.progress.Total = int64(d.total)
progressBar := d.progress.Start()

// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n, err = io.Copy(dst, f)
d.active = false

d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
d.current += uint64(n)
progressBar.Set64(int64(d.current))

// use a goro in case someone else wants to enable cancel/resume
} else {
Expand All @@ -699,8 +710,8 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
break
}

d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
d.current += uint64(n)
progressBar.Set64(int64(d.current))
}
d.active = false
e <- err
Expand All @@ -709,6 +720,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
// ...and as usual we spin until it's done
err = <-errch
}
progressBar.Finish()
f.Close()
return err
}
33 changes: 23 additions & 10 deletions common/step_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (

"github.com/hashicorp/packer/packer"
"github.com/mitchellh/multistep"

"gopkg.in/cheggaaa/pb.v1"
)

// StepDownload downloads a remote file using the download client within
Expand Down Expand Up @@ -139,7 +141,23 @@ func (s *StepDownload) Cleanup(multistep.StateBag) {}
func (s *StepDownload) download(config *DownloadConfig, state multistep.StateBag) (string, error, bool) {
var path string
ui := state.Get("ui").(packer.Ui)
download := NewDownloadClient(config)

// design the appearance of the progress bar
bar := pb.New64(0)
bar.ShowPercent = true
bar.ShowCounters = true
bar.ShowSpeed = false
bar.ShowBar = true
bar.ShowTimeLeft = false
bar.ShowFinalTime = false
bar.SetUnits(pb.U_BYTES)
bar.Format("[=>-]")
bar.SetRefreshRate(1 * time.Second)
bar.SetWidth(25)
bar.Callback = ui.Message

// create download client with config and progress bar
download := NewDownloadClient(config, bar)

downloadCompleteCh := make(chan error, 1)
go func() {
Expand All @@ -148,24 +166,19 @@ func (s *StepDownload) download(config *DownloadConfig, state multistep.StateBag
downloadCompleteCh <- err
}()

progressTicker := time.NewTicker(5 * time.Second)
defer progressTicker.Stop()

for {
select {
case err := <-downloadCompleteCh:
bar.Finish()

if err != nil {
return "", err, true
}

return path, nil, true
case <-progressTicker.C:
progress := download.PercentProgress()
if progress >= 0 {
ui.Message(fmt.Sprintf("Download progress: %d%%", progress))
}

case <-time.After(1 * time.Second):
if _, ok := state.GetOk(multistep.StateCancelled); ok {
bar.Finish()
ui.Say("Interrupt received. Cancelling download...")
return "", nil, false
}
Expand Down

0 comments on commit 745d62f

Please sign in to comment.