From 9ea510d6612c3d1e34b06a7e817adc4daa6d7368 Mon Sep 17 00:00:00 2001 From: atakavci Date: Tue, 10 Dec 2024 12:13:24 +0300 Subject: [PATCH] -reviews from @sazzad16 --- .../java/redis/clients/jedis/Connection.java | 9 +++---- .../clients/jedis/ConnectionFactory.java | 26 ++++++++++--------- .../redis/clients/jedis/ConnectionPool.java | 9 ++++--- .../authentication/TokenCredentials.java | 2 +- .../RedisEntraIDIntegrationTests.java | 2 +- ...AuthenticationClusterIntegrationTests.java | 2 +- .../TokenBasedAuthenticationUnitTests.java | 9 +------ 7 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index fe57fb80fb..a59d359e34 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -560,7 +560,7 @@ public void setCredentials(RedisCredentials credentials) { currentCredentials.set(credentials); } - public void authenticate(RedisCredentials credentials) { + private void authenticate(RedisCredentials credentials) { if (credentials == null || credentials.getPassword() == null) { return; } @@ -577,11 +577,8 @@ public void authenticate(RedisCredentials credentials) { getStatusCodeReply(); } - public void reAuth() { - RedisCredentials temp = currentCredentials.getAndSet(null); - if (temp != null) { - authenticate(temp); - } + public void reAuthenticate() { + authenticate(currentCredentials.getAndSet(null)); } protected Map hello(byte[]... args) { diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index 45e89fc2da..6ce7c3663e 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -104,12 +104,7 @@ public PooledObject makeObject() throws Exception { public void passivateObject(PooledObject pooledConnection) throws Exception { // TODO maybe should select db 0? Not sure right now. Connection jedis = pooledConnection.getObject(); - try { - jedis.reAuth(); - } catch (Exception e) { - authXEventListener.onConnectionAuthenticationError(e); - throw e; - } + reAuthenticate(jedis); } @Override @@ -117,16 +112,23 @@ public boolean validateObject(PooledObject pooledConnection) { final Connection jedis = pooledConnection.getObject(); try { // check HostAndPort ?? - try { - jedis.reAuth(); - } catch (Exception e) { - authXEventListener.onConnectionAuthenticationError(e); - throw e; + if (!jedis.isConnected()) { + return false; } - return jedis.isConnected() && jedis.ping(); + reAuthenticate(jedis); + return jedis.ping(); } catch (final Exception e) { logger.warn("Error while validating pooled Connection object.", e); return false; } } + + private void reAuthenticate(Connection jedis) throws Exception { + try { + jedis.reAuthenticate(); + } catch (Exception e) { + authXEventListener.onConnectionAuthenticationError(e); + throw e; + } + } } diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index 536b3a6484..2ae1401081 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -56,10 +56,13 @@ public Connection getResource() { @Override public void close() { - if (authXManager != null) { - authXManager.stop(); + try { + if (authXManager != null) { + authXManager.stop(); + } + } finally { + super.close(); } - super.close(); } private void attachAuthenticationListener(AuthXManager authXManager) { diff --git a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java index 9c5a54f135..471c34bc40 100644 --- a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java +++ b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java @@ -3,7 +3,7 @@ import redis.clients.authentication.core.Token; import redis.clients.jedis.RedisCredentials; -public class TokenCredentials implements RedisCredentials { +class TokenCredentials implements RedisCredentials { private final String user; private final char[] password; diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java index 5d1a6d289b..b6010ca28f 100644 --- a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -305,7 +305,7 @@ public void allConnectionsReauthTest() throws InterruptedException, ExecutionExc } connections.forEach(conn -> { - verify(conn, atLeast(1)).reAuth(); + verify(conn, atLeast(1)).reAuthenticate(); }); executor.shutdown(); } diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java index d711804335..2b6e4e3256 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java @@ -118,7 +118,7 @@ public Token requestToken() { connections.forEach(conn -> { await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) - .untilAsserted(() -> verify(conn, atLeast(2)).reAuth()); + .untilAsserted(() -> verify(conn, atLeast(2)).reAuthenticate()); }); latch.countDown(); task1.get(); diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java index a70fec0704..699dc47f31 100644 --- a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -3,14 +3,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.awaitility.Awaitility.await; import static org.awaitility.Durations.*; import static org.hamcrest.CoreMatchers.either;