Skip to content

Commit

Permalink
moved ciphers to internal
Browse files Browse the repository at this point in the history
  • Loading branch information
kofoworola committed Jan 30, 2025
1 parent bd06228 commit a0aa344
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 5 deletions.
11 changes: 6 additions & 5 deletions gateway/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,13 @@ func (gw *Gateway) certHandler(w http.ResponseWriter, r *http.Request) {
}

func getCipherAliases(ciphers []string) (cipherCodes []uint16) {
for _, v := range tls.CipherSuites() {
for _, str := range ciphers {
if str == v.Name {
cipherCodes = append(cipherCodes, v.ID)
}
for _, v := range ciphers {
id, err := crypto.ResolveCipher(v)
if err != nil {
log.Debugf("cipher %s not found; skipped", v)
continue
}
cipherCodes = append(cipherCodes, id)
}
return cipherCodes
}
75 changes: 75 additions & 0 deletions internal/crypto/ciphers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package crypto

import (
"crypto/tls"
"fmt"
"strings"
)

// CipherSuite stores information about a cipher suite.
// It shadows tls.CipherSuite but translates TLS versions to strings.
type CipherSuite struct {
ID uint16 `json:"id"`
Name string `json:"name"`
Insecure bool `json:"insecure"`
TLS []string `json:"tls"`
}

// NewCipher translates tls.CipherSuite to our local type.
func NewCipher(in *tls.CipherSuite) *CipherSuite {
return &CipherSuite{
ID: in.ID,
Name: in.Name,
Insecure: in.Insecure,
TLS: TLSVersions(in.SupportedVersions),
}
}

// String returns a human-readable string for the cipher.
func (c *CipherSuite) String() string {
return fmt.Sprintf("Cipher ID: %d, Name: %s, Insecure: %t, TLS: %v", c.ID, c.Name, c.Insecure, c.TLS)
}

// TLSVersions will return a list of TLS versions as a string.
func TLSVersions(in []uint16) []string {
versions := make([]string, len(in))
for i, v := range in {
switch v {
case tls.VersionTLS10:
versions[i] = "1.0"
case tls.VersionTLS11:
versions[i] = "1.1"
case tls.VersionTLS12:
versions[i] = "1.2"
case tls.VersionTLS13:
versions[i] = "1.3"
default:
versions[i] = ""
}
}
return versions
}

// GetCiphers generates a list of CipherSuite from the available ciphers.
func GetCiphers() []*CipherSuite {
ciphers := tls.CipherSuites()
result := make([]*CipherSuite, 0, len(ciphers))

for _, cipher := range ciphers {
result = append(result, NewCipher(cipher))
}

return result
}

// ResolveCipher translates a string representation of a cipher to its uint16 ID.
// It's case-insensitive when matching the cipher by name.
func ResolveCipher(cipherName string) (uint16, error) {
ciphers := GetCiphers()
for _, cipher := range ciphers {
if strings.EqualFold(cipher.Name, cipherName) {
return cipher.ID, nil
}
}
return 0, fmt.Errorf("cipher %s not found", cipherName)
}
99 changes: 99 additions & 0 deletions internal/crypto/ciphers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package crypto

import (
"crypto/tls"
"testing"
)

func TestNewCipher(t *testing.T) {
mockCipher := &tls.CipherSuite{
ID: uint16(0x0001),
Name: "TLS_MOCK_CIPHER",
Insecure: false,
SupportedVersions: []uint16{tls.VersionTLS12, tls.VersionTLS13},
}

cipher := NewCipher(mockCipher)

if cipher.ID != mockCipher.ID {
t.Errorf("Expected ID %d, got %d", mockCipher.ID, cipher.ID)
}
if cipher.Name != mockCipher.Name {
t.Errorf("Expected Name %s, got %s", mockCipher.Name, cipher.Name)
}
if cipher.Insecure != mockCipher.Insecure {
t.Errorf("Expected Insecure %t, got %t", mockCipher.Insecure, cipher.Insecure)
}
if len(cipher.TLS) != 2 || cipher.TLS[0] != "1.2" || cipher.TLS[1] != "1.3" {
t.Errorf("Expected TLS versions [1.2, 1.3], got %v", cipher.TLS)
}
}

func TestGetCiphers(t *testing.T) {
ciphers := GetCiphers()
if len(ciphers) == 0 {
t.Error("Expected non-empty cipher list")
}

for _, cipher := range ciphers {
if cipher.ID == 0 || cipher.Name == "" {
t.Errorf("Invalid cipher: %v", cipher)
}
}
}

func TestResolveCipher(t *testing.T) {
testCases := []struct {
name string
input string
expected uint16
hasError bool
}{
{"Valid cipher", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", 0xc02f, false},
{"Invalid cipher", "INVALID_CIPHER", 0, true},
{"Case insensitive", "tls_ecdhe_rsa_with_aes_128_gcm_sha256", 0xc02f, false},
{"Empty input", "", 0, true},
{"Partial match", "TLS_ECDHE", 0, true},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := ResolveCipher(tc.input)
if tc.hasError && err == nil {
t.Error("Expected error, got nil")
}
if !tc.hasError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != tc.expected {
t.Errorf("Expected %d, got %d", tc.expected, result)
}
})
}
}

func TestTLSVersions(t *testing.T) {
testCases := []struct {
name string
input []uint16
expected []string
}{
{"All versions", []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13}, []string{"1.0", "1.1", "1.2", "1.3"}},
{"Unknown version", []uint16{0x0000}, []string{""}},
{"Mixed versions", []uint16{tls.VersionTLS12, 0x0000, tls.VersionTLS13}, []string{"1.2", "", "1.3"}},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := TLSVersions(tc.input)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d versions, got %d", len(tc.expected), len(result))
}
for i, v := range result {
if v != tc.expected[i] {
t.Errorf("Expected version %s at index %d, got %s", tc.expected[i], i, v)
}
}
})
}
}

0 comments on commit a0aa344

Please sign in to comment.