Skip to content

Commit

Permalink
feat(ee): saml idp initiated sso
Browse files Browse the repository at this point in the history
  • Loading branch information
lfleischmann authored Feb 13, 2025
1 parent 55d6efb commit 983000d
Show file tree
Hide file tree
Showing 17 changed files with 469 additions and 49 deletions.
235 changes: 198 additions & 37 deletions backend/ee/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (
auditlog "github.com/teamhanko/hanko/backend/audit_log"
"github.com/teamhanko/hanko/backend/ee/saml/dto"
"github.com/teamhanko/hanko/backend/ee/saml/provider"
samlUtils "github.com/teamhanko/hanko/backend/ee/saml/utils"
"github.com/teamhanko/hanko/backend/persistence/models"
"github.com/teamhanko/hanko/backend/session"
"github.com/teamhanko/hanko/backend/thirdparty"
"github.com/teamhanko/hanko/backend/utils"
"net/http"
"net/url"
"strings"
"time"
)

type Handler struct {
Expand Down Expand Up @@ -97,48 +99,132 @@ func (handler *Handler) Auth(c echo.Context) error {
return c.Redirect(http.StatusTemporaryRedirect, redirectUrl)
}

func (handler *Handler) CallbackPost(c echo.Context) error {
state, samlError := VerifyState(handler.samlService.Config(), handler.samlService.Persister().GetSamlStatePersister(), c.FormValue("RelayState"))
if samlError != nil {
func (handler *Handler) callbackPostIdPInitiated(c echo.Context, samlResponse string) error {
// ignore URL parse error because config validation already ensures it is a parseable URL
redirectTo, _ := url.Parse(handler.samlService.Config().Saml.DefaultRedirectUrl)

// We need to already parse the response to be able to extract information (a response's ID, Issuer, InResponseTo
// nodes/values) to ensure protection against replaying IDP initiated responses as well as using service provider
// issued responses as IDP initiated responses, even though we later also use the gosaml2 library to parse (and then
// also validate) the response _again_. The reason is that the gosaml2 library does not make this information
// easily/publicly accessible through its API.
parsedSamlResponseDocument, _, err := samlUtils.ParseSamlResponse(samlResponse)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest(samlError.Error()).WithCause(samlError),
handler.samlService.Config().Saml.DefaultRedirectUrl,
thirdparty.ErrorInvalidRequest("could not parse saml response").WithCause(err),
redirectTo.String(),
)
}

if strings.TrimSpace(state.RedirectTo) == "" {
state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl
responseElement := parsedSamlResponseDocument.FindElement("/Response")
if responseElement == nil {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest("invalid saml response: no response node present"),
redirectTo.String(),
)
}

redirectTo, samlError := url.Parse(state.RedirectTo)
if samlError != nil {
issuerElement := parsedSamlResponseDocument.FindElement("/Response/Issuer")
if issuerElement == nil || issuerElement.Text() == "" {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest("invalid saml response: no issuer node present"),
redirectTo.String(),
)
}

issuer := issuerElement.Text()

serviceProvider, err := handler.samlService.GetProviderByIssuer(issuer)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("unable to parse redirect url").WithCause(samlError),
handler.samlService.Config().Saml.DefaultRedirectUrl,
thirdparty.ErrorInvalidRequest(
fmt.Sprintf("could not get provider for issuer %s", issuer)).
WithCause(err),
redirectTo.String(),
)
}

foundProvider, samlError := handler.samlService.GetProviderByDomain(state.Provider)
if samlError != nil {
// We need to check whether this is an unsolicited request, otherwise SP initiated responses could
// be used as IDP initiated responses.
if responseElement.SelectAttr("InResponseTo") != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("unable to find provider by domain").WithCause(samlError),
thirdparty.ErrorInvalidRequest("saml request is not unsolicited"),
redirectTo.String(),
)
}

assertionInfo, samlError := handler.parseSamlResponse(foundProvider, c.FormValue("SAMLResponse"))
if samlError != nil {
assertionInfo, err := handler.getAssertionInfo(serviceProvider, samlResponse)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest("could not get assertion info").WithCause(err),
redirectTo.String(),
)
}

samlResponseIDAttr := responseElement.SelectAttr("ID")
if samlResponseIDAttr == nil {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest("invalid saml response: no ID for response present"),
redirectTo.String(),
)
}

