From cede6403e12276eaaf4e6e4ef43cd8f8f68bf31d Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Mon, 25 Sep 2023 15:41:16 -0700 Subject: [PATCH] Same error messages shown in CLI's callback web page and in terminal --- pkg/oidcclient/login.go | 48 +++++++++++++++++++++--------------- pkg/oidcclient/login_test.go | 34 ++++++++++++++++--------- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index fac2b99c4..f31592c25 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -499,16 +499,9 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) ( // validations on the returned ID token. tokenCtx, tokenCtxCancelFunc := context.WithTimeout(h.ctx, httpRequestTimeout) defer tokenCtxCancelFunc() - token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). - ExchangeAuthcodeAndValidateTokens( - tokenCtx, - authCode, - h.pkce, - h.nonce, - h.oauth2Config.RedirectURL, - ) + token, err := h.redeemAuthCode(tokenCtx, authCode) if err != nil { - return nil, fmt.Errorf("error during authorization code exchange: %w", err) + return nil, fmt.Errorf("could not complete authorization code exchange: %w", err) } return token, nil @@ -642,7 +635,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin return } - // When a code is pasted, redeem it for a token and return that result on the callbacks channel. + // When a code is pasted, redeem it for a token and return the results on the callback channel. token, err := h.redeemAuthCode(ctx, code) h.callbacks <- callbackResult{token: token, err: err} }() @@ -849,11 +842,23 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true, false) } -func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { - // If we return an error, also report it back over the channel to the main CLI thread. +// handleAuthCodeCallback is used as an http handler, so it does not run in the CLI's main goroutine. +// Upon a callback redirect request from an identity provider, it uses a callback channel to communicate +// its results back to the main thread of the CLI. The result can contain either some tokens from the +// identity provider's token endpoint, or the result can contain an error. When the result is an error, +// the CLI's main goroutine is responsible for printing that error to the terminal. At the same time, +// this function serves a web response, and that web response is rendered in the user's browser. So the +// user has two places to look for error messages: in their browser and in the CLI's terminal. Ideally, +// these messages would be the same. Note that using httperr.Wrap will cause the details of the wrapped +// err to be printed by the CLI, but not printed in the browser due to the way that the httperr package +// works, so avoid using httperr.Wrap in this function. +func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (returnedErr error) { defer func() { - if err != nil { - h.callbacks <- callbackResult{err: err} + // If we returned an error, then also report it back over the channel to the main CLI goroutine. + // Because returnedErr is the named return value, inside this defer returnedErr will hold the value + // returned by any explicit return statement. + if returnedErr != nil { + h.callbacks <- callbackResult{err: returnedErr} } }() @@ -867,9 +872,10 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req } // For POST and OPTIONS requests, calculate the allowed origin for CORS. - issuerURL, parseErr := url.Parse(h.issuer) - if parseErr != nil { - return httperr.Wrap(http.StatusInternalServerError, "invalid issuer url", parseErr) + issuerURL, err := url.Parse(h.issuer) + if err != nil { + // Avoid using httperr.Wrap because that would hide the details of err from the browser output. + return httperr.Newf(http.StatusInternalServerError, "invalid issuer url: %s", err.Error()) } allowOrigin := issuerURL.Scheme + "://" + issuerURL.Host @@ -902,8 +908,9 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req } // Otherwise, this is a POST request... // Parse and pull the response parameters from an application/x-www-form-urlencoded request body. - if err := r.ParseForm(); err != nil { - return httperr.Wrap(http.StatusBadRequest, "invalid form", err) + if err = r.ParseForm(); err != nil { + // Avoid using httperr.Wrap because that would hide the details of err from the browser output. + return httperr.Newf(http.StatusBadRequest, "invalid form: %s", err.Error()) } params = r.Form @@ -943,7 +950,8 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // validations on the returned ID token. token, err := h.redeemAuthCode(r.Context(), params.Get("code")) if err != nil { - return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) + // Avoid using httperr.Wrap because that would hide the details of err from the browser output. + return httperr.Newf(http.StatusBadRequest, "could not complete authorization code exchange: %s", err.Error()) } h.callbacks <- callbackResult{token: token} diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index b041dbfcf..a7268d090 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -1174,7 +1174,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo }, issuer: successServer.URL, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "error during authorization code exchange: some authcode exchange or token validation error", + wantErr: "could not complete authorization code exchange: some authcode exchange or token validation error", }, { name: "successful ldap login with prompts for username and password", @@ -2236,7 +2236,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { { name: "invalid code", query: "state=test-state&code=invalid", - wantErr: "could not complete code exchange: some exchange error", + wantErr: "could not complete authorization code exchange: some exchange error", wantHeaders: map[string][]string{}, wantHTTPStatus: http.StatusBadRequest, opt: func(t *testing.T) Option { @@ -2362,14 +2362,25 @@ func TestHandleAuthCodeCallback(t *testing.T) { err = h.handleAuthCodeCallback(resp, req) if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) - if tt.wantHTTPStatus != 0 { - rec := httptest.NewRecorder() - err.(httperr.Responder).Respond(rec) - require.Equal(t, tt.wantHTTPStatus, rec.Code) - } + rec := httptest.NewRecorder() + err.(httperr.Responder).Respond(rec) + require.Equal(t, tt.wantHTTPStatus, rec.Code) + // The error message returned (to be shown by the CLI) and the error message shown in the resulting + // web page should always be the same. + require.Equal(t, http.StatusText(tt.wantHTTPStatus)+": "+tt.wantErr+"\n", rec.Body.String()) } else { require.NoError(t, err) require.Equal(t, tt.wantHTTPStatus, resp.Code) + switch { + case tt.wantNoCallbacks: + // When we return an error but keep listening, then we don't need a response body. + require.Empty(t, resp.Body) + case tt.wantHTTPStatus == http.StatusOK: + // When the login succeeds, the response body should show the success message. + require.Equal(t, "you have been logged in and may now close this tab", resp.Body.String()) + default: + t.Fatal("test author made a mistake by expecting a non-200 response code without a wantErr") + } } if tt.wantHeaders != nil { @@ -2385,11 +2396,12 @@ func TestHandleAuthCodeCallback(t *testing.T) { case result := <-h.callbacks: if tt.wantErr != "" { require.EqualError(t, result.err, tt.wantErr) - return + require.Nil(t, result.token) + } else { + require.NoError(t, result.err) + require.NotNil(t, result.token) + require.Equal(t, result.token.IDToken.Token, "test-id-token") } - require.NoError(t, result.err) - require.NotNil(t, result.token) - require.Equal(t, result.token.IDToken.Token, "test-id-token") gotCallback = true } require.Equal(t, tt.wantNoCallbacks, !gotCallback)