Skip to content

Commit

Permalink
WPB-1334 extend list of OAuth apps with active refresh token ids (#4211)
Browse files Browse the repository at this point in the history
  • Loading branch information
battermann authored Aug 16, 2024
1 parent fba266f commit 9481a88
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 16 deletions.
1 change: 1 addition & 0 deletions changelog.d/2-features/WPB-1334
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adds a field which contains a list of all active sessions to each OAuth application in the response of `GET /oauth/applications`
1 change: 1 addition & 0 deletions integration/integration.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ library
Test.MLS.SubConversation
Test.MLS.Unreachable
Test.Notifications
Test.OAuth
Test.Presence
Test.Property
Test.Provider
Expand Down
38 changes: 38 additions & 0 deletions integration/test/API/Brig.hs
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,41 @@ clearProperties :: (MakesValue user) => user -> App Response
clearProperties user = do
req <- baseRequest user Brig Versioned $ joinHttpPath ["properties"]
submit "DELETE" req

-- | https://staging-nginz-https.zinfra.io/v6/api/swagger-ui/#/default/post_oauth_authorization_codes
generateOAuthAuthorizationCode :: (HasCallStack, MakesValue user, MakesValue cid) => user -> cid -> [String] -> String -> App Response
generateOAuthAuthorizationCode user cid scopes redirectUrl = do
cidStr <- asString cid
req <- baseRequest user Brig Versioned "/oauth/authorization/codes"
submit "POST" $
req
& addJSONObject
[ "client_id" .= cidStr,
"scope" .= unwords scopes,
"redirect_uri" .= redirectUrl,
"code_challenge" .= "G7CWLBqYDT8doT_oEIN3un_QwZWYKHmOqG91nwNzITc",
"code_challenge_method" .= "S256",
"response_type" .= "code",
"state" .= "abc"
]

-- | https://staging-nginz-https.zinfra.io/v6/api/swagger-ui/#/default/post_oauth_token
createOAuthAccessToken :: (HasCallStack, MakesValue user, MakesValue cid) => user -> cid -> String -> String -> App Response
createOAuthAccessToken user cid code redirectUrl = do
cidStr <- asString cid
req <- baseRequest user Brig Versioned "/oauth/token"
submit "POST" $
req
& addUrlEncodedForm
[ ("grant_type", "authorization_code"),
("client_id", cidStr),
("code_verifier", "nE3k3zykOmYki~kriKzAmeFiGT7cWugcuToFwo1YPgrZ1cFvaQqLa.dXY9MnDj3umAmG-8lSNIYIl31Cs_.fV5r2psa4WWZcB.Nlc3A-t3p67NDZaOJjIiH~8PvUH_hR"),
("code", code),
("redirect_uri", redirectUrl)
]

-- | https://staging-nginz-https.zinfra.io/v6/api/swagger-ui/#/default/get_oauth_applications
getOAuthApplications :: (HasCallStack, MakesValue user) => user -> App Response
getOAuthApplications user = do
req <- baseRequest user Brig Versioned "/oauth/applications"
submit "GET" req
6 changes: 6 additions & 0 deletions integration/test/API/BrigInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,9 @@ deleteFeatureForUser user featureName = do
uid <- objId user
req <- baseRequest user Brig Unversioned $ joinHttpPath ["i", "users", uid, "features", featureName]
submit "DELETE" req

-- | https://staging-nginz-https.zinfra.io/api-internal/swagger-ui/brig/#/brig/post_i_oauth_clients
createOAuthClient :: (HasCallStack, MakesValue user) => user -> String -> String -> App Response
createOAuthClient user name url = do
req <- baseRequest user Brig Unversioned "i/oauth/clients"
submit "POST" $ req & addJSONObject ["application_name" .= name, "redirect_url" .= url]
25 changes: 25 additions & 0 deletions integration/test/Test/OAuth.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Test.OAuth where

import API.Brig
import API.BrigInternal
import Data.String.Conversions
import Network.HTTP.Types
import Network.URI
import SetupHelpers
import Testlib.Prelude

testListApplicationsWithActiveSessions :: (HasCallStack) => App ()
testListApplicationsWithActiveSessions = do
user <- randomUser OwnDomain def
oauthClient <- createOAuthClient user "foobar" "https://example.com" >>= getJSON 200
cid <- oauthClient %. "client_id"
let scopes = ["write:conversations"]
let generateAccessToken = do
authCodeResponse <- generateOAuthAuthorizationCode user cid scopes "https://example.com"
let location = fromMaybe (error "no location header") $ parseURI . cs . snd =<< locationHeader authCodeResponse
let code = maybe "no code query param" cs $ join $ lookup (cs "code") $ parseQuery $ cs location.uriQuery
void $ createOAuthAccessToken user cid code "https://example.com" >>= getJSON 200
replicateM_ 2 generateAccessToken
[app] <- getOAuthApplications user >>= getJSON 200 >>= asList
sessions <- app %. "sessions" >>= asList
length sessions `shouldMatchInt` 2
10 changes: 10 additions & 0 deletions integration/test/Testlib/HTTP.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Data.String
import Data.String.Conversions (cs)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Tuple.Extra
import GHC.Generics
import GHC.Stack
import qualified Network.HTTP.Client as HTTP
Expand All @@ -41,6 +42,15 @@ addJSONObject = addJSON . Aeson.object
addJSON :: (Aeson.ToJSON a) => a -> HTTP.Request -> HTTP.Request
addJSON obj = addBody (HTTP.RequestBodyLBS (Aeson.encode obj)) "application/json"

addUrlEncodedForm :: [(String, String)] -> HTTP.Request -> HTTP.Request
addUrlEncodedForm form req =
req
{ HTTP.requestBody = HTTP.RequestBodyLBS (L.fromStrict (HTTP.renderSimpleQuery False (both C8.pack <$> form))),
HTTP.requestHeaders =
(fromString "Content-Type", fromString "application/x-www-form-urlencoded")
: HTTP.requestHeaders req
}

addBody :: HTTP.RequestBody -> String -> HTTP.Request -> HTTP.Request
addBody body contentType req =
req
Expand Down
30 changes: 25 additions & 5 deletions libs/wire-api/src/Wire/API/OAuth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import Data.ByteString.Lazy (fromStrict, toStrict)
import Data.Either.Combinators (mapLeft)
import Data.HashMap.Strict qualified as HM
import Data.Id as Id
import Data.Json.Util
import Data.OpenApi (ToParamSchema (..))
import Data.OpenApi qualified as S
import Data.Range
Expand Down Expand Up @@ -650,9 +651,28 @@ instance ToSchema OAuthRevokeRefreshTokenRequest where
clientIdDescription = description ?~ "The OAuth client's ID"
refreshTokenDescription = description ?~ "The refresh token"

data OAuthSession = OAuthSession
{ refreshTokenId :: OAuthRefreshTokenId,
createdAt :: UTCTimeMillis
}
deriving (Eq, Show, Ord, Generic)
deriving (Arbitrary) via (GenericUniform OAuthSession)
deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthSession)

