diff --git a/object/token.go b/object/token.go index 986b40e75743..6fbe840e4320 100644 --- a/object/token.go +++ b/object/token.go @@ -144,6 +144,48 @@ func getTokenByCode(code string) (*Token, error) { return nil, nil } +func GetTokenByAccessToken(accessToken string) (*Token, error) { + token := Token{AccessTokenHash: getTokenHash(accessToken)} + existed, err := ormer.Engine.Get(&token) + if err != nil { + return nil, err + } + + if !existed { + token = Token{AccessToken: accessToken} + existed, err = ormer.Engine.Get(&token) + if err != nil { + return nil, err + } + } + + if !existed { + return nil, nil + } + return &token, nil +} + +func GetTokenByRefreshToken(refreshToken string) (*Token, error) { + token := Token{RefreshTokenHash: getTokenHash(refreshToken)} + existed, err := ormer.Engine.Get(&token) + if err != nil { + return nil, err + } + + if !existed { + token = Token{RefreshToken: refreshToken} + existed, err = ormer.Engine.Get(&token) + if err != nil { + return nil, err + } + } + + if !existed { + return nil, nil + } + return &token, nil +} + func updateUsedByCode(token *Token) bool { affected, err := ormer.Engine.Where("code=?", token.Code).Cols("code_is_used").Update(token) if err != nil { @@ -219,18 +261,16 @@ func DeleteToken(token *Token) (bool, error) { } func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token, error) { - token := Token{AccessTokenHash: getTokenHash(accessToken)} - existed, err := ormer.Engine.Get(&token) + token, err := GetTokenByAccessToken(accessToken) if err != nil { return false, nil, nil, err } - - if !existed { + if token == nil { return false, nil, nil, nil } token.ExpiresIn = 0 - affected, err := ormer.Engine.ID(core.PK{token.Owner, token.Name}).Cols("expires_in").Update(&token) + affected, err := ormer.Engine.ID(core.PK{token.Owner, token.Name}).Cols("expires_in").Update(token) if err != nil { return false, nil, nil, err } @@ -240,22 +280,7 @@ func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token, e return false, nil, nil, err } - return affected != 0, application, &token, nil -} - -func GetTokenByAccessToken(accessToken string) (*Token, error) { - // Check if the accessToken is in the database - token := Token{AccessTokenHash: getTokenHash(accessToken)} - existed, err := ormer.Engine.Get(&token) - if err != nil { - return nil, err - } - - if !existed { - return nil, nil - } - - return &token, nil + return affected != 0, application, token, nil } func GetTokenByTokenAndApplication(token string, application string) (*Token, error) { @@ -457,16 +482,17 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId ErrorDescription: "client_id is invalid", }, nil } + if clientSecret != "" && application.ClientSecret != clientSecret { return &TokenError{ Error: InvalidClient, ErrorDescription: "client_secret is invalid", }, nil } + // check whether the refresh token is valid, and has not expired. - token := Token{RefreshTokenHash: getTokenHash(refreshToken)} - existed, err := ormer.Engine.Get(&token) - if err != nil || !existed { + token, err := GetTokenByRefreshToken(refreshToken) + if err != nil || token == nil { return &TokenError{ Error: InvalidGrant, ErrorDescription: "refresh token is invalid, expired or revoked", @@ -477,6 +503,12 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId if err != nil { return nil, err } + if cert == nil { + return &TokenError{ + Error: InvalidGrant, + ErrorDescription: fmt.Sprintf("cert: %s cannot be found", application.Cert), + }, nil + } _, err = ParseJwtToken(refreshToken, cert) if err != nil { @@ -485,6 +517,7 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId ErrorDescription: fmt.Sprintf("parse refresh token error: %s", err.Error()), }, nil } + // generate a new token user, err := getUser(application.Organization, token.User) if err != nil { @@ -502,6 +535,7 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId if err != nil { return nil, err } + newAccessToken, newRefreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) if err != nil { return &TokenError{ @@ -529,7 +563,7 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId return nil, err } - _, err = DeleteToken(&token) + _, err = DeleteToken(token) if err != nil { return nil, err } @@ -542,7 +576,6 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId ExpiresIn: newToken.ExpiresIn, Scope: newToken.Scope, } - return tokenWrapper, nil }