diff --git a/masquerade.go b/masquerade.go index 686ff93..914495a 100644 --- a/masquerade.go +++ b/masquerade.go @@ -1,9 +1,9 @@ package fronted import ( + "crypto/sha256" "encoding/json" "fmt" - "hash/crc32" "net" "net/http" "sort" @@ -125,18 +125,29 @@ func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquer } for _, m := range masquerades { - var sni string - if d.SNIConfig != nil && d.SNIConfig.UseArbitrarySNIs { - // Ensure that we use a consistent SNI for a given combination of IP address and SNI set - crc32Hash := int(crc32.ChecksumIEEE([]byte(m.IpAddress))) - sni = d.SNIConfig.ArbitrarySNIs[crc32Hash%len(d.SNIConfig.ArbitrarySNIs)] - } + sni := generateSNI(d.SNIConfig, m) d.Masquerades = append(d.Masquerades, &Masquerade{Domain: m.Domain, IpAddress: m.IpAddress, SNI: sni}) } d.PassthroughPatterns = append(d.PassthroughPatterns, passthrough...) return d } +// generateSNI generates a SNI for the given domain and ip address +func generateSNI(config *SNIConfig, m *Masquerade) string { + if config != nil && m != nil && config.UseArbitrarySNIs && len(config.ArbitrarySNIs) > 0 { + // Ensure that we use a consistent SNI for a given combination of IP address and SNI set + hash := sha256.New() + hash.Write([]byte(m.IpAddress)) + checksum := int(hash.Sum(nil)[0]) + // making sure checksum is positive + if checksum < 0 { + checksum = -checksum + } + return config.ArbitrarySNIs[checksum%len(config.ArbitrarySNIs)] + } + return "" +} + // Lookup the host alias for the given hostname for this provider func (p *Provider) Lookup(hostname string) string { // only consider the host porition if given a port as well. diff --git a/masquerade_test.go b/masquerade_test.go new file mode 100644 index 0000000..29a9475 --- /dev/null +++ b/masquerade_test.go @@ -0,0 +1,70 @@ +package fronted + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateSNI(t *testing.T) { + emptyMasquerade := new(Masquerade) + var tests = []struct { + name string + assert func(t *testing.T, actual string) + givenConfig *SNIConfig + givenMasquerade *Masquerade + }{ + { + name: "should return a empty string when given SNI config is nil", + givenConfig: nil, + givenMasquerade: emptyMasquerade, + assert: func(t *testing.T, actual string) { + assert.Empty(t, actual) + }, + }, + { + name: "should return a empty string when given SNI config is not nil and UseArbitrarySNIs is false", + givenConfig: &SNIConfig{ + UseArbitrarySNIs: false, + }, + givenMasquerade: emptyMasquerade, + assert: func(t *testing.T, actual string) { + assert.Empty(t, actual) + }, + }, + { + name: "should return a empty SNI when the list of arbitrary SNIs is empty", + givenConfig: &SNIConfig{ + UseArbitrarySNIs: true, + ArbitrarySNIs: []string{}, + }, + givenMasquerade: &Masquerade{ + IpAddress: "1.1.1.1", + Domain: "randomdomain.net", + }, + assert: func(t *testing.T, actual string) { + assert.Empty(t, actual) + }, + }, + { + name: "should return a SNI when given SNI config is not nil and UseArbitrarySNIs is true", + givenConfig: &SNIConfig{ + UseArbitrarySNIs: true, + ArbitrarySNIs: []string{"sni1.com", "sni2.com"}, + }, + givenMasquerade: &Masquerade{ + IpAddress: "1.1.1.1", + Domain: "randomdomain.net", + }, + assert: func(t *testing.T, actual string) { + assert.NotEmpty(t, actual) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := generateSNI(tt.givenConfig, tt.givenMasquerade) + tt.assert(t, actual) + }) + } +}