Skip to content

Commit

Permalink
Merge pull request #49 from getlantern/myles/resort-masquerades
Browse files Browse the repository at this point in the history
Re-sort masquerades while iterating to always dial the best
  • Loading branch information
myleshorton authored Nov 20, 2024
2 parents 3d853e4 + 3cf2858 commit b3b97df
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 30 deletions.
68 changes: 46 additions & 22 deletions fronted.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider,
return nil, fmt.Errorf("no masquerades found in providers")
}

// copy providers
providersCopy := make(map[string]*Provider, len(providers))
for k, p := range providers {
providersCopy[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname)
}

f := &fronted{
certPool: pool,
masquerades: make(sortedMasquerades, 0, size),
masquerades: loadMasquerades(providersCopy, size),
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
maxCacheSize: defaultMaxCacheSize,
cacheSaveInterval: defaultCacheSaveInterval,
cacheDirty: make(chan interface{}, 1),
cacheClosed: make(chan interface{}),
defaultProviderID: defaultProviderID,
providers: make(map[string]*Provider),
providers: providersCopy,
clientHelloID: clientHelloID,
}

// copy providers
for k, p := range providers {
f.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname)
}

f.loadCandidates(f.providers)
if cacheFile != "" {
f.initCaching(cacheFile)
}
Expand All @@ -89,14 +89,14 @@ func newFronted(pool *x509.CertPool, providers map[string]*Provider,
return f, nil
}

