Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPC-signer configuration #3725

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 70 additions & 37 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package config

import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
Expand Down Expand Up @@ -38,6 +39,7 @@ import (
"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/crypto/bls/signer/localsigner"
"github.com/ava-labs/avalanchego/utils/crypto/bls/signer/rpcsigner"
"github.com/ava-labs/avalanchego/utils/ips"
"github.com/ava-labs/avalanchego/utils/logging"
"github.com/ava-labs/avalanchego/utils/perms"
Expand Down Expand Up @@ -84,6 +86,7 @@ var (
errCannotReadDirectory = errors.New("cannot read directory")
errUnmarshalling = errors.New("unmarshalling failed")
errFileDoesNotExist = errors.New("file does not exist")
errInvalidSignerConfig = fmt.Errorf("only one of the following flags can be set: %s, %s, %s, %s", StakingEphemeralSignerEnabledKey, StakingSignerKeyContentKey, StakingSignerKeyPathKey, StakingRPCSignerKey)
)

func getConsensusConfig(v *viper.Viper) snowball.Parameters {
Expand Down Expand Up @@ -639,73 +642,103 @@ func getStakingTLSCert(v *viper.Viper) (tls.Certificate, error) {
}
}

func getStakingSigner(v *viper.Viper) (bls.Signer, error) {
if v.GetBool(StakingEphemeralSignerEnabledKey) {
key, err := localsigner.New()
func getStakingSigner(ctx context.Context, v *viper.Viper) (bls.Signer, error) {
ephemeralSignerEnabled := v.GetBool(StakingEphemeralSignerEnabledKey)
contentKeyIsSet := v.IsSet(StakingSignerKeyContentKey)
keyPathIsSet := v.IsSet(StakingSignerKeyPathKey)
rpcSignerURLIsSet := v.IsSet(StakingRPCSignerKey)

signingKeyPath := getExpandedArg(v, StakingSignerKeyPathKey)

switch {
case ephemeralSignerEnabled && !contentKeyIsSet && !keyPathIsSet && !rpcSignerURLIsSet:
signer, err := localsigner.New()
if err != nil {
return nil, fmt.Errorf("couldn't generate ephemeral signing key: %w", err)
return nil, fmt.Errorf("couldn't generate ephemeral signing signer: %w", err)
}
return key, nil
}

if v.IsSet(StakingSignerKeyContentKey) {
return signer, nil

case !ephemeralSignerEnabled && contentKeyIsSet && !keyPathIsSet && !rpcSignerURLIsSet:
signerKeyRawContent := v.GetString(StakingSignerKeyContentKey)
signerKeyContent, err := base64.StdEncoding.DecodeString(signerKeyRawContent)
if err != nil {
return nil, fmt.Errorf("unable to decode base64 content: %w", err)
}
key, err := localsigner.FromBytes(signerKeyContent)

signer, err := localsigner.FromBytes(signerKeyContent)
if err != nil {
return nil, fmt.Errorf("couldn't parse signing key: %w", err)
}
return key, nil
}

signingKeyPath := getExpandedArg(v, StakingSignerKeyPathKey)
_, err := os.Stat(signingKeyPath)
if !errors.Is(err, fs.ErrNotExist) {
return signer, nil

case !ephemeralSignerEnabled && !contentKeyIsSet && keyPathIsSet && !rpcSignerURLIsSet:
// If the key is set, but a user-file isn't provided, we don't create one.
// The siging key is only stored to the default file-location if it's created
// and saved by the current application run.
_, err := os.Stat(signingKeyPath)

if errors.Is(err, fs.ErrNotExist) {
return nil, errMissingStakingSigningKeyFile
}

signingKeyBytes, err := os.ReadFile(signingKeyPath)
if err != nil {
return nil, err
}
key, err := localsigner.FromBytes(signingKeyBytes)

signer, err := localsigner.FromBytes(signingKeyBytes)
if err != nil {
return nil, fmt.Errorf("couldn't parse signing key: %w", err)
}
return key, nil
}

if v.IsSet(StakingSignerKeyPathKey) {
return nil, errMissingStakingSigningKeyFile
}
return signer, nil

key, err := localsigner.New()
if err != nil {
return nil, fmt.Errorf("couldn't generate new signing key: %w", err)
}
case !ephemeralSignerEnabled && !contentKeyIsSet && !keyPathIsSet && rpcSignerURLIsSet:
rpcSignerURL := v.GetString(StakingRPCSignerKey)

if err := os.MkdirAll(filepath.Dir(signingKeyPath), perms.ReadWriteExecute); err != nil {
return nil, fmt.Errorf("couldn't create path for signing key at %s: %w", signingKeyPath, err)
}
signer, err := rpcsigner.NewClient(ctx, rpcSignerURL)
if err != nil {
return nil, fmt.Errorf("couldn't create rpc signer client: %w", err)
}

keyBytes := key.ToBytes()
if err := os.WriteFile(signingKeyPath, keyBytes, perms.ReadWrite); err != nil {
return nil, fmt.Errorf("couldn't write new signing key to %s: %w", signingKeyPath, err)
}
if err := os.Chmod(signingKeyPath, perms.ReadOnly); err != nil {
return nil, fmt.Errorf("couldn't restrict permissions on new signing key at %s: %w", signingKeyPath, err)
return signer, nil

case ephemeralSignerEnabled || contentKeyIsSet || keyPathIsSet || rpcSignerURLIsSet:
return nil, errInvalidSignerConfig
default:
signer, err := localsigner.New()
if err != nil {
return nil, fmt.Errorf("couldn't generate new signing key: %w", err)
}

if err := os.MkdirAll(filepath.Dir(signingKeyPath), perms.ReadWriteExecute); err != nil {
return nil, fmt.Errorf("couldn't create path for signing key at %s: %w", signingKeyPath, err)
}

keyBytes := signer.ToBytes()
if err := os.WriteFile(signingKeyPath, keyBytes, perms.ReadWrite); err != nil {
return nil, fmt.Errorf("couldn't write new signing key to %s: %w", signingKeyPath, err)
}

if err := os.Chmod(signingKeyPath, perms.ReadOnly); err != nil {
return nil, fmt.Errorf("couldn't restrict permissions on new signing key at %s: %w", signingKeyPath, err)
}

return signer, nil
}
return key, nil
}

func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, error) {
func getStakingConfig(ctx context.Context, v *viper.Viper, networkID uint32) (node.StakingConfig, error) {
config := node.StakingConfig{
SybilProtectionEnabled: v.GetBool(SybilProtectionEnabledKey),
SybilProtectionDisabledWeight: v.GetUint64(SybilProtectionDisabledWeightKey),
PartialSyncPrimaryNetwork: v.GetBool(PartialSyncPrimaryNetworkKey),
StakingKeyPath: getExpandedArg(v, StakingTLSKeyPathKey),
StakingCertPath: getExpandedArg(v, StakingCertPathKey),
StakingSignerPath: getExpandedArg(v, StakingSignerKeyPathKey),
StakingSignerRPC: getExpandedArg(v, StakingRPCSignerKey),
}
if !config.SybilProtectionEnabled && config.SybilProtectionDisabledWeight == 0 {
return node.StakingConfig{}, errSybilProtectionDisabledStakerWeights
Expand All @@ -720,7 +753,7 @@ func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, err
if err != nil {
return node.StakingConfig{}, err
}
config.StakingSigningKey, err = getStakingSigner(v)
config.StakingSigningKey, err = getStakingSigner(ctx, v)
if err != nil {
return node.StakingConfig{}, err
}
Expand Down Expand Up @@ -1244,7 +1277,7 @@ func getPluginDir(v *viper.Viper) (string, error) {
return pluginDir, nil
}

func GetNodeConfig(v *viper.Viper) (node.Config, error) {
func GetNodeConfig(ctx context.Context, v *viper.Viper) (node.Config, error) {
var (
nodeConfig node.Config
err error
Expand Down Expand Up @@ -1299,7 +1332,7 @@ func GetNodeConfig(v *viper.Viper) (node.Config, error) {
}

// Staking
nodeConfig.StakingConfig, err = getStakingConfig(v, nodeConfig.NetworkID)
nodeConfig.StakingConfig, err = getStakingConfig(ctx, v, nodeConfig.NetworkID)
if err != nil {
return node.Config{}, err
}
Expand Down
110 changes: 110 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@
package config

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net"
"os"
"path/filepath"
"reflect"
"testing"

"github.com/spf13/pflag"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"

"github.com/ava-labs/avalanchego/chains"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/proto/pb/signer"
"github.com/ava-labs/avalanchego/snow/consensus/snowball"
"github.com/ava-labs/avalanchego/subnets"
"github.com/ava-labs/avalanchego/utils/crypto/bls/signer/localsigner"
"github.com/ava-labs/avalanchego/utils/crypto/bls/signer/rpcsigner"
"github.com/ava-labs/avalanchego/utils/perms"
)

const chainConfigFilenameExtention = ".ex"
Expand Down Expand Up @@ -541,6 +549,108 @@ func TestGetSubnetConfigsFromFlags(t *testing.T) {
}
}

type signerServer struct {
signer.UnimplementedSignerServer
}

func (*signerServer) PublicKey(context.Context, *signer.PublicKeyRequest) (*signer.PublicKeyResponse, error) {
// for tests to pass, this must be the base64 encoding of a 32 byte public key
// but it does not need to be associated with any private key
bytes, err := base64.StdEncoding.DecodeString("j8Ndzc1I6EYWYUWAdhcwpQ1I2xX/i4fdwgJIaxbHlf9yQKMT0jlReiiLYsydgaS1")
if err != nil {
return nil, err
}

return &signer.PublicKeyResponse{
PublicKey: bytes,
}, nil
}

func TestGetStakingSigner(t *testing.T) {
testKey := "HLimS3vRibTMk9lZD4b+Z+GLuSBShvgbsu0WTLt2Kd4="
rpcServer := grpc.NewServer()
defer rpcServer.GracefulStop()

signer.RegisterSignerServer(rpcServer, &signerServer{})

listener, err := net.Listen("tcp", "[::1]:0")
require.NoError(t, err)

go func() {
require.NoError(t, rpcServer.Serve(listener))
}()

type config map[string]any

tests := []struct {
name string
viperKeys string
config config
expectedSignerType reflect.Type
expectedErr error
}{
{
name: "default-signer",
expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}),
},
{
name: "ephemeral-signer",
config: config{StakingEphemeralSignerEnabledKey: true},
expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}),
},
{
name: "content-key",
config: config{StakingSignerKeyContentKey: testKey},
expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}),
},
{
name: "file-key",
config: config{
StakingSignerKeyPathKey: func() string {
filePath := filepath.Join(t.TempDir(), "signer.key")
bytes, err := base64.StdEncoding.DecodeString(testKey)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filePath, bytes, perms.ReadWrite))
return filePath
}(),
},
expectedSignerType: reflect.TypeOf(&localsigner.LocalSigner{}),
},
{
name: "rpc-signer",
config: config{StakingRPCSignerKey: listener.Addr().String()},
expectedSignerType: reflect.TypeOf(&rpcsigner.Client{}),
},
{
name: "multiple-configurations-set",
config: config{
StakingEphemeralSignerEnabledKey: true,
StakingSignerKeyContentKey: testKey,
},
expectedErr: errInvalidSignerConfig,
},
}

