From 9b577b9a1d7456a4dbb78c7a3e6962c1bb94ee04 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 3 Dec 2024 11:38:16 -0700 Subject: [PATCH] Change to only modify global config details via update call --- fronted.go | 44 +++++++++++++++++++++++++------------------- fronted_test.go | 28 +++++++++++++++++++++------- test_support.go | 12 ++++++++++-- 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/fronted.go b/fronted.go index f8becc5..6c7cad6 100644 --- a/fronted.go +++ b/fronted.go @@ -51,6 +51,7 @@ type fronted struct { frontsMu sync.RWMutex frontedMu sync.RWMutex stopCh chan interface{} + crawlOnce sync.Once } // Interface for sending HTTP traffic over domain fronting. @@ -58,7 +59,7 @@ type Fronted interface { http.RoundTripper // UpdateConfig updates the set of domain fronts to try. - UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) + UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) // Close closes any resources, such as goroutines that are testing fronts. Close() @@ -70,40 +71,31 @@ type Fronted interface { // // defaultProviderID is used when a front without a provider is // encountered (eg in a cache file) -func NewFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, - clientHello tls.ClientHelloID) (Fronted, error) { +func NewFronted(cacheFile string, clientHello tls.ClientHelloID, defaultProviderID string) (Fronted, error) { log.Debug("Creating new fronted") // Log method elapsed time defer func(start time.Time) { log.Debugf("Creating a new fronted took %v", time.Since(start)) }(time.Now()) - if len(providers) == 0 { - return nil, log.Errorf("No providers configured") - } - - providersCopy := copyProviders(providers) - fronts := loadFronts(providersCopy) - f := &fronted{ - certPool: pool, - fronts: fronts, + certPool: nil, + fronts: make(sortedFronts, 0), maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, cacheSaveInterval: defaultCacheSaveInterval, cacheDirty: make(chan interface{}, 1), cacheClosed: make(chan interface{}), - defaultProviderID: defaultProviderID, - providers: providersCopy, + providers: make(map[string]*Provider), clientHelloID: clientHello, - workingFronts: newConnectingFronts(len(fronts)), + workingFronts: newConnectingFronts(4000), stopCh: make(chan interface{}), + defaultProviderID: defaultProviderID, } if cacheFile != "" { f.initCaching(cacheFile) } - go f.findWorkingFronts() return f, nil } @@ -151,17 +143,25 @@ func loadFronts(providers map[string]*Provider) sortedFronts { return fronts } -func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { +func (f *fronted) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider) { // Make copies just to avoid any concurrency issues with access that may be happening on the // caller side. log.Debug("Updating fronted configuration") + if len(providers) == 0 { + log.Errorf("No providers configured") + return + } providersCopy := copyProviders(providers) f.frontedMu.Lock() defer f.frontedMu.Unlock() f.addProviders(providersCopy) f.addFronts(loadFronts(providersCopy)) - f.defaultProviderID = defaultProviderID + f.certPool = pool + + f.crawlOnce.Do(func() { + go f.findWorkingFronts() + }) } func (f *fronted) addProviders(providers map[string]*Provider) { @@ -457,7 +457,7 @@ func (f *fronted) doDial(m Front) (net.Conn, bool, error) { var conn net.Conn var err error retriable := false - conn, err = m.dial(f.certPool, f.clientHelloID) + conn, err = m.dial(f.getCertPool(), f.clientHelloID) if err != nil { if !isNetworkUnreachable(err) { op.FailIf(err) @@ -477,6 +477,12 @@ func (f *fronted) doDial(m Front) (net.Conn, bool, error) { return conn, retriable, err } +func (f *fronted) getCertPool() *x509.CertPool { + f.frontedMu.RLock() + defer f.frontedMu.RUnlock() + return f.certPool +} + func isNetworkUnreachable(err error) bool { var opErr *net.OpError if errors.As(err, &opErr) { diff --git a/fronted_test.go b/fronted_test.go index a504ca2..b6a21bb 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -54,7 +54,9 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { UseArbitrarySNIs: true, ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) - transport, err := NewFronted(certs, p, "akamai", cacheFile, tls.HelloChrome_100) + transport, err := NewFronted(cacheFile, tls.HelloChrome_100, "akamai") + require.NoError(t, err) + transport.UpdateConfig(certs, p) client := &http.Client{ Transport: transport, @@ -80,8 +82,9 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - transport, err := NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + transport, err := NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) require.NoError(t, err) + transport.UpdateConfig(certs, p) client := &http.Client{ Transport: transport, @@ -89,8 +92,9 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, err = NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + transport, err = NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) require.NoError(t, err) + transport.UpdateConfig(certs, p) client = &http.Client{ Transport: transport, } @@ -203,7 +207,9 @@ func TestHostAliasesBasic(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - rt, err := NewFronted(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "", tls.HelloChrome_100) + rt, err := NewFronted("", tls.HelloChrome_100, "cloudsack") + require.NoError(t, err) + rt.UpdateConfig(certs, map[string]*Provider{"cloudsack": p}) client := &http.Client{Transport: rt} for _, test := range tests { @@ -311,7 +317,9 @@ func TestHostAliasesMulti(t *testing.T) { "sadcloud": p2, } - rt, err := NewFronted(certs, providers, "cloudsack", "", tls.HelloChrome_100) + rt, err := NewFronted("", tls.HelloChrome_100, "cloudsack") + require.NoError(t, err) + rt.UpdateConfig(certs, providers) client := &http.Client{Transport: rt} @@ -434,8 +442,9 @@ func TestPassthrough(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - rt, err := NewFronted(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "", tls.HelloChrome_100) + rt, err := NewFronted("", tls.HelloChrome_100, "cloudsack") require.NoError(t, err) + rt.UpdateConfig(certs, map[string]*Provider{"cloudsack": p}) client := &http.Client{Transport: rt} for _, test := range tests { @@ -503,7 +512,12 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - return NewFronted(certs, providers, "sadcloud", "", tls.HelloChrome_100) + f, err := NewFronted("", tls.HelloChrome_100, "sadcloud") + if err != nil { + return nil, err + } + f.UpdateConfig(certs, providers) + return f, nil } // This error indicates that the validator has discarded all masquerades. diff --git a/test_support.go b/test_support.go index 55c936c..9d3a2b9 100644 --- a/test_support.go +++ b/test_support.go @@ -24,13 +24,21 @@ func ConfigureForTest(t *testing.T) { func ConfigureCachingForTest(t *testing.T, cacheFile string) { certs := trustedCACerts(t) p := testProviders() - NewFronted(certs, p, testProviderID, cacheFile, tls.HelloChrome_100) + f, err := NewFronted(cacheFile, tls.HelloChrome_100, testProviderID) + if err != nil { + t.Fatalf("Unable to create fronted: %v", err) + } + f.UpdateConfig(certs, p) } func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - NewFronted(certs, p, testProviderID, "", tls.HelloChrome_100) + f, err := NewFronted("", tls.HelloChrome_100, testProviderID) + if err != nil { + t.Fatalf("Unable to create fronted: %v", err) + } + f.UpdateConfig(certs, p) } func trustedCACerts(t *testing.T) *x509.CertPool {