Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(realtime_client): Prevent sending expired tokens #1095

Merged
merged 7 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions packages/realtime_client/lib/src/realtime_channel.dart
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,11 @@ class RealtimeChannel {

joinPush.receive(
'ok',
(response) {
(response) async {
final serverPostgresFilters = response['postgres_changes'];
if (socket.accessToken != null) socket.setAuth(socket.accessToken);
if (socket.accessToken != null) {
await socket.setAuth(socket.accessToken);
}

if (serverPostgresFilters == null) {
if (callback != null) {
Expand Down
48 changes: 40 additions & 8 deletions packages/realtime_client/lib/src/realtime_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class RealtimeCloseEvent {
}

class RealtimeClient {
// This is named `accessTokenValue` in supabase-js
String? accessToken;
List<RealtimeChannel> channels = [];
final String endPoint;
Expand Down Expand Up @@ -89,6 +90,8 @@ class RealtimeClient {
};
int longpollerTimeout = 20000;
SocketStates? connState;
// This is called `accessToken` in realtime-js
Future<String> Function()? customAccessToken;
dshukertjr marked this conversation as resolved.
Show resolved Hide resolved

/// Initializes the Socket
///
Expand Down Expand Up @@ -129,6 +132,7 @@ class RealtimeClient {
this.longpollerTimeout = 20000,
RealtimeLogLevel? logLevel,
this.httpClient,
this.customAccessToken,
}) : endPoint = Uri.parse('$endPoint/${Transports.websocket}')
.replace(
queryParameters:
Expand Down Expand Up @@ -403,15 +407,43 @@ class RealtimeClient {
/// Sets the JWT access token used for channel subscription authorization and Realtime RLS.
///
/// `token` A JWT strings.
void setAuth(String? token) {
accessToken = token;
Future<void> setAuth(String? token) async {
dshukertjr marked this conversation as resolved.
Show resolved Hide resolved
final tokenToSend =
token ?? (await customAccessToken?.call()) ?? accessToken;

if (tokenToSend != null) {
Map<String, dynamic>? parsed;
try {
final decoded =
base64.decode(base64.normalize(tokenToSend.split('.')[1]));
parsed = json.decode(utf8.decode(decoded));
} catch (e) {
// ignore parsing errors
}
if (parsed != null && parsed['exp'] != null) {
final now = (DateTime.now().millisecondsSinceEpoch / 1000).floor();
final valid = now - parsed['exp'] < 0;
if (!valid) {
log(
'auth',
'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}',
null,
Level.FINE,
);
throw FormatException(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the js client doesn't throw an exception, but logs the issue and then fails silently. I think it's better to not throw, because setAuth is used in other places, where they might re-set the same access token which is then expired, but it shouldn't throw in those cases. Or we add try catch to those cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! I forgot to add, but the way js client handles the expired JWT changes in another PR though. supabase/realtime-js#439

'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}');
}
}
}

accessToken = tokenToSend;

for (final channel in channels) {
if (token != null) {
channel.updateJoinPayload({'access_token': token});
if (tokenToSend != null) {
channel.updateJoinPayload({'access_token': tokenToSend});
}
if (channel.joinedOnce && channel.isJoined) {
channel.push(ChannelEvents.accessToken, {'access_token': token});
channel.push(ChannelEvents.accessToken, {'access_token': tokenToSend});
}
}
}
Expand All @@ -436,7 +468,7 @@ class RealtimeClient {
if (heartbeatTimer != null) heartbeatTimer!.cancel();
heartbeatTimer = Timer.periodic(
Duration(milliseconds: heartbeatIntervalMs),
(Timer t) => sendHeartbeat(),
(Timer t) async => await sendHeartbeat(),
);
for (final callback in stateChangeCallbacks['open']!) {
callback();
Expand Down Expand Up @@ -502,7 +534,7 @@ class RealtimeClient {
}

@internal
void sendHeartbeat() {
Future<void> sendHeartbeat() async {
if (!isConnected) {
return;
}
Expand All @@ -524,6 +556,6 @@ class RealtimeClient {
payload: {},
ref: pendingHeartbeatRef!,
));
setAuth(accessToken);
await setAuth(accessToken);
}
}
1 change: 1 addition & 0 deletions packages/realtime_client/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ dev_dependencies:
lints: ^3.0.0
mocktail: ^1.0.0
test: ^1.16.5
crypto: ^3.0.0
103 changes: 94 additions & 9 deletions packages/realtime_client/test/socket_test.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'dart:convert';
import 'dart:io';

import 'package:crypto/crypto.dart';
import 'package:mocktail/mocktail.dart';
import 'package:realtime_client/realtime_client.dart';
import 'package:realtime_client/src/constants.dart';
Expand All @@ -16,6 +17,31 @@ typedef WebSocketChannelClosure = WebSocketChannel Function(
Map<String, String> headers,
);

/// Generate a JWT token for testing purposes
///
/// [exp] in seconds since Epoch
String generateJwt([int? exp]) {
final header = {'alg': 'HS256', 'typ': 'JWT'};

final now = DateTime.now();
final expiry = exp ??
(now.add(Duration(hours: 1)).millisecondsSinceEpoch / 1000).floor();

final payload = {'exp': expiry};

final key = 'your-256-bit-secret';

final encodedHeader = base64Url.encode(utf8.encode(json.encode(header)));
final encodedPayload = base64Url.encode(utf8.encode(json.encode(payload)));

final signatureInput = '$encodedHeader.$encodedPayload';
final hmac = Hmac(sha256, utf8.encode(key));
final digest = hmac.convert(utf8.encode(signatureInput));
final signature = base64Url.encode(digest.bytes);

return '$encodedHeader.$encodedPayload.$signature';
}

void main() {
const int int64MaxValue = 9223372036854775807;

Expand Down Expand Up @@ -174,7 +200,7 @@ void main() {
await Future.delayed(const Duration(milliseconds: 200));
expect(opens, 1);

socket.sendHeartbeat();
await socket.sendHeartbeat();
// need to wait for event to trigger
await Future.delayed(const Duration(seconds: 1));
expect(lastMsg['event'], 'heartbeat');
Expand Down Expand Up @@ -427,12 +453,13 @@ void main() {
});

group('setAuth', () {
final updateJoinPayload = {'access_token': 'token123'};
final pushPayload = {'access_token': 'token123'};
final token = generateJwt();
final updateJoinPayload = {'access_token': token};
final pushPayload = {'access_token': token};

test(
"sets access token, updates channels' join payload, and pushes token to channels",
() {
() async {
final mockedChannel1 = MockChannel();
when(() => mockedChannel1.joinedOnce).thenReturn(true);
when(() => mockedChannel1.isJoined).thenReturn(true);
Expand All @@ -457,7 +484,9 @@ void main() {
final channel1 = mockedSocket.channel(tTopic1);
final channel2 = mockedSocket.channel(tTopic2);

mockedSocket.setAuth('token123');
await mockedSocket.setAuth(token);

expect(mockedSocket.accessToken, token);

verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
Expand All @@ -466,6 +495,62 @@ void main() {
verify(() => channel2.push(ChannelEvents.accessToken, pushPayload))
.called(1);
});

test(
"sets access token, updates channels' join payload, and pushes token to channels if is not a jwt",
() async {
final mockedChannel1 = MockChannel();
final mockedChannel2 = MockChannel();
final mockedChannel3 = MockChannel();

when(() => mockedChannel1.joinedOnce).thenReturn(true);
when(() => mockedChannel1.isJoined).thenReturn(true);
when(() => mockedChannel1.push(ChannelEvents.accessToken, any()))
.thenReturn(MockPush());

when(() => mockedChannel2.joinedOnce).thenReturn(false);
when(() => mockedChannel2.isJoined).thenReturn(false);
when(() => mockedChannel2.push(ChannelEvents.accessToken, any()))
.thenReturn(MockPush());

when(() => mockedChannel3.joinedOnce).thenReturn(true);
when(() => mockedChannel3.isJoined).thenReturn(true);
when(() => mockedChannel3.push(ChannelEvents.accessToken, any()))
.thenReturn(MockPush());

const tTopic1 = 'test-topic1';
const tTopic2 = 'test-topic2';
const tTopic3 = 'test-topic3';

final mockedSocket = SocketWithMockedChannel(socketEndpoint);
mockedSocket.mockedChannelLooker.addAll(<String, RealtimeChannel>{
tTopic1: mockedChannel1,
tTopic2: mockedChannel2,
tTopic3: mockedChannel3,
});

final channel1 = mockedSocket.channel(tTopic1);
final channel2 = mockedSocket.channel(tTopic2);
final channel3 = mockedSocket.channel(tTopic3);

const token = 'sb-key';
final pushPayload = {'access_token': token};
final updateJoinPayload = {'access_token': token};

await mockedSocket.setAuth(token);

expect(mockedSocket.accessToken, token);

verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1);
verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1);
verify(() => channel3.updateJoinPayload(updateJoinPayload)).called(1);

verify(() => channel1.push(ChannelEvents.accessToken, pushPayload))
.called(1);
verifyNever(() => channel2.push(ChannelEvents.accessToken, pushPayload));
verify(() => channel3.push(ChannelEvents.accessToken, pushPayload))
.called(1);
});
});

group('sendHeartbeat', () {
Expand Down Expand Up @@ -496,18 +581,18 @@ void main() {

//! Unimplemented Test: closes socket when heartbeat is not ack'd within heartbeat window

test('pushes heartbeat data when connected', () {
test('pushes heartbeat data when connected', () async {
mockedSocket.connState = SocketStates.open;

mockedSocket.sendHeartbeat();
await mockedSocket.sendHeartbeat();

verify(() => mockedSink.add(captureAny(that: equals(data)))).called(1);
});

test('no ops when not connected', () {
test('no ops when not connected', () async {
mockedSocket.connState = SocketStates.connecting;

mockedSocket.sendHeartbeat();
await mockedSocket.sendHeartbeat();
verifyNever(() => mockedSink.add(any()));
});
});
Expand Down
21 changes: 16 additions & 5 deletions packages/supabase/lib/src/supabase_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class SupabaseClient {
logLevel: options.logLevel,
httpClient: _authHttpClient,
timeout: options.timeout ?? RealtimeConstants.defaultTimeout,
customAccessToken: accessToken,
);
}

Expand All @@ -349,22 +350,32 @@ class SupabaseClient {
void _listenForAuthEvents() {
// ignore: invalid_use_of_internal_member
_authStateSubscription = auth.onAuthStateChangeSync.listen(
(data) {
_handleTokenChanged(data.event, data.session?.accessToken);
(data) async {
await _handleTokenChanged(data.event, data.session?.accessToken);
},
onError: (error, stack) {},
);
}

void _handleTokenChanged(AuthChangeEvent event, String? token) {
Future<void> _handleTokenChanged(AuthChangeEvent event, String? token) async {
if (event == AuthChangeEvent.initialSession ||
event == AuthChangeEvent.tokenRefreshed ||
event == AuthChangeEvent.signedIn) {
realtime.setAuth(token);
try {
await realtime.setAuth(token);
} on FormatException catch (e) {
if (e.message.contains('InvalidJWTToken')) {
// The exception is thrown by RealtimeClient when the token is
// expired for example on app launch after the app has been closed
// for a while.
} else {
rethrow;
}
}
} else if (event == AuthChangeEvent.signedOut) {
// Token is removed

realtime.setAuth(_supabaseKey);
await realtime.setAuth(_supabaseKey);
}
}
}
Loading