diff --git a/castai/client.go b/castai/client.go index 2718ee5..317b790 100644 --- a/castai/client.go +++ b/castai/client.go @@ -2,7 +2,11 @@ package castai import ( "context" + "crypto/tls" + "crypto/x509" "fmt" + "net" + "net/http" "time" "github.com/go-resty/resty/v2" @@ -26,9 +30,15 @@ func NewClient(log *logrus.Logger, rest *resty.Client, clusterID string) Client } } -// NewDefaultClient configures a default instance of the resty.Client used to do HTTP requests. -func NewDefaultClient(url, key string, level logrus.Level, timeout time.Duration, version string) *resty.Client { - client := resty.New() +// NewRestyClient configures a default instance of the resty.Client used to do HTTP requests. +func NewRestyClient(url, key, ca string, level logrus.Level, timeout time.Duration, version string) (*resty.Client, error) { + clientTransport, err := createHTTPTransport(ca) + if err != nil { + return nil, err + } + client := resty.NewWithClient(&http.Client{ + Transport: clientTransport, + }) client.SetBaseURL(url) client.SetTimeout(timeout) client.Header.Set(headerAPIKey, key) @@ -37,7 +47,43 @@ func NewDefaultClient(url, key string, level logrus.Level, timeout time.Duration client.SetDebug(true) } - return client + return client, nil +} + +func createHTTPTransport(ca string) (*http.Transport, error) { + tlsConfig, err := createTLSConfig(ca) + if err != nil { + return nil, fmt.Errorf("creating TLS config: %v", err) + } + // Mostly copied from http.DefaultTransport. + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: tlsConfig, + }, nil +} + +func createTLSConfig(ca string) (*tls.Config, error) { + if len(ca) == 0 { + return nil, nil + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM([]byte(ca)) { + return nil, fmt.Errorf("failed to add root certificate to CA pool") + } + + return &tls.Config{ + RootCAs: certPool, + }, nil } type client struct { diff --git a/castai/client_test.go b/castai/client_test.go new file mode 100644 index 0000000..559012c --- /dev/null +++ b/castai/client_test.go @@ -0,0 +1,57 @@ +package castai + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewRestryClient_TLS(t *testing.T) { + t.Run("should populate tls.Config RootCAs when valid certificate presented", func(t *testing.T) { + r := require.New(t) + + ca := []byte(` +-----BEGIN CERTIFICATE----- +MIIDATCCAemgAwIBAgIUPUS4krHP49SF+yYMLHe4nCllKmEwDQYJKoZIhvcNAQEL +BQAwDzENMAsGA1UECgwEVGVzdDAgFw0yMzA5MTMwODM5MzhaGA8yMjE1MDUxMDA4 +MzkzOFowDzENMAsGA1UECgwEVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBAOVZbDa4/tf3N3VP4Ezvt18d++xrQ+bzjhuE7MWX36NWZ4wUzgmqQXd0 +OQWoxYqRGKyI847v29j2BWG17ZmbqarwZHjR98rn9gNtRJgeURlEyAh1pAprhFwb +IBS9vyyCNJtfFFF+lvWvJcU+VKIqWH/9413xDx+OE8tRWNRkS/1CVJg1Nnm3H/IF +lhWAKOYbeKY9q8RtIhb4xNqIc8nmUjDFIjRTarIuf+jDwfFQAPK5pNci+o9KCDgd +Y4lvnGfvPp9XAHnWzTRWNGJQyefZb/SdJjXlic10njfttzKBXi0x8IuV2x98AEPE +2jLXIvC+UBpvMhscdzPfahp5xkYJWx0CAwEAAaNTMFEwHQYDVR0OBBYEFFE48b+V +4E5PWqjpLcUnqWvDDgsuMB8GA1UdIwQYMBaAFFE48b+V4E5PWqjpLcUnqWvDDgsu +MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAIe82ddHX61WHmyp +zeSiF25aXBqeOUA0ScArTL0fBGi9xZ/8gVU79BvJMyfkaeBKvV06ka6g9OnleWYB +zhBmHBvCL6PsgwLxgzt/dj5ES0K3Ml+7jGmhCKKryzYj/ZvhSMyLlxZqP/nRccBG +y6G3KK4bjzqY4TcEPNs8H4Akc+0SGcPl+AAe65mXPIQhtMkANFLoRuWxMf5JmJke +dYT1GoOjRJpEWCATM+KCXa3UEpRBcXNLeOHZivuqf7n0e1CUD6+0oK4TLxVsTqti +q276VYI/vYmMLRI/iE7Qjn9uGEeR1LWpVngE9jSzSdzByvzw3DwO4sL5B+rv7O1T +9Qgi/No= +-----END CERTIFICATE----- + `) + + got, err := createTLSConfig(ca) + r.NoError(err) + r.NotNil(got) + r.NotEmpty(got.RootCAs) + }) + + t.Run("should return error and nil for tls.Config when invalid certificate is given", func(t *testing.T) { + r := require.New(t) + + ca := []byte("certificate") + got, err := createTLSConfig(ca) + r.Error(err) + r.Nil(got) + }) + + t.Run("should return nil if no certificate is set", func(t *testing.T) { + r := require.New(t) + + got, err := createTLSConfig(nil) + r.NoError(err) + r.Nil(got) + }) +} diff --git a/config/config.go b/config/config.go index 79f1ccc..254b315 100644 --- a/config/config.go +++ b/config/config.go @@ -10,6 +10,7 @@ type Config struct { NodeName string APIUrl string APIKey string + TLS TLS ClusterID string Provider string LogLevel int @@ -17,6 +18,10 @@ type Config struct { PollIntervalSeconds int } +type TLS struct { + CACert string +} + var cfg *Config // Get configuration bound to environment variables. @@ -29,6 +34,7 @@ func Get() Config { _ = viper.BindEnv("apikey", "API_KEY") _ = viper.BindEnv("apiurl", "API_URL") + _ = viper.BindEnv("tls.cacert", "TLS_CA_CERT_FILE") _ = viper.BindEnv("nodename", "NODE_NAME") _ = viper.BindEnv("clusterid", "CLUSTER_ID") _ = viper.BindEnv("provider", "PROVIDER") diff --git a/handler/handler_test.go b/handler/handler_test.go index 7baea86..43c215e 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -51,7 +51,8 @@ func TestRunLoop(t *testing.T) { defer castS.Close() fakeApi := fake.NewSimpleClientset(node) - castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0") + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0") + r.NoError(err) mockCastClient := castai.NewClient(log, castHttp, "test1") mockInterrupt := &mockInterruptChecker{interrupted: true} @@ -67,7 +68,7 @@ func TestRunLoop(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - err := handler.Run(ctx) + err = handler.Run(ctx) require.NoError(t, err) r.Equal(1, mothershipCalls) @@ -93,7 +94,8 @@ func TestRunLoop(t *testing.T) { defer castS.Close() fakeApi := fake.NewSimpleClientset(node) - castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0") + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0") + r.NoError(err) mockCastClient := castai.NewClient(log, castHttp, "test1") mockInterrupt := &mockInterruptChecker{interrupted: true} @@ -110,7 +112,7 @@ func TestRunLoop(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - err := handler.Run(ctx) + err = handler.Run(ctx) require.NoError(t, err) r.Equal(1, mothershipCalls) @@ -150,7 +152,8 @@ func TestRunLoop(t *testing.T) { defer castS.Close() fakeApi := fake.NewSimpleClientset(node) - castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, time.Millisecond*100, "0.0.0") + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, time.Millisecond*100, "0.0.0") + r.NoError(err) mockCastClient := castai.NewClient(log, castHttp, "test1") mockInterrupt := &mockInterruptChecker{interrupted: true} @@ -164,7 +167,7 @@ func TestRunLoop(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), time.Second) - err := handler.Run(ctx) + err = handler.Run(ctx) require.NoError(t, err) defer func() { @@ -185,7 +188,8 @@ func TestRunLoop(t *testing.T) { defer castS.Close() fakeApi := fake.NewSimpleClientset(node) - castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0") + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0") + r.NoError(err) mockCastClient := castai.NewClient(log, castHttp, "test1") mockRecommendation := &mockInterruptChecker{rebalanceRecommendation: true} @@ -201,7 +205,7 @@ func TestRunLoop(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - err := handler.Run(ctx) + err = handler.Run(ctx) require.NoError(t, err) r.Equal(1, mothershipCalls) }) diff --git a/main.go b/main.go index fa49a3b..8800614 100644 --- a/main.go +++ b/main.go @@ -62,13 +62,17 @@ func main() { } // Set 5 seconds until we timeout calling mothership and retry. - castHttpClient := castai.NewDefaultClient( + castHttpClient, err := castai.NewRestyClient( cfg.APIUrl, cfg.APIKey, + cfg.TLS.CACert, logrus.Level(cfg.LogLevel), 5*time.Second, Version, ) + if err != nil { + log.Fatalf("interrupt checker: %v", err) + } castClient := castai.NewClient(logger, castHttpClient, cfg.ClusterID) spotHandler := handler.NewSpotHandler(