Skip to content

Commit

Permalink
adding reauth support for both pubsub and shardedpubsub
Browse files Browse the repository at this point in the history
  • Loading branch information
atakavci committed Dec 14, 2024
1 parent 9717c9a commit 9185f44
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 75 deletions.
17 changes: 10 additions & 7 deletions src/main/java/redis/clients/jedis/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ public class Connection implements Closeable {
private String strVal;
protected String server;
protected String version;
protected AtomicReference<RedisCredentials> currentCredentials = new AtomicReference<RedisCredentials>(
null);
private boolean isTokenBasedAuthenticationEnabled = false;
private AtomicReference<RedisCredentials> currentCredentials = new AtomicReference<>(null);
private AuthXManager authXManager;

public Connection() {
this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT);
Expand All @@ -68,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC

public Connection(final JedisSocketFactory socketFactory) {
this.socketFactory = socketFactory;
this.authXManager = null;
}

public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) {
Expand Down Expand Up @@ -458,9 +458,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {

Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();

AuthXManager authXManager = config.getAuthXManager();
authXManager = config.getAuthXManager();
if (authXManager != null) {
isTokenBasedAuthenticationEnabled = true;
credentialsProvider = authXManager;
}

Expand Down Expand Up @@ -608,7 +607,11 @@ public boolean ping() {
return true;
}

public boolean isTokenBasedAuthenticationEnabled() {
return isTokenBasedAuthenticationEnabled;
protected boolean isTokenBasedAuthenticationEnabled() {
return authXManager != null;
}

protected AuthXManager getAuthXManager() {
return authXManager;
}
}
53 changes: 31 additions & 22 deletions src/main/java/redis/clients/jedis/JedisPubSubBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;

import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.exceptions.JedisException;
Expand All @@ -12,7 +13,8 @@
public abstract class JedisPubSubBase<T> {

private int subscribedChannels = 0;
private volatile Connection client;
private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator();
private final Consumer<Object> pingResultHandler = this::processPingReply;

public void onMessage(T channel, T message) {
}
Expand All @@ -36,12 +38,7 @@ public void onPong(T pattern) {
}

private void sendAndFlushCommand(Command command, T... args) {
if (client == null) {
throw new JedisException(getClass() + " is not connected to a Connection.");
}
CommandArguments cargs = new CommandArguments(command).addObjects(args);
client.sendCommand(cargs);
client.flush();
authenticator.sendAndFlushCommand(command, args);
}

public final void unsubscribe() {
Expand All @@ -63,7 +60,8 @@ public final void psubscribe(T... patterns) {
}

private void checkConnectionSuitableForPubSub() {
if (client.protocol == RedisProtocol.RESP2 && client.isTokenBasedAuthenticationEnabled()) {
if (authenticator.client.protocol != RedisProtocol.RESP3
&& authenticator.client.isTokenBasedAuthenticationEnabled()) {
throw new JedisException(
"Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!");
}
Expand All @@ -78,7 +76,13 @@ public final void punsubscribe(T... patterns) {
}

public final void ping() {
sendAndFlushCommand(Command.PING);
authenticator.commandSync.lock();
try {
sendAndFlushCommand(Command.PING);
authenticator.resultHandler.add(pingResultHandler);
} finally {
authenticator.commandSync.unlock();
}
}

public final void ping(T argument) {
Expand All @@ -94,24 +98,24 @@ public final int getSubscribedChannels() {
}

public final void proceed(Connection client, T... channels) {
this.client = client;
this.client.setTimeoutInfinite();
authenticator.registerForAuthentication(client);
authenticator.client.setTimeoutInfinite();
try {
subscribe(channels);
process();
} finally {
this.client.rollbackTimeout();
authenticator.client.rollbackTimeout();
}
}

public final void proceedWithPatterns(Connection client, T... patterns) {
this.client = client;
this.client.setTimeoutInfinite();
authenticator.registerForAuthentication(client);
authenticator.client.setTimeoutInfinite();
try {
psubscribe(patterns);
process();
} finally {
this.client.rollbackTimeout();
authenticator.client.rollbackTimeout();
}
}

Expand All @@ -121,7 +125,7 @@ public final void proceedWithPatterns(Connection client, T... patterns) {
private void process() {

do {
Object reply = client.getUnflushedObject();
Object reply = authenticator.client.getUnflushedObject();

if (reply instanceof List) {
List<Object> listReply = (List<Object>) reply;
Expand Down Expand Up @@ -175,12 +179,8 @@ private void process() {
throw new JedisException("Unknown message type: " + firstObj);
}
} else if (reply instanceof byte[]) {
byte[] resp = (byte[]) reply;
if ("PONG".equals(SafeEncoder.encode(resp))) {
onPong(null);
} else {
onPong(encode(resp));
}
Consumer<Object> resultHandler = authenticator.resultHandler.remove();
resultHandler.accept(reply);
} else {
throw new JedisException("Unknown message type: " + reply);
}
Expand All @@ -189,4 +189,13 @@ private void process() {
// /* Invalidate instance since this thread is no longer listening */
// this.client = null;
}

private void processPingReply(Object reply) {
byte[] resp = (byte[]) reply;
if ("PONG".equals(SafeEncoder.encode(resp))) {
onPong(null);
} else {
onPong(encode(resp));
}
}
}
104 changes: 104 additions & 0 deletions src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package redis.clients.jedis;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import redis.clients.authentication.core.SimpleToken;
import redis.clients.authentication.core.Token;
import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.authentication.JedisAuthenticationException;
import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.util.SafeEncoder;

public class JedisSafeAuthenticator {

private static final Token PLACEHOLDER_TOKEN = new SimpleToken(null, null, 0, 0, null);
private static final Logger logger = LoggerFactory.getLogger(JedisSafeAuthenticator.class);

protected volatile Connection client;
protected final Consumer<Object> authResultHandler = this::processAuthReply;
protected final Consumer<Token> authenticationHandler = this::safeReAuthenticate;

protected final AtomicReference<Token> pendingTokenRef = new AtomicReference<Token>(null);
protected final ReentrantLock commandSync = new ReentrantLock();
protected final Queue<Consumer<Object>> resultHandler = new ConcurrentLinkedQueue<Consumer<Object>>();

protected void sendAndFlushCommand(Command command, Object... args) {
if (client == null) {
throw new JedisException(getClass() + " is not connected to a Connection.");
}
CommandArguments cargs = new CommandArguments(command).addObjects(args);

Token newToken = pendingTokenRef.getAndSet(PLACEHOLDER_TOKEN);

// lets send the command without locking !!IF!! we know that pendingTokenRef is null replaced with PLACEHOLDER_TOKEN and no re-auth will go into action
// !!ELSE!! we are locking since we already know a re-auth is still in progress in another thread and we need to wait for it to complete, we do nothing but wait on it!
if (newToken != null) {
commandSync.lock();
}
try {
client.sendCommand(cargs);
client.flush();
} finally {
Token newerToken = pendingTokenRef.getAndSet(null);
// lets check if a newer token received since the beginning of this sendAndFlushCommand call
if (newerToken != null && newerToken != PLACEHOLDER_TOKEN) {
safeReAuthenticate(newerToken);
}
if (newToken != null) {
commandSync.unlock();
}
}
}

protected void registerForAuthentication(Connection newClient) {
Connection oldClient = this.client;
if (oldClient == newClient) return;
if (oldClient != null && oldClient.getAuthXManager() != null) {
oldClient.getAuthXManager().removePostAuthenticationHook(authenticationHandler);
}
if (newClient != null && newClient.getAuthXManager() != null) {
newClient.getAuthXManager().addPostAuthenticationHook(authenticationHandler);
}
this.client = newClient;
}

private void safeReAuthenticate(Token token) {
try {
byte[] rawPass = client.encodeToBytes(token.getValue().toCharArray());
byte[] rawUser = client.encodeToBytes(token.getUser().toCharArray());

Token newToken = pendingTokenRef.getAndSet(token);
if (newToken == null) {
commandSync.lock();
try {
sendAndFlushCommand(Command.AUTH, rawUser, rawPass);
resultHandler.add(this.authResultHandler);
} finally {
pendingTokenRef.set(null);
commandSync.unlock();
}
}
} catch (Exception e) {
logger.error("Error while re-authenticating connection", e);
client.getAuthXManager().getListener().onConnectionAuthenticationError(e);
}
}

protected void processAuthReply(Object reply) {
byte[] resp = (byte[]) reply;
String response = SafeEncoder.encode(resp);
if (!"OK".equals(response)) {
String msg = "Re-authentication failed with server response: " + response;
Exception failedAuth = new JedisAuthenticationException(msg);
logger.error(failedAuth.getMessage(), failedAuth);
client.getAuthXManager().getListener().onConnectionAuthenticationError(failedAuth);
}
}
}
31 changes: 19 additions & 12 deletions src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;

import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.exceptions.JedisException;

public abstract class JedisShardedPubSubBase<T> {

private int subscribedChannels = 0;
private volatile Connection client;
private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator();

public void onSMessage(T channel, T message) {
}
Expand All @@ -23,12 +24,7 @@ public void onSUnsubscribe(T channel, int subscribedChannels) {
}

private void sendAndFlushCommand(Command command, T... args) {
if (client == null) {
throw new JedisException(getClass() + " is not connected to a Connection.");
}
CommandArguments cargs = new CommandArguments(command).addObjects(args);
client.sendCommand(cargs);
client.flush();
authenticator.sendAndFlushCommand(command, args);
}

public final void sunsubscribe() {
Expand All @@ -40,9 +36,18 @@ public final void sunsubscribe(T... channels) {
}

public final void ssubscribe(T... channels) {
checkConnectionSuitableForPubSub();
sendAndFlushCommand(Command.SSUBSCRIBE, channels);
}

private void checkConnectionSuitableForPubSub() {
if (authenticator.client.protocol != RedisProtocol.RESP3
&& authenticator.client.isTokenBasedAuthenticationEnabled()) {
throw new JedisException(
"Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!");
}
}

public final boolean isSubscribed() {
return subscribedChannels > 0;
}
Expand All @@ -52,23 +57,22 @@ public final int getSubscribedChannels() {
}

public final void proceed(Connection client, T... channels) {
this.client = client;
this.client.setTimeoutInfinite();
authenticator.registerForAuthentication(client);
authenticator.client.setTimeoutInfinite();
try {
ssubscribe(channels);
process();
} finally {
this.client.rollbackTimeout();
authenticator.client.rollbackTimeout();
}
}

protected abstract T encode(byte[] raw);

// private void process(Client client) {
private void process() {

do {
Object reply = client.getUnflushedObject();
Object reply = authenticator.client.getUnflushedObject();

if (reply instanceof List) {
List<Object> listReply = (List<Object>) reply;
Expand Down Expand Up @@ -96,6 +100,9 @@ private void process() {
} else {
throw new JedisException("Unknown message type: " + firstObj);
}
} else if (reply instanceof byte[]) {
Consumer<Object> resultHandler = authenticator.resultHandler.remove();
resultHandler.accept(reply);
} else {
throw new JedisException("Unknown message type: " + reply);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,6 @@ public void allConnectionsReauthTest() throws InterruptedException, ExecutionExc
}
}

// T.3.2
// Test system behavior when some connections fail to re-authenticate during bulk authentication. e.g when a network partition occurs for 1 or more of them
@Test
public void partialReauthFailureTest() {

}

// T.3.3
// Verify behavior when attempting to authenticate a single connection with an expired token.
@Test
Expand Down
Loading

0 comments on commit 9185f44

Please sign in to comment.