Skip to content

Commit

Permalink
Merge pull request #23 from castai/support-custom-ca
Browse files Browse the repository at this point in the history
support custom certificate authority
  • Loading branch information
furkhat authored Aug 7, 2024
2 parents 82cddd5 + 5da7fc6 commit 974206b
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 13 deletions.
54 changes: 50 additions & 4 deletions castai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package castai

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"time"

"github.com/go-resty/resty/v2"
Expand All @@ -26,9 +30,15 @@ func NewClient(log *logrus.Logger, rest *resty.Client, clusterID string) Client
}
}

// NewDefaultClient configures a default instance of the resty.Client used to do HTTP requests.
func NewDefaultClient(url, key string, level logrus.Level, timeout time.Duration, version string) *resty.Client {
client := resty.New()
// NewRestyClient configures a default instance of the resty.Client used to do HTTP requests.
func NewRestyClient(url, key, ca string, level logrus.Level, timeout time.Duration, version string) (*resty.Client, error) {
clientTransport, err := createHTTPTransport(ca)
if err != nil {
return nil, err
}
client := resty.NewWithClient(&http.Client{
Transport: clientTransport,
})
client.SetBaseURL(url)
client.SetTimeout(timeout)
client.Header.Set(headerAPIKey, key)
Expand All @@ -37,7 +47,43 @@ func NewDefaultClient(url, key string, level logrus.Level, timeout time.Duration
client.SetDebug(true)
}

return client
return client, nil
}

func createHTTPTransport(ca string) (*http.Transport, error) {
tlsConfig, err := createTLSConfig(ca)
if err != nil {
return nil, fmt.Errorf("creating TLS config: %v", err)
}
// Mostly copied from http.DefaultTransport.
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}, nil
}

func createTLSConfig(ca string) (*tls.Config, error) {
if len(ca) == 0 {
return nil, nil
}

certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM([]byte(ca)) {
return nil, fmt.Errorf("failed to add root certificate to CA pool")
}

return &tls.Config{
RootCAs: certPool,
}, nil
}

