-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bd06228
commit a0aa344
Showing
3 changed files
with
180 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package crypto | ||
|
||
import ( | ||
"crypto/tls" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
// CipherSuite stores information about a cipher suite. | ||
// It shadows tls.CipherSuite but translates TLS versions to strings. | ||
type CipherSuite struct { | ||
ID uint16 `json:"id"` | ||
Name string `json:"name"` | ||
Insecure bool `json:"insecure"` | ||
TLS []string `json:"tls"` | ||
} | ||
|
||
// NewCipher translates tls.CipherSuite to our local type. | ||
func NewCipher(in *tls.CipherSuite) *CipherSuite { | ||
return &CipherSuite{ | ||
ID: in.ID, | ||
Name: in.Name, | ||
Insecure: in.Insecure, | ||
TLS: TLSVersions(in.SupportedVersions), | ||
} | ||
} | ||
|
||
// String returns a human-readable string for the cipher. | ||
func (c *CipherSuite) String() string { | ||
return fmt.Sprintf("Cipher ID: %d, Name: %s, Insecure: %t, TLS: %v", c.ID, c.Name, c.Insecure, c.TLS) | ||
} | ||
|
||
// TLSVersions will return a list of TLS versions as a string. | ||
func TLSVersions(in []uint16) []string { | ||
versions := make([]string, len(in)) | ||
for i, v := range in { | ||
switch v { | ||
case tls.VersionTLS10: | ||
versions[i] = "1.0" | ||
case tls.VersionTLS11: | ||
versions[i] = "1.1" | ||
case tls.VersionTLS12: | ||
versions[i] = "1.2" | ||
case tls.VersionTLS13: | ||
versions[i] = "1.3" | ||
default: | ||
versions[i] = "" | ||
} | ||
} | ||
return versions | ||
} | ||
|
||
// GetCiphers generates a list of CipherSuite from the available ciphers. | ||
func GetCiphers() []*CipherSuite { | ||
ciphers := tls.CipherSuites() | ||
result := make([]*CipherSuite, 0, len(ciphers)) | ||
|
||
for _, cipher := range ciphers { | ||
result = append(result, NewCipher(cipher)) | ||
} | ||
|
||
return result | ||
} | ||
|
||
// ResolveCipher translates a string representation of a cipher to its uint16 ID. | ||
// It's case-insensitive when matching the cipher by name. | ||
func ResolveCipher(cipherName string) (uint16, error) { | ||
ciphers := GetCiphers() | ||
for _, cipher := range ciphers { | ||
if strings.EqualFold(cipher.Name, cipherName) { | ||
return cipher.ID, nil | ||
} | ||
} | ||
return 0, fmt.Errorf("cipher %s not found", cipherName) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package crypto | ||
|
||
import ( | ||
"crypto/tls" | ||
"testing" | ||
) | ||
|
||
func TestNewCipher(t *testing.T) { | ||
mockCipher := &tls.CipherSuite{ | ||
ID: uint16(0x0001), | ||
Name: "TLS_MOCK_CIPHER", | ||
Insecure: false, | ||
SupportedVersions: []uint16{tls.VersionTLS12, tls.VersionTLS13}, | ||
} | ||
|
||
cipher := NewCipher(mockCipher) | ||
|
||
if cipher.ID != mockCipher.ID { | ||
t.Errorf("Expected ID %d, got %d", mockCipher.ID, cipher.ID) | ||
} | ||
if cipher.Name != mockCipher.Name { | ||
t.Errorf("Expected Name %s, got %s", mockCipher.Name, cipher.Name) | ||
} | ||
if cipher.Insecure != mockCipher.Insecure { | ||
t.Errorf("Expected Insecure %t, got %t", mockCipher.Insecure, cipher.Insecure) | ||
} | ||
if len(cipher.TLS) != 2 || cipher.TLS[0] != "1.2" || cipher.TLS[1] != "1.3" { | ||
t.Errorf("Expected TLS versions [1.2, 1.3], got %v", cipher.TLS) | ||
} | ||
} | ||
|
||
func TestGetCiphers(t *testing.T) { | ||
ciphers := GetCiphers() | ||
if len(ciphers) == 0 { | ||
t.Error("Expected non-empty cipher list") | ||
} | ||
|
||
for _, cipher := range ciphers { | ||
if cipher.ID == 0 || cipher.Name == "" { | ||
t.Errorf("Invalid cipher: %v", cipher) | ||
} | ||
} | ||
} | ||
|
||
func TestResolveCipher(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
input string | ||
expected uint16 | ||
hasError bool | ||
}{ | ||
{"Valid cipher", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", 0xc02f, false}, | ||
{"Invalid cipher", "INVALID_CIPHER", 0, true}, | ||
{"Case insensitive", "tls_ecdhe_rsa_with_aes_128_gcm_sha256", 0xc02f, false}, | ||
{"Empty input", "", 0, true}, | ||
{"Partial match", "TLS_ECDHE", 0, true}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
result, err := ResolveCipher(tc.input) | ||
if tc.hasError && err == nil { | ||
t.Error("Expected error, got nil") | ||
} | ||
if !tc.hasError && err != nil { | ||
t.Errorf("Unexpected error: %v", err) | ||
} | ||
if result != tc.expected { | ||
t.Errorf("Expected %d, got %d", tc.expected, result) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestTLSVersions(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
input []uint16 | ||
expected []string | ||
}{ | ||
{"All versions", []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13}, []string{"1.0", "1.1", "1.2", "1.3"}}, | ||
{"Unknown version", []uint16{0x0000}, []string{""}}, | ||
{"Mixed versions", []uint16{tls.VersionTLS12, 0x0000, tls.VersionTLS13}, []string{"1.2", "", "1.3"}}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
result := TLSVersions(tc.input) | ||
if len(result) != len(tc.expected) { | ||
t.Errorf("Expected %d versions, got %d", len(tc.expected), len(result)) | ||
} | ||
for i, v := range result { | ||
if v != tc.expected[i] { | ||
t.Errorf("Expected version %s at index %d, got %s", tc.expected[i], i, v) | ||
} | ||
} | ||
}) | ||
} | ||
} |