diff --git a/cache.go b/cache.go index 43391bb..9df34d5 100644 --- a/cache.go +++ b/cache.go @@ -3,15 +3,16 @@ package fronted import ( "encoding/json" "os" + "path/filepath" "time" ) -func (d *direct) initCaching(cacheFile string) { +func (d *fronted) initCaching(cacheFile string) { d.prepopulateMasquerades(cacheFile) go d.maintainCache(cacheFile) } -func (d *direct) prepopulateMasquerades(cacheFile string) { +func (d *fronted) prepopulateMasquerades(cacheFile string) { bytes, err := os.ReadFile(cacheFile) if err != nil { // This is not a big deal since we'll just fill the cache later @@ -38,16 +39,16 @@ func (d *direct) prepopulateMasquerades(cacheFile string) { // update last succeeded status of masquerades based on cached values for _, m := range d.masquerades { for _, cm := range cachedMasquerades { - sameMasquerade := cm.ProviderID == m.ProviderID && cm.Domain == m.Domain && cm.IpAddress == m.IpAddress - cachedValueFresh := now.Sub(m.LastSucceeded) < d.maxAllowedCachedAge + sameMasquerade := cm.ProviderID == m.getProviderID() && cm.Domain == m.getDomain() && cm.IpAddress == m.getIpAddress() + cachedValueFresh := now.Sub(m.lastSucceeded()) < d.maxAllowedCachedAge if sameMasquerade && cachedValueFresh { - m.LastSucceeded = cm.LastSucceeded + m.setLastSucceeded(cm.LastSucceeded) } } } } -func (d *direct) markCacheDirty() { +func (d *fronted) markCacheDirty() { select { case d.cacheDirty <- nil: // okay @@ -56,7 +57,7 @@ func (d *direct) markCacheDirty() { } } -func (d *direct) maintainCache(cacheFile string) { +func (d *fronted) maintainCache(cacheFile string) { for { select { case <-d.cacheClosed: @@ -72,7 +73,7 @@ func (d *direct) maintainCache(cacheFile string) { } } -func (d *direct) updateCache(cacheFile string) { +func (d *fronted) updateCache(cacheFile string) { log.Debugf("Updating cache at %v", cacheFile) cache := d.masquerades.sortedCopy() sizeToSave := len(cache) @@ -87,10 +88,20 @@ func (d *direct) updateCache(cacheFile string) { err = os.WriteFile(cacheFile, b, 0644) if err != nil { log.Errorf("Unable to save cache to disk: %v", err) + // Log the directory of the cache file and if it exists for debugging purposes + parent := filepath.Dir(cacheFile) + // check if the parent directory exists + if _, err := os.Stat(parent); err == nil { + // parent directory exists + log.Debugf("Parent directory of cache file exists: %v", parent) + } else { + // parent directory does not exist + log.Debugf("Parent directory of cache file does not exist: %v", parent) + } } } -func (d *direct) closeCache() { +func (d *fronted) closeCache() { d.closeCacheOnce.Do(func() { close(d.cacheClosed) }) diff --git a/cache_test.go b/cache_test.go index 06e1c42..9b874d1 100644 --- a/cache_test.go +++ b/cache_test.go @@ -26,8 +26,8 @@ func TestCaching(t *testing.T) { cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil), } - makeDirect := func() *direct { - d := &direct{ + makeDirect := func() *fronted { + d := &fronted{ masquerades: make(sortedMasquerades, 0, 1000), maxAllowedCachedAge: 250 * time.Millisecond, maxCacheSize: 4, diff --git a/context.go b/context.go index 37e7ac9..556ea43 100644 --- a/context.go +++ b/context.go @@ -27,11 +27,11 @@ func Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProvi } } -// NewDirect creates a new http.RoundTripper that does direct domain fronting +// NewFronted creates a new http.RoundTripper that does direct domain fronting // using the default context. If the default context isn't configured within // the given timeout, this method returns nil, false. -func NewDirect(timeout time.Duration) (http.RoundTripper, bool) { - return DefaultContext.NewDirect(timeout) +func NewFronted(timeout time.Duration) (http.RoundTripper, bool) { + return DefaultContext.NewFronted(timeout) } // Close closes any existing cache file in the default context @@ -68,55 +68,30 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { - existing := _existing.(*direct) + existing := _existing.(*fronted) log.Debugf("Closing cache from existing instance for %s context", fctx.name) existing.closeCache() } - size := 0 - for _, p := range providers { - size += len(p.Masquerades) + _, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + fctx.instance.Set(f) + }) + if err != nil { + return err } - - if size == 0 { - return fmt.Errorf("no masquerades for %s context", fctx.name) - } - - d := &direct{ - certPool: pool, - masquerades: make(sortedMasquerades, 0, size), - maxAllowedCachedAge: defaultMaxAllowedCachedAge, - maxCacheSize: defaultMaxCacheSize, - cacheSaveInterval: defaultCacheSaveInterval, - cacheDirty: make(chan interface{}, 1), - cacheClosed: make(chan interface{}), - defaultProviderID: defaultProviderID, - providers: make(map[string]*Provider), - clientHelloID: clientHelloID, - } - - // copy providers - for k, p := range providers { - d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) - } - - d.loadCandidates(d.providers) - if cacheFile != "" { - d.initCaching(cacheFile) - } - go d.vet(numberToVetInitially) - fctx.instance.Set(d) return nil } -// NewDirect creates a new http.RoundTripper that does direct domain fronting. +// NewFronted creates a new http.RoundTripper that does direct domain fronting. // If the context isn't configured within the given timeout, this method // returns nil, false. -func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper, bool) { +func (fctx *FrontingContext) NewFronted(timeout time.Duration) (http.RoundTripper, bool) { instance, ok := fctx.instance.Get(timeout) if !ok { log.Errorf("No DirectHttpClient available within %v for context %s", timeout, fctx.name) return nil, false + } else { + log.Debugf("DirectHttpClient available for context %s", fctx.name) } return instance.(http.RoundTripper), true } @@ -125,7 +100,7 @@ func (fctx *FrontingContext) NewDirect(timeout time.Duration) (http.RoundTripper func (fctx *FrontingContext) Close() { _existing, ok := fctx.instance.Get(0) if ok && _existing != nil { - existing := _existing.(*direct) + existing := _existing.(*fronted) log.Debugf("Closing cache from existing instance in %s context", fctx.name) existing.closeCache() } diff --git a/default_masquerades.go b/default_masquerades.go index 94611fc..5b05c9c 100644 --- a/default_masquerades.go +++ b/default_masquerades.go @@ -34,11 +34,11 @@ var DefaultTrustedCAs = []*CA{ var DefaultAkamaiMasquerades = []*Masquerade{ { Domain: "a248.e.akamai.net", - IpAddress: "23.53.122.84", + IpAddress: "104.117.247.143", }, { Domain: "a248.e.akamai.net", - IpAddress: "2.19.198.29", + IpAddress: "23.47.194.73", }, } diff --git a/direct.go b/fronted.go similarity index 55% rename from direct.go rename to fronted.go index 3a70035..9ec9ca9 100644 --- a/direct.go +++ b/fronted.go @@ -13,19 +13,16 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" tls "github.com/refraction-networking/utls" "github.com/getlantern/golog" - "github.com/getlantern/idletiming" - "github.com/getlantern/netx" "github.com/getlantern/ops" - "github.com/getlantern/tlsdialer/v3" ) const ( - numberToVetInitially = 10 defaultMaxAllowedCachedAge = 24 * time.Hour defaultMaxCacheSize = 1000 defaultCacheSaveInterval = 5 * time.Second @@ -36,8 +33,9 @@ var ( log = golog.LoggerFor("fronted") ) -// direct is an implementation of http.RoundTripper -type direct struct { +// fronted identifies working IP address/domain pairings for domain fronting and is +// an implementation of http.RoundTripper for the convenience of callers. +type fronted struct { certPool *x509.CertPool masquerades sortedMasquerades maxAllowedCachedAge time.Duration @@ -51,7 +49,46 @@ type direct struct { clientHelloID tls.ClientHelloID } -func (d *direct) loadCandidates(initial map[string]*Provider) { +func newFronted(pool *x509.CertPool, providers map[string]*Provider, + defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID, + listener func(f *fronted)) (*fronted, error) { + size := 0 + for _, p := range providers { + size += len(p.Masquerades) + } + + if size == 0 { + return nil, fmt.Errorf("no masquerades found in providers") + } + + f := &fronted{ + certPool: pool, + masquerades: make(sortedMasquerades, 0, size), + maxAllowedCachedAge: defaultMaxAllowedCachedAge, + maxCacheSize: defaultMaxCacheSize, + cacheSaveInterval: defaultCacheSaveInterval, + cacheDirty: make(chan interface{}, 1), + cacheClosed: make(chan interface{}), + defaultProviderID: defaultProviderID, + providers: make(map[string]*Provider), + clientHelloID: clientHelloID, + } + + // copy providers + for k, p := range providers { + f.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) + } + + f.loadCandidates(f.providers) + if cacheFile != "" { + f.initCaching(cacheFile) + } + f.findWorkingMasquerades(listener) + + return f, nil +} + +func (f *fronted) loadCandidates(initial map[string]*Provider) { log.Debugf("Loading candidates for %d providers", len(initial)) defer log.Debug("Finished loading candidates") @@ -70,141 +107,107 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { } for _, c := range sh { - log.Trace("Adding candidate") - d.masquerades = append(d.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) + f.masquerades = append(f.masquerades, &masquerade{Masquerade: *c, ProviderID: key}) } } } -func (d *direct) providerFor(m *masquerade) *Provider { - pid := m.ProviderID +func (f *fronted) providerFor(m MasqueradeInterface) *Provider { + pid := m.getProviderID() if pid == "" { - pid = d.defaultProviderID + pid = f.defaultProviderID } - return d.providers[pid] + return f.providers[pid] } -// Vet vets the specified Masquerade, verifying certificate using the given CertPool +// Vet vets the specified Masquerade, verifying certificate using the given CertPool. +// This is used in genconfig. func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - return vet(m, pool, testURL) -} - -func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool { - op := ops.Begin("vet_masquerade") - defer op.End() - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) - - d := &direct{ + d := &fronted{ certPool: pool, maxAllowedCachedAge: defaultMaxAllowedCachedAge, maxCacheSize: defaultMaxCacheSize, } - conn, _, err := d.doDial(m) + masq := &masquerade{Masquerade: *m} + conn, _, err := d.doDial(masq) if err != nil { - op.FailIf(err) return false } defer conn.Close() - return postCheck(conn, testURL) + return masq.postCheck(conn, testURL) } -func (d *direct) vet(numberToVet int) { - log.Debugf("Vetting %d initial candidates in series", numberToVet) - for i := 0; i < numberToVet; i++ { - d.vetOne() +// findWorkingMasquerades finds working masquerades by vetting them in batches and in +// parallel. Speed is of the essence here, as without working masquerades, users will +// be unable to fetch proxy configurations, particularly in the case of a first time +// user who does not have proxies cached on disk. +func (f *fronted) findWorkingMasquerades(listener func(f *fronted)) { + // vet masquerades in batches + const batchSize int = 25 + var successful atomic.Uint32 + + // We loop through all of them until we have 4 successful ones. + for i := 0; i < len(f.masquerades) && successful.Load() < 4; i += batchSize { + f.vetBatch(i, batchSize, &successful, listener) } } -func (d *direct) vetOne() { - // We're just testing the ability to connect here, destination site doesn't - // really matter - log.Debug("Vetting one") - unvettedMasquerades := make([]*masquerade, 0, len(d.masquerades)) - for _, m := range d.masquerades { - if m.lastSucceeded().IsZero() { - unvettedMasquerades = append(unvettedMasquerades, m) - } +func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, listener func(f *fronted)) { + log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) + var wg sync.WaitGroup + masqueradeSize := len(f.masquerades) + for i := start; i < start+batchSize && i < masqueradeSize; i++ { + wg.Add(1) + go func(m MasqueradeInterface) { + defer wg.Done() + if f.vetMasquerade(m) { + successful.Add(1) + if listener != nil { + go listener(f) + } + } + }(f.masquerades[i]) } + wg.Wait() +} - // Don't take more than 10 seconds to dial a masquerade for vetting - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - conn, m, masqueradeGood, err := d.dialWith(ctx, unvettedMasquerades) +func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { + conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { log.Errorf("unexpected error vetting masquerades: %v", err) - return + return false } - defer conn.Close() + defer func() { + if conn != nil { + conn.Close() + } + }() - provider := d.providerFor(m) + provider := f.providerFor(m) if provider == nil { - log.Debugf("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID) - return - } - - if !masqueradeGood(postCheck(conn, provider.TestURL)) { - log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") - return - } - - log.Debug("Finished vetting one") -} - -// postCheck does a post with invalid data to verify domain-fronting works -func postCheck(conn net.Conn, testURL string) bool { - client := &http.Client{ - Transport: frontedHTTPTransport(conn, true), - } - return doCheck(client, http.MethodPost, http.StatusAccepted, testURL) -} - -func doCheck(client *http.Client, method string, expectedStatus int, u string) bool { - op := ops.Begin("check_masquerade") - defer op.End() - - isPost := method == http.MethodPost - var requestBody io.Reader - if isPost { - requestBody = strings.NewReader("a") - } - req, _ := http.NewRequest(method, u, requestBody) - if isPost { - req.Header.Set("Content-Type", "application/json") - } - resp, err := client.Do(req) - if err != nil { - op.FailIf(err) - log.Debugf("Unsuccessful vetting with %v request, discarding masquerade: %v", method, err) + log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", + m.getProviderID(), f.providers) return false } - if resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - } - if resp.StatusCode != expectedStatus { - op.Set("response_status", resp.StatusCode) - op.Set("expected_status", expectedStatus) - msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status) - op.FailIf(fmt.Errorf(msg)) - log.Debug(msg) + if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { + log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") return false } + + log.Debugf("Successfully vetted one masquerade %v", m) return true } -// Do continually retries a given request until it succeeds because some -// fronting providers will return a 403 for some domains. -func (d *direct) RoundTrip(req *http.Request) (*http.Response, error) { - res, _, err := d.RoundTripHijack(req) +// RoundTrip loops through all available masquerades, sorted by the one that most recently +// connected, retrying several times on failures. +func (f *fronted) RoundTrip(req *http.Request) (*http.Response, error) { + res, _, err := f.RoundTripHijack(req) return res, err } -// Do continually retries a given request until it succeeds because some -// fronting providers will return a 403 for some domains. Also return the -// underlying net.Conn established. -func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { +// RoundTripHijack loops through all available masquerades, sorted by the one that most +// recently connected, retrying several times on failures. +func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, error) { op := ops.Begin("fronted_roundtrip") defer op.End() @@ -247,15 +250,15 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := d.dial(req.Context()) + conn, m, masqueradeGood, err := f.dialAll(req.Context()) if err != nil { // unable to find good masquerade, fail op.FailIf(err) return nil, nil, err } - provider := d.providerFor(m) + provider := f.providerFor(m) if provider == nil { - log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.ProviderID) + log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.getProviderID()) masqueradeGood(false) continue } @@ -265,11 +268,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e // so it is returned as good. conn.Close() masqueradeGood(true) - err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s", m.ProviderID, originHost) + err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s", + m.getProviderID(), originHost) op.FailIf(err) return nil, nil, err } - log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, m.ProviderID) + log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, m.getProviderID()) reqi, err := cloneRequestWith(req, frontedHost, getBody()) if err != nil { @@ -305,36 +309,13 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e return nil, nil, op.FailIf(errors.New("could not complete request even with retries")) } -func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { - url := *req.URL - url.Host = frontedHost - r, err := http.NewRequest(req.Method, url.String(), body) - if err != nil { - return nil, err - } - - for k, vs := range req.Header { - if !strings.EqualFold(k, "Host") { - v := make([]string, len(vs)) - copy(v, vs) - r.Header[k] = v - } - } - return r, nil -} - -// Dial dials out using a masquerade. If the available masquerade fails, it -// retries with others until it either succeeds or exhausts the available -// masquerades. If successful, it returns a connection to the masquerade, -// the selected masquerade, and a function that the caller can use to -// tell us whether the masquerade is good or not (i.e. if masquerade was good, -// keep it). -func (d *direct) dial(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) { - conn, m, masqueradeGood, err := d.dialWith(ctx, d.masquerades) +// Dial dials out using all available masquerades until one succeeds. +func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { + conn, m, masqueradeGood, err := f.dialAllWith(ctx, f.masquerades) return conn, m, masqueradeGood, err } -func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) { +func (f *fronted) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) { // never take more than a minute trying to find a dialer ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() @@ -343,7 +324,6 @@ func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (n totalMasquerades := len(masqueradesToTry) dialLoop: for _, m := range masqueradesToTry { - // check to see if we've timed out select { case <-ctx.Done(): log.Debugf("Timed out dialing to %v with %v total masquerades", m, totalMasquerades) @@ -351,44 +331,51 @@ dialLoop: default: // okay } - - log.Tracef("Dialing to %v", m) - - // We do the full TLS connection here because in practice the domains at a given IP - // address can change frequently on CDNs, so the certificate may not match what - // we expect. - conn, retriable, err := d.doDial(&m.Masquerade) - masqueradeGood := func(good bool) bool { - if good { - m.markSucceeded() - } else { - m.markFailed() - } - d.markCacheDirty() - return good - } + conn, masqueradeGood, err := f.dialMasquerade(m) if err == nil { - log.Debug("Returning connection") - return conn, m, masqueradeGood, err - } else if !retriable { - log.Debugf("Dropping masquerade: non retryable error: %v", err) - masqueradeGood(false) + return conn, m, masqueradeGood, nil } } return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) } -func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) { +func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { + log.Tracef("Dialing to %v", m) + + // We do the full TLS connection here because in practice the domains at a given IP + // address can change frequently on CDNs, so the certificate may not match what + // we expect. + conn, retriable, err := f.doDial(m) + masqueradeGood := func(good bool) bool { + if good { + m.markSucceeded() + } else { + m.markFailed() + } + f.markCacheDirty() + return good + } + if err == nil { + log.Debugf("Returning connection for masquerade: %v", m) + return conn, masqueradeGood, err + } else if !retriable { + log.Debugf("Dropping masquerade: non retryable error: %v", err) + masqueradeGood(false) + } + return conn, masqueradeGood, err +} + +func (f *fronted) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) + op.Set("masquerade_domain", m.getDomain()) + op.Set("masquerade_ip", m.getIpAddress()) - conn, err = d.dialServerWith(m) + conn, err = m.dial(f.certPool, f.clientHelloID) if err != nil { op.FailIf(err) - log.Debugf("Could not dial to %v, %v", m.IpAddress, err) + log.Debugf("Could not dial to %v, %v", m.getIpAddress(), err) // Don't re-add this candidate if it's any certificate error, as that // will just keep failing and will waste connections. We can't access the underlying // error at this point so just look for "certificate" and "handshake". @@ -401,65 +388,10 @@ func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error } } else { log.Debugf("Got successful connection to: %v", m) - idleTimeout := 70 * time.Second - - log.Debugf("Wrapping connection in idletiming connection: %v", m) - conn = idletiming.Conn(conn, idleTimeout, func() { - log.Debugf("Connection to %v idle for %v, closed", conn.RemoteAddr(), idleTimeout) - }) } return } -func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { - op := ops.Begin("dial_server_with") - defer op.End() - - op.Set("masquerade_domain", m.Domain) - op.Set("masquerade_ip", m.IpAddress) - - tlsConfig := d.frontingTLSConfig(m) - dialTimeout := 10 * time.Second - addr := m.IpAddress - var sendServerNameExtension bool - - if m.SNI != "" { - sendServerNameExtension = true - - op.Set("arbitrary_sni", m.SNI) - tlsConfig.ServerName = m.SNI - tlsConfig.InsecureSkipVerify = true - tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - var verifyHostname string - if m.VerifyHostname != nil { - verifyHostname = *m.VerifyHostname - op.Set("verify_hostname", verifyHostname) - } - return verifyPeerCertificate(rawCerts, d.certPool, verifyHostname) - } - - } - - _, _, err := net.SplitHostPort(addr) - if err != nil { - addr = net.JoinHostPort(addr, "443") - } - - dialer := &tlsdialer.Dialer{ - DoDial: netx.DialTimeout, - Timeout: dialTimeout, - SendServerName: sendServerNameExtension, - Config: tlsConfig, - ClientHelloID: d.clientHelloID, - } - conn, err := dialer.Dial("tcp", addr) - if err != nil && m != nil { - err = fmt.Errorf("unable to dial masquerade %s: %s", m.Domain, err) - op.FailIf(err) - } - return conn, err -} - func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain string) error { if len(rawCerts) == 0 { return fmt.Errorf("no certificates presented") @@ -508,15 +440,6 @@ func generateVerifyOptions(roots *x509.CertPool, domain string) x509.VerifyOptio } } -// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the -// initial TCP connection to the CDN. -func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config { - return &tls.Config{ - ServerName: m.Domain, - RootCAs: d.certPool, - } -} - // frontedHTTPTransport is the transport to use to route to the actual fronted destination domain. // This uses the pre-established connection to the CDN on the fronting domain. func frontedHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripper { @@ -527,6 +450,7 @@ func frontedHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripp }, TLSHandshakeTimeout: 40 * time.Second, DisableKeepAlives: disableKeepAlives, + IdleConnTimeout: 70 * time.Second, }, } } @@ -551,3 +475,21 @@ func (ddf *directTransport) RoundTrip(req *http.Request) (resp *http.Response, e norm.URL.Scheme = "http" return ddf.Transport.RoundTrip(norm) } + +func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { + url := *req.URL + url.Host = frontedHost + r, err := http.NewRequest(req.Method, url.String(), body) + if err != nil { + return nil, err + } + + for k, vs := range req.Header { + if !strings.EqualFold(k, "Host") { + v := make([]string, len(vs)) + copy(v, vs) + r.Header[k] = v + } + } + return r, nil +} diff --git a/direct_test.go b/fronted_test.go similarity index 81% rename from direct_test.go rename to fronted_test.go index 89a3a05..a1a1f3f 100644 --- a/direct_test.go +++ b/fronted_test.go @@ -3,8 +3,10 @@ package fronted import ( "crypto/x509" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -13,25 +15,26 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "testing" "time" + . "github.com/getlantern/waitforserver" + tls "github.com/refraction-networking/utls" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - . "github.com/getlantern/waitforserver" ) -func TestDirectDomainFronting(t *testing.T) { +func TestDirectDomainFrontingWithoutSNIConfig(t *testing.T) { dir, err := os.MkdirTemp("", "direct_test") require.NoError(t, err, "Unable to create temp dir") defer os.RemoveAll(dir) cacheFile := filepath.Join(dir, "cachefile.2") - doTestDomainFronting(t, cacheFile, numberToVetInitially) + doTestDomainFronting(t, cacheFile, 10) time.Sleep(defaultCacheSaveInterval * 2) // Then try again, this time reusing the existing cacheFile but a corrupted version corruptMasquerades(cacheFile) - doTestDomainFronting(t, cacheFile, numberToVetInitially) + doTestDomainFronting(t, cacheFile, 10) } func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { @@ -52,9 +55,10 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { UseArbitrarySNIs: true, ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) - Configure(certs, p, testProviderID, cacheFile) + testContext := NewFrontingContext("TestDirectDomainFrontingWithSNIConfig") + testContext.Configure(certs, p, "akamai", cacheFile) - transport, ok := NewDirect(0) + transport, ok := testContext.NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -80,9 +84,10 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - Configure(certs, p, testProviderID, cacheFile) + testContext := NewFrontingContext("doTestDomainFronting") + testContext.Configure(certs, p, testProviderID, cacheFile) - transport, ok := NewDirect(30 * time.Second) + transport, ok := testContext.NewFronted(30 * time.Second) require.True(t, ok) client := &http.Client{ @@ -91,25 +96,25 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = NewDirect(0) + transport, ok = testContext.NewFronted(30 * time.Second) require.True(t, ok) client = &http.Client{ Transport: transport, } require.True(t, doCheck(client, http.MethodGet, http.StatusOK, getURL)) - instance, ok := DefaultContext.instance.Get(0) + instance, ok := testContext.instance.Get(0) require.True(t, ok) - d := instance.(*direct) + d := instance.(*fronted) - // Check the number of masquerades at the end, waiting up to 30 seconds until we get the right number + // Check the number of masquerades at the end, waiting until we get the right number masqueradesAtEnd := 0 - for i := 0; i < 100; i++ { + for i := 0; i < 1000; i++ { masqueradesAtEnd = len(d.masquerades) if masqueradesAtEnd == expectedMasqueradesAtEnd { break } - time.Sleep(300 * time.Millisecond) + time.Sleep(30 * time.Millisecond) } require.GreaterOrEqual(t, masqueradesAtEnd, expectedMasqueradesAtEnd) return masqueradesAtEnd @@ -135,7 +140,7 @@ func TestLoadCandidates(t *testing.T) { } } - d := &direct{ + d := &fronted{ masquerades: make(sortedMasquerades, 0, len(expected)), } @@ -144,7 +149,7 @@ func TestLoadCandidates(t *testing.T) { actual := make(map[Masquerade]bool) count := 0 for _, m := range d.masquerades { - actual[Masquerade{Domain: m.Domain, IpAddress: m.IpAddress}] = true + actual[Masquerade{Domain: m.getDomain(), IpAddress: m.getIpAddress()}] = true count++ } @@ -233,9 +238,11 @@ func TestHostAliasesBasic(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + testContext := NewFrontingContext("TestHostAliasesBasic") + testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + + rt, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -345,8 +352,9 @@ func TestHostAliasesMulti(t *testing.T) { "sadcloud": p2, } - Configure(certs, providers, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + testContext := NewFrontingContext("TestHostAliasesMulti") + testContext.Configure(certs, providers, "cloudsack", "") + rt, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -470,9 +478,11 @@ func TestPassthrough(t *testing.T) { certs := x509.NewCertPool() certs.AddCert(cloudSack.Certificate()) - Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := NewDirect(10 * time.Second) + testContext := NewFrontingContext("TestPassthrough") + testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + + rt, ok := testContext.NewFronted(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -528,7 +538,7 @@ func TestCustomValidators(t *testing.T) { sadCloudValidator := NewStatusCodeValidator(sadCloudCodes) testURL := "https://abc.forbidden.com/quux" - setup := func(validator ResponseValidator) { + setup := func(ctx *FrontingContext, validator ResponseValidator) { masq := []*Masquerade{{Domain: "example.com", IpAddress: sadCloudAddr}} alias := map[string]string{ "abc.forbidden.com": "abc.sadcloud.io", @@ -542,7 +552,7 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - Configure(certs, providers, "sadcloud", "") + ctx.Configure(certs, providers, "sadcloud", "") } // This error indicates that the validator has discarded all masquerades. @@ -551,32 +561,38 @@ func TestCustomValidators(t *testing.T) { masqueradesExhausted := fmt.Sprintf(`Get "%v": could not complete request even with retries`, testURL) tests := []struct { + name string responseCode int validator ResponseValidator expectedError string }{ // with the default validator, only 403s are rejected { + name: "with default validator, it should reject 403", responseCode: http.StatusForbidden, validator: nil, expectedError: masqueradesExhausted, }, { + name: "with default validator, it should accept 202", responseCode: http.StatusAccepted, validator: nil, expectedError: "", }, { + name: "with default validator, it should accept 402", responseCode: http.StatusPaymentRequired, validator: nil, expectedError: "", }, { + name: "with default validator, it should accept 418", responseCode: http.StatusTeapot, validator: nil, expectedError: "", }, { + name: "with default validator, it should accept 502", responseCode: http.StatusBadGateway, validator: nil, expectedError: "", @@ -584,26 +600,31 @@ func TestCustomValidators(t *testing.T) { // with the custom validator, 403 is allowed, listed codes are rejected { + name: "with custom validator, it should accept 403", responseCode: http.StatusForbidden, validator: sadCloudValidator, expectedError: "", }, { + name: "with custom validator, it should accept 402", responseCode: http.StatusAccepted, validator: sadCloudValidator, expectedError: "", }, { + name: "with custom validator, it should reject and return error for 402", responseCode: http.StatusPaymentRequired, validator: sadCloudValidator, expectedError: masqueradesExhausted, }, { + name: "with custom validator, it should reject and return error for 418", responseCode: http.StatusTeapot, validator: sadCloudValidator, expectedError: masqueradesExhausted, }, { + name: "with custom validator, it should reject and return error for 502", responseCode: http.StatusBadGateway, validator: sadCloudValidator, expectedError: masqueradesExhausted, @@ -611,34 +632,31 @@ func TestCustomValidators(t *testing.T) { } for _, test := range tests { - setup(test.validator) - direct, ok := NewDirect(1 * time.Second) - if !assert.True(t, ok) { - return - } - client := &http.Client{ - Transport: direct, - } + t.Run(test.name, func(t *testing.T) { + testContext := NewFrontingContext(test.name) + setup(testContext, test.validator) + direct, ok := testContext.NewFronted(30 * time.Second) + require.True(t, ok) + client := &http.Client{ + Transport: direct, + } - req, err := http.NewRequest(http.MethodGet, testURL, nil) - if !assert.NoError(t, err) { - return - } - if test.responseCode != http.StatusAccepted { - val := strconv.Itoa(test.responseCode) - log.Debugf("requesting forced response code %s", val) - req.Header.Set(CDNForceFail, val) - } + req, err := http.NewRequest(http.MethodGet, testURL, nil) + require.NoError(t, err) + if test.responseCode != http.StatusAccepted { + val := strconv.Itoa(test.responseCode) + log.Debugf("requesting forced response code %s", val) + req.Header.Set(CDNForceFail, val) + } - res, err := client.Do(req) - if test.expectedError == "" { - if !assert.NoError(t, err) { - continue + res, err := client.Do(req) + if test.expectedError == "" { + require.NoError(t, err) + assert.Equal(t, test.responseCode, res.StatusCode, "Failed to force response status code") + } else { + assert.EqualError(t, err, test.expectedError) } - assert.Equal(t, test.responseCode, res.StatusCode, "Failed to force response status code") - } else { - assert.EqualError(t, err, test.expectedError) - } + }) } } @@ -804,3 +822,151 @@ func TestVerifyPeerCertificate(t *testing.T) { }) } } + +func TestFindWorkingMasquerades(t *testing.T) { + tests := []struct { + name string + masquerades []*mockMasquerade + expectedSuccessful int + expectedMasquerades int + }{ + { + name: "All successful", + masquerades: []*mockMasquerade{ + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockMasquerade("domain2.com", "2.2.2.2", 0, true), + newMockMasquerade("domain3.com", "3.3.3.3", 0, true), + newMockMasquerade("domain4.com", "4.4.4.4", 0, true), + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + }, + expectedSuccessful: 4, + }, + { + name: "Some successful", + masquerades: []*mockMasquerade{ + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + newMockMasquerade("domain2.com", "2.2.2.2", 0, false), + newMockMasquerade("domain3.com", "3.3.3.3", 0, true), + newMockMasquerade("domain4.com", "4.4.4.4", 0, false), + newMockMasquerade("domain1.com", "1.1.1.1", 0, true), + }, + expectedSuccessful: 2, + }, + { + name: "None successful", + masquerades: []*mockMasquerade{ + newMockMasquerade("domain1.com", "1.1.1.1", 0, false), + newMockMasquerade("domain2.com", "2.2.2.2", 0, false), + newMockMasquerade("domain3.com", "3.3.3.3", 0, false), + newMockMasquerade("domain4.com", "4.4.4.4", 0, false), + }, + expectedSuccessful: 0, + }, + { + name: "Batch processing", + masquerades: func() []*mockMasquerade { + var masquerades []*mockMasquerade + for i := 0; i < 50; i++ { + masquerades = append(masquerades, newMockMasquerade(fmt.Sprintf("domain%d.com", i), fmt.Sprintf("1.1.1.%d", i), 0, i%2 == 0)) + } + return masquerades + }(), + expectedSuccessful: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &fronted{} + d.providers = make(map[string]*Provider) + d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) + d.masquerades = make(sortedMasquerades, len(tt.masquerades)) + for i, m := range tt.masquerades { + d.masquerades[i] = m + } + + var successful atomic.Uint32 + d.vetBatch(0, 10, &successful, nil) + + tries := 0 + for successful.Load() < uint32(tt.expectedSuccessful) && tries < 100 { + time.Sleep(30 * time.Millisecond) + tries++ + } + + assert.GreaterOrEqual(t, int(successful.Load()), tt.expectedSuccessful) + }) + } +} + +// Generate a mock of a MasqueradeInterface with a Dial method that can optionally +// return an error after a specified number of milliseconds. +func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade { + return &mockMasquerade{ + Domain: domain, + IpAddress: ipAddress, + timeout: timeout, + passesCheck: passesCheck, + } +} + +type mockMasquerade struct { + Domain string + IpAddress string + timeout time.Duration + passesCheck bool + lastSucceededTime time.Time +} + +// setLastSucceeded implements MasqueradeInterface. +func (m *mockMasquerade) setLastSucceeded(succeededTime time.Time) { + m.lastSucceededTime = succeededTime +} + +// lastSucceeded implements MasqueradeInterface. +func (m *mockMasquerade) lastSucceeded() time.Time { + return m.lastSucceededTime +} + +// postCheck implements MasqueradeInterface. +func (m *mockMasquerade) postCheck(net.Conn, string) bool { + return m.passesCheck +} + +// dial implements MasqueradeInterface. +func (m *mockMasquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { + if m.timeout > 0 { + time.Sleep(m.timeout) + return nil, errors.New("mock dial error") + } + m.lastSucceededTime = time.Now() + return &net.TCPConn{}, nil +} + +// getDomain implements MasqueradeInterface. +func (m *mockMasquerade) getDomain() string { + return m.Domain +} + +// getIpAddress implements MasqueradeInterface. +func (m *mockMasquerade) getIpAddress() string { + return m.IpAddress +} + +// getProviderID implements MasqueradeInterface. +func (m *mockMasquerade) getProviderID() string { + return "testProviderId" +} + +// markFailed implements MasqueradeInterface. +func (m *mockMasquerade) markFailed() { + +} + +// markSucceeded implements MasqueradeInterface. +func (m *mockMasquerade) markSucceeded() { +} + +// Make sure that the mockMasquerade implements the MasqueradeInterface +var _ MasqueradeInterface = (*mockMasquerade)(nil) diff --git a/masquerade.go b/masquerade.go index b57a6fb..9dcdd52 100644 --- a/masquerade.go +++ b/masquerade.go @@ -2,14 +2,22 @@ package fronted import ( "crypto/sha256" + "crypto/x509" "encoding/json" + "errors" "fmt" + "io" "net" "net/http" "sort" "strings" "sync" "time" + + "github.com/getlantern/netx" + "github.com/getlantern/ops" + "github.com/getlantern/tlsdialer/v3" + tls "github.com/refraction-networking/utls" ) const ( @@ -44,6 +52,29 @@ type Masquerade struct { VerifyHostname *string } +// Create a masquerade interface for easier testing. +type MasqueradeInterface interface { + dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) + + // Accessor for the domain of the masquerade + getDomain() string + + //Accessor for the IP address of the masquerade + getIpAddress() string + + markSucceeded() + + markFailed() + + lastSucceeded() time.Time + + setLastSucceeded(time.Time) + + postCheck(net.Conn, string) bool + + getProviderID() string +} + type masquerade struct { Masquerade // lastSucceeded: the most recent time at which this Masquerade succeeded @@ -53,6 +84,98 @@ type masquerade struct { mx sync.RWMutex } +func (m *masquerade) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (net.Conn, error) { + tlsConfig := &tls.Config{ + ServerName: m.Domain, + RootCAs: rootCAs, + } + dialTimeout := 10 * time.Second + addr := m.IpAddress + var sendServerNameExtension bool + if m.SNI != "" { + sendServerNameExtension = true + tlsConfig.ServerName = m.SNI + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + var verifyHostname string + if m.VerifyHostname != nil { + verifyHostname = *m.VerifyHostname + } + return verifyPeerCertificate(rawCerts, rootCAs, verifyHostname) + } + } + dialer := &tlsdialer.Dialer{ + DoDial: netx.DialTimeout, + Timeout: dialTimeout, + SendServerName: sendServerNameExtension, + Config: tlsConfig, + ClientHelloID: clientHelloID, + } + _, _, err := net.SplitHostPort(addr) + if err != nil { + // If there is no port, we default to 443 + addr = net.JoinHostPort(addr, "443") + } + return dialer.Dial("tcp", addr) +} + +// postCheck does a post with invalid data to verify domain-fronting works +func (m *masquerade) postCheck(conn net.Conn, testURL string) bool { + client := &http.Client{ + Transport: frontedHTTPTransport(conn, true), + } + return doCheck(client, http.MethodPost, http.StatusAccepted, testURL) +} + +func doCheck(client *http.Client, method string, expectedStatus int, u string) bool { + op := ops.Begin("check_masquerade") + defer op.End() + + isPost := method == http.MethodPost + var requestBody io.Reader + if isPost { + requestBody = strings.NewReader("a") + } + req, _ := http.NewRequest(method, u, requestBody) + if isPost { + req.Header.Set("Content-Type", "application/json") + } + resp, err := client.Do(req) + if err != nil { + op.FailIf(err) + log.Debugf("Unsuccessful vetting with %v request, discarding masquerade: %v", method, err) + return false + } + if resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + if resp.StatusCode != expectedStatus { + op.Set("response_status", resp.StatusCode) + op.Set("expected_status", expectedStatus) + msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status) + op.FailIf(errors.New(msg)) + log.Debug(msg) + return false + } + return true +} + +// getDomain implements MasqueradeInterface. +func (m *masquerade) getDomain() string { + return m.Domain +} + +// getIpAddress implements MasqueradeInterface. +func (m *masquerade) getIpAddress() string { + return m.IpAddress +} + +// getProviderID implements MasqueradeInterface. +func (m *masquerade) getProviderID() string { + return m.ProviderID +} + // MarshalJSON marshals masquerade into json func (m *masquerade) MarshalJSON() ([]byte, error) { m.mx.RLock() @@ -68,6 +191,12 @@ func (m *masquerade) lastSucceeded() time.Time { return m.LastSucceeded } +func (m *masquerade) setLastSucceeded(t time.Time) { + m.mx.Lock() + defer m.mx.Unlock() + m.LastSucceeded = t +} + func (m *masquerade) markSucceeded() { m.mx.Lock() defer m.mx.Unlock() @@ -80,6 +209,9 @@ func (m *masquerade) markFailed() { m.LastSucceeded = time.Time{} } +// Make sure that the masquerade struct implements the MasqueradeInterface +var _ MasqueradeInterface = (*masquerade)(nil) + // A Direct fronting provider configuration. type Provider struct { // Specific hostname mappings used for this provider. @@ -211,7 +343,7 @@ func NewStatusCodeValidator(reject []int) ResponseValidator { } // slice of masquerade sorted by last vetted time -type sortedMasquerades []*masquerade +type sortedMasquerades []MasqueradeInterface func (m sortedMasquerades) Len() int { return len(m) } func (m sortedMasquerades) Swap(i, j int) { m[i], m[j] = m[j], m[i] } @@ -221,7 +353,7 @@ func (m sortedMasquerades) Less(i, j int) bool { } else if m[j].lastSucceeded().After(m[i].lastSucceeded()) { return false } else { - return m[i].IpAddress < m[j].IpAddress + return m[i].getIpAddress() < m[j].getIpAddress() } } diff --git a/test_support.go b/test_support.go index 54da8a7..6f787a0 100644 --- a/test_support.go +++ b/test_support.go @@ -58,6 +58,6 @@ func testProvidersWithHosts(hosts map[string]string) map[string]*Provider { } func testAkamaiProvidersWithHosts(hosts map[string]string, sniConfig *SNIConfig) map[string]*Provider { return map[string]*Provider{ - testProviderID: NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig, nil), + "akamai": NewProvider(hosts, "https://fronted-ping.dsa.akamai.getiantem.org/ping", DefaultAkamaiMasquerades, nil, nil, sniConfig, nil), } }