From aaeb3c13ca5c2182f1f48f01449c4fd1b6ffca1e Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 22 Oct 2024 15:31:15 -0600 Subject: [PATCH 01/22] Parallelize masquerade lookup --- context.go | 2 +- direct.go | 174 ++++++++++++++++++++++++------------------------- direct_test.go | 4 +- 3 files changed, 88 insertions(+), 92 deletions(-) diff --git a/context.go b/context.go index 37e7ac9..c9c7730 100644 --- a/context.go +++ b/context.go @@ -104,7 +104,7 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m if cacheFile != "" { d.initCaching(cacheFile) } - go d.vet(numberToVetInitially) + d.findWorkingMasquerades() fctx.instance.Set(d) return nil } diff --git a/direct.go b/direct.go index 3a70035..3c6da45 100644 --- a/direct.go +++ b/direct.go @@ -13,6 +13,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" tls "github.com/refraction-networking/utls" @@ -25,7 +26,6 @@ import ( ) const ( - numberToVetInitially = 10 defaultMaxAllowedCachedAge = 24 * time.Hour defaultMaxCacheSize = 1000 defaultCacheSaveInterval = 5 * time.Second @@ -84,17 +84,9 @@ func (d *direct) providerFor(m *masquerade) *Provider { return d.providers[pid] } -// Vet vets the specified Masquerade, verifying certificate using the given CertPool +// Vet vets the specified Masquerade, verifying certificate using the given CertPool. +// This is used in genconfig. func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - return vet(m, pool, testURL) -} - -func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - op := ops.Begin("vet_masquerade") - defer op.End() - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) - d := &direct{ certPool: pool, maxAllowedCachedAge: defaultMaxAllowedCachedAge, @@ -102,54 +94,55 @@ func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { } conn, _, err := d.doDial(m) if err != nil { - op.FailIf(err) return false } defer conn.Close() return postCheck(conn, testURL) } -func (d *direct) vet(numberToVet int) { - log.Debugf("Vetting %d initial candidates in series", numberToVet) - for i := 0; i < numberToVet; i++ { - d.vetOne() - } -} - -func (d *direct) vetOne() { - // We're just testing the ability to connect here, destination site doesn't - // really matter - log.Debug("Vetting one") - unvettedMasquerades := make([]*masquerade, 0, len(d.masquerades)) - for _, m := range d.masquerades { - if m.lastSucceeded().IsZero() { - unvettedMasquerades = append(unvettedMasquerades, m) +func (d *direct) findWorkingMasquerades() { + // vet masquerades in batches + const batchSize int = 25 + var successful atomic.Uint32 + for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize { + var wg sync.WaitGroup + for j := i; j < i+batchSize && j < len(d.masquerades); j++ { + wg.Add(1) + go func(m *masquerade) { + defer wg.Done() + if d.vetMasquerade(m) { + successful.Add(1) + } + }(d.masquerades[j]) } + wg.Wait() } +} - // Don't take more than 10 seconds to dial a masquerade for vetting - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, m, masqueradeGood, err := d.dialWith(ctx, unvettedMasquerades) +func (d *direct) vetMasquerade(m *masquerade) bool { + conn, masqueradeGood, err := d.dialMasquerade(m) if err != nil { log.Errorf("unexpected error vetting masquerades: %v", err) - return + return false } - defer conn.Close() + defer func() { + if conn != nil { + conn.Close() + } + }() provider := d.providerFor(m) if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID) - return + return false } - if !masqueradeGood(postCheck(conn, provider.TestURL)) { log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") - return + return false } - log.Debug("Finished vetting one") + log.Debugf("Finished vetting one masquerade %v", m) + return true } // postCheck does a post with invalid data to verify domain-fronting works @@ -187,7 +180,7 @@ func doCheck(client *http.Client, method string, expectedStatus int, u string) b op.Set("response_status", resp.StatusCode) op.Set("expected_status", expectedStatus) msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status) - op.FailIf(fmt.Errorf(msg)) + op.FailIf(errors.New(msg)) log.Debug(msg) return false } @@ -247,7 +240,7 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := d.dial(req.Context()) + conn, m, masqueradeGood, err := d.dialAll(req.Context()) if err != nil { // unable to find good masquerade, fail op.FailIf(err) @@ -305,36 +298,13 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e return nil, nil, op.FailIf(errors.New("could not complete request even with retries")) } -func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { - url := *req.URL - url.Host = frontedHost - r, err := http.NewRequest(req.Method, url.String(), body) - if err != nil { - return nil, err - } - - for k, vs := range req.Header { - if !strings.EqualFold(k, "Host") { - v := make([]string, len(vs)) - copy(v, vs) - r.Header[k] = v - } - } - return r, nil -} - -// Dial dials out using a masquerade. If the available masquerade fails, it -// retries with others until it either succeeds or exhausts the available -// masquerades. If successful, it returns a connection to the masquerade, -// the selected masquerade, and a function that the caller can use to -// tell us whether the masquerade is good or not (i.e. if masquerade was good, -// keep it). -func (d *direct) dial(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { - conn, m, masqueradeGood, err := d.dialWith(ctx, d.masquerades) +// Dial dials out using all available masquerades until one succeeds. +func (d *direct) dialAll(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { + conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { +func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { // never take more than a minute trying to find a dialer ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() @@ -343,7 +313,6 @@ func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (n totalMasquerades := len(masqueradesToTry) dialLoop: for _, m := range masqueradesToTry { - // check to see if we've timed out select { case <-ctx.Done(): log.Debugf("Timed out dialing to %v with %v total masquerades", m, totalMasquerades) @@ -351,34 +320,43 @@ dialLoop: default: // okay } - - log.Tracef("Dialing to %v", m) - - // We do the full TLS connection here because in practice the domains at a given IP - // address can change frequently on CDNs, so the certificate may not match what - // we expect. - conn, retriable, err := d.doDial(&m.Masquerade) - masqueradeGood := func(good bool) bool { - if good { - m.markSucceeded() - } else { - m.markFailed() - } - d.markCacheDirty() - return good - } + conn, masqueradeGood, err := d.dialMasquerade(m) if err == nil { - log.Debug("Returning connection") - return conn, m, masqueradeGood, err - } else if !retriable { - log.Debugf("Dropping masquerade: non retryable error: %v", err) - masqueradeGood(false) + return conn, m, masqueradeGood, nil } } return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) } +func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error) { + // check to see if we've timed out + + log.Tracef("Dialing to %v", m) + + // We do the full TLS connection here because in practice the domains at a given IP + // address can change frequently on CDNs, so the certificate may not match what + // we expect. + conn, retriable, err := d.doDial(&m.Masquerade) + masqueradeGood := func(good bool) bool { + if good { + m.markSucceeded() + } else { + m.markFailed() + } + d.markCacheDirty() + return good + } + if err == nil { + log.Debug("Returning connection") + return conn, masqueradeGood, err + } else if !retriable { + log.Debugf("Dropping masquerade: non retryable error: %v", err) + masqueradeGood(false) + } + return conn, masqueradeGood, err +} + func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() @@ -551,3 +529,21 @@ func (ddf *directTransport) RoundTrip(req *http.Request) (resp *http.Response, e norm.URL.Scheme = "http" return ddf.Transport.RoundTrip(norm) } + +func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { + url := *req.URL + url.Host = frontedHost + r, err := http.NewRequest(req.Method, url.String(), body) + if err != nil { + return nil, err + } + + for k, vs := range req.Header { + if !strings.EqualFold(k, "Host") { + v := make([]string, len(vs)) + copy(v, vs) + r.Header[k] = v + } + } + return r, nil +} diff --git a/direct_test.go b/direct_test.go index 89a3a05..2e9096f 100644 --- a/direct_test.go +++ b/direct_test.go @@ -27,11 +27,11 @@ func TestDirectDomainFronting(t *testing.T) { require.NoError(t, err, "Unable to create temp dir") defer os.RemoveAll(dir) cacheFile := filepath.Join(dir, "cachefile.2") - doTestDomainFronting(t, cacheFile, numberToVetInitially) + doTestDomainFronting(t, cacheFile, 10) time.Sleep(defaultCacheSaveInterval * 2) // Then try again, this time reusing the existing cacheFile but a corrupted version corruptMasquerades(cacheFile) - doTestDomainFronting(t, cacheFile, numberToVetInitially) + doTestDomainFronting(t, cacheFile, 10) } func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { From 5985882e5b6184cae19869c18313d4d7a1a4233d Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Wed, 23 Oct 2024 15:33:17 -0600 Subject: [PATCH 02/22] Refactor to make everything more testable --- cache.go | 6 +- direct.go | 153 +++++++++++++---------------------------------- direct_test.go | 158 ++++++++++++++++++++++++++++++++++++++++++++++++- go.mod | 1 + masquerade.go | 136 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 334 insertions(+), 120 deletions(-) diff --git a/cache.go b/cache.go index 43391bb..0007afe 100644 --- a/cache.go +++ b/cache.go @@ -38,10 +38,10 @@ func (d *direct) prepopulateMasquerades(cacheFile string) { // update last succeeded status of masquerades based on cached values for _, m := range d.masquerades { for _, cm := range cachedMasquerades { - sameMasquerade := cm.ProviderID == m.ProviderID && cm.Domain == m.Domain && cm.IpAddress == m.IpAddress - cachedValueFresh := now.Sub(m.LastSucceeded) < d.maxAllowedCachedAge + sameMasquerade := cm.ProviderID == m.getProviderID() && cm.Domain == m.getDomain() && cm.IpAddress == m.getIpAddress() + cachedValueFresh := now.Sub(m.lastSucceeded()) < d.maxAllowedCachedAge if sameMasquerade && cachedValueFresh { - m.LastSucceeded = cm.LastSucceeded + m.setLastSucceeded(cm.LastSucceeded) } } } diff --git a/direct.go b/direct.go index 3c6da45..c79d8a4 100644 --- a/direct.go +++ b/direct.go @@ -20,9 +20,7 @@ import ( "github.com/getlantern/golog" "github.com/getlantern/idletiming" - "github.com/getlantern/netx" "github.com/getlantern/ops" - "github.com/getlantern/tlsdialer/v3" ) const ( @@ -76,8 +74,8 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { } } -func (d *direct) providerFor(m *masquerade) *Provider { - pid := m.ProviderID +func (d *direct) providerFor(m MasqueradeInterface) *Provider { + pid := m.getProviderID() if pid == "" { pid = d.defaultProviderID } @@ -92,12 +90,13 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, } - conn, _, err := d.doDial(m) + masq := &masquerade{Masquerade: *m} + conn, _, err := d.doDial(masq) if err != nil { return false } defer conn.Close() - return postCheck(conn, testURL) + return masq.postCheck(conn, testURL) } func (d *direct) findWorkingMasquerades() { @@ -105,21 +104,25 @@ func (d *direct) findWorkingMasquerades() { const batchSize int = 25 var successful atomic.Uint32 for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize { - var wg sync.WaitGroup - for j := i; j < i+batchSize && j < len(d.masquerades); j++ { - wg.Add(1) - go func(m *masquerade) { - defer wg.Done() - if d.vetMasquerade(m) { - successful.Add(1) - } - }(d.masquerades[j]) - } - wg.Wait() + d.vetGroup(i, batchSize, &successful) + } +} + +func (d *direct) vetGroup(start, batchSize int, successful *atomic.Uint32) { + var wg sync.WaitGroup + for j := start; j < start+batchSize && j < len(d.masquerades); j++ { + wg.Add(1) + go func(m MasqueradeInterface) { + defer wg.Done() + if d.vetMasquerade(m) { + successful.Add(1) + } + }(d.masquerades[j]) } + wg.Wait() } -func (d *direct) vetMasquerade(m *masquerade) bool { +func (d *direct) vetMasquerade(m MasqueradeInterface) bool { conn, masqueradeGood, err := d.dialMasquerade(m) if err != nil { log.Errorf("unexpected error vetting masquerades: %v", err) @@ -133,10 +136,11 @@ func (d *direct) vetMasquerade(m *masquerade) bool { provider := d.providerFor(m) if provider == nil { - log.Debugf("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID) + log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", + m.getProviderID(), d.providers) return false } - if !masqueradeGood(postCheck(conn, provider.TestURL)) { + if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") return false } @@ -145,48 +149,6 @@ func (d *direct) vetMasquerade(m *masquerade) bool { return true } -// postCheck does a post with invalid data to verify domain-fronting works -func postCheck(conn net.Conn, testURL string) bool { - client := &http.Client{ - Transport: frontedHTTPTransport(conn, true), - } - return doCheck(client, http.MethodPost, http.StatusAccepted, testURL) -} - -func doCheck(client *http.Client, method string, expectedStatus int, u string) bool { - op := ops.Begin("check_masquerade") - defer op.End() - - isPost := method == http.MethodPost - var requestBody io.Reader - if isPost { - requestBody = strings.NewReader("a") - } - req, _ := http.NewRequest(method, u, requestBody) - if isPost { - req.Header.Set("Content-Type", "application/json") - } - resp, err := client.Do(req) - if err != nil { - op.FailIf(err) - log.Debugf("Unsuccessful vetting with %v request, discarding masquerade: %v", method, err) - return false - } - if resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - } - if resp.StatusCode != expectedStatus { - op.Set("response_status", resp.StatusCode) - op.Set("expected_status", expectedStatus) - msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status) - op.FailIf(errors.New(msg)) - log.Debug(msg) - return false - } - return true -} - // Do continually retries a given request until it succeeds because some // fronting providers will return a 403 for some domains. func (d *direct) RoundTrip(req *http.Request) (*http.Response, error) { @@ -248,7 +210,7 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e } provider := d.providerFor(m) if provider == nil { - log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.ProviderID) + log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.getProviderID()) masqueradeGood(false) continue } @@ -258,11 +220,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e // so it is returned as good. conn.Close() masqueradeGood(true) - err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s", m.ProviderID, originHost) + err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s", + m.getProviderID(), originHost) op.FailIf(err) return nil, nil, err } - log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, m.ProviderID) + log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, m.getProviderID()) reqi, err := cloneRequestWith(req, frontedHost, getBody()) if err != nil { @@ -299,12 +262,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e } // Dial dials out using all available masquerades until one succeeds. -func (d *direct) dialAll(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { +func (d *direct) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { +func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) { // never take more than a minute trying to find a dialer ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() @@ -329,7 +292,7 @@ dialLoop: return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) } -func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error) { +func (d *direct) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { // check to see if we've timed out log.Tracef("Dialing to %v", m) @@ -337,7 +300,7 @@ func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error // We do the full TLS connection here because in practice the domains at a given IP // address can change frequently on CDNs, so the certificate may not match what // we expect. - conn, retriable, err := d.doDial(&m.Masquerade) + conn, retriable, err := d.doDial(m) masqueradeGood := func(good bool) bool { if good { m.markSucceeded() @@ -357,16 +320,16 @@ func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error return conn, masqueradeGood, err } -func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) { +func (d *direct) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) + op.Set("masquerade_domain", m.getDomain()) + op.Set("masquerade_ip", m.getIpAddress()) conn, err = d.dialServerWith(m) if err != nil { op.FailIf(err) - log.Debugf("Could not dial to %v, %v", m.IpAddress, err) + log.Debugf("Could not dial to %v, %v", m.getIpAddress(), err) // Don't re-add this candidate if it's any certificate error, as that // will just keep failing and will waste connections. We can't access the underlying // error at this point so just look for "certificate" and "handshake". @@ -389,50 +352,16 @@ func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error return } -func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { +func (d *direct) dialServerWith(m MasqueradeInterface) (net.Conn, error) { op := ops.Begin("dial_server_with") defer op.End() - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) - - tlsConfig := d.frontingTLSConfig(m) - dialTimeout := 10 * time.Second - addr := m.IpAddress - var sendServerNameExtension bool - - if m.SNI != "" { - sendServerNameExtension = true - - op.Set("arbitrary_sni", m.SNI) - tlsConfig.ServerName = m.SNI - tlsConfig.InsecureSkipVerify = true - tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - var verifyHostname string - if m.VerifyHostname != nil { - verifyHostname = *m.VerifyHostname - op.Set("verify_hostname", verifyHostname) - } - return verifyPeerCertificate(rawCerts, d.certPool, verifyHostname) - } - - } + op.Set("masquerade_domain", m.getDomain()) + op.Set("masquerade_ip", m.getIpAddress()) - _, _, err := net.SplitHostPort(addr) - if err != nil { - addr = net.JoinHostPort(addr, "443") - } - - dialer := &tlsdialer.Dialer{ - DoDial: netx.DialTimeout, - Timeout: dialTimeout, - SendServerName: sendServerNameExtension, - Config: tlsConfig, - ClientHelloID: d.clientHelloID, - } - conn, err := dialer.Dial("tcp", addr) + conn, err := m.dial(d.certPool, d.clientHelloID) if err != nil && m != nil { - err = fmt.Errorf("unable to dial masquerade %s: %s", m.Domain, err) + err = fmt.Errorf("unable to dial masquerade %s: %s", m.getDomain(), err) op.FailIf(err) } return conn, err diff --git a/direct_test.go b/direct_test.go index 2e9096f..365ae33 100644 --- a/direct_test.go +++ b/direct_test.go @@ -3,8 +3,10 @@ package fronted import ( "crypto/x509" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -16,10 +18,10 @@ import ( "testing" "time" + . "github.com/getlantern/waitforserver" + tls "github.com/refraction-networking/utls" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - . "github.com/getlantern/waitforserver" ) func TestDirectDomainFronting(t *testing.T) { @@ -144,7 +146,7 @@ func TestLoadCandidates(t *testing.T) { actual := make(map[Masquerade]bool) count := 0 for _, m := range d.masquerades { - actual[Masquerade{Domain: m.Domain, IpAddress: m.IpAddress}] = true + actual[Masquerade{Domain: m.getDomain(), IpAddress: m.getIpAddress()}] = true count++ } @@ -804,3 +806,153 @@ func TestVerifyPeerCertificate(t *testing.T) { }) } } + +func TestFindWorkingMasquerades(t *testing.T) { + tests := []struct { + name string + masquerades []*mockMasquerade + expectedSuccessful int + expectedMasquerades int + }{ + { + name: "All successful", + masquerades: []*mockMasquerade{ + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockMasquerade("domain2.com", "2.2.2.2", 0, true), + newMockMasquerade("domain3.com", "3.3.3.3", 0, true), + newMockMasquerade("domain4.com", "4.4.4.4", 0, true), + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + }, + expectedSuccessful: 4, + }, + { + name: "Some successful", + masquerades: []*mockMasquerade{ + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockMasquerade("domain2.com", "2.2.2.2", 0, false), + newMockMasquerade("domain3.com", "3.3.3.3", 0, true), + newMockMasquerade("domain4.com", "4.4.4.4", 0, false), + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + }, + expectedSuccessful: 2, + }, + { + name: "None successful", + masquerades: []*mockMasquerade{ + newMockMasquerade("domain1.com", "1.1.1.1", 0, false), + newMockMasquerade("domain2.com", "2.2.2.2", 0, false), + newMockMasquerade("domain3.com", "3.3.3.3", 0, false), + newMockMasquerade("domain4.com", "4.4.4.4", 0, false), + }, + expectedSuccessful: 0, + }, + { + name: "Batch processing", + masquerades: func() []*mockMasquerade { + var masquerades []*mockMasquerade + for i := 0; i < 50; i++ { + masquerades = append(masquerades, newMockMasquerade(fmt.Sprintf("domain%d.com", i), fmt.Sprintf("1.1.1.%d", i), 0, i%2 == 0)) + } + return masquerades + }(), + expectedSuccessful: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &direct{} + d.providers = make(map[string]*Provider) + d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) + d.masquerades = make(sortedMasquerades, len(tt.masquerades)) + for i, m := range tt.masquerades { + d.masquerades[i] = m + } + + d.findWorkingMasquerades() + + time.Sleep(1000 * time.Millisecond) + var successful int + for _, m := range tt.masquerades { + // If it has a last succeeded time, it was successful + if !m.lastSucceededTime.IsZero() { + successful++ + } + } + + assert.GreaterOrEqual(t, successful, tt.expectedSuccessful) + }) + } +} + +// Generate a mock of a MasqueradeInterface with a Dial method that can optionally +// return an error after a specified number of milliseconds. +func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade { + return &mockMasquerade{ + Domain: domain, + IpAddress: ipAddress, + timeout: timeout, + passesCheck: passesCheck, + } +} + +type mockMasquerade struct { + Domain string + IpAddress string + timeout time.Duration + passesCheck bool + lastSucceededTime time.Time +} + +// setLastSucceeded implements MasqueradeInterface. +func (m *mockMasquerade) setLastSucceeded(succeededTime time.Time) { + m.lastSucceededTime = succeededTime +} + +// lastSucceeded implements MasqueradeInterface. +func (m *mockMasquerade) lastSucceeded() time.Time { + return m.lastSucceededTime +} + +// postCheck implements MasqueradeInterface. +func (m *mockMasquerade) postCheck(net.Conn, string) bool { + return m.passesCheck +} + +// dial implements MasqueradeInterface. +func (m *mockMasquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { + if m.timeout > 0 { + time.Sleep(m.timeout) + return nil, errors.New("mock dial error") + } + m.lastSucceededTime = time.Now() + return &net.TCPConn{}, nil +} + +// getDomain implements MasqueradeInterface. +func (m *mockMasquerade) getDomain() string { + return m.Domain +} + +// getIpAddress implements MasqueradeInterface. +func (m *mockMasquerade) getIpAddress() string { + return m.IpAddress +} + +// getProviderID implements MasqueradeInterface. +func (m *mockMasquerade) getProviderID() string { + return "testProviderId" +} + +// markFailed implements MasqueradeInterface. +func (m *mockMasquerade) markFailed() { + +} + +// markSucceeded implements MasqueradeInterface. +func (m *mockMasquerade) markSucceeded() { +} + +// Make sure that the mockMasquerade implements the MasqueradeInterface +var _ MasqueradeInterface = (*mockMasquerade)(nil) diff --git a/go.mod b/go.mod index fc503af..fcd0b1c 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/go-stack/stack v1.8.1 // indirect github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect go.opentelemetry.io/otel v1.14.0 // indirect go.opentelemetry.io/otel/trace v1.14.0 // indirect golang.org/x/crypto v0.1.0 // indirect diff --git a/masquerade.go b/masquerade.go index b57a6fb..de10223 100644 --- a/masquerade.go +++ b/masquerade.go @@ -2,14 +2,22 @@ package fronted import ( "crypto/sha256" + "crypto/x509" "encoding/json" + "errors" "fmt" + "io" "net" "net/http" "sort" "strings" "sync" "time" + + "github.com/getlantern/netx" + "github.com/getlantern/ops" + "github.com/getlantern/tlsdialer/v3" + tls "github.com/refraction-networking/utls" ) const ( @@ -44,6 +52,29 @@ type Masquerade struct { VerifyHostname *string } +// Create a masquerade interface for easier testing. +type MasqueradeInterface interface { + dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) + + // Accessor for the domain of the masquerade + getDomain() string + + //Accessor for the IP address of the masquerade + getIpAddress() string + + markSucceeded() + + markFailed() + + lastSucceeded() time.Time + + setLastSucceeded(time.Time) + + postCheck(net.Conn, string) bool + + getProviderID() string +} + type masquerade struct { Masquerade // lastSucceeded: the most recent time at which this Masquerade succeeded @@ -53,6 +84,98 @@ type masquerade struct { mx sync.RWMutex } +func (m *masquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { + tlsConfig := &tls.Config{ + ServerName: m.Domain, + RootCAs: rootCAs, + } + dialTimeout := 10 * time.Second + addr := m.IpAddress + var sendServerNameExtension bool + if m.SNI != "" { + sendServerNameExtension = true + tlsConfig.ServerName = m.SNI + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + var verifyHostname string + if m.VerifyHostname != nil { + verifyHostname = *m.VerifyHostname + } + return verifyPeerCertificate(rawCerts, rootCAs, verifyHostname) + } + } + dialer := &tlsdialer.Dialer{ + DoDial: netx.DialTimeout, + Timeout: dialTimeout, + SendServerName: sendServerNameExtension, + Config: tlsConfig, + ClientHelloID: clientHelloID, + } + _, _, err := net.SplitHostPort(addr) + if err != nil { + // If there is no port, we default to 443 + addr = net.JoinHostPort(addr, "443") + } + return dialer.Dial("tcp", addr) +} + +// postCheck does a post with invalid data to verify domain-fronting works +func (m *masquerade) postCheck(conn net.Conn, testURL string) bool { + client := &http.Client{ + Transport: frontedHTTPTransport(conn, true), + } + return doCheck(client, http.MethodPost, http.StatusAccepted, testURL) +} + +func doCheck(client *http.Client, method string, expectedStatus int, u string) bool { + op := ops.Begin("check_masquerade") + defer op.End() + + isPost := method == http.MethodPost + var requestBody io.Reader + if isPost { + requestBody = strings.NewReader("a") + } + req, _ := http.NewRequest(method, u, requestBody) + if isPost { + req.Header.Set("Content-Type", "application/json") + } + resp, err := client.Do(req) + if err != nil { + op.FailIf(err) + log.Debugf("Unsuccessful vetting with %v request, discarding masquerade: %v", method, err) + return false + } + if resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + if resp.StatusCode != expectedStatus { + op.Set("response_status", resp.StatusCode) + op.Set("expected_status", expectedStatus) + msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status) + op.FailIf(errors.New(msg)) + log.Debug(msg) + return false + } + return true +} + +// getDomain implements MasqueradeInterface. +func (m *masquerade) getDomain() string { + return m.Domain +} + +// getIpAddress implements MasqueradeInterface. +func (m *masquerade) getIpAddress() string { + return m.IpAddress +} + +// getProviderID implements MasqueradeInterface. +func (m *masquerade) getProviderID() string { + return m.ProviderID +} + // MarshalJSON marshals masquerade into json func (m *masquerade) MarshalJSON() ([]byte, error) { m.mx.RLock() @@ -68,6 +191,12 @@ func (m *masquerade) lastSucceeded() time.Time { return m.LastSucceeded } +func (m *masquerade) setLastSucceeded(t time.Time) { + m.mx.Lock() + defer m.mx.Unlock() + m.LastSucceeded = t +} + func (m *masquerade) markSucceeded() { m.mx.Lock() defer m.mx.Unlock() @@ -80,6 +209,9 @@ func (m *masquerade) markFailed() { m.LastSucceeded = time.Time{} } +// Make sure that the mockMasquerade implements the MasqueradeInterface +var _ MasqueradeInterface = (*masquerade)(nil) + // A Direct fronting provider configuration. type Provider struct { // Specific hostname mappings used for this provider. @@ -211,7 +343,7 @@ func NewStatusCodeValidator(reject []int) ResponseValidator { } // slice of masquerade sorted by last vetted time -type sortedMasquerades []*masquerade +type sortedMasquerades []MasqueradeInterface func (m sortedMasquerades) Len() int { return len(m) } func (m sortedMasquerades) Swap(i, j int) { m[i], m[j] = m[j], m[i] } @@ -221,7 +353,7 @@ func (m sortedMasquerades) Less(i, j int) bool { } else if m[j].lastSucceeded().After(m[i].lastSucceeded()) { return false } else { - return m[i].IpAddress < m[j].IpAddress + return m[i].getIpAddress() < m[j].getIpAddress() } } From 52b3c2531d958e6f1f391a4dc5713ebb57b037e8 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Wed, 23 Oct 2024 15:39:27 -0600 Subject: [PATCH 03/22] Improved test --- direct_test.go | 19 +++++++++---------- go.mod | 1 - 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/direct_test.go b/direct_test.go index 365ae33..88f84ad 100644 --- a/direct_test.go +++ b/direct_test.go @@ -15,6 +15,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "testing" "time" @@ -856,7 +857,7 @@ func TestFindWorkingMasquerades(t *testing.T) { } return masquerades }(), - expectedSuccessful: 10, + expectedSuccessful: 4, }, } @@ -870,18 +871,16 @@ func TestFindWorkingMasquerades(t *testing.T) { d.masquerades[i] = m } - d.findWorkingMasquerades() + var successful atomic.Uint32 + d.vetGroup(0, 10, &successful) - time.Sleep(1000 * time.Millisecond) - var successful int - for _, m := range tt.masquerades { - // If it has a last succeeded time, it was successful - if !m.lastSucceededTime.IsZero() { - successful++ - } + tries := 0 + for successful.Load() < uint32(tt.expectedSuccessful) && tries < 100 { + time.Sleep(30 * time.Millisecond) + tries++ } - assert.GreaterOrEqual(t, successful, tt.expectedSuccessful) + assert.GreaterOrEqual(t, int(successful.Load()), tt.expectedSuccessful) }) } } diff --git a/go.mod b/go.mod index fcd0b1c..fc503af 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,6 @@ require ( github.com/go-stack/stack v1.8.1 // indirect github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect go.opentelemetry.io/otel v1.14.0 // indirect go.opentelemetry.io/otel/trace v1.14.0 // indirect golang.org/x/crypto v0.1.0 // indirect From 18e39262b4b87a17880dbeb2419ced287ce6492b Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Wed, 23 Oct 2024 15:40:46 -0600 Subject: [PATCH 04/22] Added constant --- direct.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/direct.go b/direct.go index c79d8a4..957f981 100644 --- a/direct.go +++ b/direct.go @@ -110,7 +110,8 @@ func (d *direct) findWorkingMasquerades() { func (d *direct) vetGroup(start, batchSize int, successful *atomic.Uint32) { var wg sync.WaitGroup - for j := start; j < start+batchSize && j < len(d.masquerades); j++ { + masqueradeSize := len(d.masquerades) + for j := start; j < start+batchSize && j < masqueradeSize; j++ { wg.Add(1) go func(m MasqueradeInterface) { defer wg.Done() From a348358ea0f622daae04afe7a68abfcad86fd6d6 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 04:59:35 -0600 Subject: [PATCH 05/22] Comment and naming --- direct.go | 10 ++++++++-- direct_test.go | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/direct.go b/direct.go index 957f981..747f756 100644 --- a/direct.go +++ b/direct.go @@ -99,16 +99,22 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { return masq.postCheck(conn, testURL) } +// findWorkingMasquerades finds working masquerades by vetting them in batches and in +// parallel. Speed is of the essence here, as without working masquerades, users will +// be unable to fetch proxy configurations, particularly in the case of a first time +// user who does not have proxies cached on disk. func (d *direct) findWorkingMasquerades() { // vet masquerades in batches const batchSize int = 25 var successful atomic.Uint32 + + // We loop through all of them until we have 4 successful ones. for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize { - d.vetGroup(i, batchSize, &successful) + d.vetBatch(i, batchSize, &successful) } } -func (d *direct) vetGroup(start, batchSize int, successful *atomic.Uint32) { +func (d *direct) vetBatch(start, batchSize int, successful *atomic.Uint32) { var wg sync.WaitGroup masqueradeSize := len(d.masquerades) for j := start; j < start+batchSize && j < masqueradeSize; j++ { diff --git a/direct_test.go b/direct_test.go index 88f84ad..0fbc8ad 100644 --- a/direct_test.go +++ b/direct_test.go @@ -872,7 +872,7 @@ func TestFindWorkingMasquerades(t *testing.T) { } var successful atomic.Uint32 - d.vetGroup(0, 10, &successful) + d.vetBatch(0, 10, &successful) tries := 0 for successful.Load() < uint32(tt.expectedSuccessful) && tries < 100 { From 849c965ea86bf173952a46b93d70a917aa98d5ea Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 09:50:54 -0600 Subject: [PATCH 06/22] Log details of the cache's parent directory --- cache.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cache.go b/cache.go index 0007afe..cbb79d3 100644 --- a/cache.go +++ b/cache.go @@ -3,6 +3,7 @@ package fronted import ( "encoding/json" "os" + "path/filepath" "time" ) @@ -87,6 +88,16 @@ func (d *direct) updateCache(cacheFile string) { err = os.WriteFile(cacheFile, b, 0644) if err != nil { log.Errorf("Unable to save cache to disk: %v", err) + // Log the directory of the cache file and if it exists for debugging purposes + parent := filepath.Dir(cacheFile) + // check if the parent directory exists + if _, err := os.Stat(parent); err == nil { + // parent directory exists + log.Debugf("Parent directory of cache file exists: %v", parent) + } else { + // parent directory does not exist + log.Debugf("Parent directory of cache file does not exist: %v", parent) + } } } From c7a73e667b0d2ce721672b34a68e7af7c30b2267 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 10:08:34 -0600 Subject: [PATCH 07/22] log tweaks --- direct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/direct.go b/direct.go index 747f756..02e0578 100644 --- a/direct.go +++ b/direct.go @@ -68,7 +68,6 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { } for _, c := range sh { - log.Trace("Adding candidate") d.masquerades = append(d.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) } } @@ -115,6 +114,7 @@ func (d *direct) findWorkingMasquerades() { } func (d *direct) vetBatch(start, batchSize int, successful *atomic.Uint32) { + log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup masqueradeSize := len(d.masquerades) for j := start; j < start+batchSize && j < masqueradeSize; j++ { From 9ce208752d141c502d24d3c6a09148e94df23293 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 14:43:58 -0600 Subject: [PATCH 08/22] Naming and initialization cleanup --- cache.go | 12 +-- cache_test.go | 4 +- context.go | 50 +++------- direct.go => fronted.go | 151 +++++++++++++++--------------- direct_test.go => fronted_test.go | 20 ++-- 5 files changed, 107 insertions(+), 130 deletions(-) rename direct.go => fronted.go (75%) rename direct_test.go => fronted_test.go (99%) diff --git a/cache.go b/cache.go index cbb79d3..9df34d5 100644 --- a/cache.go +++ b/cache.go @@ -7,12 +7,12 @@ import ( "time" ) -func (d *direct) initCaching(cacheFile string) { +func (d *fronted) initCaching(cacheFile string) { d.prepopulateMasquerades(cacheFile) go d.maintainCache(cacheFile) } -func (d *direct) prepopulateMasquerades(cacheFile string) { +func (d *fronted) prepopulateMasquerades(cacheFile string) { bytes, err := os.ReadFile(cacheFile) if err != nil { // This is not a big deal since we'll just fill the cache later @@ -48,7 +48,7 @@ func (d *direct) prepopulateMasquerades(cacheFile string) { } } -func (d *direct) markCacheDirty() { +func (d *fronted) markCacheDirty() { select { case d.cacheDirty <- nil: // okay @@ -57,7 +57,7 @@ func (d *direct) markCacheDirty() { } } -func (d *direct) maintainCache(cacheFile string) { +func (d *fronted) maintainCache(cacheFile string) { for { select { case <-d.cacheClosed: @@ -73,7 +73,7 @@ func (d *direct) maintainCache(cacheFile string) { } } -func (d *direct) updateCache(cacheFile string) { +func (d *fronted) updateCache(cacheFile string) { log.Debugf("Updating cache at %v", cacheFile) cache := d.masquerades.sortedCopy() sizeToSave := len(cache) @@ -101,7 +101,7 @@ func (d *direct) updateCache(cacheFile string) { } } -func (d *direct) closeCache() { +func (d *fronted) closeCache() { d.closeCacheOnce.Do(func() { close(d.cacheClosed) }) diff --git a/cache_test.go b/cache_test.go index 06e1c42..9b874d1 100644 --- a/cache_test.go +++ b/cache_test.go @@ -26,8 +26,8 @@ func TestCaching(t *testing.T) { cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil), } - makeDirect := func() *direct { - d := &direct{ + makeDirect := func() *fronted { + d := &fronted{ masquerades: make(sortedMasquerades, 0, 1000), maxAllowedCachedAge: 250 * time.Millisecond, maxCacheSize: 4, diff --git a/context.go b/context.go index c9c7730..68027f3 100644 --- a/context.go +++ b/context.go @@ -27,11 +27,11 @@ func Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProvi } } -// NewDirect creates a new http.RoundTripper that does direct domain fronting +// NewFronted creates a new http.RoundTripper that does direct domain fronting // using the default context. If the default context isn't configured within // the given timeout, this method returns nil, false. -func NewDirect(timeout time.Duration) (http.RoundTripper, bool) { - return DefaultContext.NewDirect(timeout) +func NewFronted(timeout time.Duration) (http.RoundTripper, bool) { + return DefaultContext.NewFronted(timeout) } // Close closes any existing cache file in the default context @@ -68,51 +68,23 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { - existing := _existing.(*direct) + existing := _existing.(*fronted) log.Debugf("Closing cache from existing instance for %s context", fctx.name) existing.closeCache() } - size := 0 - for _, p := range providers { - size += len(p.Masquerades) + f, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID) + if err != nil { + return err } - - if size == 0 { - return fmt.Errorf("no masquerades for %s context", fctx.name) - } - - d := &direct{ - certPool: pool, - masquerades: make(sortedMasquerades, 0, size), - maxAllowedCachedAge: defaultMaxAllowedCachedAge, - maxCacheSize: defaultMaxCacheSize, - cacheSaveInterval: defaultCacheSaveInterval, - cacheDirty: make(chan interface{}, 1), - cacheClosed: make(chan interface{}), - defaultProviderID: defaultProviderID, - providers: make(map[string]*Provider), - clientHelloID: clientHelloID, - } - - // copy providers - for k, p := range providers { - d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) - } - - d.loadCandidates(d.providers) - if cacheFile != "" { - d.initCaching(cacheFile) - } - d.findWorkingMasquerades() - fctx.instance.Set(d) + fctx.instance.Set(f) return nil } -// NewDirect creates a new http.RoundTripper that does direct domain fronting. +// NewFronted creates a new http.RoundTripper that does direct domain fronting. // If the context isn't configured within the given timeout, this method // returns nil, false. -func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper, bool) { +func (fctx *FrontingContext) NewFronted(timeout time.Duration) (http.RoundTripper, bool) { instance, ok := fctx.instance.Get(timeout) if !ok { log.Errorf("No DirectHttpClient available within %v for context %s", timeout, fctx.name) @@ -125,7 +97,7 @@ func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper func (fctx *FrontingContext) Close() { _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { - existing := _existing.(*direct) + existing := _existing.(*fronted) log.Debugf("Closing cache from existing instance in %s context", fctx.name) existing.closeCache() } diff --git a/direct.go b/fronted.go similarity index 75% rename from direct.go rename to fronted.go index 02e0578..53b1a9f 100644 --- a/direct.go +++ b/fronted.go @@ -19,7 +19,6 @@ import ( tls "github.com/refraction-networking/utls" "github.com/getlantern/golog" - "github.com/getlantern/idletiming" "github.com/getlantern/ops" ) @@ -34,8 +33,9 @@ var ( log = golog.LoggerFor("fronted") ) -// direct is an implementation of http.RoundTripper -type direct struct { +// fronted identifies working IP address/domain pairings for domain fronting and is +// an implementation of http.RoundTripper for the convenience of callers. +type fronted struct { certPool *x509.CertPool masquerades sortedMasquerades maxAllowedCachedAge time.Duration @@ -49,7 +49,44 @@ type direct struct { clientHelloID tls.ClientHelloID } -func (d *direct) loadCandidates(initial map[string]*Provider) { +func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID) (*fronted, error) { + size := 0 + for _, p := range providers { + size += len(p.Masquerades) + } + + if size == 0 { + return nil, fmt.Errorf("no masquerades found in providers") + } + + f := &fronted{ + certPool: pool, + masquerades: make(sortedMasquerades, 0, size), + maxAllowedCachedAge: defaultMaxAllowedCachedAge, + maxCacheSize: defaultMaxCacheSize, + cacheSaveInterval: defaultCacheSaveInterval, + cacheDirty: make(chan interface{}, 1), + cacheClosed: make(chan interface{}), + defaultProviderID: defaultProviderID, + providers: make(map[string]*Provider), + clientHelloID: clientHelloID, + } + + // copy providers + for k, p := range providers { + f.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) + } + + f.loadCandidates(f.providers) + if cacheFile != "" { + f.initCaching(cacheFile) + } + f.findWorkingMasquerades() + + return f, nil +} + +func (f *fronted) loadCandidates(initial map[string]*Provider) { log.Debugf("Loading candidates for %d providers", len(initial)) defer log.Debug("Finished loading candidates") @@ -68,23 +105,23 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { } for _, c := range sh { - d.masquerades = append(d.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) + f.masquerades = append(f.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) } } } -func (d *direct) providerFor(m MasqueradeInterface) *Provider { +func (f *fronted) providerFor(m MasqueradeInterface) *Provider { pid := m.getProviderID() if pid == "" { - pid = d.defaultProviderID + pid = f.defaultProviderID } - return d.providers[pid] + return f.providers[pid] } // Vet vets the specified Masquerade, verifying certificate using the given CertPool. // This is used in genconfig. func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - d := &direct{ + d := &fronted{ certPool: pool, maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, @@ -102,35 +139,35 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { // parallel. Speed is of the essence here, as without working masquerades, users will // be unable to fetch proxy configurations, particularly in the case of a first time // user who does not have proxies cached on disk. -func (d *direct) findWorkingMasquerades() { +func (f *fronted) findWorkingMasquerades() { // vet masquerades in batches const batchSize int = 25 var successful atomic.Uint32 // We loop through all of them until we have 4 successful ones. - for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize { - d.vetBatch(i, batchSize, &successful) + for i := 0; i < len(f.masquerades) && successful.Load() < 4; i += batchSize { + f.vetBatch(i, batchSize, &successful) } } -func (d *direct) vetBatch(start, batchSize int, successful *atomic.Uint32) { +func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32) { log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup - masqueradeSize := len(d.masquerades) + masqueradeSize := len(f.masquerades) for j := start; j < start+batchSize && j < masqueradeSize; j++ { wg.Add(1) go func(m MasqueradeInterface) { defer wg.Done() - if d.vetMasquerade(m) { + if f.vetMasquerade(m) { successful.Add(1) } - }(d.masquerades[j]) + }(f.masquerades[j]) } wg.Wait() } -func (d *direct) vetMasquerade(m MasqueradeInterface) bool { - conn, masqueradeGood, err := d.dialMasquerade(m) +func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { + conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { log.Errorf("unexpected error vetting masquerades: %v", err) return false @@ -141,10 +178,10 @@ func (d *direct) vetMasquerade(m MasqueradeInterface) bool { } }() - provider := d.providerFor(m) + provider := f.providerFor(m) if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", - m.getProviderID(), d.providers) + m.getProviderID(), f.providers) return false } if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { @@ -152,21 +189,20 @@ func (d *direct) vetMasquerade(m MasqueradeInterface) bool { return false } - log.Debugf("Finished vetting one masquerade %v", m) + log.Debugf("Successfully vetted one masquerade %v", m) return true } -// Do continually retries a given request until it succeeds because some -// fronting providers will return a 403 for some domains. -func (d *direct) RoundTrip(req *http.Request) (*http.Response, error) { - res, _, err := d.RoundTripHijack(req) +// RoundTrip loops through all available masquerades, sorted by the one that most recently +// connected, retrying several times on failures. +func (f *fronted) RoundTrip(req *http.Request) (*http.Response, error) { + res, _, err := f.RoundTripHijack(req) return res, err } -// Do continually retries a given request until it succeeds because some -// fronting providers will return a 403 for some domains. Also return the -// underlying net.Conn established. -func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { +// RoundTripHijack loops through all available masquerades, sorted by the one that most +// recently connected, retrying several times on failures. +func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { op := ops.Begin("fronted_roundtrip") defer op.End() @@ -209,13 +245,13 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := d.dialAll(req.Context()) + conn, m, masqueradeGood, err := f.dialAll(req.Context()) if err != nil { // unable to find good masquerade, fail op.FailIf(err) return nil, nil, err } - provider := d.providerFor(m) + provider := f.providerFor(m) if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.getProviderID()) masqueradeGood(false) @@ -269,12 +305,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e } // Dial dials out using all available masquerades until one succeeds. -func (d *direct) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { - conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades) +func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { + conn, m, masqueradeGood, err := f.dialAllWith(ctx, f.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) { +func (f *fronted) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) { // never take more than a minute trying to find a dialer ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() @@ -290,7 +326,7 @@ dialLoop: default: // okay } - conn, masqueradeGood, err := d.dialMasquerade(m) + conn, masqueradeGood, err := f.dialMasquerade(m) if err == nil { return conn, m, masqueradeGood, nil } @@ -299,26 +335,24 @@ dialLoop: return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) } -func (d *direct) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { - // check to see if we've timed out - +func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { log.Tracef("Dialing to %v", m) // We do the full TLS connection here because in practice the domains at a given IP // address can change frequently on CDNs, so the certificate may not match what // we expect. - conn, retriable, err := d.doDial(m) + conn, retriable, err := f.doDial(m) masqueradeGood := func(good bool) bool { if good { m.markSucceeded() } else { m.markFailed() } - d.markCacheDirty() + f.markCacheDirty() return good } if err == nil { - log.Debug("Returning connection") + log.Debugf("Returning connection for masquerade: %v", m) return conn, masqueradeGood, err } else if !retriable { log.Debugf("Dropping masquerade: non retryable error: %v", err) @@ -327,13 +361,13 @@ func (d *direct) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) boo return conn, masqueradeGood, err } -func (d *direct) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) { +func (f *fronted) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() op.Set("masquerade_domain", m.getDomain()) op.Set("masquerade_ip", m.getIpAddress()) - conn, err = d.dialServerWith(m) + conn, err = m.dial(f.certPool, f.clientHelloID) if err != nil { op.FailIf(err) log.Debugf("Could not dial to %v, %v", m.getIpAddress(), err) @@ -349,31 +383,10 @@ func (d *direct) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, e } } else { log.Debugf("Got successful connection to: %v", m) - idleTimeout := 70 * time.Second - - log.Debugf("Wrapping connection in idletiming connection: %v", m) - conn = idletiming.Conn(conn, idleTimeout, func() { - log.Debugf("Connection to %v idle for %v, closed", conn.RemoteAddr(), idleTimeout) - }) } return } -func (d *direct) dialServerWith(m MasqueradeInterface) (net.Conn, error) { - op := ops.Begin("dial_server_with") - defer op.End() - - op.Set("masquerade_domain", m.getDomain()) - op.Set("masquerade_ip", m.getIpAddress()) - - conn, err := m.dial(d.certPool, d.clientHelloID) - if err != nil && m != nil { - err = fmt.Errorf("unable to dial masquerade %s: %s", m.getDomain(), err) - op.FailIf(err) - } - return conn, err -} - func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain string) error { if len(rawCerts) == 0 { return fmt.Errorf("no certificates presented") @@ -422,15 +435,6 @@ func generateVerifyOptions(roots *x509.CertPool, domain string) x509.VerifyOptio } } -// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the -// initial TCP connection to the CDN. -func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config { - return &tls.Config{ - ServerName: m.Domain, - RootCAs: d.certPool, - } -} - // frontedHTTPTransport is the transport to use to route to the actual fronted destination domain. // This uses the pre-established connection to the CDN on the fronting domain. func frontedHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripper { @@ -441,6 +445,7 @@ func frontedHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripp }, TLSHandshakeTimeout: 40 * time.Second, DisableKeepAlives: disableKeepAlives, + IdleConnTimeout: 70 * time.Second, }, } } diff --git a/direct_test.go b/fronted_test.go similarity index 99% rename from direct_test.go rename to fronted_test.go index 0fbc8ad..23fc1a1 100644 --- a/direct_test.go +++ b/fronted_test.go @@ -57,7 +57,7 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { }) Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewDirect(0) + transport, ok := NewFronted(0) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -85,7 +85,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE p := testProvidersWithHosts(hosts) Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewDirect(30 * time.Second) + transport, ok := NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ @@ -94,7 +94,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = NewDirect(0) + transport, ok = NewFronted(0) require.True(t, ok) client = &http.Client{ Transport: transport, @@ -103,7 +103,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE instance, ok := DefaultContext.instance.Get(0) require.True(t, ok) - d := instance.(*direct) + d := instance.(*fronted) // Check the number of masquerades at the end, waiting up to 30 seconds until we get the right number masqueradesAtEnd := 0 @@ -138,7 +138,7 @@ func TestLoadCandidates(t *testing.T) { } } - d := &direct{ + d := &fronted{ masquerades: make(sortedMasquerades, 0, len(expected)), } @@ -238,7 +238,7 @@ func TestHostAliasesBasic(t *testing.T) { certs.AddCert(cloudSack.Certificate()) Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + rt, ok := NewFronted(10 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -349,7 +349,7 @@ func TestHostAliasesMulti(t *testing.T) { } Configure(certs, providers, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + rt, ok := NewFronted(10 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -475,7 +475,7 @@ func TestPassthrough(t *testing.T) { certs.AddCert(cloudSack.Certificate()) Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + rt, ok := NewFronted(10 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -615,7 +615,7 @@ func TestCustomValidators(t *testing.T) { for _, test := range tests { setup(test.validator) - direct, ok := NewDirect(1 * time.Second) + direct, ok := NewFronted(1 * time.Second) if !assert.True(t, ok) { return } @@ -863,7 +863,7 @@ func TestFindWorkingMasquerades(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &direct{} + d := &fronted{} d.providers = make(map[string]*Provider) d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) d.masquerades = make(sortedMasquerades, len(tt.masquerades)) From f04931919184e028b18336f77b8c64bad3bc1f10 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 14:53:19 -0600 Subject: [PATCH 09/22] Fixed comment --- masquerade.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/masquerade.go b/masquerade.go index de10223..9dcdd52 100644 --- a/masquerade.go +++ b/masquerade.go @@ -209,7 +209,7 @@ func (m *masquerade) markFailed() { m.LastSucceeded = time.Time{} } -// Make sure that the mockMasquerade implements the MasqueradeInterface +// Make sure that the masquerade struct implements the MasqueradeInterface var _ MasqueradeInterface = (*masquerade)(nil) // A Direct fronting provider configuration. From a602b0cdcbfa8eaedda3a40f43f90b115150b7a8 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Thu, 24 Oct 2024 16:02:32 -0600 Subject: [PATCH 10/22] Only set the context instance when we've successfully connected to a working masquerade --- context.go | 5 +++-- fronted.go | 15 ++++++++++----- fronted_test.go | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/context.go b/context.go index 68027f3..86432a1 100644 --- a/context.go +++ b/context.go @@ -73,11 +73,12 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m existing.closeCache() } - f, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID) + _, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + fctx.instance.Set(f) + }) if err != nil { return err } - fctx.instance.Set(f) return nil } diff --git a/fronted.go b/fronted.go index 53b1a9f..9407585 100644 --- a/fronted.go +++ b/fronted.go @@ -49,7 +49,9 @@ type fronted struct { clientHelloID tls.ClientHelloID } -func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID) (*fronted, error) { +func newFronted(pool *x509.CertPool, providers map[string]*Provider, + defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID, + listener func(f *fronted)) (*fronted, error) { size := 0 for _, p := range providers { size += len(p.Masquerades) @@ -81,7 +83,7 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProv if cacheFile != "" { f.initCaching(cacheFile) } - f.findWorkingMasquerades() + f.findWorkingMasquerades(listener) return f, nil } @@ -139,18 +141,18 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { // parallel. Speed is of the essence here, as without working masquerades, users will // be unable to fetch proxy configurations, particularly in the case of a first time // user who does not have proxies cached on disk. -func (f *fronted) findWorkingMasquerades() { +func (f *fronted) findWorkingMasquerades(listener func(f *fronted)) { // vet masquerades in batches const batchSize int = 25 var successful atomic.Uint32 // We loop through all of them until we have 4 successful ones. for i := 0; i < len(f.masquerades) && successful.Load() < 4; i += batchSize { - f.vetBatch(i, batchSize, &successful) + f.vetBatch(i, batchSize, &successful, listener) } } -func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32) { +func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, listener func(f *fronted)) { log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup masqueradeSize := len(f.masquerades) @@ -160,6 +162,9 @@ func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32) { defer wg.Done() if f.vetMasquerade(m) { successful.Add(1) + if listener != nil { + go listener(f) + } } }(f.masquerades[j]) } diff --git a/fronted_test.go b/fronted_test.go index 23fc1a1..7d200f0 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -872,7 +872,7 @@ func TestFindWorkingMasquerades(t *testing.T) { } var successful atomic.Uint32 - d.vetBatch(0, 10, &successful) + d.vetBatch(0, 10, &successful, nil) tries := 0 for successful.Load() < uint32(tt.expectedSuccessful) && tries < 100 { From b9c4b0d3378d9aeca5abc352912f189e71412aec Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 25 Oct 2024 05:31:02 -0600 Subject: [PATCH 11/22] Cleanup CI --- .github/workflows/test.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e9fd457..9e29ed5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,22 +1,21 @@ name: Build and Test on: - - push - - pull_request + push: + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v4 with: go-version-file: "go.mod" - - name: Granting private modules access - run: | - git config --global url."https://${{ secrets.GH_TOKEN }}:x-oauth-basic@github.com/".insteadOf "https://github.com/" - name: Run unit tests run: go test -failfast -coverprofile=profile.cov - name: Install goveralls From 86e11cccb12ebf799c2f7bab5cf8d5cd868a3f3e Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 25 Oct 2024 05:37:28 -0600 Subject: [PATCH 12/22] trying to get CI to pass --- context.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index 86432a1..46559c0 100644 --- a/context.go +++ b/context.go @@ -73,12 +73,13 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m existing.closeCache() } - _, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { - fctx.instance.Set(f) + f, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + }) if err != nil { return err } + fctx.instance.Set(f) return nil } From eda8ab2dfdaf4209ea445a47a9a11a3d9a3610ae Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 25 Oct 2024 12:23:09 -0600 Subject: [PATCH 13/22] set waittime on tests --- context.go | 5 ++--- fronted_test.go | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/context.go b/context.go index 46559c0..86432a1 100644 --- a/context.go +++ b/context.go @@ -73,13 +73,12 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m existing.closeCache() } - f, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { - + _, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + fctx.instance.Set(f) }) if err != nil { return err } - fctx.instance.Set(f) return nil } diff --git a/fronted_test.go b/fronted_test.go index 7d200f0..ab88d3f 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -57,7 +57,7 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { }) Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewFronted(0) + transport, ok := NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -94,7 +94,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = NewFronted(0) + transport, ok = NewFronted(30 * time.Second) require.True(t, ok) client = &http.Client{ Transport: transport, @@ -238,7 +238,7 @@ func TestHostAliasesBasic(t *testing.T) { certs.AddCert(cloudSack.Certificate()) Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewFronted(10 * time.Second) + rt, ok := NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -349,7 +349,7 @@ func TestHostAliasesMulti(t *testing.T) { } Configure(certs, providers, "cloudsack", "") - rt, ok := NewFronted(10 * time.Second) + rt, ok := NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -475,7 +475,7 @@ func TestPassthrough(t *testing.T) { certs.AddCert(cloudSack.Certificate()) Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewFronted(10 * time.Second) + rt, ok := NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -615,7 +615,7 @@ func TestCustomValidators(t *testing.T) { for _, test := range tests { setup(test.validator) - direct, ok := NewFronted(1 * time.Second) + direct, ok := NewFronted(30 * time.Second) if !assert.True(t, ok) { return } From 33cb0a2fb761b45398c4e9bd8a9066421d40867c Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 25 Oct 2024 12:46:24 -0600 Subject: [PATCH 14/22] use test context --- fronted_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fronted_test.go b/fronted_test.go index ab88d3f..575c2f4 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -236,9 +236,11 @@ func TestHostAliasesBasic(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewFronted(30 * time.Second) + testContext := NewFrontingContext("TestHostAliasesBasic") + testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + + rt, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } From 1d58562d93c0896c120aa05df143a8c5c79660ba Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Fri, 25 Oct 2024 13:02:03 -0600 Subject: [PATCH 15/22] always use a separate context --- fronted_test.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/fronted_test.go b/fronted_test.go index 575c2f4..3139fe0 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -55,9 +55,10 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { UseArbitrarySNIs: true, ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) - Configure(certs, p, testProviderID, cacheFile) + testContext := NewFrontingContext("TestDirectDomainFrontingWithSNIConfig") + testContext.Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewFronted(30 * time.Second) + transport, ok := testContext.NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -83,9 +84,10 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - Configure(certs, p, testProviderID, cacheFile) + testContext := NewFrontingContext("doTestDomainFronting") + testContext.Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewFronted(30 * time.Second) + transport, ok := testContext.NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ @@ -94,14 +96,14 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = NewFronted(30 * time.Second) + transport, ok = testContext.NewFronted(30 * time.Second) require.True(t, ok) client = &http.Client{ Transport: transport, } require.True(t, doCheck(client, http.MethodGet, http.StatusOK, getURL)) - instance, ok := DefaultContext.instance.Get(0) + instance, ok := testContext.instance.Get(0) require.True(t, ok) d := instance.(*fronted) @@ -350,8 +352,9 @@ func TestHostAliasesMulti(t *testing.T) { "sadcloud": p2, } - Configure(certs, providers, "cloudsack", "") - rt, ok := NewFronted(30 * time.Second) + testContext := NewFrontingContext("TestHostAliasesMulti") + testContext.Configure(certs, providers, "cloudsack", "") + rt, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -475,9 +478,11 @@ func TestPassthrough(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewFronted(30 * time.Second) + testContext := NewFrontingContext("TestPassthrough") + testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + + rt, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -532,6 +537,7 @@ func TestCustomValidators(t *testing.T) { sadCloudCodes := []int{http.StatusPaymentRequired, http.StatusTeapot, http.StatusBadGateway} sadCloudValidator := NewStatusCodeValidator(sadCloudCodes) testURL := "https://abc.forbidden.com/quux" + testContext := NewFrontingContext("TestCustomValidators") setup := func(validator ResponseValidator) { masq := []*Masquerade{{Domain: "example.com", IpAddress: sadCloudAddr}} @@ -547,7 +553,7 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - Configure(certs, providers, "sadcloud", "") + testContext.Configure(certs, providers, "sadcloud", "") } // This error indicates that the validator has discarded all masquerades. @@ -617,7 +623,7 @@ func TestCustomValidators(t *testing.T) { for _, test := range tests { setup(test.validator) - direct, ok := NewFronted(30 * time.Second) + direct, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok) { return } From 2fd978c3c13b466e07b47dae3c2a9a4993ffb72a Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sat, 26 Oct 2024 08:17:07 -0600 Subject: [PATCH 16/22] Minor tweaks --- context.go | 2 ++ fronted_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/context.go b/context.go index 86432a1..556ea43 100644 --- a/context.go +++ b/context.go @@ -90,6 +90,8 @@ func (fctx *FrontingContext) NewFronted(timeout time.Duration) (http.RoundTrippe if !ok { log.Errorf("No DirectHttpClient available within %v for context %s", timeout, fctx.name) return nil, false + } else { + log.Debugf("DirectHttpClient available for context %s", fctx.name) } return instance.(http.RoundTripper), true } diff --git a/fronted_test.go b/fronted_test.go index 3139fe0..f80cdf6 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -25,7 +25,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestDirectDomainFronting(t *testing.T) { +func TestDirectDomainFrontingWithoutSNIConfig(t *testing.T) { dir, err := os.MkdirTemp("", "direct_test") require.NoError(t, err, "Unable to create temp dir") defer os.RemoveAll(dir) @@ -107,14 +107,14 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE require.True(t, ok) d := instance.(*fronted) - // Check the number of masquerades at the end, waiting up to 30 seconds until we get the right number + // Check the number of masquerades at the end, waiting until we get the right number masqueradesAtEnd := 0 - for i := 0; i < 100; i++ { + for i := 0; i < 1000; i++ { masqueradesAtEnd = len(d.masquerades) if masqueradesAtEnd == expectedMasqueradesAtEnd { break } - time.Sleep(300 * time.Millisecond) + time.Sleep(30 * time.Millisecond) } require.GreaterOrEqual(t, masqueradesAtEnd, expectedMasqueradesAtEnd) return masqueradesAtEnd From 3ea53efa2fa6c9a59fdd70c10936e32b8d530a85 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Sun, 27 Oct 2024 11:23:55 -0600 Subject: [PATCH 17/22] Minor tweaks --- fronted.go | 4 ++-- fronted_test.go | 2 +- test_support.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fronted.go b/fronted.go index 9407585..9ec9ca9 100644 --- a/fronted.go +++ b/fronted.go @@ -156,7 +156,7 @@ func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, list log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup masqueradeSize := len(f.masquerades) - for j := start; j < start+batchSize && j < masqueradeSize; j++ { + for i := start; i < start+batchSize && i < masqueradeSize; i++ { wg.Add(1) go func(m MasqueradeInterface) { defer wg.Done() @@ -166,7 +166,7 @@ func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, list go listener(f) } } - }(f.masquerades[j]) + }(f.masquerades[i]) } wg.Wait() } diff --git a/fronted_test.go b/fronted_test.go index f80cdf6..910a5ee 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -56,7 +56,7 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) testContext := NewFrontingContext("TestDirectDomainFrontingWithSNIConfig") - testContext.Configure(certs, p, testProviderID, cacheFile) + testContext.Configure(certs, p, "akamai", cacheFile) transport, ok := testContext.NewFronted(30 * time.Second) require.True(t, ok) diff --git a/test_support.go b/test_support.go index 54da8a7..a3fc561 100644 --- a/test_support.go +++ b/test_support.go @@ -58,6 +58,6 @@ func testProvidersWithHosts(hosts map[string]string) map[string]*Provider { } func testAkamaiProvidersWithHosts(hosts map[string]string, sniConfig *SNIConfig) map[string]*Provider { return map[string]*Provider{ - testProviderID: NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig, nil), + "akamai": NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig, nil), } } From 0a8fe3ea29024a54d248abc19ac74c72a42c4262 Mon Sep 17 00:00:00 2001 From: WendelHime <6754291+WendelHime@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:13:48 -0300 Subject: [PATCH 18/22] fix: update masquerades --- default_masquerades.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/default_masquerades.go b/default_masquerades.go index 94611fc..5b05c9c 100644 --- a/default_masquerades.go +++ b/default_masquerades.go @@ -34,11 +34,11 @@ var DefaultTrustedCAs = []*CA{ var DefaultAkamaiMasquerades = []*Masquerade{ { Domain: "a248.e.akamai.net", - IpAddress: "23.53.122.84", + IpAddress: "104.117.247.143", }, { Domain: "a248.e.akamai.net", - IpAddress: "2.19.198.29", + IpAddress: "23.47.194.73", }, } From 9892583ea5737378b10e1f336bdae1bed99765db Mon Sep 17 00:00:00 2001 From: WendelHime <6754291+WendelHime@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:14:04 -0300 Subject: [PATCH 19/22] fix: update akamai ping URL --- test_support.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_support.go b/test_support.go index a3fc561..6f787a0 100644 --- a/test_support.go +++ b/test_support.go @@ -58,6 +58,6 @@ func testProvidersWithHosts(hosts map[string]string) map[string]*Provider { } func testAkamaiProvidersWithHosts(hosts map[string]string, sniConfig *SNIConfig) map[string]*Provider { return map[string]*Provider{ - "akamai": NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig, nil), + "akamai": NewProvider(hosts, "https://fronted-ping.dsa.akamai.getiantem.org/ping", DefaultAkamaiMasquerades, nil, nil, sniConfig, nil), } } From 935999c40715026b5c38555ac0050d6b3dd62518 Mon Sep 17 00:00:00 2001 From: WendelHime <6754291+WendelHime@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:54:16 -0300 Subject: [PATCH 20/22] fix: adding test names to TestCustomValidators --- fronted_test.go | 57 +++++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/fronted_test.go b/fronted_test.go index 910a5ee..d8d9986 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -562,32 +562,38 @@ func TestCustomValidators(t *testing.T) { masqueradesExhausted := fmt.Sprintf(`Get "%v": could not complete request even with retries`, testURL) tests := []struct { + name string responseCode int validator ResponseValidator expectedError string }{ // with the default validator, only 403s are rejected { + name: "it should return masquerades exhausted error when providing nil validator and returning 403", responseCode: http.StatusForbidden, validator: nil, expectedError: masqueradesExhausted, }, { + name: "it should return no errors when providing nil validator and receiving a 202", responseCode: http.StatusAccepted, validator: nil, expectedError: "", }, { + name: "it should return no errors when providing nil validator and receiving 402", responseCode: http.StatusPaymentRequired, validator: nil, expectedError: "", }, { + name: "it should return no errors when providing nil validator and receiving 418", responseCode: http.StatusTeapot, validator: nil, expectedError: "", }, { + name: "it should return no errors when providing nil validator and receiving 502", responseCode: http.StatusBadGateway, validator: nil, expectedError: "", @@ -595,26 +601,31 @@ func TestCustomValidators(t *testing.T) { // with the custom validator, 403 is allowed, listed codes are rejected { + name: "it should return no errors when providing validator that accepts 403", responseCode: http.StatusForbidden, validator: sadCloudValidator, expectedError: "", }, { + name: "it should return no errors when providing validator that accepts 202", responseCode: http.StatusAccepted, validator: sadCloudValidator, expectedError: "", }, { + name: "it should return masquerades exhausted when validator receives a 402", responseCode: http.StatusPaymentRequired, validator: sadCloudValidator, expectedError: masqueradesExhausted, }, { + name: "it should return masquerades exhausted when validator receives a 418", responseCode: http.StatusTeapot, validator: sadCloudValidator, expectedError: masqueradesExhausted, }, { + name: "it should return masquerades exhausted when validator receives a 502", responseCode: http.StatusBadGateway, validator: sadCloudValidator, expectedError: masqueradesExhausted, @@ -622,34 +633,30 @@ func TestCustomValidators(t *testing.T) { } for _, test := range tests { - setup(test.validator) - direct, ok := testContext.NewFronted(30 * time.Second) - if !assert.True(t, ok) { - return - } - client := &http.Client{ - Transport: direct, - } + t.Run(test.name, func(t *testing.T) { + setup(test.validator) + direct, ok := testContext.NewFronted(30 * time.Second) + require.True(t, ok) + client := &http.Client{ + Transport: direct, + } - req, err := http.NewRequest(http.MethodGet, testURL, nil) - if !assert.NoError(t, err) { - return - } - if test.responseCode != http.StatusAccepted { - val := strconv.Itoa(test.responseCode) - log.Debugf("requesting forced response code %s", val) - req.Header.Set(CDNForceFail, val) - } + req, err := http.NewRequest(http.MethodGet, testURL, nil) + require.NoError(t, err) + if test.responseCode != http.StatusAccepted { + val := strconv.Itoa(test.responseCode) + log.Debugf("requesting forced response code %s", val) + req.Header.Set(CDNForceFail, val) + } - res, err := client.Do(req) - if test.expectedError == "" { - if !assert.NoError(t, err) { - continue + res, err := client.Do(req) + if test.expectedError == "" { + require.NoError(t, err) + assert.Equal(t, test.responseCode, res.StatusCode, "Failed to force response status code") + } else { + assert.EqualError(t, err, test.expectedError) } - assert.Equal(t, test.responseCode, res.StatusCode, "Failed to force response status code") - } else { - assert.EqualError(t, err, test.expectedError) - } + }) } } From 22c7e1d7bb2e49f0a8e7955ed11a6b4b6b81bf0e Mon Sep 17 00:00:00 2001 From: WendelHime <6754291+WendelHime@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:02:08 -0300 Subject: [PATCH 21/22] fix: renaming tests --- fronted_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/fronted_test.go b/fronted_test.go index d8d9986..bf81682 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -569,31 +569,31 @@ func TestCustomValidators(t *testing.T) { }{ // with the default validator, only 403s are rejected { - name: "it should return masquerades exhausted error when providing nil validator and returning 403", + name: "with default validator, it should reject 403", responseCode: http.StatusForbidden, validator: nil, expectedError: masqueradesExhausted, }, { - name: "it should return no errors when providing nil validator and receiving a 202", + name: "with default validator, it should accept 202", responseCode: http.StatusAccepted, validator: nil, expectedError: "", }, { - name: "it should return no errors when providing nil validator and receiving 402", + name: "with default validator, it should accept 402", responseCode: http.StatusPaymentRequired, validator: nil, expectedError: "", }, { - name: "it should return no errors when providing nil validator and receiving 418", + name: "with default validator, it should accept 418", responseCode: http.StatusTeapot, validator: nil, expectedError: "", }, { - name: "it should return no errors when providing nil validator and receiving 502", + name: "with default validator, it should accept 502", responseCode: http.StatusBadGateway, validator: nil, expectedError: "", @@ -601,31 +601,31 @@ func TestCustomValidators(t *testing.T) { // with the custom validator, 403 is allowed, listed codes are rejected { - name: "it should return no errors when providing validator that accepts 403", + name: "with custom validator, it should accept 403", responseCode: http.StatusForbidden, validator: sadCloudValidator, expectedError: "", }, { - name: "it should return no errors when providing validator that accepts 202", + name: "with custom validator, it should accept 402", responseCode: http.StatusAccepted, validator: sadCloudValidator, expectedError: "", }, { - name: "it should return masquerades exhausted when validator receives a 402", + name: "with custom validator, it should reject and return error for 402", responseCode: http.StatusPaymentRequired, validator: sadCloudValidator, expectedError: masqueradesExhausted, }, { - name: "it should return masquerades exhausted when validator receives a 418", + name: "with custom validator, it should reject and return error for 418", responseCode: http.StatusTeapot, validator: sadCloudValidator, expectedError: masqueradesExhausted, }, { - name: "it should return masquerades exhausted when validator receives a 502", + name: "with custom validator, it should reject and return error for 502", responseCode: http.StatusBadGateway, validator: sadCloudValidator, expectedError: masqueradesExhausted, From 59aaab8a7d2848208911fa4abc57925869f25853 Mon Sep 17 00:00:00 2001 From: WendelHime <6754291+WendelHime@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:12:20 -0300 Subject: [PATCH 22/22] fix: creating new FrontingContext for each test --- fronted_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fronted_test.go b/fronted_test.go index bf81682..a1a1f3f 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -537,9 +537,8 @@ func TestCustomValidators(t *testing.T) { sadCloudCodes := []int{http.StatusPaymentRequired, http.StatusTeapot, http.StatusBadGateway} sadCloudValidator := NewStatusCodeValidator(sadCloudCodes) testURL := "https://abc.forbidden.com/quux" - testContext := NewFrontingContext("TestCustomValidators") - setup := func(validator ResponseValidator) { + setup := func(ctx *FrontingContext, validator ResponseValidator) { masq := []*Masquerade{{Domain: "example.com", IpAddress: sadCloudAddr}} alias := map[string]string{ "abc.forbidden.com": "abc.sadcloud.io", @@ -553,7 +552,7 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - testContext.Configure(certs, providers, "sadcloud", "") + ctx.Configure(certs, providers, "sadcloud", "") } // This error indicates that the validator has discarded all masquerades. @@ -634,7 +633,8 @@ func TestCustomValidators(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - setup(test.validator) + testContext := NewFrontingContext(test.name) + setup(testContext, test.validator) direct, ok := testContext.NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{