Skip to content

Commit

Permalink
Merge pull request #376 from bcgov/ssoteam-1848
Browse files Browse the repository at this point in the history
fix: user-session remover
  • Loading branch information
jlangy authored Aug 29, 2024
2 parents ee8f1ba + dcb5932 commit 50085ea
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .tool-versions
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ k6 0.34.1
terraform 1.2.0
terraform-docs 0.12.1
tflint 0.28.1
java openjdk-14.0.1
java openjdk-17.0.1
gradle 7.3.1
8 changes: 4 additions & 4 deletions docker/keycloak/configuration/24/quarkus.properties
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ quarkus.log.file.json.exception-output-type=formatted
quarkus.log.file.json.key-overrides=timestamp=@timestamp
quarkus.log.file.json.additional-field."@version".value=1
# Quarkus will auto-compress if ending with .zip: https://quarkus.io/guides/logging.
quarkus.log.file.rotation.file-suffix=.zip
quarkus.log.file.rotation.file-suffix=${QUARKUS_LOG_FILE_ROTATION_FILE_SUFFIX:.zip}
# Optional: Disable rotation by size (adjust value as needed)
quarkus.log.file.rotation.max-file-size=200M
# The number of rotated files. From above configuration, this will keep 200M * 42 files ~= 8Gigabytes of data before replacing.
quarkus.log.file.rotation.max-backup-index=42
quarkus.log.file.rotation.max-file-size=${QUARKUS_LOG_FILE_ROTATION_MAX_FILE_SIZE:200M}
# The number of rotated files per pod. From above configuration, this will keep 200M * 14 files * 3pods ~= 8Gigabytes of data before replacing.
quarkus.log.file.rotation.max-backup-index=${QUARKUS_LOG_FILE_ROTATION_MAX_BACKUP_INDEX:14}
22 changes: 20 additions & 2 deletions docker/keycloak/extensions-24/services/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
<target>17</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
</plugin>
</plugins>
</build>

Expand Down Expand Up @@ -136,18 +141,31 @@
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.9.5</version>
<artifactId>mockito-core</artifactId>
<version>5.3.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-all</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>5.9.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.9.1</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.authentication.AuthenticationFlowContext;
import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.models.UserSessionModel;

import java.util.Map;