type client struct {
Expand Down
57 changes: 57 additions & 0 deletions castai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package castai

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewRestryClient_TLS(t *testing.T) {
t.Run("should populate tls.Config RootCAs when valid certificate presented", func(t *testing.T) {
r := require.New(t)

ca := `
-----BEGIN CERTIFICATE-----
MIIDATCCAemgAwIBAgIUPUS4krHP49SF+yYMLHe4nCllKmEwDQYJKoZIhvcNAQEL
BQAwDzENMAsGA1UECgwEVGVzdDAgFw0yMzA5MTMwODM5MzhaGA8yMjE1MDUxMDA4
MzkzOFowDzENMAsGA1UECgwEVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC
AQoCggEBAOVZbDa4/tf3N3VP4Ezvt18d++xrQ+bzjhuE7MWX36NWZ4wUzgmqQXd0
OQWoxYqRGKyI847v29j2BWG17ZmbqarwZHjR98rn9gNtRJgeURlEyAh1pAprhFwb
IBS9vyyCNJtfFFF+lvWvJcU+VKIqWH/9413xDx+OE8tRWNRkS/1CVJg1Nnm3H/IF
lhWAKOYbeKY9q8RtIhb4xNqIc8nmUjDFIjRTarIuf+jDwfFQAPK5pNci+o9KCDgd
Y4lvnGfvPp9XAHnWzTRWNGJQyefZb/SdJjXlic10njfttzKBXi0x8IuV2x98AEPE
2jLXIvC+UBpvMhscdzPfahp5xkYJWx0CAwEAAaNTMFEwHQYDVR0OBBYEFFE48b+V
4E5PWqjpLcUnqWvDDgsuMB8GA1UdIwQYMBaAFFE48b+V4E5PWqjpLcUnqWvDDgsu
MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAIe82ddHX61WHmyp
zeSiF25aXBqeOUA0ScArTL0fBGi9xZ/8gVU79BvJMyfkaeBKvV06ka6g9OnleWYB
zhBmHBvCL6PsgwLxgzt/dj5ES0K3Ml+7jGmhCKKryzYj/ZvhSMyLlxZqP/nRccBG
y6G3KK4bjzqY4TcEPNs8H4Akc+0SGcPl+AAe65mXPIQhtMkANFLoRuWxMf5JmJke
dYT1GoOjRJpEWCATM+KCXa3UEpRBcXNLeOHZivuqf7n0e1CUD6+0oK4TLxVsTqti
q276VYI/vYmMLRI/iE7Qjn9uGEeR1LWpVngE9jSzSdzByvzw3DwO4sL5B+rv7O1T
9Qgi/No=
-----END CERTIFICATE-----
`

got, err := createTLSConfig(ca)
r.NoError(err)
r.NotNil(got)
r.NotEmpty(got.RootCAs)
})

t.Run("should return error and nil for tls.Config when invalid certificate is given", func(t *testing.T) {
r := require.New(t)

ca := "certificate"
got, err := createTLSConfig(ca)
r.Error(err)
r.Nil(got)
})

t.Run("should return nil if no certificate is set", func(t *testing.T) {
r := require.New(t)

got, err := createTLSConfig("")
r.NoError(err)
r.Nil(got)
})
}
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type Config struct {
NodeName string
APIUrl string
APIKey string
TLSCACert string
ClusterID string
Provider string
LogLevel int
Expand All @@ -29,6 +30,7 @@ func Get() Config {

_ = viper.BindEnv("apikey", "API_KEY")
_ = viper.BindEnv("apiurl", "API_URL")
_ = viper.BindEnv("tlscacert", "TLS_CA_CERT_FILE")
_ = viper.BindEnv("nodename", "NODE_NAME")
_ = viper.BindEnv("clusterid", "CLUSTER_ID")
_ = viper.BindEnv("provider", "PROVIDER")
Expand Down
20 changes: 12 additions & 8 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func TestRunLoop(t *testing.T) {
defer castS.Close()

fakeApi := fake.NewSimpleClientset(node)
castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0")
castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0")
r.NoError(err)
mockCastClient := castai.NewClient(log, castHttp, "test1")

mockInterrupt := &mockInterruptChecker{interrupted: true}
Expand All @@ -67,7 +68,7 @@ func TestRunLoop(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

err := handler.Run(ctx)
err = handler.Run(ctx)
require.NoError(t, err)
r.Equal(1, mothershipCalls)

Expand All @@ -93,7 +94,8 @@ func TestRunLoop(t *testing.T) {
defer castS.Close()

fakeApi := fake.NewSimpleClientset(node)
castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0")
castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0")
r.NoError(err)
mockCastClient := castai.NewClient(log, castHttp, "test1")

mockInterrupt := &mockInterruptChecker{interrupted: true}
Expand All @@ -110,7 +112,7 @@ func TestRunLoop(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()

err := handler.Run(ctx)
err = handler.Run(ctx)
require.NoError(t, err)
r.Equal(1, mothershipCalls)

Expand Down Expand Up @@ -150,7 +152,8 @@ func TestRunLoop(t *testing.T) {
defer castS.Close()

fakeApi := fake.NewSimpleClientset(node)
castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, time.Millisecond*100, "0.0.0")
castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, time.Millisecond*100, "0.0.0")
r.NoError(err)
mockCastClient := castai.NewClient(log, castHttp, "test1")

mockInterrupt := &mockInterruptChecker{interrupted: true}
Expand All @@ -164,7 +167,7 @@ func TestRunLoop(t *testing.T) {
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
err := handler.Run(ctx)
err = handler.Run(ctx)
require.NoError(t, err)

defer func() {
Expand All @@ -185,7 +188,8 @@ func TestRunLoop(t *testing.T) {
defer castS.Close()

fakeApi := fake.NewSimpleClientset(node)
castHttp := castai.NewDefaultClient(castS.URL, "test", log.Level, 100*time.Millisecond, "0.0.0")
castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0")
r.NoError(err)
mockCastClient := castai.NewClient(log, castHttp, "test1")

mockRecommendation := &mockInterruptChecker{rebalanceRecommendation: true}
Expand All @@ -201,7 +205,7 @@ func TestRunLoop(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

err := handler.Run(ctx)
err = handler.Run(ctx)
require.NoError(t, err)
r.Equal(1, mothershipCalls)
})
Expand Down
6 changes: 5 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,17 @@ func main() {
}

// Set 5 seconds until we timeout calling mothership and retry.
castHttpClient := castai.NewDefaultClient(
castHttpClient, err := castai.NewRestyClient(
cfg.APIUrl,
cfg.APIKey,
cfg.TLSCACert,
logrus.Level(cfg.LogLevel),
5*time.Second,
Version,
)
if err != nil {
log.Fatalf("failed to create http client: %v", err)
}
castClient := castai.NewClient(logger, castHttpClient, cfg.ClusterID)

spotHandler := handler.NewSpotHandler(
Expand Down

0 comments on commit 974206b

Please sign in to comment.