Skip to content

Commit

Permalink
fix: prevent sending expired tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
dshukertjr committed Dec 10, 2024
1 parent ccfcbf5 commit 17bf76f
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 8 deletions.
40 changes: 35 additions & 5 deletions packages/realtime_client/lib/src/realtime_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class RealtimeCloseEvent {
}

class RealtimeClient {
// This is named `accessTokenValue` in supabase-js
String? accessToken;
List<RealtimeChannel> channels = [];
final String endPoint;
Expand Down Expand Up @@ -89,6 +90,8 @@ class RealtimeClient {
};
int longpollerTimeout = 20000;
SocketStates? connState;
// This is called `accessToken` in realtime-js
Future<String> Function()? customAccessToken;

/// Initializes the Socket
///
Expand Down Expand Up @@ -403,15 +406,42 @@ class RealtimeClient {
/// Sets the JWT access token used for channel subscription authorization and Realtime RLS.
///
/// `token` A JWT strings.
void setAuth(String? token) {
accessToken = token;
Future<void> setAuth(String? token) async {
final tokenToSend =
token ?? (await customAccessToken?.call()) ?? accessToken;

if (tokenToSend != null) {
Map<String, dynamic>? parsed;
try {
final decoded =
utf8.decode(base64Url.decode(tokenToSend.split('.')[1]));
parsed = json.decode(decoded);
} catch (e) {
// ignore parsing errors
}
if (parsed != null && parsed['exp'] != null) {
final now = (DateTime.now().millisecondsSinceEpoch / 1000).floor();
final valid = now - parsed['exp'] < 0;
if (!valid) {
log(
'auth',
'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}',
null,
Level.FINE,
);
throw 'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}';
}
}
}

accessToken = tokenToSend;

for (final channel in channels) {
if (token != null) {
channel.updateJoinPayload({'access_token': token});
if (tokenToSend != null) {
channel.updateJoinPayload({'access_token': tokenToSend});
}
if (channel.joinedOnce && channel.isJoined) {
channel.push(ChannelEvents.accessToken, {'access_token': token});
channel.push(ChannelEvents.accessToken, {'access_token': tokenToSend});
}
}
}
Expand Down
1 change: 1 addition & 0 deletions packages/realtime_client/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ dev_dependencies:
lints: ^3.0.0
mocktail: ^1.0.0
test: ^1.16.5
crypto: ^3.0.6
91 changes: 88 additions & 3 deletions packages/realtime_client/test/socket_test.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'dart:convert';
import 'dart:io';

import 'package:crypto/crypto.dart';
import 'package:mocktail/mocktail.dart';
import 'package:realtime_client/realtime_client.dart';
import 'package:realtime_client/src/constants.dart';
Expand All @@ -16,6 +17,31 @@ typedef WebSocketChannelClosure = WebSocketChannel Function(
Map<String, String> headers,
);

/// Generate a JWT token for testing purposes
///
/// [exp] in seconds since Epoch
String generateJwt([int? exp]) {
final header = {'alg': 'HS256', 'typ': 'JWT'};

final now = DateTime.now();
final expiry = exp ??
(now.add(Duration(hours: 1)).millisecondsSinceEpoch / 1000).floor();

final payload = {'exp': expiry};

final key = 'your-256-bit-secret';

final encodedHeader = base64Url.encode(utf8.encode(json.encode(header)));
final encodedPayload = base64Url.encode(utf8.encode(json.encode(payload)));

final signatureInput = '$encodedHeader.$encodedPayload';
final hmac = Hmac(sha256, utf8.encode(key));
final digest = hmac.convert(utf8.encode(signatureInput));
final signature = base64Url.encode(digest.bytes);

return '$encodedHeader.$encodedPayload.$signature';
}

void main() {
const int int64MaxValue = 9223372036854775807;

Expand Down Expand Up @@ -427,8 +453,9 @@ void main() {
});

group('setAuth', () {
final updateJoinPayload = {'access_token': 'token123'};
final pushPayload = {'access_token': 'token123'};
final token = generateJwt();
final updateJoinPayload = {'access_token': token};
final pushPayload = {'access_token': token};

test(
"sets access token, updates channels' join payload, and pushes token to channels",
Expand Down Expand Up @@ -457,7 +484,9 @@ void main() {
final channel1 = mockedSocket.channel(tTopic1);
final channel2 = mockedSocket.channel(tTopic2);

mockedSocket.setAuth('token123');
mockedSocket.setAuth(token);

expect(mockedSocket.accessToken, token);

verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
Expand All @@ -466,6 +495,62 @@ void main() {
verify(() => channel2.push(ChannelEvents.accessToken, pushPayload))
.called(1);
});

test(
"sets access token, updates channels' join payload, and pushes token to channels if is not a jwt",
() {
final mockedChannel1 = MockChannel();
final mockedChannel2 = MockChannel();
final mockedChannel3 = MockChannel();

when(() => mockedChannel1.joinedOnce).thenReturn(true);
when(() => mockedChannel1.isJoined).thenReturn(true);
when(() => mockedChannel1.push(ChannelEvents.accessToken, any()))
.thenReturn(MockPush());

when(() => mockedChannel2.joinedOnce).thenReturn(false);
when(() => mockedChannel2.isJoined).thenReturn(false);
when(() => mockedChannel2.push(ChannelEvents.accessToken, any()))
.thenReturn(MockPush());

when(() => mockedChannel3.joinedOnce).thenReturn(true);
when(() => mockedChannel3.isJoined).thenReturn(true);
when(() => mockedChannel3.push(ChannelEvents.accessToken, any()))
.thenReturn(MockPush());

const tTopic1 = 'test-topic1';
const tTopic2 = 'test-topic2';
const tTopic3 = 'test-topic3';

final mockedSocket = SocketWithMockedChannel(socketEndpoint);
mockedSocket.mockedChannelLooker.addAll(<String, RealtimeChannel>{
tTopic1: mockedChannel1,
tTopic2: mockedChannel2,
tTopic3: mockedChannel3,
});

final channel1 = mockedSocket.channel(tTopic1);
final channel2 = mockedSocket.channel(tTopic2);
final channel3 = mockedSocket.channel(tTopic3);

const token = 'sb-key';
final pushPayload = {'access_token': token};
final updateJoinPayload = {'access_token': token};

mockedSocket.setAuth(token);

expect(mockedSocket.accessToken, token);

verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
verify(() => channel3.updateJoinPayload(updateJoinPayload)).called(1);

verify(() => channel1.push(ChannelEvents.accessToken, pushPayload))
.called(1);
verifyNever(() => channel2.push(ChannelEvents.accessToken, pushPayload));
verify(() => channel3.push(ChannelEvents.accessToken, pushPayload))
.called(1);
});
});

group('sendHeartbeat', () {
Expand Down

0 comments on commit 17bf76f

Please sign in to comment.