Skip to content

Commit

Permalink
Display IDP name when prompting for username and password
Browse files Browse the repository at this point in the history
[#181927293]
  • Loading branch information
joshuatcasey committed Sep 29, 2023
1 parent 78cb862 commit 936645a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 52 deletions.
1 change: 1 addition & 0 deletions cmd/pinniped/cmd/login_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func runOIDCLogin(cmd *cobra.Command, deps oidcLoginCommandDeps, flags oidcLogin
oidcclient.WithLogger(plog.Logr()), //nolint:staticcheck // old code with lots of log statements
oidcclient.WithScopes(flags.scopes),
oidcclient.WithSessionCache(sessionCache),
oidcclient.WithOutWriter(os.Stderr),
}

if flags.listenPort != 0 {
Expand Down
47 changes: 30 additions & 17 deletions pkg/oidcclient/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ const (
// we set this to be relatively long.
overallTimeout = 90 * time.Minute

defaultLDAPUsernamePrompt = "Username: "
defaultLDAPPasswordPrompt = "Password: "
usernamePrompt = "Username: "
passwordPrompt = "Password: "

// For CLI-based auth, such as with LDAP upstream identity providers, the user may use these environment variables
// to avoid getting interactively prompted for username and password.
Expand All @@ -78,6 +78,7 @@ type handlerState struct {
clientID string
scopes []string
cache SessionCache
out io.Writer

upstreamIdentityProviderName string
upstreamIdentityProviderType string
Expand Down Expand Up @@ -109,8 +110,8 @@ type handlerState struct {
isTTY func(int) bool
getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error)
promptForValue func(ctx context.Context, promptLabel string, out io.Writer) (string, error)
promptForSecret func(promptLabel string, out io.Writer) (string, error)

callbacks chan callbackResult
}
Expand Down Expand Up @@ -216,6 +217,14 @@ func WithSessionCache(cache SessionCache) Option {
}
}

// WithOutWriter sets stderr io.Writer.
func WithOutWriter(out io.Writer) Option {
return func(h *handlerState) error {
h.out = out
return nil
}
}

// WithClient sets the HTTP client used to make CLI-to-provider requests.
func WithClient(httpClient *http.Client) Option {
return func(h *handlerState) error {
Expand Down Expand Up @@ -420,7 +429,7 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
// Ask the user for their username and password, or get them from env vars.
username, password, err := h.getUsernameAndPassword()
username, password, err := h.getUsernameAndPassword(h.out)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -508,12 +517,16 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (
}

// Prompt for the user's username and password, or read them from env vars if they are available.
func (h *handlerState) getUsernameAndPassword() (string, string, error) {
func (h *handlerState) getUsernameAndPassword(out io.Writer) (string, string, error) {
var err error

if h.upstreamIdentityProviderName != "" {
_, _ = fmt.Fprintf(out, "\nLog in to %s\n\n", h.upstreamIdentityProviderName)
}

username := h.getEnv(defaultUsernameEnvVarName)
if username == "" {
username, err = h.promptForValue(h.ctx, defaultLDAPUsernamePrompt)
username, err = h.promptForValue(h.ctx, usernamePrompt, h.out)
if err != nil {
return "", "", fmt.Errorf("error prompting for username: %w", err)
}
Expand All @@ -523,7 +536,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {

password := h.getEnv(defaultPasswordEnvVarName)
if password == "" {
password, err = h.promptForSecret(defaultLDAPPasswordPrompt)
password, err = h.promptForSecret(passwordPrompt, h.out)
if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err)
}
Expand Down Expand Up @@ -581,7 +594,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp

// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, h.out)
defer func() {
cancel()
cleanupPrompt()
Expand Down Expand Up @@ -621,15 +634,15 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
go func() {
defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr)
_, _ = fmt.Fprintln(h.out)

wg.Done()
}()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ", h.out)
if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
_, _ = fmt.Fprint(h.out, "[...]\n")

h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return
Expand All @@ -642,11 +655,11 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
return wg.Wait
}

func promptForValue(ctx context.Context, promptLabel string) (string, error) {
func promptForValue(ctx context.Context, promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal")
}
_, err := fmt.Fprint(os.Stderr, promptLabel)
_, err := fmt.Fprint(out, promptLabel)
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
Expand Down Expand Up @@ -674,11 +687,11 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
}
}

func promptForSecret(promptLabel string) (string, error) {
func promptForSecret(promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal")
}
_, err := fmt.Fprint(os.Stderr, promptLabel)
_, err := fmt.Fprint(out, promptLabel)
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
Expand All @@ -689,7 +702,7 @@ func promptForSecret(promptLabel string) (string, error) {
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n")
_, err = fmt.Fprint(out, "\n")
if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err)
}
Expand Down
77 changes: 42 additions & 35 deletions pkg/oidcclient/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"syscall"
"testing"
Expand Down Expand Up @@ -316,8 +317,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "some-upstream-username", nil
}
h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "some-upstream-password", nil }

cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Expand Down Expand Up @@ -352,13 +355,14 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}

tests := []struct {
name string
opt func(t *testing.T) Option
issuer string
clientID string
wantErr string
wantToken *oidctypes.Token
wantLogs []string
name string
opt func(t *testing.T) Option
issuer string
clientID string
wantErr string
wantToken *oidctypes.Token
wantLogs []string
wantStdErr string
}{
{
name: "option error",
Expand Down Expand Up @@ -709,7 +713,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error")
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
Expand All @@ -720,7 +724,8 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open browser" "error"="some browser open error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
wantStdErr: "Log in by visiting this link.*",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "listen success and manual prompt succeeds",
Expand All @@ -736,7 +741,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
Expand Down Expand Up @@ -990,7 +995,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error")
}
Expand All @@ -1007,7 +1012,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
Expand Down Expand Up @@ -1198,11 +1203,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.getEnv = func(_ string) string {
return "" // asking for any env var returns empty as if it were unset
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}
Expand Down Expand Up @@ -1273,9 +1278,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken,
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "\nLog in to some-upstream-name\n\n",
wantToken: &testToken,
},
{
name: "successful ldap login with env vars for username and password",
Expand Down Expand Up @@ -1306,22 +1312,21 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset
}
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}

cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
UpstreamProviderName: "some-upstream-name",
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
Expand All @@ -1330,7 +1335,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))

