diff --git a/direct.go b/direct.go index 890876a..d9e167b 100644 --- a/direct.go +++ b/direct.go @@ -65,7 +65,7 @@ func (d *direct) loadCandidates(initial map[string]*Provider) { // ('inside-out' Fisher-Yates) sh := make([]*Masquerade, size) for i := 0; i < size; i++ { - j := rand.Intn(i + 1) // 0 <= j <= i + j := rand.IntN(i + 1) // 0 <= j <= i sh[i] = sh[j] sh[j] = arr[i] } @@ -356,7 +356,7 @@ dialLoop: // We do the full TLS connection here because in practice the domains at a given IP // address can change frequently on CDNs, so the certificate may not match what // we expect. - conn, retriable, err := d.doDial(m) + conn, retriable, err := d.doDial(&m.Masquerade) masqueradeGood := func(good bool) bool { if good { m.markSucceeded() @@ -378,7 +378,7 @@ dialLoop: return nil, nil, nil, errors.New("could not dial any masquerade?") } -func (d *direct) doDial(m *masquerade) (conn net.Conn, retriable bool, err error) { +func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() op.Set("masquerade_domain", m.Domain) @@ -410,7 +410,7 @@ func (d *direct) doDial(m *masquerade) (conn net.Conn, retriable bool, err error return } -func (d *direct) dialServerWith(m *masquerade) (net.Conn, error) { +func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { tlsConfig := d.frontingTLSConfig(m) dialTimeout := 10 * time.Second sendServerNameExtension := false @@ -465,11 +465,20 @@ func (d *direct) verifyPeerCertificate(domain string, rawCerts [][]byte, _ [][]* return nil } +func (d *direct) findProviderFromMasquerade(m *Masquerade) *Provider { + for _, masquerade := range d.masquerades { + if masquerade.Domain == m.Domain && masquerade.IpAddress == m.IpAddress { + return d.providers[masquerade.ProviderID] + } + } + return nil +} + // frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the // initial TCP connection to the CDN. -func (d *direct) frontingTLSConfig(m *masquerade) *tls.Config { - provider := d.providers[m.ProviderID] - if provider.SNIConfig.UseArbitrarySNIs { +func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config { + provider := d.findProviderFromMasquerade(m) + if provider != nil && provider.SNIConfig != nil && provider.SNIConfig.UseArbitrarySNIs { randomSNIIndex := rand.IntN(len(provider.SNIConfig.ArbitrarySNIs)) sniDomain := provider.SNIConfig.ArbitrarySNIs[randomSNIIndex] return &tls.Config{