Skip to content

Commit

Permalink
fix(oauth): OAuth clients not to inherit DefaultClient config
Browse files Browse the repository at this point in the history
  • Loading branch information
spbsoluble committed Oct 31, 2024
1 parent 54983b9 commit 1fb370e
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 143 deletions.
98 changes: 56 additions & 42 deletions auth_providers/auth_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ type CommandAuthConfig struct {

// HttpClient is the http Client to be used for authentication to Keyfactor Command API
HttpClient *http.Client
//DefaultHttpClient *http.Client
}

// cleanHostName cleans the hostname for authentication to Keyfactor Command API.
Expand Down Expand Up @@ -275,66 +276,53 @@ func (c *CommandAuthConfig) ValidateAuthConfig() error {
// check if CommandCACert is set in environment
if caCert, ok := os.LookupEnv(EnvKeyfactorCACert); ok {
c.CommandCACert = caCert
} else {
return nil
}
}

// check for skip verify in environment
if skipVerify, ok := os.LookupEnv(EnvKeyfactorSkipVerify); ok {
c.SkipVerify = skipVerify == "true" || skipVerify == "1"
}

//TODO: This should be part of BuildTransport
//if c.SkipVerify {
// c.HttpClient.Transport = &http.Transport{
// TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
// }
// //return nil
//}
//
//caErr := c.updateCACerts()
//if caErr != nil {
// return caErr
//}

return nil
}

// BuildTransport creates a custom http Transport for authentication to Keyfactor Command API.
func (c *CommandAuthConfig) BuildTransport() (*http.Transport, error) {
var output *http.Transport
if c.HttpClient == nil {
c.SetClient(nil)
}
// check if c already has a transport and if it does, assign it to output else create a new transport
if c.HttpClient.Transport != nil {
if transport, ok := c.HttpClient.Transport.(*http.Transport); ok {
output = transport
} else {
output = &http.Transport{
TLSClientConfig: &tls.Config{},
}
}
} else {
output = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateOnceAsClient,
},
TLSHandshakeTimeout: 10 * time.Second,
}
output := http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateOnceAsClient,
},
TLSHandshakeTimeout: 10 * time.Second,
}

if c.SkipVerify {
output.TLSClientConfig.InsecureSkipVerify = true
}

if c.CommandCACert != "" {
_ = c.updateCACerts()
if _, err := os.Stat(c.CommandCACert); err == nil {
cert, ioErr := os.ReadFile(c.CommandCACert)
if ioErr != nil {
return &output, ioErr
}
// check if output.TLSClientConfig.RootCAs is nil
if output.TLSClientConfig.RootCAs == nil {
output.TLSClientConfig.RootCAs = x509.NewCertPool()
}
// Append your custom cert to the pool
if ok := output.TLSClientConfig.RootCAs.AppendCertsFromPEM(cert); !ok {
return &output, fmt.Errorf("failed to append custom CA cert to pool")
}
} else {
// Append your custom cert to the pool
if ok := output.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(c.CommandCACert)); !ok {
return &output, fmt.Errorf("failed to append custom CA cert to pool")
}
}
}

return output, nil
return &output, nil
}

// SetClient sets the http Client for authentication to Keyfactor Command API.
Expand All @@ -343,8 +331,34 @@ func (c *CommandAuthConfig) SetClient(client *http.Client) *http.Client {
c.HttpClient = client
}
if c.HttpClient == nil {
c.HttpClient = http.DefaultClient
//// Copy the default transport and apply the custom TLS config
//defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
////defaultTransport.TLSClientConfig = tlsConfig
//c.HttpClient = &http.Client{Transport: defaultTransport}
defaultTimeout := time.Duration(c.HttpClientTimeout) * time.Second
c.HttpClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateOnceAsClient,
},
TLSHandshakeTimeout: defaultTimeout,
DisableKeepAlives: false,
DisableCompression: false,
MaxIdleConns: 10,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 10,
IdleConnTimeout: defaultTimeout,
ResponseHeaderTimeout: defaultTimeout,
ExpectContinueTimeout: defaultTimeout,
MaxResponseHeaderBytes: 0,
WriteBufferSize: 0,
ReadBufferSize: 0,
ForceAttemptHTTP2: false,
},
}
}

return c.HttpClient
}

Expand All @@ -356,6 +370,7 @@ func (c *CommandAuthConfig) updateCACerts() error {
if caCert, ok := os.LookupEnv(EnvKeyfactorCACert); ok {
c.CommandCACert = caCert
} else {
// nothing to do
return nil
}
}
Expand Down Expand Up @@ -452,7 +467,6 @@ func (c *CommandAuthConfig) Authenticate() error {
}

c.HttpClient.Timeout = time.Duration(c.HttpClientTimeout) * time.Second