instance ToSchema OAuthSession where
schema =
object "OAuthSession" $
OAuthSession
<$> (.refreshTokenId) .= fieldWithDocModifier "refresh_token_id" refreshTokenIdDescription schema
<*> (.createdAt) .= fieldWithDocModifier "created_at" createdAtDescription schema
where
refreshTokenIdDescription = description ?~ "The ID of the refresh token"
createdAtDescription = description ?~ "The time when the session was created"

data OAuthApplication = OAuthApplication
{ applicationId :: OAuthClientId,
name :: OAuthApplicationName
name :: OAuthApplicationName,
sessions :: [OAuthSession]
}
deriving (Eq, Show, Ord, Generic)
deriving (Arbitrary) via (GenericUniform OAuthApplication)
Expand All @@ -662,13 +682,13 @@ instance ToSchema OAuthApplication where
schema =
object "OAuthApplication" $
OAuthApplication
<$> applicationId
.= fieldWithDocModifier "id" idDescription schema
<*> (.name)
.= fieldWithDocModifier "name" nameDescription schema
<$> applicationId .= fieldWithDocModifier "id" idDescription schema
<*> (.name) .= fieldWithDocModifier "name" nameDescription schema
<*> sessions .= fieldWithDocModifier "sessions" sessionsDescription (array schema)
where
idDescription = description ?~ "The OAuth client's ID"
nameDescription = description ?~ "The OAuth client's name"
sessionsDescription = description ?~ "The OAuth client's sessions"

--------------------------------------------------------------------------------
-- Errors
Expand Down
15 changes: 11 additions & 4 deletions services/brig/src/Brig/API/OAuth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import Crypto.JWT hiding (params, uri)
import Data.ByteString.Conversion
import Data.Domain
import Data.Id
import Data.Json.Util (toUTCTimeMillis)
import Data.Map qualified as Map
import Data.Misc
import Data.Set qualified as Set
import Data.Text.Ascii
Expand Down Expand Up @@ -320,10 +322,15 @@ lookupAndVerifyToken key =
getOAuthApplications :: UserId -> (Handler r) [OAuthApplication]
getOAuthApplications uid = do
activeRefreshTokens <- lift $ wrapClient $ lookupOAuthRefreshTokens uid
nub . catMaybes <$> for activeRefreshTokens oauthApp
toApplications activeRefreshTokens
where
oauthApp :: OAuthRefreshTokenInfo -> (Handler r) (Maybe OAuthApplication)
oauthApp info = (OAuthApplication info.clientId . (.name)) <$$> getOAuthClient info.userId info.clientId
toApplications :: [OAuthRefreshTokenInfo] -> (Handler r) [OAuthApplication]
toApplications infos = do
let grouped = Map.fromListWith (<>) $ (\info -> (info.clientId, [info])) <$> infos
mApps <- for (Map.toList grouped) $ \(cid, tokens) -> do
mClient <- getOAuthClient uid cid
pure $ (\client -> OAuthApplication cid client.name ((\i -> OAuthSession i.refreshTokenId (toUTCTimeMillis i.createdAt)) <$> tokens)) <$> mClient
pure $ catMaybes mApps

