Skip to content

Commit

Permalink
fix: replacing IntN old reference, reverting expected argument type t…
Browse files Browse the repository at this point in the history
…o *Masquerade instead of masquerade and add function for finding provider from a given masquerade
  • Loading branch information
WendelHime committed Jul 17, 2024
1 parent 547471b commit 878cb7b
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (d *direct) loadCandidates(initial map[string]*Provider) {
// ('inside-out' Fisher-Yates)
sh := make([]*Masquerade, size)
for i := 0; i < size; i++ {
j := rand.Intn(i + 1) // 0 <= j <= i
j := rand.IntN(i + 1) // 0 <= j <= i
sh[i] = sh[j]
sh[j] = arr[i]
}
Expand Down Expand Up @@ -356,7 +356,7 @@ dialLoop:
// We do the full TLS connection here because in practice the domains at a given IP
// address can change frequently on CDNs, so the certificate may not match what
// we expect.
conn, retriable, err := d.doDial(m)
conn, retriable, err := d.doDial(&m.Masquerade)
masqueradeGood := func(good bool) bool {
if good {
m.markSucceeded()
Expand All @@ -378,7 +378,7 @@ dialLoop:
return nil, nil, nil, errors.New("could not dial any masquerade?")
}

func (d *direct) doDial(m *masquerade) (conn net.Conn, retriable bool, err error) {
func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) {
op := ops.Begin("dial_masquerade")
defer op.End()
op.Set("masquerade_domain", m.Domain)
Expand Down Expand Up @@ -410,7 +410,7 @@ func (d *direct) doDial(m *masquerade) (conn net.Conn, retriable bool, err error
return
}

func (d *direct) dialServerWith(m *masquerade) (net.Conn, error) {
func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
tlsConfig := d.frontingTLSConfig(m)
dialTimeout := 10 * time.Second
sendServerNameExtension := false
Expand Down Expand Up @@ -465,11 +465,20 @@ func (d *direct) verifyPeerCertificate(domain string, rawCerts [][]byte, _ [][]*
return nil
}

func (d *direct) findProviderFromMasquerade(m *Masquerade) *Provider {
for _, masquerade := range d.masquerades {
if masquerade.Domain == m.Domain && masquerade.IpAddress == m.IpAddress {
return d.providers[masquerade.ProviderID]
}
}
return nil
}

// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the
// initial TCP connection to the CDN.
func (d *direct) frontingTLSConfig(m *masquerade) *tls.Config {
provider := d.providers[m.ProviderID]
if provider.SNIConfig.UseArbitrarySNIs {
func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config {
provider := d.findProviderFromMasquerade(m)
if provider != nil && provider.SNIConfig != nil && provider.SNIConfig.UseArbitrarySNIs {
randomSNIIndex := rand.IntN(len(provider.SNIConfig.ArbitrarySNIs))
sniDomain := provider.SNIConfig.ArbitrarySNIs[randomSNIIndex]
return &tls.Config{
Expand Down

0 comments on commit 878cb7b

Please sign in to comment.