From 1fb370e1fae37f56057379ad98d79b91c6bd00f7 Mon Sep 17 00:00:00 2001 From: spbsoluble <1661003+spbsoluble@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:46:35 -0700 Subject: [PATCH] fix(oauth): OAuth clients not to inherit DefaultClient config --- auth_providers/auth_core.go | 98 +++++++++++--------- auth_providers/auth_oauth.go | 117 ++++++------------------ auth_providers/auth_oauth_test.go | 72 +++++++++++++-- main.go | 143 +++++++++++++++++++++++++++++- 4 files changed, 287 insertions(+), 143 deletions(-) diff --git a/auth_providers/auth_core.go b/auth_providers/auth_core.go index e79a8c2..390646b 100644 --- a/auth_providers/auth_core.go +++ b/auth_providers/auth_core.go @@ -141,6 +141,7 @@ type CommandAuthConfig struct { // HttpClient is the http Client to be used for authentication to Keyfactor Command API HttpClient *http.Client + //DefaultHttpClient *http.Client } // cleanHostName cleans the hostname for authentication to Keyfactor Command API. @@ -275,8 +276,6 @@ func (c *CommandAuthConfig) ValidateAuthConfig() error { // check if CommandCACert is set in environment if caCert, ok := os.LookupEnv(EnvKeyfactorCACert); ok { c.CommandCACert = caCert - } else { - return nil } } @@ -284,46 +283,17 @@ func (c *CommandAuthConfig) ValidateAuthConfig() error { if skipVerify, ok := os.LookupEnv(EnvKeyfactorSkipVerify); ok { c.SkipVerify = skipVerify == "true" || skipVerify == "1" } - - //TODO: This should be part of BuildTransport - //if c.SkipVerify { - // c.HttpClient.Transport = &http.Transport{ - // TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - // } - // //return nil - //} - // - //caErr := c.updateCACerts() - //if caErr != nil { - // return caErr - //} - return nil } // BuildTransport creates a custom http Transport for authentication to Keyfactor Command API. func (c *CommandAuthConfig) BuildTransport() (*http.Transport, error) { - var output *http.Transport - if c.HttpClient == nil { - c.SetClient(nil) - } - // check if c already has a transport and if it does, assign it to output else create a new transport - if c.HttpClient.Transport != nil { - if transport, ok := c.HttpClient.Transport.(*http.Transport); ok { - output = transport - } else { - output = &http.Transport{ - TLSClientConfig: &tls.Config{}, - } - } - } else { - output = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{ - Renegotiation: tls.RenegotiateOnceAsClient, - }, - TLSHandshakeTimeout: 10 * time.Second, - } + output := http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + Renegotiation: tls.RenegotiateOnceAsClient, + }, + TLSHandshakeTimeout: 10 * time.Second, } if c.SkipVerify { @@ -331,10 +301,28 @@ func (c *CommandAuthConfig) BuildTransport() (*http.Transport, error) { } if c.CommandCACert != "" { - _ = c.updateCACerts() + if _, err := os.Stat(c.CommandCACert); err == nil { + cert, ioErr := os.ReadFile(c.CommandCACert) + if ioErr != nil { + return &output, ioErr + } + // check if output.TLSClientConfig.RootCAs is nil + if output.TLSClientConfig.RootCAs == nil { + output.TLSClientConfig.RootCAs = x509.NewCertPool() + } + // Append your custom cert to the pool + if ok := output.TLSClientConfig.RootCAs.AppendCertsFromPEM(cert); !ok { + return &output, fmt.Errorf("failed to append custom CA cert to pool") + } + } else { + // Append your custom cert to the pool + if ok := output.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(c.CommandCACert)); !ok { + return &output, fmt.Errorf("failed to append custom CA cert to pool") + } + } } - return output, nil + return &output, nil } // SetClient sets the http Client for authentication to Keyfactor Command API. @@ -343,8 +331,34 @@ func (c *CommandAuthConfig) SetClient(client *http.Client) *http.Client { c.HttpClient = client } if c.HttpClient == nil { - c.HttpClient = http.DefaultClient + //// Copy the default transport and apply the custom TLS config + //defaultTransport := http.DefaultTransport.(*http.Transport).Clone() + ////defaultTransport.TLSClientConfig = tlsConfig + //c.HttpClient = &http.Client{Transport: defaultTransport} + defaultTimeout := time.Duration(c.HttpClientTimeout) * time.Second + c.HttpClient = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + Renegotiation: tls.RenegotiateOnceAsClient, + }, + TLSHandshakeTimeout: defaultTimeout, + DisableKeepAlives: false, + DisableCompression: false, + MaxIdleConns: 10, + MaxIdleConnsPerHost: 10, + MaxConnsPerHost: 10, + IdleConnTimeout: defaultTimeout, + ResponseHeaderTimeout: defaultTimeout, + ExpectContinueTimeout: defaultTimeout, + MaxResponseHeaderBytes: 0, + WriteBufferSize: 0, + ReadBufferSize: 0, + ForceAttemptHTTP2: false, + }, + } } + return c.HttpClient } @@ -356,6 +370,7 @@ func (c *CommandAuthConfig) updateCACerts() error { if caCert, ok := os.LookupEnv(EnvKeyfactorCACert); ok { c.CommandCACert = caCert } else { + // nothing to do return nil } } @@ -452,7 +467,6 @@ func (c *CommandAuthConfig) Authenticate() error { } c.HttpClient.Timeout = time.Duration(c.HttpClientTimeout) * time.Second - cResp, cErr := c.HttpClient.Do(req) if cErr != nil { return cErr @@ -645,7 +659,7 @@ func (c *CommandAuthConfig) LoadConfig(profile string, configFilePath string, si if c.CommandCACert == "" { c.CommandCACert = server.CACertPath } - if c.SkipVerify { + if !c.SkipVerify { c.SkipVerify = server.SkipTLSVerify } diff --git a/auth_providers/auth_oauth.go b/auth_providers/auth_oauth.go index 308efe0..8a8c940 100644 --- a/auth_providers/auth_oauth.go +++ b/auth_providers/auth_oauth.go @@ -2,7 +2,6 @@ package auth_providers import ( "context" - "crypto/tls" "crypto/x509" "fmt" "net/http" @@ -51,6 +50,11 @@ type OAuthAuthenticator struct { Client *http.Client } +type oauth2Transport struct { + base http.RoundTripper + src oauth2.TokenSource +} + // GetHttpClient returns the http client func (a *OAuthAuthenticator) GetHttpClient() (*http.Client, error) { return a.Client, nil @@ -162,24 +166,17 @@ func (b *CommandConfigOauth) WithHttpClient(httpClient *http.Client) *CommandCon // GetHttpClient returns an HTTP client for oAuth authentication. func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) { cErr := b.ValidateAuthConfig() - var client http.Client - if b.CommandAuthConfig.HttpClient != nil { - client = *b.CommandAuthConfig.HttpClient - } if cErr != nil { return nil, cErr } - if client.Transport == nil { - transport, tErr := b.BuildTransport() - if tErr != nil { - return nil, tErr - } - client.Transport = transport + var client http.Client + baseTransport, tErr := b.BuildTransport() + if tErr != nil { + return nil, tErr } if b.AccessToken != "" { - baseTransport := cloneHTTPTransport(client.Transport.(*http.Transport)) client.Transport = &oauth2.Transport{ Base: baseTransport, Source: oauth2.StaticTokenSource( @@ -209,15 +206,15 @@ func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) { } } - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) - + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: baseTransport}) tokenSource := config.TokenSource(ctx) - baseTransport := cloneHTTPTransport(client.Transport.(*http.Transport)) - oauthTransport := oauth2.Transport{ - Base: baseTransport, - Source: tokenSource, + + client = http.Client{ + Transport: &oauth2Transport{ + base: baseTransport, + src: tokenSource, + }, } - client.Transport = &oauthTransport return &client, nil } @@ -375,6 +372,7 @@ func (b *CommandConfigOauth) Authenticate() error { } b.SetClient(oauthy) + //b.DefaultHttpClient = oauthy aErr := b.CommandAuthConfig.Authenticate() if aErr != nil { @@ -401,79 +399,16 @@ func (b *CommandConfigOauth) GetServerConfig() *Server { return &server } -// Example usage of CommandConfigOauth -// -// This example demonstrates how to use CommandConfigOauth to authenticate to the Keyfactor Command API using OAuth2. -// -// func ExampleCommandConfigOauth_Authenticate() { -// authConfig := &CommandConfigOauth{ -// CommandAuthConfig: CommandAuthConfig{ -// ConfigFilePath: "/path/to/config.json", -// ConfigProfile: "default", -// CommandHostName: "exampleHost", -// CommandPort: 443, -// CommandAPIPath: "/api/v1", -// CommandCACert: "/path/to/ca-cert.pem", -// SkipVerify: true, -// HttpClientTimeout: 60, -// }, -// ClientID: "exampleClientID", -// ClientSecret: "exampleClientSecret", -// TokenURL: "https://example.com/oauth/token", -// Scopes: []string{"openid", "profile", "email"}, -// Audience: "exampleAudience", -// CACertificatePath: "/path/to/ca-cert.pem", -// AccessToken: "exampleAccessToken", -// } -// -// err := authConfig.Authenticate() -// if err != nil { -// fmt.Println("Authentication failed:", err) -// } else { -// fmt.Println("Authentication successful") -// } -// } - -func cloneHTTPTransport(original *http.Transport) *http.Transport { - if original == nil { - return nil +// RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request +func (t *oauth2Transport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.src.Token() + if err != nil { + return nil, fmt.Errorf("failed to retrieve OAuth token: %w", err) } - return &http.Transport{ - Proxy: original.Proxy, - DialContext: original.DialContext, - ForceAttemptHTTP2: original.ForceAttemptHTTP2, - MaxIdleConns: original.MaxIdleConns, - IdleConnTimeout: original.IdleConnTimeout, - TLSHandshakeTimeout: original.TLSHandshakeTimeout, - ExpectContinueTimeout: original.ExpectContinueTimeout, - ResponseHeaderTimeout: original.ResponseHeaderTimeout, - TLSClientConfig: cloneTLSConfig(original.TLSClientConfig), - DialTLSContext: original.DialTLSContext, - DisableKeepAlives: original.DisableKeepAlives, - DisableCompression: original.DisableCompression, - MaxIdleConnsPerHost: original.MaxIdleConnsPerHost, - MaxConnsPerHost: original.MaxConnsPerHost, - WriteBufferSize: original.WriteBufferSize, - ReadBufferSize: original.ReadBufferSize, - } -} + // Clone the request to avoid mutating the original + reqCopy := req.Clone(req.Context()) + token.SetAuthHeader(reqCopy) -func cloneTLSConfig(original *tls.Config) *tls.Config { - if original == nil { - return nil - } - - return &tls.Config{ - InsecureSkipVerify: original.InsecureSkipVerify, - MinVersion: original.MinVersion, - MaxVersion: original.MaxVersion, - CipherSuites: original.CipherSuites, - PreferServerCipherSuites: original.PreferServerCipherSuites, - NextProtos: original.NextProtos, - ServerName: original.ServerName, - ClientAuth: original.ClientAuth, - RootCAs: original.RootCAs, - // Deep copy the rest of the TLS fields as needed - } + return t.base.RoundTrip(reqCopy) } diff --git a/auth_providers/auth_oauth_test.go b/auth_providers/auth_oauth_test.go index 5e5fb87..c2c68ad 100644 --- a/auth_providers/auth_oauth_test.go +++ b/auth_providers/auth_oauth_test.go @@ -120,14 +120,62 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) { }() //os.Setenv(auth_providers.EnvKeyfactorConfigFile, configFilePath) //os.Setenv(auth_providers.EnvKeyfactorAuthProfile, "oauth") - //os.Setenv(auth_providers.EnvKeyfactorSkipVerify, "true") + os.Setenv(auth_providers.EnvKeyfactorSkipVerify, "true") + os.Setenv(auth_providers.EnvKeyfactorCACert, "lib/certs/int-oidc-lab.eastus2.cloudapp.azure.com.pem") - t.Log("Testing oAuth with Environmental variables") + // Begin test case + noParamsTestName := fmt.Sprintf( + "w/ complete ENV variables & %s,%s", auth_providers.EnvKeyfactorCACert, + auth_providers.EnvKeyfactorSkipVerify, + ) + t.Log(fmt.Sprintf("Testing %s", noParamsTestName)) noParamsConfig := &auth_providers.CommandConfigOauth{} - noParamsConfig. - WithSkipVerify(true). - WithCommandCACert("lib/certs/int-oidc-lab.eastus2.cloudapp.azure.com.pem") - authOauthTest(t, "with complete Environmental variables", false, noParamsConfig) + authOauthTest( + t, noParamsTestName, false, noParamsConfig, + ) + t.Logf("Unsetting environment variable %s", auth_providers.EnvKeyfactorCACert) + os.Unsetenv(auth_providers.EnvKeyfactorCACert) + t.Logf("Unsetting environment variable %s", auth_providers.EnvKeyfactorSkipVerify) + os.Unsetenv(auth_providers.EnvKeyfactorSkipVerify) + // end test case + + // Begin test case + noParamsTestName = fmt.Sprintf( + "w/ complete ENV variables & %s", auth_providers.EnvKeyfactorCACert, + ) + t.Log(fmt.Sprintf("Testing %s", noParamsTestName)) + t.Logf("Setting environment variable %s", auth_providers.EnvKeyfactorCACert) + os.Setenv(auth_providers.EnvKeyfactorCACert, "lib/certs/int-oidc-lab.eastus2.cloudapp.azure.com.pem") + noParamsConfig = &auth_providers.CommandConfigOauth{} + authOauthTest(t, noParamsTestName, false, noParamsConfig) + t.Logf("Unsetting environment variable %s", auth_providers.EnvKeyfactorCACert) + os.Unsetenv(auth_providers.EnvKeyfactorCACert) + // end test case + + // Begin test case + noParamsTestName = fmt.Sprintf( + "w/ complete ENV variables & %s", auth_providers.EnvKeyfactorSkipVerify, + ) + t.Log(fmt.Sprintf("Testing %s", noParamsTestName)) + t.Logf("Setting environment variable %s", auth_providers.EnvKeyfactorSkipVerify) + os.Setenv(auth_providers.EnvKeyfactorSkipVerify, "true") + noParamsConfig = &auth_providers.CommandConfigOauth{} + authOauthTest(t, noParamsTestName, false, noParamsConfig) + t.Logf("Unsetting environment variable %s", auth_providers.EnvKeyfactorSkipVerify) + os.Unsetenv(auth_providers.EnvKeyfactorSkipVerify) + // end test case + + // Begin test case + noParamsConfig = &auth_providers.CommandConfigOauth{} + httpsFailEnvExpected := []string{"tls: failed to verify certificate", "certificate is not trusted"} + authOauthTest( + t, + fmt.Sprintf("w/o env %s", auth_providers.EnvKeyfactorCACert), + true, + noParamsConfig, + httpsFailEnvExpected..., + ) + // end test case t.Log("Testing oAuth with invalid config file path") invFilePath := &auth_providers.CommandConfigOauth{} @@ -184,6 +232,7 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) { ClientSecret: "invalid-client-secret", TokenURL: tokenURL, } + fullParamsInvalidPassConfig.WithSkipVerify(true) invalidCredsExpectedError := []string{ "oauth2", "unauthorized_client", "Invalid client or Invalid client credentials", } @@ -234,11 +283,13 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) { authOauthTest(t, "with oAuth with valid implicit config file skiptls config param", false, skipTLSConfigFileC) t.Log("Testing oAuth with valid implicit config file skiptls env") + t.Logf("Setting environment variable %s", auth_providers.EnvKeyfactorSkipVerify) os.Setenv(auth_providers.EnvKeyfactorSkipVerify, "true") skipTLSConfigFileE := &auth_providers.CommandConfigOauth{} skipTLSConfigFileE. WithConfigProfile("oauth") authOauthTest(t, "oAuth with valid implicit config file skiptls env", false, skipTLSConfigFileE) + t.Logf("Unsetting environment variable %s", auth_providers.EnvKeyfactorSkipVerify) os.Unsetenv(auth_providers.EnvKeyfactorSkipVerify) t.Log("Testing oAuth with valid implicit config file https fail") @@ -251,7 +302,6 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) { httpsFailConfigFileExpected..., ) - os.Setenv(auth_providers.EnvKeyfactorSkipVerify, "true") t.Log("Testing oAuth with invalid profile implicit config file") invProfile := &auth_providers.CommandConfigOauth{} invProfile.WithConfigProfile("invalid-profile") @@ -260,12 +310,16 @@ func TestCommandConfigOauth_Authenticate(t *testing.T) { t.Log("Testing oAuth with invalid creds implicit config file") invProfileCreds := &auth_providers.CommandConfigOauth{} - invProfileCreds.WithConfigProfile("oauth_invalid_creds") + invProfileCreds. + WithConfigProfile("oauth_invalid_creds"). + WithSkipVerify(true) authOauthTest(t, "with invalid creds implicit config file", true, invProfileCreds, invalidCredsExpectedError...) t.Log("Testing oAuth with invalid Command host implicit config file") invCmdHost := &auth_providers.CommandConfigOauth{} - invCmdHost.WithConfigProfile("oauth_invalid_host") + invCmdHost. + WithConfigProfile("oauth_invalid_host"). + WithSkipVerify(true) invHostExpectedError := []string{"no such host"} authOauthTest(t, "with invalid creds implicit config file", true, invCmdHost, invHostExpectedError...) } diff --git a/main.go b/main.go index f1615e2..96f5009 100644 --- a/main.go +++ b/main.go @@ -17,11 +17,152 @@ package main import ( "fmt" - "github.com/Keyfactor/keyfactor-auth-client-go/pkg" // Correct + "github.com/Keyfactor/keyfactor-auth-client-go/pkg" ) func main() { fmt.Println("Version:", pkg.Version) // print the package version fmt.Println("Build:", pkg.BuildTime) // print the package build fmt.Println("Commit:", pkg.CommitHash) // print the package commit + //testClients() } + +//func testClients() { +// // URL to test against +// url := os.Getenv("KEYFACTOR_AUTH_TOKEN_URL") +// caCertPath := os.Getenv("KEYFACTOR_CA_CERT") +// +// // Load the custom root CA certificate +// caCert, err := os.ReadFile(caCertPath) +// if err != nil { +// log.Fatalf("Failed to read root CA certificate: %v", err) +// } +// +// // Create a certificate pool and add the custom root CA +// caCertPool := x509.NewCertPool() +// if !caCertPool.AppendCertsFromPEM(caCert) { +// log.Fatalf("Failed to append root CA certificate to pool") +// } +// +// // OAuth2 client credentials configuration +// clientId := os.Getenv("KEYFACTOR_AUTH_CLIENT_ID") +// clientSecret := os.Getenv("KEYFACTOR_AUTH_CLIENT_SECRET") +// oauthConfig := &clientcredentials.Config{ +// ClientID: clientId, +// ClientSecret: clientSecret, +// TokenURL: url, +// } +// +// // Transport with default TLS verification (InsecureSkipVerify = false) +// transportDefaultTLS := &http.Transport{ +// TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, +// } +// +// // Transport with TLS verification skipped (InsecureSkipVerify = true) +// transportInsecureTLS := &http.Transport{ +// TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +// } +// +// // Transport with custom CA verification +// transportCustomRootCA := &http.Transport{ +// TLSClientConfig: &tls.Config{ +// RootCAs: caCertPool, // Custom root CA pool +// InsecureSkipVerify: false, // Enforce TLS verification +// }, +// } +// +// // OAuth2 Token Sources +// tokenSourceDefaultTLS := oauthConfig.TokenSource(context.Background()) +// +// ctxInsecure := context.WithValue( +// context.Background(), +// oauth2.HTTPClient, +// &http.Client{Transport: transportInsecureTLS}, +// ) +// tokenSourceInsecureTLS := oauthConfig.TokenSource(ctxInsecure) +// +// ctxCustomCA := context.WithValue( +// context.Background(), +// oauth2.HTTPClient, +// &http.Client{Transport: transportCustomRootCA}, +// ) +// tokenSourceCustomRootCA := oauthConfig.TokenSource(ctxCustomCA) +// +// // OAuth2 clients with different transports +// oauthClientDefaultTLS := &http.Client{ +// Transport: &oauth2Transport{ +// base: transportDefaultTLS, +// src: tokenSourceDefaultTLS, +// }, +// } +// +// oauthClientInsecureTLS := &http.Client{ +// Transport: &oauth2Transport{ +// base: transportInsecureTLS, +// src: tokenSourceInsecureTLS, +// }, +// } +// +// oauthClientCustomRootCA := &http.Client{ +// Transport: &oauth2Transport{ +// base: transportCustomRootCA, +// src: tokenSourceCustomRootCA, +// }, +// } +// +// // Prepare the GET request +// req, err := http.NewRequest("GET", url, nil) +// if err != nil { +// log.Fatalf("Failed to create request: %v", err) +// } +// +// // Test 1: OAuth2 client with default TLS verification (expected to fail if certificate is invalid) +// fmt.Println("Testing OAuth2 client with default TLS verification...") +// resp1, err1 := oauthClientDefaultTLS.Do(req) +// if err1 != nil { +// log.Printf("OAuth2 client with default TLS failed as expected: %v\n", err1) +// } else { +// fmt.Printf("OAuth2 client with default TLS succeeded: %s\n", resp1.Status) +// resp1.Body.Close() +// } +// +// // Test 2: OAuth2 client with skipped TLS verification (should succeed) +// fmt.Println("\nTesting OAuth2 client with skipped TLS verification...") +// resp2, err2 := oauthClientInsecureTLS.Do(req) +// if err2 != nil { +// log.Fatalf("OAuth2 client with skipped TLS failed: %v\n", err2) +// } else { +// fmt.Printf("OAuth2 client with skipped TLS succeeded: %s\n", resp2.Status) +// resp2.Body.Close() +// } +// +// // Test 3: OAuth2 client with custom root CA (should succeed if the CA is valid) +// fmt.Println("\nTesting OAuth2 client with custom root CA verification...") +// resp3, err3 := oauthClientCustomRootCA.Do(req) +// if err3 != nil { +// log.Fatalf("OAuth2 client with custom root CA failed: %v\n", err3) +// } else { +// fmt.Printf("OAuth2 client with custom root CA succeeded: %s\n", resp3.Status) +// resp3.Body.Close() +// } +//} +// +//// oauth2Transport is a custom RoundTripper that injects the OAuth2 token into requests +//type oauth2Transport struct { +// base http.RoundTripper +// src oauth2.TokenSource +//} +// +//// RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request +//func (t *oauth2Transport) RoundTrip(req *http.Request) (*http.Response, error) { +// token, err := t.src.Token() +// if err != nil { +// return nil, fmt.Errorf("failed to retrieve OAuth token: %w", err) +// } +// +// // Clone the request to avoid mutating the original +// reqCopy := req.Clone(req.Context()) +// token.SetAuthHeader(reqCopy) +// +// return t.base.RoundTrip(reqCopy) +//}