Skip to content

Commit

Permalink
Adding verifyHost to peer certificate verification (#43)
Browse files Browse the repository at this point in the history
* fix: replacing crc32 by sha256 hash sum and making sure it doesn't return a negative value

* fix: adding SNI to peer certificate verification

* fix: update error message

* fix: add test for verifying peer certificate

* chore: removing unused function

* feat: adding verify host attribute

* fix: add missing verify host to context.go

* fix: add missing verify host to cache_test.go and remove deprecated references from cache.go

* feat: use verify host and update tests

* chore: removing log trace and add field to operation so we can see the received verifyHost
  • Loading branch information
WendelHime authored Aug 22, 2024
1 parent 6976132 commit 6e97652
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 36 deletions.
6 changes: 3 additions & 3 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package fronted

import (
"encoding/json"
"io/ioutil"
"os"
"time"
)

Expand All @@ -12,7 +12,7 @@ func (d *direct) initCaching(cacheFile string) {
}

func (d *direct) prepopulateMasquerades(cacheFile string) {
bytes, err := ioutil.ReadFile(cacheFile)
bytes, err := os.ReadFile(cacheFile)
if err != nil {
// This is not a big deal since we'll just fill the cache later
log.Debugf("ignorable error: Unable to read cache file for prepopulation: %v", err)
Expand Down Expand Up @@ -84,7 +84,7 @@ func (d *direct) updateCache(cacheFile string) {
log.Errorf("Unable to marshal cache to JSON: %v", err)
return
}
err = ioutil.WriteFile(cacheFile, b, 0644)
err = os.WriteFile(cacheFile, b, 0644)
if err != nil {
log.Errorf("Unable to save cache to disk: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestCaching(t *testing.T) {
cloudsackID := "cloudsack"

providers := map[string]*Provider{
testProviderID: NewProvider(nil, "", nil, nil, nil, nil),
cloudsackID: NewProvider(nil, "", nil, nil, nil, nil),
testProviderID: NewProvider(nil, "", nil, nil, nil, nil, nil),
cloudsackID: NewProvider(nil, "", nil, nil, nil, nil, nil),
}

makeDirect := func() *direct {
Expand Down
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (fctx *FrontingContext) ConfigureWithHello(pool *x509.CertPool, providers m

// copy providers
for k, p := range providers {
d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig)
d.providers[k] = NewProvider(p.HostAliases, p.TestURL, p.Masquerades, p.Validator, p.PassthroughPatterns, p.SNIConfig, p.VerifyHostname)
}

d.loadCandidates(d.providers)
Expand Down
48 changes: 28 additions & 20 deletions direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,12 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
tlsConfig.ServerName = m.SNI
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
log.Tracef("verifying peer certificate for masquerade domain %s", m.Domain)
return verifyPeerCertificate(rawCerts, d.certPool, m.Domain)
var verifyHostname string
if m.VerifyHostname != nil {
verifyHostname = *m.VerifyHostname
op.Set("verify_hostname", verifyHostname)
}
return verifyPeerCertificate(rawCerts, d.certPool, verifyHostname)
}

}
Expand All @@ -447,9 +451,9 @@ func (d *direct) dialServerWith(m *Masquerade) (net.Conn, error) {
ClientHelloID: d.clientHelloID,
}
conn, err := dialer.Dial("tcp", addr)

if err != nil && m != nil {
err = fmt.Errorf("unable to dial masquerade %s: %s", m.Domain, err)
op.FailIf(err)
}
return conn, err
}
Expand All @@ -463,13 +467,7 @@ func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain strin
return fmt.Errorf("unable to parse certificate: %w", err)
}

masqueradeOpts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
DNSName: domain,
Intermediates: x509.NewCertPool(),
}

opts := []x509.VerifyOptions{generateVerifyOptions(roots, domain)}
for i := range rawCerts {
if i == 0 {
continue
Expand All @@ -478,24 +476,34 @@ func verifyPeerCertificate(rawCerts [][]byte, roots *x509.CertPool, domain strin
if err != nil {
return fmt.Errorf("unable to parse intermediate certificate: %w", err)
}
masqueradeOpts.Intermediates.AddCert(crt)

for _, opt := range opts {
opt.Intermediates.AddCert(crt)
}
}

_, masqueradeErr := cert.Verify(masqueradeOpts)
if masqueradeErr != nil {
return fmt.Errorf("certificate verification failed for masquerade: %w", masqueradeErr)
var verificationErrors error
for _, opt := range opts {
_, err := cert.Verify(opt)
if err != nil {
verificationErrors = errors.Join(verificationErrors, err)
}
}

if verificationErrors != nil {
return fmt.Errorf("certificate verification failed: %w", verificationErrors)
}

return nil
}

func (d *direct) findProviderFromMasquerade(m *Masquerade) *Provider {
for _, masquerade := range d.masquerades {
if masquerade.Domain == m.Domain && masquerade.IpAddress == m.IpAddress {
return d.providers[masquerade.ProviderID]
}
func generateVerifyOptions(roots *x509.CertPool, domain string) x509.VerifyOptions {
return x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
DNSName: domain,
Intermediates: x509.NewCertPool(),
}
return nil
}

// frontingTLSConfig builds a tls.Config for dialing the fronting domain. This is to establish the
Expand Down
86 changes: 81 additions & 5 deletions direct_test.go

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions masquerade.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ type Masquerade struct {

// SNI: the SNI to use for this masquerade
SNI string

// VerifyHostname is used for checking if the certificate for a given hostname is valid.
// This is used for verifying if the peer certificate for the hostnames that are being fronted are valid.
VerifyHostname *string
}

type masquerade struct {
Expand Down Expand Up @@ -103,6 +107,10 @@ type Provider struct {
// detects a failure for a given masquerade, it is discarded.
// The default validator is used if nil.
Validator ResponseValidator

// VerifyHostname is used for checking if the certificate for a given hostname is valid.
// This attribute is only being defined here so it can be sent to the masquerade struct later.
VerifyHostname *string
}

type SNIConfig struct {
Expand All @@ -111,7 +119,7 @@ type SNIConfig struct {
}

// Create a Provider with the given details
func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquerade, validator ResponseValidator, passthrough []string, sniConfig *SNIConfig) *Provider {
func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquerade, validator ResponseValidator, passthrough []string, sniConfig *SNIConfig, verifyHostname *string) *Provider {
d := &Provider{
HostAliases: make(map[string]string),
TestURL: testURL,
Expand All @@ -126,7 +134,7 @@ func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquer

for _, m := range masquerades {
sni := generateSNI(d.SNIConfig, m)
d.Masquerades = append(d.Masquerades, &Masquerade{Domain: m.Domain, IpAddress: m.IpAddress, SNI: sni})
d.Masquerades = append(d.Masquerades, &Masquerade{Domain: m.Domain, IpAddress: m.IpAddress, SNI: sni, VerifyHostname: verifyHostname})
}
d.PassthroughPatterns = append(d.PassthroughPatterns, passthrough...)
return d
Expand Down
61 changes: 61 additions & 0 deletions masquerade_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,72 @@
package fronted

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewProvider(t *testing.T) {
verifyHostname := "verifyHostname.com"
var tests = []struct {
name string
givenHosts map[string]string
givenTestURL string
givenMasquerades []*Masquerade
givenValidator ResponseValidator
givenPassthrough []string
givenSNIConfig *SNIConfig
givenVerifyHostname *string
assert func(t *testing.T, actual *Provider)
}{
{
name: "should return a new provider without host aliases, masquerades and passthrough",
givenHosts: map[string]string{},
givenTestURL: "http://test.com",
assert: func(t *testing.T, actual *Provider) {
assert.Empty(t, actual.HostAliases)
assert.Empty(t, actual.Masquerades)
assert.Empty(t, actual.PassthroughPatterns)
assert.Equal(t, "http://test.com", actual.TestURL)
assert.Nil(t, actual.Validator)
assert.Nil(t, actual.SNIConfig)
},
},
{
name: "should return a new provider with host aliases, masquerades and passthrough",
givenHosts: map[string]string{"host1": "alias1", "host2": "alias2"},
givenTestURL: "http://test.com",
givenMasquerades: []*Masquerade{{Domain: "domain1", IpAddress: "127.0.0.1"}},
givenValidator: func(*http.Response) error { return nil },
givenPassthrough: []string{"passthrough1", "passthrough2"},
givenSNIConfig: &SNIConfig{
UseArbitrarySNIs: true,
ArbitrarySNIs: []string{"sni1.com", "sni2.com"},
},
givenVerifyHostname: &verifyHostname,
assert: func(t *testing.T, actual *Provider) {
assert.Equal(t, "http://test.com", actual.TestURL)
assert.Equal(t, "alias1", actual.HostAliases["host1"])
assert.Equal(t, "alias2", actual.HostAliases["host2"])
assert.Equal(t, 1, len(actual.Masquerades))
assert.Equal(t, "domain1", actual.Masquerades[0].Domain)
assert.Equal(t, "127.0.0.1", actual.Masquerades[0].IpAddress)
assert.Equal(t, "sni1.com", actual.Masquerades[0].SNI)
assert.Equal(t, verifyHostname, *actual.Masquerades[0].VerifyHostname)
assert.Equal(t, 2, len(actual.PassthroughPatterns))
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actual := NewProvider(tt.givenHosts, tt.givenTestURL, tt.givenMasquerades, tt.givenValidator, tt.givenPassthrough, tt.givenSNIConfig, tt.givenVerifyHostname)
tt.assert(t, actual)
})
}
}

func TestGenerateSNI(t *testing.T) {
emptyMasquerade := new(Masquerade)
var tests = []struct {
Expand Down
6 changes: 3 additions & 3 deletions test_support.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ func trustedCACerts(t *testing.T) *x509.CertPool {

func testProviders() map[string]*Provider {
return map[string]*Provider{
testProviderID: NewProvider(testHosts, pingTestURL, testMasquerades, nil, nil, nil),
testProviderID: NewProvider(testHosts, pingTestURL, testMasquerades, nil, nil, nil, nil),
}
}

func testProvidersWithHosts(hosts map[string]string) map[string]*Provider {
return map[string]*Provider{
testProviderID: NewProvider(hosts, pingTestURL, testMasquerades, nil, nil, nil),
testProviderID: NewProvider(hosts, pingTestURL, testMasquerades, nil, nil, nil, nil),
}
}
func testAkamaiProvidersWithHosts(hosts map[string]string, sniConfig *SNIConfig) map[string]*Provider {
return map[string]*Provider{
testProviderID: NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig),
testProviderID: NewProvider(hosts, pingTestURL, DefaultAkamaiMasquerades, nil, nil, sniConfig, nil),
}
}

0 comments on commit 6e97652

Please sign in to comment.