samlResponseID := samlResponseIDAttr.Value

samlIDPInitiatedRequestPersister := handler.samlService.Persister().GetSamlIDPInitiatedRequestPersister()

// We use the SAML response's ID to prevent replay attacks by persisting every IDP initiated request and
// checking whether an IDP initiated request already exists for this request.
existingSamlIDPInitiatedRequest, err := samlIDPInitiatedRequestPersister.GetByResponseIDAndIssuer(samlResponseID, issuer)
if existingSamlIDPInitiatedRequest != nil {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest("attempting to replay unsolicited saml request"),
redirectTo.String(),
)
}

// We assume only one assertion, and we assume it is present because we already validated it using the gosaml2
// library (which also consumes only one/the first assertion). We also assume assertion conditions are present
// because validation assures it is not nil (or else it returns an error).
expiresAtString := assertionInfo.Assertions[0].Conditions.NotOnOrAfter

expiresAt, err := time.Parse(time.RFC3339, expiresAtString)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("unable to parse saml response").WithCause(samlError),
thirdparty.ErrorServer("could not parse saml assertion conditions' NotOnOrAfter value").WithCause(err),
redirectTo.String(),
)
}

redirectUrl, samlError := handler.linkAccount(c, redirectTo, state, foundProvider, assertionInfo)
// If no request exists we create a new IDP initiated request model and persist it.
samlIDPInitiatedRequest, err := models.NewSamlIDPInitiatedRequest(samlResponseID, issuer, expiresAt)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("could not instantiate saml idp initiated request model").WithCause(err),
redirectTo.String(),
)
}

err = samlIDPInitiatedRequestPersister.Create(*samlIDPInitiatedRequest)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("could not persist saml idp initiated request"),
redirectTo.String(),
)
}

redirectUrl, samlError := handler.linkAccount(c, redirectTo, true, serviceProvider, assertionInfo)
if samlError != nil {
return handler.redirectError(
c,
Expand All @@ -147,38 +233,113 @@ func (handler *Handler) CallbackPost(c echo.Context) error {
)
}

// Add hint to the redirect URL that this is an IDP initiated request so that a token exchange can
// eventually be performed through the dedicated flow API handler.
values := redirectUrl.Query()
values.Add("saml_hint", "idp_initiated")
redirectUrl.RawQuery = values.Encode()

return c.Redirect(http.StatusFound, redirectUrl.String())
}

func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state *State, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) {
func (handler *Handler) CallbackPost(c echo.Context) error {
relayState := c.FormValue("RelayState")
samlResponse := c.FormValue("SAMLResponse")

if handler.isIDPInitiated(relayState) {
return handler.callbackPostIdPInitiated(c, samlResponse)
} else {
state, err := VerifyState(
handler.samlService.Config(),
handler.samlService.Persister().GetSamlStatePersister(),
strings.TrimPrefix(relayState, statePrefixServiceProviderInitiated),
)

if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err),
handler.samlService.Config().Saml.DefaultRedirectUrl,
)
}

if strings.TrimSpace(state.RedirectTo) == "" {
state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl
}

redirectTo, err := url.Parse(state.RedirectTo)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("unable to parse redirect url").WithCause(err),
handler.samlService.Config().Saml.DefaultRedirectUrl,
)
}

foundProvider, err := handler.samlService.GetProviderByDomain(state.Provider)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("unable to find provider by domain").WithCause(err),
redirectTo.String(),
)
}

assertionInfo, err := handler.getAssertionInfo(foundProvider, samlResponse)
if err != nil {
return handler.redirectError(
c,
thirdparty.ErrorServer("unable to parse saml response").WithCause(err),
redirectTo.String(),
)
}

redirectUrl, err := handler.linkAccount(c, redirectTo, state.IsFlow, foundProvider, assertionInfo)
if err != nil {
return handler.redirectError(
c,
err,
redirectTo.String(),
)
}

return c.Redirect(http.StatusFound, redirectUrl.String())
}
}

func (handler *Handler) isIDPInitiated(relayState string) bool {
return !strings.HasPrefix(relayState, statePrefixServiceProviderInitiated)
}

