From 78aab8fc1f61b0e7610a955725b76bbd12c8dc33 Mon Sep 17 00:00:00 2001 From: Mingye Chen Date: Sun, 5 Nov 2023 23:24:57 -0500 Subject: [PATCH] Fix issues with cert generation --- pkg/dtls/listener.go | 7 +------ pkg/dtls/seedtocert.go | 35 +++++++++++++++++------------------ 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/pkg/dtls/listener.go b/pkg/dtls/listener.go index 31753f4c..8112205b 100644 --- a/pkg/dtls/listener.go +++ b/pkg/dtls/listener.go @@ -256,10 +256,5 @@ func (l *Listener) getCertificateFromClientHello(clientHello *dtls.ClientHelloIn } func randomCertificate() (*tls.Certificate, error) { - seedBytes := []byte{} - _, err := rand.Read(seedBytes) - if err != nil { - return nil, err - } - return newCertificate(seedBytes) + return newCertificate(rand.Reader) } diff --git a/pkg/dtls/seedtocert.go b/pkg/dtls/seedtocert.go index eac3bbbc..0b995c75 100644 --- a/pkg/dtls/seedtocert.go +++ b/pkg/dtls/seedtocert.go @@ -22,7 +22,7 @@ import ( ) func clientHelloRandomFromSeed(seed []byte) ([handshake.RandomBytesLength]byte, error) { - randSource := hkdf.New(sha256.New, seed, nil, nil) + randSource := hkdf.New(sha256.New, seed, []byte("clientHelloRandomFromSeed"), nil) randomBytes := [handshake.RandomBytesLength]byte{} _, err := io.ReadFull(randSource, randomBytes[:]) @@ -34,19 +34,17 @@ func clientHelloRandomFromSeed(seed []byte) ([handshake.RandomBytesLength]byte, } // getPrivkey creates ECDSA private key used in DTLS Certificates -func getPrivkey(seed []byte) (*ecdsa.PrivateKey, error) { - randSource := hkdf.New(sha256.New, seed, nil, nil) - +func getPrivkey(randSource io.Reader) (*ecdsa.PrivateKey, error) { privkey, err := keygen.ECDSALegacy(elliptic.P256(), randSource) if err != nil { return &ecdsa.PrivateKey{}, err } + return privkey, nil } // getX509Tpl creates x509 template for x509 Certificates generation used in DTLS Certificates. -func getX509Tpl(seed []byte) (*x509.Certificate, error) { - randSource := hkdf.New(sha256.New, seed, nil, nil) +func getX509Tpl(randSource io.Reader) (*x509.Certificate, error) { maxBigInt := new(big.Int) maxBigInt.Exp(big.NewInt(2), big.NewInt(130), nil).Sub(maxBigInt, big.NewInt(1)) @@ -86,20 +84,19 @@ func getX509Tpl(seed []byte) (*x509.Certificate, error) { }, nil } -func newCertificate(seed []byte) (*tls.Certificate, error) { - privkey, err := getPrivkey(seed) +func newCertificate(randSource io.Reader) (*tls.Certificate, error) { + + privkey, err := getPrivkey(randSource) if err != nil { return &tls.Certificate{}, err } - tpl, err := getX509Tpl(seed) + tpl, err := getX509Tpl(randSource) if err != nil { return &tls.Certificate{}, err } - randSource := hkdf.New(sha256.New, seed, nil, nil) - - certDER, err := x509.CreateCertificate(randSource, tpl, tpl, privkey.Public(), privkey) + certDER, err := x509.CreateCertificate(rand.Reader, tpl, tpl, privkey.Public(), privkey) if err != nil { return &tls.Certificate{}, err } @@ -111,15 +108,17 @@ func newCertificate(seed []byte) (*tls.Certificate, error) { } func certsFromSeed(seed []byte) (*tls.Certificate, *tls.Certificate, error) { - clientCert, err := newCertificate(seed) + randSource := hkdf.New(sha256.New, seed, []byte("certsFromSeed"), nil) + + clientCert, err := newCertificate(randSource) if err != nil { return &tls.Certificate{}, &tls.Certificate{}, fmt.Errorf("error generate cert: %v", err) } - // serverCert, err := newCertificate(seed) - // if err != nil { - // return &tls.Certificate{}, &tls.Certificate{}, fmt.Errorf("error generate cert: %v", err) - // } + serverCert, err := newCertificate(randSource) + if err != nil { + return &tls.Certificate{}, &tls.Certificate{}, fmt.Errorf("error generate cert: %v", err) + } - return clientCert, clientCert, nil + return clientCert, serverCert, nil }