Skip to content

Commit

Permalink
Parallelize masquerade lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
myleshorton committed Oct 22, 2024
1 parent 8501716 commit aaeb3c1
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 92 deletions.
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m
if cacheFile != "" {
d.initCaching(cacheFile)
}
go d.vet(numberToVetInitially)
d.findWorkingMasquerades()
fctx.instance.Set(d)
return nil
}
Expand Down
174 changes: 85 additions & 89 deletions direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

tls "github.com/refraction-networking/utls"
Expand All @@ -25,7 +26,6 @@ import (
)

const (
numberToVetInitially = 10
defaultMaxAllowedCachedAge = 24 * time.Hour
defaultMaxCacheSize = 1000
defaultCacheSaveInterval = 5 * time.Second
Expand Down Expand Up @@ -84,72 +84,65 @@ func (d *direct) providerFor(m *masquerade) *Provider {
return d.providers[pid]
}

// Vet vets the specified Masquerade, verifying certificate using the given CertPool
// Vet vets the specified Masquerade, verifying certificate using the given CertPool.
// This is used in genconfig.
func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool {
return vet(m, pool, testURL)
}

func vet(m *Masquerade, pool *x509.CertPool, testURL string) bool {
op := ops.Begin("vet_masquerade")
defer op.End()
op.Set("masquerade_domain", m.Domain)
op.Set("masquerade_ip", m.IpAddress)

d := &direct{
certPool: pool,
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
maxCacheSize: defaultMaxCacheSize,
}
conn, _, err := d.doDial(m)
if err != nil {
op.FailIf(err)
return false
}
defer conn.Close()
return postCheck(conn, testURL)
}

func (d *direct) vet(numberToVet int) {
log.Debugf("Vetting %d initial candidates in series", numberToVet)
for i := 0; i < numberToVet; i++ {
d.vetOne()
}
}

func (d *direct) vetOne() {
// We're just testing the ability to connect here, destination site doesn't
// really matter
log.Debug("Vetting one")
unvettedMasquerades := make([]*masquerade, 0, len(d.masquerades))
for _, m := range d.masquerades {
if m.lastSucceeded().IsZero() {
unvettedMasquerades = append(unvettedMasquerades, m)
func (d *direct) findWorkingMasquerades() {
// vet masquerades in batches
const batchSize int = 25
var successful atomic.Uint32
for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize {
var wg sync.WaitGroup
for j := i; j < i+batchSize && j < len(d.masquerades); j++ {
wg.Add(1)
go func(m *masquerade) {
defer wg.Done()
if d.vetMasquerade(m) {
successful.Add(1)
}
}(d.masquerades[j])
}
wg.Wait()
}
}

// Don't take more than 10 seconds to dial a masquerade for vetting
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

conn, m, masqueradeGood, err := d.dialWith(ctx, unvettedMasquerades)
func (d *direct) vetMasquerade(m *masquerade) bool {
conn, masqueradeGood, err := d.dialMasquerade(m)
if err != nil {
log.Errorf("unexpected error vetting masquerades: %v", err)
return
return false
}
defer conn.Close()
defer func() {
if conn != nil {
conn.Close()
}
}()

provider := d.providerFor(m)
if provider == nil {
log.Debugf("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID)
return
return false
}

if !masqueradeGood(postCheck(conn, provider.TestURL)) {
log.Debugf("Unsuccessful vetting with POST request, discarding masquerade")
return
return false
}

log.Debug("Finished vetting one")
log.Debugf("Finished vetting one masquerade %v", m)
return true
}

// postCheck does a post with invalid data to verify domain-fronting works
Expand Down Expand Up @@ -187,7 +180,7 @@ func doCheck(client *http.Client, method string, expectedStatus int, u string) b
op.Set("response_status", resp.StatusCode)
op.Set("expected_status", expectedStatus)
msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status)
op.FailIf(fmt.Errorf(msg))
op.FailIf(errors.New(msg))
log.Debug(msg)
return false
}
Expand Down Expand Up @@ -247,7 +240,7 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e
log.Debugf("Retrying domain-fronted request, pass %d", i)
}

conn, m, masqueradeGood, err := d.dial(req.Context())
conn, m, masqueradeGood, err := d.dialAll(req.Context())
if err != nil {
// unable to find good masquerade, fail
op.FailIf(err)
Expand Down Expand Up @@ -305,36 +298,13 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e
return nil, nil, op.FailIf(errors.New("could not complete request even with retries"))
}

func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) {
url := *req.URL
url.Host = frontedHost
r, err := http.NewRequest(req.Method, url.String(), body)
if err != nil {
return nil, err
}

for k, vs := range req.Header {
if !strings.EqualFold(k, "Host") {
v := make([]string, len(vs))
copy(v, vs)
r.Header[k] = v
}
}
return r, nil
}

// Dial dials out using a masquerade. If the available masquerade fails, it
// retries with others until it either succeeds or exhausts the available
// masquerades. If successful, it returns a connection to the masquerade,
// the selected masquerade, and a function that the caller can use to
// tell us whether the masquerade is good or not (i.e. if masquerade was good,
// keep it).
func (d *direct) dial(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) {
conn, m, masqueradeGood, err := d.dialWith(ctx, d.masquerades)
// Dial dials out using all available masquerades until one succeeds.
func (d *direct) dialAll(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) {
conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades)
return conn, m, masqueradeGood, err
}

func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) {
func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) {
// never take more than a minute trying to find a dialer
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
Expand All @@ -343,42 +313,50 @@ func (d *direct) dialWith(ctx context.Context, masquerades sortedMasquerades) (n
totalMasquerades := len(masqueradesToTry)
dialLoop:
for _, m := range masqueradesToTry {
// check to see if we've timed out
select {
case <-ctx.Done():
log.Debugf("Timed out dialing to %v with %v total masquerades", m, totalMasquerades)
break dialLoop
default:
// okay
}

log.Tracef("Dialing to %v", m)

// We do the full TLS connection here because in practice the domains at a given IP
// address can change frequently on CDNs, so the certificate may not match what
// we expect.
conn, retriable, err := d.doDial(&m.Masquerade)
masqueradeGood := func(good bool) bool {
if good {
m.markSucceeded()
} else {
m.markFailed()
}
d.markCacheDirty()
return good
}
conn, masqueradeGood, err := d.dialMasquerade(m)
if err == nil {
log.Debug("Returning connection")
return conn, m, masqueradeGood, err
} else if !retriable {
log.Debugf("Dropping masquerade: non retryable error: %v", err)
masqueradeGood(false)
return conn, m, masqueradeGood, nil
}
}

return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades)
}

