Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SNI to peer certificate verification #43

Merged
merged 11 commits into from
Aug 22, 2024
6 changes: 3 additions & 3 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package fronted

import (
"encoding/json"
"io/ioutil"
"os"
"time"
)

Expand All @@ -12,7 +12,7 @@ func (d *direct) initCaching(cacheFile string) {
}

func (d *direct) prepopulateMasquerades(cacheFile string) {
bytes, err := ioutil.ReadFile(cacheFile)
bytes, err := os.ReadFile(cacheFile)
if err != nil {
// This is not a big deal since we'll just fill the cache later
log.Debugf("ignorable error: Unable to read cache file for prepopulation: %v", err)
Expand Down Expand Up @@ -84,7 +84,7 @@ func (d *direct) updateCache(cacheFile string) {
log.Errorf("Unable to marshal cache to JSON: %v", err)
return
}
err = ioutil.WriteFile(cacheFile, b, 0644)
err = os.WriteFile(cacheFile, b, 0644)
if err != nil {
log.Errorf("Unable to save cache to disk: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestCaching(t *testing.T) {
cloudsackID := "cloudsack"

providers := map[string]*Provider{
testProviderID: NewProvider(nil, "", nil, nil, nil, nil),
cloudsackID: NewProvider(nil, "", nil, nil, nil, nil),
testProviderID: NewProvider(nil, "", nil, nil, nil, nil, nil),
cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil),
}

makeDirect := func() *direct {
Expand Down
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m

// copy providers
for k, p := range providers {
d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig)
d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname)
}

d.loadCandidates(d.providers)
Expand Down
48 changes: 28 additions & 20 deletions direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,12 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
tlsConfig.ServerName = m.SNI
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
log.Tracef("verifying peer certificate for masquerade domain %s", m.Domain)
return verifyPeerCertificate(rawCerts, d.certPool, m.Domain)
var verifyHostname string
if m.VerifyHostname != nil {
verifyHostname = *m.VerifyHostname
op.Set("verify_hostname", verifyHostname)
}
return verifyPeerCertificate(rawCerts, d.certPool, verifyHostname)
}

}
Expand All @@ -447,9 +451,9 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
ClientHelloID: d.clientHelloID,
}
conn, err := dialer.Dial("tcp", addr)

if err != nil && m != nil {
err = fmt.Errorf("unable to dial masquerade %s: %s", m.Domain, err)
op.FailIf(err)
}
return conn, err
}
Expand All @@ -463,13 +467,7 @@ func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain strin
return fmt.Errorf("unable to parse certificate: %w", err)
}

masqueradeOpts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
DNSName: domain,
Intermediates: x509.NewCertPool(),
}

