Skip to content

Commit

Permalink
[WIP] bundle: Parallel download and decompression
Browse files Browse the repository at this point in the history
This commit does the following:
- Return a reader from the bundle Download function.
- Use the reader to stream the bytes to Extract function.

This commit replaces grab client with the net/http client to ensure
that the bytes are streamed come in correct order to the Extract func.
Currently, only zst decompression is being used in the
UncompressWithReader function as it is the primary compression algorithm
being used in crc.
  • Loading branch information
vyasgun committed Jan 10, 2025
1 parent 8a1d173 commit bb33dee
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 79 deletions.
2 changes: 1 addition & 1 deletion cmd/crc-embedder/cmd/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func downloadDataFiles(goos string, components []string, destDir string) ([]stri
if !shouldDownload(components, componentName) {
continue
}
filename, err := download.Download(context.TODO(), dl.url, destDir, dl.permissions, nil)
_, filename, err := download.Download(context.TODO(), dl.url, destDir, dl.permissions, nil)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/crc/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ func (c *Cache) getExecutable(destDir string) (string, error) {
destPath := filepath.Join(destDir, archiveName)
err := embed.Extract(archiveName, destPath)
if err != nil {
return download.Download(context.TODO(), c.archiveURL, destDir, 0600, nil)
_, filename, err := download.Download(context.TODO(), c.archiveURL, destDir, 0600, nil)
return filename, err
}

return destPath, err
Expand Down
1 change: 1 addition & 0 deletions pkg/crc/image/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func (img *imageHandler) copyImage(ctx context.Context, destPath string, reportW
if ctx == nil {
panic("ctx is nil, this should not happen")
}

manifestData, err := copy.Image(ctx, policyContext,
destRef, srcRef, &copy.Options{
ReportWriter: reportWriter,
Expand Down
24 changes: 15 additions & 9 deletions pkg/crc/machine/bundle/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,43 +344,49 @@ func getVerifiedHash(url string, file string) (string, error) {
return "", fmt.Errorf("%s hash is missing or shasums are malformed", file)
}

func downloadDefault(ctx context.Context, preset crcPreset.Preset) (string, error) {
func downloadDefault(ctx context.Context, preset crcPreset.Preset) (io.Reader, string, error) {
downloadInfo, err := getBundleDownloadInfo(preset)
if err != nil {
return "", err
return nil, "", err
}
return downloadInfo.Download(ctx, constants.GetDefaultBundlePath(preset), 0664)
}

func Download(ctx context.Context, preset crcPreset.Preset, bundleURI string, enableBundleQuayFallback bool) (string, error) {
func Download(ctx context.Context, preset crcPreset.Preset, bundleURI string, enableBundleQuayFallback bool) (io.Reader, string, error) {
// If we are asked to download
// ~/.crc/cache/crc_podman_libvirt_4.1.1.crcbundle, this means we want
// are downloading the default bundle for this release. This uses a
// different codepath from user-specified URIs as for the default
// bundles, their sha256sums are known and can be checked.
var reader io.Reader
if bundleURI == constants.GetDefaultBundlePath(preset) {
switch preset {
case crcPreset.OpenShift, crcPreset.Microshift:
downloadedBundlePath, err := downloadDefault(ctx, preset)
var err error
var downloadedBundlePath string
reader, downloadedBundlePath, err = downloadDefault(ctx, preset)
if err != nil && enableBundleQuayFallback {
logging.Info("Unable to download bundle from mirror, falling back to quay")
return image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset))
bundle, err := image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset))
return nil, bundle, err
}
return downloadedBundlePath, err
return reader, downloadedBundlePath, err
case crcPreset.OKD:
fallthrough
default:
return image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset))
bundle, err := image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset))
return nil, bundle, err
}
}
switch {
case strings.HasPrefix(bundleURI, "http://"), strings.HasPrefix(bundleURI, "https://"):
return download.Download(ctx, bundleURI, constants.MachineCacheDir, 0644, nil)
case strings.HasPrefix(bundleURI, "docker://"):
return image.PullBundle(ctx, bundleURI)
bundle, err := image.PullBundle(ctx, bundleURI)
return nil, bundle, err
}
// the `bundleURI` parameter turned out to be a local path
return bundleURI, nil
return reader, bundleURI, nil
}

type Version struct {
Expand Down
42 changes: 40 additions & 2 deletions pkg/crc/machine/bundle/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -124,6 +125,36 @@ func (bundle *CrcBundleInfo) createSymlinkOrCopyPodmanRemote(binDir string) erro
return bundle.copyExecutableFromBundle(binDir, PodmanExecutable, constants.PodmanRemoteExecutableName)
}

