Skip to content

Commit

Permalink
Sync from server repo (947a0b40484)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Spilchen committed Nov 28, 2023
1 parent 8802a09 commit 6f31867
Show file tree
Hide file tree
Showing 86 changed files with 1,350 additions and 1,201 deletions.
6 changes: 3 additions & 3 deletions commands/cmd_restart_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
*/
type CmdRestartNodes struct {
CmdBase
restartNodesOptions *vclusterops.VRestartNodesOptions
restartNodesOptions *vclusterops.VStartNodesOptions

// Comma-separated list of vnode=host
vnodeListStr *string
Expand All @@ -28,7 +28,7 @@ func makeCmdRestartNodes() *CmdRestartNodes {

// parser, used to parse command-line flags
newCmd.parser = flag.NewFlagSet("restart_node", flag.ExitOnError)
restartNodesOptions := vclusterops.VRestartNodesOptionsFactory()
restartNodesOptions := vclusterops.VStartNodesOptionsFactory()

// require flags
restartNodesOptions.DBName = newCmd.parser.String("db-name", "", "The name of the database to restart nodes")
Expand Down Expand Up @@ -114,7 +114,7 @@ func (c *CmdRestartNodes) Run(vcc vclusterops.VClusterCommands) error {
options.Config = config

// this is the instruction that will be used by both CLI and operator
err = vcc.VRestartNodes(options)
err = vcc.VStartNodes(options)
if err != nil {
return err
}
Expand Down
102 changes: 88 additions & 14 deletions commands/cmd_scrutinize.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ import (
)

const (
// Environment variable names storing paths to PEM text.
// Environment variable names storing name of k8s secret that has NMA cert
secretNameSpaceEnvVar = "NMA_SECRET_NAMESPACE"
secretNameEnvVar = "NMA_SECRET_NAME"

// Environment variable names for locating the NMA certs located in the file system
nmaRootCAPathEnvVar = "NMA_ROOTCA_PATH"
nmaCertPathEnvVar = "NMA_CERT_PATH"
nmaKeyPathEnvVar = "NMA_KEY_PATH"
)

const (
Expand Down Expand Up @@ -137,8 +142,8 @@ func (c *CmdScrutinize) validateParse(logger vlog.Printer) error {
func (c *CmdScrutinize) Analyze(logger vlog.Printer) error {
logger.Info("Called method Analyze()")

// set cert/key values from env k8s
err := c.updateCertTextsFromk8s(logger)
// Read the NMA certs into the options struct
err := c.readNMACerts(logger)
if err != nil {
return err
}
Expand Down Expand Up @@ -224,45 +229,114 @@ func (k8sSecretRetrieverStruct) RetrieveSecret(namespace, secretName string) (ca
return caCertVal, tlsCertVal, tlsKeyVal, nil
}

// updateCertTextsFromk8s retrieves PEM-encoded text of CA certs, the server cert, and
// the server key from kubernetes.
func (c *CmdScrutinize) updateCertTextsFromk8s(logger vlog.Printer) error {
func (c *CmdScrutinize) readNMACerts(logger vlog.Printer) error {
loaderFuncs := []func(vlog.Printer) (bool, error){
c.nmaCertLookupFromK8sSecret,
c.nmaCertLookupFromEnv,
}
for _, fnc := range loaderFuncs {
certsLoaded, err := fnc(logger)
if err != nil || certsLoaded {
return err
}
}
logger.Info("failed to retrieve the NMA certs from any source")
return nil
}

// nmaCertLookupFromK8sSecret retrieves PEM-encoded text of CA certs, the server cert, and
// the server key directly from kubernetes secrets.
func (c *CmdScrutinize) nmaCertLookupFromK8sSecret(logger vlog.Printer) (bool, error) {
_, portSet := os.LookupEnv(kubernetesPort)
if !portSet {
return nil
return false, nil
}
logger.Info("K8s environment")
secretNameSpace, nameSpaceSet := os.LookupEnv(secretNameSpaceEnvVar)
secretName, nameSet := os.LookupEnv(secretNameEnvVar)

// either secret namespace/name must be set, or none at all
if !((nameSpaceSet && nameSet) || (!nameSpaceSet && !nameSet)) {
missingParamError := constructMissingParamsMsg([]bool{nameSpaceSet, nameSet}, []string{kubernetesPort,
secretNameSpaceEnvVar, secretNameEnvVar})
return fmt.Errorf("all or none of the environment variables %s and %s must be set. %s",
missingParamError := constructMissingParamsMsg([]bool{nameSpaceSet, nameSet},
[]string{secretNameSpaceEnvVar, secretNameEnvVar})
return false, fmt.Errorf("all or none of the environment variables %s and %s must be set. %s",
secretNameSpaceEnvVar, secretNameEnvVar, missingParamError)
}

if !nameSpaceSet {
logger.Info("Secret name not set in env. Failback to other cert retieval methods.")
return nil
return false, nil
}

caCert, cert, key, err := c.k8secretRetreiver.RetrieveSecret(secretNameSpace, secretName)
if err != nil {
return fmt.Errorf("failed to read certs from k8s secret %s in namespace %s: %w", secretName, secretNameSpace, err)
return false, fmt.Errorf("failed to read certs from k8s secret %s in namespace %s: %w", secretName, secretNameSpace, err)
}
if len(caCert) != 0 && len(cert) != 0 && len(key) != 0 {
logger.Info("Successfully read cert from k8s secret ", "secretName", secretName, "secretNameSpace", secretNameSpace)
} else {
return fmt.Errorf("failed to read CA, cert or key (sizes = %d/%d/%d)",
return false, fmt.Errorf("failed to read CA, cert or key (sizes = %d/%d/%d)",
len(caCert), len(cert), len(key))
}
c.sOptions.CaCert = string(caCert)
c.sOptions.Cert = string(cert)
c.sOptions.Key = string(key)

return nil
return true, nil
}

// nmaCertLookupFromEnv retrieves the NMA certs from plaintext file identified
// by an environment variable.
func (c *CmdScrutinize) nmaCertLookupFromEnv(logger vlog.Printer) (bool, error) {
rootCAPath, rootCAPathSet := os.LookupEnv(nmaRootCAPathEnvVar)
certPath, certPathSet := os.LookupEnv(nmaCertPathEnvVar)
keyPath, keyPathSet := os.LookupEnv(nmaKeyPathEnvVar)

// either all env vars are set or none at all
if !((rootCAPathSet && certPathSet && keyPathSet) || (!rootCAPathSet && !certPathSet && !keyPathSet)) {
missingParamError := constructMissingParamsMsg([]bool{rootCAPathSet, certPathSet, keyPathSet},
[]string{nmaRootCAPathEnvVar, nmaCertPathEnvVar, nmaKeyPathEnvVar})
return false, fmt.Errorf("all or none of the environment variables %s, %s and %s must be set. %s",
nmaRootCAPathEnvVar, nmaCertPathEnvVar, nmaKeyPathEnvVar, missingParamError)
}

if !rootCAPathSet {
logger.Info("NMA cert location paths not set in env")
return false, nil
}

var err error

c.sOptions.CaCert, err = readNonEmptyFile(rootCAPath)
if err != nil {
return false, fmt.Errorf("failed to read root CA from %s: %w", rootCAPath, err)
}

c.sOptions.Cert, err = readNonEmptyFile(certPath)
if err != nil {
return false, fmt.Errorf("failed to read cert from %s: %w", certPath, err)
}

c.sOptions.Key, err = readNonEmptyFile(keyPath)
if err != nil {
return false, fmt.Errorf("failed to read key from %s: %w", keyPath, err)
}

logger.Info("Successfully read certs from file", "rootCAPath", rootCAPath, "certPath", certPath, "keyPath", keyPath)
return true, nil
}

// readNonEmptyFile is a helper that reads the contents of a file into a string.
// It returns an error if the file is empty.
func readNonEmptyFile(filename string) (string, error) {
contents, err := os.ReadFile(filename)
if err != nil {
return "", fmt.Errorf("failed to read from %s: %w", filename, err)
}
if len(contents) == 0 {
return "", fmt.Errorf("%s is empty", filename)
}
return string(contents), nil
}

// constructMissingParamsMsg builds a warning string listing each
Expand Down
85 changes: 80 additions & 5 deletions commands/scrutinize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestScrutinCmd(t *testing.T) {
assert.ErrorContains(t, err, "unable to get database name from environment variable")
}

func TestUpdateCertTextsFromK8s(t *testing.T) {
func TestNMACertLookupFromK8sSecret(t *testing.T) {
const randomBytes = "123"
c := makeCmdScrutinize()
c.k8secretRetreiver = TestK8sSecretRetriever{
Expand All @@ -92,8 +92,9 @@ func TestUpdateCertTextsFromK8s(t *testing.T) {

// Case 2: when the certs are configured correctly

err := c.updateCertTextsFromk8s(vlog.Printer{})
ok, err := c.nmaCertLookupFromK8sSecret(vlog.Printer{})
assert.NoError(t, err)
assert.True(t, ok)
assert.Equal(t, "test cert 1", c.sOptions.CaCert)
assert.Equal(t, "test cert 2", c.sOptions.Cert)
assert.Equal(t, "test cert 3", c.sOptions.Key)
Expand All @@ -106,19 +107,93 @@ func TestUpdateCertTextsFromK8s(t *testing.T) {
cert: "test cert 2",
key: "", // Missing
}
err = c.updateCertTextsFromk8s(vlog.Printer{})
ok, err = c.nmaCertLookupFromK8sSecret(vlog.Printer{})
assert.Error(t, err)
assert.False(t, ok)

// Failure to retrieve the secret should fail the request
c = makeCmdScrutinize()
c.k8secretRetreiver = TestK8sSecretRetriever{success: false}
err = c.updateCertTextsFromk8s(vlog.Printer{})
ok, err = c.nmaCertLookupFromK8sSecret(vlog.Printer{})
assert.Error(t, err)
assert.False(t, ok)

// If the nma env vars aren't set, then we go onto the next retrieval method
os.Clearenv()
os.Setenv("KUBERNETES_PORT", randomBytes)
c = makeCmdScrutinize()
err = c.updateCertTextsFromk8s(vlog.Printer{})
ok, err = c.nmaCertLookupFromK8sSecret(vlog.Printer{})
assert.NoError(t, err)
assert.False(t, ok)
}

func TestNMACertLookupFromEnv(t *testing.T) {
sampleRootCA := "== sample root CA =="
sampleCert := "== sample cert =="
sampleKey := "== sample key =="

frootCA, err := os.CreateTemp("", "root-ca-")
assert.NoError(t, err)
defer frootCA.Close()
defer os.Remove(frootCA.Name())
_, err = frootCA.WriteString(sampleRootCA)
assert.NoError(t, err)
frootCA.Close()

var fcert *os.File
fcert, err = os.CreateTemp("", "cert-")
assert.NoError(t, err)
defer fcert.Close()
defer os.Remove(fcert.Name())
_, err = fcert.WriteString(sampleCert)
assert.NoError(t, err)
fcert.Close()

var fkeyEmpty *os.File
fkeyEmpty, err = os.CreateTemp("", "key-")
assert.NoError(t, err)
// Omit writing any data to test code path
fkeyEmpty.Close()
defer os.Remove(fkeyEmpty.Name())

os.Setenv(nmaRootCAPathEnvVar, frootCA.Name())
os.Setenv(nmaCertPathEnvVar, fcert.Name())
// intentionally omit key path env var to test error path

// Should fail because only 2 of 3 env vars are set
c := makeCmdScrutinize()
ok, err := c.nmaCertLookupFromEnv(vlog.Printer{})
assert.Error(t, err)
assert.False(t, ok)

// Set 3rd env var
os.Setenv(nmaKeyPathEnvVar, fkeyEmpty.Name())

// Should fail because one of the files is empty
c = makeCmdScrutinize()
ok, err = c.nmaCertLookupFromEnv(vlog.Printer{})
assert.Error(t, err)
assert.False(t, ok)

// Populate empty file with contents
var fkey *os.File
fkey, err = os.CreateTemp("", "key-")
assert.NoError(t, err)
defer fkey.Close()
defer os.Remove(fkey.Name())
_, err = fkey.WriteString(sampleKey)
assert.NoError(t, err)
fkey.Close()

// Point to key that is non-empty
os.Setenv(nmaKeyPathEnvVar, fkey.Name())

// Should succeed now as everything is setup properly
c = makeCmdScrutinize()
ok, err = c.nmaCertLookupFromEnv(vlog.Printer{})
assert.NoError(t, err)
assert.True(t, ok)
assert.Equal(t, sampleRootCA, c.sOptions.CaCert)
assert.Equal(t, sampleCert, c.sOptions.Cert)
assert.Equal(t, sampleKey, c.sOptions.Key)
}
36 changes: 18 additions & 18 deletions vclusterops/adapter_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ import (
"github.com/vertica/vcluster/vclusterops/vlog"
)

type AdapterPool struct {
type adapterPool struct {
logger vlog.Printer
// map from host to HTTPAdapter
connections map[string]Adapter
connections map[string]adapter
}

var (
poolInstance AdapterPool
poolInstance adapterPool
once sync.Once
)

// return a singleton instance of the AdapterPool
func getPoolInstance(logger vlog.Printer) AdapterPool {
func getPoolInstance(logger vlog.Printer) adapterPool {
/* if once.Do(f) is called multiple times,
* only the first call will invoke f,
* even if f has a different value in each invocation.
Expand All @@ -49,43 +49,43 @@ func getPoolInstance(logger vlog.Printer) AdapterPool {
return poolInstance
}

func makeAdapterPool(logger vlog.Printer) AdapterPool {
newAdapterPool := AdapterPool{}
newAdapterPool.connections = make(map[string]Adapter)
func makeAdapterPool(logger vlog.Printer) adapterPool {
newAdapterPool := adapterPool{}
newAdapterPool.connections = make(map[string]adapter)
newAdapterPool.logger = logger.WithName("AdapterPool")
return newAdapterPool
}

type adapterToRequest struct {
adapter Adapter
request HostHTTPRequest
adapter adapter
request hostHTTPRequest
}

func (pool *AdapterPool) sendRequest(clusterHTTPRequest *ClusterHTTPRequest) error {
func (pool *adapterPool) sendRequest(httpRequest *clusterHTTPRequest) error {
// build a collection of adapter to request
// we need this step as a host may not be in the pool
// in that case, we should not proceed
var adapterToRequestCollection []adapterToRequest
for host := range clusterHTTPRequest.RequestCollection {
request := clusterHTTPRequest.RequestCollection[host]
adapter, ok := pool.connections[host]
for host := range httpRequest.RequestCollection {
request := httpRequest.RequestCollection[host]
adpt, ok := pool.connections[host]
if !ok {
return fmt.Errorf("host %s is not found in the adapter pool", host)
}
ar := adapterToRequest{adapter: adapter, request: request}
ar := adapterToRequest{adapter: adpt, request: request}
adapterToRequestCollection = append(adapterToRequestCollection, ar)
}

hostCount := len(adapterToRequestCollection)

// result channel to collect result from each host
resultChannel := make(chan HostHTTPResult, hostCount)
resultChannel := make(chan hostHTTPResult, hostCount)

// only track the progress of HTTP requests for vcluster CLI
if pool.logger.ForCli {
// use context to check whether a step has completed
ctx, cancelCtx := context.WithCancel(context.Background())
go progressCheck(ctx, clusterHTTPRequest.Name)
go progressCheck(ctx, httpRequest.Name)
// cancel the progress check context when the result channel is closed
defer cancelCtx()
}
Expand All @@ -101,11 +101,11 @@ func (pool *AdapterPool) sendRequest(clusterHTTPRequest *ClusterHTTPRequest) err
// handle results
// we expect to receive the same number of results from the channel as the number of hosts
// before proceeding to the next steps
clusterHTTPRequest.ResultCollection = make(map[string]HostHTTPResult)
httpRequest.ResultCollection = make(map[string]hostHTTPResult)
for i := 0; i < hostCount; i++ {
result, ok := <-resultChannel
if ok {
clusterHTTPRequest.ResultCollection[result.host] = result
httpRequest.ResultCollection[result.host] = result
}
}
close(resultChannel)
Expand Down
Loading

0 comments on commit 6f31867

Please sign in to comment.