From f23302ad5977b07808fd94597a40cba84152e977 Mon Sep 17 00:00:00 2001 From: WendelHime <6754291+WendelHime@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:20:36 -0300 Subject: [PATCH] fix: adding SNI to peer certificate verification --- direct.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/direct.go b/direct.go index de9e330..f51e5c1 100644 --- a/direct.go +++ b/direct.go @@ -428,8 +428,8 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { tlsConfig.ServerName = m.SNI tlsConfig.InsecureSkipVerify = true tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - log.Tracef("verifying peer certificate for masquerade domain %s", m.Domain) - return verifyPeerCertificate(rawCerts, d.certPool, m.Domain) + log.Tracef("verifying peer certificate for masquerade domain [%s] and SNI [%s]", m.Domain, m.SNI) + return verifyPeerCertificate(rawCerts, d.certPool, m.Domain, m.SNI) } } @@ -447,14 +447,14 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) { ClientHelloID: d.clientHelloID, } conn, err := dialer.Dial("tcp", addr) - if err != nil && m != nil { err = fmt.Errorf("unable to dial masquerade %s: %s", m.Domain, err) + op.FailIf(err) } return conn, err } -func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain string) error { +func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain string, sni string) error { if len(rawCerts) == 0 { return fmt.Errorf("no certificates presented") } @@ -470,6 +470,13 @@ func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain strin Intermediates: x509.NewCertPool(), } + sniOpts := x509.VerifyOptions{ + Roots: roots, + CurrentTime: time.Now(), + DNSName: sni, + Intermediates: x509.NewCertPool(), + } + for i := range rawCerts { if i == 0 { continue @@ -479,10 +486,12 @@ func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain strin return fmt.Errorf("unable to parse intermediate certificate: %w", err) } masqueradeOpts.Intermediates.AddCert(crt) + sniOpts.Intermediates.AddCert(crt) } + _, sniErr := cert.Verify(sniOpts) _, masqueradeErr := cert.Verify(masqueradeOpts) - if masqueradeErr != nil { + if masqueradeErr != nil && sniErr != nil { return fmt.Errorf("certificate verification failed for masquerade: %w", masqueradeErr) }