diff --git a/connecting_fronts.go b/connecting_fronts.go new file mode 100644 index 0000000..ec4611d --- /dev/null +++ b/connecting_fronts.go @@ -0,0 +1,68 @@ +package fronted + +import ( + "errors" + "sort" + "sync" + "time" +) + +type connectTimeFront struct { + MasqueradeInterface + connectTime time.Duration +} + +type connectingFronts struct { + fronts []connectTimeFront + //frontsChan chan MasqueradeInterface + sync.RWMutex +} + +// Make sure that connectingFronts is a connectListener +var _ workingFronts = &connectingFronts{} + +// newConnectingFronts creates a new ConnectingFronts struct with an empty slice of Masquerade IPs and domains. +func newConnectingFronts() *connectingFronts { + return &connectingFronts{ + fronts: make([]connectTimeFront, 0), + //frontsChan: make(chan MasqueradeInterface), + } +} + +// AddFront adds a new front to the list of fronts. +func (cf *connectingFronts) onConnected(m MasqueradeInterface, connectTime time.Duration) { + cf.Lock() + defer cf.Unlock() + + cf.fronts = append(cf.fronts, connectTimeFront{ + MasqueradeInterface: m, + connectTime: connectTime, + }) + // Sort fronts by connect time. + sort.Slice(cf.fronts, func(i, j int) bool { + return cf.fronts[i].connectTime < cf.fronts[j].connectTime + }) + //cf.frontsChan <- m +} + +func (cf *connectingFronts) onError(m MasqueradeInterface) { + cf.Lock() + defer cf.Unlock() + + // Remove the front from connecting fronts. + for i, front := range cf.fronts { + if front.MasqueradeInterface == m { + cf.fronts = append(cf.fronts[:i], cf.fronts[i+1:]...) + return + } + } +} + +func (cf *connectingFronts) workingFront() (MasqueradeInterface, error) { + cf.RLock() + defer cf.RUnlock() + if len(cf.fronts) == 0 { + return nil, errors.New("no fronts available") + } + return cf.fronts[0].MasqueradeInterface, nil +} diff --git a/context.go b/context.go index eb8d78a..02853a7 100644 --- a/context.go +++ b/context.go @@ -12,53 +12,60 @@ import ( "github.com/getlantern/eventual/v2" ) +// Create an interface for the fronting context +type Fronting interface { + UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) + NewRoundTripper(timeout time.Duration) (http.RoundTripper, bool) + Close() +} + var defaultContext = newFrontingContext("default") +// Make sure that the default context is a Fronting +var _ Fronting = defaultContext + // Configure sets the masquerades to use, the trusted root CAs, and the // cache file for caching masquerades to set up direct domain fronting // in the default context. // // defaultProviderID is used when a masquerade without a provider is // encountered (eg in a cache file) -func Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) { - if err := defaultContext.Configure(pool, providers, defaultProviderID, cacheFile); err != nil { - log.Errorf("Error configuring fronting %s context: %s!!", defaultContext.name, err) +func NewFronter(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) (Fronting, error) { + if err := defaultContext.configure(pool, providers, defaultProviderID, cacheFile); err != nil { + return nil, log.Errorf("Error configuring fronting %s context: %s!!", defaultContext.name, err) } -} - -// NewFronted creates a new http.RoundTripper that does direct domain fronting -// using the default context. If the default context isn't configured within -// the given timeout, this method returns nil, false. -func NewFronted(timeout time.Duration) (http.RoundTripper, bool) { - return defaultContext.NewFronted(timeout) -} - -// Close closes any existing cache file in the default context -func Close() { - defaultContext.Close() + return defaultContext, nil } func newFrontingContext(name string) *frontingContext { return &frontingContext{ - name: name, - instance: eventual.NewValue(), + name: name, + instance: eventual.NewValue(), + connectingFronts: newConnectingFronts(), } } type frontingContext struct { - name string - instance eventual.Value + name string + instance eventual.Value + fronted *fronted + connectingFronts *connectingFronts } -// Configure sets the masquerades to use, the trusted root CAs, and the +// UpdateConfig updates the configuration of the fronting context +func (fctx *frontingContext) UpdateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { + fctx.fronted.updateConfig(pool, providers, defaultProviderID) +} + +// configure sets the masquerades to use, the trusted root CAs, and the // cache file for caching masquerades to set up direct domain fronting. // defaultProviderID is used when a masquerade without a provider is // encountered (eg in a cache file) -func (fctx *frontingContext) Configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) error { - return fctx.ConfigureWithHello(pool, providers, defaultProviderID, cacheFile, tls.ClientHelloID{}) +func (fctx *frontingContext) configure(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string) error { + return fctx.configureWithHello(pool, providers, defaultProviderID, cacheFile, tls.ClientHelloID{}) } -func (fctx *frontingContext) ConfigureWithHello(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, clientHelloID tls.ClientHelloID) error { +func (fctx *frontingContext) configureWithHello(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string, cacheFile string, clientHelloID tls.ClientHelloID) error { log.Debugf("Configuring fronted %s context", fctx.name) if len(providers) == 0 { @@ -73,10 +80,11 @@ func (fctx *frontingContext) ConfigureWithHello(pool *x509.CertPool, providers m existing.closeCache() } - _, err := newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { + var err error + fctx.fronted, err = newFronted(pool, providers, defaultProviderID, cacheFile, clientHelloID, func(f *fronted) { log.Debug("Setting fronted instance") fctx.instance.Set(f) - }) + }, fctx.connectingFronts) if err != nil { return err } @@ -86,7 +94,7 @@ func (fctx *frontingContext) ConfigureWithHello(pool *x509.CertPool, providers m // NewFronted creates a new http.RoundTripper that does direct domain fronting. // If the context isn't configured within the given timeout, this method // returns nil, false. -func (fctx *frontingContext) NewFronted(timeout time.Duration) (http.RoundTripper, bool) { +func (fctx *frontingContext) NewRoundTripper(timeout time.Duration) (http.RoundTripper, bool) { start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() diff --git a/fronted.go b/fronted.go index 71ed6b9..4092788 100644 --- a/fronted.go +++ b/fronted.go @@ -2,7 +2,6 @@ package fronted import ( "bytes" - "context" "crypto/x509" "errors" "fmt" @@ -34,6 +33,12 @@ var ( log = golog.LoggerFor("fronted") ) +type workingFronts interface { + onConnected(m MasqueradeInterface, connectTime time.Duration) + onError(m MasqueradeInterface) + workingFront() (MasqueradeInterface, error) +} + // fronted identifies working IP address/domain pairings for domain fronting and is // an implementation of http.RoundTripper for the convenience of callers. type fronted struct { @@ -48,11 +53,16 @@ type fronted struct { defaultProviderID string providers map[string]*Provider clientHelloID tls.ClientHelloID + workingFronts workingFronts + providersMu sync.RWMutex + masqueradesMu sync.RWMutex + frontedMu sync.RWMutex } func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID, cacheFile string, clientHelloID tls.ClientHelloID, - listener func(f *fronted)) (*fronted, error) { + listener func(f *fronted), + workingFronts workingFronts) (*fronted, error) { size := 0 for _, p := range providers { size += len(p.Masquerades) @@ -63,10 +73,7 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, } // copy providers - providersCopy := make(map[string]*Provider, len(providers)) - for k, p := range providers { - providersCopy[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) - } + providersCopy := copyProviders(providers) f := &fronted{ certPool: pool, @@ -79,6 +86,7 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID: defaultProviderID, providers: providersCopy, clientHelloID: clientHelloID, + workingFronts: workingFronts, } if cacheFile != "" { @@ -89,6 +97,15 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider, return f, nil } +func copyProviders(providers map[string]*Provider) map[string]*Provider { + providersCopy := make(map[string]*Provider, len(providers)) + for key, p := range providers { + providersCopy[key] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname) + log.Debugf("Domain fronting provider is %v", providersCopy[key]) + } + return providersCopy +} + func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { log.Debugf("Loading candidates for %d providers", len(initial)) defer log.Debug("Finished loading candidates") @@ -114,6 +131,34 @@ func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades { return masquerades } +func (f *fronted) updateConfig(pool *x509.CertPool, providers map[string]*Provider, defaultProviderID string) { + // Make copies just to avoid any concurrency issues with access that may be happening on the + // caller side. + providersCopy := copyProviders(providers) + f.frontedMu.Lock() + defer f.frontedMu.Unlock() + f.addProviders(providersCopy) + f.addMasquerades(loadMasquerades(providersCopy, len(providersCopy))) + f.defaultProviderID = defaultProviderID + f.certPool = pool +} + +func (f *fronted) addProviders(providers map[string]*Provider) { + // Add new providers to the existing providers map, overwriting any existing ones. + f.providersMu.Lock() + defer f.providersMu.Unlock() + for key, p := range providers { + f.providers[key] = p + } +} + +func (f *fronted) addMasquerades(masquerades sortedMasquerades) { + // Add new masquerades to the existing masquerades slice, but add them at the beginning. + f.masqueradesMu.Lock() + defer f.masqueradesMu.Unlock() + f.masquerades = append(masquerades, f.masquerades...) +} + func (f *fronted) providerFor(m MasqueradeInterface) *Provider { pid := m.getProviderID() if pid == "" { @@ -149,35 +194,51 @@ func (f *fronted) findWorkingMasquerades(listener func(f *fronted)) { var successful atomic.Uint32 // We loop through all of them until we have 4 successful ones. - for i := 0; i < len(f.masquerades) && successful.Load() < 4; i += batchSize { + for i := 0; i < f.masqueradeSize() && successful.Load() < 4; i += batchSize { f.vetBatch(i, batchSize, &successful, listener) } } +func (f *fronted) masqueradeSize() int { + f.masqueradesMu.Lock() + defer f.masqueradesMu.Unlock() + return len(f.masquerades) +} + +func (f *fronted) masqueradeAt(i int) MasqueradeInterface { + f.masqueradesMu.Lock() + defer f.masqueradesMu.Unlock() + return f.masquerades[i] +} + func (f *fronted) vetBatch(start, batchSize int, successful *atomic.Uint32, listener func(f *fronted)) { log.Debugf("Vetting masquerade batch %d-%d", start, start+batchSize) var wg sync.WaitGroup - masqueradeSize := len(f.masquerades) - for i := start; i < start+batchSize && i < masqueradeSize; i++ { + for i := start; i < start+batchSize && i < f.masqueradeSize(); i++ { wg.Add(1) go func(m MasqueradeInterface) { defer wg.Done() - if f.vetMasquerade(m) { + working, connectTime := f.vetMasquerade(m) + if working { successful.Add(1) + f.workingFronts.onConnected(m, connectTime) if listener != nil { go listener(f) } + } else { + f.workingFronts.onError(m) } - }(f.masquerades[i]) + }(f.masqueradeAt(i)) } wg.Wait() } -func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { +func (f *fronted) vetMasquerade(m MasqueradeInterface) (bool, time.Duration) { + start := time.Now() conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { log.Debugf("unexpected error vetting masquerades: %v", err) - return false + return false, time.Since(start) } defer func() { if conn != nil { @@ -189,15 +250,15 @@ func (f *fronted) vetMasquerade(m MasqueradeInterface) bool { if provider == nil { log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v", m.getProviderID(), f.providers) - return false + return false, time.Since(start) } if !masqueradeGood(m.postCheck(conn, provider.TestURL)) { log.Debugf("Unsuccessful vetting with POST request, discarding masquerade") - return false + return false, time.Since(start) } log.Debugf("Successfully vetted one masquerade %v", m) - return true + return true, time.Since(start) } // RoundTrip loops through all available masquerades, sorted by the one that most recently @@ -242,37 +303,41 @@ func (f *fronted) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, return io.NopCloser(bytes.NewReader(body)) } - tries := 1 - if isIdempotent { - tries = maxTries - } + const tries = 6 for i := 0; i < tries; i++ { if i > 0 { log.Debugf("Retrying domain-fronted request, pass %d", i) } - conn, m, masqueradeGood, err := f.dialAll(req.Context()) + m, err := f.workingFronts.workingFront() if err != nil { - // unable to find good masquerade, fail - op.FailIf(err) - return nil, nil, err + // For some reason we have no working fronts. Sleep for a bit and try again. + time.Sleep(1 * time.Second) + continue } - resp, conn, err := f.validateMasqueradeWithConn(req, conn, m, originHost, getBody, masqueradeGood) + conn, masqueradeGood, err := f.dialMasquerade(m) if err != nil { - log.Debugf("Could not complete request: %v", err) + log.Debugf("Could not dial to %v: %v", m, err) + f.workingFronts.onError(m) continue } - return resp, conn, nil + resp, conn, err := f.request(req, conn, m, originHost, getBody, masqueradeGood) + if err != nil { + log.Debugf("Could not complete request: %v", err) + f.workingFronts.onError(m) + } else { + return resp, conn, nil + } } return nil, nil, op.FailIf(errors.New("could not complete request even with retries")) } -func (f *fronted) validateMasqueradeWithConn(req *http.Request, conn net.Conn, m MasqueradeInterface, originHost string, getBody func() io.ReadCloser, masqueradeGood func(bool) bool) (*http.Response, net.Conn, error) { - op := ops.Begin("validate_masquerade_with_conn") +func (f *fronted) request(req *http.Request, conn net.Conn, m MasqueradeInterface, originHost string, getBody func() io.ReadCloser, masqueradeGood func(bool) bool) (*http.Response, net.Conn, error) { + op := ops.Begin("fronted_request") defer op.End() provider := f.providerFor(m) if provider == nil { @@ -322,61 +387,6 @@ func (f *fronted) validateMasqueradeWithConn(req *http.Request, conn net.Conn, m return resp, conn, nil } -// Dial dials out using all available masquerades until one succeeds. -func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) { - defer func(op ops.Op) { op.End() }(ops.Begin("dial_all")) - // never take more than a minute trying to find a dialer - ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() - - triedMasquerades := make(map[MasqueradeInterface]bool) - masqueradesToTry := f.masquerades.sortedCopy() - totalMasquerades := len(masqueradesToTry) -dialLoop: - // Loop through up to len(masqueradesToTry) times, trying each masquerade in turn. - // If the context is done, return an error. - for i := 0; i < totalMasquerades; i++ { - select { - case <-ctx.Done(): - log.Debugf("Timed out dialing with %v total masquerades", totalMasquerades) - break dialLoop - default: - // okay - } - - m, err := f.masqueradeToTry(masqueradesToTry, triedMasquerades) - if err != nil { - log.Errorf("No masquerades left to try") - break dialLoop - } - conn, masqueradeGood, err := f.dialMasquerade(m) - triedMasquerades[m] = true - if err != nil { - log.Debugf("Could not dial to %v: %v", m, err) - // As we're looping through the masquerades, each check takes time. As that's happening, - // other goroutines may be successfully vetting new masquerades, which will change the - // sorting. We want to make sure we're always trying the best masquerades first. - masqueradesToTry = f.masquerades.sortedCopy() - totalMasquerades = len(masqueradesToTry) - continue - } - return conn, m, masqueradeGood, nil - } - - return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades) -} - -func (f *fronted) masqueradeToTry(masquerades sortedMasquerades, triedMasquerades map[MasqueradeInterface]bool) (MasqueradeInterface, error) { - for _, m := range masquerades { - if triedMasquerades[m] { - continue - } - return m, nil - } - // This should be quite rare, as it means we've tried typically thousands of masquerades. - return nil, errors.New("no masquerades left to try") -} - func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) { log.Tracef("Dialing to %v", m) diff --git a/fronted_test.go b/fronted_test.go index 575af02..49e9e66 100644 --- a/fronted_test.go +++ b/fronted_test.go @@ -1,7 +1,6 @@ package fronted import ( - "context" "crypto/x509" "encoding/json" "errors" @@ -58,9 +57,9 @@ func TestDirectDomainFrontingWithSNIConfig(t *testing.T) { ArbitrarySNIs: []string{"mercadopago.com", "amazon.com.br", "facebook.com", "google.com", "twitter.com", "youtube.com", "instagram.com", "linkedin.com", "whatsapp.com", "netflix.com", "microsoft.com", "yahoo.com", "bing.com", "wikipedia.org", "github.com"}, }) testContext := newFrontingContext("TestDirectDomainFrontingWithSNIConfig") - testContext.Configure(certs, p, "akamai", cacheFile) + testContext.configure(certs, p, "akamai", cacheFile) - transport, ok := testContext.NewFronted(30 * time.Second) + transport, ok := testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: transport, @@ -87,9 +86,9 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) testContext := newFrontingContext("doTestDomainFronting") - testContext.Configure(certs, p, testProviderID, cacheFile) + testContext.configure(certs, p, testProviderID, cacheFile) - transport, ok := testContext.NewFronted(30 * time.Second) + transport, ok := testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client := &http.Client{ @@ -98,7 +97,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE } require.True(t, doCheck(client, http.MethodPost, http.StatusAccepted, pingURL)) - transport, ok = testContext.NewFronted(30 * time.Second) + transport, ok = testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client = &http.Client{ Transport: transport, @@ -242,9 +241,9 @@ func TestHostAliasesBasic(t *testing.T) { certs.AddCert(cloudSack.Certificate()) testContext := newFrontingContext("TestHostAliasesBasic") - testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := testContext.NewFronted(30 * time.Second) + rt, ok := testContext.NewRoundTripper(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -355,8 +354,8 @@ func TestHostAliasesMulti(t *testing.T) { } testContext := newFrontingContext("TestHostAliasesMulti") - testContext.Configure(certs, providers, "cloudsack", "") - rt, ok := testContext.NewFronted(30 * time.Second) + testContext.configure(certs, providers, "cloudsack", "") + rt, ok := testContext.NewRoundTripper(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -482,9 +481,9 @@ func TestPassthrough(t *testing.T) { certs.AddCert(cloudSack.Certificate()) testContext := newFrontingContext("TestPassthrough") - testContext.Configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") + testContext.configure(certs, map[string]*Provider{"cloudsack": p}, "cloudsack", "") - rt, ok := testContext.NewFronted(30 * time.Second) + rt, ok := testContext.NewRoundTripper(30 * time.Second) if !assert.True(t, ok, "failed to obtain direct roundtripper") { return } @@ -554,7 +553,7 @@ func TestCustomValidators(t *testing.T) { "sadcloud": p, } - ctx.Configure(certs, providers, "sadcloud", "") + ctx.configure(certs, providers, "sadcloud", "") } // This error indicates that the validator has discarded all masquerades. @@ -637,7 +636,7 @@ func TestCustomValidators(t *testing.T) { t.Run(test.name, func(t *testing.T) { testContext := newFrontingContext(test.name) setup(testContext, test.validator) - direct, ok := testContext.NewFronted(30 * time.Second) + direct, ok := testContext.NewRoundTripper(30 * time.Second) require.True(t, ok) client := &http.Client{ Transport: direct, @@ -902,134 +901,6 @@ func TestFindWorkingMasquerades(t *testing.T) { } } -func TestMasqueradeToTry(t *testing.T) { - min := time.Now().Add(-time.Minute) - hour := time.Now().Add(-time.Hour) - domain1 := newMockMasqueradeWithLastSuccess("domain1.com", "1.1.1.1", 0, true, min) - domain2 := newMockMasqueradeWithLastSuccess("domain2.com", "2.2.2.2", 0, true, hour) - tests := []struct { - name string - masquerades sortedMasquerades - triedMasquerades map[MasqueradeInterface]bool - expected MasqueradeInterface - }{ - { - name: "No tried masquerades", - masquerades: sortedMasquerades{ - domain1, - domain2, - }, - triedMasquerades: map[MasqueradeInterface]bool{}, - expected: domain1, - }, - { - name: "Some tried masquerades", - masquerades: sortedMasquerades{ - domain1, - domain2, - }, - triedMasquerades: map[MasqueradeInterface]bool{ - domain1: true, - }, - expected: domain2, - }, - { - name: "All masquerades tried", - masquerades: sortedMasquerades{ - domain1, - domain2, - }, - triedMasquerades: map[MasqueradeInterface]bool{ - domain1: true, - domain2: true, - }, - expected: nil, - }, - { - name: "Empty masquerades list", - masquerades: sortedMasquerades{}, - triedMasquerades: map[MasqueradeInterface]bool{}, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f := &fronted{} - masquerades := tt.masquerades.sortedCopy() - result, _ := f.masqueradeToTry(masquerades, tt.triedMasquerades) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestDialAll(t *testing.T) { - tests := []struct { - name string - masquerades []*mockMasquerade - expectedSuccessful bool - expectedMasquerades int - }{ - { - name: "All successful", - masquerades: []*mockMasquerade{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain2.com", "2.2.2.2", 0, true), - newMockMasquerade("domain3.com", "3.3.3.3", 0, true), - newMockMasquerade("domain4.com", "4.4.4.4", 0, true), - }, - expectedSuccessful: true, - }, - { - name: "Some successful", - masquerades: []*mockMasquerade{ - newMockMasquerade("domain1.com", "1.1.1.1", 0, true), - newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false), - newMockMasquerade("domain3.com", "3.3.3.3", 0, true), - newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false), - }, - expectedSuccessful: true, - }, - { - name: "None successful", - masquerades: []*mockMasquerade{ - newMockMasquerade("domain1.com", "1.1.1.1", 1*time.Millisecond, false), - newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false), - newMockMasquerade("domain3.com", "3.3.3.3", 1*time.Millisecond, false), - newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false), - }, - expectedSuccessful: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &fronted{} - d.providers = make(map[string]*Provider) - d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil) - d.masquerades = make(sortedMasquerades, len(tt.masquerades)) - for i, m := range tt.masquerades { - d.masquerades[i] = m - } - - ctx := context.Background() - conn, m, masqueradeGood, err := d.dialAll(ctx) - - if tt.expectedSuccessful { - assert.NoError(t, err) - assert.NotNil(t, conn) - assert.NotNil(t, m) - assert.NotNil(t, masqueradeGood) - } else { - assert.Error(t, err) - assert.Nil(t, conn) - assert.Nil(t, m) - assert.Nil(t, masqueradeGood) - } - }) - } -} - // Generate a mock of a MasqueradeInterface with a Dial method that can optionally // return an error after a specified number of milliseconds. func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade { diff --git a/test_support.go b/test_support.go index 6f787a0..d508a62 100644 --- a/test_support.go +++ b/test_support.go @@ -23,13 +23,13 @@ func ConfigureForTest(t *testing.T) { func ConfigureCachingForTest(t *testing.T, cacheFile string) { certs := trustedCACerts(t) p := testProviders() - Configure(certs, p, testProviderID, cacheFile) + NewFronter(certs, p, testProviderID, cacheFile) } func ConfigureHostAlaisesForTest(t *testing.T, hosts map[string]string) { certs := trustedCACerts(t) p := testProvidersWithHosts(hosts) - Configure(certs, p, testProviderID, "") + NewFronter(certs, p, testProviderID, "") } func trustedCACerts(t *testing.T) *x509.CertPool {