func (repo *Repository) ExtractWithReader(ctx context.Context, reader io.Reader, path string) error {
logging.Debugf("Extracting bundle from reader")
bundleName := filepath.Base(path)

tmpDir := filepath.Join(repo.CacheDir, "tmp-extract")
_ = os.RemoveAll(tmpDir) // clean up before using it
defer func() {
_ = os.RemoveAll(tmpDir) // clean up after using it
}()

if _, err := extract.UncompressWithReader(ctx, reader, tmpDir); err != nil {
return err
}

bundleBaseDir := GetBundleNameWithoutExtension(bundleName)
bundleDir := filepath.Join(repo.CacheDir, bundleBaseDir)
_ = os.RemoveAll(bundleDir)
err := crcerrors.Retry(context.Background(), time.Minute, func() error {
if err := os.Rename(filepath.Join(tmpDir, bundleBaseDir), bundleDir); err != nil {
return &crcerrors.RetriableError{Err: err}
}
return nil
}, 5*time.Second)
if err != nil {
return err
}

return os.Chmod(bundleDir, 0755)
}

func (repo *Repository) Extract(ctx context.Context, path string) error {
bundleName := filepath.Base(path)

Expand Down Expand Up @@ -198,8 +229,15 @@ func Use(bundleName string) (*CrcBundleInfo, error) {
return defaultRepo.Use(bundleName)
}

func Extract(ctx context.Context, path string) (*CrcBundleInfo, error) {
if err := defaultRepo.Extract(ctx, path); err != nil {
func Extract(ctx context.Context, reader io.Reader, path string) (*CrcBundleInfo, error) {
var err error
if reader == nil {
err = defaultRepo.Extract(ctx, path)
} else {
err = defaultRepo.ExtractWithReader(ctx, reader, path)
}

if err != nil {
return nil, err
}
return defaultRepo.Get(filepath.Base(path))
Expand Down
6 changes: 4 additions & 2 deletions pkg/crc/machine/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ func getCrcBundleInfo(ctx context.Context, preset crcPreset.Preset, bundleName,
return bundleInfo, nil
}
logging.Debugf("Failed to load bundle %s: %v", bundleName, err)

logging.Infof("Downloading bundle: %s...", bundleName)
bundlePath, err = bundle.Download(ctx, preset, bundlePath, enableBundleQuayFallback)
reader, bundlePath, err := bundle.Download(ctx, preset, bundlePath, enableBundleQuayFallback)
if err != nil {
return nil, err
}

logging.Infof("Extracting bundle: %s...", bundleName)
if _, err := bundle.Extract(ctx, bundlePath); err != nil {
if _, err := bundle.Extract(ctx, reader, bundlePath); err != nil {
return nil, err
}
return bundle.Use(bundleName)
Expand Down
6 changes: 4 additions & 2 deletions pkg/crc/preflight/preflight_checks_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package preflight
import (
"context"
"fmt"
"io"
"os"
"path/filepath"

Expand Down Expand Up @@ -116,13 +117,14 @@ func fixBundleExtracted(bundlePath string, preset crcpreset.Preset, enableBundle
return fmt.Errorf("Cannot create directory %s: %v", bundleDir, err)
}
var err error
var reader io.Reader
logging.Infof("Downloading bundle: %s...", bundlePath)
if bundlePath, err = bundle.Download(context.TODO(), preset, bundlePath, enableBundleQuayFallback); err != nil {
if reader, bundlePath, err = bundle.Download(context.TODO(), preset, bundlePath, enableBundleQuayFallback); err != nil {
return err
}

logging.Infof("Uncompressing %s", bundlePath)
if _, err := bundle.Extract(context.TODO(), bundlePath); err != nil {
if _, err := bundle.Extract(context.TODO(), reader, bundlePath); err != nil {
if errors.Is(err, os.ErrNotExist) {
return errors.Wrap(err, "Use `crc setup -b <bundle-path>`")
}
Expand Down
91 changes: 30 additions & 61 deletions pkg/download/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,97 +2,66 @@ package download

import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"os"
"path/filepath"
"time"

"github.com/cavaliergopher/grab/v3"
"github.com/crc-org/crc/v2/pkg/crc/logging"
"github.com/crc-org/crc/v2/pkg/crc/network/httpproxy"
"github.com/crc-org/crc/v2/pkg/crc/version"
"github.com/crc-org/crc/v2/pkg/os/terminal"

"github.com/cavaliergopher/grab/v3"
"github.com/cheggaaa/pb/v3"
"github.com/pkg/errors"
)

func doRequest(client *grab.Client, req *grab.Request) (string, error) {
const minSizeForProgressBar = 100_000_000

resp := client.Do(req)
if resp.Size() < minSizeForProgressBar {
<-resp.Done
return resp.Filename, resp.Err()
}

t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
var bar *pb.ProgressBar
if terminal.IsShowTerminalOutput() {
bar = pb.Start64(resp.Size())
bar.Set(pb.Bytes, true)
// This is the same as the 'Default' template https://github.com/cheggaaa/pb/blob/224e0746e1e7b9c5309d6e2637264bfeb746d043/v3/preset.go#L8-L10
// except that the 'per second' suffix is changed to '/s' (by default it is ' p/s' which is unexpected)
progressBarTemplate := `{{with string . "prefix"}}{{.}} {{end}}{{counters . }} {{bar . }} {{percent . }} {{speed . "%s/s" "??/s"}}{{with string . "suffix"}} {{.}}{{end}}`
bar.SetTemplateString(progressBarTemplate)
defer bar.Finish()
}

loop:
for {
select {
case <-t.C:
if terminal.IsShowTerminalOutput() {
bar.SetCurrent(resp.BytesComplete())
}
case <-resp.Done:
break loop
}
}

return resp.Filename, resp.Err()
}

// Download function takes sha256sum as hex decoded byte
// something like hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5d")
func Download(ctx context.Context, uri, destination string, mode os.FileMode, sha256sum []byte) (string, error) {
func Download(ctx context.Context, uri, destination string, mode os.FileMode, _ []byte) (io.Reader, string, error) {
logging.Debugf("Downloading %s to %s", uri, destination)

client := grab.NewClient()
client.UserAgent = version.UserAgent()
client.HTTPClient = &http.Client{Transport: httpproxy.HTTPTransport()}
req, err := grab.NewRequest(destination, uri)
if err != nil {
return "", errors.Wrapf(err, "unable to get request from %s", uri)
}

if ctx == nil {
panic("ctx is nil, this should not happen")
}
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)

if err != nil {
return nil, "", errors.Wrapf(err, "unable to get request from %s", uri)
}
client := http.Client{Transport: &http.Transport{}}

req = req.WithContext(ctx)

if sha256sum != nil {
req.SetChecksum(sha256.New(), sha256sum, true)
resp, err := client.Do(req)
if err != nil {
return nil, "", err
}

filename, err := doRequest(client, req)
var filename, dir string
if filepath.Ext(destination) == ".crcbundle" {
dir = filepath.Dir(destination)
} else {
dir = destination
}
if disposition, params, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")); disposition == "attachment" {
filename = filepath.Join(dir, params["filename"])
} else {
filename = filepath.Join(dir, filepath.Base(resp.Request.URL.Path))
}
file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return "", err
return nil, "", err
}

if err := os.Chmod(filename, mode); err != nil {
_ = os.Remove(filename)
return "", err
return nil, "", err
}

logging.Debugf("Download saved to %v", filename)
return filename, nil
return io.TeeReader(resp.Body, file), filename, nil
}

// InMemory takes a URL and returns a ReadCloser object to the downloaded file
Expand Down Expand Up @@ -138,10 +107,10 @@ func NewRemoteFile(uri, sha256sum string) *RemoteFile {

}

func (r *RemoteFile) Download(ctx context.Context, bundlePath string, mode os.FileMode) (string, error) {
func (r *RemoteFile) Download(ctx context.Context, bundlePath string, mode os.FileMode) (io.Reader, string, error) {
sha256bytes, err := hex.DecodeString(r.sha256sum)
if err != nil {
return "", err
return nil, "", err
}
return Download(ctx, r.URI, bundlePath, mode, sha256bytes)
}
Expand Down
14 changes: 14 additions & 0 deletions pkg/extract/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ func Uncompress(ctx context.Context, tarball, targetDir string) ([]string, error
return uncompress(ctx, tarball, targetDir, nil, terminal.IsShowTerminalOutput())
}

func UncompressWithReader(ctx context.Context, reader io.Reader, targetDir string) ([]string, error) {
return uncompressWithReader(ctx, reader, targetDir, nil, terminal.IsShowTerminalOutput())
}

func uncompressWithReader(ctx context.Context, reader io.Reader, targetDir string, fileFilter func(string) bool, showProgress bool) ([]string, error) {
logging.Debugf("Uncompressing from reader to %s", targetDir)

reader, err := zstd.NewReader(reader)
if err != nil {
return nil, err
}
return untar(ctx, reader, targetDir, fileFilter, showProgress)
}

func uncompress(ctx context.Context, tarball, targetDir string, fileFilter func(string) bool, showProgress bool) ([]string, error) {
logging.Debugf("Uncompressing %s to %s", tarball, targetDir)

Expand Down
2 changes: 1 addition & 1 deletion test/extended/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func DownloadBundle(bundleLocation string, bundleDestination string, bundleName
return bundleDestination, err
}

filename, err := download.Download(context.TODO(), bundleLocation, bundleDestination, 0644, nil)
_, filename, err := download.Download(context.TODO(), bundleLocation, bundleDestination, 0644, nil)
fmt.Printf("Downloading bundle from %s to %s.\n", bundleLocation, bundleDestination)
if err != nil {
return "", err
Expand Down

0 comments on commit bb33dee

Please sign in to comment.