Skip to content

Commit

Permalink
Use a sorted list instead of a bunch of channels for tracking masquer…
Browse files Browse the repository at this point in the history
…ades
  • Loading branch information
oxtoacart committed May 31, 2023
1 parent 071d192 commit 3cf7ac1
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 237 deletions.
142 changes: 57 additions & 85 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,120 +6,92 @@ import (
"time"
)

type cacheOp struct {
m masquerade
remove bool
close bool
func (d *direct) initCaching(cacheFile string) {
d.prepopulateMasquerades(cacheFile)
go d.maintainCache(cacheFile)
}

func (d *direct) initCaching(cacheFile string) int {
cache := d.prepopulateMasquerades(cacheFile)
prevetted := len(cache)
go d.fillCache(cache, cacheFile)
return prevetted
}

func (d *direct) prepopulateMasquerades(cacheFile string) []masquerade {
var cache []masquerade
func (d *direct) prepopulateMasquerades(cacheFile string) {
bytes, err := ioutil.ReadFile(cacheFile)
if err != nil {
// This is not a big deal since we'll just fill the cache later
log.Debugf("ignorable error: Unable to read cache file for prepoulation.: %v", err)
return nil
log.Debugf("ignorable error: Unable to read cache file for prepopulation: %v", err)
return
}

if len(bytes) == 0 {
// This can happen if the file is empty or just not there
log.Debug("ignorable error: Cache file is empty")
return nil
return
}

log.Debugf("Attempting to prepopulate masquerades from cache file: %v", cacheFile)
var masquerades []masquerade
if err := json.Unmarshal(bytes, &masquerades); err != nil {
log.Errorf("Error prepopulating cached masquerades: %v", err)
return cache
var cachedMasquerades []*masquerade
if err := json.Unmarshal(bytes, &cachedMasquerades); err != nil {
log.Errorf("Error reading cached masquerades: %v", err)
return
}

log.Debugf("Cache contained %d masquerades", len(masquerades))
log.Debugf("Cache contained %d masquerades", len(cachedMasquerades))
now := time.Now()
for _, m := range masquerades {
if now.Sub(m.LastVetted) < d.maxAllowedCachedAge {
// fill in default for masquerades lacking provider id
if m.ProviderID == "" {
m.ProviderID = d.defaultProviderID
}
// Skip entries for providers that are not configured.
_, ok := d.providers[m.ProviderID]
if !ok {
log.Debugf("Skipping cached entry for unknown/disabled provider %s", m.ProviderID)
continue
}
select {
case d.cached <- m:
// submitted
cache = append(cache, m)
default:
// channel full, that's okay

// update last succeeded status of masquerades based on cached values
for _, m := range d.masquerades {
for _, cm := range cachedMasquerades {
sameMasquerade := cm.ProviderID == m.ProviderID && cm.Domain == m.Domain && cm.IpAddress == m.IpAddress
cachedValueFresh := now.Sub(m.LastSucceeded) < d.maxAllowedCachedAge
if sameMasquerade && cachedValueFresh {
m.LastSucceeded = cm.LastSucceeded
}
}
}
}

return cache
func (d *direct) markCacheDirty() {
select {
case d.cacheDirty <- nil:
// okay
default:
// already dirty
}
}

func (d *direct) fillCache(cache []masquerade, cacheFile string) {
saveTicker := time.NewTicker(d.cacheSaveInterval)
defer saveTicker.Stop()
cacheChanged := false
func (d *direct) maintainCache(cacheFile string) {
for {
select {
case op := <-d.toCache:
if op.close {
log.Debug("Cache closed, stop filling")
case <-d.cacheClosed:
return
case <-time.After(d.cacheSaveInterval):
select {
case <-d.cacheClosed:
return
case <-d.cacheDirty:
d.updateCache(cacheFile)
}
m := op.m
if op.remove {
newCache := make([]masquerade, len(cache))
for _, existing := range cache {
if existing.Domain == m.Domain && existing.IpAddress == m.IpAddress {
log.Debugf("Removing masquerade for %v (%v)", m.Domain, m.IpAddress)
} else {
newCache = append(newCache, existing)
}
}
cache = newCache
} else {
log.Debugf("Caching vetted masquerade for %v (%v)", m.Domain, m.IpAddress)
cache = append(cache, m)
}
cacheChanged = true
case <-saveTicker.C:
if !cacheChanged {
continue
}
log.Debug("Saving updated masquerade cache")
// Truncate cache to max length if necessary
if len(cache) > d.maxCacheSize {
truncated := make([]masquerade, d.maxCacheSize)
copy(truncated, cache[len(cache)-d.maxCacheSize:])
cache = truncated
}
b, err := json.Marshal(cache)
if err != nil {
log.Errorf("Unable to marshal cache to JSON: %v", err)
break
}
err = ioutil.WriteFile(cacheFile, b, 0644)
if err != nil {
log.Errorf("Unable to save cache to disk: %v", err)
}
cacheChanged = false
}
}
}

func (d *direct) updateCache(cacheFile string) {
log.Debugf("Updating cache at %v", cacheFile)
cache := d.masquerades.sortedCopy()
sizeToSave := len(cache)
if d.maxCacheSize < sizeToSave {
sizeToSave = d.maxCacheSize
}
b, err := json.Marshal(cache[:sizeToSave])
if err != nil {
log.Errorf("Unable to marshal cache to JSON: %v", err)
return
}
err = ioutil.WriteFile(cacheFile, b, 0644)
if err != nil {
log.Errorf("Unable to save cache to disk: %v", err)
}
}

func (d *direct) closeCache() {
d.toCache <- &cacheOp{close: true}
d.closeCacheOnce.Do(func() {
close(d.cacheClosed)
})
}
63 changes: 25 additions & 38 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fronted

import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -28,46 +29,38 @@ func TestCaching(t *testing.T) {

makeDirect := func() *direct {
d := &direct{
candidates: make(chan masquerade, 1000),
masquerades: make(chan masquerade, 1000),
cached: make(chan masquerade, 1000),
masquerades: make(sortedMasquerades, 0, 1000),
maxAllowedCachedAge: 250 * time.Millisecond,
maxCacheSize: 4,
cacheSaveInterval: 50 * time.Millisecond,
toCache: make(chan *cacheOp, 1000),
cacheDirty: make(chan interface{}, 1),
cacheClosed: make(chan interface{}),
providers: providers,
defaultProviderID: cloudsackID,
}
go d.fillCache(make([]masquerade, 0), cacheFile)
go d.maintainCache(cacheFile)
return d
}

now := time.Now()
ma := masquerade{Masquerade{Domain: "a", IpAddress: "1"}, now, testProviderID}
mb := masquerade{Masquerade{Domain: "b", IpAddress: "2"}, now, testProviderID}
mc := masquerade{Masquerade{Domain: "c", IpAddress: "3"}, now, ""} // defaulted
md := masquerade{Masquerade{Domain: "d", IpAddress: "4"}, now, "sadcloud"} // skipped
mb := &masquerade{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, ProviderID: testProviderID}
mc := &masquerade{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, ProviderID: ""} // defaulted
md := &masquerade{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped

d := makeDirect()
d.toCache <- &cacheOp{m: ma}
d.toCache <- &cacheOp{m: mb}
d.toCache <- &cacheOp{m: mc}
d.toCache <- &cacheOp{m: md}
d.toCache <- &cacheOp{m: ma, remove: true}
d.masquerades = append(d.masquerades, mb, mc, md)

readCached := func() []masquerade {
var result []masquerade
for {
select {
case m := <-d.cached:
result = append(result, m)
default:
return result
}
}
readCached := func() []*masquerade {
var result []*masquerade
b, err := ioutil.ReadFile(cacheFile)
require.NoError(t, err, "Unable to read cache file")
err = json.Unmarshal(b, &result)
require.NoError(t, err, "Unable to unmarshal cache file")
return result
}

// Fill the cache
// Save the cache
d.markCacheDirty()
time.Sleep(d.cacheSaveInterval * 2)
d.closeCache()

Expand All @@ -77,18 +70,12 @@ func TestCaching(t *testing.T) {
d = makeDirect()
d.prepopulateMasquerades(cacheFile)
masquerades := readCached()
require.Len(t, masquerades, 2, "Wrong number of masquerades read")
require.Equal(t, "b", masquerades[0].Domain, "Wrong masquerade at position 0")
require.Equal(t, "2", masquerades[0].IpAddress, "Masquerade at position 0 has wrong IpAddress")
require.Equal(t, testProviderID, masquerades[0].ProviderID, "Masquerade at position 0 has wrong ProviderID")
require.Equal(t, "c", masquerades[1].Domain, "Wrong masquerade at position 0")
require.Equal(t, "3", masquerades[1].IpAddress, "Masquerade at position 1 has wrong IpAddress")
require.Equal(t, cloudsackID, masquerades[1].ProviderID, "Masquerade at position 1 has wrong ProviderID")
d.closeCache()

time.Sleep(d.maxAllowedCachedAge)
d = makeDirect()
d.prepopulateMasquerades(cacheFile)
require.Empty(t, readCached(), "Cache should be empty after masquerades expire")
require.Len(t, masquerades, 3, "Wrong number of masquerades read")
for i, expected := range []*masquerade{mb, mc, md} {
require.Equal(t, expected.Domain, masquerades[i].Domain, "Wrong masquerade at position %d", i)
require.Equal(t, expected.IpAddress, masquerades[i].IpAddress, "Masquerade at position %d has wrong IpAddress", 0)
require.Equal(t, expected.ProviderID, masquerades[i].ProviderID, "Masquerade at position %d has wrong ProviderID", 0)
require.Equal(t, now.Unix(), masquerades[i].LastSucceeded.Unix(), "Masquerade at position %d has wrong LastSucceeded", 0)
}
d.closeCache()
}
17 changes: 8 additions & 9 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ func NewDirect(timeout time.Duration) (http.RoundTripper, bool) {
return DefaultContext.NewDirect(timeout)
}

// CloseCache closes any existing cache file in the default context
func CloseCache() {
DefaultContext.CloseCache()
// Close closes any existing cache file in the default context
func Close() {
DefaultContext.Close()
}

func NewFrontingContext(name string) *FrontingContext {
Expand Down Expand Up @@ -84,13 +84,12 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m

d := &direct{
certPool: pool,
candidates: make(chan masquerade, size),
masquerades: make(chan masquerade, size),
cached: make(chan masquerade, size),
masquerades: make(sortedMasquerades, 0, size),
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
maxCacheSize: defaultMaxCacheSize,
cacheSaveInterval: defaultCacheSaveInterval,
toCache: make(chan *cacheOp, defaultMaxCacheSize),
cacheDirty: make(chan interface{}, 1),
cacheClosed: make(chan interface{}),
defaultProviderID: defaultProviderID,
providers: make(map[string]*Provider),
clientHelloID: clientHelloID,
Expand Down Expand Up @@ -122,8 +121,8 @@ func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper
return instance.(http.RoundTripper), true
}

// CloseCache closes any existing cache file in the default contexxt.
func (fctx *FrontingContext) CloseCache() {
// Close closes any existing cache file in the default contexxt.
func (fctx *FrontingContext) Close() {
_existing, ok := fctx.instance.Get(0)
if ok && _existing != nil {
existing := _existing.(*direct)
Expand Down
Loading

0 comments on commit 3cf7ac1

Please sign in to comment.