diff --git a/pkg/auth/common/login_handler.go b/pkg/auth/common/login_handler.go index e6d55e24d..2331ac15d 100644 --- a/pkg/auth/common/login_handler.go +++ b/pkg/auth/common/login_handler.go @@ -60,6 +60,7 @@ type TanzuLoginHandler struct { callbackHandlerMutex sync.Mutex tlsSkipVerify bool caCertData string + suppressInteractive bool } // LoginOption is an optional configuration for Login(). @@ -133,6 +134,15 @@ func WithClientID(clientID string) LoginOption { } } +// WithSuppressInteractive specifies whether to fall back to interactive login if +// an access token cannot be obtained. +func WithSuppressInteractive(suppress bool) LoginOption { + return func(h *TanzuLoginHandler) error { + h.suppressInteractive = suppress + return nil + } +} + // WithListenerPort specifies a TCP listener port on localhost, which will be used for the redirect_uri and to handle the // authorization code callback. By default, a random high port will be chosen which requires the authorization server // to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252#section-7.3: @@ -166,18 +176,28 @@ func WithListenerPortFromEnv(envVarName string) LoginOption { } func (h *TanzuLoginHandler) DoLogin() (*Token, error) { + var err error + var token *Token + if h.refreshToken != "" { - token, err := h.getTokenWithRefreshToken() - if err == nil { - return token, nil + token, err = h.getTokenWithRefreshToken() + if err == nil || h.suppressInteractive { + // non interactive login mode should update the cert map as well + // before returning. + if err == nil && h.suppressInteractive { + h.updateCertMap() + } + return token, err } } + // If refresh token fails, proceed with login flow through the browser return h.browserLogin() } func (h *TanzuLoginHandler) getTokenWithRefreshToken() (*Token, error) { - refreshedToken, err := h.oauthConfig.TokenSource(context.TODO(), &oauth2.Token{RefreshToken: h.refreshToken}).Token() + ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig()) + refreshedToken, err := h.oauthConfig.TokenSource(ctx, &oauth2.Token{RefreshToken: h.refreshToken}).Token() if err != nil { return nil, err } @@ -492,16 +512,20 @@ func GetTLSConfig(endpoint, certData string, skipVerify bool) *tls.Config { return nil } +func contextWithCustomTLSConfig(ctx context.Context, tlsConfig *tls.Config) context.Context { + if tlsConfig != nil { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = tlsConfig + + sslcli := &http.Client{Transport: tr} + ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli) + } + return ctx +} + func (h *TanzuLoginHandler) getTokenUsingAuthCode(ctx context.Context, code string) (*oauth2.Token, error) { if h.idpType == config.UAAIdpType { - tlsConfig := h.getTLSConfig() - if tlsConfig != nil { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = tlsConfig - - sslcli := &http.Client{Transport: tr} - ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli) - } + ctx = contextWithCustomTLSConfig(ctx, h.getTLSConfig()) } token, err := h.oauthConfig.Exchange(ctx, code, h.pkceCodePair.Verifier()) diff --git a/pkg/auth/common/login_handler_test.go b/pkg/auth/common/login_handler_test.go index 2b15638da..5ad6d0ecb 100644 --- a/pkg/auth/common/login_handler_test.go +++ b/pkg/auth/common/login_handler_test.go @@ -27,6 +27,8 @@ const ( ) func TestHandleTokenRefresh(t *testing.T) { + assert := assert.New(t) + // Mock HTTP server for token refresh server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -45,19 +47,64 @@ func TestHandleTokenRefresh(t *testing.T) { } token, err := lh.getTokenWithRefreshToken() - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - if token == nil { - t.Error("Expected a non-nil token, got nil") + assert.Nil(err) + assert.NotNil(token) + assert.Equal(token.AccessToken, "fake-access-token") + assert.Equal(token.RefreshToken, "fake-refresh-token") + assert.Equal(token.TokenType, "id-token") + assert.Equal(token.IDToken, "fake-id-token") + assert.Equal(token.ExpiresIn, int64(3599)) +} + +// test that login with refresh token completes without triggering browser +// login regardless of whether refresh succeeded or not +func TestLoginWithAPIToken(t *testing.T) { + assert := assert.New(t) + + // Mock HTTP server for token refresh + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), "refresh_token=valid-api-token") { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token": "fake-access-token", "refresh_token": "fake-refresh-token", "expires_in": 3600, "id_token": "fake-id-token"}`)) + return + } + http.Error(w, "refresh_error", http.StatusBadRequest) + })) + defer server.Close() + + lh := &TanzuLoginHandler{ + oauthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: server.URL, + }, + }, + refreshToken: "valid-api-token", + suppressInteractive: true, } - if token != nil { - assert.Equal(t, token.AccessToken, "fake-access-token") - assert.Equal(t, token.RefreshToken, "fake-refresh-token") - assert.Equal(t, token.TokenType, "id-token") - assert.Equal(t, token.IDToken, "fake-id-token") - assert.Equal(t, token.ExpiresIn, int64(3599)) + token, err := lh.DoLogin() + + assert.Nil(err) + assert.NotNil(token) + assert.Equal(token.AccessToken, "fake-access-token") + assert.Equal(token.RefreshToken, "fake-refresh-token") + assert.Equal(token.TokenType, "id-token") + assert.Equal(token.IDToken, "fake-id-token") + assert.Equal(token.ExpiresIn, int64(3599)) + + lh = &TanzuLoginHandler{ + oauthConfig: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: server.URL, + }, + }, + refreshToken: "bad-refresh-token", + suppressInteractive: true, } + token, err = lh.DoLogin() + assert.NotNil(err) + assert.Nil(token) } func TestGetAuthCodeURL_validResponse(t *testing.T) { diff --git a/pkg/auth/uaa/uaa.go b/pkg/auth/uaa/uaa.go index 500c14fd5..eba4b91e2 100644 --- a/pkg/auth/uaa/uaa.go +++ b/pkg/auth/uaa/uaa.go @@ -5,66 +5,10 @@ package uaa import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/url" - - "github.com/pkg/errors" - "github.com/vmware-tanzu/tanzu-cli/pkg/auth/common" "github.com/vmware-tanzu/tanzu-cli/pkg/constants" - "github.com/vmware-tanzu/tanzu-cli/pkg/interfaces" ) -var ( - httpRestClient interfaces.HTTPClient -) - -// GetAccessTokenFromAPIToken fetches access token using the API-token. -func GetAccessTokenFromAPIToken(apiToken, uaaEndpoint, endpointCACertPath string, skipTLSVerify bool) (*common.Token, error) { - tokenURL := getIssuerEndpoints(uaaEndpoint).TokenURL - data := url.Values{} - data.Set("refresh_token", apiToken) - data.Set("client_id", GetAlternateClientID()) - data.Set("grant_type", "refresh_token") - - req, _ := http.NewRequestWithContext(context.Background(), "POST", tokenURL, bytes.NewBufferString(data.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - if httpRestClient == nil { - tlsConfig := common.GetTLSConfig(uaaEndpoint, endpointCACertPath, skipTLSVerify) - if tlsConfig == nil { - return nil, errors.New("unable to set up tls config") - } - - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = tlsConfig - httpRestClient = &http.Client{Transport: tr} - } - - resp, err := httpRestClient.Do(req) - if err != nil { - return nil, errors.WithMessage(err, "Failed to obtain access token. Please provide valid API token") - } - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, errors.Errorf("Failed to obtain access token. Please provide valid API token -- %s", string(body)) - } - - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - token := common.Token{} - - if err = json.Unmarshal(body, &token); err != nil { - return nil, errors.Wrap(err, "could not unmarshal auth token") - } - - return &token, nil -} - // GetTokens fetches the UAA access token func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, error) { clientID := tanzuCLIClientID @@ -72,6 +16,9 @@ func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, e clientID = GetAlternateClientID() } loginOptions := []common.LoginOption{common.WithRefreshToken(refreshOrAPIToken), common.WithListenerPortFromEnv(constants.TanzuCLIOAuthLocalListenerPort), common.WithClientID(clientID)} + if tokenType == common.APITokenType { + loginOptions = append(loginOptions, common.WithSuppressInteractive(true)) + } token, err := TanzuLogin(issuer, loginOptions...) if err != nil { diff --git a/pkg/auth/uaa/uaa_test.go b/pkg/auth/uaa/uaa_test.go deleted file mode 100644 index 1a37edd56..000000000 --- a/pkg/auth/uaa/uaa_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2023 VMware, Inc. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package uaa - -import ( - "bytes" - "fmt" - "io" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/vmware-tanzu/tanzu-cli/pkg/fakes" -) - -const ( - fakeIssuerURL = "https://auth0.com/" - fakeAPIToken = "fake_api_token" - fakeCACrtPath = "/fake/ca.crt" - fakeSkipVerify = false -) - -func TestGetAccessTokenFromAPIToken(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(`{ - "id_token": "abc", - "token_type": "Test", - "expires_in": 86400, - "scope": "Test", - "access_token": "LetMeIn", - "refresh_token": "LetMeInAgain"}`))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 200, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - if err != nil { - fmt.Println(err) - fmt.Println("Error...................................") - } - assert.Nil(err) - assert.Equal("LetMeIn", token.AccessToken) - - req := fakeHTTPClient.DoArgsForCall(0) - bodyBytes, _ := io.ReadAll(req.Body) - body := string(bodyBytes) - - assert.Contains(body, "refresh_token="+fakeAPIToken) - assert.Contains(body, "client_id="+GetAlternateClientID()) - assert.Contains(body, "grant_type=refresh_token") -} - -func TestGetAccessTokenFromAPIToken_FailStatus(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(``))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 403, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - assert.NotNil(err) - assert.Contains(err.Error(), "Failed to obtain access token. Please provide valid API token") - assert.Nil(token) -} - -func TestGetAccessTokenFromAPIToken_InvalidResponse(t *testing.T) { - assert := assert.New(t) - fakeHTTPClient := &fakes.FakeHTTPClient{} - responseBody := io.NopCloser(bytes.NewReader([]byte(`[{ - "id_token": "abc", - "token_type": "Test", - "expires_in": 86400, - "scope": "Test", - "access_token": "LetMeIn", - "refresh_token": "LetMeInAgain"}]`))) - fakeHTTPClient.DoReturns(&http.Response{ - StatusCode: 200, - Body: responseBody, - }, nil) - httpRestClient = fakeHTTPClient - - token, err := GetAccessTokenFromAPIToken(fakeAPIToken, fakeIssuerURL, fakeCACrtPath, fakeSkipVerify) - assert.NotNil(err) - assert.Contains(err.Error(), "could not unmarshal") - assert.Nil(token) -} diff --git a/pkg/command/context.go b/pkg/command/context.go index 9cf39882b..6bae0224f 100644 --- a/pkg/command/context.go +++ b/pkg/command/context.go @@ -733,7 +733,26 @@ func getSelfManagedOrg(c *configtypes.Context) (string, string) { } func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiTokenValue string) (claims *commonauth.Claims, err error) { - token, err := uaa.GetAccessTokenFromAPIToken(apiTokenValue, uaaEndpoint, endpointCACertPath, skipTLSVerify) + loginOptions := []commonauth.LoginOption{ + commonauth.WithSuppressInteractive(true), // fail instead of falling back to interactive login + commonauth.WithRefreshToken(apiTokenValue), + commonauth.WithClientID(uaa.GetAlternateClientID()), + } + + var endpointCACertData string + if endpointCACertPath != "" { + fileBytes, err := os.ReadFile(endpointCACertPath) + if err != nil { + return nil, errors.Wrapf(err, "error reading certificate file %s", endpointCACertPath) + } + endpointCACertData = base64.StdEncoding.EncodeToString(fileBytes) + } + if skipTLSVerify || endpointCACertData != "" { + loginOptions = append(loginOptions, commonauth.WithCertInfo(skipTLSVerify, endpointCACertData)) + } + + // Invoke TanzuLogin to obtain access token via API token + token, err := uaa.TanzuLogin(uaaEndpoint, loginOptions...) if err != nil { return nil, errors.Wrap(err, "failed to get token from UAA") } @@ -748,7 +767,7 @@ func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiT a.Permissions = claims.Permissions a.AccessToken = token.AccessToken a.IDToken = token.IDToken - a.RefreshToken = apiTokenValue + a.RefreshToken = token.RefreshToken a.Type = commonauth.APITokenType expiresAt := time.Now().Local().Add(time.Second * time.Duration(token.ExpiresIn)) a.Expiration = expiresAt