func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error) {
// check to see if we've timed out

log.Tracef("Dialing to %v", m)

// We do the full TLS connection here because in practice the domains at a given IP
// address can change frequently on CDNs, so the certificate may not match what
// we expect.
conn, retriable, err := d.doDial(&m.Masquerade)
masqueradeGood := func(good bool) bool {
if good {
m.markSucceeded()
} else {
m.markFailed()
}
d.markCacheDirty()
return good
}
if err == nil {
log.Debug("Returning connection")
return conn, masqueradeGood, err
} else if !retriable {
log.Debugf("Dropping masquerade: non retryable error: %v", err)
masqueradeGood(false)
}
return conn, masqueradeGood, err
}

func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) {
op := ops.Begin("dial_masquerade")
defer op.End()
Expand Down Expand Up @@ -551,3 +529,21 @@ func (ddf *directTransport) RoundTrip(req *http.Request) (resp *http.Response, e
norm.URL.Scheme = "http"
return ddf.Transport.RoundTrip(norm)
}

func cloneRequestWith(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) {
url := *req.URL
url.Host = frontedHost
r, err := http.NewRequest(req.Method, url.String(), body)
if err != nil {
return nil, err
}

for k, vs := range req.Header {
if !strings.EqualFold(k, "Host") {
v := make([]string, len(vs))
copy(v, vs)
r.Header[k] = v
}
}
return r, nil
}
4 changes: 2 additions & 2 deletions direct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ func TestDirectDomainFronting(t *testing.T) {
require.NoError(t, err, "Unable to create temp dir")
defer os.RemoveAll(dir)
cacheFile := filepath.Join(dir, "cachefile.2")
doTestDomainFronting(t, cacheFile, numberToVetInitially)
doTestDomainFronting(t, cacheFile, 10)
time.Sleep(defaultCacheSaveInterval * 2)
// Then try again, this time reusing the existing cacheFile but a corrupted version
corruptMasquerades(cacheFile)
doTestDomainFronting(t, cacheFile, numberToVetInitially)
doTestDomainFronting(t, cacheFile, 10)
}

func TestDirectDomainFrontingWithSNIConfig(t *testing.T) {
Expand Down

0 comments on commit aaeb3c1

Please sign in to comment.