func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, isFlow bool, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) {
var accountLinkingResult *thirdparty.AccountLinkingResult
var samlError error
samlError = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error {
var err error
err = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error {
userdata := provider.GetUserData(assertionInfo)
identityProviderIssuer := assertionInfo.Assertions[0].Issuer
samlDomain := provider.GetDomain()
linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, state.IsFlow)
if samlErrorTx != nil {
return samlErrorTx
linkResult, errTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, isFlow)
if errTx != nil {
return errTx
}

accountLinkingResult = linkResult

emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email)
identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject)

token, tokenError := models.NewToken(
token, errTx := models.NewToken(
linkResult.User.ID,
models.TokenWithIdentityID(identityModel.ID),
models.TokenForFlowAPI(state.IsFlow),
models.TokenForFlowAPI(isFlow),
models.TokenUserCreated(linkResult.UserCreated))
if tokenError != nil {
return thirdparty.ErrorServer("could not create token").WithCause(tokenError)
if errTx != nil {
return thirdparty.ErrorServer("could not create token").WithCause(errTx)
}

tokenError = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token)
if tokenError != nil {
return thirdparty.ErrorServer("could not save token to db").WithCause(tokenError)
errTx = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token)
if errTx != nil {
return thirdparty.ErrorServer("could not save token to db").WithCause(errTx)
}

query := redirectTo.Query()
Expand All @@ -188,20 +349,20 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state *
return nil
})

if samlError != nil {
return nil, samlError
if err != nil {
return nil, err
}

samlError = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil)
err = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil)

if samlError != nil {
return nil, samlError
if err != nil {
return nil, err
}

return redirectTo, nil
}

func (handler *Handler) parseSamlResponse(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) {
func (handler *Handler) getAssertionInfo(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) {
assertionInfo, err := provider.GetService().RetrieveAssertionInfo(samlResponse)
if err != nil {
return nil, thirdparty.ErrorServer("unable to parse SAML response").WithCause(err)
Expand Down
11 changes: 11 additions & 0 deletions backend/ee/saml/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Service interface {
Persister() persistence.Persister
Providers() []provider.ServiceProvider
GetProviderByDomain(domain string) (provider.ServiceProvider, error)
GetProviderByIssuer(issuer string) (provider.ServiceProvider, error)
GetAuthUrl(provider provider.ServiceProvider, redirectTo string, isFlow bool) (string, error)
}

Expand Down Expand Up @@ -83,6 +84,16 @@ func (s *defaultService) GetProviderByDomain(domain string) (provider.ServicePro
return nil, fmt.Errorf("unknown provider for domain %s", domain)
}

func (s *defaultService) GetProviderByIssuer(issuer string) (provider.ServiceProvider, error) {
for _, availableProvider := range s.providers {
if availableProvider.GetService().IdentityProviderIssuer == issuer {
return availableProvider, nil
}
}

return nil, fmt.Errorf("unknown provider for issuer %s", issuer)
}

func (s *defaultService) GetAuthUrl(provider provider.ServiceProvider, redirectTo string, isFlow bool) (string, error) {
if ok := samlUtils.IsAllowedRedirect(s.config.Saml, redirectTo); !ok {
return "", thirdparty.ErrorInvalidRequest(fmt.Sprintf("redirect to '%s' not allowed", redirectTo))
Expand Down
6 changes: 5 additions & 1 deletion backend/ee/saml/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type State struct {
IsFlow bool `json:"is_flow"`
}

const statePrefixServiceProviderInitiated = "hanko_spi_"

func GenerateStateForFlowAPI(isFlow bool) func(*State) {
return func(state *State) {
state.IsFlow = isFlow
Expand Down Expand Up @@ -77,7 +79,9 @@ func GenerateState(config *config.Config, persister persistence.SamlStatePersist
return nil, fmt.Errorf("could not save state to db: %w", err)
}

return []byte(encryptedState), nil
// Add prefix to distinguish between SP initiated and IDP initiated requests in callback handler.
result := fmt.Sprintf("%s%s", statePrefixServiceProviderInitiated, encryptedState)
return []byte(result), nil
}

func VerifyState(config *config.Config, persister persistence.SamlStatePersister, state string) (*State, error) {
Expand Down
Loading

0 comments on commit 983000d

Please sign in to comment.