func (f *fronted) loadCandidates(initial map[string]*Provider) {
func loadMasquerades(initial map[string]*Provider, size int) sortedMasquerades {
log.Debugf("Loading candidates for %d providers", len(initial))
defer log.Debug("Finished loading candidates")

masquerades := make(sortedMasquerades, 0, size)
for key, p := range initial {
arr := p.Masquerades
size := len(arr)
log.Debugf("Adding %d candidates for %v", size, key)

// make a shuffled copy of arr
// ('inside-out' Fisher-Yates)
Expand All @@ -108,9 +108,10 @@ func (f *fronted) loadCandidates(initial map[string]*Provider) {
}

for _, c := range sh {
f.masquerades = append(f.masquerades, &masquerade{Masquerade: *c, ProviderID: key})
masquerades = append(masquerades, &masquerade{Masquerade: *c, ProviderID: key})
}
}
return masquerades
}

func (f *fronted) providerFor(m MasqueradeInterface) *Provider {
Expand Down Expand Up @@ -323,36 +324,59 @@ func (f *fronted) validateMasqueradeWithConn(req *http.Request, conn net.Conn, m

// Dial dials out using all available masquerades until one succeeds.
func (f *fronted) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) {
conn, m, masqueradeGood, err := f.dialAllWith(ctx, f.masquerades)
return conn, m, masqueradeGood, err
}

func (f *fronted) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) {
defer func(op ops.Op) { op.End() }(ops.Begin("dial_all_with"))
defer func(op ops.Op) { op.End() }(ops.Begin("dial_all"))
// never take more than a minute trying to find a dialer
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()

masqueradesToTry := masquerades.sortedCopy()
triedMasquerades := make(map[MasqueradeInterface]bool)
masqueradesToTry := f.masquerades.sortedCopy()
totalMasquerades := len(masqueradesToTry)
dialLoop:
for _, m := range masqueradesToTry {
// Loop through up to len(masqueradesToTry) times, trying each masquerade in turn.
// If the context is done, return an error.
for i := 0; i < totalMasquerades; i++ {
select {
case <-ctx.Done():
log.Debugf("Timed out dialing to %v with %v total masquerades", m, totalMasquerades)
log.Debugf("Timed out dialing with %v total masquerades", totalMasquerades)
break dialLoop
default:
// okay
}

m, err := f.masqueradeToTry(masqueradesToTry, triedMasquerades)
if err != nil {
log.Errorf("No masquerades left to try")
break dialLoop
}
conn, masqueradeGood, err := f.dialMasquerade(m)
if err == nil {
return conn, m, masqueradeGood, nil
triedMasquerades[m] = true
if err != nil {
log.Debugf("Could not dial to %v: %v", m, err)
// As we're looping through the masquerades, each check takes time. As that's happening,
// other goroutines may be successfully vetting new masquerades, which will change the
// sorting. We want to make sure we're always trying the best masquerades first.
masqueradesToTry = f.masquerades.sortedCopy()
totalMasquerades = len(masqueradesToTry)
continue
}
return conn, m, masqueradeGood, nil
}

return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades)
}

func (f *fronted) masqueradeToTry(masquerades sortedMasquerades, triedMasquerades map[MasqueradeInterface]bool) (MasqueradeInterface, error) {
for _, m := range masquerades {
if triedMasquerades[m] {
continue
}
return m, nil
}
// This should be quite rare, as it means we've tried typically thousands of masquerades.
return nil, errors.New("no masquerades left to try")
}

func (f *fronted) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) {
log.Tracef("Dialing to %v", m)

Expand Down
152 changes: 144 additions & 8 deletions fronted_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fronted

import (
"context"
"crypto/x509"
"encoding/json"
"errors"
Expand Down Expand Up @@ -131,7 +132,7 @@ func TestVet(t *testing.T) {
t.Fatal("None of the default masquerades vetted successfully")
}

func TestLoadCandidates(t *testing.T) {
func TestLoadMasquerades(t *testing.T) {
providers := testProviders()

expected := make(map[Masquerade]bool)
Expand All @@ -141,12 +142,12 @@ func TestLoadCandidates(t *testing.T) {
}
}

newMasquerades := loadMasquerades(providers, len(expected))

d := &fronted{
masquerades: make(sortedMasquerades, 0, len(expected)),
masquerades: newMasquerades,
}

d.loadCandidates(providers)

actual := make(map[Masquerade]bool)
count := 0
for _, m := range d.masquerades {
Expand Down Expand Up @@ -901,14 +902,149 @@ func TestFindWorkingMasquerades(t *testing.T) {
}
}

func TestMasqueradeToTry(t *testing.T) {
min := time.Now().Add(-time.Minute)
hour := time.Now().Add(-time.Hour)
domain1 := newMockMasqueradeWithLastSuccess("domain1.com", "1.1.1.1", 0, true, min)
domain2 := newMockMasqueradeWithLastSuccess("domain2.com", "2.2.2.2", 0, true, hour)
tests := []struct {
name string
masquerades sortedMasquerades
triedMasquerades map[MasqueradeInterface]bool
expected MasqueradeInterface
}{
{
name: "No tried masquerades",
masquerades: sortedMasquerades{
domain1,
domain2,
},
triedMasquerades: map[MasqueradeInterface]bool{},
expected: domain1,
},
{
name: "Some tried masquerades",
masquerades: sortedMasquerades{
domain1,
domain2,
},
triedMasquerades: map[MasqueradeInterface]bool{
domain1: true,
},
expected: domain2,
},
{
name: "All masquerades tried",
masquerades: sortedMasquerades{
domain1,
domain2,
},
triedMasquerades: map[MasqueradeInterface]bool{
domain1: true,
domain2: true,
},
expected: nil,
},
{
name: "Empty masquerades list",
masquerades: sortedMasquerades{},
triedMasquerades: map[MasqueradeInterface]bool{},
expected: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &fronted{}
masquerades := tt.masquerades.sortedCopy()
result, _ := f.masqueradeToTry(masquerades, tt.triedMasquerades)
assert.Equal(t, tt.expected, result)
})
}
}

func TestDialAll(t *testing.T) {
tests := []struct {
name string
masquerades []*mockMasquerade
expectedSuccessful bool
expectedMasquerades int
}{
{
name: "All successful",
masquerades: []*mockMasquerade{
newMockMasquerade("domain1.com", "1.1.1.1", 0, true),
newMockMasquerade("domain2.com", "2.2.2.2", 0, true),
newMockMasquerade("domain3.com", "3.3.3.3", 0, true),
newMockMasquerade("domain4.com", "4.4.4.4", 0, true),
},
expectedSuccessful: true,
},
{
name: "Some successful",
masquerades: []*mockMasquerade{
newMockMasquerade("domain1.com", "1.1.1.1", 0, true),
newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false),
newMockMasquerade("domain3.com", "3.3.3.3", 0, true),
newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false),
},
expectedSuccessful: true,
},
{
name: "None successful",
masquerades: []*mockMasquerade{
newMockMasquerade("domain1.com", "1.1.1.1", 1*time.Millisecond, false),
newMockMasquerade("domain2.com", "2.2.2.2", 1*time.Millisecond, false),
newMockMasquerade("domain3.com", "3.3.3.3", 1*time.Millisecond, false),
newMockMasquerade("domain4.com", "4.4.4.4", 1*time.Millisecond, false),
},
expectedSuccessful: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &fronted{}
d.providers = make(map[string]*Provider)
d.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, nil)
d.masquerades = make(sortedMasquerades, len(tt.masquerades))
for i, m := range tt.masquerades {
d.masquerades[i] = m
}

ctx := context.Background()
conn, m, masqueradeGood, err := d.dialAll(ctx)

if tt.expectedSuccessful {
assert.NoError(t, err)
assert.NotNil(t, conn)
assert.NotNil(t, m)
assert.NotNil(t, masqueradeGood)
} else {
assert.Error(t, err)
assert.Nil(t, conn)
assert.Nil(t, m)
assert.Nil(t, masqueradeGood)
}
})
}
}

// Generate a mock of a MasqueradeInterface with a Dial method that can optionally
// return an error after a specified number of milliseconds.
func newMockMasquerade(domain string, ipAddress string, timeout time.Duration, passesCheck bool) *mockMasquerade {
return newMockMasqueradeWithLastSuccess(domain, ipAddress, timeout, passesCheck, time.Time{})
}

// Generate a mock of a MasqueradeInterface with a Dial method that can optionally
// return an error after a specified number of milliseconds.
func newMockMasqueradeWithLastSuccess(domain string, ipAddress string, timeout time.Duration, passesCheck bool, lastSucceededTime time.Time) *mockMasquerade {
return &mockMasquerade{
Domain: domain,
IpAddress: ipAddress,
timeout: timeout,
passesCheck: passesCheck,
Domain: domain,
IpAddress: ipAddress,
timeout: timeout,
passesCheck: passesCheck,
lastSucceededTime: lastSucceededTime,
}
}

Expand Down

0 comments on commit b3b97df

Please sign in to comment.