From e8b01bfc436260496bd754b3f871ce68275834fe Mon Sep 17 00:00:00 2001 From: Joshua Sosso Date: Fri, 29 Nov 2024 15:01:31 -0600 Subject: [PATCH] feat: add "onError" hook to generated clients (#102) * add general onError hook to ts client * add onError hook to dart client * fix issue with ts onError hook add test to ensure onError hook fires in generated ts clients * improve dart result types * add onError hooks to kotlin client * add error hooks to kotlin sse implementation * add error hooks to swift client * test onError hooks (swift and ts) * add integration tests for onError hook (dart and kotlin) * document error hooks --- languages/dart/dart-client/lib/request.dart | 101 ++++- languages/dart/dart-client/lib/ws.dart | 1 + .../dart-client/test/arri_client_test.dart | 4 +- .../lib/reference_client.dart | 30 +- languages/dart/dart-codegen/README.md | 29 +- languages/dart/dart-codegen/src/_index.ts | 8 +- languages/dart/dart-codegen/src/procedures.ts | 21 +- .../src/main/kotlin/ExampleClient.kt | 180 +++++--- languages/kotlin/kotlin-codegen/README.md | 27 +- languages/kotlin/kotlin-codegen/src/_index.ts | 53 ++- .../kotlin/kotlin-codegen/src/procedures.ts | 30 +- .../Sources/ArriClient/ArriClient.swift | 78 ++-- .../SwiftCodegenReference.swift | 22 +- languages/swift/swift-codegen/README.md | 22 +- .../swift/swift-codegen/src/procedures.ts | 11 +- languages/ts/ts-client/src/request.ts | 8 +- languages/ts/ts-client/src/sse.ts | 6 + languages/ts/ts-client/src/ws.test.ts | 2 + languages/ts/ts-client/src/ws.ts | 1 - .../src/referenceClient.test.ts | 2 + .../src/referenceClient.ts | 10 + languages/ts/ts-codegen/README.md | 4 + languages/ts/ts-codegen/src/rpc.ts | 4 +- languages/ts/ts-codegen/src/service.ts | 3 + tests/clients/dart/lib/test_client.rpc.dart | 113 ++++- tests/clients/dart/test/test_client_test.dart | 24 +- tests/clients/kotlin/src/main/kotlin/Main.kt | 16 +- .../kotlin/src/main/kotlin/TestClient.rpc.kt | 414 +++++++++++------- .../clients/swift/Sources/TestClient.g.swift | 60 ++- .../clients/swift/Tests/TestClientTests.swift | 21 +- tests/clients/ts/testClient.rpc.ts | 29 ++ tests/clients/ts/testClient.test.ts | 28 +- 32 files changed, 936 insertions(+), 426 deletions(-) diff --git a/languages/dart/dart-client/lib/request.dart b/languages/dart/dart-client/lib/request.dart index 1e2bdf1a..89d1ce42 100644 --- a/languages/dart/dart-client/lib/request.dart +++ b/languages/dart/dart-client/lib/request.dart @@ -126,26 +126,36 @@ Future parsedArriRequest( HttpMethod method = HttpMethod.post, Map? params, FutureOr> Function()? headers, + Function(Object)? onError, String? clientVersion, required T Function(String) parser, }) async { - final result = await arriRequest( - url, - httpClient: httpClient, - method: method, - params: params, - headers: headers, - clientVersion: clientVersion, - ); - if (result.statusCode >= 200 && result.statusCode <= 299) { - return parser(utf8.decode(result.bodyBytes)); + final http.Response result; + + try { + result = await arriRequest( + url, + httpClient: httpClient, + method: method, + params: params, + headers: headers, + clientVersion: clientVersion, + ); + if (result.statusCode >= 200 && result.statusCode <= 299) { + return parser(utf8.decode(result.bodyBytes)); + } + } catch (err) { + onError?.call(err); + rethrow; } - throw ArriError.fromResponse(result); + final err = ArriError.fromResponse(result); + onError?.call(err); + throw err; } /// Perform a raw HTTP request to an Arri RPC server. This function does not thrown an error. Instead it returns a request result /// in which both value and the error can be null. -Future> parsedArriRequestSafe( +Future> parsedArriRequestSafe( String url, { http.Client? httpClient, HttpMethod httpMethod = HttpMethod.get, @@ -164,22 +174,65 @@ Future> parsedArriRequestSafe( method: httpMethod, httpClient: httpClient, ); - return ArriRequestResult(value: result); + return ArriResultOk(result); } catch (err) { - return ArriRequestResult(error: err is ArriError ? err : null); + return ArriResultErr( + err is ArriError + ? err + : ArriError( + code: 0, + message: err.toString(), + data: err, + ), + ); } } -/// Container for holding a request result or a request error -class ArriRequestResult { - final T? value; - final ArriError? error; - const ArriRequestResult({this.value, this.error}); +/// Container for holding a request data or a request error +sealed class ArriResult { + bool get isOk; + bool get isErr; + T? get unwrap; + T unwrapOr(T fallback); + ArriError? get unwrapErr; } -/// Abstract endpoint to use as a base for generated client route enums -abstract class ArriEndpoint { - final String path; - final HttpMethod method; - const ArriEndpoint({required this.path, required this.method}); +class ArriResultOk implements ArriResult { + final T _data; + const ArriResultOk(this._data); + + @override + bool get isOk => true; + + @override + bool get isErr => false; + + @override + T get unwrap => _data; + + @override + T unwrapOr(T fallback) => _data; + + @override + ArriError? get unwrapErr => null; +} + +class ArriResultErr implements ArriResult { + final ArriError _err; + const ArriResultErr(this._err); + + @override + bool get isErr => true; + + @override + bool get isOk => false; + + @override + T? get unwrap => null; + + @override + ArriError get unwrapErr => _err; + + @override + T unwrapOr(T fallback) => fallback; } diff --git a/languages/dart/dart-client/lib/ws.dart b/languages/dart/dart-client/lib/ws.dart index fba9158b..1fe4d31f 100644 --- a/languages/dart/dart-client/lib/ws.dart +++ b/languages/dart/dart-client/lib/ws.dart @@ -10,6 +10,7 @@ Future> FutureOr> Function()? headers, required TServerMessage Function(String msg) parser, required String Function(TClientMessage msg) serializer, + Function(Object)? onError, String? clientVersion, }) async { var finalUrl = diff --git a/languages/dart/dart-client/test/arri_client_test.dart b/languages/dart/dart-client/test/arri_client_test.dart index 7419b957..963dee38 100644 --- a/languages/dart/dart-client/test/arri_client_test.dart +++ b/languages/dart/dart-client/test/arri_client_test.dart @@ -9,9 +9,7 @@ main() { test("invalid url", () async { final response = await parsedArriRequestSafe(nonExistentUrl, parser: (data) {}); - if (response.error != null) { - expect(response.error!.code, equals(500)); - } + expect(response.unwrapErr?.code, equals(0)); }); test('auto retry sse', () async { diff --git a/languages/dart/dart-codegen-reference/lib/reference_client.dart b/languages/dart/dart-codegen-reference/lib/reference_client.dart index 0b7c773d..57003e31 100644 --- a/languages/dart/dart-codegen-reference/lib/reference_client.dart +++ b/languages/dart/dart-codegen-reference/lib/reference_client.dart @@ -9,14 +9,17 @@ class ExampleClient { final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "20"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; ExampleClient({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; Future sendObject(NestedObject params) async { return parsedArriRequest( @@ -27,6 +30,7 @@ class ExampleClient { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => NestedObject.fromJsonString(body), + onError: _onError, ); } @@ -34,6 +38,7 @@ class ExampleClient { baseUrl: _baseUrl, headers: _headers, httpClient: _httpClient, + onError: _onError, ); } @@ -41,14 +46,17 @@ class ExampleClientBooksService { final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "20"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; ExampleClientBooksService({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; /// Get a book Future getBook(BookParams params) async { @@ -60,6 +68,7 @@ class ExampleClientBooksService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => Book.fromJsonString(body), + onError: _onError, ); } @@ -74,6 +83,7 @@ class ExampleClientBooksService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => Book.fromJsonString(body), + onError: _onError, ); } @@ -103,7 +113,16 @@ class ExampleClientBooksService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } @@ -114,6 +133,7 @@ class ExampleClientBooksService { clientVersion: _clientVersion, parser: (msg) => Book.fromJsonString(msg), serializer: (msg) => msg.toJsonString(), + onError: _onError, ); } } diff --git a/languages/dart/dart-codegen/README.md b/languages/dart/dart-codegen/README.md index 74d32e68..2d3ef8c7 100644 --- a/languages/dart/dart-codegen/README.md +++ b/languages/dart/dart-codegen/README.md @@ -69,28 +69,37 @@ final service = MyClientUsersService( ); ``` +#### Client / Service Options + +| name | Type | description | +| ---------- | ------------------------------------------- | ------------------------------------------------------------- | +| httpClient | `http.Client` | Use this to pass in a custom `http.Client` instance | +| baseUrl | `String` | The base url for the backend server | +| headers | `FutureOr> Function()?` | A function that returns a Map of headers | +| onError | `Function(Object)?` | A hook that fires whenever any error is thrown by the client. | + ### Using Arri Models All generated models will be immutable. They will have access to the following features: **Methods**: -- `Map toJson()` -- `String toJsonString()` -- `String toUrlQueryParams()` -- `copyWith()` +- `Map toJson()` +- `String toJsonString()` +- `String toUrlQueryParams()` +- `copyWith()` **Factory Methods**: -- `empty()` -- `fromJson(Map input)` -- `fromJsonString(String input)` +- `empty()` +- `fromJson(Map input)` +- `fromJsonString(String input)` **Overrides**: -- `==` operator (allows for deep equality checking) -- `hashMap` (allows for deep equality checking) -- `toString` (will print out all properties and values instead of `Instance of X`) +- `==` operator (allows for deep equality checking) +- `hashMap` (allows for deep equality checking) +- `toString` (will print out all properties and values instead of `Instance of X`) This library was generated with [Nx](https://nx.dev). diff --git a/languages/dart/dart-codegen/src/_index.ts b/languages/dart/dart-codegen/src/_index.ts index 7cee026d..6e7ff903 100644 --- a/languages/dart/dart-codegen/src/_index.ts +++ b/languages/dart/dart-codegen/src/_index.ts @@ -167,14 +167,17 @@ class ${clientName} { final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "${context.clientVersion ?? ""}"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; ${clientName}({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; ${rpcParts.join("\n\n")} @@ -184,6 +187,7 @@ ${subServices baseUrl: _baseUrl, headers: _headers, httpClient: _httpClient, + onError: _onError, );`, ) .join("\n\n")} diff --git a/languages/dart/dart-codegen/src/procedures.ts b/languages/dart/dart-codegen/src/procedures.ts index 93301a17..27fa8e16 100644 --- a/languages/dart/dart-codegen/src/procedures.ts +++ b/languages/dart/dart-codegen/src/procedures.ts @@ -74,7 +74,16 @@ export function dartHttpRpcFromSchema( onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); }`; } @@ -87,6 +96,7 @@ export function dartHttpRpcFromSchema( clientVersion: _clientVersion, ${paramsType ? "params: params.toJson()," : ""} parser: (body) ${schema.response ? `=> ${responseType}.fromJsonString(body)` : "{}"}, + onError: _onError, ); }`; } @@ -120,6 +130,7 @@ export function dartWsRpcFromSchema( clientVersion: _clientVersion, parser: (msg) ${responseType ? `=> ${responseType}.fromJsonString(msg)` : "{}"}, serializer: (msg) ${paramsType ? "=> msg.toJsonString()" : '=> ""'}, + onError: _onError, ); }`; } @@ -180,14 +191,17 @@ export function dartServiceFromSchema( final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "${context.clientVersion}"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; ${serviceName}({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; ${rpcParts.join("\n\n")} @@ -197,6 +211,7 @@ export function dartServiceFromSchema( baseUrl: _baseUrl, headers: _headers, httpClient: _httpClient, + onError: _onError, );`, ) .join("\n\n")} diff --git a/languages/kotlin/kotlin-codegen-reference/src/main/kotlin/ExampleClient.kt b/languages/kotlin/kotlin-codegen-reference/src/main/kotlin/ExampleClient.kt index 73b72c12..8976efd6 100644 --- a/languages/kotlin/kotlin-codegen-reference/src/main/kotlin/ExampleClient.kt +++ b/languages/kotlin/kotlin-codegen-reference/src/main/kotlin/ExampleClient.kt @@ -37,33 +37,40 @@ class ExampleClient( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { suspend fun sendObject(params: NestedObject): NestedObject { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/send-object", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { - throw ExampleClientError( - code = 0, - errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", - data = JsonPrimitive(response.bodyAsText()), - stack = null, - ) - } - if (response.status.value in 200..299) { - return NestedObject.fromJson(response.bodyAsText()) + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/send-object", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { + throw ExampleClientError( + code = 0, + errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", + data = JsonPrimitive(response.bodyAsText()), + stack = null, + ) + } + if (response.status.value in 200..299) { + return NestedObject.fromJson(response.bodyAsText()) + } + throw ExampleClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw ExampleClientError.fromJson(response.bodyAsText()) } val books: ExampleClientBooksService = ExampleClientBooksService( httpClient = httpClient, baseUrl = baseUrl, headers = headers, + onError = onError, ) } @@ -71,30 +78,37 @@ class ExampleClientBooksService( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { /** * Get a book */ suspend fun getBook(params: BookParams): Book { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/books/get-book", - method = HttpMethod.Get, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { - throw ExampleClientError( - code = 0, - errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", - data = JsonPrimitive(response.bodyAsText()), - stack = null, - ) - } - if (response.status.value in 200..299) { - return Book.fromJson(response.bodyAsText()) + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/books/get-book", + method = HttpMethod.Get, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { + throw ExampleClientError( + code = 0, + errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", + data = JsonPrimitive(response.bodyAsText()), + stack = null, + ) + } + if (response.status.value in 200..299) { + return Book.fromJson(response.bodyAsText()) + } + throw ExampleClientError.fromJson(response.bodyAsText()) + + } catch (e: Exception) { + onError(e) + throw e } - throw ExampleClientError.fromJson(response.bodyAsText()) } /** @@ -102,25 +116,31 @@ class ExampleClientBooksService( */ @Deprecated(message = "This method was marked as deprecated by the server") suspend fun createBook(params: Book): Book { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/books/create-book", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { - throw ExampleClientError( - code = 0, - errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", - data = JsonPrimitive(response.bodyAsText()), - stack = null, - ) - } - if (response.status.value in 200..299) { - return Book.fromJson(response.bodyAsText()) + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/books/create-book", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { + throw ExampleClientError( + code = 0, + errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", + data = JsonPrimitive(response.bodyAsText()), + stack = null, + ) + } + if (response.status.value in 200..299) { + return Book.fromJson(response.bodyAsText()) + } + throw ExampleClientError.fromJson(response.bodyAsText()) + + } catch (e: Exception) { + onError(e) + throw e } - throw ExampleClientError.fromJson(response.bodyAsText()) } @Deprecated(message = "This method was marked as deprecated by the server") @@ -147,6 +167,7 @@ class ExampleClientBooksService( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -1977,8 +1998,9 @@ private suspend fun __handleSseRequest( onOpen: ((response: HttpResponse) -> Unit) = {}, onClose: (() -> Unit) = {}, onData: ((data: String) -> Unit) = {}, - onRequestError: ((error: Exception) -> Unit) = {}, - onResponseError: ((error: ExampleClientError) -> Unit) = {}, + onError: ((err: Exception) -> Unit) = {}, + onRequestError: ((err: Exception) -> Unit) = {}, + onResponseError: ((err: ExampleClientError) -> Unit) = {}, bufferCapacity: Int, ) { val finalHeaders = headers?.invoke() ?: mutableMapOf() @@ -2013,18 +2035,18 @@ private suspend fun __handleSseRequest( if (httpResponse.status.value !in 200..299) { try { if (httpResponse.headers["Content-Type"] == "application/json") { - onResponseError( - ExampleClientError.fromJson(httpResponse.bodyAsText()) - ) + val err = ExampleClientError.fromJson(httpResponse.bodyAsText()) + onError(err) + onResponseError(err) } else { - onResponseError( - ExampleClientError( - code = httpResponse.status.value, - errorMessage = httpResponse.status.description, - data = JsonPrimitive(httpResponse.bodyAsText()), - stack = null, - ) + val err = ExampleClientError( + code = httpResponse.status.value, + errorMessage = httpResponse.status.description, + data = JsonPrimitive(httpResponse.bodyAsText()), + stack = null, ) + onError(err) + onResponseError(err) } } catch (e: CancellationException) { onClose() @@ -2044,19 +2066,21 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } if (httpResponse.headers["Content-Type"] != "text/event-stream") { try { - onResponseError( - ExampleClientError( - code = 0, - errorMessage = "Expected server to return Content-Type \"text/event-stream\". Got \"${httpResponse.headers["Content-Type"]}\"", - data = JsonPrimitive(httpResponse.bodyAsText()), - stack = null, - ) + val err = ExampleClientError( + code = 0, + errorMessage = "Expected server to return Content-Type \"text/event-stream\". Got \"${httpResponse.headers["Content-Type"]}\"", + data = JsonPrimitive(httpResponse.bodyAsText()), + stack = null, ) + onError(err) + onResponseError(err) } catch (e: CancellationException) { httpResponse.cancel() return@execute @@ -2074,6 +2098,8 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } @@ -2125,10 +2151,13 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } } catch (e: java.net.ConnectException) { + onError(e) onRequestError(e) return __handleSseRequest( httpClient = httpClient, @@ -2143,9 +2172,12 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } catch (e: Exception) { + onError(e) onRequestError(e) return __handleSseRequest( httpClient = httpClient, @@ -2160,6 +2192,8 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } diff --git a/languages/kotlin/kotlin-codegen/README.md b/languages/kotlin/kotlin-codegen/README.md index db190e04..32a5d11b 100644 --- a/languages/kotlin/kotlin-codegen/README.md +++ b/languages/kotlin/kotlin-codegen/README.md @@ -30,8 +30,8 @@ export default defineConfig({ The generated code relies on the following dependencies: -- [kotlinx.serialization](https://github.com/Kotlin/kotlinx.serialization) -- [ktor client](https://ktor.io/docs/client-dependencies.html) +- [kotlinx.serialization](https://github.com/Kotlin/kotlinx.serialization) +- [ktor client](https://ktor.io/docs/client-dependencies.html) ## Using the Generated Code @@ -72,6 +72,15 @@ val service = MyClientUsersService( ) ``` +#### Client / Service Options + +| Name | Type | Description | +| ------------------ | -------------------------------------- | ---------------------------------------------------------------- | +| httpClient | `HttpClient` | An instance of ktor HttpClient | +| baseUrl | `String` | The base URL of the API server | +| headers | `(() -> MutableMap?)?` | A function that returns a map of http headers | +| onError (Optional) | `((err: Exception) -> Unit)` | A hook that fires whenever any exception is thrown by the client | + ### Calling Procedures #### Standard HTTP Procedures @@ -149,19 +158,19 @@ All generated models will be data classes. They will have access to the followin **Methods**: -- `toJson(): String` -- `toUrlQueryParams(): String` +- `toJson(): String` +- `toUrlQueryParams(): String` **Factory Methods**: -- `new()` -- `fromJson(input: String)` -- `fromJsonElement(input: JsonElement, instancePath: String)` +- `new()` +- `fromJson(input: String)` +- `fromJsonElement(input: JsonElement, instancePath: String)` **Other Notes** -- All Enums will have a `serialValue` property. -- Discriminator schemas are converted to sealed classes +- All Enums will have a `serialValue` property. +- Discriminator schemas are converted to sealed classes ## Development diff --git a/languages/kotlin/kotlin-codegen/src/_index.ts b/languages/kotlin/kotlin-codegen/src/_index.ts index cac46b4c..e9ce5eef 100644 --- a/languages/kotlin/kotlin-codegen/src/_index.ts +++ b/languages/kotlin/kotlin-codegen/src/_index.ts @@ -115,6 +115,7 @@ export function kotlinClientFromAppDefinition( httpClient = httpClient, baseUrl = baseUrl, headers = headers, + onError = onError, )`); if (subService.content) { subServiceParts.push(subService.content); @@ -146,6 +147,7 @@ class ${clientName}( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { ${procedureParts.join("\n\n ")} } @@ -518,8 +520,9 @@ private suspend fun __handleSseRequest( onOpen: ((response: HttpResponse) -> Unit) = {}, onClose: (() -> Unit) = {}, onData: ((data: String) -> Unit) = {}, - onRequestError: ((error: Exception) -> Unit) = {}, - onResponseError: ((error: ${clientName}Error) -> Unit) = {}, + onError: ((err: Exception) -> Unit) = {}, + onRequestError: ((err: Exception) -> Unit) = {}, + onResponseError: ((err: ${clientName}Error) -> Unit) = {}, bufferCapacity: Int, ) { val finalHeaders = headers?.invoke() ?: mutableMapOf() @@ -554,18 +557,18 @@ private suspend fun __handleSseRequest( if (httpResponse.status.value !in 200..299) { try { if (httpResponse.headers["Content-Type"] == "application/json") { - onResponseError( - ${clientName}Error.fromJson(httpResponse.bodyAsText()) - ) + val err = ${clientName}Error.fromJson(httpResponse.bodyAsText()) + onError(err) + onResponseError(err) } else { - onResponseError( - ${clientName}Error( - code = httpResponse.status.value, - errorMessage = httpResponse.status.description, - data = JsonPrimitive(httpResponse.bodyAsText()), - stack = null, - ) + val err = ${clientName}Error( + code = httpResponse.status.value, + errorMessage = httpResponse.status.description, + data = JsonPrimitive(httpResponse.bodyAsText()), + stack = null, ) + onError(err) + onResponseError(err) } } catch (e: CancellationException) { onClose() @@ -585,19 +588,21 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } if (httpResponse.headers["Content-Type"] != "text/event-stream") { try { - onResponseError( - ${clientName}Error( - code = 0, - errorMessage = "Expected server to return Content-Type \\"text/event-stream\\". Got \\"\${httpResponse.headers["Content-Type"]}\\"", - data = JsonPrimitive(httpResponse.bodyAsText()), - stack = null, - ) + val err = ${clientName}Error( + code = 0, + errorMessage = "Expected server to return Content-Type \\"text/event-stream\\". Got \\"\${httpResponse.headers["Content-Type"]}\\"", + data = JsonPrimitive(httpResponse.bodyAsText()), + stack = null, ) + onError(err) + onResponseError(err) } catch (e: CancellationException) { httpResponse.cancel() return@execute @@ -615,6 +620,8 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } @@ -666,10 +673,13 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } } catch (e: java.net.ConnectException) { + onError(e) onRequestError(e) return __handleSseRequest( httpClient = httpClient, @@ -684,9 +694,12 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } catch (e: Exception) { + onError(e) onRequestError(e) return __handleSseRequest( httpClient = httpClient, @@ -701,6 +714,8 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } diff --git a/languages/kotlin/kotlin-codegen/src/procedures.ts b/languages/kotlin/kotlin-codegen/src/procedures.ts index 45bd8293..185b12ad 100644 --- a/languages/kotlin/kotlin-codegen/src/procedures.ts +++ b/languages/kotlin/kotlin-codegen/src/procedures.ts @@ -76,6 +76,7 @@ export function kotlinHttpRpcFromSchema( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -94,18 +95,23 @@ export function kotlinHttpRpcFromSchema( ) }`; return `${codeComment}suspend fun ${name}(${params ? `params: ${params}` : ""}): ${response ?? "Unit"} { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl${schema.path}", - method = HttpMethod.${pascalCase(schema.method, { normalize: true })}, - params = ${params ? "params" : null}, - headers = headers?.invoke(), - ).execute() - ${response ? headingCheck : ""} - if (response.status.value in 200..299) { - return ${response ? `${response}.fromJson(response.bodyAsText())` : ""} + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl${schema.path}", + method = HttpMethod.${pascalCase(schema.method, { normalize: true })}, + params = ${params ? "params" : null}, + headers = headers?.invoke(), + ).execute() + ${response ? headingCheck : ""} + if (response.status.value in 200..299) { + return ${response ? `${response}.fromJson(response.bodyAsText())` : ""} + } + throw ${context.clientName}Error.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw ${context.clientName}Error.fromJson(response.bodyAsText()) }`; } @@ -140,6 +146,7 @@ export function kotlinServiceFromSchema( httpClient = httpClient, baseUrl = baseUrl, headers = headers, + onError = onError, )`); if (subService.content) { subServiceParts.push(subService.content); @@ -163,6 +170,7 @@ export function kotlinServiceFromSchema( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { ${procedureParts.join("\n\n ")} } diff --git a/languages/swift/swift-client/Sources/ArriClient/ArriClient.swift b/languages/swift/swift-client/Sources/ArriClient/ArriClient.swift index c03e4783..e50c7252 100644 --- a/languages/swift/swift-client/Sources/ArriClient/ArriClient.swift +++ b/languages/swift/swift-client/Sources/ArriClient/ArriClient.swift @@ -18,44 +18,50 @@ public func parsedArriHttpRequest Dictionary, clientVersion: String, params: TParams, - timeoutSeconds: Int64 = 60 + timeoutSeconds: Int64 = 60, + onError: (Error) -> Void ) async throws -> TResponse { - var parsedURL = URLComponents(string: url) - if parsedURL == nil { - throw ArriRequestError.invalidUrl - } - var finalHeaders = headers() - if !clientVersion.isEmpty { - finalHeaders["client-version"] = clientVersion - } - var finalBody: String? - switch method { - case "GET": - if !(params is EmptyArriModel) { - parsedURL!.queryItems = params.toURLQueryParts() - } - break; - default: - if !(params is EmptyArriModel) { - finalHeaders["Content-Type"] = "application/json" - finalBody = params.toJSONString() - } - break; - } - let request = ArriHTTPRequest(url: parsedURL!.url!, method: method, headers: finalHeaders, body: finalBody) - let response = try await delegate.handleHTTPRequest(request: request) - if response.statusCode >= 200 && response.statusCode < 300 { - let result = TResponse.init(JSONData: response.body ?? Data()) - return result - } - var error = ArriResponseError(JSONData: response.body ?? Data()) - if error.code == 0 { - error.code = response.statusCode - } - if error.message.isEmpty { - error.message = response.statusMessage ?? "Unknown error" + do { + var parsedURL = URLComponents(string: url) + if parsedURL == nil { + throw ArriRequestError.invalidUrl + } + var finalHeaders = headers() + if !clientVersion.isEmpty { + finalHeaders["client-version"] = clientVersion + } + var finalBody: String? + switch method { + case "GET": + if !(params is EmptyArriModel) { + parsedURL!.queryItems = params.toURLQueryParts() + } + break; + default: + if !(params is EmptyArriModel) { + finalHeaders["Content-Type"] = "application/json" + finalBody = params.toJSONString() + } + break; + } + let request = ArriHTTPRequest(url: parsedURL!.url!, method: method, headers: finalHeaders, body: finalBody) + let response = try await delegate.handleHTTPRequest(request: request) + if response.statusCode >= 200 && response.statusCode < 300 { + let result = TResponse.init(JSONData: response.body ?? Data()) + return result + } + var error = ArriResponseError(JSONData: response.body ?? Data()) + if error.code == 0 { + error.code = response.statusCode + } + if error.message.isEmpty { + error.message = response.statusMessage ?? "Unknown error" + } + throw error + } catch (let e) { + onError(e) + throw e } - throw error } public enum ArriRequestError: Error { diff --git a/languages/swift/swift-codegen-reference/Sources/SwiftCodegenReference/SwiftCodegenReference.swift b/languages/swift/swift-codegen-reference/Sources/SwiftCodegenReference/SwiftCodegenReference.swift index 046906e1..86493963 100644 --- a/languages/swift/swift-codegen-reference/Sources/SwiftCodegenReference/SwiftCodegenReference.swift +++ b/languages/swift/swift-codegen-reference/Sources/SwiftCodegenReference/SwiftCodegenReference.swift @@ -6,20 +6,24 @@ public class ExampleClient { let baseURL: String let delegate: ArriRequestDelegate let headers: () -> Dictionary + let onError: (Error) -> Void public let books: ExampleClientBooksService public init( baseURL: String, delegate: ArriRequestDelegate, - headers: @escaping () -> Dictionary + headers: @escaping () -> Dictionary, + onError: @escaping ((Error) -> Void) = { _ -> Void in } ) { self.baseURL = baseURL self.delegate = delegate self.headers = headers + self.onError = onError self.books = ExampleClientBooksService( baseURL: baseURL, delegate: delegate, - headers: headers + headers: headers, + onError: onError ) } @@ -30,7 +34,8 @@ public class ExampleClient { method: "POST", headers: self.headers, clientVersion: "20", - params: params + params: params, + onError: onError ) return result } @@ -41,15 +46,18 @@ public class ExampleClientBooksService { let baseURL: String let delegate: ArriRequestDelegate let headers: () -> Dictionary + let onError: (Error) -> Void public init( baseURL: String, delegate: ArriRequestDelegate, - headers: @escaping () -> Dictionary + headers: @escaping () -> Dictionary, + onError: @escaping ((Error) -> Void) = { _ -> Void in } ) { self.baseURL = baseURL self.delegate = delegate self.headers = headers + self.onError = onError } /// Get a book public func getBook(_ params: BookParams) async throws -> Book { @@ -59,7 +67,8 @@ public class ExampleClientBooksService { method: "GET", headers: self.headers, clientVersion: "20", - params: params + params: params, + onError: onError ) return result } @@ -72,7 +81,8 @@ public class ExampleClientBooksService { method: "POST", headers: self.headers, clientVersion: "20", - params: params + params: params, + onError: onError ) return result } diff --git a/languages/swift/swift-codegen/README.md b/languages/swift/swift-codegen/README.md index 97919e83..e1d0e3f3 100644 --- a/languages/swift/swift-codegen/README.md +++ b/languages/swift/swift-codegen/README.md @@ -61,9 +61,13 @@ then add `ArriClient` as a dependency to your target let client = MyClient( baseURL: "https://example.com", delegate: DefaultRequestDelegate(), - headers { + headers: { var headers: Dictionary = Dictionary() return headers + }, + // optional + onError: { err in + // do something } ) @@ -156,14 +160,14 @@ await task.result #### Available Event Source Options -- `onMessage` - Closure that fires whenever a "message" event is received from the server. This is the only required option. -- `onRequest` - Closure that fires when a request has been created but has not been executed yet. -- `onRequestError` - Closure that fires when there was an error in creating the request (i.e. a malformed URL), or if we were unable to connect to the server. (i.e a `connectionRefused` error) -- `onResponse` - Closure that fires when we receive a response from the server -- `onResponseError` - Closure that fires when the server has not responded with status code from `200` - `299` or the `Content-Type` header does not contain `text/event-stream` -- `onClose` - Closure that fires when the EventSource is closed. (This will only fire if the EventSource was already able successfully receive a response from the server.) -- `maxRetryCount` - Limit the number of times that the EventSource tries to reconnect to the server. When set to `nil` it will retry indefinitely. (Default is `nil`) -- `maxRetryInterval` - Set the max delay time between retries in milliseconds. Default is `30000`. +- `onMessage` - Closure that fires whenever a "message" event is received from the server. This is the only required option. +- `onRequest` - Closure that fires when a request has been created but has not been executed yet. +- `onRequestError` - Closure that fires when there was an error in creating the request (i.e. a malformed URL), or if we were unable to connect to the server. (i.e a `connectionRefused` error) +- `onResponse` - Closure that fires when we receive a response from the server +- `onResponseError` - Closure that fires when the server has not responded with status code from `200` - `299` or the `Content-Type` header does not contain `text/event-stream` +- `onClose` - Closure that fires when the EventSource is closed. (This will only fire if the EventSource was already able successfully receive a response from the server.) +- `maxRetryCount` - Limit the number of times that the EventSource tries to reconnect to the server. When set to `nil` it will retry indefinitely. (Default is `nil`) +- `maxRetryInterval` - Set the max delay time between retries in milliseconds. Default is `30000`. ## Additional Notes diff --git a/languages/swift/swift-codegen/src/procedures.ts b/languages/swift/swift-codegen/src/procedures.ts index 5620c95f..1b566d4e 100644 --- a/languages/swift/swift-codegen/src/procedures.ts +++ b/languages/swift/swift-codegen/src/procedures.ts @@ -75,7 +75,8 @@ export function swiftHttpProcedureFromSchema( method: "${schema.method.toUpperCase()}", headers: self.headers, clientVersion: "${context.clientVersion}", - ${params ? `params: params` : "params: EmptyArriModel()"} + ${params ? `params: params` : "params: EmptyArriModel()"}, + onError: onError ) ${response ? `return result` : ""} }`; @@ -175,21 +176,25 @@ public class ${serviceName} { let baseURL: String let delegate: ArriRequestDelegate let headers: () -> Dictionary + let onError: (Error) -> Void ${services.map((service) => ` public let ${service.key}: ${service.typeName}`).join("\n")} public init( baseURL: String, delegate: ArriRequestDelegate, - headers: @escaping () -> Dictionary + headers: @escaping () -> Dictionary, + onError: @escaping ((Error) -> Void) = { _ -> Void in } ) { self.baseURL = baseURL self.delegate = delegate self.headers = headers + self.onError = onError ${services .map( (service) => ` self.${service.key} = ${service.typeName}( baseURL: baseURL, delegate: delegate, - headers: headers + headers: headers, + onError: onError )`, ) .join("\n")} diff --git a/languages/ts/ts-client/src/request.ts b/languages/ts/ts-client/src/request.ts index de06c72d..d91b1019 100644 --- a/languages/ts/ts-client/src/request.ts +++ b/languages/ts/ts-client/src/request.ts @@ -15,6 +15,7 @@ export interface ArriRequestOpts< params?: TParams; responseFromJson: (input: Record) => TType; responseFromString: (input: string) => TType; + onError?: (err: unknown) => void; serializer: ( input: TParams, ) => TParams extends undefined ? undefined : string; @@ -54,10 +55,11 @@ export async function arriRequest< return opts.responseFromJson(result); } catch (err) { const error = err as any as FetchError; + let arriError: ArriErrorInstance; if (isArriError(error.data)) { - throw new ArriErrorInstance(error.data); + arriError = new ArriErrorInstance(error.data); } else { - throw new ArriErrorInstance({ + arriError = new ArriErrorInstance({ code: error.statusCode ?? 500, message: error.statusMessage ?? @@ -67,6 +69,8 @@ export async function arriRequest< stack: error.stack, }); } + if (opts.onError) opts.onError(arriError); + throw arriError; } } diff --git a/languages/ts/ts-client/src/sse.ts b/languages/ts/ts-client/src/sse.ts index 51c92712..9761796a 100644 --- a/languages/ts/ts-client/src/sse.ts +++ b/languages/ts/ts-client/src/sse.ts @@ -95,6 +95,9 @@ export function arriSseRequest< options.onRequest?.(context); }, onRequestError(context) { + if (opts.onError) { + opts.onError(context.error); + } options.onRequestError?.({ ...context, error: new ArriErrorInstance({ @@ -108,6 +111,9 @@ export function arriSseRequest< options.onResponse?.(context); }, async onResponseError(context) { + if (opts.onError) { + opts.onError(context.error); + } if (!options.onResponseError) { return; } diff --git a/languages/ts/ts-client/src/ws.test.ts b/languages/ts/ts-client/src/ws.test.ts index 806e0979..75f3869a 100644 --- a/languages/ts/ts-client/src/ws.test.ts +++ b/languages/ts/ts-client/src/ws.test.ts @@ -1,3 +1,5 @@ +import { expect, test } from "vitest"; + import { parsedWsResponse } from "./ws"; test("Parse WS Response", () => { diff --git a/languages/ts/ts-client/src/ws.ts b/languages/ts/ts-client/src/ws.ts index 0d247744..0c9eb09f 100644 --- a/languages/ts/ts-client/src/ws.ts +++ b/languages/ts/ts-client/src/ws.ts @@ -87,7 +87,6 @@ export async function arriWsRequest< } catch (err) { console.error(err); if (opts.onConnectionError) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-argument opts.onConnectionError(err as any); } return arriWsRequest(opts, retryCount + 1); diff --git a/languages/ts/ts-codegen-reference/src/referenceClient.test.ts b/languages/ts/ts-codegen-reference/src/referenceClient.test.ts index 5667b66f..3896a076 100644 --- a/languages/ts/ts-codegen-reference/src/referenceClient.test.ts +++ b/languages/ts/ts-codegen-reference/src/referenceClient.test.ts @@ -1,6 +1,8 @@ import fs from "node:fs"; import path from "node:path"; +import { describe, expect, test } from "vitest"; + import { $$Book, $$NestedObject, diff --git a/languages/ts/ts-codegen-reference/src/referenceClient.ts b/languages/ts/ts-codegen-reference/src/referenceClient.ts index e0b21615..fc0a9158 100644 --- a/languages/ts/ts-codegen-reference/src/referenceClient.ts +++ b/languages/ts/ts-codegen-reference/src/referenceClient.ts @@ -36,15 +36,18 @@ export class ExampleClient { private readonly _headers: | HeaderMap | (() => HeaderMap | Promise); + private readonly _onError?: (err: unknown) => void; books: ExampleClientBooksService; constructor( options: { baseUrl?: string; headers?: HeaderMap | (() => HeaderMap | Promise); + onError?: (err: unknown) => void; } = {}, ) { this._baseUrl = options.baseUrl ?? ""; this._headers = options.headers ?? {}; + this._onError = options.onError; this.books = new ExampleClientBooksService(options); } @@ -53,6 +56,7 @@ export class ExampleClient { url: `${this._baseUrl}/send-object`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$NestedObject.fromJson, responseFromString: $$NestedObject.fromJsonString, @@ -67,14 +71,17 @@ export class ExampleClientBooksService { private readonly _headers: | HeaderMap | (() => HeaderMap | Promise); + private readonly _onError?: (err: unknown) => void; constructor( options: { baseUrl?: string; headers?: HeaderMap | (() => HeaderMap | Promise); + onError?: (err: unknown) => void; } = {}, ) { this._baseUrl = options.baseUrl ?? ""; this._headers = options.headers ?? {}; + this._onError = options.onError; } /** * Get a book @@ -84,6 +91,7 @@ export class ExampleClientBooksService { url: `${this._baseUrl}/books/get-book`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$Book.fromJson, responseFromString: $$Book.fromJsonString, @@ -100,6 +108,7 @@ export class ExampleClientBooksService { url: `${this._baseUrl}/books/create-book`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$Book.fromJson, responseFromString: $$Book.fromJsonString, @@ -119,6 +128,7 @@ export class ExampleClientBooksService { url: `${this._baseUrl}/books/watch-book`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$Book.fromJson, responseFromString: $$Book.fromJsonString, diff --git a/languages/ts/ts-codegen/README.md b/languages/ts/ts-codegen/README.md index c354878c..7a96b97d 100644 --- a/languages/ts/ts-codegen/README.md +++ b/languages/ts/ts-codegen/README.md @@ -53,6 +53,10 @@ const client = new MyClient({ Authorization: "", }; }, + // optional + onError: (err) => { + // do something + }, }); await client.myProcedure(); diff --git a/languages/ts/ts-codegen/src/rpc.ts b/languages/ts/ts-codegen/src/rpc.ts index b8a780b1..d86239cc 100644 --- a/languages/ts/ts-codegen/src/rpc.ts +++ b/languages/ts/ts-codegen/src/rpc.ts @@ -59,6 +59,7 @@ export function httpRpcFromDefinition( url: \`\${this._baseUrl}${def.path}\`, method: "${def.method.toLowerCase()}", headers: this._headers, + onError: this._onError, ${params ? "params: params," : ""} responseFromJson: ${response ? `$$${response}.fromJson` : "() => {}"}, responseFromString: ${response ? `$$${response}.fromJsonString` : "() => {}"}, @@ -66,7 +67,7 @@ export function httpRpcFromDefinition( clientVersion: "${context.versionNumber}", }, options, - ) + ); }`; } return `${getJsDocComment({ @@ -77,6 +78,7 @@ export function httpRpcFromDefinition( url: \`\${this._baseUrl}${def.path}\`, method: "${def.method.toLowerCase()}", headers: this._headers, + onError: this._onError, ${params ? "params: params," : ""} responseFromJson: ${response ? `$$${response}.fromJson` : "() => {}"}, responseFromString: ${response ? `$$${response}.fromJsonString` : "() => {}"}, diff --git a/languages/ts/ts-codegen/src/service.ts b/languages/ts/ts-codegen/src/service.ts index f0784fb7..edf51c75 100644 --- a/languages/ts/ts-codegen/src/service.ts +++ b/languages/ts/ts-codegen/src/service.ts @@ -73,15 +73,18 @@ export function tsServiceFromDefinition( content: `export class ${serviceName} { private readonly _baseUrl: string; private readonly _headers: HeaderMap | (() => HeaderMap | Promise); + private readonly _onError?: (err: unknown) => void; ${subServices.map((service) => ` ${service.key}: ${service.name};`).join("\n")} constructor( options: { baseUrl?: string; headers?: HeaderMap | (() => HeaderMap | Promise); + onError?: (err: unknown) => void; } = {}, ) { this._baseUrl = options.baseUrl ?? ""; this._headers = options.headers ?? {}; + this._onError = options.onError; ${subServices.map((service) => ` this.${service.key} = new ${service.name}(options);`).join("\n")} } ${rpcParts.map((rpc) => ` ${rpc}`).join("\n")} diff --git a/tests/clients/dart/lib/test_client.rpc.dart b/tests/clients/dart/lib/test_client.rpc.dart index 48879805..c0d1c8d3 100644 --- a/tests/clients/dart/lib/test_client.rpc.dart +++ b/tests/clients/dart/lib/test_client.rpc.dart @@ -9,25 +9,30 @@ class TestClient { final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "10"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; TestClient({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; TestClientTestsService get tests => TestClientTestsService( baseUrl: _baseUrl, headers: _headers, httpClient: _httpClient, + onError: _onError, ); TestClientUsersService get users => TestClientUsersService( baseUrl: _baseUrl, headers: _headers, httpClient: _httpClient, + onError: _onError, ); } @@ -35,14 +40,17 @@ class TestClientTestsService { final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "10"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; TestClientTestsService({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; Future emptyParamsGetRequest() async { return parsedArriRequest( @@ -52,6 +60,7 @@ class TestClientTestsService { headers: _headers, clientVersion: _clientVersion, parser: (body) => DefaultPayload.fromJsonString(body), + onError: _onError, ); } @@ -63,6 +72,7 @@ class TestClientTestsService { headers: _headers, clientVersion: _clientVersion, parser: (body) => DefaultPayload.fromJsonString(body), + onError: _onError, ); } @@ -75,6 +85,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) {}, + onError: _onError, ); } @@ -87,6 +98,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) {}, + onError: _onError, ); } @@ -101,6 +113,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) {}, + onError: _onError, ); } @@ -113,6 +126,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) {}, + onError: _onError, ); } @@ -125,6 +139,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => ObjectWithEveryType.fromJsonString(body), + onError: _onError, ); } @@ -138,6 +153,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => ObjectWithEveryNullableType.fromJsonString(body), + onError: _onError, ); } @@ -151,6 +167,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => ObjectWithPascalCaseKeys.fromJsonString(body), + onError: _onError, ); } @@ -164,6 +181,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => ObjectWithSnakeCaseKeys.fromJsonString(body), + onError: _onError, ); } @@ -177,6 +195,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => ObjectWithEveryOptionalType.fromJsonString(body), + onError: _onError, ); } @@ -189,6 +208,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => RecursiveObject.fromJsonString(body), + onError: _onError, ); } @@ -201,6 +221,7 @@ class TestClientTestsService { clientVersion: _clientVersion, params: params.toJson(), parser: (body) => RecursiveUnion.fromJsonString(body), + onError: _onError, ); } @@ -234,7 +255,16 @@ class TestClientTestsService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } @@ -270,7 +300,16 @@ class TestClientTestsService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } @@ -303,7 +342,16 @@ class TestClientTestsService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } @@ -335,7 +383,16 @@ class TestClientTestsService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } @@ -371,7 +428,16 @@ class TestClientTestsService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } @@ -402,7 +468,16 @@ class TestClientTestsService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } } @@ -411,14 +486,17 @@ class TestClientUsersService { final http.Client? _httpClient; final String _baseUrl; final String _clientVersion = "10"; - late final FutureOr> Function()? _headers; + final FutureOr> Function()? _headers; + final Function(Object)? _onError; TestClientUsersService({ http.Client? httpClient, required String baseUrl, FutureOr> Function()? headers, + Function(Object)? onError, }) : _httpClient = httpClient, _baseUrl = baseUrl, - _headers = headers; + _headers = headers, + _onError = onError; EventSource watchUser( UsersWatchUserParams params, { @@ -450,7 +528,16 @@ class TestClientUsersService { onMessage: onMessage, onOpen: onOpen, onClose: onClose, - onError: onError, + onError: onError != null && _onError != null + ? (err, es) { + _onError?.call(onError); + return onError(err, es); + } + : onError != null + ? onError + : _onError != null + ? (err, _) => _onError?.call(err) + : null, ); } } diff --git a/tests/clients/dart/test/test_client_test.dart b/tests/clients/dart/test/test_client_test.dart index 60705846..97433998 100644 --- a/tests/clients/dart/test/test_client_test.dart +++ b/tests/clients/dart/test/test_client_test.dart @@ -13,7 +13,6 @@ const baseUrl = "http://127.0.0.1:2020"; Future main() async { final client = TestClient(baseUrl: baseUrl, headers: () => {"x-test-header": 'test'}); - final unauthenticatedClient = TestClient(baseUrl: "http://127.0.0.1:2020"); final httpClient = HttpClient(context: SecurityContext(withTrustedRoots: true)); final ioClient = IOClient(httpClient); @@ -147,6 +146,13 @@ Future main() async { expect(result.uint64, equals(input.uint64)); }); test("unauthenticated RPC requests return a 401 error", () async { + bool firedOnErr = false; + final unauthenticatedClient = TestClient( + baseUrl: baseUrl, + onError: (_) { + firedOnErr = true; + }, + ); try { await unauthenticatedClient.tests.sendObject(input); expect(false, equals(true)); @@ -157,6 +163,7 @@ Future main() async { } expect(false, equals(true)); } + expect(firedOnErr, equals(true)); }); test("can send/receive objects with partial fields", () async { final input = ObjectWithEveryOptionalType( @@ -304,6 +311,21 @@ Future main() async { equals(true)); }); + test("onError hook fires", () async { + bool onErrFired = false; + final customClient = TestClient( + baseUrl: baseUrl, + onError: (err) { + onErrFired = true; + expect(err is ArriError, equals(true)); + }, + ); + try { + await customClient.tests.sendObject(input); + } catch (_) {} + expect(onErrFired, equals(true)); + }); + test("[SSE] supports server sent events", () async { int messageCount = 0; final eventSource = client.tests.streamMessages( diff --git a/tests/clients/kotlin/src/main/kotlin/Main.kt b/tests/clients/kotlin/src/main/kotlin/Main.kt index 442c3613..5fd6ada1 100644 --- a/tests/clients/kotlin/src/main/kotlin/Main.kt +++ b/tests/clients/kotlin/src/main/kotlin/Main.kt @@ -17,11 +17,7 @@ fun main() { mutableMapOf(Pair("x-test-header", "12345")) } ) - val unauthenticatedClient = TestClient( - httpClient = httpClient, - baseUrl = "http://localhost:2020", - null - ) + val targetDate = Instant.parse("2002-02-02T08:02:00.000Z") runBlocking { @@ -115,6 +111,15 @@ fun main() { } runBlocking { + var didCallOnErr = false + val unauthenticatedClient = TestClient( + httpClient = httpClient, + baseUrl = "http://localhost:2020", + null, + onError = { _ -> + didCallOnErr = true + } + ) val tag = "UNAUTHENTICATED REQUEST RETURNS ERROR" try { unauthenticatedClient.tests.sendObject(objectInput) @@ -125,6 +130,7 @@ fun main() { // this should never be reached expect(tag, input = false, result = true) } + expect(tag, input = didCallOnErr, result = true) } runBlocking { diff --git a/tests/clients/kotlin/src/main/kotlin/TestClient.rpc.kt b/tests/clients/kotlin/src/main/kotlin/TestClient.rpc.kt index db8c32c9..9c68a743 100644 --- a/tests/clients/kotlin/src/main/kotlin/TestClient.rpc.kt +++ b/tests/clients/kotlin/src/main/kotlin/TestClient.rpc.kt @@ -36,17 +36,20 @@ class TestClient( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { val tests: TestClientTestsService = TestClientTestsService( httpClient = httpClient, baseUrl = baseUrl, headers = headers, + onError = onError, ) val users: TestClientUsersService = TestClientUsersService( httpClient = httpClient, baseUrl = baseUrl, headers = headers, + onError = onError, ) } @@ -54,16 +57,18 @@ class TestClientTestsService( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { suspend fun emptyParamsGetRequest(): DefaultPayload { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/empty-params-get-request", - method = HttpMethod.Get, - params = null, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/empty-params-get-request", + method = HttpMethod.Get, + params = null, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -71,21 +76,26 @@ class TestClientTestsService( stack = null, ) } - if (response.status.value in 200..299) { - return DefaultPayload.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return DefaultPayload.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun emptyParamsPostRequest(): DefaultPayload { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/empty-params-post-request", - method = HttpMethod.Post, - params = null, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/empty-params-post-request", + method = HttpMethod.Post, + params = null, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -93,40 +103,54 @@ class TestClientTestsService( stack = null, ) } - if (response.status.value in 200..299) { - return DefaultPayload.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return DefaultPayload.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun emptyResponseGetRequest(params: DefaultPayload): Unit { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/empty-response-get-request", - method = HttpMethod.Get, - params = params, - headers = headers?.invoke(), - ).execute() - - if (response.status.value in 200..299) { - return + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/empty-response-get-request", + method = HttpMethod.Get, + params = params, + headers = headers?.invoke(), + ).execute() + + if (response.status.value in 200..299) { + return + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun emptyResponsePostRequest(params: DefaultPayload): Unit { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/empty-response-post-request", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - - if (response.status.value in 200..299) { - return + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/empty-response-post-request", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + + if (response.status.value in 200..299) { + return + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } /** @@ -134,44 +158,55 @@ class TestClientTestsService( */ @Deprecated(message = "This method was marked as deprecated by the server") suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/deprecated-rpc", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - - if (response.status.value in 200..299) { - return + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/deprecated-rpc", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + + if (response.status.value in 200..299) { + return + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendError(params: SendErrorParams): Unit { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-error", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - - if (response.status.value in 200..299) { - return + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-error", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + + if (response.status.value in 200..299) { + return + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendObject(params: ObjectWithEveryType): ObjectWithEveryType { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-object", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-object", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -179,21 +214,26 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return ObjectWithEveryType.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return ObjectWithEveryType.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendObjectWithNullableFields(params: ObjectWithEveryNullableType): ObjectWithEveryNullableType { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-object-with-nullable-fields", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-object-with-nullable-fields", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -201,21 +241,26 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return ObjectWithEveryNullableType.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return ObjectWithEveryNullableType.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendObjectWithPascalCaseKeys(params: ObjectWithPascalCaseKeys): ObjectWithPascalCaseKeys { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-object-with-pascal-case-keys", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-object-with-pascal-case-keys", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -223,21 +268,26 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return ObjectWithPascalCaseKeys.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return ObjectWithPascalCaseKeys.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendObjectWithSnakeCaseKeys(params: ObjectWithSnakeCaseKeys): ObjectWithSnakeCaseKeys { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-object-with-snake-case-keys", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-object-with-snake-case-keys", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -245,21 +295,26 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return ObjectWithSnakeCaseKeys.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return ObjectWithSnakeCaseKeys.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendPartialObject(params: ObjectWithEveryOptionalType): ObjectWithEveryOptionalType { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-partial-object", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-partial-object", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -267,21 +322,26 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return ObjectWithEveryOptionalType.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return ObjectWithEveryOptionalType.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendRecursiveObject(params: RecursiveObject): RecursiveObject { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-recursive-object", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-recursive-object", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -289,21 +349,26 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return RecursiveObject.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return RecursiveObject.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun sendRecursiveUnion(params: RecursiveUnion): RecursiveUnion { - val response = __prepareRequest( - client = httpClient, - url = "$baseUrl/rpcs/tests/send-recursive-union", - method = HttpMethod.Post, - params = params, - headers = headers?.invoke(), - ).execute() - if (response.headers["Content-Type"] != "application/json") { + try { + val response = __prepareRequest( + client = httpClient, + url = "$baseUrl/rpcs/tests/send-recursive-union", + method = HttpMethod.Post, + params = params, + headers = headers?.invoke(), + ).execute() + if (response.headers["Content-Type"] != "application/json") { throw TestClientError( code = 0, errorMessage = "Expected server to return Content-Type \"application/json\". Got \"${response.headers["Content-Type"]}\"", @@ -311,10 +376,14 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { stack = null, ) } - if (response.status.value in 200..299) { - return RecursiveUnion.fromJson(response.bodyAsText()) + if (response.status.value in 200..299) { + return RecursiveUnion.fromJson(response.bodyAsText()) + } + throw TestClientError.fromJson(response.bodyAsText()) + } catch (e: Exception) { + onError(e) + throw e } - throw TestClientError.fromJson(response.bodyAsText()) } suspend fun streamAutoReconnect( @@ -340,6 +409,7 @@ suspend fun deprecatedRpc(params: DeprecatedRpcParams): Unit { bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -375,6 +445,7 @@ suspend fun streamConnectionErrorTest( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -410,6 +481,7 @@ suspend fun streamLargeObjects( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -442,6 +514,7 @@ suspend fun streamLargeObjects( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -474,6 +547,7 @@ suspend fun streamLargeObjects( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -509,6 +583,7 @@ suspend fun streamTenEventsThenEnd( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -525,6 +600,7 @@ class TestClientUsersService( private val httpClient: HttpClient, private val baseUrl: String, private val headers: headersFn, + private val onError: ((err: Exception) -> Unit) = {}, ) { suspend fun watchUser( params: UsersWatchUserParams, @@ -549,6 +625,7 @@ class TestClientUsersService( bufferCapacity = bufferCapacity, onOpen = onOpen, onClose = onClose, + onError = onError, onRequestError = onRequestError, onResponseError = onResponseError, onData = { str -> @@ -6214,8 +6291,9 @@ private suspend fun __handleSseRequest( onOpen: ((response: HttpResponse) -> Unit) = {}, onClose: (() -> Unit) = {}, onData: ((data: String) -> Unit) = {}, - onRequestError: ((error: Exception) -> Unit) = {}, - onResponseError: ((error: TestClientError) -> Unit) = {}, + onError: ((err: Exception) -> Unit) = {}, + onRequestError: ((err: Exception) -> Unit) = {}, + onResponseError: ((err: TestClientError) -> Unit) = {}, bufferCapacity: Int, ) { val finalHeaders = headers?.invoke() ?: mutableMapOf() @@ -6250,18 +6328,18 @@ private suspend fun __handleSseRequest( if (httpResponse.status.value !in 200..299) { try { if (httpResponse.headers["Content-Type"] == "application/json") { - onResponseError( - TestClientError.fromJson(httpResponse.bodyAsText()) - ) + val err = TestClientError.fromJson(httpResponse.bodyAsText()) + onError(err) + onResponseError(err) } else { - onResponseError( - TestClientError( - code = httpResponse.status.value, - errorMessage = httpResponse.status.description, - data = JsonPrimitive(httpResponse.bodyAsText()), - stack = null, - ) + val err = TestClientError( + code = httpResponse.status.value, + errorMessage = httpResponse.status.description, + data = JsonPrimitive(httpResponse.bodyAsText()), + stack = null, ) + onError(err) + onResponseError(err) } } catch (e: CancellationException) { onClose() @@ -6281,19 +6359,21 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } if (httpResponse.headers["Content-Type"] != "text/event-stream") { try { - onResponseError( - TestClientError( - code = 0, - errorMessage = "Expected server to return Content-Type \"text/event-stream\". Got \"${httpResponse.headers["Content-Type"]}\"", - data = JsonPrimitive(httpResponse.bodyAsText()), - stack = null, - ) + val err = TestClientError( + code = 0, + errorMessage = "Expected server to return Content-Type \"text/event-stream\". Got \"${httpResponse.headers["Content-Type"]}\"", + data = JsonPrimitive(httpResponse.bodyAsText()), + stack = null, ) + onError(err) + onResponseError(err) } catch (e: CancellationException) { httpResponse.cancel() return@execute @@ -6311,6 +6391,8 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } @@ -6362,10 +6444,13 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } } catch (e: java.net.ConnectException) { + onError(e) onRequestError(e) return __handleSseRequest( httpClient = httpClient, @@ -6380,9 +6465,12 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } catch (e: Exception) { + onError(e) onRequestError(e) return __handleSseRequest( httpClient = httpClient, @@ -6397,6 +6485,8 @@ private suspend fun __handleSseRequest( onOpen = onOpen, onClose = onClose, onData = onData, + onError = onError, + onRequestError = onRequestError, onResponseError = onResponseError, ) } diff --git a/tests/clients/swift/Sources/TestClient.g.swift b/tests/clients/swift/Sources/TestClient.g.swift index e5a895d2..1f202e0f 100644 --- a/tests/clients/swift/Sources/TestClient.g.swift +++ b/tests/clients/swift/Sources/TestClient.g.swift @@ -6,25 +6,30 @@ public class TestClient { let baseURL: String let delegate: ArriRequestDelegate let headers: () -> Dictionary + let onError: (Error) -> Void public let tests: TestClientTestsService public let users: TestClientUsersService public init( baseURL: String, delegate: ArriRequestDelegate, - headers: @escaping () -> Dictionary + headers: @escaping () -> Dictionary, + onError: @escaping ((Error) -> Void) = { _ -> Void in } ) { self.baseURL = baseURL self.delegate = delegate self.headers = headers + self.onError = onError self.tests = TestClientTestsService( baseURL: baseURL, delegate: delegate, - headers: headers + headers: headers, + onError: onError ) self.users = TestClientUsersService( baseURL: baseURL, delegate: delegate, - headers: headers + headers: headers, + onError: onError ) } @@ -36,15 +41,18 @@ public class TestClientTestsService { let baseURL: String let delegate: ArriRequestDelegate let headers: () -> Dictionary + let onError: (Error) -> Void public init( baseURL: String, delegate: ArriRequestDelegate, - headers: @escaping () -> Dictionary + headers: @escaping () -> Dictionary, + onError: @escaping ((Error) -> Void) = { _ -> Void in } ) { self.baseURL = baseURL self.delegate = delegate self.headers = headers + self.onError = onError } public func emptyParamsGetRequest() async throws -> DefaultPayload { @@ -54,7 +62,8 @@ public class TestClientTestsService { method: "GET", headers: self.headers, clientVersion: "10", - params: EmptyArriModel() + params: EmptyArriModel(), + onError: onError ) return result } @@ -65,7 +74,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: EmptyArriModel() + params: EmptyArriModel(), + onError: onError ) return result } @@ -76,7 +86,8 @@ public class TestClientTestsService { method: "GET", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) } @@ -87,7 +98,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) } @@ -100,7 +112,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) } @@ -111,7 +124,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) } @@ -122,7 +136,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -133,7 +148,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -144,7 +160,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -155,7 +172,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -166,7 +184,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -177,7 +196,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -188,7 +208,8 @@ public class TestClientTestsService { method: "POST", headers: self.headers, clientVersion: "10", - params: params + params: params, + onError: onError ) return result } @@ -294,15 +315,18 @@ public class TestClientUsersService { let baseURL: String let delegate: ArriRequestDelegate let headers: () -> Dictionary + let onError: (Error) -> Void public init( baseURL: String, delegate: ArriRequestDelegate, - headers: @escaping () -> Dictionary + headers: @escaping () -> Dictionary, + onError: @escaping ((Error) -> Void) = { _ -> Void in } ) { self.baseURL = baseURL self.delegate = delegate self.headers = headers + self.onError = onError } public func watchUser(_ params: UsersWatchUserParams, options: EventSourceOptions) -> Task<(), Never> { diff --git a/tests/clients/swift/Tests/TestClientTests.swift b/tests/clients/swift/Tests/TestClientTests.swift index 50f42dc2..c5d706f9 100644 --- a/tests/clients/swift/Tests/TestClientTests.swift +++ b/tests/clients/swift/Tests/TestClientTests.swift @@ -3,6 +3,7 @@ import ArriClient @testable import TestClientSwift final class TestSwiftClientTests: XCTestCase { + let baseUrl = "http://localhost:2020" let client = TestClient( baseURL: "http://localhost:2020", delegate: DefaultRequestDelegate(), @@ -12,13 +13,7 @@ final class TestSwiftClientTests: XCTestCase { return headers } ) - let unauthenticatedClient = TestClient( - baseURL: "http://localhost:2020", - delegate: DefaultRequestDelegate(), - headers: { - return Dictionary() - } - ) + let testDate = Date(timeIntervalSince1970: 500000) func testSendObject() async throws { let input = ObjectWithEveryType( @@ -180,6 +175,17 @@ final class TestSwiftClientTests: XCTestCase { } func testSendUnauthenticatedRequest() async { + var firedOnError = false + let unauthenticatedClient = TestClient( + baseURL: baseUrl, + delegate: DefaultRequestDelegate(), + headers: { + return Dictionary() + }, + onError: { _ in + firedOnError = true + } + ) var didError = false do { let _ = try await unauthenticatedClient.tests.sendObject(ObjectWithEveryType()) @@ -192,6 +198,7 @@ final class TestSwiftClientTests: XCTestCase { } } XCTAssert(didError) + XCTAssert(firedOnError) } func testRpcWithNoParams() async throws { diff --git a/tests/clients/ts/testClient.rpc.ts b/tests/clients/ts/testClient.rpc.ts index 3b936745..871a9fe4 100644 --- a/tests/clients/ts/testClient.rpc.ts +++ b/tests/clients/ts/testClient.rpc.ts @@ -32,16 +32,19 @@ export class TestClient { private readonly _headers: | HeaderMap | (() => HeaderMap | Promise); + private readonly _onError?: (err: unknown) => void; tests: TestClientTestsService; users: TestClientUsersService; constructor( options: { baseUrl?: string; headers?: HeaderMap | (() => HeaderMap | Promise); + onError?: (err: unknown) => void; } = {}, ) { this._baseUrl = options.baseUrl ?? ""; this._headers = options.headers ?? {}; + this._onError = options.onError; this.tests = new TestClientTestsService(options); this.users = new TestClientUsersService(options); } @@ -52,21 +55,25 @@ export class TestClientTestsService { private readonly _headers: | HeaderMap | (() => HeaderMap | Promise); + private readonly _onError?: (err: unknown) => void; constructor( options: { baseUrl?: string; headers?: HeaderMap | (() => HeaderMap | Promise); + onError?: (err: unknown) => void; } = {}, ) { this._baseUrl = options.baseUrl ?? ""; this._headers = options.headers ?? {}; + this._onError = options.onError; } async emptyParamsGetRequest(): Promise { return arriRequest({ url: `${this._baseUrl}/rpcs/tests/empty-params-get-request`, method: "get", headers: this._headers, + onError: this._onError, responseFromJson: $$DefaultPayload.fromJson, responseFromString: $$DefaultPayload.fromJsonString, @@ -79,6 +86,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/empty-params-post-request`, method: "post", headers: this._headers, + onError: this._onError, responseFromJson: $$DefaultPayload.fromJson, responseFromString: $$DefaultPayload.fromJsonString, @@ -91,6 +99,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/empty-response-get-request`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: () => {}, responseFromString: () => {}, @@ -103,6 +112,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/empty-response-post-request`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: () => {}, responseFromString: () => {}, @@ -119,6 +129,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/deprecated-rpc`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: () => {}, responseFromString: () => {}, @@ -131,6 +142,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-error`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: () => {}, responseFromString: () => {}, @@ -145,6 +157,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-object`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$ObjectWithEveryType.fromJson, responseFromString: $$ObjectWithEveryType.fromJsonString, @@ -162,6 +175,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-object-with-nullable-fields`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$ObjectWithEveryNullableType.fromJson, responseFromString: $$ObjectWithEveryNullableType.fromJsonString, @@ -176,6 +190,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-object-with-pascal-case-keys`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$ObjectWithPascalCaseKeys.fromJson, responseFromString: $$ObjectWithPascalCaseKeys.fromJsonString, @@ -190,6 +205,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-object-with-snake-case-keys`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$ObjectWithSnakeCaseKeys.fromJson, responseFromString: $$ObjectWithSnakeCaseKeys.fromJsonString, @@ -207,6 +223,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-partial-object`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$ObjectWithEveryOptionalType.fromJson, responseFromString: $$ObjectWithEveryOptionalType.fromJsonString, @@ -221,6 +238,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-recursive-object`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$RecursiveObject.fromJson, responseFromString: $$RecursiveObject.fromJsonString, @@ -233,6 +251,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/send-recursive-union`, method: "post", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$RecursiveUnion.fromJson, responseFromString: $$RecursiveUnion.fromJsonString, @@ -249,6 +268,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/stream-auto-reconnect`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$AutoReconnectResponse.fromJson, responseFromString: $$AutoReconnectResponse.fromJsonString, @@ -273,6 +293,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/stream-connection-error-test`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$StreamConnectionErrorTestResponse.fromJson, responseFromString: @@ -294,6 +315,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/stream-large-objects`, method: "get", headers: this._headers, + onError: this._onError, responseFromJson: $$StreamLargeObjectsResponse.fromJson, responseFromString: $$StreamLargeObjectsResponse.fromJsonString, @@ -312,6 +334,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/stream-messages`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$ChatMessage.fromJson, responseFromString: $$ChatMessage.fromJsonString, @@ -332,6 +355,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/stream-retry-with-new-credentials`, method: "get", headers: this._headers, + onError: this._onError, responseFromJson: $$TestsStreamRetryWithNewCredentialsResponse.fromJson, @@ -354,6 +378,7 @@ export class TestClientTestsService { url: `${this._baseUrl}/rpcs/tests/stream-ten-events-then-end`, method: "get", headers: this._headers, + onError: this._onError, responseFromJson: $$ChatMessage.fromJson, responseFromString: $$ChatMessage.fromJsonString, @@ -370,15 +395,18 @@ export class TestClientUsersService { private readonly _headers: | HeaderMap | (() => HeaderMap | Promise); + private readonly _onError?: (err: unknown) => void; constructor( options: { baseUrl?: string; headers?: HeaderMap | (() => HeaderMap | Promise); + onError?: (err: unknown) => void; } = {}, ) { this._baseUrl = options.baseUrl ?? ""; this._headers = options.headers ?? {}; + this._onError = options.onError; } watchUser( params: UsersWatchUserParams, @@ -389,6 +417,7 @@ export class TestClientUsersService { url: `${this._baseUrl}/rpcs/users/watch-user`, method: "get", headers: this._headers, + onError: this._onError, params: params, responseFromJson: $$UsersWatchUserResponse.fromJson, responseFromString: $$UsersWatchUserResponse.fromJsonString, diff --git a/tests/clients/ts/testClient.test.ts b/tests/clients/ts/testClient.test.ts index b333b094..f49f0255 100644 --- a/tests/clients/ts/testClient.test.ts +++ b/tests/clients/ts/testClient.test.ts @@ -31,9 +31,6 @@ const client = new TestClient({ baseUrl, headers, }); -const unauthenticatedClient = new TestClient({ - baseUrl, -}); test("route request", async () => { const result = await ofetch("/routes/hello-world", { @@ -173,6 +170,13 @@ test("returns error if sending nothing when RPC expects body", async () => { } }); test("unauthenticated RPC request returns a 401 error", async () => { + let firedOnErr = false; + const unauthenticatedClient = new TestClient({ + baseUrl, + onError(_) { + firedOnErr = true; + }, + }); try { await unauthenticatedClient.tests.sendObject(input); expect(true).toBe(false); @@ -182,6 +186,7 @@ test("unauthenticated RPC request returns a 401 error", async () => { expect(err.code).toBe(401); } } + expect(firedOnErr).toBe(true); }); test("can use async functions for headers", async () => { const _client = new TestClient({ @@ -328,6 +333,23 @@ test("can send/receive recursive unions", async () => { expect(result).toStrictEqual(payload); }); +test("onError hook fires properly", async () => { + let onErrorFired = false; + const customClient = new TestClient({ + baseUrl, + onError(err) { + onErrorFired = true; + expect(err instanceof ArriErrorInstance).toBe(true); + }, + }); + try { + await customClient.tests.sendObject(input); + } catch (_) { + // do nothing + } + expect(onErrorFired).toBe(true); +}); + test("[SSE] supports server sent events", async () => { let wasConnected = false; let receivedMessageCount = 0;