opts := []x509.VerifyOptions{generateVerifyOptions(roots, domain)}
for i := range rawCerts {
if i == 0 {
continue
Expand All @@ -478,24 +476,34 @@ func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain strin
if err != nil {
return fmt.Errorf("unable to parse intermediate certificate: %w", err)
}
masqueradeOpts.Intermediates.AddCert(crt)

for _, opt := range opts {
opt.Intermediates.AddCert(crt)
}
}

_, masqueradeErr := cert.Verify(masqueradeOpts)
if masqueradeErr != nil {
return fmt.Errorf("certificate verification failed for masquerade: %w", masqueradeErr)
var verificationErrors error
for _, opt := range opts {
_, err := cert.Verify(opt)
if err != nil {
verificationErrors = errors.Join(verificationErrors, err)
}
}

if verificationErrors != nil {
return fmt.Errorf("certificate verification failed: %w", verificationErrors)
}

return nil
}

func (d *direct) findProviderFromMasquerade(m *Masquerade) *Provider {
for _, masquerade := range d.masquerades {
if masquerade.Domain == m.Domain && masquerade.IpAddress == m.IpAddress {
return d.providers[masquerade.ProviderID]
}
func generateVerifyOptions(roots *x509.CertPool, domain string) x509.VerifyOptions {
return x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
DNSName: domain,
Intermediates: x509.NewCertPool(),
}
return nil
}

// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the
Expand Down
86 changes: 81 additions & 5 deletions direct_test.go

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions masquerade.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ type Masquerade struct {

// SNI: the SNI to use for this masquerade
SNI string

// VerifyHostname is used for checking if the certificate for a given hostname is valid.
// This is used for verifying if the peer certificate for the hostnames that are being fronted are valid.
VerifyHostname *string
}

type masquerade struct {
Expand Down Expand Up @@ -103,6 +107,10 @@ type Provider struct {
// detects a failure for a given masquerade, it is discarded.
// The default validator is used if nil.
Validator ResponseValidator

// VerifyHostname is used for checking if the certificate for a given hostname is valid.
// This attribute is only being defined here so it can be sent to the masquerade struct later.
VerifyHostname *string
}

type SNIConfig struct {
Expand All @@ -111,7 +119,7 @@ type SNIConfig struct {
}

// Create a Provider with the given details
func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquerade, validator ResponseValidator, passthrough []string, sniConfig *SNIConfig) *Provider {
func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquerade, validator ResponseValidator, passthrough []string, sniConfig *SNIConfig, verifyHostname *string) *Provider {
d := &Provider{
HostAliases: make(map[string]string),
TestURL: testURL,
Expand All @@ -126,7 +134,7 @@ func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquer

for _, m := range masquerades {
sni := generateSNI(d.SNIConfig, m)
d.Masquerades = append(d.Masquerades, &Masquerade{Domain: m.Domain, IpAddress: m.IpAddress, SNI: sni})
d.Masquerades = append(d.Masquerades, &Masquerade{Domain: m.Domain, IpAddress: m.IpAddress, SNI: sni, VerifyHostname: verifyHostname})
}
d.PassthroughPatterns = append(d.PassthroughPatterns, passthrough...)
return d
Expand Down
61 changes: 61 additions & 0 deletions masquerade_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,72 @@
package fronted

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewProvider(t *testing.T) {
verifyHostname := "verifyHostname.com"
var tests = []struct {
name string
givenHosts map[string]string
givenTestURL string
givenMasquerades []*Masquerade
givenValidator ResponseValidator
givenPassthrough []string
givenSNIConfig *SNIConfig
givenVerifyHostname *string
assert func(t *testing.T, actual *Provider)
}{
{
name: "should return a new provider without host aliases, masquerades and passthrough",
givenHosts: map[string]string{},
givenTestURL: "http://test.com",
assert: func(t *testing.T, actual *Provider) {
assert.Empty(t, actual.HostAliases)
assert.Empty(t, actual.Masquerades)
assert.Empty(t, actual.PassthroughPatterns)
assert.Equal(t, "http://test.com", actual.TestURL)
assert.Nil(t, actual.Validator)
assert.Nil(t, actual.SNIConfig)
},
},
{
name: "should return a new provider with host aliases, masquerades and passthrough",
givenHosts: map[string]string{"host1": "alias1", "host2": "alias2"},
givenTestURL: "http://test.com",
givenMasquerades: []*Masquerade{{Domain: "domain1", IpAddress: "127.0.0.1"}},
givenValidator: func(*http.Response) error { return nil },
givenPassthrough: []string{"passthrough1", "passthrough2"},
givenSNIConfig: &SNIConfig{
UseArbitrarySNIs: true,
ArbitrarySNIs: []string{"sni1.com", "sni2.com"},
},
givenVerifyHostname: &verifyHostname,
assert: func(t *testing.T, actual *Provider) {
assert.Equal(t, "http://test.com", actual.TestURL)
assert.Equal(t, "alias1", actual.HostAliases["host1"])
assert.Equal(t, "alias2", actual.HostAliases["host2"])
assert.Equal(t, 1, len(actual.Masquerades))
assert.Equal(t, "domain1", actual.Masquerades[0].Domain)
assert.Equal(t, "127.0.0.1", actual.Masquerades[0].IpAddress)
assert.Equal(t, "sni1.com", actual.Masquerades[0].SNI)
assert.Equal(t, verifyHostname, *actual.Masquerades[0].VerifyHostname)
assert.Equal(t, 2, len(actual.PassthroughPatterns))
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actual := NewProvider(tt.givenHosts, tt.givenTestURL, tt.givenMasquerades, tt.givenValidator, tt.givenPassthrough, tt.givenSNIConfig, tt.givenVerifyHostname)
tt.assert(t, actual)
})
}
}

func TestGenerateSNI(t *testing.T) {
emptyMasquerade := new(Masquerade)
var tests = []struct {
Expand Down
6 changes: 3 additions & 3 deletions test_support.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ func trustedCACerts(t *testing.T) *x509.CertPool {

func testProviders() map[string]*Provider {
return map[string]*Provider{
testProviderID: NewProvider(testHosts, pingTestURL, testMasquerades, nil, nil, nil),
testProviderID: NewProvider(testHosts, pingTestURL, testMasquerades, nil, nil, nil, nil),
}
}

func testProvidersWithHosts(hosts map[string]string) map[string]*Provider {
return map[string]*Provider{
testProviderID: NewProvider(hosts, pingTestURL, testMasquerades, nil, nil, nil),
testProviderID: NewProvider(hosts, pingTestURL, testMasquerades, nil, nil, nil, nil),
}
}
func testAkamaiProvidersWithHosts(hosts map[string]string, sniConfig *SNIConfig) map[string]*Provider {
return map[string]*Provider{
testProviderID: NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig),
testProviderID: NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig, nil),
}
}
Loading