// required for proper write permissions for the default signer-key location
t.Setenv("HOME", t.TempDir())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a defer step to clean up stuff? Or is this overkill?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe t.SetEnv takes care of this for you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI:

// Setenv calls os.Setenv(key, value) and uses Cleanup to
// restore the environment variable to its original value
// after the test.
//
// Because Setenv affects the whole process, it cannot be used
// in parallel tests or tests with parallel ancestors.
func (t *T) Setenv(key, value string) {
	// Non-parallel subtests that have parallel ancestors may still
	// run in parallel with other tests: they are only non-parallel
	// with respect to the other subtests of the same parent.
	// Since SetEnv affects the whole process, we need to disallow it
	// if the current test or any parent is parallel.
	isParallel := false
	for c := &t.common; c != nil; c = c.parent {
		if c.isParallel {
			isParallel = true
			break
		}
	}
	if isParallel {
		panic("testing: t.Setenv called after t.Parallel; cannot set environment variables in parallel tests")
	}

	t.isEnvSet = true

	t.common.Setenv(key, value)
}


for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
v := setupViperFlags()

for key, value := range tt.config {
v.Set(key, value)
}

signer, err := getStakingSigner(context.Background(), v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we test this function instead of GetNodeConfig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't change GetNodeConfig aside from taking a context.Context as an argument. These are unit tests; they are testing the most granular unit possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is testing code not exported by the actual package though (i.e, getStakingSigner is an implementation detail of the package and not part of its public API) ... shouldn't we test this through GetNodeConfig instead?


require.ErrorIs(err, tt.expectedErr)
require.Equal(tt.expectedSignerType, reflect.TypeOf(signer))
})
}
}

