diff --git a/direct.go b/direct.go index 3ab43c7..890876a 100644 --- a/direct.go +++ b/direct.go @@ -8,7 +8,7 @@ import ( "fmt" "io" "io/ioutil" - "math/rand" + "math/rand/v2" "net" "net/http" "net/url" @@ -356,7 +356,7 @@ dialLoop: // We do the full TLS connection here because in practice the domains at a given IP // address can change frequently on CDNs, so the certificate may not match what // we expect. - conn, retriable, err := d.doDial(&m.Masquerade) + conn, retriable, err := d.doDial(m) masqueradeGood := func(good bool) bool { if good { m.markSucceeded() @@ -378,7 +378,7 @@ dialLoop: return nil, nil, nil, errors.New("could not dial any masquerade?") } -func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) { +func (d *direct) doDial(m *masquerade) (conn net.Conn, retriable bool, err error) { op := ops.Begin("dial_masquerade") defer op.End() op.Set("masquerade_domain", m.Domain) @@ -410,7 +410,7 @@ func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error return } -func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { +func (d *direct) dialServerWith(m *masquerade) (net.Conn, error) { tlsConfig := d.frontingTLSConfig(m) dialTimeout := 10 * time.Second sendServerNameExtension := false @@ -436,9 +436,51 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { return conn, err } +func (d *direct) verifyPeerCertificate(domain string, rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return errors.New("no certificates provided") + } + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return err + } + opts := x509.VerifyOptions{ + Roots: d.certPool, + CurrentTime: time.Now(), + DNSName: domain, + Intermediates: x509.NewCertPool(), + } + for i := 1; i < len(rawCerts); i++ { + intermediate, err := x509.ParseCertificate(rawCerts[i]) + if err != nil { + return err + } + opts.Intermediates.AddCert(intermediate) + } + _, err = cert.Verify(opts) + if err != nil { + return err + } + + return nil +} + // frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the // initial TCP connection to the CDN. -func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config { +func (d *direct) frontingTLSConfig(m *masquerade) *tls.Config { + provider := d.providers[m.ProviderID] + if provider.SNIConfig.UseArbitrarySNIs { + randomSNIIndex := rand.IntN(len(provider.SNIConfig.ArbitrarySNIs)) + sniDomain := provider.SNIConfig.ArbitrarySNIs[randomSNIIndex] + return &tls.Config{ + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + return d.verifyPeerCertificate(m.Domain, rawCerts, verifiedChains) + }, + ServerName: sniDomain, + RootCAs: d.certPool, + } + } return &tls.Config{ ServerName: m.Domain, RootCAs: d.certPool,