diff --git a/get/filestorage.go b/get/filestorage.go index 44773a8..f70ab03 100644 --- a/get/filestorage.go +++ b/get/filestorage.go @@ -2,10 +2,6 @@ package get import ( "crypto" - "crypto/sha1" - "crypto/sha256" - "encoding/hex" - "errors" "io" "log" "os" @@ -25,8 +21,9 @@ func NewFileStorage(directory string) Storage { return &FileStorage{directory, make([]byte, 4*1024*1024)} } -// Checksum returns the checksum value of a file in the permanent location, according to the checksumType algorithm -func (s *FileStorage) Checksum(filename string, hash crypto.Hash) (checksum string, err error) { +// NewReader returns a Reader for a file in the permanent location, returns ErrFileNotFound +// if the requested path was not found at all +func (s *FileStorage) NewReader(filename string) (reader io.ReadCloser, err error) { fullPath := path.Join(s.directory, filename) stat, err := os.Stat(fullPath) if os.IsNotExist(err) || stat == nil { @@ -38,25 +35,8 @@ func (s *FileStorage) Checksum(filename string, hash crypto.Hash) (checksum stri if err != nil { log.Fatal(err) } - defer f.Close() - switch hash { - case crypto.SHA1: - h := sha1.New() - if _, err = io.CopyBuffer(h, f, s.checksumBuffer); err != nil { - log.Fatal(err) - } - checksum = hex.EncodeToString(h.Sum(nil)) - case crypto.SHA256: - h := sha256.New() - if _, err = io.CopyBuffer(h, f, s.checksumBuffer); err != nil { - log.Fatal(err) - } - checksum = hex.EncodeToString(h.Sum(nil)) - default: - err = errors.New("Unknown ChecksumType") - } - return + return f, err } // StoringMapper returns a mapper that will store read data to a temporary location specified by filename @@ -74,7 +54,7 @@ func (s *FileStorage) StoringMapper(filename string, checksum string, hash crypt return } - result = util.NewTeeReadCloser(reader, file) + result = util.NewTeeReadCloser(reader, util.NewChecksummingWriter(file, checksum, hash)) return } } diff --git a/get/s3storage.go b/get/s3storage.go index 60b62b6..53eef81 100644 --- a/get/s3storage.go +++ b/get/s3storage.go @@ -141,14 +141,15 @@ func (s *S3Storage) newPrefix() string { return "a/" } -// Checksum returns the checksum value of a file in the permanent location, according to the checksumType algorithm -func (s *S3Storage) Checksum(filename string, hash crypto.Hash) (checksum string, err error) { - input := &s3.HeadObjectInput{ +// NewReader returns a Reader for a file in the permanent location, returns ErrFileNotFound +// if the requested path was not found at all +func (s *S3Storage) NewReader(filename string) (reader io.ReadCloser, err error) { + input := &s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(s.prefix + filename), } - info, err := s.svc.HeadObject(input) + info, err := s.svc.GetObject(input) if err != nil { if aerr, ok := err.(awserr.Error); ok { switch aerr.Code() { @@ -159,11 +160,7 @@ func (s *S3Storage) Checksum(filename string, hash crypto.Hash) (checksum string return } - if val, ok := info.Metadata["Checksum"]; ok { - checksum = *val - } - - return + return info.Body, err } // StoringMapper returns a mapper that will store read data to a temporary location specified by filename @@ -179,9 +176,6 @@ func (s *S3Storage) StoringMapper(filename string, checksum string, hash crypto. Bucket: aws.String(s.bucket), Key: aws.String(s.newPrefix() + filename), Body: pipeReader, - Metadata: map[string]*string{ - "Checksum": &checksum, - }, }) errs <- err }() diff --git a/get/storage.go b/get/storage.go index c4069a4..16ca98b 100644 --- a/get/storage.go +++ b/get/storage.go @@ -3,6 +3,7 @@ package get import ( "crypto" "errors" + "io" "github.com/moio/minima/util" ) @@ -15,9 +16,9 @@ type Storage interface { StoringMapper(filename string, checksum string, hash crypto.Hash) util.ReaderMapper // Commit moves any temporary file accumulated so far to the permanent location Commit() (err error) - // Checksum returns the checksum value of a file in the permanent location, according to the checksumType algorithm - // returns ErrFileNotFound if the requested path was not found at all - Checksum(filename string, hash crypto.Hash) (checksum string, err error) + // NewReader returns a Reader for a file in the permanent location, returns ErrFileNotFound + // if the requested path was not found at all + NewReader(filename string) (reader io.ReadCloser, err error) // Recycle will copy a file from the permanent to the temporary location Recycle(filename string) (err error) } diff --git a/get/syncer.go b/get/syncer.go index bf5399b..b72740b 100644 --- a/get/syncer.go +++ b/get/syncer.go @@ -73,15 +73,25 @@ func NewSyncer(url string, archs map[string]bool, storage Storage) *Syncer { // StoreRepo stores an HTTP repo in a Storage, automatically retrying in case of recoverable errors func (r *Syncer) StoreRepo() (err error) { + checksumMap := r.readChecksumMap() for i := 0; i < 10; i++ { - err = r.storeRepo() + err = r.storeRepo(checksumMap) if err == nil { return } uerr, unexpectedStatusCode := err.(*UnexpectedStatusCodeError) - if unexpectedStatusCode && uerr.StatusCode == 404 { - log.Printf("Got 404, presumably temporarily, retrying...\n") + if unexpectedStatusCode { + if uerr.StatusCode == 404 { + log.Printf("Got 404, presumably temporarily, retrying...\n") + } else { + return err + } + } + + _, checksumError := err.(*util.ChecksumError) + if checksumError { + log.Printf("Checksum did not match, presumably the repo was published while syncing, retrying...\n") } else { return err } @@ -92,8 +102,8 @@ func (r *Syncer) StoreRepo() (err error) { } // StoreRepo stores an HTTP repo in a Storage -func (r *Syncer) storeRepo() (err error) { - packagesToDownload, packagesToRecycle, err := r.processMetadata() +func (r *Syncer) storeRepo(checksumMap map[string]XMLChecksum) (err error) { + packagesToDownload, packagesToRecycle, err := r.processMetadata(checksumMap) if err != nil { return } @@ -136,12 +146,13 @@ func (r *Syncer) downloadStoreApply(path string, checksum string, hash crypto.Ha if err != nil { return err } + return util.Compose(r.storage.StoringMapper(path, checksum, hash), f)(body) } // processMetadata stores the repo metadata and returns a list of package file // paths to download -func (r *Syncer) processMetadata() (packagesToDownload []XMLPackage, packagesToRecycle []XMLPackage, err error) { +func (r *Syncer) processMetadata(checksumMap map[string]XMLChecksum) (packagesToDownload []XMLPackage, packagesToRecycle []XMLPackage, err error) { err = r.downloadStoreApply(repomdPath, "", 0, func(reader io.ReadCloser) (err error) { decoder := xml.NewDecoder(reader) var repomd XMLRepomd @@ -154,7 +165,7 @@ func (r *Syncer) processMetadata() (packagesToDownload []XMLPackage, packagesToR for i := 0; i < len(data); i++ { metadataPath := data[i].Location.Href if data[i].Type == "primary" { - packagesToDownload, packagesToRecycle, err = r.processPrimary(metadataPath) + packagesToDownload, packagesToRecycle, err = r.processPrimary(metadataPath, checksumMap) } else { err = r.downloadStore(metadataPath) } @@ -192,19 +203,67 @@ func (r *Syncer) processMetadata() (packagesToDownload []XMLPackage, packagesToR return } +func (r *Syncer) readMetaData(reader io.Reader) (primary XMLMetaData, err error) { + gzReader, err := gzip.NewReader(reader) + if err != nil { + return + } + defer gzReader.Close() + + decoder := xml.NewDecoder(gzReader) + err = decoder.Decode(&primary) + + return +} + +func (r *Syncer) readChecksumMap() (checksumMap map[string]XMLChecksum) { + checksumMap = make(map[string]XMLChecksum) + repomdReader, err := r.storage.NewReader(repomdPath) + if err != nil { + if err == ErrFileNotFound { + log.Println("First-time sync started") + } else { + log.Println(err.Error()) + log.Println("Error while reading previously-downloaded metadata. Starting sync from scratch") + } + return + } + defer repomdReader.Close() + + decoder := xml.NewDecoder(repomdReader) + var repomd XMLRepomd + err = decoder.Decode(&repomd) + if err != nil { + log.Println(err.Error()) + log.Println("Error while parsing previously-downloaded metadata. Starting sync from scratch") + return + } + + data := repomd.Data + for i := 0; i < len(data); i++ { + metadataPath := data[i].Location.Href + if data[i].Type == "primary" { + primaryReader, err := r.storage.NewReader(metadataPath) + if err != nil { + return + } + primary, err := r.readMetaData(primaryReader) + if err != nil { + return + } + for _, pack := range primary.Packages { + checksumMap[pack.Location.Href] = pack.Checksum + } + } + } + return +} + // processPrimary stores the primary XML metadata file and returns a list of // package file paths to download -func (r *Syncer) processPrimary(path string) (packagesToDownload []XMLPackage, packagesToRecycle []XMLPackage, err error) { +func (r *Syncer) processPrimary(path string, checksumMap map[string]XMLChecksum) (packagesToDownload []XMLPackage, packagesToRecycle []XMLPackage, err error) { err = r.downloadStoreApply(path, "", 0, func(reader io.ReadCloser) (err error) { - gzReader, err := gzip.NewReader(reader) - if err != nil { - return - } - defer gzReader.Close() - - decoder := xml.NewDecoder(gzReader) - var primary XMLMetaData - err = decoder.Decode(&primary) + primary, err := r.readMetaData(reader) if err != nil { return } @@ -212,22 +271,17 @@ func (r *Syncer) processPrimary(path string) (packagesToDownload []XMLPackage, p allArchs := len(r.archs) == 0 for _, pack := range primary.Packages { if allArchs || pack.Arch == "noarch" || r.archs[pack.Arch] { - storageChecksum, err := r.storage.Checksum(pack.Location.Href, hashMap[pack.Checksum.Type]) + previousChecksum, ok := checksumMap[pack.Location.Href] switch { - case err == ErrFileNotFound: + case !ok: log.Printf("...package '%v' not found, will be downloaded\n", pack.Location.Href) packagesToDownload = append(packagesToDownload, pack) - case err != nil: - log.Printf("Checksum evaluation of the package '%v' returned the following error:\n", pack.Location.Href) - log.Printf("Error message: %v\n", err) - log.Println("...package skipped") - case pack.Checksum.Checksum != storageChecksum: - log.Printf("...package '%v' has a checksum error, will be redownloaded\n", pack.Location.Href) - log.Printf("[repo vs local] = ['%v' VS '%v']\n", pack.Checksum.Checksum, storageChecksum) - packagesToDownload = append(packagesToDownload, pack) - default: + case previousChecksum.Type == pack.Checksum.Type && previousChecksum.Checksum == pack.Checksum.Checksum: log.Printf("...package '%v' is up-to-date already, will be recycled\n", pack.Location.Href) packagesToRecycle = append(packagesToRecycle, pack) + default: + log.Printf("...package '%v' does not have the expected checksum, will be redownloaded\n", pack.Location.Href) + packagesToDownload = append(packagesToDownload, pack) } } } diff --git a/get/syncer_test.go b/get/syncer_test.go index 0eef9c0..5847902 100644 --- a/get/syncer_test.go +++ b/get/syncer_test.go @@ -45,11 +45,11 @@ func TestStoreRepo(t *testing.T) { for _, file := range expectedFiles { originalInfo, serr := os.Stat(filepath.Join("testdata", "repo", file)) if err != nil { - t.Error(serr) + t.Fatal(serr) } syncedInfo, serr := os.Stat(filepath.Join(directory, file)) if serr != nil { - t.Error(serr) + t.Fatal(serr) } if originalInfo.Size() != syncedInfo.Size() { t.Error("original and synced versions of", file, "differ:", originalInfo.Size(), "vs", syncedInfo.Size()) @@ -61,5 +61,4 @@ func TestStoreRepo(t *testing.T) { if err != nil { t.Error(err) } - } diff --git a/util/io.go b/util/io.go index 29c0875..eba7e26 100644 --- a/util/io.go +++ b/util/io.go @@ -1,6 +1,10 @@ package util import ( + "crypto" + "encoding/hex" + "fmt" + "hash" "io" "io/ioutil" ) @@ -18,7 +22,9 @@ func Compose(mapper ReaderMapper, f ReaderConsumer) ReaderConsumer { if err != nil { return } - defer mappedReader.Close() + defer func() { + err = mappedReader.Close() + }() return f(mappedReader) } @@ -74,3 +80,52 @@ func (t *TeeReadCloser) Close() (err error) { err = t.writer.Close() return } + +// ChecksummingWriter is a WriteCloser that checks on close that the checksum matches +type ChecksummingWriter struct { + writer io.WriteCloser + expectedSum string + hashFunction crypto.Hash + hash hash.Hash +} + +// NewChecksummingWriter returns a new ChecksummingWriter +func NewChecksummingWriter(writer io.WriteCloser, expectedSum string, hashFunction crypto.Hash) *ChecksummingWriter { + if hashFunction != 0 { + return &ChecksummingWriter{writer, expectedSum, hashFunction, hashFunction.New()} + } + return &ChecksummingWriter{writer, expectedSum, hashFunction, nil} +} + +// Write delegates to the writer and hash +func (w *ChecksummingWriter) Write(p []byte) (n int, err error) { + if w.hashFunction != 0 { + w.hash.Write(p) + } + return w.writer.Write(p) +} + +// Close delegates to the writer and checks the hash sum +func (w *ChecksummingWriter) Close() (err error) { + err = w.writer.Close() + if err != nil { + return + } + if w.hashFunction != 0 { + actualSum := hex.EncodeToString(w.hash.Sum(nil)) + if w.expectedSum != actualSum { + err = &ChecksumError{w.expectedSum, actualSum} + } + } + return +} + +// ChecksumError is returned if the expected and actual checksums do not match +type ChecksumError struct { + expected string + actual string +} + +func (e *ChecksumError) Error() string { + return fmt.Sprintf("Checksum mismatch: expected %s, actual %s", e.expected, e.actual) +}