From 9ce208752d141c502d24d3c6a09148e94df23293 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 14:43:58 -0600 Subject: [PATCH] Naming and initialization cleanup --- cache.go | 12 +-- cache_test.go | 4 +- context.go | 50 +++------- direct.go => fronted.go | 151 +++++++++++++++--------------- direct_test.go => fronted_test.go | 20 ++-- 5 files changed, 107 insertions(+), 130 deletions(-) rename direct.go => fronted.go (75%) rename direct_test.go => fronted_test.go (99%) diff --git a/cache.go b/cache.go index cbb79d3..9df34d5 100644 --- a/cache.go +++ b/cache.go @@ -7,12 +7,12 @@ import ( "time" ) -func (d *direct) initCaching(cacheFile string) { +func (d *fronted) initCaching(cacheFile string) { d.prepopulateMasquerades(cacheFile) go d.maintainCache(cacheFile) } -func (d *direct) prepopulateMasquerades(cacheFile string) { +func (d *fronted) prepopulateMasquerades(cacheFile string) { bytes, err := os.ReadFile(cacheFile) if err != nil { // This is not a big deal since we'll just fill the cache later @@ -48,7 +48,7 @@ func (d *direct) prepopulateMasquerades(cacheFile string) { } } -func (d *direct) markCacheDirty() { +func (d *fronted) markCacheDirty() { select { case d.cacheDirty <- nil: // okay @@ -57,7 +57,7 @@ func (d *direct) markCacheDirty() { } } -func (d *direct) maintainCache(cacheFile string) { +func (d *fronted) maintainCache(cacheFile string) { for { select { case <-d.cacheClosed: @@ -73,7 +73,7 @@ func (d *direct) maintainCache(cacheFile string) { } } -func (d *direct) updateCache(cacheFile string) { +func (d *fronted) updateCache(cacheFile string) { log.Debugf("Updating cache at %v", cacheFile) cache := d.masquerades.sortedCopy() sizeToSave := len(cache) @@ -101,7 +101,7 @@ func (d *direct) updateCache(cacheFile string) { } } -func (d *direct) closeCache() { +func (d *fronted) closeCache() { d.closeCacheOnce.Do(func() { close(d.cacheClosed) }) diff --git a/cache_test.go b/cache_test.go index 06e1c42..9b874d1 100644 --- a/cache_test.go +++ b/cache_test.go @@ -26,8 +26,8 @@ func TestCaching(t *testing.T) { cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil), } - makeDirect := func() *direct { - d := &direct{ + makeDirect := func() *fronted { + d := &fronted{ masquerades: make(sortedMasquerades, 0, 1000), maxAllowedCachedAge: 250 * time.Millisecond, maxCacheSize: 4, diff --git a/context.go b/context.go index c9c7730..68027f3 100644 --- a/context.go +++ b/context.go @@ -27,11 +27,11 @@ func Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProvi } } -// NewDirect creates a new http.RoundTripper that does direct domain fronting +// NewFronted creates a new http.RoundTripper that does direct domain fronting // using the default context. If the default context isn't configured within // the given timeout, this method returns nil, false. -func NewDirect(timeout time.Duration) (http.RoundTripper, bool) { - return DefaultContext.NewDirect(timeout) +func NewFronted(timeout time.Duration) (http.RoundTripper, bool) { + return DefaultContext.NewFronted(timeout) } // Close closes any existing cache file in the default context @@ -68,51 +68,23 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { - existing := _existing.(*direct) + existing := _existing.(*fronted) log.Debugf("Closing cache from existing instance for %s context", fctx.name) existing.closeCache() } - size := 0 - for _, p := range providers { - size += len(p.Masquerades) + f, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID) + if err != nil { + return err } - - if size == 0 { - return fmt.Errorf("no masquerades for %s context", fctx.name) - } - - d := &direct{ - certPool: pool, - masquerades: make(sortedMasquerades, 0, size), - maxAllowedCachedAge: defaultMaxAllowedCachedAge, - maxCacheSize: defaultMaxCacheSize, - cacheSaveInterval: defaultCacheSaveInterval, - cacheDirty: make(chan interface{}, 1), - cacheClosed: make(chan interface{}), - defaultProviderID: defaultProviderID, - providers: make(map[string]*Provider), - clientHelloID: clientHelloID, - } - - // copy providers - for k, p := range providers { - d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) - } - - d.loadCandidates(d.providers) - if cacheFile != "" { - d.initCaching(cacheFile) - } - d.findWorkingMasquerades() - fctx.instance.Set(d) + fctx.instance.Set(f) return nil } -// NewDirect creates a new http.RoundTripper that does direct domain fronting. +// NewFronted creates a new http.RoundTripper that does direct domain fronting. // If the context isn't configured within the given timeout, this method // returns nil, false. -func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper, bool) { +func (fctx *FrontingContext) NewFronted(timeout time.Duration) (http.RoundTripper, bool) { instance, ok := fctx.instance.Get(timeout) if !ok { log.Errorf("No DirectHttpClient available within %v for context %s", timeout, fctx.name) @@ -125,7 +97,7 @@ func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper func (fctx *FrontingContext) Close() { _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { - existing := _existing.(*direct) + existing := _existing.(*fronted) log.Debugf("Closing cache from existing instance in %s context", fctx.name) existing.closeCache() } diff --git a/direct.go b/fronted.go similarity index 75% rename from direct.go rename to fronted.go index 02e0578..53b1a9f 100644 --- a/direct.go +++ b/fronted.go @@ -19,7 +19,6 @@ import ( tls "github.com/refraction-networking/utls" "github.com/getlantern/golog" - "github.com/getlantern/idletiming" "github.com/getlantern/ops" ) @@ -34,8 +33,9 @@ var ( log = golog.LoggerFor("fronted") ) -// direct is an implementation of http.RoundTripper -type direct struct { +// fronted identifies working IP address/domain pairings for domain fronting and is +// an implementation of http.RoundTripper for the convenience of callers. +type fronted struct { certPool *x509.CertPool masquerades sortedMasquerades maxAllowedCachedAge time.Duration @@ -49,7 +49,44 @@ type direct struct { clientHelloID tls.ClientHelloID } -func (d *direct) loadCandidates(initial map[string]*Provider) { +func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID) (*fronted, error) { + size := 0 + for _, p := range providers { + size += len(p.Masquerades) + } + + if size == 0 { + return nil, fmt.Errorf("no masquerades found in providers") + } + + f := &fronted{ + certPool: pool, + masquerades: make(sortedMasquerades, 0, size), + maxAllowedCachedAge: defaultMaxAllowedCachedAge, + maxCacheSize: defaultMaxCacheSize, + cacheSaveInterval: defaultCacheSaveInterval, + cacheDirty: make(chan interface{}, 1), + cacheClosed: make(chan interface{}), + defaultProviderID: defaultProviderID, + providers: make(map[string]*Provider), + clientHelloID: clientHelloID, + } + + // copy providers + for k, p := range providers { + f.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) + } + + f.loadCandidates(f.providers) + if cacheFile != "" { + f.initCaching(cacheFile) + } + f.findWorkingMasquerades() + + return f, nil +} + +func (f *fronted) loadCandidates(initial map[string]*Provider) { log.Debugf("Loading candidates for %d providers", len(initial)) defer log.Debug("Finished loading candidates") @@ -68,23 +105,23 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { } for _, c := range sh { - d.masquerades = append(d.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) + f.masquerades = append(f.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) } } } -func (d *direct) providerFor(m MasqueradeInterface) *Provider { +func (f *fronted) providerFor(m MasqueradeInterface) *Provider { pid := m.getProviderID() if pid == "" { - pid = d.defaultProviderID + pid = f.defaultProviderID } - return d.providers[pid] + return f.providers[pid] } // Vet vets the specified Masquerade, verifying certificate using the given CertPool. // This is used in genconfig. func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - d := &direct{ + d := &fronted{ certPool: pool, maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, @@ -102,35 +139,35 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { // parallel. Speed is of the essence here, as without working masquerades, users will // be unable to fetch proxy configurations, particularly in the case of a first time // user who does not have proxies cached on disk. -func (d *direct) findWorkingMasquerades() { +func (f *fronted) findWorkingMasquerades() { // vet masquerades in batches const batchSize int = 25 var successful atomic.Uint32 // We loop through all of them until we have 4 successful ones. - for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize { - d.vetBatch(i, batchSize, &successful) + for i := 0; i < len(f.masquerades) && successful.Load() < 4; i += batchSize { + f.vetBatch(i, batchSize, &successful) } } -func (d *direct) vetBatch(start, batchSize int, successful *atomic.Uint32) { +func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32) { log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup - masqueradeSize := len(d.masquerades) + masqueradeSize := len(f.masquerades) for j := start; j < start+batchSize && j < masqueradeSize; j++ { wg.Add(1) go func(m MasqueradeInterface) { defer wg.Done() - if d.vetMasquerade(m) { + if f.vetMasquerade(m) { successful.Add(1) } - }(d.masquerades[j]) + }(f.masquerades[j]) } wg.Wait() } -func (d *direct) vetMasquerade(m MasqueradeInterface) bool { - conn, masqueradeGood, err := d.dialMasquerade(m) +func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { + conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { log.Errorf("unexpected error vetting masquerades: %v", err) return false @@ -141,10 +178,10 @@ func (d *direct) vetMasquerade(m MasqueradeInterface) bool { } }() - provider := d.providerFor(m) + provider := f.providerFor(m) if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", - m.getProviderID(), d.providers) + m.getProviderID(), f.providers) return false } if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { @@ -152,21 +189,20 @@ func (d *direct) vetMasquerade(m MasqueradeInterface) bool { return false } - log.Debugf("Finished vetting one masquerade %v", m) + log.Debugf("Successfully vetted one masquerade %v", m) return true } -// Do continually retries a given request until it succeeds because some -// fronting providers will return a 403 for some domains. -func (d *direct) RoundTrip(req *http.Request) (*http.Response, error) { - res, _, err := d.RoundTripHijack(req) +// RoundTrip loops through all available masquerades, sorted by the one that most recently +// connected, retrying several times on failures. +func (f *fronted) RoundTrip(req *http.Request) (*http.Response, error) { + res, _, err := f.RoundTripHijack(req) return res, err } -// Do continually retries a given request until it succeeds because some -// fronting providers will return a 403 for some domains. Also return the -// underlying net.Conn established. -func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { +// RoundTripHijack loops through all available masquerades, sorted by the one that most +// recently connected, retrying several times on failures. +func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { op := ops.Begin("fronted_roundtrip") defer op.End() @@ -209,13 +245,13 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := d.dialAll(req.Context()) + conn, m, masqueradeGood, err := f.dialAll(req.Context()) if err != nil { // unable to find good masquerade, fail op.FailIf(err) return nil, nil, err } - provider := d.providerFor(m) + provider := f.providerFor(m) if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.getProviderID()) masqueradeGood(false) @@ -269,12 +305,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e } // Dial dials out using all available masquerades until one succeeds. -func (d *direct) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { - conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades) +func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { + conn, m, masqueradeGood, err := f.dialAllWith(ctx, f.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) { +func (f *fronted) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) { // never take more than a minute trying to find a dialer ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() @@ -290,7 +326,7 @@ dialLoop: default: // okay } - conn, masqueradeGood, err := d.dialMasquerade(m) + conn, masqueradeGood, err := f.dialMasquerade(m) if err == nil { return conn, m, masqueradeGood, nil } @@ -299,26 +335,24 @@ dialLoop: return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) } -func (d *direct) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { - // check to see if we've timed out - +func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { log.Tracef("Dialing to %v", m) // We do the full TLS connection here because in practice the domains at a given IP // address can change frequently on CDNs, so the certificate may not match what // we expect. - conn, retriable, err := d.doDial(m) + conn, retriable, err := f.doDial(m) masqueradeGood := func(good bool) bool { if good { m.markSucceeded() } else { m.markFailed() } - d.markCacheDirty() + f.markCacheDirty() return good } if err == nil { - log.Debug("Returning connection") + log.Debugf("Returning connection for masquerade: %v", m) return conn, masqueradeGood, err } else if !retriable { log.Debugf("Dropping masquerade: non retryable error: %v", err) @@ -327,13 +361,13 @@ func (d *direct) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) boo return conn, masqueradeGood, err } -func (d *direct) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) { +func (f *fronted) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() op.Set("masquerade_domain", m.getDomain()) op.Set("masquerade_ip", m.getIpAddress()) - conn, err = d.dialServerWith(m) + conn, err = m.dial(f.certPool, f.clientHelloID) if err != nil { op.FailIf(err) log.Debugf("Could not dial to %v, %v", m.getIpAddress(), err) @@ -349,31 +383,10 @@ func (d *direct) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, e } } else { log.Debugf("Got successful connection to: %v", m) - idleTimeout := 70 * time.Second - - log.Debugf("Wrapping connection in idletiming connection: %v", m) - conn = idletiming.Conn(conn, idleTimeout, func() { - log.Debugf("Connection to %v idle for %v, closed", conn.RemoteAddr(), idleTimeout) - }) } return } -func (d *direct) dialServerWith(m MasqueradeInterface) (net.Conn, error) { - op := ops.Begin("dial_server_with") - defer op.End() - - op.Set("masquerade_domain", m.getDomain()) - op.Set("masquerade_ip", m.getIpAddress()) - - conn, err := m.dial(d.certPool, d.clientHelloID) - if err != nil && m != nil { - err = fmt.Errorf("unable to dial masquerade %s: %s", m.getDomain(), err) - op.FailIf(err) - } - return conn, err -} - func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain string) error { if len(rawCerts) == 0 { return fmt.Errorf("no certificates presented") @@ -422,15 +435,6 @@ func generateVerifyOptions(roots *x509.CertPool, domain string) x509.VerifyOptio } } -// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the -// initial TCP connection to the CDN. -func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config { - return &tls.Config{ - ServerName: m.Domain, - RootCAs: d.certPool, - } -} - // frontedHTTPTransport is the transport to use to route to the actual fronted destination domain. // This uses the pre-established connection to the CDN on the fronting domain. func frontedHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripper { @@ -441,6 +445,7 @@ func frontedHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripp }, TLSHandshakeTimeout: 40 * time.Second, DisableKeepAlives: disableKeepAlives, + IdleConnTimeout: 70 * time.Second, }, } } diff --git a/direct_test.go b/fronted_test.go similarity index 99% rename from direct_test.go rename to fronted_test.go index 0fbc8ad..23fc1a1 100644 --- a/direct_test.go +++ b/fronted_test.go @@ -57,7 +57,7 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { }) Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewDirect(0) + transport, ok := NewFronted(0) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -85,7 +85,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE p := testProvidersWithHosts(hosts) Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewDirect(30 * time.Second) + transport, ok := NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ @@ -94,7 +94,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = NewDirect(0) + transport, ok = NewFronted(0) require.True(t, ok) client = &http.Client{ Transport: transport, @@ -103,7 +103,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE instance, ok := DefaultContext.instance.Get(0) require.True(t, ok) - d := instance.(*direct) + d := instance.(*fronted) // Check the number of masquerades at the end, waiting up to 30 seconds until we get the right number masqueradesAtEnd := 0 @@ -138,7 +138,7 @@ func TestLoadCandidates(t *testing.T) { } } - d := &direct{ + d := &fronted{ masquerades: make(sortedMasquerades, 0, len(expected)), } @@ -238,7 +238,7 @@ func TestHostAliasesBasic(t *testing.T) { certs.AddCert(cloudSack.Certificate()) Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + rt, ok := NewFronted(10 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -349,7 +349,7 @@ func TestHostAliasesMulti(t *testing.T) { } Configure(certs, providers, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + rt, ok := NewFronted(10 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -475,7 +475,7 @@ func TestPassthrough(t *testing.T) { certs.AddCert(cloudSack.Certificate()) Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + rt, ok := NewFronted(10 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -615,7 +615,7 @@ func TestCustomValidators(t *testing.T) { for _, test := range tests { setup(test.validator) - direct, ok := NewDirect(1 * time.Second) + direct, ok := NewFronted(1 * time.Second) if !assert.True(t, ok) { return } @@ -863,7 +863,7 @@ func TestFindWorkingMasquerades(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &direct{} + d := &fronted{} d.providers = make(map[string]*Provider) d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) d.masquerades = make(sortedMasquerades, len(tt.masquerades))