From d16339d5a48a7886775014647332a54696417f8f Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 22 Nov 2024 13:53:40 -0700 Subject: [PATCH 01/25] Change to separately track working fronts --- connecting_fronts.go | 68 +++++++++++++++++ context.go | 60 ++++++++------- fronted.go | 178 +++++++++++++++++++++++-------------------- fronted_test.go | 155 ++++--------------------------------- test_support.go | 4 +- 5 files changed, 211 insertions(+), 254 deletions(-) create mode 100644 connecting_fronts.go diff --git a/connecting_fronts.go b/connecting_fronts.go new file mode 100644 index 0000000..ec4611d --- /dev/null +++ b/connecting_fronts.go @@ -0,0 +1,68 @@ +package fronted + +import ( + "errors" + "sort" + "sync" + "time" +) + +type connectTimeFront struct { + MasqueradeInterface + connectTime time.Duration +} + +type connectingFronts struct { + fronts []connectTimeFront + //frontsChan chan MasqueradeInterface + sync.RWMutex +} + +// Make sure that connectingFronts is a connectListener +var _ workingFronts = &connectingFronts{} + +// newConnectingFronts creates a new ConnectingFronts struct with an empty slice of Masquerade IPs and domains. +func newConnectingFronts() *connectingFronts { + return &connectingFronts{ + fronts: make([]connectTimeFront, 0), + //frontsChan: make(chan MasqueradeInterface), + } +} + +// AddFront adds a new front to the list of fronts. +func (cf *connectingFronts) onConnected(m MasqueradeInterface, connectTime time.Duration) { + cf.Lock() + defer cf.Unlock() + + cf.fronts = append(cf.fronts, connectTimeFront{ + MasqueradeInterface: m, + connectTime: connectTime, + }) + // Sort fronts by connect time. + sort.Slice(cf.fronts, func(i, j int) bool { + return cf.fronts[i].connectTime < cf.fronts[j].connectTime + }) + //cf.frontsChan <- m +} + +func (cf *connectingFronts) onError(m MasqueradeInterface) { + cf.Lock() + defer cf.Unlock() + + // Remove the front from connecting fronts. + for i, front := range cf.fronts { + if front.MasqueradeInterface == m { + cf.fronts = append(cf.fronts[:i], cf.fronts[i+1:]...) + return + } + } +} + +func (cf *connectingFronts) workingFront() (MasqueradeInterface, error) { + cf.RLock() + defer cf.RUnlock() + if len(cf.fronts) == 0 { + return nil, errors.New("no fronts available") + } + return cf.fronts[0].MasqueradeInterface, nil +} diff --git a/context.go b/context.go index eb8d78a..02853a7 100644 --- a/context.go +++ b/context.go @@ -12,53 +12,60 @@ import ( "github.com/getlantern/eventual/v2" ) +// Create an interface for the fronting context +type Fronting interface { + UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) + NewRoundTripper(timeout time.Duration) (http.RoundTripper, bool) + Close() +} + var defaultContext = newFrontingContext("default") +// Make sure that the default context is a Fronting +var _ Fronting = defaultContext + // Configure sets the masquerades to use, the trusted root CAs, and the // cache file for caching masquerades to set up direct domain fronting // in the default context. // // defaultProviderID is used when a masquerade without a provider is // encountered (eg in a cache file) -func Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) { - if err := defaultContext.Configure(pool, providers, defaultProviderID, cacheFile); err != nil { - log.Errorf("Error configuring fronting %s context: %s!!", defaultContext.name, err) +func NewFronter(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) (Fronting, error) { + if err := defaultContext.configure(pool, providers, defaultProviderID, cacheFile); err != nil { + return nil, log.Errorf("Error configuring fronting %s context: %s!!", defaultContext.name, err) } -} - -// NewFronted creates a new http.RoundTripper that does direct domain fronting -// using the default context. If the default context isn't configured within -// the given timeout, this method returns nil, false. -func NewFronted(timeout time.Duration) (http.RoundTripper, bool) { - return defaultContext.NewFronted(timeout) -} - -// Close closes any existing cache file in the default context -func Close() { - defaultContext.Close() + return defaultContext, nil } func newFrontingContext(name string) *frontingContext { return &frontingContext{ - name: name, - instance: eventual.NewValue(), + name: name, + instance: eventual.NewValue(), + connectingFronts: newConnectingFronts(), } } type frontingContext struct { - name string - instance eventual.Value + name string + instance eventual.Value + fronted *fronted + connectingFronts *connectingFronts } -// Configure sets the masquerades to use, the trusted root CAs, and the +// UpdateConfig updates the configuration of the fronting context +func (fctx *frontingContext) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { + fctx.fronted.updateConfig(pool, providers, defaultProviderID) +} + +// configure sets the masquerades to use, the trusted root CAs, and the // cache file for caching masquerades to set up direct domain fronting. // defaultProviderID is used when a masquerade without a provider is // encountered (eg in a cache file) -func (fctx *frontingContext) Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) error { - return fctx.ConfigureWithHello(pool, providers, defaultProviderID, cacheFile, tls.ClientHelloID{}) +func (fctx *frontingContext) configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) error { + return fctx.configureWithHello(pool, providers, defaultProviderID, cacheFile, tls.ClientHelloID{}) } -func (fctx *frontingContext) ConfigureWithHello(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, clientHelloID tls.ClientHelloID) error { +func (fctx *frontingContext) configureWithHello(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, clientHelloID tls.ClientHelloID) error { log.Debugf("Configuring fronted %s context", fctx.name) if len(providers) == 0 { @@ -73,10 +80,11 @@ func (fctx *frontingContext) ConfigureWithHello(pool *x509.CertPool, providers m existing.closeCache() } - _, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + var err error + fctx.fronted, err = newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { log.Debug("Setting fronted instance") fctx.instance.Set(f) - }) + }, fctx.connectingFronts) if err != nil { return err } @@ -86,7 +94,7 @@ func (fctx *frontingContext) ConfigureWithHello(pool *x509.CertPool, providers m // NewFronted creates a new http.RoundTripper that does direct domain fronting. // If the context isn't configured within the given timeout, this method // returns nil, false. -func (fctx *frontingContext) NewFronted(timeout time.Duration) (http.RoundTripper, bool) { +func (fctx *frontingContext) NewRoundTripper(timeout time.Duration) (http.RoundTripper, bool) { start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() diff --git a/fronted.go b/fronted.go index 71ed6b9..4092788 100644 --- a/fronted.go +++ b/fronted.go @@ -2,7 +2,6 @@ package fronted import ( "bytes" - "context" "crypto/x509" "errors" "fmt" @@ -34,6 +33,12 @@ var ( log = golog.LoggerFor("fronted") ) +type workingFronts interface { + onConnected(m MasqueradeInterface, connectTime time.Duration) + onError(m MasqueradeInterface) + workingFront() (MasqueradeInterface, error) +} + // fronted identifies working IP address/domain pairings for domain fronting and is // an implementation of http.RoundTripper for the convenience of callers. type fronted struct { @@ -48,11 +53,16 @@ type fronted struct { defaultProviderID string providers map[string]*Provider clientHelloID tls.ClientHelloID + workingFronts workingFronts + providersMu sync.RWMutex + masqueradesMu sync.RWMutex + frontedMu sync.RWMutex } func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID, - listener func(f *fronted)) (*fronted, error) { + listener func(f *fronted), + workingFronts workingFronts) (*fronted, error) { size := 0 for _, p := range providers { size += len(p.Masquerades) @@ -63,10 +73,7 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, } // copy providers - providersCopy := make(map[string]*Provider, len(providers)) - for k, p := range providers { - providersCopy[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) - } + providersCopy := copyProviders(providers) f := &fronted{ certPool: pool, @@ -79,6 +86,7 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID: defaultProviderID, providers: providersCopy, clientHelloID: clientHelloID, + workingFronts: workingFronts, } if cacheFile != "" { @@ -89,6 +97,15 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, return f, nil } +func copyProviders(providers map[string]*Provider) map[string]*Provider { + providersCopy := make(map[string]*Provider, len(providers)) + for key, p := range providers { + providersCopy[key] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) + log.Debugf("Domain fronting provider is %v", providersCopy[key]) + } + return providersCopy +} + func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { log.Debugf("Loading candidates for %d providers", len(initial)) defer log.Debug("Finished loading candidates") @@ -114,6 +131,34 @@ func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { return masquerades } +func (f *fronted) updateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { + // Make copies just to avoid any concurrency issues with access that may be happening on the + // caller side. + providersCopy := copyProviders(providers) + f.frontedMu.Lock() + defer f.frontedMu.Unlock() + f.addProviders(providersCopy) + f.addMasquerades(loadMasquerades(providersCopy, len(providersCopy))) + f.defaultProviderID = defaultProviderID + f.certPool = pool +} + +func (f *fronted) addProviders(providers map[string]*Provider) { + // Add new providers to the existing providers map, overwriting any existing ones. + f.providersMu.Lock() + defer f.providersMu.Unlock() + for key, p := range providers { + f.providers[key] = p + } +} + +func (f *fronted) addMasquerades(masquerades sortedMasquerades) { + // Add new masquerades to the existing masquerades slice, but add them at the beginning. + f.masqueradesMu.Lock() + defer f.masqueradesMu.Unlock() + f.masquerades = append(masquerades, f.masquerades...) +} + func (f *fronted) providerFor(m MasqueradeInterface) *Provider { pid := m.getProviderID() if pid == "" { @@ -149,35 +194,51 @@ func (f *fronted) findWorkingMasquerades(listener func(f *fronted)) { var successful atomic.Uint32 // We loop through all of them until we have 4 successful ones. - for i := 0; i < len(f.masquerades) && successful.Load() < 4; i += batchSize { + for i := 0; i < f.masqueradeSize() && successful.Load() < 4; i += batchSize { f.vetBatch(i, batchSize, &successful, listener) } } +func (f *fronted) masqueradeSize() int { + f.masqueradesMu.Lock() + defer f.masqueradesMu.Unlock() + return len(f.masquerades) +} + +func (f *fronted) masqueradeAt(i int) MasqueradeInterface { + f.masqueradesMu.Lock() + defer f.masqueradesMu.Unlock() + return f.masquerades[i] +} + func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, listener func(f *fronted)) { log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup - masqueradeSize := len(f.masquerades) - for i := start; i < start+batchSize && i < masqueradeSize; i++ { + for i := start; i < start+batchSize && i < f.masqueradeSize(); i++ { wg.Add(1) go func(m MasqueradeInterface) { defer wg.Done() - if f.vetMasquerade(m) { + working, connectTime := f.vetMasquerade(m) + if working { successful.Add(1) + f.workingFronts.onConnected(m, connectTime) if listener != nil { go listener(f) } + } else { + f.workingFronts.onError(m) } - }(f.masquerades[i]) + }(f.masqueradeAt(i)) } wg.Wait() } -func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { +func (f *fronted) vetMasquerade(m MasqueradeInterface) (bool, time.Duration) { + start := time.Now() conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { log.Debugf("unexpected error vetting masquerades: %v", err) - return false + return false, time.Since(start) } defer func() { if conn != nil { @@ -189,15 +250,15 @@ func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", m.getProviderID(), f.providers) - return false + return false, time.Since(start) } if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") - return false + return false, time.Since(start) } log.Debugf("Successfully vetted one masquerade %v", m) - return true + return true, time.Since(start) } // RoundTrip loops through all available masquerades, sorted by the one that most recently @@ -242,37 +303,41 @@ func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, return io.NopCloser(bytes.NewReader(body)) } - tries := 1 - if isIdempotent { - tries = maxTries - } + const tries = 6 for i := 0; i < tries; i++ { if i > 0 { log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := f.dialAll(req.Context()) + m, err := f.workingFronts.workingFront() if err != nil { - // unable to find good masquerade, fail - op.FailIf(err) - return nil, nil, err + // For some reason we have no working fronts. Sleep for a bit and try again. + time.Sleep(1 * time.Second) + continue } - resp, conn, err := f.validateMasqueradeWithConn(req, conn, m, originHost, getBody, masqueradeGood) + conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { - log.Debugf("Could not complete request: %v", err) + log.Debugf("Could not dial to %v: %v", m, err) + f.workingFronts.onError(m) continue } - return resp, conn, nil + resp, conn, err := f.request(req, conn, m, originHost, getBody, masqueradeGood) + if err != nil { + log.Debugf("Could not complete request: %v", err) + f.workingFronts.onError(m) + } else { + return resp, conn, nil + } } return nil, nil, op.FailIf(errors.New("could not complete request even with retries")) } -func (f *fronted) validateMasqueradeWithConn(req *http.Request, conn net.Conn, m MasqueradeInterface, originHost string, getBody func() io.ReadCloser, masqueradeGood func(bool) bool) (*http.Response, net.Conn, error) { - op := ops.Begin("validate_masquerade_with_conn") +func (f *fronted) request(req *http.Request, conn net.Conn, m MasqueradeInterface, originHost string, getBody func() io.ReadCloser, masqueradeGood func(bool) bool) (*http.Response, net.Conn, error) { + op := ops.Begin("fronted_request") defer op.End() provider := f.providerFor(m) if provider == nil { @@ -322,61 +387,6 @@ func (f *fronted) validateMasqueradeWithConn(req *http.Request, conn net.Conn, m return resp, conn, nil } -// Dial dials out using all available masquerades until one succeeds. -func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { - defer func(op ops.Op) { op.End() }(ops.Begin("dial_all")) - // never take more than a minute trying to find a dialer - ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() - - triedMasquerades := make(map[MasqueradeInterface]bool) - masqueradesToTry := f.masquerades.sortedCopy() - totalMasquerades := len(masqueradesToTry) -dialLoop: - // Loop through up to len(masqueradesToTry) times, trying each masquerade in turn. - // If the context is done, return an error. - for i := 0; i < totalMasquerades; i++ { - select { - case <-ctx.Done(): - log.Debugf("Timed out dialing with %v total masquerades", totalMasquerades) - break dialLoop - default: - // okay - } - - m, err := f.masqueradeToTry(masqueradesToTry, triedMasquerades) - if err != nil { - log.Errorf("No masquerades left to try") - break dialLoop - } - conn, masqueradeGood, err := f.dialMasquerade(m) - triedMasquerades[m] = true - if err != nil { - log.Debugf("Could not dial to %v: %v", m, err) - // As we're looping through the masquerades, each check takes time. As that's happening, - // other goroutines may be successfully vetting new masquerades, which will change the - // sorting. We want to make sure we're always trying the best masquerades first. - masqueradesToTry = f.masquerades.sortedCopy() - totalMasquerades = len(masqueradesToTry) - continue - } - return conn, m, masqueradeGood, nil - } - - return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) -} - -func (f *fronted) masqueradeToTry(masquerades sortedMasquerades, triedMasquerades map[MasqueradeInterface]bool) (MasqueradeInterface, error) { - for _, m := range masquerades { - if triedMasquerades[m] { - continue - } - return m, nil - } - // This should be quite rare, as it means we've tried typically thousands of masquerades. - return nil, errors.New("no masquerades left to try") -} - func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { log.Tracef("Dialing to %v", m) diff --git a/fronted_test.go b/fronted_test.go index 575af02..49e9e66 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -1,7 +1,6 @@ package fronted import ( - "context" "crypto/x509" "encoding/json" "errors" @@ -58,9 +57,9 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) testContext := newFrontingContext("TestDirectDomainFrontingWithSNIConfig") - testContext.Configure(certs, p, "akamai", cacheFile) + testContext.configure(certs, p, "akamai", cacheFile) - transport, ok := testContext.NewFronted(30 * time.Second) + transport, ok := testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -87,9 +86,9 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) testContext := newFrontingContext("doTestDomainFronting") - testContext.Configure(certs, p, testProviderID, cacheFile) + testContext.configure(certs, p, testProviderID, cacheFile) - transport, ok := testContext.NewFronted(30 * time.Second) + transport, ok := testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client := &http.Client{ @@ -98,7 +97,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = testContext.NewFronted(30 * time.Second) + transport, ok = testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client = &http.Client{ Transport: transport, @@ -242,9 +241,9 @@ func TestHostAliasesBasic(t *testing.T) { certs.AddCert(cloudSack.Certificate()) testContext := newFrontingContext("TestHostAliasesBasic") - testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := testContext.NewFronted(30 * time.Second) + rt, ok := testContext.NewRoundTripper(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -355,8 +354,8 @@ func TestHostAliasesMulti(t *testing.T) { } testContext := newFrontingContext("TestHostAliasesMulti") - testContext.Configure(certs, providers, "cloudsack", "") - rt, ok := testContext.NewFronted(30 * time.Second) + testContext.configure(certs, providers, "cloudsack", "") + rt, ok := testContext.NewRoundTripper(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -482,9 +481,9 @@ func TestPassthrough(t *testing.T) { certs.AddCert(cloudSack.Certificate()) testContext := newFrontingContext("TestPassthrough") - testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := testContext.NewFronted(30 * time.Second) + rt, ok := testContext.NewRoundTripper(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -554,7 +553,7 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - ctx.Configure(certs, providers, "sadcloud", "") + ctx.configure(certs, providers, "sadcloud", "") } // This error indicates that the validator has discarded all masquerades. @@ -637,7 +636,7 @@ func TestCustomValidators(t *testing.T) { t.Run(test.name, func(t *testing.T) { testContext := newFrontingContext(test.name) setup(testContext, test.validator) - direct, ok := testContext.NewFronted(30 * time.Second) + direct, ok := testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: direct, @@ -902,134 +901,6 @@ func TestFindWorkingMasquerades(t *testing.T) { } } -func TestMasqueradeToTry(t *testing.T) { - min := time.Now().Add(-time.Minute) - hour := time.Now().Add(-time.Hour) - domain1 := newMockMasqueradeWithLastSuccess("domain1.com", "1.1.1.1", 0, true, min) - domain2 := newMockMasqueradeWithLastSuccess("domain2.com", "2.2.2.2", 0, true, hour) - tests := []struct { - name string - masquerades sortedMasquerades - triedMasquerades map[MasqueradeInterface]bool - expected MasqueradeInterface - }{ - { - name: "No tried masquerades", - masquerades: sortedMasquerades{ - domain1, - domain2, - }, - triedMasquerades: map[MasqueradeInterface]bool{}, - expected: domain1, - }, - { - name: "Some tried masquerades", - masquerades: sortedMasquerades{ - domain1, - domain2, - }, - triedMasquerades: map[MasqueradeInterface]bool{ - domain1: true, - }, - expected: domain2, - }, - { - name: "All masquerades tried", - masquerades: sortedMasquerades{ - domain1, - domain2, - }, - triedMasquerades: map[MasqueradeInterface]bool{ - domain1: true, - domain2: true, - }, - expected: nil, - }, - { - name: "Empty masquerades list", - masquerades: sortedMasquerades{}, - triedMasquerades: map[MasqueradeInterface]bool{}, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f := &fronted{} - masquerades := tt.masquerades.sortedCopy() - result, _ := f.masqueradeToTry(masquerades, tt.triedMasquerades) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestDialAll(t *testing.T) { - tests := []struct { - name string - masquerades []*mockMasquerade - expectedSuccessful bool - expectedMasquerades int - }{ - { - name: "All successful", - masquerades: []*mockMasquerade{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain2.com", "2.2.2.2", 0, true), - newMockMasquerade("domain3.com", "3.3.3.3", 0, true), - newMockMasquerade("domain4.com", "4.4.4.4", 0, true), - }, - expectedSuccessful: true, - }, - { - name: "Some successful", - masquerades: []*mockMasquerade{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false), - newMockMasquerade("domain3.com", "3.3.3.3", 0, true), - newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false), - }, - expectedSuccessful: true, - }, - { - name: "None successful", - masquerades: []*mockMasquerade{ - newMockMasquerade("domain1.com", "1.1.1.1", 1*time.Millisecond, false), - newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false), - newMockMasquerade("domain3.com", "3.3.3.3", 1*time.Millisecond, false), - newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false), - }, - expectedSuccessful: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &fronted{} - d.providers = make(map[string]*Provider) - d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) - d.masquerades = make(sortedMasquerades, len(tt.masquerades)) - for i, m := range tt.masquerades { - d.masquerades[i] = m - } - - ctx := context.Background() - conn, m, masqueradeGood, err := d.dialAll(ctx) - - if tt.expectedSuccessful { - assert.NoError(t, err) - assert.NotNil(t, conn) - assert.NotNil(t, m) - assert.NotNil(t, masqueradeGood) - } else { - assert.Error(t, err) - assert.Nil(t, conn) - assert.Nil(t, m) - assert.Nil(t, masqueradeGood) - } - }) - } -} - // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade { diff --git a/test_support.go b/test_support.go index 6f787a0..d508a62 100644 --- a/test_support.go +++ b/test_support.go @@ -23,13 +23,13 @@ func ConfigureForTest(t *testing.T) { func ConfigureCachingForTest(t *testing.T, cacheFile string) { certs := trustedCACerts(t) p := testProviders() - Configure(certs, p, testProviderID, cacheFile) + NewFronter(certs, p, testProviderID, cacheFile) } func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - Configure(certs, p, testProviderID, "") + NewFronter(certs, p, testProviderID, "") } func trustedCACerts(t *testing.T) *x509.CertPool { From cc002f7b0bcdbf226b1e20e0f3e8e594819de0ef Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Mon, 25 Nov 2024 13:47:47 -0700 Subject: [PATCH 02/25] Naming tweaks --- context.go | 15 +++++++-------- fronted_test.go | 14 +++++++------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index 02853a7..fd1b378 100644 --- a/context.go +++ b/context.go @@ -13,16 +13,16 @@ import ( ) // Create an interface for the fronting context -type Fronting interface { +type Fronted interface { UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) - NewRoundTripper(timeout time.Duration) (http.RoundTripper, bool) + NewRoundTripper(timeout time.Duration) (http.RoundTripper, error) Close() } var defaultContext = newFrontingContext("default") // Make sure that the default context is a Fronting -var _ Fronting = defaultContext +var _ Fronted = defaultContext // Configure sets the masquerades to use, the trusted root CAs, and the // cache file for caching masquerades to set up direct domain fronting @@ -30,7 +30,7 @@ var _ Fronting = defaultContext // // defaultProviderID is used when a masquerade without a provider is // encountered (eg in a cache file) -func NewFronter(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) (Fronting, error) { +func NewFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) (Fronted, error) { if err := defaultContext.configure(pool, providers, defaultProviderID, cacheFile); err != nil { return nil, log.Errorf("Error configuring fronting %s context: %s!!", defaultContext.name, err) } @@ -94,18 +94,17 @@ func (fctx *frontingContext) configureWithHello(pool *x509.CertPool, providers m // NewFronted creates a new http.RoundTripper that does direct domain fronting. // If the context isn't configured within the given timeout, this method // returns nil, false. -func (fctx *frontingContext) NewRoundTripper(timeout time.Duration) (http.RoundTripper, bool) { +func (fctx *frontingContext) NewRoundTripper(timeout time.Duration) (http.RoundTripper, error) { start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() instance, err := fctx.instance.Get(ctx) if err != nil { - log.Errorf("No DirectHttpClient available within %v for context %s", timeout, fctx.name) - return nil, false + return nil, log.Errorf("No DirectHttpClient available within %v for context %s with error %v", timeout, fctx.name, err) } else { log.Debugf("DirectHttpClient available for context %s after %v with duration %v", fctx.name, time.Since(start), timeout) } - return instance.(http.RoundTripper), true + return instance.(http.RoundTripper), nil } // Close closes any existing cache file in the default contexxt. diff --git a/fronted_test.go b/fronted_test.go index 49e9e66..0a83b9f 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -60,7 +60,7 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { testContext.configure(certs, p, "akamai", cacheFile) transport, ok := testContext.NewRoundTripper(30 * time.Second) - require.True(t, ok) + require.NoError(t, ok) client := &http.Client{ Transport: transport, } @@ -89,7 +89,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE testContext.configure(certs, p, testProviderID, cacheFile) transport, ok := testContext.NewRoundTripper(30 * time.Second) - require.True(t, ok) + require.NoError(t, ok) client := &http.Client{ Transport: transport, @@ -98,7 +98,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) transport, ok = testContext.NewRoundTripper(30 * time.Second) - require.True(t, ok) + require.NoError(t, ok) client = &http.Client{ Transport: transport, } @@ -244,7 +244,7 @@ func TestHostAliasesBasic(t *testing.T) { testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") rt, ok := testContext.NewRoundTripper(30 * time.Second) - if !assert.True(t, ok, "failed to obtain direct roundtripper") { + if !assert.NoError(t, ok, "failed to obtain direct roundtripper") { return } client := &http.Client{Transport: rt} @@ -356,7 +356,7 @@ func TestHostAliasesMulti(t *testing.T) { testContext := newFrontingContext("TestHostAliasesMulti") testContext.configure(certs, providers, "cloudsack", "") rt, ok := testContext.NewRoundTripper(30 * time.Second) - if !assert.True(t, ok, "failed to obtain direct roundtripper") { + if !assert.NoError(t, ok, "failed to obtain direct roundtripper") { return } client := &http.Client{Transport: rt} @@ -484,7 +484,7 @@ func TestPassthrough(t *testing.T) { testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") rt, ok := testContext.NewRoundTripper(30 * time.Second) - if !assert.True(t, ok, "failed to obtain direct roundtripper") { + if !assert.NoError(t, ok, "failed to obtain direct roundtripper") { return } client := &http.Client{Transport: rt} @@ -637,7 +637,7 @@ func TestCustomValidators(t *testing.T) { testContext := newFrontingContext(test.name) setup(testContext, test.validator) direct, ok := testContext.NewRoundTripper(30 * time.Second) - require.True(t, ok) + require.NoError(t, ok) client := &http.Client{ Transport: direct, } From 41251ede8ea40b16dc11133ca4f23d00f1f1acf8 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 26 Nov 2024 13:42:12 -0700 Subject: [PATCH 03/25] fix for test compile error --- test_support.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_support.go b/test_support.go index d508a62..cd615f7 100644 --- a/test_support.go +++ b/test_support.go @@ -23,13 +23,13 @@ func ConfigureForTest(t *testing.T) { func ConfigureCachingForTest(t *testing.T, cacheFile string) { certs := trustedCACerts(t) p := testProviders() - NewFronter(certs, p, testProviderID, cacheFile) + NewFronted(certs, p, testProviderID, cacheFile) } func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - NewFronter(certs, p, testProviderID, "") + NewFronted(certs, p, testProviderID, "") } func trustedCACerts(t *testing.T) *x509.CertPool { From 008e81074d38f2a6ba0864f1132b104ef9d676c5 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 12:25:02 -0700 Subject: [PATCH 04/25] remove log --- fronted.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fronted.go b/fronted.go index 4092788..78f8da2 100644 --- a/fronted.go +++ b/fronted.go @@ -101,7 +101,6 @@ func copyProviders(providers map[string]*Provider) map[string]*Provider { providersCopy := make(map[string]*Provider, len(providers)) for key, p := range providers { providersCopy[key] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) - log.Debugf("Domain fronting provider is %v", providersCopy[key]) } return providersCopy } @@ -134,6 +133,7 @@ func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { func (f *fronted) updateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { // Make copies just to avoid any concurrency issues with access that may be happening on the // caller side. + log.Debug("Updating fronted configuration") providersCopy := copyProviders(providers) f.frontedMu.Lock() defer f.frontedMu.Unlock() From 9ee0721dd6394b6bd0f0397ace5f4fbe993a11b4 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 12:27:53 -0700 Subject: [PATCH 05/25] tweak else statement --- context.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/context.go b/context.go index fd1b378..6a8ae6a 100644 --- a/context.go +++ b/context.go @@ -81,11 +81,10 @@ func (fctx *frontingContext) configureWithHello(pool *x509.CertPool, providers m } var err error - fctx.fronted, err = newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + if fctx.fronted, err = newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { log.Debug("Setting fronted instance") fctx.instance.Set(f) - }, fctx.connectingFronts) - if err != nil { + }, fctx.connectingFronts); err != nil { return err } return nil From cfa8beba29d42df9a85edd73920c29815e6899d6 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 12:39:58 -0700 Subject: [PATCH 06/25] tone down logging --- fronted.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fronted.go b/fronted.go index 78f8da2..8519741 100644 --- a/fronted.go +++ b/fronted.go @@ -257,7 +257,7 @@ func (f *fronted) vetMasquerade(m MasqueradeInterface) (bool, time.Duration) { return false, time.Since(start) } - log.Debugf("Successfully vetted one masquerade %v", m) + log.Debugf("Successfully vetted one masquerade %v", m.getIpAddress()) return true, time.Since(start) } @@ -393,6 +393,7 @@ func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bo // We do the full TLS connection here because in practice the domains at a given IP // address can change frequently on CDNs, so the certificate may not match what // we expect. + start := time.Now() conn, retriable, err := f.doDial(m) masqueradeGood := func(good bool) bool { if good { @@ -404,7 +405,7 @@ func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bo return good } if err == nil { - log.Debugf("Returning connection for masquerade: %v", m) + log.Debugf("Returning connection for masquerade %v in %v", m.getIpAddress(), time.Since(start)) return conn, masqueradeGood, err } else if !retriable { log.Debugf("Dropping masquerade: non retryable error: %v", err) @@ -423,7 +424,6 @@ func (f *fronted) doDial(m MasqueradeInterface) (net.Conn, bool, error) { var conn net.Conn var err error retriable := false - start := time.Now() conn, err = m.dial(f.certPool, f.clientHelloID) if err != nil { if !isNetworkUnreachable(err) { @@ -440,8 +440,6 @@ func (f *fronted) doDial(m MasqueradeInterface) (net.Conn, bool, error) { log.Debugf("Unexpected error dialing, keeping masquerade: %v", err) retriable = true } - } else { - log.Debugf("Got successful connection to: %+v in %v", m, time.Since(start)) } return conn, retriable, err } From bd8e9398ef5e5447b2c4346c4721138ff7e7e032 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 12:54:39 -0700 Subject: [PATCH 07/25] use go test fmt and check for nil param --- .github/workflows/test.yaml | 22 +++++++++++++++++++--- fronted.go | 6 ++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9e29ed5..f90544b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -13,11 +13,27 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version-file: "go.mod" - - name: Run unit tests - run: go test -failfast -coverprofile=profile.cov + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} + - name: Run tests + run: | + set -euo pipefail + go test -json -race -failfast -tags="headless" -coverprofile=profile.cov -v ./... 2>&1 | tee /tmp/gotest.log | gotestfmt -nofail + + # Upload the original go test log as an artifact for later review. + - name: Upload test log + uses: actions/upload-artifact@v3 + if: always() + with: + name: test-log + path: /tmp/gotest.log + if-no-files-found: error - name: Install goveralls run: go install github.com/mattn/goveralls@latest - name: Send coverage diff --git a/fronted.go b/fronted.go index 8519741..aa8d564 100644 --- a/fronted.go +++ b/fronted.go @@ -61,8 +61,10 @@ type fronted struct { func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID, - listener func(f *fronted), - workingFronts workingFronts) (*fronted, error) { + listener func(f *fronted), workingFronts workingFronts) (*fronted, error) { + if workingFronts == nil { + return nil, fmt.Errorf("workingFronts must not be nil") + } size := 0 for _, p := range providers { size += len(p.Masquerades) From 2a65be367112b65172c407b1e806777ca731de6d Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 13:30:21 -0700 Subject: [PATCH 08/25] try new test format --- .github/workflows/test.yaml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f90544b..983233f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,16 +16,15 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: "go.mod" - - name: Set up gotestfmt - uses: gotesttools/gotestfmt-action@v2 - with: - # Optional: pass GITHUB_TOKEN to avoid rate limiting. - token: ${{ secrets.GITHUB_TOKEN }} + - name: Install go go-ctrf-json-reporter + run: go install github.com/ctrf-io/go-ctrf-json-reporter/cmd/go-ctrf-json-reporter@latest - name: Run tests run: | set -euo pipefail - go test -json -race -failfast -tags="headless" -coverprofile=profile.cov -v ./... 2>&1 | tee /tmp/gotest.log | gotestfmt -nofail - + go test -json -race -failfast -tags="headless" -coverprofile=profile.cov ./... | go-ctrf-json-reporter -output ctrf-report.json + - name: Run CTRF annotations + run: npx github-actions-ctrf ctrf-report.json + if: always() # Upload the original go test log as an artifact for later review. - name: Upload test log uses: actions/upload-artifact@v3 From 7f93522f89192857634c3d522d10677925165622 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 13:46:06 -0700 Subject: [PATCH 09/25] Another try at cleaner test output --- .github/workflows/test.yaml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 983233f..d7fe2cb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,17 +22,11 @@ jobs: run: | set -euo pipefail go test -json -race -failfast -tags="headless" -coverprofile=profile.cov ./... | go-ctrf-json-reporter -output ctrf-report.json - - name: Run CTRF annotations - run: npx github-actions-ctrf ctrf-report.json - if: always() - # Upload the original go test log as an artifact for later review. - - name: Upload test log - uses: actions/upload-artifact@v3 - if: always() + - name: Upload test results + uses: actions/upload-artifact@v4 with: - name: test-log - path: /tmp/gotest.log - if-no-files-found: error + name: ctrf-report + path: ctrf-report.json - name: Install goveralls run: go install github.com/mattn/goveralls@latest - name: Send coverage From b2c702b5dc3e04db130cd328b43ebcf9a278452d Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 13:52:44 -0700 Subject: [PATCH 10/25] do not fail fast --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d7fe2cb..0843e94 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -21,7 +21,7 @@ jobs: - name: Run tests run: | set -euo pipefail - go test -json -race -failfast -tags="headless" -coverprofile=profile.cov ./... | go-ctrf-json-reporter -output ctrf-report.json + go test -json -race -tags="headless" -coverprofile=profile.cov ./... | go-ctrf-json-reporter -output ctrf-report.json - name: Upload test results uses: actions/upload-artifact@v4 with: From 800f367e3f5ad88923a8bc654809026651396a72 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 13:57:17 -0700 Subject: [PATCH 11/25] no pipefail --- .github/workflows/test.yaml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 0843e94..aafc61f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,9 +19,7 @@ jobs: - name: Install go go-ctrf-json-reporter run: go install github.com/ctrf-io/go-ctrf-json-reporter/cmd/go-ctrf-json-reporter@latest - name: Run tests - run: | - set -euo pipefail - go test -json -race -tags="headless" -coverprofile=profile.cov ./... | go-ctrf-json-reporter -output ctrf-report.json + run: go test -json -race -tags="headless" -coverprofile=profile.cov ./... | go-ctrf-json-reporter -output ctrf-report.json - name: Upload test results uses: actions/upload-artifact@v4 with: From 55e5821dea3d6e31aa3b0b51dacccfecd9e396a3 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 14:09:54 -0700 Subject: [PATCH 12/25] install node --- .github/workflows/test.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index aafc61f..aeae5b5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,6 +12,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - name: Install dependencies + run: npm install - name: Set up Go uses: actions/setup-go@v5 with: @@ -25,6 +27,8 @@ jobs: with: name: ctrf-report path: ctrf-report.json + - name: Publish Test Summary Results + run: npx github-actions-ctrf ctrf-report.json - name: Install goveralls run: go install github.com/mattn/goveralls@latest - name: Send coverage From 12a44503691aca774a45a32fc168c18a91a20d18 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 30 Nov 2024 14:18:01 -0700 Subject: [PATCH 13/25] add node --- .github/workflows/test.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index aeae5b5..46e76ac 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,8 +12,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Install dependencies - run: npm install + - uses: actions/setup-node@v4 + with: + node-version: 18 - name: Set up Go uses: actions/setup-go@v5 with: From 6b6fc13be04cfcfa2c2e40f750cf4ed524e7e751 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 10:56:45 -0700 Subject: [PATCH 14/25] Lots of cleanups to more cleanly handle continually finding working fronts --- cache.go | 9 +- cache_test.go | 38 +++--- connecting_fronts.go | 69 ++++------ connecting_fronts_test.go | 39 ++++++ context.go | 121 ----------------- masquerade.go => front.go | 48 ++++--- masquerade_test.go => front_test.go | 0 fronted.go | 193 +++++++++++++++------------ fronted_test.go | 198 +++++++++++++++------------- go.mod | 1 - go.sum | 2 - test_support.go | 5 +- 12 files changed, 336 insertions(+), 387 deletions(-) create mode 100644 connecting_fronts_test.go delete mode 100644 context.go rename masquerade.go => front.go (90%) rename masquerade_test.go => front_test.go (100%) diff --git a/cache.go b/cache.go index 9df34d5..4af5fa5 100644 --- a/cache.go +++ b/cache.go @@ -27,7 +27,7 @@ func (d *fronted) prepopulateMasquerades(cacheFile string) { } log.Debugf("Attempting to prepopulate masquerades from cache file: %v", cacheFile) - var cachedMasquerades []*masquerade + var cachedMasquerades []*front if err := json.Unmarshal(bytes, &cachedMasquerades); err != nil { log.Errorf("Error reading cached masquerades: %v", err) return @@ -37,7 +37,7 @@ func (d *fronted) prepopulateMasquerades(cacheFile string) { now := time.Now() // update last succeeded status of masquerades based on cached values - for _, m := range d.masquerades { + for _, m := range d.fronts { for _, cm := range cachedMasquerades { sameMasquerade := cm.ProviderID == m.getProviderID() && cm.Domain == m.getDomain() && cm.IpAddress == m.getIpAddress() cachedValueFresh := now.Sub(m.lastSucceeded()) < d.maxAllowedCachedAge @@ -75,7 +75,7 @@ func (d *fronted) maintainCache(cacheFile string) { func (d *fronted) updateCache(cacheFile string) { log.Debugf("Updating cache at %v", cacheFile) - cache := d.masquerades.sortedCopy() + cache := d.fronts.sortedCopy() sizeToSave := len(cache) if d.maxCacheSize < sizeToSave { sizeToSave = d.maxCacheSize @@ -101,8 +101,9 @@ func (d *fronted) updateCache(cacheFile string) { } } -func (d *fronted) closeCache() { +func (d *fronted) Close() { d.closeCacheOnce.Do(func() { close(d.cacheClosed) }) + d.stopCh <- nil } diff --git a/cache_test.go b/cache_test.go index 9b874d1..ba18747 100644 --- a/cache_test.go +++ b/cache_test.go @@ -26,9 +26,9 @@ func TestCaching(t *testing.T) { cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil), } - makeDirect := func() *fronted { - d := &fronted{ - masquerades: make(sortedMasquerades, 0, 1000), + makeFronted := func() *fronted { + f := &fronted{ + fronts: make(sortedFronts, 0, 1000), maxAllowedCachedAge: 250 * time.Millisecond, maxCacheSize: 4, cacheSaveInterval: 50 * time.Millisecond, @@ -37,20 +37,20 @@ func TestCaching(t *testing.T) { providers: providers, defaultProviderID: cloudsackID, } - go d.maintainCache(cacheFile) - return d + go f.maintainCache(cacheFile) + return f } now := time.Now() - mb := &masquerade{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, ProviderID: testProviderID} - mc := &masquerade{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, ProviderID: ""} // defaulted - md := &masquerade{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped + mb := &front{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, ProviderID: testProviderID} + mc := &front{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, ProviderID: ""} // defaulted + md := &front{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped - d := makeDirect() - d.masquerades = append(d.masquerades, mb, mc, md) + f := makeFronted() + f.fronts = append(f.fronts, mb, mc, md) - readCached := func() []*masquerade { - var result []*masquerade + readCached := func() []*front { + var result []*front b, err := os.ReadFile(cacheFile) require.NoError(t, err, "Unable to read cache file") err = json.Unmarshal(b, &result) @@ -59,22 +59,22 @@ func TestCaching(t *testing.T) { } // Save the cache - d.markCacheDirty() - time.Sleep(d.cacheSaveInterval * 2) - d.closeCache() + f.markCacheDirty() + time.Sleep(f.cacheSaveInterval * 2) + f.Close() time.Sleep(50 * time.Millisecond) // Reopen cache file and make sure right data was in there - d = makeDirect() - d.prepopulateMasquerades(cacheFile) + f = makeFronted() + f.prepopulateMasquerades(cacheFile) masquerades := readCached() require.Len(t, masquerades, 3, "Wrong number of masquerades read") - for i, expected := range []*masquerade{mb, mc, md} { + for i, expected := range []*front{mb, mc, md} { require.Equal(t, expected.Domain, masquerades[i].Domain, "Wrong masquerade at position %d", i) require.Equal(t, expected.IpAddress, masquerades[i].IpAddress, "Masquerade at position %d has wrong IpAddress", 0) require.Equal(t, expected.ProviderID, masquerades[i].ProviderID, "Masquerade at position %d has wrong ProviderID", 0) require.Equal(t, now.Unix(), masquerades[i].LastSucceeded.Unix(), "Masquerade at position %d has wrong LastSucceeded", 0) } - d.closeCache() + f.Close() } diff --git a/connecting_fronts.go b/connecting_fronts.go index ec4611d..8e2147f 100644 --- a/connecting_fronts.go +++ b/connecting_fronts.go @@ -1,68 +1,53 @@ package fronted import ( - "errors" - "sort" - "sync" - "time" + "context" ) -type connectTimeFront struct { - MasqueradeInterface - connectTime time.Duration +type workingFronts interface { + onConnected(m Front) + connectingFront(context.Context) (Front, error) + size() int } type connectingFronts struct { - fronts []connectTimeFront - //frontsChan chan MasqueradeInterface - sync.RWMutex + // Create a channel of fronts that are connecting. + frontsCh chan Front } // Make sure that connectingFronts is a connectListener var _ workingFronts = &connectingFronts{} // newConnectingFronts creates a new ConnectingFronts struct with an empty slice of Masquerade IPs and domains. -func newConnectingFronts() *connectingFronts { +func newConnectingFronts(size int) *connectingFronts { return &connectingFronts{ - fronts: make([]connectTimeFront, 0), - //frontsChan: make(chan MasqueradeInterface), + // We overallocate the channel to avoid blocking. + frontsCh: make(chan Front, size), } } // AddFront adds a new front to the list of fronts. -func (cf *connectingFronts) onConnected(m MasqueradeInterface, connectTime time.Duration) { - cf.Lock() - defer cf.Unlock() - - cf.fronts = append(cf.fronts, connectTimeFront{ - MasqueradeInterface: m, - connectTime: connectTime, - }) - // Sort fronts by connect time. - sort.Slice(cf.fronts, func(i, j int) bool { - return cf.fronts[i].connectTime < cf.fronts[j].connectTime - }) - //cf.frontsChan <- m +func (cf *connectingFronts) onConnected(m Front) { + cf.frontsCh <- m } -func (cf *connectingFronts) onError(m MasqueradeInterface) { - cf.Lock() - defer cf.Unlock() - - // Remove the front from connecting fronts. - for i, front := range cf.fronts { - if front.MasqueradeInterface == m { - cf.fronts = append(cf.fronts[:i], cf.fronts[i+1:]...) - return +func (cf *connectingFronts) connectingFront(ctx context.Context) (Front, error) { + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case m := <-cf.frontsCh: + // The front may have stopped succeeding since we last checked, + // so only return it if it's still succeeding. + if m.isSucceeding() { + // Add the front back to the channel. + cf.frontsCh <- m + return m, nil + } } } } -func (cf *connectingFronts) workingFront() (MasqueradeInterface, error) { - cf.RLock() - defer cf.RUnlock() - if len(cf.fronts) == 0 { - return nil, errors.New("no fronts available") - } - return cf.fronts[0].MasqueradeInterface, nil +func (cf *connectingFronts) size() int { + return len(cf.frontsCh) } diff --git a/connecting_fronts_test.go b/connecting_fronts_test.go new file mode 100644 index 0000000..23e3fa9 --- /dev/null +++ b/connecting_fronts_test.go @@ -0,0 +1,39 @@ +package fronted + +import ( + "testing" +) + +func TestConnectingFrontsSize(t *testing.T) { + tests := []struct { + name string + setup func() *connectingFronts + expected int + }{ + { + name: "empty channel", + setup: func() *connectingFronts { + return newConnectingFronts(10) + }, + expected: 0, + }, + { + name: "non-empty channel", + setup: func() *connectingFronts { + cf := newConnectingFronts(10) + cf.onConnected(&mockFront{}) + return cf + }, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cf := tt.setup() + if got := cf.size(); got != tt.expected { + t.Errorf("size() = %d, want %d", got, tt.expected) + } + }) + } +} diff --git a/context.go b/context.go deleted file mode 100644 index 6a8ae6a..0000000 --- a/context.go +++ /dev/null @@ -1,121 +0,0 @@ -package fronted - -import ( - "context" - "crypto/x509" - "fmt" - "net/http" - "time" - - tls "github.com/refraction-networking/utls" - - "github.com/getlantern/eventual/v2" -) - -// Create an interface for the fronting context -type Fronted interface { - UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) - NewRoundTripper(timeout time.Duration) (http.RoundTripper, error) - Close() -} - -var defaultContext = newFrontingContext("default") - -// Make sure that the default context is a Fronting -var _ Fronted = defaultContext - -// Configure sets the masquerades to use, the trusted root CAs, and the -// cache file for caching masquerades to set up direct domain fronting -// in the default context. -// -// defaultProviderID is used when a masquerade without a provider is -// encountered (eg in a cache file) -func NewFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) (Fronted, error) { - if err := defaultContext.configure(pool, providers, defaultProviderID, cacheFile); err != nil { - return nil, log.Errorf("Error configuring fronting %s context: %s!!", defaultContext.name, err) - } - return defaultContext, nil -} - -func newFrontingContext(name string) *frontingContext { - return &frontingContext{ - name: name, - instance: eventual.NewValue(), - connectingFronts: newConnectingFronts(), - } -} - -type frontingContext struct { - name string - instance eventual.Value - fronted *fronted - connectingFronts *connectingFronts -} - -// UpdateConfig updates the configuration of the fronting context -func (fctx *frontingContext) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { - fctx.fronted.updateConfig(pool, providers, defaultProviderID) -} - -// configure sets the masquerades to use, the trusted root CAs, and the -// cache file for caching masquerades to set up direct domain fronting. -// defaultProviderID is used when a masquerade without a provider is -// encountered (eg in a cache file) -func (fctx *frontingContext) configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) error { - return fctx.configureWithHello(pool, providers, defaultProviderID, cacheFile, tls.ClientHelloID{}) -} - -func (fctx *frontingContext) configureWithHello(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, clientHelloID tls.ClientHelloID) error { - log.Debugf("Configuring fronted %s context", fctx.name) - - if len(providers) == 0 { - return fmt.Errorf("no fronted providers for %s context", fctx.name) - } - - if _existing, err := fctx.instance.Get(eventual.DontWait); err != nil { - log.Debugf("No existing instance for %s context: %s", fctx.name, err) - } else if _existing != nil { - existing := _existing.(*fronted) - log.Debugf("Closing cache from existing instance for %s context", fctx.name) - existing.closeCache() - } - - var err error - if fctx.fronted, err = newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { - log.Debug("Setting fronted instance") - fctx.instance.Set(f) - }, fctx.connectingFronts); err != nil { - return err - } - return nil -} - -// NewFronted creates a new http.RoundTripper that does direct domain fronting. -// If the context isn't configured within the given timeout, this method -// returns nil, false. -func (fctx *frontingContext) NewRoundTripper(timeout time.Duration) (http.RoundTripper, error) { - start := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - instance, err := fctx.instance.Get(ctx) - if err != nil { - return nil, log.Errorf("No DirectHttpClient available within %v for context %s with error %v", timeout, fctx.name, err) - } else { - log.Debugf("DirectHttpClient available for context %s after %v with duration %v", fctx.name, time.Since(start), timeout) - } - return instance.(http.RoundTripper), nil -} - -// Close closes any existing cache file in the default contexxt. -func (fctx *frontingContext) Close() { - _existing, err := fctx.instance.Get(eventual.DontWait) - if err != nil { - log.Errorf("Error getting existing instance for %s context: %s", fctx.name, err) - return - } - if _existing != nil { - existing := _existing.(*fronted) - log.Debugf("Closing cache from existing instance in %s context", fctx.name) - existing.closeCache() - } -} diff --git a/masquerade.go b/front.go similarity index 90% rename from masquerade.go rename to front.go index 0f8dce7..4e50404 100644 --- a/masquerade.go +++ b/front.go @@ -53,7 +53,7 @@ type Masquerade struct { } // Create a masquerade interface for easier testing. -type MasqueradeInterface interface { +type Front interface { dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) // Accessor for the domain of the masquerade @@ -73,9 +73,11 @@ type MasqueradeInterface interface { postCheck(net.Conn, string) bool getProviderID() string + + isSucceeding() bool } -type masquerade struct { +type front struct { Masquerade // lastSucceeded: the most recent time at which this Masquerade succeeded LastSucceeded time.Time @@ -84,7 +86,7 @@ type masquerade struct { mx sync.RWMutex } -func (m *masquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { +func (m *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { tlsConfig := &tls.Config{ ServerName: m.Domain, RootCAs: rootCAs, @@ -120,7 +122,7 @@ func (m *masquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloI } // postCheck does a post with invalid data to verify domain-fronting works -func (m *masquerade) postCheck(conn net.Conn, testURL string) bool { +func (m *front) postCheck(conn net.Conn, testURL string) bool { client := &http.Client{ Transport: frontedHTTPTransport(conn, true), } @@ -162,55 +164,61 @@ func doCheck(client *http.Client, method string, expectedStatus int, u string) b } // getDomain implements MasqueradeInterface. -func (m *masquerade) getDomain() string { +func (m *front) getDomain() string { return m.Domain } // getIpAddress implements MasqueradeInterface. -func (m *masquerade) getIpAddress() string { +func (m *front) getIpAddress() string { return m.IpAddress } // getProviderID implements MasqueradeInterface. -func (m *masquerade) getProviderID() string { +func (m *front) getProviderID() string { return m.ProviderID } // MarshalJSON marshals masquerade into json -func (m *masquerade) MarshalJSON() ([]byte, error) { +func (m *front) MarshalJSON() ([]byte, error) { m.mx.RLock() defer m.mx.RUnlock() // Type alias for masquerade so that we don't infinitely recurse when marshaling the struct - type alias masquerade + type alias front return json.Marshal((*alias)(m)) } -func (m *masquerade) lastSucceeded() time.Time { +func (m *front) lastSucceeded() time.Time { m.mx.RLock() defer m.mx.RUnlock() return m.LastSucceeded } -func (m *masquerade) setLastSucceeded(t time.Time) { +func (m *front) setLastSucceeded(t time.Time) { m.mx.Lock() defer m.mx.Unlock() m.LastSucceeded = t } -func (m *masquerade) markSucceeded() { +func (m *front) markSucceeded() { m.mx.Lock() defer m.mx.Unlock() m.LastSucceeded = time.Now() } -func (m *masquerade) markFailed() { +func (m *front) markFailed() { m.mx.Lock() defer m.mx.Unlock() m.LastSucceeded = time.Time{} } +func (m *front) isSucceeding() bool { + m.mx.RLock() + defer m.mx.RUnlock() + return m.LastSucceeded.After(time.Time{}) +} + // Make sure that the masquerade struct implements the MasqueradeInterface -var _ MasqueradeInterface = (*masquerade)(nil) +var _ Front = (*front)(nil) // A Direct fronting provider configuration. type Provider struct { @@ -343,11 +351,11 @@ func NewStatusCodeValidator(reject []int) ResponseValidator { } // slice of masquerade sorted by last vetted time -type sortedMasquerades []MasqueradeInterface +type sortedFronts []Front -func (m sortedMasquerades) Len() int { return len(m) } -func (m sortedMasquerades) Swap(i, j int) { m[i], m[j] = m[j], m[i] } -func (m sortedMasquerades) Less(i, j int) bool { +func (m sortedFronts) Len() int { return len(m) } +func (m sortedFronts) Swap(i, j int) { m[i], m[j] = m[j], m[i] } +func (m sortedFronts) Less(i, j int) bool { if m[i].lastSucceeded().After(m[j].lastSucceeded()) { return true } else if m[j].lastSucceeded().After(m[i].lastSucceeded()) { @@ -357,8 +365,8 @@ func (m sortedMasquerades) Less(i, j int) bool { } } -func (m sortedMasquerades) sortedCopy() sortedMasquerades { - c := make(sortedMasquerades, len(m)) +func (m sortedFronts) sortedCopy() sortedFronts { + c := make(sortedFronts, len(m)) copy(c, m) sort.Sort(c) return c diff --git a/masquerade_test.go b/front_test.go similarity index 100% rename from masquerade_test.go rename to front_test.go diff --git a/fronted.go b/fronted.go index aa8d564..f8becc5 100644 --- a/fronted.go +++ b/fronted.go @@ -12,7 +12,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "syscall" "time" @@ -33,17 +32,11 @@ var ( log = golog.LoggerFor("fronted") ) -type workingFronts interface { - onConnected(m MasqueradeInterface, connectTime time.Duration) - onError(m MasqueradeInterface) - workingFront() (MasqueradeInterface, error) -} - // fronted identifies working IP address/domain pairings for domain fronting and is // an implementation of http.RoundTripper for the convenience of callers. type fronted struct { certPool *x509.CertPool - masquerades sortedMasquerades + fronts sortedFronts maxAllowedCachedAge time.Duration maxCacheSize int cacheSaveInterval time.Duration @@ -55,31 +48,46 @@ type fronted struct { clientHelloID tls.ClientHelloID workingFronts workingFronts providersMu sync.RWMutex - masqueradesMu sync.RWMutex + frontsMu sync.RWMutex frontedMu sync.RWMutex + stopCh chan interface{} } -func newFronted(pool *x509.CertPool, providers map[string]*Provider, - defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID, - listener func(f *fronted), workingFronts workingFronts) (*fronted, error) { - if workingFronts == nil { - return nil, fmt.Errorf("workingFronts must not be nil") - } - size := 0 - for _, p := range providers { - size += len(p.Masquerades) - } +// Interface for sending HTTP traffic over domain fronting. +type Fronted interface { + http.RoundTripper - if size == 0 { - return nil, fmt.Errorf("no masquerades found in providers") + // UpdateConfig updates the set of domain fronts to try. + UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) + + // Close closes any resources, such as goroutines that are testing fronts. + Close() +} + +// NewFronted sets the domain fronts to use, the trusted root CAs, the fronting providers +// (such as Akamai, Cloudfront, etc), and the cache file for caching fronts to set up +// domain fronting. +// +// defaultProviderID is used when a front without a provider is +// encountered (eg in a cache file) +func NewFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, + clientHello tls.ClientHelloID) (Fronted, error) { + log.Debug("Creating new fronted") + // Log method elapsed time + defer func(start time.Time) { + log.Debugf("Creating a new fronted took %v", time.Since(start)) + }(time.Now()) + + if len(providers) == 0 { + return nil, log.Errorf("No providers configured") } - // copy providers providersCopy := copyProviders(providers) + fronts := loadFronts(providersCopy) f := &fronted{ certPool: pool, - masquerades: loadMasquerades(providersCopy, size), + fronts: fronts, maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, cacheSaveInterval: defaultCacheSaveInterval, @@ -87,14 +95,15 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, cacheClosed: make(chan interface{}), defaultProviderID: defaultProviderID, providers: providersCopy, - clientHelloID: clientHelloID, - workingFronts: workingFronts, + clientHelloID: clientHello, + workingFronts: newConnectingFronts(len(fronts)), + stopCh: make(chan interface{}), } if cacheFile != "" { f.initCaching(cacheFile) } - f.findWorkingMasquerades(listener) + go f.findWorkingFronts() return f, nil } @@ -107,15 +116,24 @@ func copyProviders(providers map[string]*Provider) map[string]*Provider { return providersCopy } -func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { - log.Debugf("Loading candidates for %d providers", len(initial)) +func loadFronts(providers map[string]*Provider) sortedFronts { + log.Debugf("Loading candidates for %d providers", len(providers)) defer log.Debug("Finished loading candidates") - masquerades := make(sortedMasquerades, 0, size) - for key, p := range initial { + // Preallocate the slice to avoid reallocation + size := 0 + for _, p := range providers { + size += len(p.Masquerades) + } + + fronts := make(sortedFronts, size) + + index := 0 + for key, p := range providers { arr := p.Masquerades size := len(arr) + // Shuffle the masquerades to avoid biasing the order in which they are tried // make a shuffled copy of arr // ('inside-out' Fisher-Yates) sh := make([]*Masquerade, size) @@ -126,13 +144,14 @@ func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { } for _, c := range sh { - masquerades = append(masquerades, &masquerade{Masquerade: *c, ProviderID: key}) + fronts[index] = &front{Masquerade: *c, ProviderID: key} + index++ } } - return masquerades + return fronts } -func (f *fronted) updateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { +func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { // Make copies just to avoid any concurrency issues with access that may be happening on the // caller side. log.Debug("Updating fronted configuration") @@ -140,7 +159,7 @@ func (f *fronted) updateConfig(pool *x509.CertPool, providers map[string]*Provid f.frontedMu.Lock() defer f.frontedMu.Unlock() f.addProviders(providersCopy) - f.addMasquerades(loadMasquerades(providersCopy, len(providersCopy))) + f.addFronts(loadFronts(providersCopy)) f.defaultProviderID = defaultProviderID f.certPool = pool } @@ -154,14 +173,14 @@ func (f *fronted) addProviders(providers map[string]*Provider) { } } -func (f *fronted) addMasquerades(masquerades sortedMasquerades) { +func (f *fronted) addFronts(fronts sortedFronts) { // Add new masquerades to the existing masquerades slice, but add them at the beginning. - f.masqueradesMu.Lock() - defer f.masqueradesMu.Unlock() - f.masquerades = append(masquerades, f.masquerades...) + f.frontsMu.Lock() + defer f.frontsMu.Unlock() + f.fronts = append(fronts, f.fronts...) } -func (f *fronted) providerFor(m MasqueradeInterface) *Provider { +func (f *fronted) providerFor(m Front) *Provider { pid := m.getProviderID() if pid == "" { pid = f.defaultProviderID @@ -177,7 +196,7 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, } - masq := &masquerade{Masquerade: *m} + masq := &front{Masquerade: *m} conn, _, err := d.doDial(masq) if err != nil { return false @@ -186,61 +205,75 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { return masq.postCheck(conn, testURL) } -// findWorkingMasquerades finds working masquerades by vetting them in batches and in -// parallel. Speed is of the essence here, as without working masquerades, users will +// findWorkingFronts finds working domain fronts by vetting them in batches and in +// parallel. Speed is of the essence here, as without working fronts, users will // be unable to fetch proxy configurations, particularly in the case of a first time // user who does not have proxies cached on disk. -func (f *fronted) findWorkingMasquerades(listener func(f *fronted)) { - // vet masquerades in batches +func (f *fronted) findWorkingFronts() { + // vet fronts in batches const batchSize int = 40 - var successful atomic.Uint32 - // We loop through all of them until we have 4 successful ones. - for i := 0; i < f.masqueradeSize() && successful.Load() < 4; i += batchSize { - f.vetBatch(i, batchSize, &successful, listener) + // Keep looping through all fronts making sure we have working ones. + i := 0 + for { + // Continually loop through the fronts in batches until we have 4 working ones, + // always looping around to the beginning if we reach the end. + // This is important, for example, when the user goes offline and all fronts start failing. + // We want to just keep trying in that case so that we find working fronts as soon as they + // come back online. + if f.workingFronts.size() < 4 { + f.vetBatch(i, batchSize) + i = index(i, batchSize, f.frontSize()) + } else { + select { + case <-f.stopCh: + log.Debug("Stopping parallel dialing") + return + case <-time.After(time.Duration(rand.IntN(12000)) * time.Millisecond): + } + } } } -func (f *fronted) masqueradeSize() int { - f.masqueradesMu.Lock() - defer f.masqueradesMu.Unlock() - return len(f.masquerades) +func index(i, batchSize, size int) int { + return (i + batchSize) % size } -func (f *fronted) masqueradeAt(i int) MasqueradeInterface { - f.masqueradesMu.Lock() - defer f.masqueradesMu.Unlock() - return f.masquerades[i] +func (f *fronted) frontSize() int { + f.frontsMu.Lock() + defer f.frontsMu.Unlock() + return len(f.fronts) } -func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, listener func(f *fronted)) { +func (f *fronted) frontAt(i int) Front { + f.frontsMu.Lock() + defer f.frontsMu.Unlock() + return f.fronts[i] +} + +func (f *fronted) vetBatch(start, batchSize int) { log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup - for i := start; i < start+batchSize && i < f.masqueradeSize(); i++ { + for i := start; i < start+batchSize && i < f.frontSize(); i++ { wg.Add(1) - go func(m MasqueradeInterface) { + go func(m Front) { defer wg.Done() - working, connectTime := f.vetMasquerade(m) + working := f.vetFront(m) if working { - successful.Add(1) - f.workingFronts.onConnected(m, connectTime) - if listener != nil { - go listener(f) - } + f.workingFronts.onConnected(m) } else { - f.workingFronts.onError(m) + m.markFailed() } - }(f.masqueradeAt(i)) + }(f.frontAt(i)) } wg.Wait() } -func (f *fronted) vetMasquerade(m MasqueradeInterface) (bool, time.Duration) { - start := time.Now() - conn, masqueradeGood, err := f.dialMasquerade(m) +func (f *fronted) vetFront(m Front) bool { + conn, masqueradeGood, err := f.dialFront(m) if err != nil { log.Debugf("unexpected error vetting masquerades: %v", err) - return false, time.Since(start) + return false } defer func() { if conn != nil { @@ -252,15 +285,15 @@ func (f *fronted) vetMasquerade(m MasqueradeInterface) (bool, time.Duration) { if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", m.getProviderID(), f.providers) - return false, time.Since(start) + return false } if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") - return false, time.Since(start) + return false } log.Debugf("Successfully vetted one masquerade %v", m.getIpAddress()) - return true, time.Since(start) + return true } // RoundTrip loops through all available masquerades, sorted by the one that most recently @@ -312,24 +345,22 @@ func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, log.Debugf("Retrying domain-fronted request, pass %d", i) } - m, err := f.workingFronts.workingFront() + m, err := f.workingFronts.connectingFront(req.Context()) if err != nil { // For some reason we have no working fronts. Sleep for a bit and try again. time.Sleep(1 * time.Second) continue } - conn, masqueradeGood, err := f.dialMasquerade(m) + conn, masqueradeGood, err := f.dialFront(m) if err != nil { log.Debugf("Could not dial to %v: %v", m, err) - f.workingFronts.onError(m) continue } resp, conn, err := f.request(req, conn, m, originHost, getBody, masqueradeGood) if err != nil { log.Debugf("Could not complete request: %v", err) - f.workingFronts.onError(m) } else { return resp, conn, nil } @@ -338,7 +369,7 @@ func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, return nil, nil, op.FailIf(errors.New("could not complete request even with retries")) } -func (f *fronted) request(req *http.Request, conn net.Conn, m MasqueradeInterface, originHost string, getBody func() io.ReadCloser, masqueradeGood func(bool) bool) (*http.Response, net.Conn, error) { +func (f *fronted) request(req *http.Request, conn net.Conn, m Front, originHost string, getBody func() io.ReadCloser, masqueradeGood func(bool) bool) (*http.Response, net.Conn, error) { op := ops.Begin("fronted_request") defer op.End() provider := f.providerFor(m) @@ -389,7 +420,7 @@ func (f *fronted) request(req *http.Request, conn net.Conn, m MasqueradeInterfac return resp, conn, nil } -func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { +func (f *fronted) dialFront(m Front) (net.Conn, func(bool) bool, error) { log.Tracef("Dialing to %v", m) // We do the full TLS connection here because in practice the domains at a given IP @@ -416,7 +447,7 @@ func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bo return conn, masqueradeGood, err } -func (f *fronted) doDial(m MasqueradeInterface) (net.Conn, bool, error) { +func (f *fronted) doDial(m Front) (net.Conn, bool, error) { op := ops.Begin("dial_masquerade") defer op.End() op.Set("masquerade_domain", m.getDomain()) diff --git a/fronted_test.go b/fronted_test.go index 0a83b9f..a504ca2 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -15,11 +15,9 @@ import ( "path/filepath" "strconv" "strings" - "sync/atomic" "testing" "time" - "github.com/getlantern/eventual/v2" . "github.com/getlantern/waitforserver" tls "github.com/refraction-networking/utls" "github.com/stretchr/testify/assert" @@ -56,11 +54,8 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { UseArbitrarySNIs: true, ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) - testContext := newFrontingContext("TestDirectDomainFrontingWithSNIConfig") - testContext.configure(certs, p, "akamai", cacheFile) + transport, err := NewFronted(certs, p, "akamai", cacheFile, tls.HelloChrome_100) - transport, ok := testContext.NewRoundTripper(30 * time.Second) - require.NoError(t, ok) client := &http.Client{ Transport: transport, } @@ -85,11 +80,8 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - testContext := newFrontingContext("doTestDomainFronting") - testContext.configure(certs, p, testProviderID, cacheFile) - - transport, ok := testContext.NewRoundTripper(30 * time.Second) - require.NoError(t, ok) + transport, err := NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + require.NoError(t, err) client := &http.Client{ Transport: transport, @@ -97,21 +89,19 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = testContext.NewRoundTripper(30 * time.Second) - require.NoError(t, ok) + transport, err = NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + require.NoError(t, err) client = &http.Client{ Transport: transport, } require.True(t, doCheck(client, http.MethodGet, http.StatusOK, getURL)) - instance, err := testContext.instance.Get(eventual.DontWait) - require.NoError(t, err) - d := instance.(*fronted) + d := transport.(*fronted) // Check the number of masquerades at the end, waiting until we get the right number masqueradesAtEnd := 0 for i := 0; i < 1000; i++ { - masqueradesAtEnd = len(d.masquerades) + masqueradesAtEnd = len(d.fronts) if masqueradesAtEnd == expectedMasqueradesAtEnd { break } @@ -131,33 +121,6 @@ func TestVet(t *testing.T) { t.Fatal("None of the default masquerades vetted successfully") } -func TestLoadMasquerades(t *testing.T) { - providers := testProviders() - - expected := make(map[Masquerade]bool) - for _, p := range providers { - for _, m := range p.Masquerades { - expected[*m] = true - } - } - - newMasquerades := loadMasquerades(providers, len(expected)) - - d := &fronted{ - masquerades: newMasquerades, - } - - actual := make(map[Masquerade]bool) - count := 0 - for _, m := range d.masquerades { - actual[Masquerade{Domain: m.getDomain(), IpAddress: m.getIpAddress()}] = true - count++ - } - - assert.Equal(t, len(DefaultCloudfrontMasquerades), count, "Unexpected number of candidates") - assert.Equal(t, expected, actual, "Masquerades did not load as expected") -} - func TestHostAliasesBasic(t *testing.T) { headersIn := map[string][]string{ @@ -240,13 +203,8 @@ func TestHostAliasesBasic(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - testContext := newFrontingContext("TestHostAliasesBasic") - testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + rt, err := NewFronted(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "", tls.HelloChrome_100) - rt, ok := testContext.NewRoundTripper(30 * time.Second) - if !assert.NoError(t, ok, "failed to obtain direct roundtripper") { - return - } client := &http.Client{Transport: rt} for _, test := range tests { req, err := http.NewRequest(http.MethodGet, test.url, nil) @@ -353,12 +311,8 @@ func TestHostAliasesMulti(t *testing.T) { "sadcloud": p2, } - testContext := newFrontingContext("TestHostAliasesMulti") - testContext.configure(certs, providers, "cloudsack", "") - rt, ok := testContext.NewRoundTripper(30 * time.Second) - if !assert.NoError(t, ok, "failed to obtain direct roundtripper") { - return - } + rt, err := NewFronted(certs, providers, "cloudsack", "", tls.HelloChrome_100) + client := &http.Client{Transport: rt} providerCounts := make(map[string]int) @@ -480,13 +434,9 @@ func TestPassthrough(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - testContext := newFrontingContext("TestPassthrough") - testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + rt, err := NewFronted(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "", tls.HelloChrome_100) + require.NoError(t, err) - rt, ok := testContext.NewRoundTripper(30 * time.Second) - if !assert.NoError(t, ok, "failed to obtain direct roundtripper") { - return - } client := &http.Client{Transport: rt} for _, test := range tests { req, err := http.NewRequest(http.MethodGet, test.url, nil) @@ -539,7 +489,7 @@ func TestCustomValidators(t *testing.T) { sadCloudValidator := NewStatusCodeValidator(sadCloudCodes) testURL := "https://abc.forbidden.com/quux" - setup := func(ctx *frontingContext, validator ResponseValidator) { + setup := func(validator ResponseValidator) (Fronted, error) { masq := []*Masquerade{{Domain: "example.com", IpAddress: sadCloudAddr}} alias := map[string]string{ "abc.forbidden.com": "abc.sadcloud.io", @@ -553,7 +503,7 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - ctx.configure(certs, providers, "sadcloud", "") + return NewFronted(certs, providers, "sadcloud", "", tls.HelloChrome_100) } // This error indicates that the validator has discarded all masquerades. @@ -634,10 +584,8 @@ func TestCustomValidators(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - testContext := newFrontingContext(test.name) - setup(testContext, test.validator) - direct, ok := testContext.NewRoundTripper(30 * time.Second) - require.NoError(t, ok) + direct, err := setup(test.validator) + require.NoError(t, err) client := &http.Client{ Transport: direct, } @@ -827,13 +775,13 @@ func TestVerifyPeerCertificate(t *testing.T) { func TestFindWorkingMasquerades(t *testing.T) { tests := []struct { name string - masquerades []*mockMasquerade + masquerades []*mockFront expectedSuccessful int expectedMasquerades int }{ { name: "All successful", - masquerades: []*mockMasquerade{ + masquerades: []*mockFront{ newMockMasquerade("domain1.com", "1.1.1.1", 0, true), newMockMasquerade("domain2.com", "2.2.2.2", 0, true), newMockMasquerade("domain3.com", "3.3.3.3", 0, true), @@ -845,7 +793,7 @@ func TestFindWorkingMasquerades(t *testing.T) { }, { name: "Some successful", - masquerades: []*mockMasquerade{ + masquerades: []*mockFront{ newMockMasquerade("domain1.com", "1.1.1.1", 0, true), newMockMasquerade("domain2.com", "2.2.2.2", 0, false), newMockMasquerade("domain3.com", "3.3.3.3", 0, true), @@ -856,7 +804,7 @@ func TestFindWorkingMasquerades(t *testing.T) { }, { name: "None successful", - masquerades: []*mockMasquerade{ + masquerades: []*mockFront{ newMockMasquerade("domain1.com", "1.1.1.1", 0, false), newMockMasquerade("domain2.com", "2.2.2.2", 0, false), newMockMasquerade("domain3.com", "3.3.3.3", 0, false), @@ -866,8 +814,8 @@ func TestFindWorkingMasquerades(t *testing.T) { }, { name: "Batch processing", - masquerades: func() []*mockMasquerade { - var masquerades []*mockMasquerade + masquerades: func() []*mockFront { + var masquerades []*mockFront for i := 0; i < 50; i++ { masquerades = append(masquerades, newMockMasquerade(fmt.Sprintf("domain%d.com", i), fmt.Sprintf("1.1.1.%d", i), 0, i%2 == 0)) } @@ -879,38 +827,93 @@ func TestFindWorkingMasquerades(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &fronted{} + d := &fronted{ + workingFronts: newConnectingFronts(10), + } d.providers = make(map[string]*Provider) d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) - d.masquerades = make(sortedMasquerades, len(tt.masquerades)) + d.fronts = make(sortedFronts, len(tt.masquerades)) for i, m := range tt.masquerades { - d.masquerades[i] = m + d.fronts[i] = m } - var successful atomic.Uint32 - d.vetBatch(0, 10, &successful, nil) + d.vetBatch(0, 10) tries := 0 - for successful.Load() < uint32(tt.expectedSuccessful) && tries < 100 { + for d.workingFronts.size() < tt.expectedSuccessful && tries < 100 { time.Sleep(30 * time.Millisecond) tries++ } - assert.GreaterOrEqual(t, int(successful.Load()), tt.expectedSuccessful) + assert.GreaterOrEqual(t, d.workingFronts.size(), tt.expectedSuccessful) + }) + } +} + +func TestLoadFronts(t *testing.T) { + providers := map[string]*Provider{ + "provider1": { + Masquerades: []*Masquerade{ + {Domain: "domain1.com", IpAddress: "1.1.1.1"}, + {Domain: "domain2.com", IpAddress: "2.2.2.2"}, + }, + }, + "provider2": { + Masquerades: []*Masquerade{ + {Domain: "domain3.com", IpAddress: "3.3.3.3"}, + {Domain: "domain4.com", IpAddress: "4.4.4.4"}, + }, + }, + } + + expected := map[string]bool{ + "domain1.com": true, + "domain2.com": true, + "domain3.com": true, + "domain4.com": true, + } + + masquerades := loadFronts(providers) + + assert.Equal(t, 4, len(masquerades), "Unexpected number of masquerades loaded") + + for _, m := range masquerades { + assert.True(t, expected[m.getDomain()], "Unexpected masquerade domain: %s", m.getDomain()) + } +} + +func TestIndex(t *testing.T) { + tests := []struct { + i, batchSize, size int + expected int + }{ + {i: 0, batchSize: 10, size: 100, expected: 10}, + {i: 5, batchSize: 10, size: 100, expected: 15}, + {i: 95, batchSize: 10, size: 100, expected: 5}, + {i: 99, batchSize: 10, size: 100, expected: 9}, + {i: 0, batchSize: 5, size: 20, expected: 5}, + {i: 15, batchSize: 5, size: 20, expected: 0}, + {i: 18, batchSize: 5, size: 20, expected: 3}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("i=%d,batchSize=%d,size=%d", test.i, test.batchSize, test.size), func(t *testing.T) { + result := index(test.i, test.batchSize, test.size) + assert.Equal(t, test.expected, result) }) } } // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. -func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade { +func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockFront { return newMockMasqueradeWithLastSuccess(domain, ipAddress, timeout, passesCheck, time.Time{}) } // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. -func newMockMasqueradeWithLastSuccess(domain string, ipAddress string, timeout time.Duration, passesCheck bool, lastSucceededTime time.Time) *mockMasquerade { - return &mockMasquerade{ +func newMockMasqueradeWithLastSuccess(domain string, ipAddress string, timeout time.Duration, passesCheck bool, lastSucceededTime time.Time) *mockFront { + return &mockFront{ Domain: domain, IpAddress: ipAddress, timeout: timeout, @@ -919,7 +922,7 @@ func newMockMasqueradeWithLastSuccess(domain string, ipAddress string, timeout t } } -type mockMasquerade struct { +type mockFront struct { Domain string IpAddress string timeout time.Duration @@ -928,22 +931,27 @@ type mockMasquerade struct { } // setLastSucceeded implements MasqueradeInterface. -func (m *mockMasquerade) setLastSucceeded(succeededTime time.Time) { +func (m *mockFront) setLastSucceeded(succeededTime time.Time) { m.lastSucceededTime = succeededTime } // lastSucceeded implements MasqueradeInterface. -func (m *mockMasquerade) lastSucceeded() time.Time { +func (m *mockFront) lastSucceeded() time.Time { return m.lastSucceededTime } +// isSucceeding implements MasqueradeInterface. +func (m *mockFront) isSucceeding() bool { + return m.lastSucceededTime.After(time.Time{}) +} + // postCheck implements MasqueradeInterface. -func (m *mockMasquerade) postCheck(net.Conn, string) bool { +func (m *mockFront) postCheck(net.Conn, string) bool { return m.passesCheck } // dial implements MasqueradeInterface. -func (m *mockMasquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { +func (m *mockFront) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { if m.timeout > 0 { time.Sleep(m.timeout) return nil, errors.New("mock dial error") @@ -953,28 +961,28 @@ func (m *mockMasquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHe } // getDomain implements MasqueradeInterface. -func (m *mockMasquerade) getDomain() string { +func (m *mockFront) getDomain() string { return m.Domain } // getIpAddress implements MasqueradeInterface. -func (m *mockMasquerade) getIpAddress() string { +func (m *mockFront) getIpAddress() string { return m.IpAddress } // getProviderID implements MasqueradeInterface. -func (m *mockMasquerade) getProviderID() string { +func (m *mockFront) getProviderID() string { return "testProviderId" } // markFailed implements MasqueradeInterface. -func (m *mockMasquerade) markFailed() { +func (m *mockFront) markFailed() { } // markSucceeded implements MasqueradeInterface. -func (m *mockMasquerade) markSucceeded() { +func (m *mockFront) markSucceeded() { } // Make sure that the mockMasquerade implements the MasqueradeInterface -var _ MasqueradeInterface = (*mockMasquerade)(nil) +var _ Front = (*mockFront)(nil) diff --git a/go.mod b/go.mod index d0f6eab..8a50f55 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/getlantern/fronted go 1.22.3 require ( - github.com/getlantern/eventual/v2 v2.0.2 github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 github.com/getlantern/keyman v0.0.0-20180207174507-f55e7280e93a github.com/getlantern/netx v0.0.0-20210806160745-b824e2cad607 diff --git a/go.sum b/go.sum index f8b45c6..9450ea9 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,6 @@ github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7/go.mod h1:l+xpFB github.com/getlantern/errors v1.0.1/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A= github.com/getlantern/errors v1.0.3 h1:Ne4Ycj7NI1BtSyAfVeAT/DNoxz7/S2BUc3L2Ht1YSHE= github.com/getlantern/errors v1.0.3/go.mod h1:m8C7H1qmouvsGpwQqk/6NUpIVMpfzUPn608aBZDYV04= -github.com/getlantern/eventual/v2 v2.0.2 h1:7b3N2oQBVqSHwm/8u7C8b6W+OkkjgZSmwUc1AdIkrHc= -github.com/getlantern/eventual/v2 v2.0.2/go.mod h1:o1VZHRk8UArBra+pwPSi23WrahBG4qgg4/ew6Mmlq84= github.com/getlantern/fdcount v0.0.0-20190912142506-f89afd7367c4 h1:JdD4XSaT6/j6InM7MT1E4WRvzR8gurxfq53A3ML3B/Q= github.com/getlantern/fdcount v0.0.0-20190912142506-f89afd7367c4/go.mod h1:XZwE+iIlAgr64OFbXKFNCllBwV4wEipPx8Hlo2gZdbM= github.com/getlantern/filepersist v0.0.0-20160317154340-c5f0cd24e799 h1:FhkPUYCQYmoxS02r2GRrIV7dahUIncRl36xzs3/mnjA= diff --git a/test_support.go b/test_support.go index cd615f7..55c936c 100644 --- a/test_support.go +++ b/test_support.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/getlantern/keyman" + tls "github.com/refraction-networking/utls" ) var ( @@ -23,13 +24,13 @@ func ConfigureForTest(t *testing.T) { func ConfigureCachingForTest(t *testing.T, cacheFile string) { certs := trustedCACerts(t) p := testProviders() - NewFronted(certs, p, testProviderID, cacheFile) + NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) } func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - NewFronted(certs, p, testProviderID, "") + NewFronted(certs, p, testProviderID, "", tls.HelloChrome_100) } func trustedCACerts(t *testing.T) *x509.CertPool { From 9b577b9a1d7456a4dbb78c7a3e6962c1bb94ee04 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 11:38:16 -0700 Subject: [PATCH 15/25] Change to only modify global config details via update call --- fronted.go | 44 +++++++++++++++++++++++++------------------- fronted_test.go | 28 +++++++++++++++++++++------- test_support.go | 12 ++++++++++-- 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/fronted.go b/fronted.go index f8becc5..6c7cad6 100644 --- a/fronted.go +++ b/fronted.go @@ -51,6 +51,7 @@ type fronted struct { frontsMu sync.RWMutex frontedMu sync.RWMutex stopCh chan interface{} + crawlOnce sync.Once } // Interface for sending HTTP traffic over domain fronting. @@ -58,7 +59,7 @@ type Fronted interface { http.RoundTripper // UpdateConfig updates the set of domain fronts to try. - UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) + UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) // Close closes any resources, such as goroutines that are testing fronts. Close() @@ -70,40 +71,31 @@ type Fronted interface { // // defaultProviderID is used when a front without a provider is // encountered (eg in a cache file) -func NewFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, - clientHello tls.ClientHelloID) (Fronted, error) { +func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProviderID string) (Fronted, error) { log.Debug("Creating new fronted") // Log method elapsed time defer func(start time.Time) { log.Debugf("Creating a new fronted took %v", time.Since(start)) }(time.Now()) - if len(providers) == 0 { - return nil, log.Errorf("No providers configured") - } - - providersCopy := copyProviders(providers) - fronts := loadFronts(providersCopy) - f := &fronted{ - certPool: pool, - fronts: fronts, + certPool: nil, + fronts: make(sortedFronts, 0), maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, cacheSaveInterval: defaultCacheSaveInterval, cacheDirty: make(chan interface{}, 1), cacheClosed: make(chan interface{}), - defaultProviderID: defaultProviderID, - providers: providersCopy, + providers: make(map[string]*Provider), clientHelloID: clientHello, - workingFronts: newConnectingFronts(len(fronts)), + workingFronts: newConnectingFronts(4000), stopCh: make(chan interface{}), + defaultProviderID: defaultProviderID, } if cacheFile != "" { f.initCaching(cacheFile) } - go f.findWorkingFronts() return f, nil } @@ -151,17 +143,25 @@ func loadFronts(providers map[string]*Provider) sortedFronts { return fronts } -func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { +func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) { // Make copies just to avoid any concurrency issues with access that may be happening on the // caller side. log.Debug("Updating fronted configuration") + if len(providers) == 0 { + log.Errorf("No providers configured") + return + } providersCopy := copyProviders(providers) f.frontedMu.Lock() defer f.frontedMu.Unlock() f.addProviders(providersCopy) f.addFronts(loadFronts(providersCopy)) - f.defaultProviderID = defaultProviderID + f.certPool = pool + + f.crawlOnce.Do(func() { + go f.findWorkingFronts() + }) } func (f *fronted) addProviders(providers map[string]*Provider) { @@ -457,7 +457,7 @@ func (f *fronted) doDial(m Front) (net.Conn, bool, error) { var conn net.Conn var err error retriable := false - conn, err = m.dial(f.certPool, f.clientHelloID) + conn, err = m.dial(f.getCertPool(), f.clientHelloID) if err != nil { if !isNetworkUnreachable(err) { op.FailIf(err) @@ -477,6 +477,12 @@ func (f *fronted) doDial(m Front) (net.Conn, bool, error) { return conn, retriable, err } +func (f *fronted) getCertPool() *x509.CertPool { + f.frontedMu.RLock() + defer f.frontedMu.RUnlock() + return f.certPool +} + func isNetworkUnreachable(err error) bool { var opErr *net.OpError if errors.As(err, &opErr) { diff --git a/fronted_test.go b/fronted_test.go index a504ca2..b6a21bb 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -54,7 +54,9 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { UseArbitrarySNIs: true, ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) - transport, err := NewFronted(certs, p, "akamai", cacheFile, tls.HelloChrome_100) + transport, err := NewFronted(cacheFile, tls.HelloChrome_100, "akamai") + require.NoError(t, err) + transport.UpdateConfig(certs, p) client := &http.Client{ Transport: transport, @@ -80,8 +82,9 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - transport, err := NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + transport, err := NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) require.NoError(t, err) + transport.UpdateConfig(certs, p) client := &http.Client{ Transport: transport, @@ -89,8 +92,9 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, err = NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + transport, err = NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) require.NoError(t, err) + transport.UpdateConfig(certs, p) client = &http.Client{ Transport: transport, } @@ -203,7 +207,9 @@ func TestHostAliasesBasic(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - rt, err := NewFronted(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "", tls.HelloChrome_100) + rt, err := NewFronted("", tls.HelloChrome_100, "cloudsack") + require.NoError(t, err) + rt.UpdateConfig(certs, map[string]*Provider{"cloudsack": p}) client := &http.Client{Transport: rt} for _, test := range tests { @@ -311,7 +317,9 @@ func TestHostAliasesMulti(t *testing.T) { "sadcloud": p2, } - rt, err := NewFronted(certs, providers, "cloudsack", "", tls.HelloChrome_100) + rt, err := NewFronted("", tls.HelloChrome_100, "cloudsack") + require.NoError(t, err) + rt.UpdateConfig(certs, providers) client := &http.Client{Transport: rt} @@ -434,8 +442,9 @@ func TestPassthrough(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - rt, err := NewFronted(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "", tls.HelloChrome_100) + rt, err := NewFronted("", tls.HelloChrome_100, "cloudsack") require.NoError(t, err) + rt.UpdateConfig(certs, map[string]*Provider{"cloudsack": p}) client := &http.Client{Transport: rt} for _, test := range tests { @@ -503,7 +512,12 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - return NewFronted(certs, providers, "sadcloud", "", tls.HelloChrome_100) + f, err := NewFronted("", tls.HelloChrome_100, "sadcloud") + if err != nil { + return nil, err + } + f.UpdateConfig(certs, providers) + return f, nil } // This error indicates that the validator has discarded all masquerades. diff --git a/test_support.go b/test_support.go index 55c936c..9d3a2b9 100644 --- a/test_support.go +++ b/test_support.go @@ -24,13 +24,21 @@ func ConfigureForTest(t *testing.T) { func ConfigureCachingForTest(t *testing.T, cacheFile string) { certs := trustedCACerts(t) p := testProviders() - NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + f, err := NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) + if err != nil { + t.Fatalf("Unable to create fronted: %v", err) + } + f.UpdateConfig(certs, p) } func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - NewFronted(certs, p, testProviderID, "", tls.HelloChrome_100) + f, err := NewFronted("", tls.HelloChrome_100, testProviderID) + if err != nil { + t.Fatalf("Unable to create fronted: %v", err) + } + f.UpdateConfig(certs, p) } func trustedCACerts(t *testing.T) *x509.CertPool { From 988625bd5f5ab02bfb68436717e8a50a5188a360 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 13:06:29 -0700 Subject: [PATCH 16/25] no eventual --- go.mod | 1 - go.sum | 1 - 2 files changed, 2 deletions(-) diff --git a/go.mod b/go.mod index e392179..4951da0 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.22.6 toolchain go1.23.3 require ( - github.com/getlantern/eventual/v2 v2.0.2 github.com/getlantern/golog v0.0.0-20230503153817-8e72de7e0a65 github.com/getlantern/keyman v0.0.0-20200819205636-76fef27c39f1 github.com/getlantern/netx v0.0.0-20240814210628-0984f52e2d18 diff --git a/go.sum b/go.sum index f69619d..53e51c6 100644 --- a/go.sum +++ b/go.sum @@ -82,7 +82,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= From 7eddee1a57555b3adb9037726c41522ed15223d1 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 13:24:18 -0700 Subject: [PATCH 17/25] Fix test --- cache.go | 14 ++++++++------ cache_test.go | 9 ++++++++- fronted.go | 2 +- fronted_test.go | 17 +++++++++-------- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/cache.go b/cache.go index 4af5fa5..076bb5a 100644 --- a/cache.go +++ b/cache.go @@ -8,11 +8,11 @@ import ( ) func (d *fronted) initCaching(cacheFile string) { - d.prepopulateMasquerades(cacheFile) + d.prepopulateFronts(cacheFile) go d.maintainCache(cacheFile) } -func (d *fronted) prepopulateMasquerades(cacheFile string) { +func (d *fronted) prepopulateFronts(cacheFile string) { bytes, err := os.ReadFile(cacheFile) if err != nil { // This is not a big deal since we'll just fill the cache later @@ -27,18 +27,18 @@ func (d *fronted) prepopulateMasquerades(cacheFile string) { } log.Debugf("Attempting to prepopulate masquerades from cache file: %v", cacheFile) - var cachedMasquerades []*front - if err := json.Unmarshal(bytes, &cachedMasquerades); err != nil { + var cachedFronts []*front + if err := json.Unmarshal(bytes, &cachedFronts); err != nil { log.Errorf("Error reading cached masquerades: %v", err) return } - log.Debugf("Cache contained %d masquerades", len(cachedMasquerades)) + log.Debugf("Cache contained %d masquerades", len(cachedFronts)) now := time.Now() // update last succeeded status of masquerades based on cached values for _, m := range d.fronts { - for _, cm := range cachedMasquerades { + for _, cm := range cachedFronts { sameMasquerade := cm.ProviderID == m.getProviderID() && cm.Domain == m.getDomain() && cm.IpAddress == m.getIpAddress() cachedValueFresh := now.Sub(m.lastSucceeded()) < d.maxAllowedCachedAge if sameMasquerade && cachedValueFresh { @@ -98,6 +98,8 @@ func (d *fronted) updateCache(cacheFile string) { // parent directory does not exist log.Debugf("Parent directory of cache file does not exist: %v", parent) } + } else { + log.Debugf("Cache saved to disk") } } diff --git a/cache_test.go b/cache_test.go index ba18747..4c39035 100644 --- a/cache_test.go +++ b/cache_test.go @@ -26,6 +26,7 @@ func TestCaching(t *testing.T) { cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil), } + log.Debug("Creating fronted") makeFronted := func() *fronted { f := &fronted{ fronts: make(sortedFronts, 0, 1000), @@ -36,6 +37,7 @@ func TestCaching(t *testing.T) { cacheClosed: make(chan interface{}), providers: providers, defaultProviderID: cloudsackID, + stopCh: make(chan interface{}, 10), } go f.maintainCache(cacheFile) return f @@ -47,9 +49,12 @@ func TestCaching(t *testing.T) { md := &front{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped f := makeFronted() + + log.Debug("Adding fronts") f.fronts = append(f.fronts, mb, mc, md) readCached := func() []*front { + log.Debug("Reading cached fronts") var result []*front b, err := os.ReadFile(cacheFile) require.NoError(t, err, "Unable to read cache file") @@ -60,14 +65,16 @@ func TestCaching(t *testing.T) { // Save the cache f.markCacheDirty() + time.Sleep(f.cacheSaveInterval * 2) f.Close() time.Sleep(50 * time.Millisecond) + log.Debug("Reopening fronted") // Reopen cache file and make sure right data was in there f = makeFronted() - f.prepopulateMasquerades(cacheFile) + f.prepopulateFronts(cacheFile) masquerades := readCached() require.Len(t, masquerades, 3, "Wrong number of masquerades read") for i, expected := range []*front{mb, mc, md} { diff --git a/fronted.go b/fronted.go index 6c7cad6..8f21c55 100644 --- a/fronted.go +++ b/fronted.go @@ -89,7 +89,7 @@ func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProvider providers: make(map[string]*Provider), clientHelloID: clientHello, workingFronts: newConnectingFronts(4000), - stopCh: make(chan interface{}), + stopCh: make(chan interface{}, 10), defaultProviderID: defaultProviderID, } diff --git a/fronted_test.go b/fronted_test.go index b6a21bb..e46c9fd 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -841,25 +841,26 @@ func TestFindWorkingMasquerades(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &fronted{ + f := &fronted{ workingFronts: newConnectingFronts(10), + stopCh: make(chan interface{}, 10), } - d.providers = make(map[string]*Provider) - d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) - d.fronts = make(sortedFronts, len(tt.masquerades)) + f.providers = make(map[string]*Provider) + f.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) + f.fronts = make(sortedFronts, len(tt.masquerades)) for i, m := range tt.masquerades { - d.fronts[i] = m + f.fronts[i] = m } - d.vetBatch(0, 10) + f.vetBatch(0, 10) tries := 0 - for d.workingFronts.size() < tt.expectedSuccessful && tries < 100 { + for f.workingFronts.size() < tt.expectedSuccessful && tries < 100 { time.Sleep(30 * time.Millisecond) tries++ } - assert.GreaterOrEqual(t, d.workingFronts.size(), tt.expectedSuccessful) + assert.GreaterOrEqual(t, f.workingFronts.size(), tt.expectedSuccessful) }) } } From 83d725772d3ee8e62b35572e597f5203a026a2f4 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 13:43:36 -0700 Subject: [PATCH 18/25] Improve naming and make sure requests have a context --- cache.go | 12 ++++++------ connecting_fronts.go | 19 ++++++++++--------- connecting_fronts_test.go | 6 +++--- fronted.go | 30 +++++++++++++++++++----------- fronted_test.go | 11 +++++++---- 5 files changed, 45 insertions(+), 33 deletions(-) diff --git a/cache.go b/cache.go index 076bb5a..e4ff8c6 100644 --- a/cache.go +++ b/cache.go @@ -37,12 +37,12 @@ func (d *fronted) prepopulateFronts(cacheFile string) { now := time.Now() // update last succeeded status of masquerades based on cached values - for _, m := range d.fronts { - for _, cm := range cachedFronts { - sameMasquerade := cm.ProviderID == m.getProviderID() && cm.Domain == m.getDomain() && cm.IpAddress == m.getIpAddress() - cachedValueFresh := now.Sub(m.lastSucceeded()) < d.maxAllowedCachedAge - if sameMasquerade && cachedValueFresh { - m.setLastSucceeded(cm.LastSucceeded) + for _, f := range d.fronts { + for _, cf := range cachedFronts { + sameFront := cf.ProviderID == f.getProviderID() && cf.Domain == f.getDomain() && cf.IpAddress == f.getIpAddress() + cachedValueFresh := now.Sub(f.lastSucceeded()) < d.maxAllowedCachedAge + if sameFront && cachedValueFresh { + f.setLastSucceeded(cf.LastSucceeded) } } } diff --git a/connecting_fronts.go b/connecting_fronts.go index 8e2147f..39b178e 100644 --- a/connecting_fronts.go +++ b/connecting_fronts.go @@ -4,34 +4,35 @@ import ( "context" ) -type workingFronts interface { +type connectingFronts interface { onConnected(m Front) connectingFront(context.Context) (Front, error) size() int } -type connectingFronts struct { +type connecting struct { // Create a channel of fronts that are connecting. frontsCh chan Front } // Make sure that connectingFronts is a connectListener -var _ workingFronts = &connectingFronts{} +var _ connectingFronts = &connecting{} -// newConnectingFronts creates a new ConnectingFronts struct with an empty slice of Masquerade IPs and domains. -func newConnectingFronts(size int) *connectingFronts { - return &connectingFronts{ +// newConnectingFronts creates a new ConnectingFronts struct with a channel of fronts that have +// successfully connected. +func newConnectingFronts(size int) *connecting { + return &connecting{ // We overallocate the channel to avoid blocking. frontsCh: make(chan Front, size), } } // AddFront adds a new front to the list of fronts. -func (cf *connectingFronts) onConnected(m Front) { +func (cf *connecting) onConnected(m Front) { cf.frontsCh <- m } -func (cf *connectingFronts) connectingFront(ctx context.Context) (Front, error) { +func (cf *connecting) connectingFront(ctx context.Context) (Front, error) { for { select { case <-ctx.Done(): @@ -48,6 +49,6 @@ func (cf *connectingFronts) connectingFront(ctx context.Context) (Front, error) } } -func (cf *connectingFronts) size() int { +func (cf *connecting) size() int { return len(cf.frontsCh) } diff --git a/connecting_fronts_test.go b/connecting_fronts_test.go index 23e3fa9..4ef0fc6 100644 --- a/connecting_fronts_test.go +++ b/connecting_fronts_test.go @@ -7,19 +7,19 @@ import ( func TestConnectingFrontsSize(t *testing.T) { tests := []struct { name string - setup func() *connectingFronts + setup func() *connecting expected int }{ { name: "empty channel", - setup: func() *connectingFronts { + setup: func() *connecting { return newConnectingFronts(10) }, expected: 0, }, { name: "non-empty channel", - setup: func() *connectingFronts { + setup: func() *connecting { cf := newConnectingFronts(10) cf.onConnected(&mockFront{}) return cf diff --git a/fronted.go b/fronted.go index 8f21c55..aa40fac 100644 --- a/fronted.go +++ b/fronted.go @@ -2,6 +2,7 @@ package fronted import ( "bytes" + "context" "crypto/x509" "errors" "fmt" @@ -46,7 +47,7 @@ type fronted struct { defaultProviderID string providers map[string]*Provider clientHelloID tls.ClientHelloID - workingFronts workingFronts + connectingFronts connectingFronts providersMu sync.RWMutex frontsMu sync.RWMutex frontedMu sync.RWMutex @@ -65,12 +66,9 @@ type Fronted interface { Close() } -// NewFronted sets the domain fronts to use, the trusted root CAs, the fronting providers -// (such as Akamai, Cloudfront, etc), and the cache file for caching fronts to set up -// domain fronting. -// -// defaultProviderID is used when a front without a provider is -// encountered (eg in a cache file) +// NewFronted creates a new Fronted instance with the given cache file, clientHelloID, and defaultProviderID. +// At this point it does not have the actual IPs, domains, etc of the fronts to try. +// defaultProviderID is used when a front without a provider is encountered (eg in a cache file) func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProviderID string) (Fronted, error) { log.Debug("Creating new fronted") // Log method elapsed time @@ -88,7 +86,7 @@ func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProvider cacheClosed: make(chan interface{}), providers: make(map[string]*Provider), clientHelloID: clientHello, - workingFronts: newConnectingFronts(4000), + connectingFronts: newConnectingFronts(4000), stopCh: make(chan interface{}, 10), defaultProviderID: defaultProviderID, } @@ -143,6 +141,9 @@ func loadFronts(providers map[string]*Provider) sortedFronts { return fronts } +// UpdateConfig sets the domain fronts to use, the trusted root CAs, the fronting providers +// (such as Akamai, Cloudfront, etc), and the cache file for caching fronts to set up +// domain fronting. func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) { // Make copies just to avoid any concurrency issues with access that may be happening on the // caller side. @@ -159,6 +160,7 @@ func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provid f.certPool = pool + // The goroutine for finding working fronts runs forever, so only start it once. f.crawlOnce.Do(func() { go f.findWorkingFronts() }) @@ -221,7 +223,7 @@ func (f *fronted) findWorkingFronts() { // This is important, for example, when the user goes offline and all fronts start failing. // We want to just keep trying in that case so that we find working fronts as soon as they // come back online. - if f.workingFronts.size() < 4 { + if f.connectingFronts.size() < 4 { f.vetBatch(i, batchSize) i = index(i, batchSize, f.frontSize()) } else { @@ -260,7 +262,7 @@ func (f *fronted) vetBatch(start, batchSize int) { defer wg.Done() working := f.vetFront(m) if working { - f.workingFronts.onConnected(m) + f.connectingFronts.onConnected(m) } else { m.markFailed() } @@ -308,6 +310,12 @@ func (f *fronted) RoundTrip(req *http.Request) (*http.Response, error) { func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { op := ops.Begin("fronted_roundtrip") defer op.End() + // If the request has a context, use it. Otherwise, create a new one that has a timeout. + if req.Context() == nil { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + req = req.WithContext(ctx) + } isIdempotent := req.Method != http.MethodPost && req.Method != http.MethodPatch op.Set("is_idempotent", isIdempotent) @@ -345,7 +353,7 @@ func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, log.Debugf("Retrying domain-fronted request, pass %d", i) } - m, err := f.workingFronts.connectingFront(req.Context()) + m, err := f.connectingFronts.connectingFront(req.Context()) if err != nil { // For some reason we have no working fronts. Sleep for a bit and try again. time.Sleep(1 * time.Second) diff --git a/fronted_test.go b/fronted_test.go index e46c9fd..a82ef65 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -29,10 +29,13 @@ func TestDirectDomainFrontingWithoutSNIConfig(t *testing.T) { require.NoError(t, err, "Unable to create temp dir") defer os.RemoveAll(dir) cacheFile := filepath.Join(dir, "cachefile.2") + + log.Debug("Testing direct domain fronting without SNI config") doTestDomainFronting(t, cacheFile, 10) time.Sleep(defaultCacheSaveInterval * 2) // Then try again, this time reusing the existing cacheFile but a corrupted version corruptMasquerades(cacheFile) + log.Debug("Testing direct domain fronting without SNI config again") doTestDomainFronting(t, cacheFile, 10) } @@ -842,8 +845,8 @@ func TestFindWorkingMasquerades(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { f := &fronted{ - workingFronts: newConnectingFronts(10), - stopCh: make(chan interface{}, 10), + connectingFronts: newConnectingFronts(10), + stopCh: make(chan interface{}, 10), } f.providers = make(map[string]*Provider) f.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) @@ -855,12 +858,12 @@ func TestFindWorkingMasquerades(t *testing.T) { f.vetBatch(0, 10) tries := 0 - for f.workingFronts.size() < tt.expectedSuccessful && tries < 100 { + for f.connectingFronts.size() < tt.expectedSuccessful && tries < 100 { time.Sleep(30 * time.Millisecond) tries++ } - assert.GreaterOrEqual(t, f.workingFronts.size(), tt.expectedSuccessful) + assert.GreaterOrEqual(t, f.connectingFronts.size(), tt.expectedSuccessful) }) } } From 8be677503f807580b4ea7a043bd731623aa295fb Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 14:44:21 -0700 Subject: [PATCH 19/25] Updated to return fronted instance --- test_support.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test_support.go b/test_support.go index 9d3a2b9..d45702a 100644 --- a/test_support.go +++ b/test_support.go @@ -17,11 +17,11 @@ var ( // ConfigureForTest configures fronted for testing using default masquerades and // certificate authorities. -func ConfigureForTest(t *testing.T) { - ConfigureCachingForTest(t, "") +func ConfigureForTest(t *testing.T) Fronted { + return ConfigureCachingForTest(t, "") } -func ConfigureCachingForTest(t *testing.T, cacheFile string) { +func ConfigureCachingForTest(t *testing.T, cacheFile string) Fronted { certs := trustedCACerts(t) p := testProviders() f, err := NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) @@ -29,9 +29,10 @@ func ConfigureCachingForTest(t *testing.T, cacheFile string) { t.Fatalf("Unable to create fronted: %v", err) } f.UpdateConfig(certs, p) + return f } -func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { +func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) Fronted { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) f, err := NewFronted("", tls.HelloChrome_100, testProviderID) @@ -39,6 +40,7 @@ func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { t.Fatalf("Unable to create fronted: %v", err) } f.UpdateConfig(certs, p) + return f } func trustedCACerts(t *testing.T) *x509.CertPool { From d7ef15835878a5882d4c02118b8437eed5361a37 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Wed, 4 Dec 2024 09:32:15 -0700 Subject: [PATCH 20/25] downgraded some dependencies --- go.mod | 6 +++--- go.sum | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 4951da0..16df9fc 100644 --- a/go.mod +++ b/go.mod @@ -9,16 +9,16 @@ require ( github.com/getlantern/keyman v0.0.0-20200819205636-76fef27c39f1 github.com/getlantern/netx v0.0.0-20240814210628-0984f52e2d18 github.com/getlantern/ops v0.0.0-20231025133620-f368ab734534 - github.com/getlantern/tlsdialer/v3 v3.0.4 + github.com/getlantern/tlsdialer/v3 v3.0.3 github.com/getlantern/waitforserver v1.0.1 - github.com/refraction-networking/utls v1.6.7 + github.com/refraction-networking/utls v1.3.3 github.com/stretchr/testify v1.8.4 ) require ( github.com/andybalholm/brotli v1.0.6 // indirect - github.com/cloudflare/circl v1.3.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gaukas/godicttls v0.0.3 // indirect github.com/getlantern/byteexec v0.0.0-20170405023437-4cfb26ec74f4 // indirect github.com/getlantern/context v0.0.0-20220418194847-3d5e7a086201 // indirect github.com/getlantern/elevate v0.0.0-20200430163644-2881a121236d // indirect diff --git a/go.sum b/go.sum index 53e51c6..444b727 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,17 @@ github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= -github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= +github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/getlantern/byteexec v0.0.0-20170405023437-4cfb26ec74f4 h1:Nqmy8i81dzokjNHpyOg24gnQBeGRF7D51m8HmBRNn0Y= github.com/getlantern/byteexec v0.0.0-20170405023437-4cfb26ec74f4/go.mod h1:4WCQkaCIwta0KlF9bQZA1jYqp8bzIS2PeCqjnef8nZ8= github.com/getlantern/context v0.0.0-20190109183933-c447772a6520/go.mod h1:L+mq6/vvYHKjCX2oez0CgEAJmbq1fbb/oNJIWQkBybY= github.com/getlantern/context v0.0.0-20220418194847-3d5e7a086201 h1:oEZYEpZo28Wdx+5FZo4aU7JFXu0WG/4wJWese5reQSA= github.com/getlantern/context v0.0.0-20220418194847-3d5e7a086201/go.mod h1:Y9WZUHEb+mpra02CbQ/QczLUe6f0Dezxaw5DCJlJQGo= +github.com/getlantern/elevate v0.0.0-20180207094634-c2e2e4901072/go.mod h1:T4VB2POK13lsPLFV98WJQrL7gAXYD9TyJxBU2P8c8p4= github.com/getlantern/elevate v0.0.0-20200430163644-2881a121236d h1:o0EHYAq7u9/umRZE0PpJ00GYQvxPxVUvtoDkUca2guQ= github.com/getlantern/elevate v0.0.0-20200430163644-2881a121236d/go.mod h1:+nYKXAqGigcDHB3as7WikMzg3eIHzGUbLnBKOCBJeUE= github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A= @@ -31,6 +32,7 @@ github.com/getlantern/hex v0.0.0-20220104173244-ad7e4b9194dc/go.mod h1:D9RWpXy/E github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA= github.com/getlantern/hidden v0.0.0-20220104173330-f221c5a24770 h1:cSrD9ryDfTV2yaur9Qk3rHYD414j3Q1rl7+L0AylxrE= github.com/getlantern/hidden v0.0.0-20220104173330-f221c5a24770/go.mod h1:GOQsoDnEHl6ZmNIL+5uVo+JWRFWozMEp18Izcb++H+A= +github.com/getlantern/iptool v0.0.0-20210721034953-519bf8ce0147/go.mod h1:hfspzdRcvJ130tpTPL53/L92gG0pFtvQ6ln35ppwhHE= github.com/getlantern/iptool v0.0.0-20230112135223-c00e863b2696 h1:D7wbL2Ww6QN5SblEDMiQcFulqz2jgcvawKaNBTzHLvQ= github.com/getlantern/iptool v0.0.0-20230112135223-c00e863b2696/go.mod h1:hfspzdRcvJ130tpTPL53/L92gG0pFtvQ6ln35ppwhHE= github.com/getlantern/keyman v0.0.0-20200819205636-76fef27c39f1 h1:8qNXKWQBgmMfaXXTNfYs1D6i42etSjvwxfCSlmvHHr8= @@ -39,16 +41,16 @@ github.com/getlantern/mockconn v0.0.0-20200818071412-cb30d065a848 h1:2MhMMVBTnaH github.com/getlantern/mockconn v0.0.0-20200818071412-cb30d065a848/go.mod h1:+F5GJ7qGpQ03DBtcOEyQpM30ix4BLswdaojecFtsdy8= github.com/getlantern/mtime v0.0.0-20200417132445-23682092d1f7 h1:03J6Cb42EG06lHgpOFGm5BOax4qFqlSbSeKO2RGrj2g= github.com/getlantern/mtime v0.0.0-20200417132445-23682092d1f7/go.mod h1:GfzwugvtH7YcmNIrHHizeyImsgEdyL88YkdnK28B14c= +github.com/getlantern/netx v0.0.0-20211206143627-7ccfeb739cbd/go.mod h1:WEXF4pfIfnHBUAKwLa4DW7kcEINtG6wjUkbL2btwXZQ= github.com/getlantern/netx v0.0.0-20240814210628-0984f52e2d18 h1:I5xFq/HkvWGUPysqC8LQH9oks1WaM9BpcB+fjmvMRic= github.com/getlantern/netx v0.0.0-20240814210628-0984f52e2d18/go.mod h1:4WkWbHy7Mqri9lxpLFN6dOU5nUy3kyNCpHxSRQZocv0= github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA= +github.com/getlantern/ops v0.0.0-20200403153110-8476b16edcd6/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA= github.com/getlantern/ops v0.0.0-20220713155959-1315d978fff7/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA= github.com/getlantern/ops v0.0.0-20231025133620-f368ab734534 h1:3BwvWj0JZzFEvNNiMhCu4bf60nqcIuQpTYb00Ezm1ag= github.com/getlantern/ops v0.0.0-20231025133620-f368ab734534/go.mod h1:ZsLfOY6gKQOTyEcPYNA9ws5/XHZQFroxqCOhHjGcs9Y= -github.com/getlantern/tlsdialer/v3 v3.0.4 h1:j9GHqtD2+cwGP/q+Rvr/wC43nPrRPk6YfEmWfZJ4p1I= -github.com/getlantern/tlsdialer/v3 v3.0.4/go.mod h1:G0rWRzTX9WuQ0r31c/Zg9sfwTrMs82kyQCswlhwv/Us= -github.com/getlantern/tlsresumption v0.0.0-20241203054031-f3e4eec291ad h1:RsOMhwKzMD0M7FrsqZ0fKwTblr6pNCYrzmtnbnVr3Cg= -github.com/getlantern/tlsresumption v0.0.0-20241203054031-f3e4eec291ad/go.mod h1:lAlPJK1Y5nFlydkty/imyPFpLCi5V+hzQHXOqeoeXyk= +github.com/getlantern/tlsdialer/v3 v3.0.3 h1:OXzzAqO8YojBOu2Kk8wquX2zbFmgJjji41RpaT6knLg= +github.com/getlantern/tlsdialer/v3 v3.0.3/go.mod h1:hwA0X81pnrgx7GEwddaGWSxqr6eLBm7A0rrUMK2J7KY= github.com/getlantern/waitforserver v1.0.1 h1:xBjqJ3GgEk9JMWnDgRSiNHXINi6Lv2tGNjJR0hCkHFY= github.com/getlantern/waitforserver v1.0.1/go.mod h1:K1oSA8lNKgQ9iC00OFpMfMNm4UMrsxoGCdHf0NT9LGs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -76,8 +78,9 @@ github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c/go.mod h1:X07ZCGwU github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/refraction-networking/utls v1.6.7 h1:zVJ7sP1dJx/WtVuITug3qYUq034cDq9B2MR1K67ULZM= -github.com/refraction-networking/utls v1.6.7/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0= +github.com/refraction-networking/utls v0.0.0-20190909200633-43c36d3c1f57/go.mod h1:tz9gX959MEFfFN5whTIocCLUG57WiILqtdVxI8c6Wj0= +github.com/refraction-networking/utls v1.3.3 h1:f/TBLX7KBciRyFH3bwupp+CE4fzoYKCirhdRcC490sw= +github.com/refraction-networking/utls v1.3.3/go.mod h1:DlecWW1LMlMJu+9qpzzQqdHDT/C2LAe03EdpLUz/RL8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -107,6 +110,7 @@ go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= @@ -120,6 +124,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190912141932-bc967efca4b8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From b9307b77633106f430fcafdd2458830950eb3ced Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Wed, 4 Dec 2024 11:17:05 -0700 Subject: [PATCH 21/25] Improved cert pool handling --- fronted.go | 66 +++++++++++++++++++++++++----------------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/fronted.go b/fronted.go index aa40fac..bc0dd1f 100644 --- a/fronted.go +++ b/fronted.go @@ -13,6 +13,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -36,7 +37,7 @@ var ( // fronted identifies working IP address/domain pairings for domain fronting and is // an implementation of http.RoundTripper for the convenience of callers. type fronted struct { - certPool *x509.CertPool + certPool atomic.Value fronts sortedFronts maxAllowedCachedAge time.Duration maxCacheSize int @@ -77,7 +78,7 @@ func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProvider }(time.Now()) f := &fronted{ - certPool: nil, + certPool: atomic.Value{}, fronts: make(sortedFronts, 0), maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, @@ -98,6 +99,31 @@ func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProvider return f, nil } +// UpdateConfig sets the domain fronts to use, the trusted root CAs, the fronting providers +// (such as Akamai, Cloudfront, etc), and the cache file for caching fronts to set up +// domain fronting. +func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) { + // Make copies just to avoid any concurrency issues with access that may be happening on the + // caller side. + log.Debug("Updating fronted configuration") + if len(providers) == 0 { + log.Errorf("No providers configured") + return + } + providersCopy := copyProviders(providers) + f.frontedMu.Lock() + defer f.frontedMu.Unlock() + f.addProviders(providersCopy) + f.addFronts(loadFronts(providersCopy)) + + f.certPool.Store(pool) + + // The goroutine for finding working fronts runs forever, so only start it once. + f.crawlOnce.Do(func() { + go f.findWorkingFronts() + }) +} + func copyProviders(providers map[string]*Provider) map[string]*Provider { providersCopy := make(map[string]*Provider, len(providers)) for key, p := range providers { @@ -141,31 +167,6 @@ func loadFronts(providers map[string]*Provider) sortedFronts { return fronts } -// UpdateConfig sets the domain fronts to use, the trusted root CAs, the fronting providers -// (such as Akamai, Cloudfront, etc), and the cache file for caching fronts to set up -// domain fronting. -func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) { - // Make copies just to avoid any concurrency issues with access that may be happening on the - // caller side. - log.Debug("Updating fronted configuration") - if len(providers) == 0 { - log.Errorf("No providers configured") - return - } - providersCopy := copyProviders(providers) - f.frontedMu.Lock() - defer f.frontedMu.Unlock() - f.addProviders(providersCopy) - f.addFronts(loadFronts(providersCopy)) - - f.certPool = pool - - // The goroutine for finding working fronts runs forever, so only start it once. - f.crawlOnce.Do(func() { - go f.findWorkingFronts() - }) -} - func (f *fronted) addProviders(providers map[string]*Provider) { // Add new providers to the existing providers map, overwriting any existing ones. f.providersMu.Lock() @@ -194,10 +195,11 @@ func (f *fronted) providerFor(m Front) *Provider { // This is used in genconfig. func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { d := &fronted{ - certPool: pool, + certPool: atomic.Value{}, maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, } + d.certPool.Store(pool) masq := &front{Masquerade: *m} conn, _, err := d.doDial(masq) if err != nil { @@ -465,7 +467,7 @@ func (f *fronted) doDial(m Front) (net.Conn, bool, error) { var conn net.Conn var err error retriable := false - conn, err = m.dial(f.getCertPool(), f.clientHelloID) + conn, err = m.dial(f.certPool.Load().(*x509.CertPool), f.clientHelloID) if err != nil { if !isNetworkUnreachable(err) { op.FailIf(err) @@ -485,12 +487,6 @@ func (f *fronted) doDial(m Front) (net.Conn, bool, error) { return conn, retriable, err } -func (f *fronted) getCertPool() *x509.CertPool { - f.frontedMu.RLock() - defer f.frontedMu.RUnlock() - return f.certPool -} - func isNetworkUnreachable(err error) bool { var opErr *net.OpError if errors.As(err, &opErr) { From 67fc744f6bee094810952690201bff52dd9b7378 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 6 Dec 2024 09:16:32 -0700 Subject: [PATCH 22/25] Added comments --- connecting_fronts.go | 3 ++- fronted.go | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/connecting_fronts.go b/connecting_fronts.go index 39b178e..e0456fe 100644 --- a/connecting_fronts.go +++ b/connecting_fronts.go @@ -27,7 +27,7 @@ func newConnectingFronts(size int) *connecting { } } -// AddFront adds a new front to the list of fronts. +// onConnected adds a working front to the channel of working fronts. func (cf *connecting) onConnected(m Front) { cf.frontsCh <- m } @@ -35,6 +35,7 @@ func (cf *connecting) onConnected(m Front) { func (cf *connecting) connectingFront(ctx context.Context) (Front, error) { for { select { + // This is typically the context of the HTTP request. If the context is done, return an error. case <-ctx.Done(): return nil, ctx.Err() case m := <-cf.frontsCh: diff --git a/fronted.go b/fronted.go index bc0dd1f..d3574d5 100644 --- a/fronted.go +++ b/fronted.go @@ -144,6 +144,7 @@ func loadFronts(providers map[string]*Provider) sortedFronts { fronts := make(sortedFronts, size) + // Note that map iteration order is random, so the order of the providers is automatically randomized. index := 0 for key, p := range providers { arr := p.Masquerades From 9a7741a6000a3ea98c0ba5a18b5c40d340cdaead Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 6 Dec 2024 10:43:09 -0700 Subject: [PATCH 23/25] Use worker pool instead of waitgroup --- fronted.go | 76 ++++++++++++++++++++++++++++--------------------- fronted_test.go | 62 +++++++++++++--------------------------- go.mod | 1 + go.sum | 2 ++ 4 files changed, 67 insertions(+), 74 deletions(-) diff --git a/fronted.go b/fronted.go index d3574d5..7cb6ed0 100644 --- a/fronted.go +++ b/fronted.go @@ -21,6 +21,8 @@ import ( "github.com/getlantern/golog" "github.com/getlantern/ops" + + "github.com/alitto/pond/v2" ) const ( @@ -210,38 +212,66 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { return masq.postCheck(conn, testURL) } -// findWorkingFronts finds working domain fronts by vetting them in batches and in -// parallel. Speed is of the essence here, as without working fronts, users will +// findWorkingFronts finds working domain fronts by testing them using a worker pool. Speed +// is of the essence here, as without working fronts, users will // be unable to fetch proxy configurations, particularly in the case of a first time // user who does not have proxies cached on disk. func (f *fronted) findWorkingFronts() { - // vet fronts in batches - const batchSize int = 40 - // Keep looping through all fronts making sure we have working ones. - i := 0 for { - // Continually loop through the fronts in batches until we have 4 working ones, - // always looping around to the beginning if we reach the end. + // Continually loop through the fronts until we have 4 working ones. // This is important, for example, when the user goes offline and all fronts start failing. // We want to just keep trying in that case so that we find working fronts as soon as they // come back online. - if f.connectingFronts.size() < 4 { - f.vetBatch(i, batchSize) - i = index(i, batchSize, f.frontSize()) + if !f.hasEnoughWorkingFronts() { + // Note that trying all fronts takes awhile, as it only completes when we either + // have enough working fronts, or we've tried all of them. + log.Debug("findWorkingFronts::Trying all fronts") + f.tryAllFronts() + log.Debug("findWorkingFronts::Tried all fronts") } else { + log.Debug("findWorkingFronts::Enough working fronts...sleeping") select { case <-f.stopCh: - log.Debug("Stopping parallel dialing") + log.Debug("findWorkingFronts::Stopping parallel dialing") return case <-time.After(time.Duration(rand.IntN(12000)) * time.Millisecond): + // Run again after a random time between 0 and 12 seconds } } } } -func index(i, batchSize, size int) int { - return (i + batchSize) % size +func (f *fronted) tryAllFronts() { + // Vet fronts using a worker pool of 40 goroutines. + pool := pond.NewPool(40) + + // Submit all fronts to the worker pool. + for i := 0; i < f.frontSize(); i++ { + i := i + m := f.frontAt(i) + pool.Submit(func() { + log.Debugf("Running task #%d with front %v", i, m.getIpAddress()) + if f.hasEnoughWorkingFronts() { + // We have enough working fronts, so no need to continue. + log.Debug("Enough working fronts...ignoring task") + return + } + working := f.vetFront(m) + if working { + f.connectingFronts.onConnected(m) + } else { + m.markFailed() + } + }) + } + + // Stop the pool and wait for all submitted tasks to complete + pool.StopAndWait() +} + +func (f *fronted) hasEnoughWorkingFronts() bool { + return f.connectingFronts.size() >= 4 } func (f *fronted) frontSize() int { @@ -256,24 +286,6 @@ func (f *fronted) frontAt(i int) Front { return f.fronts[i] } -func (f *fronted) vetBatch(start, batchSize int) { - log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) - var wg sync.WaitGroup - for i := start; i < start+batchSize && i < f.frontSize(); i++ { - wg.Add(1) - go func(m Front) { - defer wg.Done() - working := f.vetFront(m) - if working { - f.connectingFronts.onConnected(m) - } else { - m.markFailed() - } - }(f.frontAt(i)) - } - wg.Wait() -} - func (f *fronted) vetFront(m Front) bool { conn, masqueradeGood, err := f.dialFront(m) if err != nil { diff --git a/fronted_test.go b/fronted_test.go index a82ef65..e1a5b21 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -799,33 +799,33 @@ func TestFindWorkingMasquerades(t *testing.T) { { name: "All successful", masquerades: []*mockFront{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain2.com", "2.2.2.2", 0, true), - newMockMasquerade("domain3.com", "3.3.3.3", 0, true), - newMockMasquerade("domain4.com", "4.4.4.4", 0, true), - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockFront("domain1.com", "1.1.1.1", 0, true), + newMockFront("domain2.com", "2.2.2.2", 0, true), + newMockFront("domain3.com", "3.3.3.3", 0, true), + newMockFront("domain4.com", "4.4.4.4", 0, true), + newMockFront("domain1.com", "1.1.1.1", 0, true), + newMockFront("domain1.com", "1.1.1.1", 0, true), }, expectedSuccessful: 4, }, { name: "Some successful", masquerades: []*mockFront{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain2.com", "2.2.2.2", 0, false), - newMockMasquerade("domain3.com", "3.3.3.3", 0, true), - newMockMasquerade("domain4.com", "4.4.4.4", 0, false), - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockFront("domain1.com", "1.1.1.1", 0, true), + newMockFront("domain2.com", "2.2.2.2", 0, false), + newMockFront("domain3.com", "3.3.3.3", 0, true), + newMockFront("domain4.com", "4.4.4.4", 0, false), + newMockFront("domain1.com", "1.1.1.1", 0, true), }, expectedSuccessful: 2, }, { name: "None successful", masquerades: []*mockFront{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, false), - newMockMasquerade("domain2.com", "2.2.2.2", 0, false), - newMockMasquerade("domain3.com", "3.3.3.3", 0, false), - newMockMasquerade("domain4.com", "4.4.4.4", 0, false), + newMockFront("domain1.com", "1.1.1.1", 0, false), + newMockFront("domain2.com", "2.2.2.2", 0, false), + newMockFront("domain3.com", "3.3.3.3", 0, false), + newMockFront("domain4.com", "4.4.4.4", 0, false), }, expectedSuccessful: 0, }, @@ -834,7 +834,7 @@ func TestFindWorkingMasquerades(t *testing.T) { masquerades: func() []*mockFront { var masquerades []*mockFront for i := 0; i < 50; i++ { - masquerades = append(masquerades, newMockMasquerade(fmt.Sprintf("domain%d.com", i), fmt.Sprintf("1.1.1.%d", i), 0, i%2 == 0)) + masquerades = append(masquerades, newMockFront(fmt.Sprintf("domain%d.com", i), fmt.Sprintf("1.1.1.%d", i), 0, i%2 == 0)) } return masquerades }(), @@ -855,7 +855,7 @@ func TestFindWorkingMasquerades(t *testing.T) { f.fronts[i] = m } - f.vetBatch(0, 10) + f.tryAllFronts() tries := 0 for f.connectingFronts.size() < tt.expectedSuccessful && tries < 100 { @@ -900,37 +900,15 @@ func TestLoadFronts(t *testing.T) { } } -func TestIndex(t *testing.T) { - tests := []struct { - i, batchSize, size int - expected int - }{ - {i: 0, batchSize: 10, size: 100, expected: 10}, - {i: 5, batchSize: 10, size: 100, expected: 15}, - {i: 95, batchSize: 10, size: 100, expected: 5}, - {i: 99, batchSize: 10, size: 100, expected: 9}, - {i: 0, batchSize: 5, size: 20, expected: 5}, - {i: 15, batchSize: 5, size: 20, expected: 0}, - {i: 18, batchSize: 5, size: 20, expected: 3}, - } - - for _, test := range tests { - t.Run(fmt.Sprintf("i=%d,batchSize=%d,size=%d", test.i, test.batchSize, test.size), func(t *testing.T) { - result := index(test.i, test.batchSize, test.size) - assert.Equal(t, test.expected, result) - }) - } -} - // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. -func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockFront { - return newMockMasqueradeWithLastSuccess(domain, ipAddress, timeout, passesCheck, time.Time{}) +func newMockFront(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockFront { + return newMockFrontWithLastSuccess(domain, ipAddress, timeout, passesCheck, time.Time{}) } // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. -func newMockMasqueradeWithLastSuccess(domain string, ipAddress string, timeout time.Duration, passesCheck bool, lastSucceededTime time.Time) *mockFront { +func newMockFrontWithLastSuccess(domain string, ipAddress string, timeout time.Duration, passesCheck bool, lastSucceededTime time.Time) *mockFront { return &mockFront{ Domain: domain, IpAddress: ipAddress, diff --git a/go.mod b/go.mod index 16df9fc..c4523a9 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.6 toolchain go1.23.3 require ( + github.com/alitto/pond/v2 v2.1.5 github.com/getlantern/golog v0.0.0-20230503153817-8e72de7e0a65 github.com/getlantern/keyman v0.0.0-20200819205636-76fef27c39f1 github.com/getlantern/netx v0.0.0-20240814210628-0984f52e2d18 diff --git a/go.sum b/go.sum index 444b727..93f5e6b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/alitto/pond/v2 v2.1.5 h1:2pp/KAPcb02NSpHsjjnxnrTDzogMLsq+vFf/L0DB84A= +github.com/alitto/pond/v2 v2.1.5/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= From 2954726ca423469fd493882c9a2d8b727ba1f77d Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 6 Dec 2024 13:24:33 -0700 Subject: [PATCH 24/25] Use a random range to avoid quick checks --- fronted.go | 6 +++++- fronted_test.go | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/fronted.go b/fronted.go index 7cb6ed0..d717eeb 100644 --- a/fronted.go +++ b/fronted.go @@ -235,7 +235,7 @@ func (f *fronted) findWorkingFronts() { case <-f.stopCh: log.Debug("findWorkingFronts::Stopping parallel dialing") return - case <-time.After(time.Duration(rand.IntN(12000)) * time.Millisecond): + case <-time.After(time.Duration(randRange(6, 12)) * time.Second): // Run again after a random time between 0 and 12 seconds } } @@ -623,3 +623,7 @@ func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) } return r, nil } + +func randRange(min, max int) int { + return rand.IntN(max-min) + min +} diff --git a/fronted_test.go b/fronted_test.go index e1a5b21..2ac5fb0 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -900,6 +900,28 @@ func TestLoadFronts(t *testing.T) { } } +func TestRandRange(t *testing.T) { + tests := []struct { + min, max int + }{ + {1, 10}, + {5, 15}, + {0, 100}, + {-10, 10}, + {50, 60}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("min=%d,max=%d", tt.min, tt.max), func(t *testing.T) { + for i := 0; i < 100; i++ { + result := randRange(tt.min, tt.max) + assert.GreaterOrEqual(t, result, tt.min) + assert.Less(t, result, tt.max) + } + }) + } +} + // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. func newMockFront(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockFront { From 7bc4f00b82815f467eb6d3de85df7d263f20e09b Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 6 Dec 2024 13:51:25 -0700 Subject: [PATCH 25/25] Do not keep vetting if stopped --- cache.go | 7 ------- fronted.go | 21 ++++++++++++++++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/cache.go b/cache.go index e4ff8c6..54290a4 100644 --- a/cache.go +++ b/cache.go @@ -102,10 +102,3 @@ func (d *fronted) updateCache(cacheFile string) { log.Debugf("Cache saved to disk") } } - -func (d *fronted) Close() { - d.closeCacheOnce.Do(func() { - close(d.cacheClosed) - }) - d.stopCh <- nil -} diff --git a/fronted.go b/fronted.go index d717eeb..f2d6dff 100644 --- a/fronted.go +++ b/fronted.go @@ -56,6 +56,7 @@ type fronted struct { frontedMu sync.RWMutex stopCh chan interface{} crawlOnce sync.Once + stopped atomic.Bool } // Interface for sending HTTP traffic over domain fronting. @@ -248,13 +249,15 @@ func (f *fronted) tryAllFronts() { // Submit all fronts to the worker pool. for i := 0; i < f.frontSize(); i++ { - i := i m := f.frontAt(i) pool.Submit(func() { - log.Debugf("Running task #%d with front %v", i, m.getIpAddress()) + //log.Debugf("Running task #%d with front %v", i, m.getIpAddress()) + if f.isStopped() { + return + } if f.hasEnoughWorkingFronts() { // We have enough working fronts, so no need to continue. - log.Debug("Enough working fronts...ignoring task") + //log.Debug("Enough working fronts...ignoring task") return } working := f.vetFront(m) @@ -627,3 +630,15 @@ func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) func randRange(min, max int) int { return rand.IntN(max-min) + min } + +func (f *fronted) Close() { + f.stopped.Store(true) + f.closeCacheOnce.Do(func() { + close(f.cacheClosed) + }) + f.stopCh <- nil +} + +func (f *fronted) isStopped() bool { + return f.stopped.Load() +}