From aaeb3c13ca5c2182f1f48f01449c4fd1b6ffca1e Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Tue, 22 Oct 2024 15:31:15 -0600 Subject: [PATCH] Parallelize masquerade lookup --- context.go | 2 +- direct.go | 174 ++++++++++++++++++++++++------------------------- direct_test.go | 4 +- 3 files changed, 88 insertions(+), 92 deletions(-) diff --git a/context.go b/context.go index 37e7ac9..c9c7730 100644 --- a/context.go +++ b/context.go @@ -104,7 +104,7 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m if cacheFile != "" { d.initCaching(cacheFile) } - go d.vet(numberToVetInitially) + d.findWorkingMasquerades() fctx.instance.Set(d) return nil } diff --git a/direct.go b/direct.go index 3a70035..3c6da45 100644 --- a/direct.go +++ b/direct.go @@ -13,6 +13,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" tls "github.com/refraction-networking/utls" @@ -25,7 +26,6 @@ import ( ) const ( - numberToVetInitially = 10 defaultMaxAllowedCachedAge = 24 * time.Hour defaultMaxCacheSize = 1000 defaultCacheSaveInterval = 5 * time.Second @@ -84,17 +84,9 @@ func (d *direct) providerFor(m *masquerade) *Provider { return d.providers[pid] } -// Vet vets the specified Masquerade, verifying certificate using the given CertPool +// Vet vets the specified Masquerade, verifying certificate using the given CertPool. +// This is used in genconfig. func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - return vet(m, pool, testURL) -} - -func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - op := ops.Begin("vet_masquerade") - defer op.End() - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) - d := &direct{ certPool: pool, maxAllowedCachedAge: defaultMaxAllowedCachedAge, @@ -102,54 +94,55 @@ func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { } conn, _, err := d.doDial(m) if err != nil { - op.FailIf(err) return false } defer conn.Close() return postCheck(conn, testURL) } -func (d *direct) vet(numberToVet int) { - log.Debugf("Vetting %d initial candidates in series", numberToVet) - for i := 0; i < numberToVet; i++ { - d.vetOne() - } -} - -func (d *direct) vetOne() { - // We're just testing the ability to connect here, destination site doesn't - // really matter - log.Debug("Vetting one") - unvettedMasquerades := make([]*masquerade, 0, len(d.masquerades)) - for _, m := range d.masquerades { - if m.lastSucceeded().IsZero() { - unvettedMasquerades = append(unvettedMasquerades, m) +func (d *direct) findWorkingMasquerades() { + // vet masquerades in batches + const batchSize int = 25 + var successful atomic.Uint32 + for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize { + var wg sync.WaitGroup + for j := i; j < i+batchSize && j < len(d.masquerades); j++ { + wg.Add(1) + go func(m *masquerade) { + defer wg.Done() + if d.vetMasquerade(m) { + successful.Add(1) + } + }(d.masquerades[j]) } + wg.Wait() } +} - // Don't take more than 10 seconds to dial a masquerade for vetting - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, m, masqueradeGood, err := d.dialWith(ctx, unvettedMasquerades) +func (d *direct) vetMasquerade(m *masquerade) bool { + conn, masqueradeGood, err := d.dialMasquerade(m) if err != nil { log.Errorf("unexpected error vetting masquerades: %v", err) - return + return false } - defer conn.Close() + defer func() { + if conn != nil { + conn.Close() + } + }() provider := d.providerFor(m) if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID) - return + return false } - if !masqueradeGood(postCheck(conn, provider.TestURL)) { log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") - return + return false } - log.Debug("Finished vetting one") + log.Debugf("Finished vetting one masquerade %v", m) + return true } // postCheck does a post with invalid data to verify domain-fronting works @@ -187,7 +180,7 @@ func doCheck(client *http.Client, method string, expectedStatus int, u string) b op.Set("response_status", resp.StatusCode) op.Set("expected_status", expectedStatus) msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status) - op.FailIf(fmt.Errorf(msg)) + op.FailIf(errors.New(msg)) log.Debug(msg) return false } @@ -247,7 +240,7 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := d.dial(req.Context()) + conn, m, masqueradeGood, err := d.dialAll(req.Context()) if err != nil { // unable to find good masquerade, fail op.FailIf(err) @@ -305,36 +298,13 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e return nil, nil, op.FailIf(errors.New("could not complete request even with retries")) } -func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { - url := *req.URL - url.Host = frontedHost - r, err := http.NewRequest(req.Method, url.String(), body) - if err != nil { - return nil, err - } - - for k, vs := range req.Header { - if !strings.EqualFold(k, "Host") { - v := make([]string, len(vs)) - copy(v, vs) - r.Header[k] = v - } - } - return r, nil -} - -// Dial dials out using a masquerade. If the available masquerade fails, it -// retries with others until it either succeeds or exhausts the available -// masquerades. If successful, it returns a connection to the masquerade, -// the selected masquerade, and a function that the caller can use to -// tell us whether the masquerade is good or not (i.e. if masquerade was good, -// keep it). -func (d *direct) dial(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { - conn, m, masqueradeGood, err := d.dialWith(ctx, d.masquerades) +// Dial dials out using all available masquerades until one succeeds. +func (d *direct) dialAll(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { + conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { +func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { // never take more than a minute trying to find a dialer ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() @@ -343,7 +313,6 @@ func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (n totalMasquerades := len(masqueradesToTry) dialLoop: for _, m := range masqueradesToTry { - // check to see if we've timed out select { case <-ctx.Done(): log.Debugf("Timed out dialing to %v with %v total masquerades", m, totalMasquerades) @@ -351,34 +320,43 @@ dialLoop: default: // okay } - - log.Tracef("Dialing to %v", m) - - // We do the full TLS connection here because in practice the domains at a given IP - // address can change frequently on CDNs, so the certificate may not match what - // we expect. - conn, retriable, err := d.doDial(&m.Masquerade) - masqueradeGood := func(good bool) bool { - if good { - m.markSucceeded() - } else { - m.markFailed() - } - d.markCacheDirty() - return good - } + conn, masqueradeGood, err := d.dialMasquerade(m) if err == nil { - log.Debug("Returning connection") - return conn, m, masqueradeGood, err - } else if !retriable { - log.Debugf("Dropping masquerade: non retryable error: %v", err) - masqueradeGood(false) + return conn, m, masqueradeGood, nil } } return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) } +func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error) { + // check to see if we've timed out + + log.Tracef("Dialing to %v", m) + + // We do the full TLS connection here because in practice the domains at a given IP + // address can change frequently on CDNs, so the certificate may not match what + // we expect. + conn, retriable, err := d.doDial(&m.Masquerade) + masqueradeGood := func(good bool) bool { + if good { + m.markSucceeded() + } else { + m.markFailed() + } + d.markCacheDirty() + return good + } + if err == nil { + log.Debug("Returning connection") + return conn, masqueradeGood, err + } else if !retriable { + log.Debugf("Dropping masquerade: non retryable error: %v", err) + masqueradeGood(false) + } + return conn, masqueradeGood, err +} + func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() @@ -551,3 +529,21 @@ func (ddf *directTransport) RoundTrip(req *http.Request) (resp *http.Response, e norm.URL.Scheme = "http" return ddf.Transport.RoundTrip(norm) } + +func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { + url := *req.URL + url.Host = frontedHost + r, err := http.NewRequest(req.Method, url.String(), body) + if err != nil { + return nil, err + } + + for k, vs := range req.Header { + if !strings.EqualFold(k, "Host") { + v := make([]string, len(vs)) + copy(v, vs) + r.Header[k] = v + } + } + return r, nil +} diff --git a/direct_test.go b/direct_test.go index 89a3a05..2e9096f 100644 --- a/direct_test.go +++ b/direct_test.go @@ -27,11 +27,11 @@ func TestDirectDomainFronting(t *testing.T) { require.NoError(t, err, "Unable to create temp dir") defer os.RemoveAll(dir) cacheFile := filepath.Join(dir, "cachefile.2") - doTestDomainFronting(t, cacheFile, numberToVetInitially) + doTestDomainFronting(t, cacheFile, 10) time.Sleep(defaultCacheSaveInterval * 2) // Then try again, this time reusing the existing cacheFile but a corrupted version corruptMasquerades(cacheFile) - doTestDomainFronting(t, cacheFile, numberToVetInitially) + doTestDomainFronting(t, cacheFile, 10) } func TestDirectDomainFrontingWithSNIConfig(t *testing.T) {