Skip to content

Commit

Permalink
feat: check if there's a SNIConfig available for the provider and use…
Browse files Browse the repository at this point in the history
… it if there's a enabled SNIConfig; also adding function for verifying the certificate
  • Loading branch information
WendelHime committed Jul 17, 2024
1 parent 3062fd0 commit 0cebc16
Showing 1 changed file with 47 additions and 5 deletions.
52 changes: 47 additions & 5 deletions direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"fmt"
"io"
"io/ioutil"
"math/rand"
"math/rand/v2"

Check failure on line 11 in direct.go

View workflow job for this annotation

GitHub Actions / build

package math/rand/v2 is not in GOROOT (/opt/hostedtoolcache/go/1.18.10/x64/src/math/rand/v2)

Check failure on line 11 in direct.go

View workflow job for this annotation

GitHub Actions / build

package math/rand/v2 is not in GOROOT (/opt/hostedtoolcache/go/1.18.10/x64/src/math/rand/v2)
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -356,7 +356,7 @@ dialLoop:
// We do the full TLS connection here because in practice the domains at a given IP
// address can change frequently on CDNs, so the certificate may not match what
// we expect.
conn, retriable, err := d.doDial(&m.Masquerade)
conn, retriable, err := d.doDial(m)
masqueradeGood := func(good bool) bool {
if good {
m.markSucceeded()
Expand All @@ -378,7 +378,7 @@ dialLoop:
return nil, nil, nil, errors.New("could not dial any masquerade?")
}

func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error) {
func (d *direct) doDial(m *masquerade) (conn net.Conn, retriable bool, err error) {
op := ops.Begin("dial_masquerade")
defer op.End()
op.Set("masquerade_domain", m.Domain)
Expand Down Expand Up @@ -410,7 +410,7 @@ func (d *direct) doDial(m *Masquerade) (conn net.Conn, retriable bool, err error
return
}

func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
func (d *direct) dialServerWith(m *masquerade) (net.Conn, error) {
tlsConfig := d.frontingTLSConfig(m)
dialTimeout := 10 * time.Second
sendServerNameExtension := false
Expand All @@ -436,9 +436,51 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
return conn, err
}

func (d *direct) verifyPeerCertificate(domain string, rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errors.New("no certificates provided")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return err
}
opts := x509.VerifyOptions{
Roots: d.certPool,
CurrentTime: time.Now(),
DNSName: domain,
Intermediates: x509.NewCertPool(),
}
for i := 1; i < len(rawCerts); i++ {
intermediate, err := x509.ParseCertificate(rawCerts[i])
if err != nil {
return err
}
opts.Intermediates.AddCert(intermediate)
}
_, err = cert.Verify(opts)
if err != nil {
return err
}

return nil
}

// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the
// initial TCP connection to the CDN.
func (d *direct) frontingTLSConfig(m *Masquerade) *tls.Config {
func (d *direct) frontingTLSConfig(m *masquerade) *tls.Config {
provider := d.providers[m.ProviderID]
if provider.SNIConfig.UseArbitrarySNIs {
randomSNIIndex := rand.IntN(len(provider.SNIConfig.ArbitrarySNIs))
sniDomain := provider.SNIConfig.ArbitrarySNIs[randomSNIIndex]
return &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return d.verifyPeerCertificate(m.Domain, rawCerts, verifiedChains)
},
ServerName: sniDomain,
RootCAs: d.certPool,
}
}
return &tls.Config{
ServerName: m.Domain,
RootCAs: d.certPool,
Expand Down

0 comments on commit 0cebc16

Please sign in to comment.