From 9cd4e87073724f76777d20d9498022dd99d3b865 Mon Sep 17 00:00:00 2001 From: dorsha Date: Sun, 8 Dec 2024 16:24:21 +0200 Subject: [PATCH] Multi SSO - allow to pass sso ID for sso start method --- descope/internal/auth/sso.go | 6 ++++- descope/internal/auth/sso_test.go | 27 ++++++++++++++++--- descope/sdk/auth.go | 3 ++- .../tests/mocks/auth/authenticationmock.go | 6 ++--- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/descope/internal/auth/sso.go b/descope/internal/auth/sso.go index ed4a65f3..c58fa3d7 100644 --- a/descope/internal/auth/sso.go +++ b/descope/internal/auth/sso.go @@ -18,7 +18,7 @@ type ssoStartResponse struct { URL string `json:"url"` } -func (auth *sso) Start(ctx context.Context, tenant string, redirectURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (url string, err error) { +func (auth *sso) Start(ctx context.Context, tenant string, redirectURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (url string, err error) { if tenant == "" { return "", utils.NewInvalidArgumentError("tenant") } @@ -31,6 +31,10 @@ func (auth *sso) Start(ctx context.Context, tenant string, redirectURL string, p if len(prompt) > 0 { m["prompt"] = prompt } + if len(ssoID) > 0 { + m["ssoId"] = ssoID + } + var pswd string if loginOptions.IsJWTRequired() { pswd, err = getValidRefreshToken(r) diff --git a/descope/internal/auth/sso_test.go b/descope/internal/auth/sso_test.go index c39e18bd..c32b27ec 100644 --- a/descope/internal/auth/sso_test.go +++ b/descope/internal/auth/sso_test.go @@ -29,7 +29,26 @@ func TestSSOStart(t *testing.T) { })) require.NoError(t, err) w := httptest.NewRecorder() - urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, nil, nil, w) + urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, "", nil, nil, w) + require.NoError(t, err) + assert.EqualValues(t, uri, urlStr) + assert.EqualValues(t, urlStr, w.Result().Header.Get(descope.RedirectLocationCookieName)) + assert.EqualValues(t, http.StatusTemporaryRedirect, w.Result().StatusCode) +} + +func TestSSOStartWithSSOID(t *testing.T) { + uri := "http://test.me" + tenant := "tenantID" + prompt := "none" + landingURL := "https://test.com" + ssoID := "lu lu" + a, err := newTestAuth(nil, DoRedirect(uri, func(r *http.Request) { + assert.EqualValues(t, fmt.Sprintf("%s?prompt=%s&redirectURL=%s&ssoId=%s&tenant=%s", composeSSOStartURL(), prompt, url.QueryEscape(landingURL), url.QueryEscape(ssoID), tenant), r.URL.RequestURI()) + assert.Nil(t, r.Body) + })) + require.NoError(t, err) + w := httptest.NewRecorder() + urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, ssoID, nil, nil, w) require.NoError(t, err) assert.EqualValues(t, uri, urlStr) assert.EqualValues(t, urlStr, w.Result().Header.Get(descope.RedirectLocationCookieName)) @@ -44,7 +63,7 @@ func TestSSOStartFailureNoTenant(t *testing.T) { landingURL := "https://test.com" prompt := "none" tenant := "" - _, err = a.SSO().Start(context.Background(), tenant, landingURL, prompt, nil, nil, w) + _, err = a.SSO().Start(context.Background(), tenant, landingURL, prompt, "", nil, nil, w) require.ErrorIs(t, err, utils.NewInvalidArgumentError("tenant")) } @@ -68,7 +87,7 @@ func TestSSOStartStepup(t *testing.T) { })) require.NoError(t, err) w := httptest.NewRecorder() - urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, &http.Request{Header: http.Header{"Cookie": []string{"DSR=test"}}}, &descope.LoginOptions{Stepup: true, CustomClaims: map[string]interface{}{"k1": "v1"}}, w) + urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, "", &http.Request{Header: http.Header{"Cookie": []string{"DSR=test"}}}, &descope.LoginOptions{Stepup: true, CustomClaims: map[string]interface{}{"k1": "v1"}}, w) require.NoError(t, err) assert.EqualValues(t, uri, urlStr) assert.EqualValues(t, urlStr, w.Result().Header.Get(descope.RedirectLocationCookieName)) @@ -82,7 +101,7 @@ func TestSSOStartInvalidForwardResponse(t *testing.T) { _, err = a.SAML().Start(context.Background(), "", "", nil, nil, w) require.Error(t, err) - _, err = a.SSO().Start(context.Background(), "test", "", "", nil, &descope.LoginOptions{Stepup: true}, w) + _, err = a.SSO().Start(context.Background(), "test", "", "", "", nil, &descope.LoginOptions{Stepup: true}, w) assert.ErrorIs(t, err, descope.ErrInvalidStepUpJWT) } diff --git a/descope/sdk/auth.go b/descope/sdk/auth.go index 356a0e88..c5c2a44b 100644 --- a/descope/sdk/auth.go +++ b/descope/sdk/auth.go @@ -261,7 +261,8 @@ type SSOServiceProvider interface { // return will be the redirect URL that needs to return to client // and finalize with the ExchangeToken call // prompt argument relevant only in case tenant configured with AuthType OIDC - Start(ctx context.Context, tenant string, returnURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (redirectURL string, err error) + // ssoID can be used for providing the relevant SSO configuration (when having multiple SSO configurations per tenant) + Start(ctx context.Context, tenant string, returnURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (redirectURL string, err error) // ExchangeToken - Finalize tenant login authentication // code should be extracted from the redirect URL of SAML/OIDC authentication flow diff --git a/descope/tests/mocks/auth/authenticationmock.go b/descope/tests/mocks/auth/authenticationmock.go index 281ec258..5feb660d 100644 --- a/descope/tests/mocks/auth/authenticationmock.go +++ b/descope/tests/mocks/auth/authenticationmock.go @@ -509,7 +509,7 @@ func (m *MockSAML) ExchangeToken(_ context.Context, code string, w http.Response // Mock SSO type MockSSO struct { - StartAssert func(tenant string, redirectURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) + StartAssert func(tenant string, redirectURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) StartError error StartResponse string @@ -518,9 +518,9 @@ type MockSSO struct { ExchangeTokenResponse *descope.AuthenticationInfo } -func (m *MockSSO) Start(_ context.Context, tenant string, redirectURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (string, error) { +func (m *MockSSO) Start(_ context.Context, tenant string, redirectURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (string, error) { if m.StartAssert != nil { - m.StartAssert(tenant, redirectURL, prompt, r, loginOptions, w) + m.StartAssert(tenant, redirectURL, prompt, ssoID, r, loginOptions, w) } return m.StartResponse, m.StartError }