diff --git a/misk-actions/src/main/kotlin/misk/exceptions/Exceptions.kt b/misk-actions/src/main/kotlin/misk/exceptions/Exceptions.kt index 1ebce348d50..b8215410178 100644 --- a/misk-actions/src/main/kotlin/misk/exceptions/Exceptions.kt +++ b/misk-actions/src/main/kotlin/misk/exceptions/Exceptions.kt @@ -27,19 +27,25 @@ open class WebActionException( ) : Exception(message, cause) { val isClientError = code in (400..499) val isServerError = code in (500..599) + + constructor( + code: Int, + message: String, + cause: Throwable? = null + ) : this(code, message, message, cause) } /** Base exception for when resources are not found */ open class NotFoundException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_NOT_FOUND, message, message, cause) + WebActionException(HTTP_NOT_FOUND, message, cause) /** Base exception for when authentication fails */ open class UnauthenticatedException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_UNAUTHORIZED, message, message, cause) + WebActionException(HTTP_UNAUTHORIZED, message, cause) /** Base exception for when authenticated credentials lack access to a resource */ open class UnauthorizedException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_FORBIDDEN, message, message, cause) + WebActionException(HTTP_FORBIDDEN, message, cause) /** * Base exception for when a resource is unavailable. The message is not exposed to the caller. @@ -49,19 +55,19 @@ open class ResourceUnavailableException(message: String = "", cause: Throwable? /** Base exception for bad client requests */ open class BadRequestException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_BAD_REQUEST, message, message, cause) + WebActionException(HTTP_BAD_REQUEST, message, cause) /** Base exception for when a request causes a conflict */ open class ConflictException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_CONFLICT, message, message, cause) + WebActionException(HTTP_CONFLICT, message, cause) /** This exception is custom to Misk. */ open class UnprocessableEntityException(message: String = "", cause: Throwable? = null) : - WebActionException(422, message, message, cause) + WebActionException(422, message, cause) /** This exception is custom to Misk. */ open class TooManyRequestsException(message: String = "", cause: Throwable? = null) : - WebActionException(429, message, message, cause) + WebActionException(429, message, cause) /** This exception is custom to Misk. */ open class ClientClosedRequestException(message: String = "", cause: Throwable? = null) : @@ -75,10 +81,10 @@ open class GatewayTimeoutException(message: String = "", cause: Throwable? = nul WebActionException(HTTP_GATEWAY_TIMEOUT, "GATEWAY_TIMEOUT", message, cause) open class PayloadTooLargeException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_ENTITY_TOO_LARGE, message, message, cause) + WebActionException(HTTP_ENTITY_TOO_LARGE, message, cause) open class UnsupportedMediaTypeException(message: String = "", cause: Throwable? = null) : - WebActionException(HTTP_UNSUPPORTED_TYPE, message, message, cause) + WebActionException(HTTP_UNSUPPORTED_TYPE, message, cause) /** Similar to [kotlin.require], but throws [BadRequestException] if the check fails */ inline fun requireRequest(check: Boolean, lazyMessage: () -> String) { diff --git a/misk-grpc-tests/build.gradle.kts b/misk-grpc-tests/build.gradle.kts index a872d491411..a0f1cda1ab6 100644 --- a/misk-grpc-tests/build.gradle.kts +++ b/misk-grpc-tests/build.gradle.kts @@ -62,6 +62,7 @@ sourceSets { dependencies { implementation(Dependencies.assertj) + implementation(Dependencies.awaitility) implementation(Dependencies.junitApi) implementation(Dependencies.kotlinTest) implementation(Dependencies.docker) @@ -76,6 +77,8 @@ dependencies { implementation(project(":misk-actions")) implementation(project(":misk-core")) implementation(project(":misk-inject")) + implementation(project(":misk-metrics")) + implementation(project(":misk-metrics-testing")) implementation(project(":misk-service")) implementation(project(":misk-testing")) diff --git a/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/GetFeatureGrpcAction.kt b/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/GetFeatureGrpcAction.kt index 2340875f3d5..e3de3733556 100644 --- a/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/GetFeatureGrpcAction.kt +++ b/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/GetFeatureGrpcAction.kt @@ -1,5 +1,6 @@ package misk.grpc.miskserver +import misk.exceptions.WebActionException import misk.web.actions.WebAction import misk.web.interceptors.LogRequestResponse import routeguide.Feature @@ -10,6 +11,9 @@ import javax.inject.Inject class GetFeatureGrpcAction @Inject constructor() : WebAction, RouteGuideGetFeatureBlockingServer { @LogRequestResponse(bodySampling = 1.0, errorBodySampling = 1.0) override fun GetFeature(request: Point): Feature { + if (request.latitude == -1) { + throw WebActionException(request.longitude ?: 500, "unexpected latitude error!") + } return Feature(name = "maple tree", location = request) } } diff --git a/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/RouteGuideMiskServiceModule.kt b/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/RouteGuideMiskServiceModule.kt index 01a4c2cd2ff..b87f282970b 100644 --- a/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/RouteGuideMiskServiceModule.kt +++ b/misk-grpc-tests/src/main/kotlin/misk/grpc/miskserver/RouteGuideMiskServiceModule.kt @@ -1,8 +1,10 @@ package misk.grpc.miskserver import com.google.inject.Provides +import com.google.inject.util.Modules import misk.MiskTestingServiceModule import misk.inject.KAbstractModule +import misk.metrics.FakeMetricsModule import misk.web.WebActionModule import misk.web.WebServerTestingModule import misk.web.jetty.JettyService @@ -12,7 +14,7 @@ import javax.inject.Named class RouteGuideMiskServiceModule : KAbstractModule() { override fun configure() { install(WebServerTestingModule(webConfig = WebServerTestingModule.TESTING_WEB_CONFIG)) - install(MiskTestingServiceModule()) + install(Modules.override(MiskTestingServiceModule()).with(FakeMetricsModule())) install(WebActionModule.create()) install(WebActionModule.create()) } diff --git a/misk-grpc-tests/src/test/kotlin/misk/grpc/MiskClientMiskServerTest.kt b/misk-grpc-tests/src/test/kotlin/misk/grpc/MiskClientMiskServerTest.kt index 92b019cfb4f..2773eb65a28 100644 --- a/misk-grpc-tests/src/test/kotlin/misk/grpc/MiskClientMiskServerTest.kt +++ b/misk-grpc-tests/src/test/kotlin/misk/grpc/MiskClientMiskServerTest.kt @@ -2,8 +2,8 @@ package misk.grpc import com.google.inject.Guice import com.google.inject.util.Modules -import javax.inject.Inject -import javax.inject.Named +import com.squareup.wire.GrpcException +import com.squareup.wire.GrpcStatus import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.runBlocking @@ -13,11 +13,20 @@ import misk.grpc.miskserver.RouteChatGrpcAction import misk.grpc.miskserver.RouteGuideMiskServiceModule import misk.inject.getInstance import misk.logging.LogCollectorModule +import misk.metrics.FakeMetrics import misk.testing.MiskTest import misk.testing.MiskTestModule import misk.web.interceptors.RequestLoggingInterceptor import okhttp3.HttpUrl import org.assertj.core.api.Assertions.assertThat +import org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS +import org.awaitility.Durations.ONE_MILLISECOND +import org.awaitility.kotlin.atMost +import org.awaitility.kotlin.await +import org.awaitility.kotlin.matches +import org.awaitility.kotlin.untilCallTo +import org.awaitility.kotlin.withPollDelay +import org.awaitility.kotlin.withPollInterval import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import routeguide.Feature @@ -25,6 +34,9 @@ import routeguide.Point import routeguide.RouteGuideClient import routeguide.RouteNote import wisp.logging.LogCollector +import javax.inject.Inject +import javax.inject.Named +import kotlin.test.assertFailsWith @MiskTest(startService = true) class MiskClientMiskServerTest { @@ -37,6 +49,7 @@ class MiskClientMiskServerTest { @Inject lateinit var logCollector: LogCollector @Inject lateinit var routeChatGrpcAction: RouteChatGrpcAction @Inject @field:Named("grpc server") lateinit var serverUrl: HttpUrl + @Inject lateinit var metrics: FakeMetrics private lateinit var routeGuide: RouteGuideClient private lateinit var callCounter: RouteGuideCallCounter @@ -105,4 +118,60 @@ class MiskClientMiskServerTest { sendChannel.close() } } + + @Test + fun serverFailureGeneric() { + val point = Point( + latitude = -1, + longitude = 500 + ) + + runBlocking { + val e = assertFailsWith { + routeGuide.GetFeature().execute(point) + } + assertThat(e.grpcMessage).isEqualTo("Internal Server Error") + assertThat(e.grpcStatus).isEqualTo(GrpcStatus.UNKNOWN) + + // Assert that _metrics_ counted a 500 and no 200s, even though an HTTP 200 was returned + // over HTTP. The 200 is implicitly asserted by the fact that we got a GrpcException, which + // is only thrown if a properly constructed gRPC error is received. + assertResponseCount(200, 0) + assertResponseCount(500, 1) + } + } + + @Test + fun serverFailureNotFound() { + val point = Point( + latitude = -1, + longitude = 404 + ) + + runBlocking { + val e = assertFailsWith { + routeGuide.GetFeature().execute(point) + } + assertThat(e.grpcMessage).isEqualTo("unexpected latitude error!") + assertThat(e.grpcStatus).isEqualTo(GrpcStatus.UNIMPLEMENTED) + .withFailMessage("wrong gRPC status ${e.grpcStatus.name}") + + // Assert that _metrics_ counted a 404 and no 200s, even though an HTTP 200 was returned + // over HTTP. The 200 is implicitly asserted by the fact that we got a GrpcException, which + // is only thrown if a properly constructed gRPC error is received. + assertResponseCount(200, 0) + assertResponseCount(404, 1) + } + } + + private fun assertResponseCount(code: Int, count: Int) { + await withPollInterval ONE_MILLISECOND atMost ONE_HUNDRED_MILLISECONDS untilCallTo { + metrics.histogramCount( + "http_request_latency_ms", + "action" to "GetFeatureGrpcAction", + "caller" to "unknown", + "code" to code.toString(), + )?.toInt() ?: 0 + } matches { it == count } + } } diff --git a/misk-testing/src/main/kotlin/misk/web/FakeHttpCall.kt b/misk-testing/src/main/kotlin/misk/web/FakeHttpCall.kt index 80e62e60691..b7d4eccf516 100644 --- a/misk-testing/src/main/kotlin/misk/web/FakeHttpCall.kt +++ b/misk-testing/src/main/kotlin/misk/web/FakeHttpCall.kt @@ -16,6 +16,7 @@ data class FakeHttpCall( override val dispatchMechanism: DispatchMechanism = DispatchMechanism.GET, override val requestHeaders: Headers = headersOf(), override var statusCode: Int = 200, + override var networkStatusCode: Int = 200, val headersBuilder: Headers.Builder = Headers.Builder(), var sendTrailers: Boolean = false, val trailersBuilder: Headers.Builder = Headers.Builder(), @@ -28,6 +29,11 @@ data class FakeHttpCall( override val responseHeaders: Headers get() = headersBuilder.build() + override fun setStatusCodes(statusCode: Int, networkStatusCode: Int) { + this.statusCode = statusCode + this.networkStatusCode = networkStatusCode + } + override fun setResponseHeader(name: String, value: String) { headersBuilder.set(name, value) } diff --git a/misk/src/main/kotlin/misk/web/HttpCall.kt b/misk/src/main/kotlin/misk/web/HttpCall.kt index 2f487734928..1a1d488bd4f 100644 --- a/misk/src/main/kotlin/misk/web/HttpCall.kt +++ b/misk/src/main/kotlin/misk/web/HttpCall.kt @@ -30,10 +30,20 @@ interface HttpCall { val dispatchMechanism: DispatchMechanism val requestHeaders: Headers - /** The HTTP response under construction. */ + /** Meaningful HTTP status about what actually happened. Not sent over the wire in the case + * of gRPC, which always returns HTTP 200 even for errors. */ var statusCode: Int + + /** The HTTP status code actually sent over the network. For gRPC, this is always 200, even + * for errors, per the spec. **/ + val networkStatusCode: Int + val responseHeaders: Headers + /** Set both the raw network status code and the meaningful status code that's + * recorded in metrics */ + fun setStatusCodes(statusCode: Int, networkStatusCode: Int) + fun setResponseHeader(name: String, value: String) fun addResponseHeaders(headers: Headers) diff --git a/misk/src/main/kotlin/misk/web/ServletHttpCall.kt b/misk/src/main/kotlin/misk/web/ServletHttpCall.kt index c95f4a60b4d..04d8923a8fa 100644 --- a/misk/src/main/kotlin/misk/web/ServletHttpCall.kt +++ b/misk/src/main/kotlin/misk/web/ServletHttpCall.kt @@ -25,16 +25,26 @@ internal data class ServletHttpCall( var responseBody: BufferedSink? = null, var webSocket: WebSocket? = null ) : HttpCall { + private var _actualStatusCode: Int? = null override var statusCode: Int - get() = upstreamResponse.statusCode + get() = _actualStatusCode ?: upstreamResponse.statusCode set(value) { + _actualStatusCode = value upstreamResponse.statusCode = value } + override val networkStatusCode: Int + get() = upstreamResponse.statusCode + override val responseHeaders: Headers get() = upstreamResponse.headers + override fun setStatusCodes(statusCode: Int, networkStatusCode: Int) { + _actualStatusCode = statusCode + upstreamResponse.statusCode = networkStatusCode + } + override fun setResponseHeader(name: String, value: String) { upstreamResponse.setHeader(name, value) } diff --git a/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt b/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt index 608ef956534..92b560ea225 100644 --- a/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt +++ b/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt @@ -1,9 +1,14 @@ package misk.web.exceptions import com.google.common.util.concurrent.UncheckedExecutionException +import com.squareup.wire.GrpcStatus +import com.squareup.wire.ProtoAdapter import misk.Action import misk.exceptions.UnauthenticatedException import misk.exceptions.UnauthorizedException +import misk.grpc.GrpcMessageSink +import misk.web.DispatchMechanism +import misk.web.HttpCall import misk.web.NetworkChain import misk.web.NetworkInterceptor import misk.web.Response @@ -11,6 +16,9 @@ import misk.web.ResponseBody import misk.web.mediatype.MediaTypes import misk.web.toResponseBody import okhttp3.Headers.Companion.toHeaders +import okio.Buffer +import okio.BufferedSink +import okio.ByteString import wisp.logging.getLogger import wisp.logging.log import java.lang.reflect.InvocationTargetException @@ -37,13 +45,67 @@ class ExceptionHandlingInterceptor( } catch (th: Throwable) { val response = toResponse(th) chain.httpCall.statusCode = response.statusCode - chain.httpCall.takeResponseBody()?.use { sink -> - chain.httpCall.addResponseHeaders(response.headers) - (response.body as ResponseBody).writeTo(sink) + if (chain.httpCall.dispatchMechanism == DispatchMechanism.GRPC) { + sendGrpcFailure(chain.httpCall, response) + } else { + sendHttpFailure(chain.httpCall, response) } } } + private fun sendHttpFailure(httpCall: HttpCall, response: Response<*>) { + httpCall.takeResponseBody()?.use { sink -> + httpCall.addResponseHeaders(response.headers) + (response.body as ResponseBody).writeTo(sink) + } + } + + /** + * Borrow behavior from [GrpcFeatureBinding] to send a gRPC error with an HTTP 200 status code. + * This is weird but it's how gRPC clients work. + * + * One thing to note is for our metrics we want to pretend that the HTTP code is what we sent. + * Otherwise gRPC requests that crashed and yielded an HTTP 200 code will confuse operators. + */ + private fun sendGrpcFailure(httpCall: HttpCall, response: Response<*>) { + httpCall.setStatusCodes(httpCall.statusCode, 200) + httpCall.requireTrailers() + httpCall.setResponseHeader("grpc-encoding", "identity") + httpCall.setResponseHeader("Content-Type", MediaTypes.APPLICATION_GRPC) + httpCall.setResponseTrailer( + "grpc-status", + toGrpcStatus(response.statusCode).code.toString() + ) + httpCall.setResponseTrailer("grpc-message", this.grpcMessage(response)) + httpCall.takeResponseBody()?.use { responseBody: BufferedSink -> + GrpcMessageSink(responseBody, ProtoAdapter.BYTES, grpcEncoding = "identity") + .use { messageSink -> + messageSink.write(ByteString.EMPTY) + } + } + } + + private fun grpcMessage(response: Response<*>): String { + val buffer = Buffer() + (response.body as ResponseBody).writeTo(buffer) + return buffer.readUtf8() + } + + /** https://grpc.github.io/grpc/core/md_doc_http-grpc-status-mapping.html */ + private fun toGrpcStatus(statusCode: Int): GrpcStatus { + return when (statusCode) { + 400 -> GrpcStatus.INTERNAL + 401 -> GrpcStatus.UNAUTHENTICATED + 403 -> GrpcStatus.PERMISSION_DENIED + 404 -> GrpcStatus.UNIMPLEMENTED + 429 -> GrpcStatus.UNAVAILABLE + 502 -> GrpcStatus.UNAVAILABLE + 503 -> GrpcStatus.UNAVAILABLE + 504 -> GrpcStatus.UNAVAILABLE + else -> GrpcStatus.UNKNOWN + } + } + private fun toResponse(th: Throwable): Response<*> = when (th) { is UnauthenticatedException -> UNAUTHENTICATED_RESPONSE is UnauthorizedException -> UNAUTHORIZED_RESPONSE diff --git a/misk/src/test/kotlin/misk/grpc/GrpcConnectivityTest.kt b/misk/src/test/kotlin/misk/grpc/GrpcConnectivityTest.kt index dd1bf779036..4870351e028 100644 --- a/misk/src/test/kotlin/misk/grpc/GrpcConnectivityTest.kt +++ b/misk/src/test/kotlin/misk/grpc/GrpcConnectivityTest.kt @@ -3,6 +3,7 @@ package misk.grpc import com.google.inject.Guice import com.squareup.protos.test.grpc.HelloReply import com.squareup.protos.test.grpc.HelloRequest +import com.squareup.wire.GrpcStatus import com.squareup.wire.Service import com.squareup.wire.WireRpc import misk.MiskTestingServiceModule @@ -12,7 +13,6 @@ import misk.testing.MiskTest import misk.testing.MiskTestModule import misk.web.WebActionModule import misk.web.WebServerTestingModule -import misk.web.WebTestingModule import misk.web.actions.WebAction import misk.web.jetty.JettyService import misk.web.mediatype.MediaTypes @@ -119,12 +119,11 @@ class GrpcConnectivityTest { val call = client.newCall(request) val response = call.execute() response.use { - assertThat(response.code).isEqualTo(400) - assertThat(response.body!!.string()).isEqualTo("bad request!") - assertThat(response.headers["grpc-status"]).isNull() - assertThat(response.headers["grpc-encoding"]).isNull() - assertThat(response.trailers().size).isEqualTo(0) - assertThat(response.body?.contentType()).isEqualTo("text/plain;charset=utf-8".toMediaType()) + assertThat(response.code).isEqualTo(200) + assertThat(response.headers["grpc-encoding"]).isEqualTo("identity") + assertThat(response.body!!.contentType()).isEqualTo("application/grpc".toMediaType()) + response.body?.close() + assertThat(response.trailers()["grpc-status"]).isEqualTo(GrpcStatus.INTERNAL.code.toString()) } }