Skip to content

Commit

Permalink
cli: support for authenticating with private keys and certificates st…
Browse files Browse the repository at this point in the history
…ored in PKCS #11 backend (#771)

Signed-off-by: Daniel Weiße <[email protected]>
  • Loading branch information
daniel-weisse authored Dec 10, 2024
1 parent 981bdaa commit 62bacea
Show file tree
Hide file tree
Showing 13 changed files with 458 additions and 45 deletions.
42 changes: 36 additions & 6 deletions cli/internal/certcache/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"errors"

"github.com/edgelesssys/marblerun/cli/internal/file"
"github.com/edgelesssys/marblerun/cli/internal/pkcs11"
"github.com/edgelesssys/marblerun/util"
"github.com/spf13/afero"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -43,21 +44,50 @@ func LoadCoordinatorCachedCert(flags *pflag.FlagSet, fs afero.Fs) (root, interme
}

// LoadClientCert parses the command line flags to load a TLS client certificate.
func LoadClientCert(flags *pflag.FlagSet) (*tls.Certificate, error) {
// The returned cancel function must be called only after the certificate is no longer needed.
func LoadClientCert(flags *pflag.FlagSet) (crt *tls.Certificate, cancel func() error, err error) {
certFile, err := flags.GetString("cert")
if err != nil {
return nil, err
return nil, nil, err
}
keyFile, err := flags.GetString("key")
if err != nil {
return nil, err
return nil, nil, err
}

pkcs11ConfigFile, err := flags.GetString("pkcs11-config")
if err != nil {
return nil, nil, err
}
pkcs11KeyID, err := flags.GetString("pkcs11-key-id")
if err != nil {
return nil, nil, err
}
pkcs11KeyLabel, err := flags.GetString("pkcs11-key-label")
if err != nil {
return nil, nil, err
}
clientCert, err := tls.LoadX509KeyPair(certFile, keyFile)
pkcs11CertID, err := flags.GetString("pkcs11-cert-id")
if err != nil {
return nil, err
return nil, nil, err
}
pkcs11CertLabel, err := flags.GetString("pkcs11-cert-label")
if err != nil {
return nil, nil, err
}

var clientCert tls.Certificate
switch {
case pkcs11ConfigFile != "":
clientCert, cancel, err = pkcs11.LoadX509KeyPair(pkcs11ConfigFile, pkcs11KeyID, pkcs11KeyLabel, pkcs11CertID, pkcs11CertLabel)
case certFile != "" && keyFile != "":
clientCert, err = tls.LoadX509KeyPair(certFile, keyFile)
cancel = func() error { return nil }
default:
err = errors.New("neither PKCS#11 nor file-based client certificate can be loaded with the provided flags")
}

return &clientCert, nil
return &clientCert, cancel, err
}

func saveCert(fh *file.Handler, root, intermediate *x509.Certificate) error {
Expand Down
21 changes: 21 additions & 0 deletions cli/internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/marblerun/cli/internal/kube"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"k8s.io/apimachinery/pkg/util/version"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/clientcmd"
Expand All @@ -32,6 +33,26 @@ func webhookDNSName(namespace string) string {
return "marble-injector." + namespace
}

func addClientAuthFlags(cmd *cobra.Command, flags *pflag.FlagSet) {
flags.StringP("cert", "c", "", "PEM encoded admin certificate file")
flags.StringP("key", "k", "", "PEM encoded admin key file")
cmd.MarkFlagsRequiredTogether("key", "cert")

flags.String("pkcs11-config", "", "Path to a PKCS#11 configuration file to load the client certificate with")
flags.String("pkcs11-key-id", "", "ID of the private key in the PKCS#11 token")
flags.String("pkcs11-key-label", "", "Label of the private key in the PKCS#11 token")
flags.String("pkcs11-cert-id", "", "ID of the certificate in the PKCS#11 token")
flags.String("pkcs11-cert-label", "", "Label of the certificate in the PKCS#11 token")
must(cobra.MarkFlagFilename(flags, "pkcs11-config", "json"))
cmd.MarkFlagsOneRequired("pkcs11-key-id", "pkcs11-key-label", "cert")
cmd.MarkFlagsOneRequired("pkcs11-cert-id", "pkcs11-cert-label", "cert")

cmd.MarkFlagsMutuallyExclusive("pkcs11-config", "cert")
cmd.MarkFlagsMutuallyExclusive("pkcs11-config", "key")
cmd.MarkFlagsOneRequired("pkcs11-config", "cert")
cmd.MarkFlagsOneRequired("pkcs11-config", "key")
}

// parseRestFlags parses the command line flags used to configure the REST client.
func parseRestFlags(cmd *cobra.Command) (api.VerifyOptions, string, error) {
eraConfig, err := cmd.Flags().GetString("era-config")
Expand Down
37 changes: 21 additions & 16 deletions cli/internal/cmd/manifestUpdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ An admin certificate specified in the original manifest is needed to verify the
Args: cobra.ExactArgs(2),
RunE: runUpdateApply,
}

cmd.Flags().StringP("cert", "c", "", "PEM encoded admin certificate file (required)")
must(cmd.MarkFlagRequired("cert"))
cmd.Flags().StringP("key", "k", "", "PEM encoded admin key file (required)")
must(cmd.MarkFlagRequired("key"))
addClientAuthFlags(cmd, cmd.Flags())

return cmd
}
Expand All @@ -66,11 +62,8 @@ All participants must use the same manifest to acknowledge the pending update.
Args: cobra.ExactArgs(2),
RunE: runUpdateAcknowledge,
}
addClientAuthFlags(cmd, cmd.Flags())

