From 31d663769d11ccfe930d260097ca70e4a5809d3b Mon Sep 17 00:00:00 2001 From: Vui Lam Date: Mon, 7 Oct 2024 20:44:27 -0700 Subject: [PATCH 1/3] Refresh token with custom TLS config Token refresh needs to account for custom CA cert and cert validation flag. Factors out the TLSConfig customization using the context.Context to be used for that well. Signed-off-by: Vui Lam --- pkg/auth/common/login_handler.go | 27 ++++++++++++++++----------- pkg/auth/common/login_handler_test.go | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/pkg/auth/common/login_handler.go b/pkg/auth/common/login_handler.go index e6d55e24d..0294ea728 100644 --- a/pkg/auth/common/login_handler.go +++ b/pkg/auth/common/login_handler.go @@ -167,7 +167,8 @@ func WithListenerPortFromEnv(envVarName string) LoginOption { func (h *TanzuLoginHandler) DoLogin() (*Token, error) { if h.refreshToken != "" { - token, err := h.getTokenWithRefreshToken() + ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig()) + token, err := h.getTokenWithRefreshToken(ctx) if err == nil { return token, nil } @@ -176,8 +177,8 @@ func (h *TanzuLoginHandler) DoLogin() (*Token, error) { return h.browserLogin() } -func (h *TanzuLoginHandler) getTokenWithRefreshToken() (*Token, error) { - refreshedToken, err := h.oauthConfig.TokenSource(context.TODO(), &oauth2.Token{RefreshToken: h.refreshToken}).Token() +func (h *TanzuLoginHandler) getTokenWithRefreshToken(ctx context.Context) (*Token, error) { + refreshedToken, err := h.oauthConfig.TokenSource(ctx, &oauth2.Token{RefreshToken: h.refreshToken}).Token() if err != nil { return nil, err } @@ -492,16 +493,20 @@ func GetTLSConfig(endpoint, certData string, skipVerify bool) *tls.Config { return nil } +func contextWithCustomTLSConfig(ctx context.Context, tlsConfig *tls.Config) context.Context { + if tlsConfig != nil { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = tlsConfig + + sslcli := &http.Client{Transport: tr} + ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli) + } + return ctx +} + func (h *TanzuLoginHandler) getTokenUsingAuthCode(ctx context.Context, code string) (*oauth2.Token, error) { if h.idpType == config.UAAIdpType { - tlsConfig := h.getTLSConfig() - if tlsConfig != nil { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = tlsConfig - - sslcli := &http.Client{Transport: tr} - ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli) - } + ctx = contextWithCustomTLSConfig(ctx, h.getTLSConfig()) } token, err := h.oauthConfig.Exchange(ctx, code, h.pkceCodePair.Verifier()) diff --git a/pkg/auth/common/login_handler_test.go b/pkg/auth/common/login_handler_test.go index 2b15638da..6c96df264 100644 --- a/pkg/auth/common/login_handler_test.go +++ b/pkg/auth/common/login_handler_test.go @@ -44,7 +44,7 @@ func TestHandleTokenRefresh(t *testing.T) { refreshToken: "fake-refresh-token", } - token, err := lh.getTokenWithRefreshToken() + token, err := lh.getTokenWithRefreshToken(context.TODO()) if err != nil { t.Errorf("Expected no error, got %v", err) } From 6baaf005ce6cd58f8915167bc87f0511e8a9bc58 Mon Sep 17 00:00:00 2001 From: Vui Lam Date: Mon, 7 Oct 2024 23:43:18 -0700 Subject: [PATCH 2/3] Simplify auth with api token against uaa Refactored login handler so that we can leverage it to perform the token refresh using provided API token. Also: - ensures that the API token based login updates CLI Context with the refresh token obtained from UAA. - do not interactively login on expired refresh token for api-token type tokens. Signed-off-by: Vui Lam --- pkg/auth/common/login_handler.go | 24 +++++-- pkg/auth/common/login_handler_test.go | 71 +++++++++++++++++---- pkg/auth/uaa/uaa.go | 59 +---------------- pkg/auth/uaa/uaa_test.go | 92 --------------------------- pkg/command/context.go | 23 ++++++- 5 files changed, 102 insertions(+), 167 deletions(-) delete mode 100644 pkg/auth/uaa/uaa_test.go diff --git a/pkg/auth/common/login_handler.go b/pkg/auth/common/login_handler.go index 0294ea728..9cf1f34f9 100644 --- a/pkg/auth/common/login_handler.go +++ b/pkg/auth/common/login_handler.go @@ -60,6 +60,7 @@ type TanzuLoginHandler struct { callbackHandlerMutex sync.Mutex tlsSkipVerify bool caCertData string + suppressInteractive bool } // LoginOption is an optional configuration for Login(). @@ -133,6 +134,15 @@ func WithClientID(clientID string) LoginOption { } } +// WithSuppressInteractive specifies whether to fall back to interactive login if +// an access token cannot be obtained. +func WithSuppressInteractive(suppress bool) LoginOption { + return func(h *TanzuLoginHandler) error { + h.suppressInteractive = suppress + return nil + } +} + // WithListenerPort specifies a TCP listener port on localhost, which will be used for the redirect_uri and to handle the // authorization code callback. By default, a random high port will be chosen which requires the authorization server // to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252#section-7.3: @@ -166,18 +176,22 @@ func WithListenerPortFromEnv(envVarName string) LoginOption { } func (h *TanzuLoginHandler) DoLogin() (*Token, error) { + var err error + var token *Token + if h.refreshToken != "" { - ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig()) - token, err := h.getTokenWithRefreshToken(ctx) - if err == nil { - return token, nil + token, err = h.getTokenWithRefreshToken() + if err == nil || h.suppressInteractive { + return token, err } } + // If refresh token fails, proceed with login flow through the browser return h.browserLogin() } -func (h *TanzuLoginHandler) getTokenWithRefreshToken(ctx context.Context) (*Token, error) { +func (h *TanzuLoginHandler) getTokenWithRefreshToken() (*Token, error) { + ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig()) refreshedToken, err := h.oauthConfig.TokenSource(ctx, &oauth2.Token{RefreshToken: h.refreshToken}).Token() if err != nil { return nil, err diff --git a/pkg/auth/common/login_handler_test.go b/pkg/auth/common/login_handler_test.go index 6c96df264..5ad6d0ecb 100644 --- a/pkg/auth/common/login_handler_test.go +++ b/pkg/auth/common/login_handler_test.go @@ -27,6 +27,8 @@ const ( ) func TestHandleTokenRefresh(t *testing.T) { + assert := assert.New(t) + // Mock HTTP server for token refresh server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -44,20 +46,65 @@ func TestHandleTokenRefresh(t *testing.T) { refreshToken: "fake-refresh-token", } - token, err := lh.getTokenWithRefreshToken(context.TODO()) - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - if token == nil { - t.Error("Expected a non-nil token, got nil") + token, err := lh.getTokenWithRefreshToken() + assert.Nil(err) + assert.NotNil(token) + assert.Equal(token.AccessToken, "fake-access-token") + assert.Equal(token.RefreshToken, "fake-refresh-token") + assert.Equal(token.TokenType, "id-token") + assert.Equal(token.IDToken, "fake-id-token") + assert.Equal(token.ExpiresIn, int64(3599)) +} + +// test that login with refresh token completes without triggering browser +// login regardless of whether refresh succeeded or not +func TestLoginWithAPIToken(t *testing.T) { + assert := assert.New(t) + + // Mock HTTP server for token refresh + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), "refresh_token=valid-api-token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token": "fake-access-token", "refresh_token": "fake-refresh-token", "expires_in": 3600, "id_token": "fake-id-token"}`)) + return + } + http.Error(w, "refresh_error", http.StatusBadRequest) + })) + defer server.Close() + + lh := &TanzuLoginHandler{ + oauthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: server.URL, + }, + }, + refreshToken: "valid-api-token", + suppressInteractive: true, } - if token != nil { - assert.Equal(t, token.AccessToken, "fake-access-token") - assert.Equal(t, token.RefreshToken, "fake-refresh-token") - assert.Equal(t, token.TokenType, "id-token") - assert.Equal(t, token.IDToken, "fake-id-token") - assert.Equal(t, token.ExpiresIn, int64(3599)) + token, err := lh.DoLogin() + + assert.Nil(err) + assert.NotNil(token) + assert.Equal(token.AccessToken, "fake-access-token") + assert.Equal(token.RefreshToken, "fake-refresh-token") + assert.Equal(token.TokenType, "id-token") + assert.Equal(token.IDToken, "fake-id-token") + assert.Equal(token.ExpiresIn, int64(3599)) + + lh = &TanzuLoginHandler{ + oauthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: server.URL, + }, + }, + refreshToken: "bad-refresh-token", + suppressInteractive: true, } + token, err = lh.DoLogin() + assert.NotNil(err) + assert.Nil(token) } func TestGetAuthCodeURL_validResponse(t *testing.T) { diff --git a/pkg/auth/uaa/uaa.go b/pkg/auth/uaa/uaa.go index 500c14fd5..eba4b91e2 100644 --- a/pkg/auth/uaa/uaa.go +++ b/pkg/auth/uaa/uaa.go @@ -5,66 +5,10 @@ package uaa import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/url" - - "github.com/pkg/errors" - "github.com/vmware-tanzu/tanzu-cli/pkg/auth/common" "github.com/vmware-tanzu/tanzu-cli/pkg/constants" - "github.com/vmware-tanzu/tanzu-cli/pkg/interfaces" ) -var ( - httpRestClient interfaces.HTTPClient -) - -// GetAccessTokenFromAPIToken fetches access token using the API-token. -func GetAccessTokenFromAPIToken(apiToken, uaaEndpoint, endpointCACertPath string, skipTLSVerify bool) (*common.Token, error) { - tokenURL := getIssuerEndpoints(uaaEndpoint).TokenURL - data := url.Values{} - data.Set("refresh_token", apiToken) - data.Set("client_id", GetAlternateClientID()) - data.Set("grant_type", "refresh_token") - - req, _ := http.NewRequestWithContext(context.Background(), "POST", tokenURL, bytes.NewBufferString(data.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - if httpRestClient == nil { - tlsConfig := common.GetTLSConfig(uaaEndpoint, endpointCACertPath, skipTLSVerify) - if tlsConfig == nil { - return nil, errors.New("unable to set up tls config") - } - - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = tlsConfig - httpRestClient = &http.Client{Transport: tr} - } - - resp, err := httpRestClient.Do(req) - if err != nil { - return nil, errors.WithMessage(err, "Failed to obtain access token. Please provide valid API token") - } - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, errors.Errorf("Failed to obtain access token. Please provide valid API token -- %s", string(body)) - } - - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - token := common.Token{} - - if err = json.Unmarshal(body, &token); err != nil { - return nil, errors.Wrap(err, "could not unmarshal auth token") - } - - return &token, nil -} - // GetTokens fetches the UAA access token func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, error) { clientID := tanzuCLIClientID @@ -72,6 +16,9 @@ func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, e clientID = GetAlternateClientID() } loginOptions := []common.LoginOption{common.WithRefreshToken(refreshOrAPIToken), common.WithListenerPortFromEnv(constants.TanzuCLIOAuthLocalListenerPort), common.WithClientID(clientID)} + if tokenType == common.APITokenType { + loginOptions = append(loginOptions, common.WithSuppressInteractive(true)) + } token, err := TanzuLogin(issuer, loginOptions...) if err != nil { diff --git a/pkg/auth/uaa/uaa_test.go b/pkg/auth/uaa/uaa_test.go deleted file mode 100644 index 1a37edd56..000000000 --- a/pkg/auth/uaa/uaa_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2023 VMware, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package uaa - -import ( - "bytes" - "fmt" - "io" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/vmware-tanzu/tanzu-cli/pkg/fakes" -) - -const ( - fakeIssuerURL = "https://auth0.com/" - fakeAPIToken = "fake_api_token" - fakeCACrtPath = "/fake/ca.crt" - fakeSkipVerify = false -) - -func TestGetAccessTokenFromAPIToken(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(`{ - "id_token": "abc", - "token_type": "Test", - "expires_in": 86400, - "scope": "Test", - "access_token": "LetMeIn", - "refresh_token": "LetMeInAgain"}`))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 200, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - if err != nil { - fmt.Println(err) - fmt.Println("Error...................................") - } - assert.Nil(err) - assert.Equal("LetMeIn", token.AccessToken) - - req := fakeHTTPClient.DoArgsForCall(0) - bodyBytes, _ := io.ReadAll(req.Body) - body := string(bodyBytes) - - assert.Contains(body, "refresh_token="+fakeAPIToken) - assert.Contains(body, "client_id="+GetAlternateClientID()) - assert.Contains(body, "grant_type=refresh_token") -} - -func TestGetAccessTokenFromAPIToken_FailStatus(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(``))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 403, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - assert.NotNil(err) - assert.Contains(err.Error(), "Failed to obtain access token. Please provide valid API token") - assert.Nil(token) -} - -func TestGetAccessTokenFromAPIToken_InvalidResponse(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(`[{ - "id_token": "abc", - "token_type": "Test", - "expires_in": 86400, - "scope": "Test", - "access_token": "LetMeIn", - "refresh_token": "LetMeInAgain"}]`))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 200, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - assert.NotNil(err) - assert.Contains(err.Error(), "could not unmarshal") - assert.Nil(token) -} diff --git a/pkg/command/context.go b/pkg/command/context.go index 9cf39882b..6bae0224f 100644 --- a/pkg/command/context.go +++ b/pkg/command/context.go @@ -733,7 +733,26 @@ func getSelfManagedOrg(c *configtypes.Context) (string, string) { } func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiTokenValue string) (claims *commonauth.Claims, err error) { - token, err := uaa.GetAccessTokenFromAPIToken(apiTokenValue, uaaEndpoint, endpointCACertPath, skipTLSVerify) + loginOptions := []commonauth.LoginOption{ + commonauth.WithSuppressInteractive(true), // fail instead of falling back to interactive login + commonauth.WithRefreshToken(apiTokenValue), + commonauth.WithClientID(uaa.GetAlternateClientID()), + } + + var endpointCACertData string + if endpointCACertPath != "" { + fileBytes, err := os.ReadFile(endpointCACertPath) + if err != nil { + return nil, errors.Wrapf(err, "error reading certificate file %s", endpointCACertPath) + } + endpointCACertData = base64.StdEncoding.EncodeToString(fileBytes) + } + if skipTLSVerify || endpointCACertData != "" { + loginOptions = append(loginOptions, commonauth.WithCertInfo(skipTLSVerify, endpointCACertData)) + } + + // Invoke TanzuLogin to obtain access token via API token + token, err := uaa.TanzuLogin(uaaEndpoint, loginOptions...) if err != nil { return nil, errors.Wrap(err, "failed to get token from UAA") } @@ -748,7 +767,7 @@ func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiT a.Permissions = claims.Permissions a.AccessToken = token.AccessToken a.IDToken = token.IDToken - a.RefreshToken = apiTokenValue + a.RefreshToken = token.RefreshToken a.Type = commonauth.APITokenType expiresAt := time.Now().Local().Add(time.Second * time.Duration(token.ExpiresIn)) a.Expiration = expiresAt From ca10d433e0735765632fdf32db2a191b47c75554 Mon Sep 17 00:00:00 2001 From: Vui Lam Date: Tue, 8 Oct 2024 17:14:08 -0700 Subject: [PATCH 3/3] Update the cert map on successful non interactive login Signed-off-by: Vui Lam --- pkg/auth/common/login_handler.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/auth/common/login_handler.go b/pkg/auth/common/login_handler.go index 9cf1f34f9..2331ac15d 100644 --- a/pkg/auth/common/login_handler.go +++ b/pkg/auth/common/login_handler.go @@ -182,6 +182,11 @@ func (h *TanzuLoginHandler) DoLogin() (*Token, error) { if h.refreshToken != "" { token, err = h.getTokenWithRefreshToken() if err == nil || h.suppressInteractive { + // non interactive login mode should update the cert map as well + // before returning. + if err == nil && h.suppressInteractive { + h.updateCertMap() + } return token, err } }