cResp, cErr := c.HttpClient.Do(req)
if cErr != nil {
return cErr
Expand Down Expand Up @@ -645,7 +659,7 @@ func (c *CommandAuthConfig) LoadConfig(profile string, configFilePath string, si
if c.CommandCACert == "" {
c.CommandCACert = server.CACertPath
}
if c.SkipVerify {
if !c.SkipVerify {
c.SkipVerify = server.SkipTLSVerify
}

Expand Down
117 changes: 26 additions & 91 deletions auth_providers/auth_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package auth_providers

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
Expand Down Expand Up @@ -51,6 +50,11 @@ type OAuthAuthenticator struct {
Client *http.Client
}

type oauth2Transport struct {
base http.RoundTripper
src oauth2.TokenSource
}

// GetHttpClient returns the http client
func (a *OAuthAuthenticator) GetHttpClient() (*http.Client, error) {
return a.Client, nil
Expand Down Expand Up @@ -162,24 +166,17 @@ func (b *CommandConfigOauth) WithHttpClient(httpClient *http.Client) *CommandCon
// GetHttpClient returns an HTTP client for oAuth authentication.
func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
cErr := b.ValidateAuthConfig()
var client http.Client
if b.CommandAuthConfig.HttpClient != nil {
client = *b.CommandAuthConfig.HttpClient
}
if cErr != nil {
return nil, cErr
}

if client.Transport == nil {
transport, tErr := b.BuildTransport()
if tErr != nil {
return nil, tErr
}
client.Transport = transport
var client http.Client
baseTransport, tErr := b.BuildTransport()
if tErr != nil {
return nil, tErr
}

if b.AccessToken != "" {
baseTransport := cloneHTTPTransport(client.Transport.(*http.Transport))
client.Transport = &oauth2.Transport{
Base: baseTransport,
Source: oauth2.StaticTokenSource(
Expand Down Expand Up @@ -209,15 +206,15 @@ func (b *CommandConfigOauth) GetHttpClient() (*http.Client, error) {
}
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: baseTransport})
tokenSource := config.TokenSource(ctx)
baseTransport := cloneHTTPTransport(client.Transport.(*http.Transport))
oauthTransport := oauth2.Transport{
Base: baseTransport,
Source: tokenSource,

client = http.Client{
Transport: &oauth2Transport{
base: baseTransport,
src: tokenSource,
},
}
client.Transport = &oauthTransport

return &client, nil
}
Expand Down Expand Up @@ -375,6 +372,7 @@ func (b *CommandConfigOauth) Authenticate() error {
}

b.SetClient(oauthy)
//b.DefaultHttpClient = oauthy

aErr := b.CommandAuthConfig.Authenticate()
if aErr != nil {
Expand All @@ -401,79 +399,16 @@ func (b *CommandConfigOauth) GetServerConfig() *Server {
return &server
}

// Example usage of CommandConfigOauth
//
// This example demonstrates how to use CommandConfigOauth to authenticate to the Keyfactor Command API using OAuth2.
//
// func ExampleCommandConfigOauth_Authenticate() {
// authConfig := &CommandConfigOauth{
// CommandAuthConfig: CommandAuthConfig{
// ConfigFilePath: "/path/to/config.json",
// ConfigProfile: "default",
// CommandHostName: "exampleHost",
// CommandPort: 443,
// CommandAPIPath: "/api/v1",
// CommandCACert: "/path/to/ca-cert.pem",
// SkipVerify: true,
// HttpClientTimeout: 60,
// },
// ClientID: "exampleClientID",
// ClientSecret: "exampleClientSecret",
// TokenURL: "https://example.com/oauth/token",
// Scopes: []string{"openid", "profile", "email"},
// Audience: "exampleAudience",
// CACertificatePath: "/path/to/ca-cert.pem",
// AccessToken: "exampleAccessToken",
// }
//
// err := authConfig.Authenticate()
// if err != nil {
// fmt.Println("Authentication failed:", err)
// } else {
// fmt.Println("Authentication successful")
// }
// }

func cloneHTTPTransport(original *http.Transport) *http.Transport {
if original == nil {
return nil
// RoundTrip executes a single HTTP transaction, adding the OAuth2 token to the request
func (t *oauth2Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.src.Token()
if err != nil {
return nil, fmt.Errorf("failed to retrieve OAuth token: %w", err)
}

return &http.Transport{
Proxy: original.Proxy,
DialContext: original.DialContext,
ForceAttemptHTTP2: original.ForceAttemptHTTP2,
MaxIdleConns: original.MaxIdleConns,
IdleConnTimeout: original.IdleConnTimeout,
TLSHandshakeTimeout: original.TLSHandshakeTimeout,
ExpectContinueTimeout: original.ExpectContinueTimeout,
ResponseHeaderTimeout: original.ResponseHeaderTimeout,
TLSClientConfig: cloneTLSConfig(original.TLSClientConfig),
DialTLSContext: original.DialTLSContext,
DisableKeepAlives: original.DisableKeepAlives,
DisableCompression: original.DisableCompression,
MaxIdleConnsPerHost: original.MaxIdleConnsPerHost,
MaxConnsPerHost: original.MaxConnsPerHost,
WriteBufferSize: original.WriteBufferSize,
ReadBufferSize: original.ReadBufferSize,
}
}
// Clone the request to avoid mutating the original
reqCopy := req.Clone(req.Context())
token.SetAuthHeader(reqCopy)

func cloneTLSConfig(original *tls.Config) *tls.Config {
if original == nil {
return nil
}

return &tls.Config{
InsecureSkipVerify: original.InsecureSkipVerify,
MinVersion: original.MinVersion,
MaxVersion: original.MaxVersion,
CipherSuites: original.CipherSuites,
PreferServerCipherSuites: original.PreferServerCipherSuites,
NextProtos: original.NextProtos,
ServerName: original.ServerName,
ClientAuth: original.ClientAuth,
RootCAs: original.RootCAs,
// Deep copy the rest of the TLS fields as needed
}
return t.base.RoundTrip(reqCopy)
}
Loading

0 comments on commit 1fb370e

Please sign in to comment.