Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move App Sync subscription headers to protocol #5301

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,18 @@ const _requiredHeaders = {
AWSHeaders.contentType: 'application/json; charset=utf-8',
};

// AppSync expects "{}" encoded in the URI as the payload during handshake.
const _emptyBody = <String, dynamic>{};
/// The default payload to include to AppSync.
///
/// AppSync expects "{}" encoded in the URI as the payload during handshake.
@internal
const appSyncDefaultPayload = <String, dynamic>{};

/// Generate a URI for the connection and all subscriptions.
///
/// See https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html#handshake-details-to-establish-the-websocket-connection=
Future<Uri> generateConnectionUri(
ApiOutputs config,
AmplifyAuthProviderRepository authRepo,
) async {
// First, generate auth query parameters.
final authorizationHeaders = await _generateAuthorizationHeaders(
config,
isConnectionInit: true,
authRepo: authRepo,
body: _emptyBody,
);
final encodedAuthHeaders =
base64.encode(json.encode(authorizationHeaders).codeUnits);
Future<Uri> generateConnectionUri(ApiOutputs config) async {
final authQueryParameters = {
'header': encodedAuthHeaders,
'payload': base64.encode(utf8.encode(json.encode(_emptyBody))),
'payload': base64.encode(utf8.encode(json.encode(appSyncDefaultPayload))),
};
// Conditionally format the URI for a) AppSync domain b) custom domain.
var endpointUriHost = Uri.parse(config.url).host;
Expand Down Expand Up @@ -86,7 +76,7 @@ Future<WebSocketSubscriptionRegistrationMessage>
required GraphQLRequest<T> request,
}) async {
final body = {'variables': request.variables, 'query': request.document};
final authorizationHeaders = await _generateAuthorizationHeaders(
final authorizationHeaders = await generateAuthorizationHeaders(
config,
isConnectionInit: false,
authRepo: authRepo,
Expand Down Expand Up @@ -114,7 +104,8 @@ Future<WebSocketSubscriptionRegistrationMessage>
/// a canonical HTTP request that is authorized but never sent. The headers from
/// the HTTP request are reformatted and returned. This logic applies for all auth
/// modes as determined by [authRepo] parameter.
Future<Map<String, String>> _generateAuthorizationHeaders(
@internal
Future<Map<String, String>> generateAuthorizationHeaders(
ApiOutputs config, {
required bool isConnectionInit,
required AmplifyAuthProviderRepository authRepo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import 'package:amplify_api_dart/src/graphql/web_socket/types/subscriptions_even
import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_message_stream_transformer.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/types/web_socket_types.dart';
import 'package:amplify_core/amplify_core.dart';
// ignore: implementation_imports
import 'package:amplify_core/src/config/amplify_outputs/api_outputs.dart';
import 'package:async/async.dart';
import 'package:meta/meta.dart';
import 'package:stream_transform/stream_transform.dart';
Expand Down Expand Up @@ -72,15 +74,14 @@ class AmplifyWebSocketService
);

try {
const webSocketProtocols = ['graphql-ws'];
final connectionUri = await generateConnectionUri(
final protocols = await generateProtocols(
state.config,
state.authProviderRepo,
);

final connectionUri = await generateConnectionUri(state.config);
final channel = WebSocketChannel.connect(
connectionUri,
protocols: webSocketProtocols,
protocols: protocols,
);
sink = channel.sink;

Expand All @@ -95,6 +96,28 @@ class AmplifyWebSocketService
}
}

/// Generates a list of protocols from a [WebSocketState].
Jordan-Nelson marked this conversation as resolved.
Show resolved Hide resolved
@visibleForTesting
Future<List<String>> generateProtocols(
ApiOutputs outputs,
AmplifyAuthProviderRepository authRepo,
) async {
final authorizationHeaders = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: authRepo,
body: appSyncDefaultPayload,
);
final encodedAuthHeaders = base64Url
.encode(json.encode(authorizationHeaders).codeUnits)
// remove padding char ("=") as it is optional in base64Url encoding and
// is not permitted in protocol names.
// Base64Url Spec: https://datatracker.ietf.org/doc/html/rfc4648#section-5
// Protocol name separators: https://www.rfc-editor.org/rfc/rfc2616 (see "separators")
.replaceAll('=', '');
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Is it safe to remove all = or should we target only the last one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is safe to remove all. base64url uses = as padding only. Padding is optional according to the base64url spec, but dart includes it. We want to remove it since it is not allowed in a protocol name. I pushed a comment to explain that.

return ['graphql-ws', 'header-$encodedAuthHeaders'];
}

@override
Future<void> register(
ConnectedState state,
Expand Down
25 changes: 23 additions & 2 deletions packages/api/amplify_api_dart/test/util.dart
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ const testApiKeyConfigCustomDomain = DataOutputs(
);

const expectedApiKeyWebSocketConnectionUrl =
'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJBY2NlcHQiOiJhcHBsaWNhdGlvbi9qc29uLCB0ZXh0L2phdmFzY3JpcHQiLCJDb250ZW50LUVuY29kaW5nIjoiYW16LTEuMCIsIkNvbnRlbnQtVHlwZSI6ImFwcGxpY2F0aW9uL2pzb247IGNoYXJzZXQ9dXRmLTgiLCJYLUFwaS1LZXkiOiJhYmMtMTIzIiwiSG9zdCI6ImFiYzEyMy5hcHBzeW5jLWFwaS51cy1lYXN0LTEuYW1hem9uYXdzLmNvbSJ9&payload=e30%3D';
'wss://abc123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?payload=e30%3D';
const expectedApiKeyWebSocketConnectionUrlCustomDomain =
'wss://foo.bar.aws.dev/graphql/realtime?header=eyJBY2NlcHQiOiJhcHBsaWNhdGlvbi9qc29uLCB0ZXh0L2phdmFzY3JpcHQiLCJDb250ZW50LUVuY29kaW5nIjoiYW16LTEuMCIsIkNvbnRlbnQtVHlwZSI6ImFwcGxpY2F0aW9uL2pzb247IGNoYXJzZXQ9dXRmLTgiLCJYLUFwaS1LZXkiOiJhYmMtMTIzIiwiSG9zdCI6ImZvby5iYXIuYXdzLmRldiJ9&payload=e30%3D';
'wss://foo.bar.aws.dev/graphql/realtime?payload=e30%3D';

AmplifyAuthProviderRepository getTestAuthProviderRepo() {
final testAuthProviderRepo = AmplifyAuthProviderRepository()
Expand Down Expand Up @@ -341,3 +341,24 @@ void testQueryPredicateTranslation(
}

final deepEquals = const DeepCollectionEquality().equals;

/// Creates [DataOutputs] and [AmplifyAuthProviderRepository] for use in tests.
(DataOutputs, AmplifyAuthProviderRepository) createOutputsAndRepo(
AmplifyAuthProvider authProvider,
APIAuthorizationType type, [
String? apiKey,
]) {
final repo = AmplifyAuthProviderRepository()
..registerAuthProvider(
type.authProviderToken,
authProvider,
);
final outputs = DataOutputs(
awsRegion: 'us-east-1',
url: 'https://example.com/',
defaultAuthorizationType: type,
authorizationTypes: [type],
apiKey: type == APIAuthorizationType.apiKey ? apiKey : null,
);
return (outputs, repo);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,17 @@ void main() {
}

group('generateConnectionUri', () {
test('should generate authorized connection URI', () async {
final actualConnectionUri =
await generateConnectionUri(testApiKeyConfig, authProviderRepo);
test('should generate connection URI', () async {
final actualConnectionUri = await generateConnectionUri(testApiKeyConfig);
expect(
actualConnectionUri.toString(),
expectedApiKeyWebSocketConnectionUrl,
);
});

test('should generate authorized connection URI with a custom domain',
() async {
test('should generate connection URI with a custom domain', () async {
final actualConnectionUri = await generateConnectionUri(
testApiKeyConfigCustomDomain,
authProviderRepo,
);
expect(
actualConnectionUri.toString(),
Expand Down Expand Up @@ -141,4 +138,68 @@ void main() {
);
});
});

group('generateAuthorizationHeaders', () {
const apiKey = 'fake-key';

test('should generate headers for API key Authorization', () async {
final (outputs, repo) = createOutputsAndRepo(
AppSyncApiKeyAuthProvider(),
APIAuthorizationType.apiKey,
apiKey,
);
final headers = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: repo,
body: {},
);
expect(headers[xApiKey], apiKey);
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});

test('should generate headers for IAM Authorization', () async {
final (outputs, repo) = createOutputsAndRepo(
TestIamAuthProvider(),
APIAuthorizationType.iam,
);
final headers = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: repo,
body: {},
);
expect(
headers['Authorization']!.contains('Credential=fake-access-key-123'),
true,
);
expect(headers.containsKey(AWSHeaders.date), true);
expect(headers.containsKey(AWSHeaders.contentSHA256), true);
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});

test('should generate headers for user pool Authorization', () async {
final (outputs, repo) = createOutputsAndRepo(
TestTokenAuthProvider(),
APIAuthorizationType.userPools,
);
final headers = await generateAuthorizationHeaders(
outputs,
isConnectionInit: true,
authRepo: repo,
body: {},
);
expect(headers[AWSHeaders.authorization], 'test-access-token-123');
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import 'dart:convert';

import 'package:amplify_api_dart/src/graphql/providers/app_sync_api_key_auth_provider.dart';
import 'package:amplify_api_dart/src/graphql/web_socket/services/web_socket_service.dart';
import 'package:amplify_core/amplify_core.dart';
import 'package:test/test.dart';

import '../util.dart';

void main() {
group('AmplifyWebSocketService', () {
group('generateProtocols', () {});
const apiKey = 'fake-key';
test('should generate a protocol that includes the appropriate headers',
() async {
final (outputs, repo) = createOutputsAndRepo(
AppSyncApiKeyAuthProvider(),
APIAuthorizationType.apiKey,
apiKey,
);
final service = AmplifyWebSocketService();
final protocols = await service.generateProtocols(outputs, repo);
final encodedHeaders = protocols[1].replaceFirst('header-', '');
final headers = json.decode(
String.fromCharCodes(base64Url.decode(encodedHeaders)),
) as Map<String, dynamic>;
expect(headers[xApiKey], apiKey);
expect(headers.containsKey(AWSHeaders.accept), true);
expect(headers.containsKey(AWSHeaders.contentEncoding), true);
expect(headers.containsKey(AWSHeaders.contentType), true);
expect(headers.containsKey(AWSHeaders.host), true);
});
});
}
Loading