cmd.Flags().StringP("cert", "c", "", "PEM encoded admin certificate file (required)")
must(cmd.MarkFlagRequired("cert"))
cmd.Flags().StringP("key", "k", "", "PEM encoded admin key file (required)")
must(cmd.MarkFlagRequired("key"))
return cmd
}

Expand All @@ -83,11 +76,8 @@ func newUpdateCancel() *cobra.Command {
Args: cobra.ExactArgs(1),
RunE: runUpdateCancel,
}
addClientAuthFlags(cmd, cmd.Flags())

cmd.Flags().StringP("cert", "c", "", "PEM encoded admin certificate file (required)")
must(cmd.MarkFlagRequired("cert"))
cmd.Flags().StringP("key", "k", "", "PEM encoded admin key file (required)")
must(cmd.MarkFlagRequired("key"))
return cmd
}

Expand Down Expand Up @@ -116,10 +106,15 @@ func runUpdateApply(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

manifest, err := loadManifestFile(file.New(manifestFile, fs))
if err != nil {
Expand All @@ -142,10 +137,15 @@ func runUpdateAcknowledge(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

manifest, err := loadManifestFile(file.New(manifestFile, fs))
if err != nil {
Expand Down Expand Up @@ -177,10 +177,15 @@ func runUpdateCancel(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

if err := api.ManifestUpdateCancel(cmd.Context(), hostname, root, keyPair); err != nil {
return fmt.Errorf("canceling update: %w", err)
Expand Down
6 changes: 1 addition & 5 deletions cli/internal/cmd/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ func NewSecretCmd() *cobra.Command {
Manage secrets for the MarbleRun Coordinator.
Set or retrieve a secret defined in the manifest.`,
}

cmd.PersistentFlags().StringP("cert", "c", "", "PEM encoded MarbleRun user certificate file (required)")
cmd.PersistentFlags().StringP("key", "k", "", "PEM encoded MarbleRun user key file (required)")
must(cmd.MarkPersistentFlagRequired("key"))
must(cmd.MarkPersistentFlagRequired("cert"))
addClientAuthFlags(cmd, cmd.PersistentFlags())

cmd.AddCommand(newSecretSet())
cmd.AddCommand(newSecretGet())
Expand Down
7 changes: 6 additions & 1 deletion cli/internal/cmd/secretGet.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,15 @@ func runSecretGet(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

getSecrets := func(ctx context.Context) (map[string]manifest.Secret, error) {
return api.SecretGet(ctx, hostname, root, keyPair, secretIDs)
Expand Down
7 changes: 6 additions & 1 deletion cli/internal/cmd/secretSet.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ func runSecretSet(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

if err := api.SecretSet(cmd.Context(), hostname, root, keyPair, newSecrets); err != nil {
return err
Expand Down
82 changes: 82 additions & 0 deletions cli/internal/pkcs11/pkcs11.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: BUSL-1.1
*/

package pkcs11

import (
"crypto"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"

"github.com/ThalesGroup/crypto11"
)

// LoadX509KeyPair loads a [tls.Certificate] using the provided PKCS#11 configuration file.
// The returned cancel function must be called to release the PKCS#11 resources only after the certificate is no longer needed.
func LoadX509KeyPair(pkcs11ConfigPath string, keyID, keyLabel, certID, certLabel string) (crt tls.Certificate, cancel func() error, err error) {
pkcs11, err := crypto11.ConfigureFromFile(pkcs11ConfigPath)
if err != nil {
return crt, nil, err
}
defer func() {
if err != nil {
err = errors.Join(err, pkcs11.Close())
}
}()

var keyIDBytes, keyLabelBytes, certIDBytes, certLabelBytes []byte
if keyID != "" {
keyIDBytes = []byte(keyID)
}
if keyLabel != "" {
keyLabelBytes = []byte(keyLabel)
}
if certID != "" {
certIDBytes = []byte(certID)
}
if certLabel != "" {
certLabelBytes = []byte(certLabel)
}

privateKey, err := loadPrivateKey(pkcs11, keyIDBytes, keyLabelBytes)
if err != nil {
return crt, nil, err
}
cert, err := loadCertificate(pkcs11, certIDBytes, certLabelBytes)
if err != nil {
return crt, nil, err
}

return tls.Certificate{
Certificate: [][]byte{cert.Raw},
PrivateKey: privateKey,
Leaf: cert,
}, pkcs11.Close, nil
}

func loadPrivateKey(pkcs11 *crypto11.Context, id, label []byte) (crypto.Signer, error) {
priv, err := pkcs11.FindKeyPair(id, label)
if err != nil {
return nil, err
}
if priv == nil {
return nil, fmt.Errorf("no key pair found for id \"%s\" and label \"%s\"", id, label)
}
return priv, nil
}

func loadCertificate(pkcs11 *crypto11.Context, id, label []byte) (*x509.Certificate, error) {
cert, err := pkcs11.FindCertificate(id, label, nil)
if err != nil {
return nil, err
}
if cert == nil {
return nil, fmt.Errorf("no certificate found for id \"%s\" and label \"%s\"", id, label)
}
return cert, nil
}
Loading

0 comments on commit 62bacea

Please sign in to comment.