diff --git a/internal/server/cert.go b/internal/server/cert.go index afaf740..90e8b48 100644 --- a/internal/server/cert.go +++ b/internal/server/cert.go @@ -3,6 +3,7 @@ package server import ( "crypto/tls" "log/slog" + "sync" ) type CertManager interface { @@ -14,6 +15,7 @@ type StaticCertManager struct { tlsCertificateFilePath string tlsPrivateKeyFilePath string cert *tls.Certificate + lock sync.RWMutex } func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) *StaticCertManager { @@ -24,7 +26,16 @@ func NewStaticCertManager(tlsCertificateFilePath, tlsPrivateKeyFilePath string) } func (m *StaticCertManager) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + m.lock.RLock() if m.cert != nil { + defer m.lock.RUnlock() + return m.cert, nil + } + m.lock.RUnlock() + + m.lock.Lock() + defer m.lock.Unlock() + if m.cert != nil { // Double-check locking return m.cert, nil } diff --git a/internal/server/cert_test.go b/internal/server/cert_test.go index 6599d11..9ce422a 100644 --- a/internal/server/cert_test.go +++ b/internal/server/cert_test.go @@ -38,6 +38,21 @@ func TestCertificateLoading(t *testing.T) { require.NotNil(t, cert) } +func TestCertificateLoadingRaceCondition(t *testing.T) { + certPath, keyPath, err := prepareTestCertificateFiles() + require.NoError(t, err) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + manager := NewStaticCertManager(certPath, keyPath) + go func() { + manager.GetCertificate(&tls.ClientHelloInfo{}) + }() + cert, err := manager.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) +} + func TestCachesLoadedCertificate(t *testing.T) { certPath, keyPath, err := prepareTestCertificateFiles() require.NoError(t, err)