Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-sort masquerades while iterating to always dial the best #49

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions fronted.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider,
return nil, fmt.Errorf("no masquerades found in providers")
}

// copy providers
providersCopy := make(map[string]*Provider, len(providers))
for k, p := range providers {
providersCopy[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname)
}

f := &fronted{
certPool: pool,
masquerades: make(sortedMasquerades, 0, size),
masquerades: loadMasquerades(providersCopy, size),
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
maxCacheSize: defaultMaxCacheSize,
cacheSaveInterval: defaultCacheSaveInterval,
cacheDirty: make(chan interface{}, 1),
cacheClosed: make(chan interface{}),
defaultProviderID: defaultProviderID,
providers: make(map[string]*Provider),
providers: providersCopy,
clientHelloID: clientHelloID,
}

// copy providers
for k, p := range providers {
f.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname)
}

f.loadCandidates(f.providers)
if cacheFile != "" {
f.initCaching(cacheFile)
}
Expand All @@ -89,14 +89,14 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider,
return f, nil
}

func (f *fronted) loadCandidates(initial map[string]*Provider) {
func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades {
log.Debugf("Loading candidates for %d providers", len(initial))
defer log.Debug("Finished loading candidates")

masquerades := make(sortedMasquerades, 0, size)
for key, p := range initial {
arr := p.Masquerades
size := len(arr)
log.Debugf("Adding %d candidates for %v", size, key)

// make a shuffled copy of arr
// ('inside-out' Fisher-Yates)
Expand All @@ -108,9 +108,10 @@ func (f *fronted) loadCandidates(initial map[string]*Provider) {
}

for _, c := range sh {
f.masquerades = append(f.masquerades, &masquerade{Masquerade: *c, ProviderID: key})
masquerades = append(masquerades, &masquerade{Masquerade: *c, ProviderID: key})
}
}
return masquerades
}

func (f *fronted) providerFor(m MasqueradeInterface) *Provider {
Expand Down Expand Up @@ -323,36 +324,59 @@ func (f *fronted) validateMasqueradeWithConn(req *http.Request, conn net.Conn, m

// Dial dials out using all available masquerades until one succeeds.
func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) {
conn, m, masqueradeGood, err := f.dialAllWith(ctx, f.masquerades)
return conn, m, masqueradeGood, err
}

func (f *fronted) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) {
defer func(op ops.Op) { op.End() }(ops.Begin("dial_all_with"))
defer func(op ops.Op) { op.End() }(ops.Begin("dial_all"))
// never take more than a minute trying to find a dialer
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()

masqueradesToTry := masquerades.sortedCopy()
triedMasquerades := make(map[MasqueradeInterface]bool)
masqueradesToTry := f.masquerades.sortedCopy()
totalMasquerades := len(masqueradesToTry)
dialLoop:
for _, m := range masqueradesToTry {
// Loop through up to len(masqueradesToTry) times, trying each masquerade in turn.
// If the context is done, return an error.
for i := 0; i < totalMasquerades; i++ {
select {
case <-ctx.Done():
log.Debugf("Timed out dialing to %v with %v total masquerades", m, totalMasquerades)
log.Debugf("Timed out dialing with %v total masquerades", totalMasquerades)
break dialLoop
default:
// okay
}

m, err := f.masqueradeToTry(masqueradesToTry, triedMasquerades)
if err != nil {
log.Errorf("No masquerades left to try")
break dialLoop
}
conn, masqueradeGood, err := f.dialMasquerade(m)
if err == nil {
return conn, m, masqueradeGood, nil
triedMasquerades[m] = true
if err != nil {
log.Debugf("Could not dial to %v: %v", m, err)
// As we're looping through the masquerades, each check takes time. As that's happening,
// other goroutines may be successfully vetting new masquerades, which will change the
// sorting. We want to make sure we're always trying the best masquerades first.
masqueradesToTry = f.masquerades.sortedCopy()
totalMasquerades = len(masqueradesToTry)
continue
}
return conn, m, masqueradeGood, nil
}

return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades)
}

func (f *fronted) masqueradeToTry(masquerades sortedMasquerades, triedMasquerades map[MasqueradeInterface]bool) (MasqueradeInterface, error) {
for _, m := range masquerades {
if triedMasquerades[m] {
continue
}
return m, nil
}
// This should be quite rare, as it means we've tried typically thousands of masquerades.
return nil, errors.New("no masquerades left to try")
}

func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) {
log.Tracef("Dialing to %v", m)

Expand Down
152 changes: 144 additions & 8 deletions fronted_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fronted

import (
"context"
"crypto/x509"
"encoding/json"
"errors"
Expand Down Expand Up @@ -131,7 +132,7 @@ func TestVet(t *testing.T) {
t.Fatal("None of the default masquerades vetted successfully")
}

func TestLoadCandidates(t *testing.T) {
func TestLoadMasquerades(t *testing.T) {
providers := testProviders()

expected := make(map[Masquerade]bool)
Expand All @@ -141,12 +142,12 @@ func TestLoadCandidates(t *testing.T) {
}
}

newMasquerades := loadMasquerades(providers, len(expected))

d := &fronted{
masquerades: make(sortedMasquerades, 0, len(expected)),
masquerades: newMasquerades,
}

d.loadCandidates(providers)

actual := make(map[Masquerade]bool)
count := 0
for _, m := range d.masquerades {
Expand Down Expand Up @@ -901,14 +902,149 @@ func TestFindWorkingMasquerades(t *testing.T) {
}
}

func TestMasqueradeToTry(t *testing.T) {
min := time.Now().Add(-time.Minute)
hour := time.Now().Add(-time.Hour)
domain1 := newMockMasqueradeWithLastSuccess("domain1.com", "1.1.1.1", 0, true, min)
domain2 := newMockMasqueradeWithLastSuccess("domain2.com", "2.2.2.2", 0, true, hour)
tests := []struct {
name string
masquerades sortedMasquerades
triedMasquerades map[MasqueradeInterface]bool
expected MasqueradeInterface
}{
{
name: "No tried masquerades",
masquerades: sortedMasquerades{
domain1,
domain2,
},
triedMasquerades: map[MasqueradeInterface]bool{},
expected: domain1,
},
{
name: "Some tried masquerades",
masquerades: sortedMasquerades{
domain1,
domain2,
},
triedMasquerades: map[MasqueradeInterface]bool{
domain1: true,
},
expected: domain2,
},
{
name: "All masquerades tried",
masquerades: sortedMasquerades{
domain1,
domain2,
},
triedMasquerades: map[MasqueradeInterface]bool{
domain1: true,
domain2: true,
},
expected: nil,
},
{
name: "Empty masquerades list",
masquerades: sortedMasquerades{},
triedMasquerades: map[MasqueradeInterface]bool{},
expected: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fronted{}
masquerades := tt.masquerades.sortedCopy()
result, _ := f.masqueradeToTry(masquerades, tt.triedMasquerades)
assert.Equal(t, tt.expected, result)
})
}
}

func TestDialAll(t *testing.T) {
tests := []struct {
name string
masquerades []*mockMasquerade
expectedSuccessful bool
expectedMasquerades int
}{
{
name: "All successful",
masquerades: []*mockMasquerade{
newMockMasquerade("domain1.com", "1.1.1.1", 0, true),
newMockMasquerade("domain2.com", "2.2.2.2", 0, true),
newMockMasquerade("domain3.com", "3.3.3.3", 0, true),
newMockMasquerade("domain4.com", "4.4.4.4", 0, true),
},
expectedSuccessful: true,
},
{
name: "Some successful",
masquerades: []*mockMasquerade{
newMockMasquerade("domain1.com", "1.1.1.1", 0, true),
newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false),
newMockMasquerade("domain3.com", "3.3.3.3", 0, true),
newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false),
},
expectedSuccessful: true,
},
{
name: "None successful",
masquerades: []*mockMasquerade{
newMockMasquerade("domain1.com", "1.1.1.1", 1*time.Millisecond, false),
newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false),
newMockMasquerade("domain3.com", "3.3.3.3", 1*time.Millisecond, false),
newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false),
},
expectedSuccessful: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &fronted{}
d.providers = make(map[string]*Provider)
d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil)
d.masquerades = make(sortedMasquerades, len(tt.masquerades))
for i, m := range tt.masquerades {
d.masquerades[i] = m
}

ctx := context.Background()
conn, m, masqueradeGood, err := d.dialAll(ctx)

if tt.expectedSuccessful {
assert.NoError(t, err)
assert.NotNil(t, conn)
assert.NotNil(t, m)
assert.NotNil(t, masqueradeGood)
} else {
assert.Error(t, err)
assert.Nil(t, conn)
assert.Nil(t, m)
assert.Nil(t, masqueradeGood)
}
})
}
}

// Generate a mock of a MasqueradeInterface with a Dial method that can optionally
// return an error after a specified number of milliseconds.
func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade {
return newMockMasqueradeWithLastSuccess(domain, ipAddress, timeout, passesCheck, time.Time{})
}

// Generate a mock of a MasqueradeInterface with a Dial method that can optionally
// return an error after a specified number of milliseconds.
func newMockMasqueradeWithLastSuccess(domain string, ipAddress string, timeout time.Duration, passesCheck bool, lastSucceededTime time.Time) *mockMasquerade {
return &mockMasquerade{
Domain: domain,
IpAddress: ipAddress,
timeout: timeout,
passesCheck: passesCheck,
Domain: domain,
IpAddress: ipAddress,
timeout: timeout,
passesCheck: passesCheck,
lastSucceededTime: lastSucceededTime,
}
}

Expand Down
Loading