Skip to content

Commit

Permalink
certs
Browse files Browse the repository at this point in the history
  • Loading branch information
kdudkov committed Oct 28, 2024
1 parent 4f3004a commit 77d06d4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 57 deletions.
36 changes: 21 additions & 15 deletions cmd/goatak_server/cert_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"log/slog"
"math/big"
Expand Down Expand Up @@ -75,31 +76,35 @@ func getTLSConfigHandler(app *App) fiber.Handler {
}

func signClientCert(uid string, clientCSR *x509.CertificateRequest, serverCert *x509.Certificate, privateKey crypto.PrivateKey, days int) (*x509.Certificate, error) {
tpl := getCertTemplate(&serverCert.Subject, clientCSR, uid, days)

certBytes, err := x509.CreateCertificate(rand.Reader, tpl, serverCert, clientCSR.PublicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to generate certificate, error: %w", err)
}

return x509.ParseCertificate(certBytes)
}

func getCertTemplate(issuer *pkix.Name, csr *x509.CertificateRequest, uid string, days int) *x509.Certificate {
serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))

template := x509.Certificate{
Signature: clientCSR.Signature,
SignatureAlgorithm: clientCSR.SignatureAlgorithm,
return &x509.Certificate{
Signature: csr.Signature,
SignatureAlgorithm: csr.SignatureAlgorithm,

PublicKeyAlgorithm: clientCSR.PublicKeyAlgorithm,
PublicKey: clientCSR.PublicKey,
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
PublicKey: csr.PublicKey,

SerialNumber: serialNumber,
Issuer: serverCert.Subject,
Subject: clientCSR.Subject,
Issuer: *issuer,
Subject: csr.Subject,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Duration(days*24) * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
EmailAddresses: []string{uid},
}

certBytes, err := x509.CreateCertificate(rand.Reader, &template, serverCert, clientCSR.PublicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to generate certificate, error: %w", err)
}

return x509.ParseCertificate(certBytes)
}

func (app *App) processSignRequest(ctx *fiber.Ctx) (*x509.Certificate, error) {
Expand Down Expand Up @@ -150,8 +155,9 @@ func getSignHandler(app *App) fiber.Handler {
}

certs := map[string]*x509.Certificate{"signedCert": signedCert}
certs["ca0"] = app.config.serverCert
for i, c := range app.config.ca {
certs[fmt.Sprintf("ca%d", i)] = c
certs[fmt.Sprintf("ca%d", i+1)] = c
}

p12Bytes, err := tlsutil.MakeP12TrustStore(certs, p12Password)
Expand Down
12 changes: 3 additions & 9 deletions cmd/goatak_server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,12 @@ func (c *AppConfig) processCerts() error {
}
}

roots := x509.NewCertPool()
c.certPool = roots

ca, err := loadPem(c.k.String("ssl.ca"))
if err != nil {
return err
}

for _, c := range ca {
roots.AddCert(c)
}

c.certPool = tlsutil.MakeCertPool(ca...)
c.ca = ca

cert, err := loadPem(c.k.String("ssl.cert"))
Expand All @@ -138,8 +132,8 @@ func (c *AppConfig) processCerts() error {
c.serverCert = cert[0]
}

for _, c := range cert {
roots.AddCert(c)
for _, crt := range cert {
c.certPool.AddCert(crt)
}

tlsCert, err := tls.LoadX509KeyPair(c.k.String("ssl.cert"), c.k.String("ssl.key"))
Expand Down
62 changes: 29 additions & 33 deletions pkg/tlsutil/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

const cr = 10

func ParseBlock(b []byte, typ string) *pem.Block {
func ParseBlock(typ string, b []byte) *pem.Block {
bb := bytes.Buffer{}
bb.WriteString(fmt.Sprintf("-----BEGIN %s-----\n", typ))
bb.Write(b)
Expand All @@ -24,13 +24,13 @@ func ParseBlock(b []byte, typ string) *pem.Block {
}

func ParseCert(s string) (*x509.Certificate, error) {
csrBlock := ParseBlock([]byte(s), "CERTIFICATE")
csrBlock := ParseBlock("CERTIFICATE", []byte(s))

return x509.ParseCertificate(csrBlock.Bytes)
}

func ParseCsr(b []byte) (*x509.CertificateRequest, error) {
csrBlock := ParseBlock(b, "REQUEST")
csrBlock := ParseBlock("REQUEST", b)

return x509.ParseCertificateRequest(csrBlock.Bytes)
}
Expand Down Expand Up @@ -78,32 +78,6 @@ func MakeCertPool(certs ...*x509.Certificate) *x509.CertPool {
return cp
}

func LogCert(logger *slog.Logger, name string, cert *x509.Certificate) {
if cert == nil {
logger.Error("no cert for " + name)

return
}

logger.Info(fmt.Sprintf("%s sn: %x", name, cert.SerialNumber))
logger.Info(fmt.Sprintf("%s subject: %s", name, cert.Subject.String()))
logger.Info(fmt.Sprintf("%s issuer: %s", name, cert.Issuer.String()))
logger.Info(fmt.Sprintf("%s valid till %s", name, cert.NotAfter))

if len(cert.DNSNames) > 0 {
logger.Info(fmt.Sprintf("%s dns_names: %s", name, strings.Join(cert.DNSNames, ",")))
}

if len(cert.IPAddresses) > 0 {
ip1 := make([]string, len(cert.IPAddresses))
for i, ip := range cert.IPAddresses {
ip1[i] = ip.String()
}

logger.Info(fmt.Sprintf("%s ip_addresses: %s", name, strings.Join(ip1, ",")))
}
}

func DecodeAllCerts(bytes []byte) ([]*x509.Certificate, error) {
return DecodeAllByType("CERTIFICATE", bytes)
}
Expand All @@ -129,10 +103,6 @@ func DecodeAllByType(typ string, bytes []byte) ([]*x509.Certificate, error) {
}
}

if len(certs) == 0 {
return nil, fmt.Errorf("no %s in found", typ)
}

return certs, nil
}

Expand All @@ -141,3 +111,29 @@ func LogCerts(logger *slog.Logger, certs ...*x509.Certificate) {
LogCert(logger, fmt.Sprintf("cert #%d", i), c)
}
}

func LogCert(logger *slog.Logger, name string, cert *x509.Certificate) {
if cert == nil {
logger.Error("no cert for " + name)

return
}

logger.Info(fmt.Sprintf("%s sn: %x", name, cert.SerialNumber))
logger.Info(fmt.Sprintf("%s subject: %s", name, cert.Subject.String()))
logger.Info(fmt.Sprintf("%s issuer: %s", name, cert.Issuer.String()))
logger.Info(fmt.Sprintf("%s valid till %s", name, cert.NotAfter))

if len(cert.DNSNames) > 0 {
logger.Info(fmt.Sprintf("%s dns_names: %s", name, strings.Join(cert.DNSNames, ",")))
}

if len(cert.IPAddresses) > 0 {
ip1 := make([]string, len(cert.IPAddresses))
for i, ip := range cert.IPAddresses {
ip1[i] = ip.String()
}

logger.Info(fmt.Sprintf("%s ip_addresses: %s", name, strings.Join(ip1, ",")))
}
}

0 comments on commit 77d06d4

Please sign in to comment.