diff --git a/packages/realtime_client/lib/src/realtime_channel.dart b/packages/realtime_client/lib/src/realtime_channel.dart index 7c37d800..02c7aa86 100644 --- a/packages/realtime_client/lib/src/realtime_channel.dart +++ b/packages/realtime_client/lib/src/realtime_channel.dart @@ -150,9 +150,11 @@ class RealtimeChannel { joinPush.receive( 'ok', - (response) { + (response) async { final serverPostgresFilters = response['postgres_changes']; - if (socket.accessToken != null) socket.setAuth(socket.accessToken); + if (socket.accessToken != null) { + await socket.setAuth(socket.accessToken); + } if (serverPostgresFilters == null) { if (callback != null) { diff --git a/packages/realtime_client/lib/src/realtime_client.dart b/packages/realtime_client/lib/src/realtime_client.dart index f5e6f5fa..3ecfb612 100644 --- a/packages/realtime_client/lib/src/realtime_client.dart +++ b/packages/realtime_client/lib/src/realtime_client.dart @@ -54,6 +54,7 @@ class RealtimeCloseEvent { } class RealtimeClient { + // This is named `accessTokenValue` in supabase-js String? accessToken; List channels = []; final String endPoint; @@ -89,6 +90,8 @@ class RealtimeClient { }; int longpollerTimeout = 20000; SocketStates? connState; + // This is called `accessToken` in realtime-js + Future Function()? customAccessToken; /// Initializes the Socket /// @@ -129,6 +132,7 @@ class RealtimeClient { this.longpollerTimeout = 20000, RealtimeLogLevel? logLevel, this.httpClient, + this.customAccessToken, }) : endPoint = Uri.parse('$endPoint/${Transports.websocket}') .replace( queryParameters: @@ -403,15 +407,43 @@ class RealtimeClient { /// Sets the JWT access token used for channel subscription authorization and Realtime RLS. /// /// `token` A JWT strings. - void setAuth(String? token) { - accessToken = token; + Future setAuth(String? token) async { + final tokenToSend = + token ?? (await customAccessToken?.call()) ?? accessToken; + + if (tokenToSend != null) { + Map? parsed; + try { + final decoded = + base64.decode(base64.normalize(tokenToSend.split('.')[1])); + parsed = json.decode(utf8.decode(decoded)); + } catch (e) { + // ignore parsing errors + } + if (parsed != null && parsed['exp'] != null) { + final now = (DateTime.now().millisecondsSinceEpoch / 1000).floor(); + final valid = now - parsed['exp'] < 0; + if (!valid) { + log( + 'auth', + 'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}', + null, + Level.FINE, + ); + throw FormatException( + 'InvalidJWTToken: Invalid value for JWT claim "exp" with value ${parsed['exp']}'); + } + } + } + + accessToken = tokenToSend; for (final channel in channels) { - if (token != null) { - channel.updateJoinPayload({'access_token': token}); + if (tokenToSend != null) { + channel.updateJoinPayload({'access_token': tokenToSend}); } if (channel.joinedOnce && channel.isJoined) { - channel.push(ChannelEvents.accessToken, {'access_token': token}); + channel.push(ChannelEvents.accessToken, {'access_token': tokenToSend}); } } } @@ -436,7 +468,7 @@ class RealtimeClient { if (heartbeatTimer != null) heartbeatTimer!.cancel(); heartbeatTimer = Timer.periodic( Duration(milliseconds: heartbeatIntervalMs), - (Timer t) => sendHeartbeat(), + (Timer t) async => await sendHeartbeat(), ); for (final callback in stateChangeCallbacks['open']!) { callback(); @@ -502,7 +534,7 @@ class RealtimeClient { } @internal - void sendHeartbeat() { + Future sendHeartbeat() async { if (!isConnected) { return; } @@ -524,6 +556,6 @@ class RealtimeClient { payload: {}, ref: pendingHeartbeatRef!, )); - setAuth(accessToken); + await setAuth(accessToken); } } diff --git a/packages/realtime_client/pubspec.yaml b/packages/realtime_client/pubspec.yaml index 8c7ef8cd..6ef77e14 100644 --- a/packages/realtime_client/pubspec.yaml +++ b/packages/realtime_client/pubspec.yaml @@ -19,3 +19,4 @@ dev_dependencies: lints: ^3.0.0 mocktail: ^1.0.0 test: ^1.16.5 + crypto: ^3.0.0 diff --git a/packages/realtime_client/test/socket_test.dart b/packages/realtime_client/test/socket_test.dart index 79fe1306..55c50914 100644 --- a/packages/realtime_client/test/socket_test.dart +++ b/packages/realtime_client/test/socket_test.dart @@ -1,6 +1,7 @@ import 'dart:convert'; import 'dart:io'; +import 'package:crypto/crypto.dart'; import 'package:mocktail/mocktail.dart'; import 'package:realtime_client/realtime_client.dart'; import 'package:realtime_client/src/constants.dart'; @@ -16,6 +17,31 @@ typedef WebSocketChannelClosure = WebSocketChannel Function( Map headers, ); +/// Generate a JWT token for testing purposes +/// +/// [exp] in seconds since Epoch +String generateJwt([int? exp]) { + final header = {'alg': 'HS256', 'typ': 'JWT'}; + + final now = DateTime.now(); + final expiry = exp ?? + (now.add(Duration(hours: 1)).millisecondsSinceEpoch / 1000).floor(); + + final payload = {'exp': expiry}; + + final key = 'your-256-bit-secret'; + + final encodedHeader = base64Url.encode(utf8.encode(json.encode(header))); + final encodedPayload = base64Url.encode(utf8.encode(json.encode(payload))); + + final signatureInput = '$encodedHeader.$encodedPayload'; + final hmac = Hmac(sha256, utf8.encode(key)); + final digest = hmac.convert(utf8.encode(signatureInput)); + final signature = base64Url.encode(digest.bytes); + + return '$encodedHeader.$encodedPayload.$signature'; +} + void main() { const int int64MaxValue = 9223372036854775807; @@ -174,7 +200,7 @@ void main() { await Future.delayed(const Duration(milliseconds: 200)); expect(opens, 1); - socket.sendHeartbeat(); + await socket.sendHeartbeat(); // need to wait for event to trigger await Future.delayed(const Duration(seconds: 1)); expect(lastMsg['event'], 'heartbeat'); @@ -427,12 +453,13 @@ void main() { }); group('setAuth', () { - final updateJoinPayload = {'access_token': 'token123'}; - final pushPayload = {'access_token': 'token123'}; + final token = generateJwt(); + final updateJoinPayload = {'access_token': token}; + final pushPayload = {'access_token': token}; test( "sets access token, updates channels' join payload, and pushes token to channels", - () { + () async { final mockedChannel1 = MockChannel(); when(() => mockedChannel1.joinedOnce).thenReturn(true); when(() => mockedChannel1.isJoined).thenReturn(true); @@ -457,7 +484,9 @@ void main() { final channel1 = mockedSocket.channel(tTopic1); final channel2 = mockedSocket.channel(tTopic2); - mockedSocket.setAuth('token123'); + await mockedSocket.setAuth(token); + + expect(mockedSocket.accessToken, token); verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1); verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1); @@ -466,6 +495,62 @@ void main() { verify(() => channel2.push(ChannelEvents.accessToken, pushPayload)) .called(1); }); + + test( + "sets access token, updates channels' join payload, and pushes token to channels if is not a jwt", + () async { + final mockedChannel1 = MockChannel(); + final mockedChannel2 = MockChannel(); + final mockedChannel3 = MockChannel(); + + when(() => mockedChannel1.joinedOnce).thenReturn(true); + when(() => mockedChannel1.isJoined).thenReturn(true); + when(() => mockedChannel1.push(ChannelEvents.accessToken, any())) + .thenReturn(MockPush()); + + when(() => mockedChannel2.joinedOnce).thenReturn(false); + when(() => mockedChannel2.isJoined).thenReturn(false); + when(() => mockedChannel2.push(ChannelEvents.accessToken, any())) + .thenReturn(MockPush()); + + when(() => mockedChannel3.joinedOnce).thenReturn(true); + when(() => mockedChannel3.isJoined).thenReturn(true); + when(() => mockedChannel3.push(ChannelEvents.accessToken, any())) + .thenReturn(MockPush()); + + const tTopic1 = 'test-topic1'; + const tTopic2 = 'test-topic2'; + const tTopic3 = 'test-topic3'; + + final mockedSocket = SocketWithMockedChannel(socketEndpoint); + mockedSocket.mockedChannelLooker.addAll({ + tTopic1: mockedChannel1, + tTopic2: mockedChannel2, + tTopic3: mockedChannel3, + }); + + final channel1 = mockedSocket.channel(tTopic1); + final channel2 = mockedSocket.channel(tTopic2); + final channel3 = mockedSocket.channel(tTopic3); + + const token = 'sb-key'; + final pushPayload = {'access_token': token}; + final updateJoinPayload = {'access_token': token}; + + await mockedSocket.setAuth(token); + + expect(mockedSocket.accessToken, token); + + verify(() => channel1.updateJoinPayload(updateJoinPayload)).called(1); + verify(() => channel2.updateJoinPayload(updateJoinPayload)).called(1); + verify(() => channel3.updateJoinPayload(updateJoinPayload)).called(1); + + verify(() => channel1.push(ChannelEvents.accessToken, pushPayload)) + .called(1); + verifyNever(() => channel2.push(ChannelEvents.accessToken, pushPayload)); + verify(() => channel3.push(ChannelEvents.accessToken, pushPayload)) + .called(1); + }); }); group('sendHeartbeat', () { @@ -496,18 +581,18 @@ void main() { //! Unimplemented Test: closes socket when heartbeat is not ack'd within heartbeat window - test('pushes heartbeat data when connected', () { + test('pushes heartbeat data when connected', () async { mockedSocket.connState = SocketStates.open; - mockedSocket.sendHeartbeat(); + await mockedSocket.sendHeartbeat(); verify(() => mockedSink.add(captureAny(that: equals(data)))).called(1); }); - test('no ops when not connected', () { + test('no ops when not connected', () async { mockedSocket.connState = SocketStates.connecting; - mockedSocket.sendHeartbeat(); + await mockedSocket.sendHeartbeat(); verifyNever(() => mockedSink.add(any())); }); }); diff --git a/packages/supabase/lib/src/supabase_client.dart b/packages/supabase/lib/src/supabase_client.dart index 4f500b8b..4a7d52de 100644 --- a/packages/supabase/lib/src/supabase_client.dart +++ b/packages/supabase/lib/src/supabase_client.dart @@ -332,6 +332,7 @@ class SupabaseClient { logLevel: options.logLevel, httpClient: _authHttpClient, timeout: options.timeout ?? RealtimeConstants.defaultTimeout, + customAccessToken: accessToken, ); } @@ -349,22 +350,32 @@ class SupabaseClient { void _listenForAuthEvents() { // ignore: invalid_use_of_internal_member _authStateSubscription = auth.onAuthStateChangeSync.listen( - (data) { - _handleTokenChanged(data.event, data.session?.accessToken); + (data) async { + await _handleTokenChanged(data.event, data.session?.accessToken); }, onError: (error, stack) {}, ); } - void _handleTokenChanged(AuthChangeEvent event, String? token) { + Future _handleTokenChanged(AuthChangeEvent event, String? token) async { if (event == AuthChangeEvent.initialSession || event == AuthChangeEvent.tokenRefreshed || event == AuthChangeEvent.signedIn) { - realtime.setAuth(token); + try { + await realtime.setAuth(token); + } on FormatException catch (e) { + if (e.message.contains('InvalidJWTToken')) { + // The exception is thrown by RealtimeClient when the token is + // expired for example on app launch after the app has been closed + // for a while. + } else { + rethrow; + } + } } else if (event == AuthChangeEvent.signedOut) { // Token is removed - realtime.setAuth(_supabaseKey); + await realtime.setAuth(_supabaseKey); } } }