diff --git a/config/config.go b/config/config.go index f3a9e61516f8..d603b2c4313c 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ package config import ( + "context" "crypto/tls" "encoding/base64" "encoding/json" @@ -38,6 +39,7 @@ import ( "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/crypto/bls/signer/localsigner" + "github.com/ava-labs/avalanchego/utils/crypto/bls/signer/rpcsigner" "github.com/ava-labs/avalanchego/utils/ips" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/perms" @@ -84,6 +86,7 @@ var ( errCannotReadDirectory = errors.New("cannot read directory") errUnmarshalling = errors.New("unmarshalling failed") errFileDoesNotExist = errors.New("file does not exist") + errInvalidSignerConfig = fmt.Errorf("only one of the following flags can be set: %s, %s, %s, %s", StakingEphemeralSignerEnabledKey, StakingSignerKeyContentKey, StakingSignerKeyPathKey, StakingRPCSignerKey) ) func getConsensusConfig(v *viper.Viper) snowball.Parameters { @@ -639,66 +642,95 @@ func getStakingTLSCert(v *viper.Viper) (tls.Certificate, error) { } } -func getStakingSigner(v *viper.Viper) (bls.Signer, error) { - if v.GetBool(StakingEphemeralSignerEnabledKey) { - key, err := localsigner.New() +func getStakingSigner(ctx context.Context, v *viper.Viper) (bls.Signer, error) { + ephemeralSignerEnabled := v.GetBool(StakingEphemeralSignerEnabledKey) + contentKeyIsSet := v.IsSet(StakingSignerKeyContentKey) + keyPathIsSet := v.IsSet(StakingSignerKeyPathKey) + rpcSignerURLIsSet := v.IsSet(StakingRPCSignerKey) + + signingKeyPath := getExpandedArg(v, StakingSignerKeyPathKey) + + switch { + case ephemeralSignerEnabled && !contentKeyIsSet && !keyPathIsSet && !rpcSignerURLIsSet: + signer, err := localsigner.New() if err != nil { - return nil, fmt.Errorf("couldn't generate ephemeral signing key: %w", err) + return nil, fmt.Errorf("couldn't generate ephemeral signing signer: %w", err) } - return key, nil - } - if v.IsSet(StakingSignerKeyContentKey) { + return signer, nil + + case !ephemeralSignerEnabled && contentKeyIsSet && !keyPathIsSet && !rpcSignerURLIsSet: signerKeyRawContent := v.GetString(StakingSignerKeyContentKey) signerKeyContent, err := base64.StdEncoding.DecodeString(signerKeyRawContent) if err != nil { return nil, fmt.Errorf("unable to decode base64 content: %w", err) } - key, err := localsigner.FromBytes(signerKeyContent) + + signer, err := localsigner.FromBytes(signerKeyContent) if err != nil { return nil, fmt.Errorf("couldn't parse signing key: %w", err) } - return key, nil - } - signingKeyPath := getExpandedArg(v, StakingSignerKeyPathKey) - _, err := os.Stat(signingKeyPath) - if !errors.Is(err, fs.ErrNotExist) { + return signer, nil + + case !ephemeralSignerEnabled && !contentKeyIsSet && keyPathIsSet && !rpcSignerURLIsSet: + // If the key is set, but a user-file isn't provided, we don't create one. + // The siging key is only stored to the default file-location if it's created + // and saved by the current application run. + _, err := os.Stat(signingKeyPath) + + if errors.Is(err, fs.ErrNotExist) { + return nil, errMissingStakingSigningKeyFile + } + signingKeyBytes, err := os.ReadFile(signingKeyPath) if err != nil { return nil, err } - key, err := localsigner.FromBytes(signingKeyBytes) + + signer, err := localsigner.FromBytes(signingKeyBytes) if err != nil { return nil, fmt.Errorf("couldn't parse signing key: %w", err) } - return key, nil - } - if v.IsSet(StakingSignerKeyPathKey) { - return nil, errMissingStakingSigningKeyFile - } + return signer, nil - key, err := localsigner.New() - if err != nil { - return nil, fmt.Errorf("couldn't generate new signing key: %w", err) - } + case !ephemeralSignerEnabled && !contentKeyIsSet && !keyPathIsSet && rpcSignerURLIsSet: + rpcSignerURL := v.GetString(StakingRPCSignerKey) - if err := os.MkdirAll(filepath.Dir(signingKeyPath), perms.ReadWriteExecute); err != nil { - return nil, fmt.Errorf("couldn't create path for signing key at %s: %w", signingKeyPath, err) - } + signer, err := rpcsigner.NewClient(ctx, rpcSignerURL) + if err != nil { + return nil, fmt.Errorf("couldn't create rpc signer client: %w", err) + } - keyBytes := key.ToBytes() - if err := os.WriteFile(signingKeyPath, keyBytes, perms.ReadWrite); err != nil { - return nil, fmt.Errorf("couldn't write new signing key to %s: %w", signingKeyPath, err) - } - if err := os.Chmod(signingKeyPath, perms.ReadOnly); err != nil { - return nil, fmt.Errorf("couldn't restrict permissions on new signing key at %s: %w", signingKeyPath, err) + return signer, nil + + case ephemeralSignerEnabled || contentKeyIsSet || keyPathIsSet || rpcSignerURLIsSet: + return nil, errInvalidSignerConfig + default: + signer, err := localsigner.New() + if err != nil { + return nil, fmt.Errorf("couldn't generate new signing key: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(signingKeyPath), perms.ReadWriteExecute); err != nil { + return nil, fmt.Errorf("couldn't create path for signing key at %s: %w", signingKeyPath, err) + } + + keyBytes := signer.ToBytes() + if err := os.WriteFile(signingKeyPath, keyBytes, perms.ReadWrite); err != nil { + return nil, fmt.Errorf("couldn't write new signing key to %s: %w", signingKeyPath, err) + } + + if err := os.Chmod(signingKeyPath, perms.ReadOnly); err != nil { + return nil, fmt.Errorf("couldn't restrict permissions on new signing key at %s: %w", signingKeyPath, err) + } + + return signer, nil } - return key, nil } -func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, error) { +func getStakingConfig(ctx context.Context, v *viper.Viper, networkID uint32) (node.StakingConfig, error) { config := node.StakingConfig{ SybilProtectionEnabled: v.GetBool(SybilProtectionEnabledKey), SybilProtectionDisabledWeight: v.GetUint64(SybilProtectionDisabledWeightKey), @@ -706,6 +738,7 @@ func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, err StakingKeyPath: getExpandedArg(v, StakingTLSKeyPathKey), StakingCertPath: getExpandedArg(v, StakingCertPathKey), StakingSignerPath: getExpandedArg(v, StakingSignerKeyPathKey), + StakingSignerRPC: getExpandedArg(v, StakingRPCSignerKey), } if !config.SybilProtectionEnabled && config.SybilProtectionDisabledWeight == 0 { return node.StakingConfig{}, errSybilProtectionDisabledStakerWeights @@ -720,7 +753,7 @@ func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, err if err != nil { return node.StakingConfig{}, err } - config.StakingSigningKey, err = getStakingSigner(v) + config.StakingSigningKey, err = getStakingSigner(ctx, v) if err != nil { return node.StakingConfig{}, err } @@ -1244,7 +1277,7 @@ func getPluginDir(v *viper.Viper) (string, error) { return pluginDir, nil } -func GetNodeConfig(v *viper.Viper) (node.Config, error) { +func GetNodeConfig(ctx context.Context, v *viper.Viper) (node.Config, error) { var ( nodeConfig node.Config err error @@ -1299,7 +1332,7 @@ func GetNodeConfig(v *viper.Viper) (node.Config, error) { } // Staking - nodeConfig.StakingConfig, err = getStakingConfig(v, nodeConfig.NetworkID) + nodeConfig.StakingConfig, err = getStakingConfig(ctx, v, nodeConfig.NetworkID) if err != nil { return node.Config{}, err } diff --git a/config/config_test.go b/config/config_test.go index 68847ca4f6d5..53beec8214ce 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,22 +4,30 @@ package config import ( + "context" "encoding/base64" "encoding/json" "fmt" "log" + "net" "os" "path/filepath" + "reflect" "testing" "github.com/spf13/pflag" "github.com/spf13/viper" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "github.com/ava-labs/avalanchego/chains" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/proto/pb/signer" "github.com/ava-labs/avalanchego/snow/consensus/snowball" "github.com/ava-labs/avalanchego/subnets" + "github.com/ava-labs/avalanchego/utils/crypto/bls/signer/localsigner" + "github.com/ava-labs/avalanchego/utils/crypto/bls/signer/rpcsigner" + "github.com/ava-labs/avalanchego/utils/perms" ) const chainConfigFilenameExtention = ".ex" @@ -541,6 +549,108 @@ func TestGetSubnetConfigsFromFlags(t *testing.T) { } } +type signerServer struct { + signer.UnimplementedSignerServer +} + +func (*signerServer) PublicKey(context.Context, *signer.PublicKeyRequest) (*signer.PublicKeyResponse, error) { + // for tests to pass, this must be the base64 encoding of a 32 byte public key + // but it does not need to be associated with any private key + bytes, err := base64.StdEncoding.DecodeString("j8Ndzc1I6EYWYUWAdhcwpQ1I2xX/i4fdwgJIaxbHlf9yQKMT0jlReiiLYsydgaS1") + if err != nil { + return nil, err + } + + return &signer.PublicKeyResponse{ + PublicKey: bytes, + }, nil +} + +func TestGetStakingSigner(t *testing.T) { + testKey := "HLimS3vRibTMk9lZD4b+Z+GLuSBShvgbsu0WTLt2Kd4=" + rpcServer := grpc.NewServer() + defer rpcServer.GracefulStop() + + signer.RegisterSignerServer(rpcServer, &signerServer{}) + + listener, err := net.Listen("tcp", "[::1]:0") + require.NoError(t, err) + + go func() { + require.NoError(t, rpcServer.Serve(listener)) + }() + + type config map[string]any + + tests := []struct { + name string + viperKeys string + config config + expectedSignerType reflect.Type + expectedErr error + }{ + { + name: "default-signer", + expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}), + }, + { + name: "ephemeral-signer", + config: config{StakingEphemeralSignerEnabledKey: true}, + expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}), + }, + { + name: "content-key", + config: config{StakingSignerKeyContentKey: testKey}, + expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}), + }, + { + name: "file-key", + config: config{ + StakingSignerKeyPathKey: func() string { + filePath := filepath.Join(t.TempDir(), "signer.key") + bytes, err := base64.StdEncoding.DecodeString(testKey) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filePath, bytes, perms.ReadWrite)) + return filePath + }(), + }, + expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}), + }, + { + name: "rpc-signer", + config: config{StakingRPCSignerKey: listener.Addr().String()}, + expectedSignerType: reflect.TypeOf(&rpcsigner.Client{}), + }, + { + name: "multiple-configurations-set", + config: config{ + StakingEphemeralSignerEnabledKey: true, + StakingSignerKeyContentKey: testKey, + }, + expectedErr: errInvalidSignerConfig, + }, + } + + // required for proper write permissions for the default signer-key location + t.Setenv("HOME", t.TempDir()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + v := setupViperFlags() + + for key, value := range tt.config { + v.Set(key, value) + } + + signer, err := getStakingSigner(context.Background(), v) + + require.ErrorIs(err, tt.expectedErr) + require.Equal(tt.expectedSignerType, reflect.TypeOf(signer)) + }) + } +} + // setups config json file and writes content func setupConfigJSON(t *testing.T, rootPath string, value string) string { configFilePath := filepath.Join(rootPath, "config.json") diff --git a/config/flags.go b/config/flags.go index 193d1343d559..e08ccd488444 100644 --- a/config/flags.go +++ b/config/flags.go @@ -271,6 +271,7 @@ func addNodeFlags(fs *pflag.FlagSet) { fs.Bool(StakingEphemeralSignerEnabledKey, false, "If true, the node uses an ephemeral staking signer key") fs.String(StakingSignerKeyPathKey, defaultStakingSignerKeyPath, fmt.Sprintf("Path to the signer private key for staking. Ignored if %s is specified", StakingSignerKeyContentKey)) fs.String(StakingSignerKeyContentKey, "", "Specifies base64 encoded signer private key for staking") + fs.String(StakingRPCSignerKey, "", "Specifies the RPC endpoint of the staking signer") fs.Bool(SybilProtectionEnabledKey, true, "Enables sybil protection. If enabled, Network TLS is required") fs.Uint64(SybilProtectionDisabledWeightKey, 100, "Weight to provide to each peer when sybil protection is disabled") fs.Bool(PartialSyncPrimaryNetworkKey, false, "Only sync the P-chain on the Primary Network. If the node is a Primary Network validator, it will report unhealthy") diff --git a/config/keys.go b/config/keys.go index 760bee97fd77..32dc6edad831 100644 --- a/config/keys.go +++ b/config/keys.go @@ -85,6 +85,7 @@ const ( StakingEphemeralSignerEnabledKey = "staking-ephemeral-signer-enabled" StakingSignerKeyPathKey = "staking-signer-key-file" StakingSignerKeyContentKey = "staking-signer-key-file-content" + StakingRPCSignerKey = "staking-rpc-signer" SybilProtectionEnabledKey = "sybil-protection-enabled" SybilProtectionDisabledWeightKey = "sybil-protection-disabled-weight" NetworkInitialTimeoutKey = "network-initial-timeout" diff --git a/config/node/config.go b/config/node/config.go index a97c3b84df2e..8abef8328988 100644 --- a/config/node/config.go +++ b/config/node/config.go @@ -82,6 +82,7 @@ type StakingConfig struct { StakingKeyPath string `json:"stakingKeyPath"` StakingCertPath string `json:"stakingCertPath"` StakingSignerPath string `json:"stakingSignerPath"` + StakingSignerRPC string `json:"stakingSignerRpc"` } type StateSyncConfig struct { diff --git a/main/main.go b/main/main.go index bdd4a5d83b23..ced58a9854f2 100644 --- a/main/main.go +++ b/main/main.go @@ -4,10 +4,12 @@ package main import ( + "context" "encoding/json" "errors" "fmt" "os" + "time" "github.com/spf13/pflag" "golang.org/x/term" @@ -52,7 +54,9 @@ func main() { os.Exit(0) } - nodeConfig, err := config.GetNodeConfig(v) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + nodeConfig, err := config.GetNodeConfig(ctx, v) + cancel() if err != nil { fmt.Printf("couldn't load node config: %s\n", err) os.Exit(1) diff --git a/utils/crypto/bls/signer/rpcsigner/client.go b/utils/crypto/bls/signer/rpcsigner/client.go index 294bb1829d81..af0a5eda150c 100644 --- a/utils/crypto/bls/signer/rpcsigner/client.go +++ b/utils/crypto/bls/signer/rpcsigner/client.go @@ -5,8 +5,11 @@ package rpcsigner import ( "context" + "fmt" "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/credentials/insecure" "github.com/ava-labs/avalanchego/utils/crypto/bls" @@ -21,7 +24,24 @@ type Client struct { pk *bls.PublicKey } -func NewClient(ctx context.Context, conn *grpc.ClientConn) (*Client, error) { +func NewClient(ctx context.Context, rpcSignerURL string) (*Client, error) { + // TODO: figure out the best parameters here given the target block-time + opts := grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.DefaultConfig, + }) + + // the rpc-signer client should call a proxy server (on the same machine) that forwards + // the request to the actual signer instead of relying on tls-credentials + conn, err := grpc.NewClient(rpcSignerURL, opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, fmt.Errorf("couldn't create rpc signer client: %w", err) + } + defer func() { + if err != nil { + conn.Close() + } + }() + client := pb.NewSignerClient(conn) pubkeyResponse, err := client.PublicKey(ctx, &pb.PublicKeyRequest{}) @@ -46,22 +66,58 @@ func (c *Client) PublicKey() *bls.PublicKey { return c.pk } +// Sign a message. The [Client] already handles transient connection errors. If this method fails, it will +// render the client in an unusable state and the client should be discarded. func (c *Client) Sign(message []byte) (*bls.Signature, error) { + var err error + defer func() { + if err != nil { + c.Close() + } + }() + resp, err := c.client.Sign(context.TODO(), &pb.SignRequest{Message: message}) if err != nil { return nil, err } - signature := resp.GetSignature() - return bls.SignatureFromBytes(signature) + sigBytes := resp.GetSignature() + sig, err := bls.SignatureFromBytes(sigBytes) + if err != nil { + return nil, err + } + + return sig, nil } +// [SignProofOfPossession] has the same behavior as [Sign] but will product a different signature. +// See BLS spec for more details. func (c *Client) SignProofOfPossession(message []byte) (*bls.Signature, error) { + var err error + defer func() { + if err != nil { + c.Close() + } + }() + resp, err := c.client.SignProofOfPossession(context.TODO(), &pb.SignProofOfPossessionRequest{Message: message}) if err != nil { return nil, err } - signature := resp.GetSignature() - return bls.SignatureFromBytes(signature) + sigBytes := resp.GetSignature() + sig, err := bls.SignatureFromBytes(sigBytes) + if err != nil { + return nil, err + } + + return sig, nil +} + +func (c *Client) Close() error { + if c.conn == nil { + return nil + } + + return c.conn.Close() }