discoveryRequestWasMade := false
authorizeRequestWasMade := false
Expand Down Expand Up @@ -1362,8 +1366,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://127.0.0.1:0/callback"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"ldap"},
}, req.URL.Query())
return &http.Response{
StatusCode: http.StatusFound,
Expand All @@ -1387,7 +1389,8 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
"\"level\"=4 \"msg\"=\"Pinniped: Read username from environment variable\" \"name\"=\"PINNIPED_USERNAME\"",
"\"level\"=4 \"msg\"=\"Pinniped: Read password from environment variable\" \"name\"=\"PINNIPED_PASSWORD\"",
},
wantToken: &testToken,
wantStdErr: `$^`, // should only match an empty string
wantToken: &testToken,
},
{
name: "successful ldap login with env vars for username and password, http.StatusSeeOther redirect",
Expand Down Expand Up @@ -1418,11 +1421,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset
}
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
Expand Down Expand Up @@ -1898,15 +1901,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements
klog.SetLogger(testLogger.Logger)

buffer := bytes.Buffer{}
tok, err := Login(tt.issuer, tt.clientID,
WithContext(context.Background()),
WithListenPort(0),
WithScopes([]string{"test-scope"}),
WithSkipBrowserOpen(),
tt.opt(t),
WithLogger(testLogger.Logger),
WithOutWriter(&buffer),
)
testLogger.Expect(tt.wantLogs)
require.Regexp(t, tt.wantStdErr, buffer.String())
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok)
Expand Down Expand Up @@ -1977,7 +1983,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
Expand All @@ -1994,7 +2000,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
Expand All @@ -2018,7 +2024,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
Expand Down Expand Up @@ -2047,6 +2053,7 @@ func TestHandlePasteCallback(t *testing.T) {
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
out: os.Stderr,
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
Expand Down

0 comments on commit 936645a

Please sign in to comment.