From aed59005bc8828a567cba2f5dcb8838d127c863b Mon Sep 17 00:00:00 2001 From: Ansar Smagulov Date: Sat, 16 Dec 2023 01:43:46 +0600 Subject: [PATCH] LRU certificate cache (#33) * feat: LRU cache for TLS certificates --- certmanager/certmanager.go | 38 ++-- certmanager/get.go | 32 +--- certmanager/lru.go | 112 +++++++++++ certmanager/lru_test.go | 178 ++++++++++++++++++ .../wailsjs/go/certmanager/CertManager.d.ts | 5 +- .../wailsjs/go/certmanager/CertManager.js | 8 +- proxy/proxy.go | 6 +- 7 files changed, 324 insertions(+), 55 deletions(-) create mode 100644 certmanager/lru.go create mode 100644 certmanager/lru_test.go diff --git a/certmanager/certmanager.go b/certmanager/certmanager.go index 2f61f511..76ea27dd 100644 --- a/certmanager/certmanager.go +++ b/certmanager/certmanager.go @@ -35,7 +35,6 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha1" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -56,20 +55,26 @@ import ( const ( caName = "rootCA.pem" keyName = "rootCA-key.pem" + // cacheMaxSize is the maximum number of certificates the cache will store. + // + // Considering that a single tls.Certificate is about 1.7KB, this means that the cache + // can store 5800 certificates in about 10MB of memory. + cacheMaxSize = 5800 + // cacheCleanupInterval is the interval at which the cache is cleaned up. + cacheCleanupInterval = 5 * time.Minute ) // CertManager manages the root CA certificate and key for the proxy. type CertManager struct { - certData []byte - keyData []byte - certPath string - cert *x509.Certificate - keyPath string - key crypto.PrivateKey - certCache map[string]tls.Certificate - certCacheMu sync.RWMutex - initOnce *sync.Once - initErr error + certData []byte + keyData []byte + certPath string + cert *x509.Certificate + keyPath string + key crypto.PrivateKey + certCache *CertLRUCache + initOnce *sync.Once + initErr error } var ( @@ -102,7 +107,7 @@ func (cm *CertManager) Init() (err error) { folderName := path.Join(config.Config.DataDir, certsFolderName()) cm.certPath = path.Join(folderName, caName) cm.keyPath = path.Join(folderName, keyName) - cm.certCache = make(map[string]tls.Certificate) + cm.certCache = NewCertLRUCache(cacheMaxSize, cacheCleanupInterval) if config.Config.GetCAInstalled() { if err = cm.loadCA(); err != nil { @@ -265,11 +270,8 @@ func (cm *CertManager) newCA() error { } // ClearCache removes all cached certificates. -func (cm *CertManager) ClearCache() { - cm.certCacheMu.Lock() - defer cm.certCacheMu.Unlock() - - cm.certCache = make(map[string]tls.Certificate) +func (cm *CertManager) PurgeCache() { + cm.certCache.Purge() } // UninstallCA wraps platform-specific uninstallCA methods. @@ -300,7 +302,7 @@ func (cm *CertManager) UninstallCA() string { cm.certCache = nil cm.initOnce = &sync.Once{} cm.initErr = nil - cm.ClearCache() + cm.PurgeCache() return "" } diff --git a/certmanager/get.go b/certmanager/get.go index 46d1b2db..b2d40ebe 100644 --- a/certmanager/get.go +++ b/certmanager/get.go @@ -13,15 +13,13 @@ import ( "time" ) -const certTTL = 2 * time.Minute +// certTTL is the time-to-live for certificates. +const certTTL = 24 * time.Hour // GetCertificate returns a self-signed certificate for the given host. func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) { - cm.certCacheMu.RLock() - cert, ok := cm.certCache[host] - cm.certCacheMu.RUnlock() - if ok { - return &cert, nil + if cert := cm.certCache.Get(host); cert != nil { + return cert, nil } privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -35,7 +33,7 @@ func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) { return nil, fmt.Errorf("generate serial number: %v", err) } - expiry := time.Now().Add(certTTL) + expiresAt := time.Now().Add(certTTL) template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ @@ -43,7 +41,7 @@ func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) { }, DNSNames: []string{host}, NotBefore: time.Now(), - NotAfter: expiry, + NotAfter: expiresAt, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, @@ -68,26 +66,12 @@ func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) { return nil, fmt.Errorf("encode key to PEM") } - cert, err = tls.X509KeyPair(pemCert, pemKey) + cert, err := tls.X509KeyPair(pemCert, pemKey) if err != nil { return nil, fmt.Errorf("load key pair: %v", err) } - cm.certCacheMu.Lock() - cm.certCache[host] = cert - cm.certCacheMu.Unlock() - - cm.ScheduleCacheCleanup(host, expiry) + cm.certCache.Put(host, expiresAt.Add(-5*time.Minute), &cert) // 5 minute buffer in case a TLS handshake takes a while, the system clock is off, etc. return &cert, nil } - -// ScheduleCacheCleanup clears the cache for the given host. -func (cm *CertManager) ScheduleCacheCleanup(host string, expiry time.Time) { - go func() { - time.Sleep(time.Until(expiry) - time.Second) // give it a second in case of a lock contention - cm.certCacheMu.Lock() - delete(cm.certCache, host) - cm.certCacheMu.Unlock() - }() -} diff --git a/certmanager/lru.go b/certmanager/lru.go new file mode 100644 index 00000000..0b0fcb6a --- /dev/null +++ b/certmanager/lru.go @@ -0,0 +1,112 @@ +package certmanager + +import ( + "container/list" + "crypto/tls" + "sync" + "time" +) + +type cacheEntry struct { + cert *tls.Certificate + expiresAt int64 + listElement *list.Element +} + +// CertLRUCache is an LRU cache of TLS certificates. +type CertLRUCache struct { + sync.Mutex + + // maxSize is the maximum number of certificates the cache can store. + maxSize int + // list is the doubly linked list used for LRU eviction. + list *list.List + // cache is the map of host to certificate. + cache map[string]cacheEntry +} + +// NewCertLRUCache initializes a certificate LRU cache with given parameters. +func NewCertLRUCache(maxSize int, cleanupInterval time.Duration) *CertLRUCache { + c := CertLRUCache{ + cache: make(map[string]cacheEntry), + list: list.New(), + maxSize: maxSize, + } + + go func() { + // Periodically remove expired entries. + // This function never exits, which is fine since the CertManager gets accessed via a singleton. Though, be careful with spawning a lot of CertManagers or caches in tests. + ticker := time.NewTicker(cleanupInterval) + for range ticker.C { + c.Lock() + for e, entry := range c.cache { + if time.Now().Unix() > entry.expiresAt { + c.list.Remove(entry.listElement) + delete(c.cache, e) + } + } + c.Unlock() + } + }() + + return &c +} + +// Get returns the certificate for the given host, or nil if it is not cached. +func (c *CertLRUCache) Get(host string) *tls.Certificate { + c.Lock() + defer c.Unlock() + + entry, ok := c.cache[host] + if !ok { + return nil + } + if time.Now().Unix() > entry.expiresAt { + c.list.Remove(entry.listElement) + delete(c.cache, host) + return nil + } + + c.list.MoveToFront(entry.listElement) + + return entry.cert +} + +// Put adds the certificate for the given host to the cache. +func (c *CertLRUCache) Put(host string, expiresAt time.Time, cert *tls.Certificate) { + c.Lock() + defer c.Unlock() + + if e, ok := c.cache[host]; ok { + c.list.MoveToFront(e.listElement) + c.cache[host] = cacheEntry{ + cert: cert, + expiresAt: expiresAt.Unix(), + listElement: e.listElement, + } + return + } + + if c.list.Len() >= c.maxSize { + // Evict the least recently used host. + e := c.list.Back() + c.list.Remove(e) + delete(c.cache, e.Value.(string)) + } + + listElement := c.list.PushFront(host) + c.cache[host] = cacheEntry{ + cert: cert, + expiresAt: expiresAt.Unix(), + listElement: listElement, + } +} + +// Purge clears the cache. +func (c *CertLRUCache) Purge() { + c.Lock() + defer c.Unlock() + + c.cache = make(map[string]cacheEntry) + c.list = list.New() +} diff --git a/certmanager/lru_test.go b/certmanager/lru_test.go new file mode 100644 index 00000000..4c1114ce --- /dev/null +++ b/certmanager/lru_test.go @@ -0,0 +1,178 @@ +package certmanager + +import ( + "crypto/tls" + "fmt" + "math/rand" + "testing" + "time" +) + +// TestPutAndGet tests that certificates can be stored in the cache and retrieved. +func TestPutAndGet(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(100, time.Hour) + cert := &tls.Certificate{} + cache.Put("example.com", time.Now().Add(24*time.Hour), cert) + if cache.Get("example.com") != cert { + t.Errorf("Expected the retrieved certificate to be the same as the one put in") + } +} + +// TestPutMultipleTimes tests that certificates can be stored in the cache multiple times. +func TestPutMultipleTimes(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(100, time.Hour) + cert1 := &tls.Certificate{} + cert2 := &tls.Certificate{} + cache.Put("example.com", time.Now().Add(24*time.Hour), cert1) + cache.Put("example.com", time.Now().Add(24*time.Hour), cert2) + if cache.Get("example.com") != cert2 { + t.Errorf("Expected the retrieved certificate to be the same as the one put in") + } +} + +// TestMultipleCerts tests that multiple certificates can be stored in the cache. +func TestMultipleCerts(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(1000, time.Hour) + certs := make([]*tls.Certificate, 1000) + + expiresAt := time.Now().Add(24 * time.Hour) + for i := 0; i < 1000; i++ { + certs[i] = &tls.Certificate{} + cache.Put(fmt.Sprintf("example%d.com", i), expiresAt, certs[i]) + } + + for _, i := range rand.Perm(1000) { + if cache.Get(fmt.Sprintf("example%d.com", i)) != certs[i] { + t.Fatalf("Expected the retrieved certificate to be the same as the one put in. Failure at index %d.", i) + } + } +} + +// TestExpiration tests that certificates expire after the given TTL. +// May introduce flakiness if the test machine is under heavy load. +func TestExpiration(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(4000, time.Second) + + checkTTL := func(ttl time.Duration, doneC chan<- struct{}, errC chan<- error) { + now := time.Now() + certs := make([]*tls.Certificate, 1000) + for i := 0; i < len(certs); i++ { + certs[i] = &tls.Certificate{} + cache.Put(fmt.Sprintf("%d.%d.example.com", i, ttl), now.Add(ttl), certs[i]) + } + + <-time.After(ttl / 2) + for _, i := range rand.Perm(1000) { + if cache.Get(fmt.Sprintf("%d.%d.example.com", i, ttl)) != certs[i] { + errC <- fmt.Errorf("Expected the retrieved certificate to be the same as the one put in. Failure at index %d.", i) + return + } + } + + <-time.After((ttl / 2) + time.Second) + for _, i := range rand.Perm(1000) { + if cache.Get(fmt.Sprintf("%d.%d.example.com", i, ttl)) != nil { + errC <- fmt.Errorf("Expected the certificate to expire. Failure at index %d.", i) + return + } + } + + doneC <- struct{}{} + } + + doneC := make(chan struct{}, 3) + errC := make(chan error, 1) + + go checkTTL(3*time.Second, doneC, errC) + go checkTTL(5*time.Second, doneC, errC) + go checkTTL(10*time.Second, doneC, errC) + + for i := 0; i < 3; i++ { + select { + case <-doneC: + case err := <-errC: + t.Fatal(err) + case <-time.After(15 * time.Second): + t.Fatal("Timed out") + } + } +} + +// TestGetExpired tests that expired certificates are not returned. +func TestGetExpired(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(1000, time.Hour) + cert := &tls.Certificate{} + cache.Put("example.com", time.Now().Add(-time.Hour), cert) + if cache.Get("example.com") != nil { + t.Errorf("Expected the retrieved certificate to be nil") + } +} + +// TestLRU tests that the LRU eviction policy works. +func TestLRU(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(2000, time.Hour) + + certsToBeEvicted := make([]*tls.Certificate, 1000) + certsToBeKept := make([]*tls.Certificate, 1000) + expiresAt := time.Now().Add(24 * time.Hour) + + for i := 0; i < 1000; i++ { + certsToBeEvicted[i] = &tls.Certificate{} + cache.Put(fmt.Sprintf("%d.evict.com", i), expiresAt, certsToBeEvicted[i]) + } + + for i := 0; i < 1000; i++ { + certsToBeKept[i] = &tls.Certificate{} + cache.Put(fmt.Sprintf("%d.keep.com", i), expiresAt, certsToBeKept[i]) + } + + for _, i := range rand.Perm(1000) { + cache.Get(fmt.Sprintf("%d.keep.com", i)) + } + + for i := 0; i < 1000; i++ { + cache.Put(fmt.Sprintf("%d.new.com", i), expiresAt, &tls.Certificate{}) + } + + for _, i := range rand.Perm(1000) { + if cache.Get(fmt.Sprintf("%d.evict.com", i)) != nil { + t.Fatalf("Expected the certificate to be evicted. Failure at index %d.", i) + } + } + + for _, i := range rand.Perm(1000) { + if cache.Get(fmt.Sprintf("%d.keep.com", i)) != certsToBeKept[i] { + t.Fatalf("Expected the certificate to be kept. Failure at index %d.", i) + } + } +} + +// TestPurge tests that the cache can be purged. +func TestPurge(t *testing.T) { + t.Parallel() + + cache := NewCertLRUCache(1000, time.Hour) + + cert := &tls.Certificate{} + cache.Put("example.com", time.Now().Add(24*time.Hour), cert) + if cache.Get("example.com") != cert { + t.Errorf("Expected the retrieved certificate to be the same as the one put in") + } + + cache.Purge() + if cache.Get("example.com") != nil { + t.Errorf("Expected the retrieved certificate to be nil after purge") + } +} diff --git a/frontend/wailsjs/go/certmanager/CertManager.d.ts b/frontend/wailsjs/go/certmanager/CertManager.d.ts index b0eff706..3bdd2936 100755 --- a/frontend/wailsjs/go/certmanager/CertManager.d.ts +++ b/frontend/wailsjs/go/certmanager/CertManager.d.ts @@ -1,14 +1,11 @@ // Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL // This file is automatically generated. DO NOT EDIT import {tls} from '../models'; -import {time} from '../models'; - -export function ClearCache():Promise; export function GetCertificate(arg1:string):Promise; export function Init():Promise; -export function ScheduleCacheCleanup(arg1:string,arg2:time.Time):Promise; +export function PurgeCache():Promise; export function UninstallCA():Promise; diff --git a/frontend/wailsjs/go/certmanager/CertManager.js b/frontend/wailsjs/go/certmanager/CertManager.js index a63a11a0..5ff8310b 100755 --- a/frontend/wailsjs/go/certmanager/CertManager.js +++ b/frontend/wailsjs/go/certmanager/CertManager.js @@ -2,10 +2,6 @@ // Cynhyrchwyd y ffeil hon yn awtomatig. PEIDIWCH Â MODIWL // This file is automatically generated. DO NOT EDIT -export function ClearCache() { - return window['go']['certmanager']['CertManager']['ClearCache'](); -} - export function GetCertificate(arg1) { return window['go']['certmanager']['CertManager']['GetCertificate'](arg1); } @@ -14,8 +10,8 @@ export function Init() { return window['go']['certmanager']['CertManager']['Init'](); } -export function ScheduleCacheCleanup(arg1, arg2) { - return window['go']['certmanager']['CertManager']['ScheduleCacheCleanup'](arg1, arg2); +export function PurgeCache() { + return window['go']['certmanager']['CertManager']['PurgeCache'](); } export function UninstallCA() { diff --git a/proxy/proxy.go b/proxy/proxy.go index cf3d25a1..e3278598 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -118,7 +118,7 @@ func (p *Proxy) initExclusionList() { // Stop stops the proxy. // If clearCaches is true, the certificate cache will be cleared. -func (p *Proxy) Stop(clearCaches bool) error { +func (p *Proxy) Stop(purgeCache bool) error { if p.server == nil { return nil } @@ -137,8 +137,8 @@ func (p *Proxy) Stop(clearCaches bool) error { return fmt.Errorf("unset system proxy: %v", err) } - if clearCaches { - certmanager.GetCertManager().ClearCache() + if purgeCache { + certmanager.GetCertManager().PurgeCache() p.filter = nil }