// setups config json file and writes content
func setupConfigJSON(t *testing.T, rootPath string, value string) string {
configFilePath := filepath.Join(rootPath, "config.json")
Expand Down
1 change: 1 addition & 0 deletions config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ func addNodeFlags(fs *pflag.FlagSet) {
fs.Bool(StakingEphemeralSignerEnabledKey, false, "If true, the node uses an ephemeral staking signer key")
fs.String(StakingSignerKeyPathKey, defaultStakingSignerKeyPath, fmt.Sprintf("Path to the signer private key for staking. Ignored if %s is specified", StakingSignerKeyContentKey))
fs.String(StakingSignerKeyContentKey, "", "Specifies base64 encoded signer private key for staking")
fs.String(StakingRPCSignerKey, "", "Specifies the RPC endpoint of the staking signer")
fs.Bool(SybilProtectionEnabledKey, true, "Enables sybil protection. If enabled, Network TLS is required")
fs.Uint64(SybilProtectionDisabledWeightKey, 100, "Weight to provide to each peer when sybil protection is disabled")
fs.Bool(PartialSyncPrimaryNetworkKey, false, "Only sync the P-chain on the Primary Network. If the node is a Primary Network validator, it will report unhealthy")
Expand Down
1 change: 1 addition & 0 deletions config/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ const (
StakingEphemeralSignerEnabledKey = "staking-ephemeral-signer-enabled"
StakingSignerKeyPathKey = "staking-signer-key-file"
StakingSignerKeyContentKey = "staking-signer-key-file-content"
StakingRPCSignerKey = "staking-rpc-signer"
SybilProtectionEnabledKey = "sybil-protection-enabled"
SybilProtectionDisabledWeightKey = "sybil-protection-disabled-weight"
NetworkInitialTimeoutKey = "network-initial-timeout"
Expand Down
1 change: 1 addition & 0 deletions config/node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type StakingConfig struct {
StakingKeyPath string `json:"stakingKeyPath"`
StakingCertPath string `json:"stakingCertPath"`
StakingSignerPath string `json:"stakingSignerPath"`
StakingSignerRPC string `json:"stakingSignerRpc"`
}

type StateSyncConfig struct {
Expand Down
6 changes: 5 additions & 1 deletion main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
package main

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"time"

"github.com/spf13/pflag"
"golang.org/x/term"
Expand Down Expand Up @@ -52,7 +54,9 @@ func main() {
os.Exit(0)
}

nodeConfig, err := config.GetNodeConfig(v)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
nodeConfig, err := config.GetNodeConfig(ctx, v)
cancel()
if err != nil {
fmt.Printf("couldn't load node config: %s\n", err)
os.Exit(1)
Expand Down
Loading
Loading