diff --git a/cache.go b/cache.go index efe0a57..a1457c5 100644 --- a/cache.go +++ b/cache.go @@ -6,120 +6,92 @@ import ( "time" ) -type cacheOp struct { - m masquerade - remove bool - close bool +func (d *direct) initCaching(cacheFile string) { + d.prepopulateMasquerades(cacheFile) + go d.maintainCache(cacheFile) } -func (d *direct) initCaching(cacheFile string) int { - cache := d.prepopulateMasquerades(cacheFile) - prevetted := len(cache) - go d.fillCache(cache, cacheFile) - return prevetted -} - -func (d *direct) prepopulateMasquerades(cacheFile string) []masquerade { - var cache []masquerade +func (d *direct) prepopulateMasquerades(cacheFile string) { bytes, err := ioutil.ReadFile(cacheFile) if err != nil { // This is not a big deal since we'll just fill the cache later - log.Debugf("ignorable error: Unable to read cache file for prepoulation.: %v", err) - return nil + log.Debugf("ignorable error: Unable to read cache file for prepopulation: %v", err) + return } if len(bytes) == 0 { // This can happen if the file is empty or just not there log.Debug("ignorable error: Cache file is empty") - return nil + return } log.Debugf("Attempting to prepopulate masquerades from cache file: %v", cacheFile) - var masquerades []masquerade - if err := json.Unmarshal(bytes, &masquerades); err != nil { - log.Errorf("Error prepopulating cached masquerades: %v", err) - return cache + var cachedMasquerades []*masquerade + if err := json.Unmarshal(bytes, &cachedMasquerades); err != nil { + log.Errorf("Error reading cached masquerades: %v", err) + return } - log.Debugf("Cache contained %d masquerades", len(masquerades)) + log.Debugf("Cache contained %d masquerades", len(cachedMasquerades)) now := time.Now() - for _, m := range masquerades { - if now.Sub(m.LastVetted) < d.maxAllowedCachedAge { - // fill in default for masquerades lacking provider id - if m.ProviderID == "" { - m.ProviderID = d.defaultProviderID - } - // Skip entries for providers that are not configured. - _, ok := d.providers[m.ProviderID] - if !ok { - log.Debugf("Skipping cached entry for unknown/disabled provider %s", m.ProviderID) - continue - } - select { - case d.cached <- m: - // submitted - cache = append(cache, m) - default: - // channel full, that's okay + + // update last succeeded status of masquerades based on cached values + for _, m := range d.masquerades { + for _, cm := range cachedMasquerades { + sameMasquerade := cm.ProviderID == m.ProviderID && cm.Domain == m.Domain && cm.IpAddress == m.IpAddress + cachedValueFresh := now.Sub(m.LastSucceeded) < d.maxAllowedCachedAge + if sameMasquerade && cachedValueFresh { + m.LastSucceeded = cm.LastSucceeded } } } +} - return cache +func (d *direct) markCacheDirty() { + select { + case d.cacheDirty <- nil: + // okay + default: + // already dirty + } } -func (d *direct) fillCache(cache []masquerade, cacheFile string) { - saveTicker := time.NewTicker(d.cacheSaveInterval) - defer saveTicker.Stop() - cacheChanged := false +func (d *direct) maintainCache(cacheFile string) { for { select { - case op := <-d.toCache: - if op.close { - log.Debug("Cache closed, stop filling") + case <-d.cacheClosed: + return + case <-time.After(d.cacheSaveInterval): + select { + case <-d.cacheClosed: return + case <-d.cacheDirty: + d.updateCache(cacheFile) } - m := op.m - if op.remove { - newCache := make([]masquerade, len(cache)) - for _, existing := range cache { - if existing.Domain == m.Domain && existing.IpAddress == m.IpAddress { - log.Debugf("Removing masquerade for %v (%v)", m.Domain, m.IpAddress) - } else { - newCache = append(newCache, existing) - } - } - cache = newCache - } else { - log.Debugf("Caching vetted masquerade for %v (%v)", m.Domain, m.IpAddress) - cache = append(cache, m) - } - cacheChanged = true - case <-saveTicker.C: - if !cacheChanged { - continue - } - log.Debug("Saving updated masquerade cache") - // Truncate cache to max length if necessary - if len(cache) > d.maxCacheSize { - truncated := make([]masquerade, d.maxCacheSize) - copy(truncated, cache[len(cache)-d.maxCacheSize:]) - cache = truncated - } - b, err := json.Marshal(cache) - if err != nil { - log.Errorf("Unable to marshal cache to JSON: %v", err) - break - } - err = ioutil.WriteFile(cacheFile, b, 0644) - if err != nil { - log.Errorf("Unable to save cache to disk: %v", err) - } - cacheChanged = false } } } +func (d *direct) updateCache(cacheFile string) { + log.Debugf("Updating cache at %v", cacheFile) + cache := d.masquerades.sortedCopy() + sizeToSave := len(cache) + if d.maxCacheSize < sizeToSave { + sizeToSave = d.maxCacheSize + } + b, err := json.Marshal(cache[:sizeToSave]) + if err != nil { + log.Errorf("Unable to marshal cache to JSON: %v", err) + return + } + err = ioutil.WriteFile(cacheFile, b, 0644) + if err != nil { + log.Errorf("Unable to save cache to disk: %v", err) + } +} + func (d *direct) closeCache() { - d.toCache <- &cacheOp{close: true} + d.closeCacheOnce.Do(func() { + close(d.cacheClosed) + }) } diff --git a/cache_test.go b/cache_test.go index c1c43b5..22f2e45 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,6 +1,7 @@ package fronted import ( + "encoding/json" "io/ioutil" "os" "path/filepath" @@ -28,46 +29,38 @@ func TestCaching(t *testing.T) { makeDirect := func() *direct { d := &direct{ - candidates: make(chan masquerade, 1000), - masquerades: make(chan masquerade, 1000), - cached: make(chan masquerade, 1000), + masquerades: make(sortedMasquerades, 0, 1000), maxAllowedCachedAge: 250 * time.Millisecond, maxCacheSize: 4, cacheSaveInterval: 50 * time.Millisecond, - toCache: make(chan *cacheOp, 1000), + cacheDirty: make(chan interface{}, 1), + cacheClosed: make(chan interface{}), providers: providers, defaultProviderID: cloudsackID, } - go d.fillCache(make([]masquerade, 0), cacheFile) + go d.maintainCache(cacheFile) return d } now := time.Now() - ma := masquerade{Masquerade{Domain: "a", IpAddress: "1"}, now, testProviderID} - mb := masquerade{Masquerade{Domain: "b", IpAddress: "2"}, now, testProviderID} - mc := masquerade{Masquerade{Domain: "c", IpAddress: "3"}, now, ""} // defaulted - md := masquerade{Masquerade{Domain: "d", IpAddress: "4"}, now, "sadcloud"} // skipped + mb := &masquerade{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, ProviderID: testProviderID} + mc := &masquerade{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, ProviderID: ""} // defaulted + md := &masquerade{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped d := makeDirect() - d.toCache <- &cacheOp{m: ma} - d.toCache <- &cacheOp{m: mb} - d.toCache <- &cacheOp{m: mc} - d.toCache <- &cacheOp{m: md} - d.toCache <- &cacheOp{m: ma, remove: true} + d.masquerades = append(d.masquerades, mb, mc, md) - readCached := func() []masquerade { - var result []masquerade - for { - select { - case m := <-d.cached: - result = append(result, m) - default: - return result - } - } + readCached := func() []*masquerade { + var result []*masquerade + b, err := ioutil.ReadFile(cacheFile) + require.NoError(t, err, "Unable to read cache file") + err = json.Unmarshal(b, &result) + require.NoError(t, err, "Unable to unmarshal cache file") + return result } - // Fill the cache + // Save the cache + d.markCacheDirty() time.Sleep(d.cacheSaveInterval * 2) d.closeCache() @@ -77,18 +70,12 @@ func TestCaching(t *testing.T) { d = makeDirect() d.prepopulateMasquerades(cacheFile) masquerades := readCached() - require.Len(t, masquerades, 2, "Wrong number of masquerades read") - require.Equal(t, "b", masquerades[0].Domain, "Wrong masquerade at position 0") - require.Equal(t, "2", masquerades[0].IpAddress, "Masquerade at position 0 has wrong IpAddress") - require.Equal(t, testProviderID, masquerades[0].ProviderID, "Masquerade at position 0 has wrong ProviderID") - require.Equal(t, "c", masquerades[1].Domain, "Wrong masquerade at position 0") - require.Equal(t, "3", masquerades[1].IpAddress, "Masquerade at position 1 has wrong IpAddress") - require.Equal(t, cloudsackID, masquerades[1].ProviderID, "Masquerade at position 1 has wrong ProviderID") - d.closeCache() - - time.Sleep(d.maxAllowedCachedAge) - d = makeDirect() - d.prepopulateMasquerades(cacheFile) - require.Empty(t, readCached(), "Cache should be empty after masquerades expire") + require.Len(t, masquerades, 3, "Wrong number of masquerades read") + for i, expected := range []*masquerade{mb, mc, md} { + require.Equal(t, expected.Domain, masquerades[i].Domain, "Wrong masquerade at position %d", i) + require.Equal(t, expected.IpAddress, masquerades[i].IpAddress, "Masquerade at position %d has wrong IpAddress", 0) + require.Equal(t, expected.ProviderID, masquerades[i].ProviderID, "Masquerade at position %d has wrong ProviderID", 0) + require.Equal(t, now.Unix(), masquerades[i].LastSucceeded.Unix(), "Masquerade at position %d has wrong LastSucceeded", 0) + } d.closeCache() } diff --git a/context.go b/context.go index d26f335..3af051b 100644 --- a/context.go +++ b/context.go @@ -34,9 +34,9 @@ func NewDirect(timeout time.Duration) (http.RoundTripper, bool) { return DefaultContext.NewDirect(timeout) } -// CloseCache closes any existing cache file in the default context -func CloseCache() { - DefaultContext.CloseCache() +// Close closes any existing cache file in the default context +func Close() { + DefaultContext.Close() } func NewFrontingContext(name string) *FrontingContext { @@ -84,13 +84,12 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m d := &direct{ certPool: pool, - candidates: make(chan masquerade, size), - masquerades: make(chan masquerade, size), - cached: make(chan masquerade, size), + masquerades: make(sortedMasquerades, 0, size), maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, cacheSaveInterval: defaultCacheSaveInterval, - toCache: make(chan *cacheOp, defaultMaxCacheSize), + cacheDirty: make(chan interface{}, 1), + cacheClosed: make(chan interface{}), defaultProviderID: defaultProviderID, providers: make(map[string]*Provider), clientHelloID: clientHelloID, @@ -122,8 +121,8 @@ func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper return instance.(http.RoundTripper), true } -// CloseCache closes any existing cache file in the default contexxt. -func (fctx *FrontingContext) CloseCache() { +// Close closes any existing cache file in the default contexxt. +func (fctx *FrontingContext) Close() { _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { existing := _existing.(*direct) diff --git a/direct.go b/direct.go index ab86def..8c80666 100644 --- a/direct.go +++ b/direct.go @@ -2,6 +2,7 @@ package fronted import ( "bytes" + "context" "crypto/x509" "errors" "fmt" @@ -12,6 +13,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" tls "github.com/refraction-networking/utls" @@ -38,13 +40,13 @@ var ( // direct is an implementation of http.RoundTripper type direct struct { certPool *x509.CertPool - candidates chan masquerade - masquerades chan masquerade - cached chan masquerade + masquerades sortedMasquerades maxAllowedCachedAge time.Duration maxCacheSize int cacheSaveInterval time.Duration - toCache chan *cacheOp + cacheDirty chan interface{} + cacheClosed chan interface{} + closeCacheOnce sync.Once defaultProviderID string providers map[string]*Provider clientHelloID tls.ClientHelloID @@ -70,7 +72,7 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { for _, c := range sh { log.Trace("Adding candidate") - d.candidates <- masquerade{Masquerade: *c, ProviderID: key} + d.masquerades = append(d.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) } } } @@ -111,43 +113,32 @@ func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { func (d *direct) vet(numberToVet int) { log.Debugf("Vetting %d initial candidates in parallel", numberToVet) for i := 0; i < numberToVet; i++ { - go d.vetOneUntilGood() + go d.vetOne() } } -func (d *direct) vetOneUntilGood() { - for { - if !d.vetOne() { - return - } - } -} - -func (d *direct) vetOne() bool { +func (d *direct) vetOne() { // We're just testing the ability to connect here, destination site doesn't // really matter log.Trace("Vetting one") - // don't vet a new masquerade if encountering an error since vetOne will keep looping until we get a successful connection - vetNewOnError := false - conn, m, masqueradeGood, masqueradesRemain, err := d.dialWith(d.candidates, d.candidates, vetNewOnError) + conn, m, masqueradeGood, err := d.dialWith(context.Background(), d.masquerades) if err != nil { - return masqueradesRemain + log.Errorf("unexpected error vetting masquerades: %v", err) } defer conn.Close() provider := d.providerFor(m) if provider == nil { log.Tracef("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID) - return masqueradesRemain + return } if !masqueradeGood(postCheck(conn, provider.TestURL)) { log.Tracef("Unsuccessful vetting with POST request, discarding masquerade") - return masqueradesRemain + return } log.Trace("Finished vetting one") - return false } // postCheck does a post with invalid data to verify domain-fronting works @@ -245,7 +236,7 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := d.dial() + conn, m, masqueradeGood, err := d.dial(req.Context()) if err != nil { // unable to find good masquerade, fail op.FailIf(err) @@ -326,86 +317,46 @@ func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) // masquerades. If successful, it returns a connection to the masquerade, // the selected masquerade, and a function that the caller can use to // tell us whether the masquerade is good or not (i.e. if masquerade was good, -// keep it, else vet a new one). -func (d *direct) dial() (net.Conn, *masquerade, func(bool) bool, error) { - // if dialing fails, eagerly vet a new masquerade - vetNewOnError := true - conn, m, masqueradeGood, _, err := d.dialWith(d.masquerades, d.cached, vetNewOnError) +// keep it). +func (d *direct) dial(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { + conn, m, masqueradeGood, err := d.dialWith(ctx, d.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialWith(vetted chan masquerade, cached chan masquerade, vetNewOnError bool) (net.Conn, *masquerade, func(bool) bool, bool, error) { - retryLater := make([]masquerade, 0) - defer func() { - for _, m := range retryLater { - // when network just recovered from offline, retryLater has more - // elements than the capacity of the channel. - select { - case vetted <- m: - default: - log.Debug("Dropping masquerade: retry channel full") - } - } - }() - +func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { for { - // order of preference vetted -> cached -> unvetted - var m masquerade - select { - case m = <-vetted: - log.Trace("Got vetted masquerade") - default: + masqueradesToTry := masquerades.sortedCopy() + for _, m := range masqueradesToTry { + // check to see if we've timed out select { - case m = <-cached: - log.Trace("Got cached masquerade") + case <-ctx.Done(): + return nil, nil, nil, errors.New("could not dial any masquerade?") default: - log.Trace("No vetted or cached masquerade found, falling back to unvetted candidate") - select { - case m = <-d.candidates: - log.Trace("Got unvetted masquerade") - default: - return nil, nil, nil, false, errors.New("could not dial any masquerade?") - } + // okay } - } - log.Tracef("Dialing to %v", m) - - // We do the full TLS connection here because in practice the domains at a given IP - // address can change frequently on CDNs, so the certificate may not match what - // we expect. - conn, retriable, err := d.doDial(&m.Masquerade) - masqueradeGood := func(good bool) bool { - if good { - m.LastVetted = time.Now() - // Requeue the working connection to masquerades - d.masquerades <- m - select { - case d.toCache <- &cacheOp{m: m}: - // ok - default: - // cache writing has fallen behind, drop masquerade - log.Debug("Dropping masquerade: cache writing is behind") - } - } else { - go func() { - d.toCache <- &cacheOp{m: m, remove: true} - }() - if vetNewOnError { - go d.vetOneUntilGood() + log.Tracef("Dialing to %v", m) + + // We do the full TLS connection here because in practice the domains at a given IP + // address can change frequently on CDNs, so the certificate may not match what + // we expect. + conn, retriable, err := d.doDial(&m.Masquerade) + masqueradeGood := func(good bool) bool { + if good { + m.markSucceeded() + } else { + m.markFailed() } + d.markCacheDirty() + return good + } + if err == nil { + log.Trace("Returning connection") + return conn, m, masqueradeGood, err + } else if !retriable { + log.Debugf("Dropping masquerade: non retryable error: %v", err) + masqueradeGood(false) } - - return good - } - if err == nil { - log.Trace("Returning connection") - return conn, &m, masqueradeGood, true, err - } else if retriable { - retryLater = append(retryLater, m) - } else { - log.Debugf("Dropping masquerade: non retryable error: %v", err) - masqueradeGood(false) } } } diff --git a/direct_test.go b/direct_test.go index 7a982d9..06f2d42 100644 --- a/direct_test.go +++ b/direct_test.go @@ -60,6 +60,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE client := &http.Client{ Transport: transport, + Timeout: 5 * time.Second, } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) @@ -108,15 +109,14 @@ func TestLoadCandidates(t *testing.T) { } d := &direct{ - candidates: make(chan masquerade, len(expected)), + masquerades: make(sortedMasquerades, 0, len(expected)), } d.loadCandidates(providers) - close(d.candidates) actual := make(map[Masquerade]bool) count := 0 - for m := range d.candidates { + for _, m := range d.masquerades { actual[Masquerade{m.Domain, m.IpAddress}] = true count++ } @@ -356,8 +356,7 @@ func TestHostAliasesMulti(t *testing.T) { } } - assert.True(t, providerCounts["cloudsack"] > 1) - assert.True(t, providerCounts["sadcloud"] > 1) + assert.True(t, providerCounts["cloudsack"]+providerCounts["sadcloud"] > 2) } func TestPassthrough(t *testing.T) { @@ -517,15 +516,12 @@ func TestCustomValidators(t *testing.T) { } Configure(certs, providers, "sadcloud", "") - - // Wait a while for vetting to finish - time.Sleep(5 * time.Second) } // This error indicates that the validator has discarded all masquerades. // Each test starts with one masquerade, which is vetted during the // call to NewDirect. - masqueradesExhausted := fmt.Sprintf(`Get "%v": could not dial any masquerade?`, testURL) + masqueradesExhausted := fmt.Sprintf(`Get "%v": could not complete request even with retries`, testURL) tests := []struct { responseCode int @@ -609,7 +605,7 @@ func TestCustomValidators(t *testing.T) { res, err := client.Do(req) if test.expectedError == "" { - if !assert.Nil(t, err) { + if !assert.NoError(t, err) { continue } assert.Equal(t, test.responseCode, res.StatusCode, "Failed to force response status code") diff --git a/masquerade.go b/masquerade.go index 167a44a..041b370 100644 --- a/masquerade.go +++ b/masquerade.go @@ -4,7 +4,9 @@ import ( "fmt" "net" "net/http" + "sort" "strings" + "sync" "time" ) @@ -35,10 +37,29 @@ type Masquerade struct { type masquerade struct { Masquerade - // lastVetted: the most recent time at which this Masquerade was vetted - LastVetted time.Time + // lastSucceeded: the most recent time at which this Masquerade succeeded + LastSucceeded time.Time // id of DirectProvider that this masquerade is provided by ProviderID string + mx sync.RWMutex +} + +func (m *masquerade) lastSucceeded() time.Time { + m.mx.RLock() + defer m.mx.RUnlock() + return m.LastSucceeded +} + +func (m *masquerade) markSucceeded() { + m.mx.Lock() + defer m.mx.Unlock() + m.LastSucceeded = time.Now() +} + +func (m *masquerade) markFailed() { + m.mx.Lock() + defer m.mx.Unlock() + m.LastSucceeded = time.Time{} } // A Direct fronting provider configuration. @@ -137,3 +158,25 @@ func NewStatusCodeValidator(reject []int) ResponseValidator { return nil } } + +// slice of masquerade sorted by last vetted time +type sortedMasquerades []*masquerade + +func (m sortedMasquerades) Len() int { return len(m) } +func (m sortedMasquerades) Swap(i, j int) { m[i], m[j] = m[j], m[i] } +func (m sortedMasquerades) Less(i, j int) bool { + if m[i].lastSucceeded().After(m[j].lastSucceeded()) { + return true + } else if m[j].lastSucceeded().After(m[i].lastSucceeded()) { + return false + } else { + return m[i].IpAddress < m[j].IpAddress + } +} + +func (m sortedMasquerades) sortedCopy() sortedMasquerades { + c := make(sortedMasquerades, len(m)) + copy(c, m) + sort.Sort(c) + return c +}