Skip to content

Commit

Permalink
LRU certificate cache (#33)
Browse files Browse the repository at this point in the history
* feat: LRU cache for TLS certificates
  • Loading branch information
anfragment authored Dec 15, 2023
1 parent 2b2031f commit aed5900
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 55 deletions.
38 changes: 20 additions & 18 deletions certmanager/certmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
Expand All @@ -56,20 +55,26 @@ import (
const (
caName = "rootCA.pem"
keyName = "rootCA-key.pem"
// cacheMaxSize is the maximum number of certificates the cache will store.
//
// Considering that a single tls.Certificate is about 1.7KB, this means that the cache
// can store 5800 certificates in about 10MB of memory.
cacheMaxSize = 5800
// cacheCleanupInterval is the interval at which the cache is cleaned up.
cacheCleanupInterval = 5 * time.Minute
)

// CertManager manages the root CA certificate and key for the proxy.
type CertManager struct {
certData []byte
keyData []byte
certPath string
cert *x509.Certificate
keyPath string
key crypto.PrivateKey
certCache map[string]tls.Certificate
certCacheMu sync.RWMutex
initOnce *sync.Once
initErr error
certData []byte
keyData []byte
certPath string
cert *x509.Certificate
keyPath string
key crypto.PrivateKey
certCache *CertLRUCache
initOnce *sync.Once
initErr error
}

var (
Expand Down Expand Up @@ -102,7 +107,7 @@ func (cm *CertManager) Init() (err error) {
folderName := path.Join(config.Config.DataDir, certsFolderName())
cm.certPath = path.Join(folderName, caName)
cm.keyPath = path.Join(folderName, keyName)
cm.certCache = make(map[string]tls.Certificate)
cm.certCache = NewCertLRUCache(cacheMaxSize, cacheCleanupInterval)

if config.Config.GetCAInstalled() {
if err = cm.loadCA(); err != nil {
Expand Down Expand Up @@ -265,11 +270,8 @@ func (cm *CertManager) newCA() error {
}

// ClearCache removes all cached certificates.
func (cm *CertManager) ClearCache() {
cm.certCacheMu.Lock()
defer cm.certCacheMu.Unlock()

cm.certCache = make(map[string]tls.Certificate)
func (cm *CertManager) PurgeCache() {
cm.certCache.Purge()
}

// UninstallCA wraps platform-specific uninstallCA methods.
Expand Down Expand Up @@ -300,7 +302,7 @@ func (cm *CertManager) UninstallCA() string {
cm.certCache = nil
cm.initOnce = &sync.Once{}
cm.initErr = nil
cm.ClearCache()
cm.PurgeCache()

return ""
}
Expand Down
32 changes: 8 additions & 24 deletions certmanager/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ import (
"time"
)

const certTTL = 2 * time.Minute
// certTTL is the time-to-live for certificates.
const certTTL = 24 * time.Hour

// GetCertificate returns a self-signed certificate for the given host.
func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) {
cm.certCacheMu.RLock()
cert, ok := cm.certCache[host]
cm.certCacheMu.RUnlock()
if ok {
return &cert, nil
if cert := cm.certCache.Get(host); cert != nil {
return cert, nil
}

privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Expand All @@ -35,15 +33,15 @@ func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) {
return nil, fmt.Errorf("generate serial number: %v", err)
}

expiry := time.Now().Add(certTTL)
expiresAt := time.Now().Add(certTTL)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Zen"},
},
DNSNames: []string{host},
NotBefore: time.Now(),
NotAfter: expiry,
NotAfter: expiresAt,

KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
Expand All @@ -68,26 +66,12 @@ func (cm *CertManager) GetCertificate(host string) (*tls.Certificate, error) {
return nil, fmt.Errorf("encode key to PEM")
}

cert, err = tls.X509KeyPair(pemCert, pemKey)
cert, err := tls.X509KeyPair(pemCert, pemKey)
if err != nil {
return nil, fmt.Errorf("load key pair: %v", err)
}

cm.certCacheMu.Lock()
cm.certCache[host] = cert
cm.certCacheMu.Unlock()

cm.ScheduleCacheCleanup(host, expiry)
cm.certCache.Put(host, expiresAt.Add(-5*time.Minute), &cert) // 5 minute buffer in case a TLS handshake takes a while, the system clock is off, etc.

return &cert, nil
}

// ScheduleCacheCleanup clears the cache for the given host.
func (cm *CertManager) ScheduleCacheCleanup(host string, expiry time.Time) {
go func() {
time.Sleep(time.Until(expiry) - time.Second) // give it a second in case of a lock contention
cm.certCacheMu.Lock()
delete(cm.certCache, host)
cm.certCacheMu.Unlock()
}()
}
112 changes: 112 additions & 0 deletions certmanager/lru.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package certmanager

import (
"container/list"
"crypto/tls"
"sync"
"time"
)

type cacheEntry struct {
cert *tls.Certificate
expiresAt int64
listElement *list.Element
}

// CertLRUCache is an LRU cache of TLS certificates.
type CertLRUCache struct {
sync.Mutex

// maxSize is the maximum number of certificates the cache can store.
maxSize int
// list is the doubly linked list used for LRU eviction.
list *list.List
// cache is the map of host to certificate.
cache map[string]cacheEntry
}

// NewCertLRUCache initializes a certificate LRU cache with given parameters.
func NewCertLRUCache(maxSize int, cleanupInterval time.Duration) *CertLRUCache {
c := CertLRUCache{
cache: make(map[string]cacheEntry),
list: list.New(),
maxSize: maxSize,
}

go func() {
// Periodically remove expired entries.
// This function never exits, which is fine since the CertManager gets accessed via a singleton. Though, be careful with spawning a lot of CertManagers or caches in tests.
ticker := time.NewTicker(cleanupInterval)
for range ticker.C {
c.Lock()
for e, entry := range c.cache {
if time.Now().Unix() > entry.expiresAt {
c.list.Remove(entry.listElement)
delete(c.cache, e)
}
}
c.Unlock()
}
}()

return &c
}

// Get returns the certificate for the given host, or nil if it is not cached.
func (c *CertLRUCache) Get(host string) *tls.Certificate {
c.Lock()
defer c.Unlock()

entry, ok := c.cache[host]
if !ok {
return nil
}
if time.Now().Unix() > entry.expiresAt {
c.list.Remove(entry.listElement)
delete(c.cache, host)
return nil
}

c.list.MoveToFront(entry.listElement)

return entry.cert
}

// Put adds the certificate for the given host to the cache.
func (c *CertLRUCache) Put(host string, expiresAt time.Time, cert *tls.Certificate) {
c.Lock()
defer c.Unlock()

if e, ok := c.cache[host]; ok {
c.list.MoveToFront(e.listElement)
c.cache[host] = cacheEntry{
cert: cert,
expiresAt: expiresAt.Unix(),
listElement: e.listElement,
}
return
}

if c.list.Len() >= c.maxSize {
// Evict the least recently used host.
e := c.list.Back()
c.list.Remove(e)
delete(c.cache, e.Value.(string))
}

listElement := c.list.PushFront(host)
c.cache[host] = cacheEntry{
cert: cert,
expiresAt: expiresAt.Unix(),
listElement: listElement,
}
}

// Purge clears the cache.
func (c *CertLRUCache) Purge() {
c.Lock()
defer c.Unlock()

c.cache = make(map[string]cacheEntry)
c.list = list.New()
}
Loading

0 comments on commit aed5900

Please sign in to comment.