--------------------------------------------------------------------------------

Expand Down Expand Up @@ -404,7 +411,7 @@ insertOAuthRefreshToken maxActiveTokens ttl info = do
determineOldestTokensToBeDeleted tokens =
take (length sorted - fromIntegral maxActiveTokens + 1) sorted
where
sorted = sortOn createdAt tokens
sorted = sortOn (.createdAt) tokens

lookupOAuthRefreshTokens :: (MonadClient m) => UserId -> m [OAuthRefreshTokenInfo]
lookupOAuthRefreshTokens uid = do
Expand Down
14 changes: 7 additions & 7 deletions services/brig/test/integration/API/OAuth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
-- with this program. If not, see <https://www.gnu.org/licenses/>.
{-# OPTIONS_GHC -fno-warn-orphans #-}

module API.OAuth where
module API.OAuth (tests) where

import API.Team.Util qualified as Team
import Bilge
Expand Down Expand Up @@ -439,7 +439,7 @@ testRefreshTokenMaxActiveTokens opts db brig =
resp <- createOAuthAccessToken brig accessTokenRequest
rid <- extractRefreshTokenId jwk resp.refreshToken
tokens <- C.runClient db (lookupOAuthRefreshTokens uid)
liftIO $ assertBool testMsg $ [rid] `hasSameElems` (refreshTokenId <$> tokens)
liftIO $ assertBool testMsg $ [rid] `hasSameElems` ((.refreshTokenId) <$> tokens)
pure (rid, cid, secret)
delayOneSec
rid2 <- do
Expand All @@ -449,7 +449,7 @@ testRefreshTokenMaxActiveTokens opts db brig =
resp <- createOAuthAccessToken brig accessTokenRequest
rid <- extractRefreshTokenId jwk resp.refreshToken
tokens <- C.runClient db (lookupOAuthRefreshTokens uid)
liftIO $ assertBool testMsg $ [rid1, rid] `hasSameElems` (refreshTokenId <$> tokens)
liftIO $ assertBool testMsg $ [rid1, rid] `hasSameElems` ((.refreshTokenId) <$> tokens)
pure rid
delayOneSec
rid3 <- do
Expand All @@ -460,7 +460,7 @@ testRefreshTokenMaxActiveTokens opts db brig =
rid <- extractRefreshTokenId jwk resp.refreshToken
recoverN 3 $ do
tokens <- C.runClient db (lookupOAuthRefreshTokens uid)
liftIO $ assertBool testMsg $ [rid2, rid] `hasSameElems` (refreshTokenId <$> tokens)
liftIO $ assertBool testMsg $ [rid2, rid] `hasSameElems` ((.refreshTokenId) <$> tokens)
pure rid
delayOneSec
do
Expand All @@ -470,7 +470,7 @@ testRefreshTokenMaxActiveTokens opts db brig =
resp <- createOAuthAccessToken brig accessTokenRequest
rid <- extractRefreshTokenId jwk resp.refreshToken
tokens <- C.runClient db (lookupOAuthRefreshTokens uid)
liftIO $ assertBool testMsg $ [rid3, rid] `hasSameElems` (refreshTokenId <$> tokens)
liftIO $ assertBool testMsg $ [rid3, rid] `hasSameElems` ((.refreshTokenId) <$> tokens)
where
extractRefreshTokenId :: (MonadIO m) => JWK -> OAuthRefreshToken -> m OAuthRefreshTokenId
extractRefreshTokenId jwk rt = do
Expand Down Expand Up @@ -609,14 +609,14 @@ testListApplicationsWithAccountAccess brig = do
bob <- createUser "bob" brig
do
apps <- listOAuthApplications brig (User.userId alice)
liftIO $ assertEqual "apps" 0 (length apps)
liftIO $ apps @?= []
void $ createOAuthApplicationWithAccountAccess brig (User.userId alice)
void $ createOAuthApplicationWithAccountAccess brig (User.userId alice)
do
aliceApps <- listOAuthApplications brig (User.userId alice)
liftIO $ assertEqual "apps" 2 (length aliceApps)
bobsApps <- listOAuthApplications brig (User.userId bob)
liftIO $ assertEqual "apps" 0 (length bobsApps)
liftIO $ bobsApps @?= []
void $ createOAuthApplicationWithAccountAccess brig (User.userId alice)
void $ createOAuthApplicationWithAccountAccess brig (User.userId bob)
do
Expand Down

0 comments on commit 9481a88

Please sign in to comment.