Skip to content

Commit

Permalink
Refactor to make everything more testable
Browse files Browse the repository at this point in the history
  • Loading branch information
myleshorton committed Oct 23, 2024
1 parent aaeb3c1 commit 5985882
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 120 deletions.
6 changes: 3 additions & 3 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func (d *direct) prepopulateMasquerades(cacheFile string) {
// update last succeeded status of masquerades based on cached values
for _, m := range d.masquerades {
for _, cm := range cachedMasquerades {
sameMasquerade := cm.ProviderID == m.ProviderID && cm.Domain == m.Domain && cm.IpAddress == m.IpAddress
cachedValueFresh := now.Sub(m.LastSucceeded) < d.maxAllowedCachedAge
sameMasquerade := cm.ProviderID == m.getProviderID() && cm.Domain == m.getDomain() && cm.IpAddress == m.getIpAddress()
cachedValueFresh := now.Sub(m.lastSucceeded()) < d.maxAllowedCachedAge
if sameMasquerade && cachedValueFresh {
m.LastSucceeded = cm.LastSucceeded
m.setLastSucceeded(cm.LastSucceeded)
}
}
}
Expand Down
153 changes: 41 additions & 112 deletions direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import (

"github.com/getlantern/golog"
"github.com/getlantern/idletiming"
"github.com/getlantern/netx"
"github.com/getlantern/ops"
"github.com/getlantern/tlsdialer/v3"
)

const (
Expand Down Expand Up @@ -76,8 +74,8 @@ func (d *direct) loadCandidates(initial map[string]*Provider) {
}
}

func (d *direct) providerFor(m *masquerade) *Provider {
pid := m.ProviderID
func (d *direct) providerFor(m MasqueradeInterface) *Provider {
pid := m.getProviderID()
if pid == "" {
pid = d.defaultProviderID
}
Expand All @@ -92,34 +90,39 @@ func Vet(m *Masquerade, pool *x509.CertPool, testURL string) bool {
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
maxCacheSize: defaultMaxCacheSize,
}
conn, _, err := d.doDial(m)
masq := &masquerade{Masquerade: *m}
conn, _, err := d.doDial(masq)
if err != nil {
return false
}
defer conn.Close()
return postCheck(conn, testURL)
return masq.postCheck(conn, testURL)
}

func (d *direct) findWorkingMasquerades() {
// vet masquerades in batches
const batchSize int = 25
var successful atomic.Uint32
for i := 0; i < len(d.masquerades) && successful.Load() < 4; i += batchSize {
var wg sync.WaitGroup
for j := i; j < i+batchSize && j < len(d.masquerades); j++ {
wg.Add(1)
go func(m *masquerade) {
defer wg.Done()
if d.vetMasquerade(m) {
successful.Add(1)
}
}(d.masquerades[j])
}
wg.Wait()
d.vetGroup(i, batchSize, &successful)
}
}

func (d *direct) vetGroup(start, batchSize int, successful *atomic.Uint32) {
var wg sync.WaitGroup
for j := start; j < start+batchSize && j < len(d.masquerades); j++ {
wg.Add(1)
go func(m MasqueradeInterface) {
defer wg.Done()
if d.vetMasquerade(m) {
successful.Add(1)
}
}(d.masquerades[j])
}
wg.Wait()
}

func (d *direct) vetMasquerade(m *masquerade) bool {
func (d *direct) vetMasquerade(m MasqueradeInterface) bool {
conn, masqueradeGood, err := d.dialMasquerade(m)
if err != nil {
log.Errorf("unexpected error vetting masquerades: %v", err)
Expand All @@ -133,10 +136,11 @@ func (d *direct) vetMasquerade(m *masquerade) bool {

provider := d.providerFor(m)
if provider == nil {
log.Debugf("Skipping masquerade with disabled/unknown provider id '%s'", m.ProviderID)
log.Debugf("Skipping masquerade with disabled/unknown provider id '%s' not in %v",
m.getProviderID(), d.providers)
return false
}
if !masqueradeGood(postCheck(conn, provider.TestURL)) {
if !masqueradeGood(m.postCheck(conn, provider.TestURL)) {
log.Debugf("Unsuccessful vetting with POST request, discarding masquerade")
return false
}
Expand All @@ -145,48 +149,6 @@ func (d *direct) vetMasquerade(m *masquerade) bool {
return true
}

// postCheck does a post with invalid data to verify domain-fronting works
func postCheck(conn net.Conn, testURL string) bool {
client := &http.Client{
Transport: frontedHTTPTransport(conn, true),
}
return doCheck(client, http.MethodPost, http.StatusAccepted, testURL)
}

func doCheck(client *http.Client, method string, expectedStatus int, u string) bool {
op := ops.Begin("check_masquerade")
defer op.End()

isPost := method == http.MethodPost
var requestBody io.Reader
if isPost {
requestBody = strings.NewReader("a")
}
req, _ := http.NewRequest(method, u, requestBody)
if isPost {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
op.FailIf(err)
log.Debugf("Unsuccessful vetting with %v request, discarding masquerade: %v", method, err)
return false
}
if resp.Body != nil {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}
if resp.StatusCode != expectedStatus {
op.Set("response_status", resp.StatusCode)
op.Set("expected_status", expectedStatus)
msg := fmt.Sprintf("Unexpected response status vetting masquerade, expected %d got %d: %v", expectedStatus, resp.StatusCode, resp.Status)
op.FailIf(errors.New(msg))
log.Debug(msg)
return false
}
return true
}

// Do continually retries a given request until it succeeds because some
// fronting providers will return a 403 for some domains.
func (d *direct) RoundTrip(req *http.Request) (*http.Response, error) {
Expand Down Expand Up @@ -248,7 +210,7 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e
}
provider := d.providerFor(m)
if provider == nil {
log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.ProviderID)
log.Debugf("Skipping masquerade with disabled/unknown provider '%s'", m.getProviderID())
masqueradeGood(false)
continue
}
Expand All @@ -258,11 +220,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e
// so it is returned as good.
conn.Close()
masqueradeGood(true)
err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s", m.ProviderID, originHost)
err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s",
m.getProviderID(), originHost)
op.FailIf(err)
return nil, nil, err
}
log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, m.ProviderID)
log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, m.getProviderID())

reqi, err := cloneRequestWith(req, frontedHost, getBody())
if err != nil {
Expand Down Expand Up @@ -299,12 +262,12 @@ func (d *direct) RoundTripHijack(req *http.Request) (*http.Response, net.Conn, e
}

// Dial dials out using all available masquerades until one succeeds.
func (d *direct) dialAll(ctx context.Context) (net.Conn, *masquerade, func(bool) bool, error) {
func (d *direct) dialAll(ctx context.Context) (net.Conn, MasqueradeInterface, func(bool) bool, error) {
conn, m, masqueradeGood, err := d.dialAllWith(ctx, d.masquerades)
return conn, m, masqueradeGood, err
}

func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, *masquerade, func(bool) bool, error) {
func (d *direct) dialAllWith(ctx context.Context, masquerades sortedMasquerades) (net.Conn, MasqueradeInterface, func(bool) bool, error) {
// never take more than a minute trying to find a dialer
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
Expand All @@ -329,15 +292,15 @@ dialLoop:
return nil, nil, nil, log.Errorf("could not dial any masquerade? tried %v", totalMasquerades)
}

func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error) {
func (d *direct) dialMasquerade(m MasqueradeInterface) (net.Conn, func(bool) bool, error) {
// check to see if we've timed out

log.Tracef("Dialing to %v", m)

// We do the full TLS connection here because in practice the domains at a given IP
// address can change frequently on CDNs, so the certificate may not match what
// we expect.
conn, retriable, err := d.doDial(&m.Masquerade)
conn, retriable, err := d.doDial(m)
masqueradeGood := func(good bool) bool {
if good {
m.markSucceeded()
Expand All @@ -357,16 +320,16 @@ func (d *direct) dialMasquerade(m *masquerade) (net.Conn, func(bool) bool, error
return conn, masqueradeGood, err
}

func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) {
func (d *direct) doDial(m MasqueradeInterface) (conn net.Conn, retriable bool, err error) {
op := ops.Begin("dial_masquerade")
defer op.End()
op.Set("masquerade_domain", m.Domain)
op.Set("masquerade_ip", m.IpAddress)
op.Set("masquerade_domain", m.getDomain())
op.Set("masquerade_ip", m.getIpAddress())

conn, err = d.dialServerWith(m)
if err != nil {
op.FailIf(err)
log.Debugf("Could not dial to %v, %v", m.IpAddress, err)
log.Debugf("Could not dial to %v, %v", m.getIpAddress(), err)
// Don't re-add this candidate if it's any certificate error, as that
// will just keep failing and will waste connections. We can't access the underlying
// error at this point so just look for "certificate" and "handshake".
Expand All @@ -389,50 +352,16 @@ func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error
return
}

func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
func (d *direct) dialServerWith(m MasqueradeInterface) (net.Conn, error) {
op := ops.Begin("dial_server_with")
defer op.End()

op.Set("masquerade_domain", m.Domain)
op.Set("masquerade_ip", m.IpAddress)

tlsConfig := d.frontingTLSConfig(m)
dialTimeout := 10 * time.Second
addr := m.IpAddress
var sendServerNameExtension bool

if m.SNI != "" {
sendServerNameExtension = true

op.Set("arbitrary_sni", m.SNI)
tlsConfig.ServerName = m.SNI
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
var verifyHostname string
if m.VerifyHostname != nil {
verifyHostname = *m.VerifyHostname
op.Set("verify_hostname", verifyHostname)
}
return verifyPeerCertificate(rawCerts, d.certPool, verifyHostname)
}

}
op.Set("masquerade_domain", m.getDomain())
op.Set("masquerade_ip", m.getIpAddress())

_, _, err := net.SplitHostPort(addr)
if err != nil {
addr = net.JoinHostPort(addr, "443")
}

dialer := &tlsdialer.Dialer{
DoDial: netx.DialTimeout,
Timeout: dialTimeout,
SendServerName: sendServerNameExtension,
Config: tlsConfig,
ClientHelloID: d.clientHelloID,
}
conn, err := dialer.Dial("tcp", addr)
conn, err := m.dial(d.certPool, d.clientHelloID)
if err != nil && m != nil {
err = fmt.Errorf("unable to dial masquerade %s: %s", m.Domain, err)
err = fmt.Errorf("unable to dial masquerade %s: %s", m.getDomain(), err)
op.FailIf(err)
}
return conn, err
Expand Down
Loading

0 comments on commit 5985882

Please sign in to comment.