Skip to content

Commit

Permalink
Multi SSO - allow to pass sso ID for sso start method (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorsha authored Dec 10, 2024
1 parent 1acae10 commit 5e9b320
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
6 changes: 5 additions & 1 deletion descope/internal/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ssoStartResponse struct {
URL string `json:"url"`
}

func (auth *sso) Start(ctx context.Context, tenant string, redirectURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (url string, err error) {
func (auth *sso) Start(ctx context.Context, tenant string, redirectURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (url string, err error) {
if tenant == "" {
return "", utils.NewInvalidArgumentError("tenant")
}
Expand All @@ -31,6 +31,10 @@ func (auth *sso) Start(ctx context.Context, tenant string, redirectURL string, p
if len(prompt) > 0 {
m["prompt"] = prompt
}
if len(ssoID) > 0 {
m["ssoId"] = ssoID
}

var pswd string
if loginOptions.IsJWTRequired() {
pswd, err = getValidRefreshToken(r)
Expand Down
27 changes: 23 additions & 4 deletions descope/internal/auth/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,26 @@ func TestSSOStart(t *testing.T) {
}))
require.NoError(t, err)
w := httptest.NewRecorder()
urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, nil, nil, w)
urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, "", nil, nil, w)
require.NoError(t, err)
assert.EqualValues(t, uri, urlStr)
assert.EqualValues(t, urlStr, w.Result().Header.Get(descope.RedirectLocationCookieName))
assert.EqualValues(t, http.StatusTemporaryRedirect, w.Result().StatusCode)
}

func TestSSOStartWithSSOID(t *testing.T) {
uri := "http://test.me"
tenant := "tenantID"
prompt := "none"
landingURL := "https://test.com"
ssoID := "lu lu"
a, err := newTestAuth(nil, DoRedirect(uri, func(r *http.Request) {
assert.EqualValues(t, fmt.Sprintf("%s?prompt=%s&redirectURL=%s&ssoId=%s&tenant=%s", composeSSOStartURL(), prompt, url.QueryEscape(landingURL), url.QueryEscape(ssoID), tenant), r.URL.RequestURI())
assert.Nil(t, r.Body)
}))
require.NoError(t, err)
w := httptest.NewRecorder()
urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, ssoID, nil, nil, w)
require.NoError(t, err)
assert.EqualValues(t, uri, urlStr)
assert.EqualValues(t, urlStr, w.Result().Header.Get(descope.RedirectLocationCookieName))
Expand All @@ -44,7 +63,7 @@ func TestSSOStartFailureNoTenant(t *testing.T) {
landingURL := "https://test.com"
prompt := "none"
tenant := ""
_, err = a.SSO().Start(context.Background(), tenant, landingURL, prompt, nil, nil, w)
_, err = a.SSO().Start(context.Background(), tenant, landingURL, prompt, "", nil, nil, w)
require.ErrorIs(t, err, utils.NewInvalidArgumentError("tenant"))
}

Expand All @@ -68,7 +87,7 @@ func TestSSOStartStepup(t *testing.T) {
}))
require.NoError(t, err)
w := httptest.NewRecorder()
urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, &http.Request{Header: http.Header{"Cookie": []string{"DSR=test"}}}, &descope.LoginOptions{Stepup: true, CustomClaims: map[string]interface{}{"k1": "v1"}}, w)
urlStr, err := a.SSO().Start(context.Background(), tenant, landingURL, prompt, "", &http.Request{Header: http.Header{"Cookie": []string{"DSR=test"}}}, &descope.LoginOptions{Stepup: true, CustomClaims: map[string]interface{}{"k1": "v1"}}, w)
require.NoError(t, err)
assert.EqualValues(t, uri, urlStr)
assert.EqualValues(t, urlStr, w.Result().Header.Get(descope.RedirectLocationCookieName))
Expand All @@ -82,7 +101,7 @@ func TestSSOStartInvalidForwardResponse(t *testing.T) {
_, err = a.SAML().Start(context.Background(), "", "", nil, nil, w)
require.Error(t, err)

_, err = a.SSO().Start(context.Background(), "test", "", "", nil, &descope.LoginOptions{Stepup: true}, w)
_, err = a.SSO().Start(context.Background(), "test", "", "", "", nil, &descope.LoginOptions{Stepup: true}, w)
assert.ErrorIs(t, err, descope.ErrInvalidStepUpJWT)
}

Expand Down
3 changes: 2 additions & 1 deletion descope/sdk/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ type SSOServiceProvider interface {
// return will be the redirect URL that needs to return to client
// and finalize with the ExchangeToken call
// prompt argument relevant only in case tenant configured with AuthType OIDC
Start(ctx context.Context, tenant string, returnURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (redirectURL string, err error)
// ssoID can be used for providing the relevant SSO configuration (when having multiple SSO configurations per tenant)
Start(ctx context.Context, tenant string, returnURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (redirectURL string, err error)

// ExchangeToken - Finalize tenant login authentication
// code should be extracted from the redirect URL of SAML/OIDC authentication flow
Expand Down
6 changes: 3 additions & 3 deletions descope/tests/mocks/auth/authenticationmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func (m *MockSAML) ExchangeToken(_ context.Context, code string, w http.Response
// Mock SSO

type MockSSO struct {
StartAssert func(tenant string, redirectURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter)
StartAssert func(tenant string, redirectURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter)
StartError error
StartResponse string

Expand All @@ -518,9 +518,9 @@ type MockSSO struct {
ExchangeTokenResponse *descope.AuthenticationInfo
}

func (m *MockSSO) Start(_ context.Context, tenant string, redirectURL string, prompt string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (string, error) {
func (m *MockSSO) Start(_ context.Context, tenant string, redirectURL string, prompt string, ssoID string, r *http.Request, loginOptions *descope.LoginOptions, w http.ResponseWriter) (string, error) {
if m.StartAssert != nil {
m.StartAssert(tenant, redirectURL, prompt, r, loginOptions, w)
m.StartAssert(tenant, redirectURL, prompt, ssoID, r, loginOptions, w)
}
return m.StartResponse, m.StartError
}
Expand Down

0 comments on commit 5e9b320

Please sign in to comment.