Skip to content

Commit

Permalink
Simplify auth with api token against uaa
Browse files Browse the repository at this point in the history
Refactored login handler so that we can leverage it to perform the token
refresh using provided API token.

Signed-off-by: Vui Lam <[email protected]>
  • Loading branch information
vuil committed Oct 8, 2024
1 parent 31d6637 commit 63170d0
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 163 deletions.
20 changes: 17 additions & 3 deletions pkg/auth/common/login_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type TanzuLoginHandler struct {
callbackHandlerMutex sync.Mutex
tlsSkipVerify bool
caCertData string
suppressInteractive bool
}

// LoginOption is an optional configuration for Login().
Expand Down Expand Up @@ -133,6 +134,15 @@ func WithClientID(clientID string) LoginOption {
}
}

// WithSuppressInteractive specifies whether to fall back to interactive login if
// an access token cannot be obtained.
func WithSuppressInteractive(suppress bool) LoginOption {
return func(h *TanzuLoginHandler) error {
h.suppressInteractive = suppress
return nil
}
}

// WithListenerPort specifies a TCP listener port on localhost, which will be used for the redirect_uri and to handle the
// authorization code callback. By default, a random high port will be chosen which requires the authorization server
// to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252#section-7.3:
Expand Down Expand Up @@ -166,13 +176,17 @@ func WithListenerPortFromEnv(envVarName string) LoginOption {
}

func (h *TanzuLoginHandler) DoLogin() (*Token, error) {
var err error
var token *Token

if h.refreshToken != "" {
ctx := contextWithCustomTLSConfig(context.TODO(), h.getTLSConfig())
token, err := h.getTokenWithRefreshToken(ctx)
if err == nil {
return token, nil
token, err = h.getTokenWithRefreshToken(ctx)
if err == nil || h.suppressInteractive {
return token, err
}
}

// If refresh token fails, proceed with login flow through the browser
return h.browserLogin()
}
Expand Down
69 changes: 58 additions & 11 deletions pkg/auth/common/login_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const (
)

func TestHandleTokenRefresh(t *testing.T) {
assert := assert.New(t)

// Mock HTTP server for token refresh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
Expand All @@ -45,19 +47,64 @@ func TestHandleTokenRefresh(t *testing.T) {
}

token, err := lh.getTokenWithRefreshToken(context.TODO())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if token == nil {
t.Error("Expected a non-nil token, got nil")
assert.Nil(err)
assert.NotNil(token)
assert.Equal(token.AccessToken, "fake-access-token")
assert.Equal(token.RefreshToken, "fake-refresh-token")
assert.Equal(token.TokenType, "id-token")
assert.Equal(token.IDToken, "fake-id-token")
assert.Equal(token.ExpiresIn, int64(3599))
}

// test that login with refresh token completes without triggering browser
// login regardless of whether refresh succeeded or not
func TestLoginWithAPIToken(t *testing.T) {
assert := assert.New(t)

// Mock HTTP server for token refresh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
body, _ := io.ReadAll(r.Body)
if strings.Contains(string(body), "refresh_token=valid-api-token") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token": "fake-access-token", "refresh_token": "fake-refresh-token", "expires_in": 3600, "id_token": "fake-id-token"}`))
return
}
http.Error(w, "refresh_error", http.StatusBadRequest)
}))
defer server.Close()

lh := &TanzuLoginHandler{
oauthConfig: &oauth2.Config{
Endpoint: oauth2.Endpoint{
TokenURL: server.URL,
},
},
refreshToken: "valid-api-token",
suppressInteractive: true,
}
if token != nil {
assert.Equal(t, token.AccessToken, "fake-access-token")
assert.Equal(t, token.RefreshToken, "fake-refresh-token")
assert.Equal(t, token.TokenType, "id-token")
assert.Equal(t, token.IDToken, "fake-id-token")
assert.Equal(t, token.ExpiresIn, int64(3599))
token, err := lh.DoLogin()

assert.Nil(err)
assert.NotNil(token)
assert.Equal(token.AccessToken, "fake-access-token")
assert.Equal(token.RefreshToken, "fake-refresh-token")
assert.Equal(token.TokenType, "id-token")
assert.Equal(token.IDToken, "fake-id-token")
assert.Equal(token.ExpiresIn, int64(3599))

lh = &TanzuLoginHandler{
oauthConfig: &oauth2.Config{
Endpoint: oauth2.Endpoint{
TokenURL: server.URL,
},
},
refreshToken: "bad-refresh-token",
suppressInteractive: true,
}
token, err = lh.DoLogin()
assert.NotNil(err)
assert.Nil(token)
}

