diff --git a/transport/tls/stream_dialer.go b/transport/tls/stream_dialer.go index 2a72d70a..23884bbd 100644 --- a/transport/tls/stream_dialer.go +++ b/transport/tls/stream_dialer.go @@ -17,6 +17,7 @@ package tls import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "net" @@ -71,17 +72,40 @@ func (d *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.S // clientConfig encodes the parameters for a TLS client connection. type clientConfig struct { - ServerName string - NextProtos []string - SessionCache tls.ClientSessionCache + ServerName string + CertificateName string + NextProtos []string + SessionCache tls.ClientSessionCache } // ToStdConfig creates a [tls.Config] based on the configured parameters. func (cfg *clientConfig) ToStdConfig() *tls.Config { + certificateName := cfg.CertificateName + if certificateName == "" { + certificateName = cfg.ServerName + } return &tls.Config{ ServerName: cfg.ServerName, NextProtos: cfg.NextProtos, ClientSessionCache: cfg.SessionCache, + // Set InsecureSkipVerify to skip the default validation we are + // replacing. This will not disable VerifyConnection. + InsecureSkipVerify: true, + VerifyConnection: func(cs tls.ConnectionState) error { + // This replicates the logic in the standard library verification: + // https://cs.opensource.google/go/go/+/master:src/crypto/tls/handshake_client.go;l=982;drc=b5f87b5407916c4049a3158cc944cebfd7a883a9 + // And the documentation example: + // https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection + opts := x509.VerifyOptions{ + DNSName: certificateName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range cs.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := cs.PeerCertificates[0].Verify(opts) + return err + }, } } @@ -131,3 +155,11 @@ func WithSessionCache(sessionCache tls.ClientSessionCache) ClientOption { config.SessionCache = sessionCache } } + +// WithCertificateName sets the hostname to be used for the certificate validation. +// If absent, defaults to SNI. +func WithCertificateName(hostname string) ClientOption { + return func(_ string, _ int, config *clientConfig) { + config.CertificateName = hostname + } +} diff --git a/transport/tls/stream_dialer_test.go b/transport/tls/stream_dialer_test.go index cbe23854..5fb4f70a 100644 --- a/transport/tls/stream_dialer_test.go +++ b/transport/tls/stream_dialer_test.go @@ -23,7 +23,7 @@ import ( ) func TestDomainFronting(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, WithSNI("www.youtube.com")) + sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, WithSNI("decoy.android.com"), WithCertificateName("www.youtube.com")) require.NoError(t, err) conn, err := sd.Dial(context.Background(), "www.google.com:443") require.NoError(t, err)