diff --git a/go/pkg/client/BUILD.bazel b/go/pkg/client/BUILD.bazel index 3d7819256..c3864deb8 100644 --- a/go/pkg/client/BUILD.bazel +++ b/go/pkg/client/BUILD.bazel @@ -24,6 +24,7 @@ go_library( "//go/pkg/command", "//go/pkg/contextmd", "//go/pkg/digest", + "//go/pkg/diskcache", "//go/pkg/filemetadata", "//go/pkg/retry", "//go/pkg/uploadinfo", diff --git a/go/pkg/client/cas_download.go b/go/pkg/client/cas_download.go index 9fe5e1fb0..8c2405247 100644 --- a/go/pkg/client/cas_download.go +++ b/go/pkg/client/cas_download.go @@ -101,6 +101,17 @@ func (c *Client) DownloadOutputs(ctx context.Context, outs map[string]*TreeOutpu symlinks = append(symlinks, out) continue } + if c.diskCache != nil { + absPath := out.Path + if !filepath.IsAbs(absPath) { + absPath = filepath.Join(outDir, absPath) + } + if c.diskCache.LoadCas(out.Digest, absPath) { + fullStats.Requested += out.Digest.Size + fullStats.Cached += out.Digest.Size + continue + } + } if _, ok := downloads[out.Digest]; ok { copies = append(copies, out) // All copies are effectivelly cached @@ -129,6 +140,11 @@ func (c *Client) DownloadOutputs(ctx context.Context, outs map[string]*TreeOutpu if err := cache.Update(absPath, md); err != nil { return fullStats, err } + if c.diskCache != nil { + if err := c.diskCache.StoreCas(output.Digest, absPath); err != nil { + return fullStats, err + } + } } for _, out := range copies { perm := c.RegularMode diff --git a/go/pkg/client/client.go b/go/pkg/client/client.go index f99f25858..2bb9093fb 100644 --- a/go/pkg/client/client.go +++ b/go/pkg/client/client.go @@ -15,10 +15,12 @@ import ( "time" "errors" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/actas" "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer" "github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/diskcache" "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" "github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo" "golang.org/x/oauth2" @@ -194,6 +196,7 @@ type Client struct { uploadOnce sync.Once downloadOnce sync.Once useBatchCompression UseBatchCompression + diskCache *diskcache.DiskCache } const ( @@ -242,6 +245,10 @@ func (c *Client) Close() error { if c.casConnection != c.connection { return c.casConnection.Close() } + if c.diskCache != nil { + // Waits for local disk GC to complete. + c.diskCache.Shutdown() + } return nil } @@ -351,6 +358,15 @@ func (o *TreeSymlinkOpts) Apply(c *Client) { c.TreeSymlinkOpts = o } +type DiskCacheOpts struct { + DiskCache *diskcache.DiskCache +} + +// Apply sets the client's TreeSymlinkOpts. +func (o *DiskCacheOpts) Apply(c *Client) { + c.diskCache = o.DiskCache +} + // MaxBatchDigests is maximum amount of digests to batch in upload and download operations. type MaxBatchDigests int diff --git a/go/pkg/client/exec.go b/go/pkg/client/exec.go index ce0f51fa4..f6b1fbce6 100644 --- a/go/pkg/client/exec.go +++ b/go/pkg/client/exec.go @@ -98,13 +98,23 @@ func (c *Client) ExecuteAction(ctx context.Context, ac *Action) (*repb.ActionRes } // CheckActionCache queries remote action cache, returning an ActionResult or nil if it doesn't exist. -func (c *Client) CheckActionCache(ctx context.Context, acDg *repb.Digest) (*repb.ActionResult, error) { +func (c *Client) CheckActionCache(ctx context.Context, dg digest.Digest) (*repb.ActionResult, error) { + if c.diskCache != nil { + if res, loaded := c.diskCache.LoadActionCache(dg); loaded { + return res, nil + } + } res, err := c.GetActionResult(ctx, &repb.GetActionResultRequest{ InstanceName: c.InstanceName, - ActionDigest: acDg, + ActionDigest: dg.ToProto(), }) switch st, _ := status.FromError(err); st.Code() { case codes.OK: + if c.diskCache != nil { + if err := c.diskCache.StoreActionCache(dg, res); err != nil { + log.Errorf("error storing ActionResult of %s to disk cache: %v", dg, err) + } + } return res, nil case codes.NotFound: return nil, nil @@ -166,12 +176,13 @@ func (c *Client) PrepAction(ctx context.Context, ac *Action) (*repb.Digest, *rep if err != nil { return nil, nil, fmt.Errorf("marshalling Action proto: %w", err) } - acDg := digest.NewFromBlob(acBlob).ToProto() + dg := digest.NewFromBlob(acBlob) + acDg := dg.ToProto() // If the result is cacheable, check if it's already in the cache. if !ac.DoNotCache || !ac.SkipCache { log.V(1).Info("Checking cache") - res, err := c.CheckActionCache(ctx, acDg) + res, err := c.CheckActionCache(ctx, dg) if err != nil { return nil, nil, err } diff --git a/go/pkg/diskcache/BUILD.bazel b/go/pkg/diskcache/BUILD.bazel new file mode 100644 index 000000000..5cd18a9c2 --- /dev/null +++ b/go/pkg/diskcache/BUILD.bazel @@ -0,0 +1,49 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "diskcache", + srcs = [ + "diskcache.go", + "sys_darwin.go", + "sys_linux.go", + "sys_windows.go", + ], + importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/diskcache", + visibility = ["//visibility:public"], + deps = [ + "//go/pkg/digest", + "@com_github_bazelbuild_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto", + "@com_github_golang_glog//:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_x_sync//errgroup:go_default_library", + ], +) + +go_test( + name = "diskcache_test", + srcs = ["diskcache_test.go"], + embed = [":diskcache"], + deps = [ + "//go/pkg/digest", + "//go/pkg/testutil", + "@com_github_bazelbuild_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_uuid//:uuid", + "@org_golang_x_sync//errgroup:go_default_library", + ], +) + +go_test( + name = "diskcache_benchmark_test", + srcs = ["diskcache_benchmark_test.go"], + embed = [":diskcache"], + deps = [ + "//go/pkg/digest", + "//go/pkg/testutil", + "@com_github_bazelbuild_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_uuid//:uuid", + "@org_golang_x_sync//errgroup:go_default_library", + ], + tags = ["manual"], +) diff --git a/go/pkg/diskcache/diskcache.go b/go/pkg/diskcache/diskcache.go new file mode 100644 index 000000000..2458ac073 --- /dev/null +++ b/go/pkg/diskcache/diskcache.go @@ -0,0 +1,486 @@ +// Package diskcache implements a local disk LRU CAS cache. +package diskcache + +import ( + "container/heap" + "context" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" + + repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + log "github.com/golang/glog" +) + +type key struct { + digest digest.Digest + isCas bool +} + +// An qitem is something we manage in a priority queue. +type qitem struct { + key key + lat time.Time // The last accessed time of the file. + index int // The index of the item in the heap. + mu sync.RWMutex // Protects the data-structure consistency for the given digest. +} + +// A priorityQueue implements heap.Interface and holds qitems. +type priorityQueue struct { + items []*qitem + n int +} + +func (q *priorityQueue) Len() int { + return q.n +} + +func (q *priorityQueue) Less(i, j int) bool { + // We want Pop to give us the oldest item. + return q.items[i].lat.Before(q.items[j].lat) +} + +func (q priorityQueue) Swap(i, j int) { + q.items[i], q.items[j] = q.items[j], q.items[i] + q.items[i].index = i + q.items[j].index = j +} + +func (q *priorityQueue) Push(x any) { + if q.n == cap(q.items) { + // Resize the queue + old := q.items + q.items = make([]*qitem, 2*cap(old)) // Initial capacity needs to be > 0. + copy(q.items, old) + } + item := x.(*qitem) + item.index = q.n + q.items[item.index] = item + q.n++ +} + +func (q *priorityQueue) Pop() any { + item := q.items[q.n-1] + q.items[q.n-1] = nil // avoid memory leak + item.index = -1 // for safety + q.n-- + return item +} + +// bumps item to the head of the queue. +func (q *priorityQueue) Bump(item *qitem) { + // Sanity check, necessary because of possible racing between Bump and GC: + if item.index < 0 || item.index >= q.n || q.items[item.index].key != item.key { + return + } + item.lat = time.Now() + heap.Fix(q, item.index) +} + +// DiskCache is a local disk LRU CAS and Action Cache cache. +type DiskCache struct { + root string // path to the root directory of the disk cache. + maxCapacityBytes uint64 // if disk size exceeds this, old items will be evicted as needed. + mu sync.Mutex // protects the queue. + store sync.Map // map of keys to qitems. + queue *priorityQueue // keys by last accessed time. + ctx context.Context + shutdown chan bool + shutdownOnce sync.Once + gcReq chan bool + gcDone chan bool + statMu sync.Mutex + stats *DiskCacheStats +} + +type DiskCacheStats struct { + TotalSizeBytes int64 + TotalNumFiles int64 + NumFilesStored int64 + TotalStoredBytes int64 + NumFilesGCed int64 + TotalGCedSizeBytes int64 + NumCacheHits int64 + NumCacheMisses int64 + TotalCacheHitSizeBytes int64 + InitTime time.Duration + TotalGCTime time.Duration + TotalGCDiskOpsTime time.Duration +} + +func New(ctx context.Context, root string, maxCapacityBytes uint64) (*DiskCache, error) { + res := &DiskCache{ + root: root, + maxCapacityBytes: maxCapacityBytes, + ctx: ctx, + queue: &priorityQueue{ + items: make([]*qitem, 1000), + }, + gcReq: make(chan bool, 1), + shutdown: make(chan bool), + gcDone: make(chan bool), + stats: &DiskCacheStats{}, + } + start := time.Now() + defer func() { atomic.AddInt64((*int64)(&res.stats.InitTime), int64(time.Since(start))) }() + heap.Init(res.queue) + if err := os.MkdirAll(root, os.ModePerm); err != nil { + return nil, err + } + // We use Git's directory/file naming structure as inspiration: + // https://git-scm.com/book/en/v2/Git-Internals-Git-Objects#:~:text=The%20subdirectory%20is%20named%20with%20the%20first%202%20characters%20of%20the%20SHA%2D1%2C%20and%20the%20filename%20is%20the%20remaining%2038%20characters. + eg, eCtx := errgroup.WithContext(ctx) + for i := 0; i < 256; i++ { + prefixDir := filepath.Join(root, fmt.Sprintf("%02x", i)) + eg.Go(func() error { + if eCtx.Err() != nil { + return eCtx.Err() + } + if err := os.MkdirAll(prefixDir, os.ModePerm); err != nil { + return err + } + return filepath.WalkDir(prefixDir, func(path string, d fs.DirEntry, err error) error { + // We log and continue on all errors, because cache read errors are not critical. + if err != nil { + return fmt.Errorf("error reading cache directory: %v", err) + } + if d.IsDir() { + return nil + } + subdir := filepath.Base(filepath.Dir(path)) + k, err := getKeyFromFileName(subdir + d.Name()) + if err != nil { + return fmt.Errorf("error parsing cached file name %s: %v", path, err) + } + info, err := d.Info() + if err != nil { + return fmt.Errorf("error getting file info of %s: %v", path, err) + } + it := &qitem{ + key: k, + lat: fileInfoToAccessTime(info), + } + size, err := res.getItemSize(k) + if err != nil { + return fmt.Errorf("error getting file size of %s: %v", path, err) + } + res.store.Store(k, it) + atomic.AddInt64(&res.stats.TotalSizeBytes, size) + atomic.AddInt64(&res.stats.TotalNumFiles, 1) + res.mu.Lock() + heap.Push(res.queue, it) + res.mu.Unlock() + return nil + }) + }) + } + if err := eg.Wait(); err != nil { + return nil, err + } + go res.daemon() + return res, nil +} + +func (d *DiskCache) getItemSize(k key) (int64, error) { + if k.isCas { + return k.digest.Size, nil + } + fname := d.getPath(k) + info, err := os.Stat(fname) + if err != nil { + return 0, fmt.Errorf("error getting info for %s: %v", fname, err) + } + return info.Size(), nil +} + +// Terminates the GC daemon, waiting for it to complete. No further Store* calls to the DiskCache should be made. +func (d *DiskCache) Shutdown() { + d.shutdownOnce.Do(func() { + d.shutdown <- true + <-d.gcDone + log.Infof("DiskCacheStats: %+v", d.stats) + }) +} + +func (d *DiskCache) GetStats() *DiskCacheStats { + // Return a copy for safety. + return &DiskCacheStats{ + TotalSizeBytes: atomic.LoadInt64(&d.stats.TotalSizeBytes), + TotalNumFiles: atomic.LoadInt64(&d.stats.TotalNumFiles), + NumFilesStored: atomic.LoadInt64(&d.stats.NumFilesStored), + TotalStoredBytes: atomic.LoadInt64(&d.stats.TotalStoredBytes), + NumFilesGCed: atomic.LoadInt64(&d.stats.NumFilesGCed), + TotalGCedSizeBytes: atomic.LoadInt64(&d.stats.TotalGCedSizeBytes), + NumCacheHits: atomic.LoadInt64(&d.stats.NumCacheHits), + NumCacheMisses: atomic.LoadInt64(&d.stats.NumCacheMisses), + TotalCacheHitSizeBytes: atomic.LoadInt64(&d.stats.TotalCacheHitSizeBytes), + InitTime: time.Duration(atomic.LoadInt64((*int64)(&d.stats.InitTime))), + TotalGCTime: time.Duration(atomic.LoadInt64((*int64)(&d.stats.TotalGCTime))), + } +} + +func getKeyFromFileName(fname string) (key, error) { + pair := strings.Split(fname, ".") + if len(pair) != 2 { + return key{}, fmt.Errorf("expected file name in the form hash[_ac].size, got %s", fname) + } + size, err := strconv.ParseInt(pair[1], 10, 64) + if err != nil { + return key{}, fmt.Errorf("invalid size in digest %s: %s", fname, err) + } + hash, isAc := strings.CutSuffix(pair[0], "_ac") + dg, err := digest.New(hash, size) + if err != nil { + return key{}, fmt.Errorf("invalid digest from file name %s: %v", fname, err) + } + return key{digest: dg, isCas: !isAc}, nil +} + +func (d *DiskCache) getPath(k key) string { + suffix := "" + if !k.isCas { + suffix = "_ac" + } + return filepath.Join(d.root, k.digest.Hash[:2], fmt.Sprintf("%s%s.%d", k.digest.Hash[2:], suffix, k.digest.Size)) +} + +func (d *DiskCache) StoreCas(dg digest.Digest, path string) error { + if dg.Size > int64(d.maxCapacityBytes) { + return fmt.Errorf("blob size %d exceeds DiskCache capacity %d", dg.Size, d.maxCapacityBytes) + } + it := &qitem{ + key: key{digest: dg, isCas: true}, + lat: time.Now(), + } + it.mu.Lock() + defer it.mu.Unlock() + _, exists := d.store.LoadOrStore(it.key, it) + if exists { + return nil + } + d.mu.Lock() + heap.Push(d.queue, it) + d.mu.Unlock() + if err := copyFile(path, d.getPath(it.key), dg.Size); err != nil { + return err + } + d.statMu.Lock() + d.stats.TotalSizeBytes += dg.Size + d.stats.TotalStoredBytes += dg.Size + newSize := uint64(d.stats.TotalSizeBytes) + d.stats.TotalNumFiles++ + d.stats.NumFilesStored++ + d.statMu.Unlock() + if newSize > d.maxCapacityBytes { + select { + case d.gcReq <- true: + default: + } + } + return nil +} + +func (d *DiskCache) StoreActionCache(dg digest.Digest, ar *repb.ActionResult) error { + bytes, err := proto.Marshal(ar) + if err != nil { + return err + } + size := uint64(len(bytes)) + if size > d.maxCapacityBytes { + return fmt.Errorf("message size %d exceeds DiskCache capacity %d", size, d.maxCapacityBytes) + } + it := &qitem{ + key: key{digest: dg, isCas: false}, + lat: time.Now(), + } + it.mu.Lock() + defer it.mu.Unlock() + d.store.Store(it.key, it) + + d.mu.Lock() + heap.Push(d.queue, it) + d.mu.Unlock() + if err := os.WriteFile(d.getPath(it.key), bytes, 0644); err != nil { + return err + } + d.statMu.Lock() + d.stats.TotalSizeBytes += int64(size) + d.stats.TotalStoredBytes += int64(size) + newSize := uint64(d.stats.TotalSizeBytes) + d.stats.TotalNumFiles++ + d.stats.NumFilesStored++ + d.statMu.Unlock() + if newSize > d.maxCapacityBytes { + select { + case d.gcReq <- true: + default: + } + } + return nil +} + +func (d *DiskCache) gc() { + start := time.Now() + // Evict old entries until total size is below cap. + var numFilesGCed, totalGCedSizeBytes int64 + var totalGCDiskOpsTime time.Duration + for uint64(atomic.LoadInt64(&d.stats.TotalSizeBytes)) > d.maxCapacityBytes { + d.mu.Lock() + it := heap.Pop(d.queue).(*qitem) + d.mu.Unlock() + size, err := d.getItemSize(it.key) + if err != nil { + log.Errorf("error getting item size for %v: %v", it.key, err) + size = 0 + } + atomic.AddInt64(&d.stats.TotalSizeBytes, -size) + numFilesGCed++ + totalGCedSizeBytes += size + it.mu.Lock() + diskOpsStart := time.Now() + // We only delete the files, and not the prefix directories, because the prefixes are not worth worrying about. + if err := os.Remove(d.getPath(it.key)); err != nil { + log.Errorf("Error removing file: %v", err) + } + totalGCDiskOpsTime += time.Since(diskOpsStart) + d.store.Delete(it.key) + it.mu.Unlock() + } + d.statMu.Lock() + d.stats.NumFilesGCed += numFilesGCed + d.stats.TotalNumFiles -= numFilesGCed + d.stats.TotalGCedSizeBytes += totalGCedSizeBytes + d.stats.TotalGCDiskOpsTime += time.Duration(totalGCDiskOpsTime) + d.stats.TotalGCTime += time.Since(start) + d.statMu.Unlock() +} + +func (d *DiskCache) daemon() { + defer func() { d.gcDone <- true }() + for { + select { + case <-d.shutdown: + return + case <-d.ctx.Done(): + return + case <-d.gcReq: + d.gc() + } + } +} + +// Copy file contents retaining the source permissions. +func copyFile(src, dst string, size int64) error { + srcInfo, err := os.Stat(src) + if err != nil { + return err + } + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return err + } + if err := out.Chmod(srcInfo.Mode()); err != nil { + return err + } + defer out.Close() + n, err := io.Copy(out, in) + if err != nil { + return err + } + // Required sanity check: if the file is being concurrently deleted, we may not always copy everything. + if n != size { + return fmt.Errorf("copy of %s to %s failed: src/dst size mismatch: wanted %d, got %d", src, dst, size, n) + } + return nil +} + +// If the digest exists in the disk cache, copy the file contents to the given path. +func (d *DiskCache) LoadCas(dg digest.Digest, path string) bool { + k := key{digest: dg, isCas: true} + iUntyped, loaded := d.store.Load(k) + if !loaded { + atomic.AddInt64(&d.stats.NumCacheMisses, 1) + return false + } + it := iUntyped.(*qitem) + it.mu.RLock() + err := copyFile(d.getPath(k), path, dg.Size) + it.mu.RUnlock() + if err != nil { + // It is not possible to prevent a race with GC; hence, we return false on copy errors. + atomic.AddInt64(&d.stats.NumCacheMisses, 1) + return false + } + + d.mu.Lock() + d.queue.Bump(it) + d.mu.Unlock() + d.statMu.Lock() + d.stats.NumCacheHits++ + d.stats.TotalCacheHitSizeBytes += dg.Size + d.statMu.Unlock() + return true +} + +func (d *DiskCache) LoadActionCache(dg digest.Digest) (ar *repb.ActionResult, loaded bool) { + k := key{digest: dg, isCas: false} + iUntyped, loaded := d.store.Load(k) + if !loaded { + atomic.AddInt64(&d.stats.NumCacheMisses, 1) + return nil, false + } + it := iUntyped.(*qitem) + it.mu.RLock() + ar = &repb.ActionResult{} + size, err := d.loadActionResult(k, ar) + if err != nil { + // It is not possible to prevent a race with GC; hence, we return false on load errors. + it.mu.RUnlock() + atomic.AddInt64(&d.stats.NumCacheMisses, 1) + return nil, false + } + it.mu.RUnlock() + + d.mu.Lock() + d.queue.Bump(it) + d.mu.Unlock() + d.statMu.Lock() + d.stats.NumCacheHits++ + d.stats.TotalCacheHitSizeBytes += int64(size) + d.statMu.Unlock() + return ar, true +} + +func (d *DiskCache) loadActionResult(k key, ar *repb.ActionResult) (int, error) { + bytes, err := os.ReadFile(d.getPath(k)) + if err != nil { + return 0, err + } + n := len(bytes) + // Required sanity check: sometimes the read pretends to succeed, but doesn't, if + // the file is being concurrently deleted. Empty ActionResult is advised against in + // the RE-API: https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L1052 + if n == 0 { + return n, fmt.Errorf("read empty ActionResult for %v", k.digest) + } + if err := proto.Unmarshal(bytes, ar); err != nil { + return n, fmt.Errorf("error unmarshalling %v as ActionResult: %v", bytes, err) + } + return n, nil +} diff --git a/go/pkg/diskcache/diskcache_benchmark_test.go b/go/pkg/diskcache/diskcache_benchmark_test.go new file mode 100644 index 000000000..141fb1235 --- /dev/null +++ b/go/pkg/diskcache/diskcache_benchmark_test.go @@ -0,0 +1,174 @@ +package diskcache + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "golang.org/x/sync/errgroup" +) + +type BenchmarkParams struct { + Name string + MaxConcurrency int // Number of concurrent threads for requests. + CacheMissCost time.Duration // Similates a remote execution / fetch. + CacheSizeBytes uint64 + FileSizeBytes int // All files have this size. + NumRequests int // Total number of cache load/store requests. + TotalNumFiles int // NumRequests will repeat over this set. + NumExistingFiles int // Affects initialization time. +} + +var kBenchmarks = []BenchmarkParams{ + { + Name: "AllGC_Small", + MaxConcurrency: 100, + CacheMissCost: 100 * time.Millisecond, + CacheSizeBytes: 20000, + FileSizeBytes: 100, + NumRequests: 10000, + TotalNumFiles: 500, + NumExistingFiles: 0, + }, + { + Name: "AllGC_Medium", + MaxConcurrency: 100, + CacheMissCost: 100 * time.Millisecond, + CacheSizeBytes: 2000000, + FileSizeBytes: 10000, + NumRequests: 10000, + TotalNumFiles: 500, + NumExistingFiles: 0, + }, + { + Name: "AllGC_Large", + MaxConcurrency: 100, + CacheMissCost: 100 * time.Millisecond, + CacheSizeBytes: 2000000000, + FileSizeBytes: 10000000, + NumRequests: 5000, + TotalNumFiles: 500, + NumExistingFiles: 0, + }, + { + Name: "AllCacheHits_Small", + MaxConcurrency: 100, + CacheMissCost: 100 * time.Millisecond, + CacheSizeBytes: 50000, + FileSizeBytes: 100, + NumRequests: 20000, + TotalNumFiles: 500, + NumExistingFiles: 500, + }, + { + Name: "AllCacheHits_Medium", + MaxConcurrency: 100, + CacheMissCost: 100 * time.Millisecond, + CacheSizeBytes: 50000000, + FileSizeBytes: 100000, + NumRequests: 20000, + TotalNumFiles: 500, + NumExistingFiles: 500, + }, + { + Name: "AllCacheHits_Large", + MaxConcurrency: 100, + CacheMissCost: 100 * time.Millisecond, + CacheSizeBytes: 5000000000, + FileSizeBytes: 10000000, + NumRequests: 5000, + TotalNumFiles: 500, + NumExistingFiles: 500, + }, +} + +func getFilename(i int) string { + return fmt.Sprintf("f_%05d", i) +} + +func TestRunAllBenchmarks(t *testing.T) { + for _, b := range kBenchmarks { + t.Run(b.Name, func(t *testing.T) { + root := t.TempDir() + fmt.Printf("Initializing source files for benchmark %s...\n", b.Name) + start := time.Now() + source := filepath.Join(root, "source") + if err := os.MkdirAll(source, 0777); err != nil { + t.Fatalf("%v", err) + } + dgs := make([]digest.Digest, b.TotalNumFiles) + for i := 0; i < b.TotalNumFiles; i++ { + filename := filepath.Join(source, getFilename(i)) + var s strings.Builder + fmt.Fprintf(&s, "%d\n", i) + for k := s.Len(); k < b.FileSizeBytes; k++ { + s.WriteByte(0) + } + blob := []byte(s.String()) + dgs[i] = digest.NewFromBlob(blob) + if err := os.WriteFile(filename, blob, 00666); err != nil { + t.Fatalf("os.WriteFile(%s): %v", filename, err) + } + } + + cacheDir := filepath.Join(root, "cache") + d, err := New(context.Background(), cacheDir, b.CacheSizeBytes) + if err != nil { + t.Fatalf("New: %v", err) + } + // Pre-populate the cache if requested: + if b.NumExistingFiles > 0 { + fmt.Printf("Pre-warming cache for benchmark %s...\n", b.Name) + for i := 0; i < b.NumExistingFiles; i++ { + fname := filepath.Join(source, getFilename(i)) + if err := d.StoreCas(dgs[i], fname); err != nil { + t.Fatalf("StoreCas(%s, %s) failed: %v", dgs[i], fname, err) + } + } + d.Shutdown() + d, err = New(context.Background(), cacheDir, b.CacheSizeBytes) + if err != nil { + t.Fatalf("New: %v", err) + } + } + // Run the simulation: store on every cache miss. + new := filepath.Join(root, "new") + if err := os.MkdirAll(new, 0777); err != nil { + t.Fatalf("%v", err) + } + fmt.Printf("Finished initialization for benchmark %s, duration %v\n", b.Name, time.Since(start)) + eg := errgroup.Group{} + eg.SetLimit(b.MaxConcurrency) + fmt.Printf("Starting benchmark %s...\n", b.Name) + start = time.Now() + for k := 0; k < b.NumRequests; k++ { + k := k + eg.Go(func() error { + i := k % b.TotalNumFiles + newName := filepath.Join(new, getFilename(k)) + if d.LoadCas(dgs[i], newName) { + if dg, err := digest.NewFromFile(newName); dg != dgs[i] || err != nil { + return fmt.Errorf("%d: err %v or digest mismatch %v vs %v", k, err, dg, dgs[i]) + } + } else { + time.Sleep(b.CacheMissCost) + if err := d.StoreCas(dgs[i], filepath.Join(source, getFilename(i))); err != nil { + return fmt.Errorf("StoreCas: %v", err) + } + } + return nil + }) + } + if err := eg.Wait(); err != nil { + t.Fatalf("%v", err) + } + d.Shutdown() + fmt.Printf("Finished benchmark %s, total duration %v, stats:\n%+v\n", b.Name, time.Since(start), d.GetStats()) + }) + } +} diff --git a/go/pkg/diskcache/diskcache_test.go b/go/pkg/diskcache/diskcache_test.go new file mode 100644 index 000000000..9c6a6b8f6 --- /dev/null +++ b/go/pkg/diskcache/diskcache_test.go @@ -0,0 +1,573 @@ +package diskcache + +import ( + "context" + "fmt" + "math/rand" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/testutil" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/proto" + + repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" +) + +func TestStoreLoadCasPerm(t *testing.T) { + tests := []struct { + name string + executable bool + }{ + { + name: "+X", + executable: true, + }, + { + name: "-X", + executable: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + root := t.TempDir() + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } + fname, _ := testutil.CreateFile(t, tc.executable, "12345") + srcInfo, err := os.Stat(fname) + if err != nil { + t.Fatalf("os.Stat() failed: %v", err) + } + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if err := d.StoreCas(dg, fname); err != nil { + t.Errorf("StoreCas(%s, %s) failed: %v", dg, fname, err) + } + newName := filepath.Join(root, "new") + if !d.LoadCas(dg, newName) { + t.Errorf("expected to load %s from the cache to %s", dg, newName) + } + fileInfo, err := os.Stat(newName) + if err != nil { + t.Fatalf("os.Stat(%s) failed: %v", newName, err) + } + if fileInfo.Mode() != srcInfo.Mode() { + t.Errorf("expected %s to have %v permissions, got: %v", newName, srcInfo.Mode(), fileInfo.Mode()) + } + contents, err := os.ReadFile(newName) + if err != nil { + t.Errorf("error reading from %s: %v", newName, err) + } + if string(contents) != "12345" { + t.Errorf("Cached result did not match: want %q, got %q", "12345", string(contents)) + } + d.Shutdown() + stats := d.GetStats() + if stats.TotalNumFiles != 1 { + t.Errorf("expected TotalNumFiles to be 1, got %d", stats.TotalNumFiles) + } + if stats.NumFilesStored != 1 { + t.Errorf("expected NumFilesStored to be 1, got %d", stats.NumFilesStored) + } + if stats.TotalStoredBytes != 5 { + t.Errorf("expected TotalStoredBytes to be 5, got %d", stats.TotalStoredBytes) + } + if stats.NumCacheHits != 1 { + t.Errorf("expected NumCacheHits to be 1, got %d", stats.NumCacheHits) + } + if stats.TotalCacheHitSizeBytes != 5 { + t.Errorf("expected TotalCacheHitSizeBytes to be 5, got %d", stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses != 0 { + t.Errorf("expected NumCacheMisses to be 0, got %d", stats.NumCacheMisses) + } + if stats.TotalSizeBytes != 5 { + t.Errorf("expected TotalSizeBytes to be 5, got %d", stats.TotalSizeBytes) + } + if stats.NumFilesGCed != 0 { + t.Errorf("expected NumFilesGCed to be 0, got %d", stats.NumFilesGCed) + } + if stats.TotalGCedSizeBytes != 0 { + t.Errorf("expected TotalGCedSizeBytes to be 0, got %d", stats.TotalGCedSizeBytes) + } + if stats.InitTime == 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime != 0 { + t.Errorf("expected TotalGCTime to be 0") + } + }) + } +} + +func TestLoadCasNotFound(t *testing.T) { + root := t.TempDir() + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } + newName := filepath.Join(root, "new") + dg := digest.NewFromBlob([]byte("bla")) + if d.LoadCas(dg, newName) { + t.Errorf("expected to not load %s from the cache to %s", dg, newName) + } + d.Shutdown() + stats := d.GetStats() + if stats.TotalNumFiles != 0 { + t.Errorf("expected TotalNumFiles to be 0, got %d", stats.TotalNumFiles) + } + if stats.TotalSizeBytes != 0 { + t.Errorf("expected TotalSizeBytes to be 0, got %d", stats.TotalSizeBytes) + } + if stats.NumFilesStored != 0 { + t.Errorf("expected NumFilesStored to be 0, got %d", stats.NumFilesStored) + } + if stats.TotalStoredBytes != 0 { + t.Errorf("expected TotalStoredBytes to be 0, got %d", stats.TotalStoredBytes) + } + if stats.NumCacheHits != 0 { + t.Errorf("expected NumCacheHits to be 0, got %d", stats.NumCacheHits) + } + if stats.TotalCacheHitSizeBytes != 0 { + t.Errorf("expected TotalCacheHitSizeBytes to be 0, got %d", stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses != 1 { + t.Errorf("expected NumCacheMisses to be 1, got %d", stats.NumCacheMisses) + } + if stats.NumFilesGCed != 0 { + t.Errorf("expected NumFilesGCed to be 0, got %d", stats.NumFilesGCed) + } + if stats.TotalGCedSizeBytes != 0 { + t.Errorf("expected TotalGCedSizeBytes to be 0, got %d", stats.TotalGCedSizeBytes) + } + if stats.InitTime == 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime != 0 { + t.Errorf("expected TotalGCTime to be 0") + } +} + +func TestStoreLoadActionCache(t *testing.T) { + root := t.TempDir() + d, err := New(context.Background(), filepath.Join(root, "cache"), 100) + if err != nil { + t.Errorf("New: %v", err) + } + ar := &repb.ActionResult{ + OutputFiles: []*repb.OutputFile{ + {Path: "bla", Digest: digest.Empty.ToProto()}, + }, + } + dg := digest.NewFromBlob([]byte("foo")) + if err := d.StoreActionCache(dg, ar); err != nil { + t.Errorf("StoreActionCache(%s) failed: %v", dg, err) + } + got, loaded := d.LoadActionCache(dg) + if !loaded { + t.Errorf("expected to load %s from the cache", dg) + } + if diff := cmp.Diff(ar, got, cmp.Comparer(proto.Equal)); diff != "" { + t.Errorf("LoadActionCache(...) gave diff on action result (-want +got):\n%s", diff) + } + d.Shutdown() + stats := d.GetStats() + if stats.TotalNumFiles != 1 { + t.Errorf("expected TotalNumFiles to be 1, got %d", stats.TotalNumFiles) + } + bytes, err := proto.Marshal(ar) + if err != nil { + t.Fatalf("error marshalling proto: %v", err) + } + size := int64(len(bytes)) + if stats.TotalSizeBytes != size { + t.Errorf("expected TotalSizeBytes to be %d, got %d", size, stats.TotalSizeBytes) + } + if stats.NumFilesStored != 1 { + t.Errorf("expected NumFilesStored to be 1, got %d", stats.NumFilesStored) + } + if stats.TotalStoredBytes != size { + t.Errorf("expected TotalStoredBytes to be %d, got %d", size, stats.TotalStoredBytes) + } + if stats.NumCacheHits != 1 { + t.Errorf("expected NumCacheHits to be 1, got %d", stats.NumCacheHits) + } + if stats.TotalCacheHitSizeBytes != size { + t.Errorf("expected TotalCacheHitSizeBytes to be %d, got %d", size, stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses != 0 { + t.Errorf("expected NumCacheMisses to be 0, got %d", stats.NumCacheMisses) + } + if stats.NumFilesGCed != 0 { + t.Errorf("expected NumFilesGCed to be 0, got %d", stats.NumFilesGCed) + } + if stats.TotalGCedSizeBytes != 0 { + t.Errorf("expected TotalGCedSizeBytes to be 0, got %d", stats.TotalGCedSizeBytes) + } + if stats.InitTime == 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime != 0 { + t.Errorf("expected TotalGCTime to be 0") + } +} + +func TestGcOldestCas(t *testing.T) { + root := t.TempDir() + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } + for i := 0; i < 5; i++ { + fname, _ := testutil.CreateFile(t, false, fmt.Sprintf("aaa %d", i)) + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if err := d.StoreCas(dg, fname); err != nil { + t.Errorf("StoreCas(%s, %s) failed: %v", dg, fname, err) + } + } + d.Shutdown() + newName := filepath.Join(root, "new") + for i := 0; i < 5; i++ { + dg := digest.NewFromBlob([]byte(fmt.Sprintf("aaa %d", i))) + if d.LoadCas(dg, newName) != (i > 0) { + t.Errorf("expected loaded to be %v for %s from the cache to %s", i > 0, dg, newName) + } + } + stats := d.GetStats() + if stats.TotalNumFiles != 4 { + t.Errorf("expected TotalNumFiles to be 4, got %d", stats.TotalNumFiles) + } + if stats.NumFilesStored != 5 { + t.Errorf("expected NumFilesStored to be 5, got %d", stats.NumFilesStored) + } + if stats.TotalStoredBytes != 25 { + t.Errorf("expected TotalStoredBytes to be 25, got %d", stats.TotalStoredBytes) + } + if stats.NumCacheHits != 4 { + t.Errorf("expected NumCacheHits to be 4, got %d", stats.NumCacheHits) + } + if stats.TotalCacheHitSizeBytes != 20 { + t.Errorf("expected TotalCacheHitSizeBytes to be 20, got %d", stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses != 1 { + t.Errorf("expected NumCacheMisses to be 1, got %d", stats.NumCacheMisses) + } + if stats.NumFilesGCed != 1 { + t.Errorf("expected NumFilesGCed to be 1, got %d", stats.NumFilesGCed) + } + if stats.TotalGCedSizeBytes != 5 { + t.Errorf("expected TotalGCedSizeBytes to be 5, got %d", stats.TotalGCedSizeBytes) + } + if uint64(stats.TotalSizeBytes) != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, stats.TotalSizeBytes) + } + if stats.InitTime <= 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime <= 0 { + t.Errorf("expected TotalGCTime to be > 0") + } +} + +func TestGcOldestActionCache(t *testing.T) { + ar := &repb.ActionResult{ + OutputFiles: []*repb.OutputFile{ + {Path: "12345", Digest: digest.Empty.ToProto()}, + }, + } + bytes, err := proto.Marshal(ar) + if err != nil { + t.Fatalf("error marshalling proto: %v", err) + } + size := len(bytes) + root := t.TempDir() + d, err := New(context.Background(), filepath.Join(root, "cache"), uint64(size)*4) + if err != nil { + t.Errorf("New: %v", err) + } + for i := 0; i < 5; i++ { + si := fmt.Sprintf("aaa %d", i) + dg := digest.NewFromBlob([]byte(si)) + ar.OutputFiles[0].Path = si + if err := d.StoreActionCache(dg, ar); err != nil { + t.Errorf("StoreActionCache(%s) failed: %v", dg, err) + } + } + d.Shutdown() + for i := 0; i < 5; i++ { + si := fmt.Sprintf("aaa %d", i) + dg := digest.NewFromBlob([]byte(si)) + got, loaded := d.LoadActionCache(dg) + if loaded { + ar.OutputFiles[0].Path = si + if diff := cmp.Diff(ar, got, cmp.Comparer(proto.Equal)); diff != "" { + t.Errorf("LoadActionCache(...) gave diff on action result (-want +got):\n%s", diff) + } + } + if loaded != (i > 0) { + t.Errorf("expected loaded to be %v for %s from the cache", i > 0, dg) + } + } + stats := d.GetStats() + if stats.TotalNumFiles != 4 { + t.Errorf("expected TotalNumFiles to be 4, got %d", stats.TotalNumFiles) + } + if stats.NumFilesStored != 5 { + t.Errorf("expected NumFilesStored to be 5, got %d", stats.NumFilesStored) + } + if stats.TotalStoredBytes != int64(size)*5 { + t.Errorf("expected TotalStoredBytes to be %d, got %d", size*5, stats.TotalStoredBytes) + } + if stats.NumCacheHits != 4 { + t.Errorf("expected NumCacheHits to be 4, got %d", stats.NumCacheHits) + } + if stats.TotalCacheHitSizeBytes != int64(size)*4 { + t.Errorf("expected TotalCacheHitSizeBytes to be %d, got %d", size*4, stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses != 1 { + t.Errorf("expected NumCacheMisses to be 1, got %d", stats.NumCacheMisses) + } + if stats.NumFilesGCed != 1 { + t.Errorf("expected NumFilesGCed to be 1, got %d", stats.NumFilesGCed) + } + if stats.TotalGCedSizeBytes != int64(size) { + t.Errorf("expected TotalGCedSizeBytes to be %d, got %d", size, stats.TotalGCedSizeBytes) + } + if uint64(stats.TotalSizeBytes) != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, stats.TotalSizeBytes) + } + if stats.InitTime <= 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime <= 0 { + t.Errorf("expected TotalGCTime to be > 0") + } +} + +func getLastAccessTime(path string) (time.Time, error) { + info, err := os.Stat(path) + if err != nil { + return time.Time{}, err + } + return fileInfoToAccessTime(info), nil +} + +// We say that Last Access Time is behaving accurately on a system if reading from the file +// bumps the LAT time forward. From experience, Mac and Linux Debian are accurate. Ubuntu -- not. +// From experience, even when the LAT gets modified on access on Ubuntu, it can be imprecise to +// an order of seconds (!). +func isSystemLastAccessTimeAccurate(t *testing.T) bool { + t.Helper() + fname, _ := testutil.CreateFile(t, false, "foo") + lat, _ := getLastAccessTime(fname) + if _, err := os.ReadFile(fname); err != nil { + t.Fatalf("%v", err) + } + newLat, _ := getLastAccessTime(fname) + return lat.Before(newLat) +} + +func TestInitFromExistingCas(t *testing.T) { + if !isSystemLastAccessTimeAccurate(t) { + // This effectively skips the test on Ubuntu, because to make the test work there, + // we would need to inject too many / too long time.Sleep statements to beat the system's + // inaccuracy. + t.Logf("Skipping TestInitFromExisting, because system Last Access Time is unreliable.") + return + } + root := t.TempDir() + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } + for i := 0; i < 4; i++ { + fname, _ := testutil.CreateFile(t, false, fmt.Sprintf("aaa %d", i)) + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + if err := d.StoreCas(dg, fname); err != nil { + t.Errorf("StoreCas(%s, %s) failed: %v", dg, fname, err) + } + } + newName := filepath.Join(root, "new") + dg := digest.NewFromBlob([]byte("aaa 0")) + if !d.LoadCas(dg, newName) { // Now 0 has been accessed, 1 is the oldest file. + t.Errorf("expected %s to be cached", dg) + } + d.Shutdown() + + // Re-initialize from existing files. + d, err = New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } + + // Check old files are cached: + dg = digest.NewFromBlob([]byte("aaa 1")) + if !d.LoadCas(dg, newName) { // Now 1 has been accessed, 2 is the oldest file. + t.Errorf("expected %s to be cached", dg) + } + fname, _ := testutil.CreateFile(t, false, "aaa 4") + dg, err = digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + // Trigger a GC by adding a new file. + if err := d.StoreCas(dg, fname); err != nil { + t.Errorf("StoreCas(%s, %s) failed: %v", dg, fname, err) + } + d.Shutdown() + dg = digest.NewFromBlob([]byte("aaa 2")) + if d.LoadCas(dg, newName) { + t.Errorf("expected to not load %s from the cache to %s", dg, newName) + } + stats := d.GetStats() + if stats.TotalNumFiles != 4 { + t.Errorf("expected TotalNumFiles to be 4, got %d", stats.TotalNumFiles) + } + if stats.NumFilesStored != 1 { + t.Errorf("expected NumFilesStored to be 1, got %d", stats.NumFilesStored) + } + if stats.TotalStoredBytes != 5 { + t.Errorf("expected TotalStoredBytes to be 5, got %d", stats.TotalStoredBytes) + } + if stats.NumCacheHits != 1 { + t.Errorf("expected NumCacheHits to be 1, got %d", stats.NumCacheHits) + } + if stats.TotalCacheHitSizeBytes != 5 { + t.Errorf("expected TotalCacheHitSizeBytes to be 5, got %d", stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses != 1 { + t.Errorf("expected NumCacheMisses to be 1, got %d", stats.NumCacheMisses) + } + if stats.NumFilesGCed != 1 { + t.Errorf("expected NumFilesGCed to be 1, got %d", stats.NumFilesGCed) + } + if stats.TotalGCedSizeBytes != 5 { + t.Errorf("expected TotalGCedSizeBytes to be 5, got %d", stats.TotalGCedSizeBytes) + } + if uint64(stats.TotalSizeBytes) != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, stats.TotalSizeBytes) + } + if stats.InitTime <= 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime <= 0 { + t.Errorf("expected TotalGCTime to be > 0") + } +} + +func TestThreadSafetyCas(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "orig"), os.ModePerm); err != nil { + t.Fatalf("%v", err) + } + if err := os.MkdirAll(filepath.Join(root, "new"), os.ModePerm); err != nil { + t.Fatalf("%v", err) + } + nFiles := 10 + attempts := 5000 + // All blobs are size 5 exactly. We will have half the byte capacity we need. + d, err := New(context.Background(), filepath.Join(root, "cache"), uint64(nFiles*5)/2) + if err != nil { + t.Errorf("New: %v", err) + } + var files []string + var dgs []digest.Digest + for i := 0; i < nFiles; i++ { + fname := filepath.Join(root, "orig", fmt.Sprintf("%d", i)) + if err := os.WriteFile(fname, []byte(fmt.Sprintf("aa %02d", i)), 0644); err != nil { + t.Fatalf("os.WriteFile: %v", err) + } + files = append(files, fname) + dg, err := digest.NewFromFile(fname) + if err != nil { + t.Fatalf("digest.NewFromFile failed: %v", err) + } + dgs = append(dgs, dg) + if err := d.StoreCas(dg, fname); err != nil { + t.Errorf("StoreCas(%s, %s) failed: %v", dg, fname, err) + } + } + // Randomly access and store files from different threads. + eg := errgroup.Group{} + var hits uint64 + var runs []int + for k := 0; k < attempts; k++ { + eg.Go(func() error { + i := rand.Intn(nFiles) + runs = append(runs, i) + newName := filepath.Join(root, "new", uuid.New().String()) + if d.LoadCas(dgs[i], newName) { + atomic.AddUint64(&hits, 1) + contents, err := os.ReadFile(newName) + if err != nil { + return fmt.Errorf("os.ReadFile: %v", err) + } + want := fmt.Sprintf("aa %02d", i) + if string(contents) != want { + return fmt.Errorf("Cached result did not match: want %q, got %q for digest %v", want, string(contents), dgs[i]) + } + } else if err := d.StoreCas(dgs[i], files[i]); err != nil { + return fmt.Errorf("StoreCas: %v", err) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + t.Error(err) + } + d.Shutdown() + if int(hits) < attempts/2 { + t.Errorf("Unexpectedly low cache hits %d out of %d attempts", hits, attempts) + } + stats := d.GetStats() + if stats.TotalNumFiles != 5 { + t.Errorf("expected TotalNumFiles to be 5, got %d", stats.TotalNumFiles) + } + if uint64(stats.NumCacheHits) != hits { + t.Errorf("expected NumCacheHits to be %d, got %d", hits, stats.NumCacheHits) + } + if uint64(stats.TotalCacheHitSizeBytes) != hits*5 { + t.Errorf("expected TotalCacheHitSizeBytes to be %d, got %d", hits*5, stats.TotalCacheHitSizeBytes) + } + if stats.NumCacheMisses+stats.NumCacheHits != int64(attempts) { + t.Errorf("expected NumCacheHits+NumCacheMisses to be %d, got %d", attempts, stats.NumCacheMisses+stats.NumCacheHits) + } + // This is less or equal because of multiple concurrent Stores. + if stats.NumFilesStored > int64(nFiles)+stats.NumCacheMisses { + t.Errorf("expected NumFilesStored to be <= %d, got %d", int64(nFiles)+stats.NumCacheMisses, stats.NumFilesStored) + } + if stats.TotalStoredBytes != 5*stats.NumFilesStored { + t.Errorf("expected TotalStoredBytes to be %d, got %d", 5*stats.NumFilesStored, stats.TotalStoredBytes) + } + if stats.NumFilesGCed <= 0 { + t.Errorf("expected NumFilesGCed to be > 0") + } + if stats.TotalGCedSizeBytes <= 0 { + t.Errorf("expected TotalGCedSizeBytes to be > 0") + } + if uint64(stats.TotalSizeBytes) != d.maxCapacityBytes { + t.Errorf("expected total size bytes to be %d, got %d", d.maxCapacityBytes, stats.TotalSizeBytes) + } + if stats.InitTime <= 0 { + t.Errorf("expected InitTime to be > 0") + } + if stats.TotalGCTime <= 0 { + t.Errorf("expected TotalGCTime to be > 0") + } +} diff --git a/go/pkg/diskcache/sys_darwin.go b/go/pkg/diskcache/sys_darwin.go new file mode 100644 index 000000000..0c98dceb5 --- /dev/null +++ b/go/pkg/diskcache/sys_darwin.go @@ -0,0 +1,13 @@ +// Utility to get the last accessed time on Darwin. +// System utilities that differ between OS implementations. +package diskcache + +import ( + "io/fs" + "syscall" + "time" +) + +func fileInfoToAccessTime(info fs.FileInfo) time.Time { + return time.Unix(info.Sys().(*syscall.Stat_t).Atimespec.Unix()) +} diff --git a/go/pkg/diskcache/sys_linux.go b/go/pkg/diskcache/sys_linux.go new file mode 100644 index 000000000..27f33d588 --- /dev/null +++ b/go/pkg/diskcache/sys_linux.go @@ -0,0 +1,12 @@ +// System utilities that differ between OS implementations. +package diskcache + +import ( + "io/fs" + "syscall" + "time" +) + +func fileInfoToAccessTime(info fs.FileInfo) time.Time { + return time.Unix(info.Sys().(*syscall.Stat_t).Atim.Unix()) +} diff --git a/go/pkg/diskcache/sys_windows.go b/go/pkg/diskcache/sys_windows.go new file mode 100644 index 000000000..b0621f64a --- /dev/null +++ b/go/pkg/diskcache/sys_windows.go @@ -0,0 +1,14 @@ +// System utilities that differ between OS implementations. +package diskcache + +import ( + "io/fs" + "syscall" + "time" +) + +// This will return correct values only if `fsutil behavior set disablelastaccess 0` is set. +// Tracking of last access time is disabled by default on Windows. +func fileInfoToAccessTime(info fs.FileInfo) time.Time { + return time.Unix(0, info.Sys().(*syscall.Win32FileAttributeData).LastAccessTime.Nanoseconds()) +} diff --git a/go/pkg/flags/BUILD.bazel b/go/pkg/flags/BUILD.bazel index 0572a941a..92bb117b4 100644 --- a/go/pkg/flags/BUILD.bazel +++ b/go/pkg/flags/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//go/pkg/balancer", "//go/pkg/client", + "//go/pkg/diskcache", "//go/pkg/moreflag", "@com_github_golang_glog//:go_default_library", "@org_golang_google_grpc//:go_default_library", diff --git a/go/pkg/flags/flags.go b/go/pkg/flags/flags.go index 75a64b7c4..a6b2377b9 100644 --- a/go/pkg/flags/flags.go +++ b/go/pkg/flags/flags.go @@ -8,6 +8,7 @@ import ( "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer" "github.com/bazelbuild/remote-apis-sdks/go/pkg/client" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/diskcache" "github.com/bazelbuild/remote-apis-sdks/go/pkg/moreflag" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -79,6 +80,10 @@ var ( UseRoundRobinBalancer = flag.Bool("use_round_robin_balancer", false, "If true, a round-robin connection bool is used for gRPC. Otherwise, the existing load balancer is used.") // RoundRobinBalancerPoolSize specifies the pool size for the round robin balancer. RoundRobinBalancerPoolSize = flag.Int("round_robin_balancer_pool_size", client.DefaultMaxConcurrentRequests, "pool size for round robin grpc balacner") + // DiskCachePath, if set, adds a local disk cache for downloaded outputs and action cache results. + DiskCachePath = flag.String("disk_cache_path", "", "If set, will use a local disk cache for downloaded outputs.") + // DiskCacheCapacityGb specifies a maximum limit in GB to store in the local disk cache. A no-op if --disk_cache_path is not set. + DiskCacheCapacityGb = flag.Float64("disk_cache_max_gb", 1.0, "Maximum GB to store in the local disk cache. A no-op if --disk_cache_path is not set.") ) func init() { @@ -132,6 +137,14 @@ func NewClientFromFlags(ctx context.Context, opts ...client.Opt) (*client.Client log.V(1).Infof("KeepAlive params = %v", params) dialOpts = append(dialOpts, grpc.WithKeepaliveParams(params)) } + if *DiskCachePath != "" { + capBytes := uint64(*DiskCacheCapacityGb * 1024 * 1024 * 1024) + diskCache, err := diskcache.New(ctx, *DiskCachePath, capBytes) + if err != nil { + return nil, err + } + opts = append(opts, &client.DiskCacheOpts{DiskCache: diskCache}) + } return client.NewClient(ctx, *Instance, client.DialParams{ Service: *Service, NoSecurity: *ServiceNoSecurity, diff --git a/go/pkg/rexec/rexec.go b/go/pkg/rexec/rexec.go index 18ecf6dcf..87b732340 100644 --- a/go/pkg/rexec/rexec.go +++ b/go/pkg/rexec/rexec.go @@ -274,7 +274,7 @@ func (ec *Context) GetCachedResult() { } if ec.opt.AcceptCached && !ec.opt.DoNotCache { ec.Metadata.EventTimes[command.EventCheckActionCache] = &command.TimeInterval{From: time.Now()} - resPb, err := ec.client.GrpcClient.CheckActionCache(ec.ctx, ec.Metadata.ActionDigest.ToProto()) + resPb, err := ec.client.GrpcClient.CheckActionCache(ec.ctx, ec.Metadata.ActionDigest) ec.Metadata.EventTimes[command.EventCheckActionCache].To = time.Now() if err != nil { ec.Result = command.NewRemoteErrorResult(err) diff --git a/go/pkg/tool/tool.go b/go/pkg/tool/tool.go index 604a0c71b..41089b9a3 100644 --- a/go/pkg/tool/tool.go +++ b/go/pkg/tool/tool.go @@ -855,11 +855,7 @@ func (c *Client) getActionResult(ctx context.Context, actionDigest string) (*rep if err != nil { return nil, err } - d := &repb.Digest{ - Hash: acDg.Hash, - SizeBytes: acDg.Size, - } - resPb, err := c.GrpcClient.CheckActionCache(ctx, d) + resPb, err := c.GrpcClient.CheckActionCache(ctx, acDg) if err != nil { return nil, err }