func TestGetAuthCodeURL_validResponse(t *testing.T) {
Expand Down
56 changes: 0 additions & 56 deletions pkg/auth/uaa/uaa.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,10 @@
package uaa

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"

"github.com/pkg/errors"

"github.com/vmware-tanzu/tanzu-cli/pkg/auth/common"
"github.com/vmware-tanzu/tanzu-cli/pkg/constants"
"github.com/vmware-tanzu/tanzu-cli/pkg/interfaces"
)

var (
httpRestClient interfaces.HTTPClient
)

// GetAccessTokenFromAPIToken fetches access token using the API-token.
func GetAccessTokenFromAPIToken(apiToken, uaaEndpoint, endpointCACertPath string, skipTLSVerify bool) (*common.Token, error) {
tokenURL := getIssuerEndpoints(uaaEndpoint).TokenURL
data := url.Values{}
data.Set("refresh_token", apiToken)
data.Set("client_id", GetAlternateClientID())
data.Set("grant_type", "refresh_token")

req, _ := http.NewRequestWithContext(context.Background(), "POST", tokenURL, bytes.NewBufferString(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

if httpRestClient == nil {
tlsConfig := common.GetTLSConfig(uaaEndpoint, endpointCACertPath, skipTLSVerify)
if tlsConfig == nil {
return nil, errors.New("unable to set up tls config")
}

tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = tlsConfig
httpRestClient = &http.Client{Transport: tr}
}

resp, err := httpRestClient.Do(req)
if err != nil {
return nil, errors.WithMessage(err, "Failed to obtain access token. Please provide valid API token")
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, errors.Errorf("Failed to obtain access token. Please provide valid API token -- %s", string(body))
}

defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
token := common.Token{}

if err = json.Unmarshal(body, &token); err != nil {
return nil, errors.Wrap(err, "could not unmarshal auth token")
}

return &token, nil
}

// GetTokens fetches the UAA access token
func GetTokens(refreshOrAPIToken, _, issuer, tokenType string) (*common.Token, error) {
clientID := tanzuCLIClientID
Expand Down
92 changes: 0 additions & 92 deletions pkg/auth/uaa/uaa_test.go

This file was deleted.

21 changes: 20 additions & 1 deletion pkg/command/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,26 @@ func getSelfManagedOrg(c *configtypes.Context) (string, string) {
}

func doUAAAPITokenAuthAndUpdateContext(c *configtypes.Context, uaaEndpoint, apiTokenValue string) (claims *commonauth.Claims, err error) {
token, err := uaa.GetAccessTokenFromAPIToken(apiTokenValue, uaaEndpoint, endpointCACertPath, skipTLSVerify)
loginOptions := []commonauth.LoginOption{
commonauth.WithSuppressInteractive(true), // fail instead of falling back to interactive login
commonauth.WithRefreshToken(apiTokenValue),
commonauth.WithClientID(uaa.GetAlternateClientID()),
}

var endpointCACertData string
if endpointCACertPath != "" {
fileBytes, err := os.ReadFile(endpointCACertPath)
if err != nil {
return nil, errors.Wrapf(err, "error reading certificate file %s", endpointCACertPath)
}
endpointCACertData = base64.StdEncoding.EncodeToString(fileBytes)
}
if skipTLSVerify || endpointCACertData != "" {
loginOptions = append(loginOptions, commonauth.WithCertInfo(skipTLSVerify, endpointCACertData))
}

// Invoke TanzuLogin to obtain access token via API token
token, err := uaa.TanzuLogin(uaaEndpoint, loginOptions...)
if err != nil {
return nil, errors.Wrap(err, "failed to get token from UAA")
}
Expand Down

0 comments on commit 63170d0

Please sign in to comment.