Expand All @@ -23,34 +25,31 @@ public boolean requiresUser() {

@Override
public void authenticate(AuthenticationFlowContext context) {
AuthenticationSessionModel session = context.getAuthenticationSession();
AuthenticationManager.AuthResult authResult = AuthenticationManager.authenticateIdentityCookie(
context.getSession(),
context.getRealm(),
true
);
UserSessionModel userSessionModel;
AuthenticationManager.AuthResult authResult = AuthenticationManager.authenticateIdentityCookie(context.getSession(), context.getRealm(), true);

// 1. If no Cookie session, proceed to next step
if (authResult == null) {
context.attempted();
return;
}

// Need to use the KeycloakSession context to get the authenticating client ID. Not available on the AuthenticationFlowContext.
KeycloakSession keycloakSession = context.getSession();
String authenticatingClientUUID = keycloakSession.getContext().getClient().getId();
userSessionModel = authResult.getSession();

// Get all existing sessions. If any session is associated with a different client, clear all user sessions.
UserSessionProvider userSessionProvider = keycloakSession.sessions();
Map<String, Long> activeClientSessionStats = userSessionProvider.getActiveClientSessionStats(context.getRealm(), false);
String authenticatingClientUUID = context.getSession().getContext().getClient().getId();
UserSessionProvider userSessionProvider = context.getSession().sessions();

for (String activeSessionClientUUID : activeClientSessionStats.keySet()) {
// Must fetch sessions from the user session model, user session provider has all session in the realm
Map<String, AuthenticatedClientSessionModel> authenticatedClientSessions = userSessionModel.getAuthenticatedClientSessions();

for (String activeSessionClientUUID : authenticatedClientSessions.keySet()) {
if (!activeSessionClientUUID.equals(authenticatingClientUUID)) {
userSessionProvider.removeUserSession(context.getRealm(), authResult.getSession());
userSessionProvider.removeUserSession(context.getRealm(), userSessionModel);
}
}

context.attempted();
return;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package com.github.bcgov.keycloak.testsuite.authenticators;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import org.mockito.Mockito;
import org.mockito.MockedStatic;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;

import com.github.bcgov.keycloak.authenticators.UserSessionRemover;
import org.keycloak.authentication.AuthenticationFlowContext;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.RealmModel;
import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.models.ClientModel;
import org.keycloak.sessions.AuthenticationSessionModel;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.KeycloakContext;
import java.util.HashMap;
import java.util.Map;

public class UserSessionRemoverTest {
private static final UserSessionRemover userSessionRemover = new UserSessionRemover();

private AuthenticationFlowContext context;
private KeycloakSession session;
private RealmModel realm;
private AuthenticationSessionModel authSession;
private UserSessionProvider userSessionProvider;
private KeycloakSession keycloakSession;
private ClientModel client;
private KeycloakContext keycloakContext;
private AuthenticationManager.AuthResult authResult;
private UserSessionModel userSessionModel;
private AuthenticatedClientSessionModel authenticatedClientSessionModel;

@BeforeEach
public void setup() {
// Initialize mocks for necessary objects
context = mock(AuthenticationFlowContext.class);
realm = mock(RealmModel.class);
authSession = mock(AuthenticationSessionModel.class);
userSessionProvider = mock(UserSessionProvider.class);
keycloakSession = mock(KeycloakSession.class);
keycloakContext = mock(KeycloakContext.class);
client = mock(ClientModel.class);
authResult = mock(AuthenticationManager.AuthResult.class);
userSessionModel = mock(UserSessionModel.class);
authenticatedClientSessionModel = mock(AuthenticatedClientSessionModel.class);


// Set up common behavior of the mocks
when(context.getSession()).thenReturn(keycloakSession);
when(context.getRealm()).thenReturn(realm);
when(context.getAuthenticationSession()).thenReturn(authSession);
when(keycloakSession.sessions()).thenReturn(userSessionProvider);
when(context.getSession()).thenReturn(keycloakSession);
when(keycloakSession.getContext()).thenReturn(keycloakContext);
when(keycloakContext.getClient()).thenReturn(client);
when(authResult.getSession()).thenReturn(userSessionModel);
}

@Test
public void testSkipClientSessionCheckWhenNullAuthResult() throws Exception {
try (MockedStatic<AuthenticationManager> authenticationManager = Mockito.mockStatic(AuthenticationManager.class)) {
authenticationManager.when(() -> AuthenticationManager.authenticateIdentityCookie(
any(KeycloakSession.class), any(RealmModel.class), any(Boolean.class)
)).thenReturn(null);
userSessionRemover.authenticate(context);

// Keycloak Session Context check skipped if no Auth Session
verify(keycloakSession, times(0)).getContext();
verify(userSessionProvider, times(0)).removeUserSession(any(RealmModel.class), any(UserSessionModel.class));
}
}

@Test
public void testRemovesUserSessionsWhenMultipleClientSessionsExist() throws Exception {
when(client.getId()).thenReturn("client1");
Map<String, AuthenticatedClientSessionModel> authenticatedClientSessions = new HashMap<>();
authenticatedClientSessions.put("client1", authenticatedClientSessionModel);
authenticatedClientSessions.put("client2", authenticatedClientSessionModel);

when(userSessionModel.getAuthenticatedClientSessions()).thenReturn(authenticatedClientSessions);

try (MockedStatic<AuthenticationManager> authenticationManager = Mockito.mockStatic(AuthenticationManager.class)) {
authenticationManager.when(() -> AuthenticationManager.authenticateIdentityCookie(
any(KeycloakSession.class), any(RealmModel.class), any(Boolean.class)
)).thenReturn(authResult);

userSessionRemover.authenticate(context);

verify(keycloakSession, times(1)).getContext();
verify(userSessionProvider, times(1)).removeUserSession(any(RealmModel.class), any(UserSessionModel.class));
}
}

@Test
public void testRemovesUserSessionsWhenSingleDifferentClientSessionFound() throws Exception {
when(client.getId()).thenReturn("client1");
Map<String, AuthenticatedClientSessionModel> authenticatedClientSessions = new HashMap<>();
authenticatedClientSessions.put("client2", authenticatedClientSessionModel);

when(userSessionModel.getAuthenticatedClientSessions()).thenReturn(authenticatedClientSessions);

try (MockedStatic<AuthenticationManager> authenticationManager = Mockito.mockStatic(AuthenticationManager.class)) {
authenticationManager.when(() -> AuthenticationManager.authenticateIdentityCookie(
any(KeycloakSession.class), any(RealmModel.class), any(Boolean.class)
)).thenReturn(authResult);
userSessionRemover.authenticate(context);

verify(keycloakSession, times(1)).getContext();
verify(userSessionProvider, times(1)).removeUserSession(any(RealmModel.class), any(UserSessionModel.class));
}
}

@Test
public void testLeavesExistingSessionWhenOnlyAssociatedToAuthenticatingClient() throws Exception {
when(client.getId()).thenReturn("client1");
Map<String, AuthenticatedClientSessionModel> authenticatedClientSessions = new HashMap<>();
authenticatedClientSessions.put("client1", authenticatedClientSessionModel);

when(userSessionModel.getAuthenticatedClientSessions()).thenReturn(authenticatedClientSessions);

try (MockedStatic<AuthenticationManager> authenticationManager = Mockito.mockStatic(AuthenticationManager.class)) {
authenticationManager.when(() -> AuthenticationManager.authenticateIdentityCookie(
any(KeycloakSession.class), any(RealmModel.class), any(Boolean.class)
)).thenReturn(authResult);
userSessionRemover.authenticate(context);

// Verify the keycloak session context is invoked to check client sessions
verify(keycloakSession, times(1)).getContext();

// Remove user session should be skipped
verify(userSessionProvider, times(0)).removeUserSession(any(RealmModel.class), any(UserSessionModel.class));
}
}
}

0 comments on commit 50085ea

Please sign in to comment.