Skip to content

Commit

Permalink
Send gRPC errors properly
Browse files Browse the repository at this point in the history
According to the [gRPC spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md), all
gRPC responses, including errors, should have an HTTP status of 200. gRPC errors are signaled using
the `grpc-status` header. This brings Misk in compliance with the spec, translating HTTP errors into
properly formatted gRPC errors.
  • Loading branch information
squarejesse authored and ewolak-sq committed Jun 18, 2021
1 parent 2afd8c7 commit c7ea3c0
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 13 deletions.
2 changes: 2 additions & 0 deletions misk-grpc-tests/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ dependencies {
implementation(project(":misk-actions"))
implementation(project(":misk-core"))
implementation(project(":misk-inject"))
implementation(project(":misk-metrics"))
implementation(project(":misk-metrics-testing"))
implementation(project(":misk-service"))
implementation(project(":misk-testing"))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package misk.grpc.miskserver

import misk.exceptions.WebActionException
import misk.web.actions.WebAction
import misk.web.interceptors.LogRequestResponse
import routeguide.Feature
Expand All @@ -10,6 +11,9 @@ import javax.inject.Inject
class GetFeatureGrpcAction @Inject constructor() : WebAction, RouteGuideGetFeatureBlockingServer {
@LogRequestResponse(bodySampling = 1.0, errorBodySampling = 1.0)
override fun GetFeature(request: Point): Feature {
if (request.latitude == -1) {
throw WebActionException(request.longitude ?: 500, "unexpected latitude error!")
}
return Feature(name = "maple tree", location = request)
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package misk.grpc.miskserver

import com.google.inject.Provides
import com.google.inject.util.Modules
import misk.MiskTestingServiceModule
import misk.inject.KAbstractModule
import misk.metrics.FakeMetricsModule
import misk.web.WebActionModule
import misk.web.WebServerTestingModule
import misk.web.jetty.JettyService
Expand All @@ -12,7 +14,7 @@ import javax.inject.Named
class RouteGuideMiskServiceModule : KAbstractModule() {
override fun configure() {
install(WebServerTestingModule(webConfig = WebServerTestingModule.TESTING_WEB_CONFIG.copy(http2 = true)))
install(MiskTestingServiceModule())
install(Modules.override(MiskTestingServiceModule()).with(FakeMetricsModule()))
install(WebActionModule.create<GetFeatureGrpcAction>())
install(WebActionModule.create<RouteChatGrpcAction>())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package misk.grpc

import com.google.inject.Guice
import com.google.inject.util.Modules
import javax.inject.Inject
import javax.inject.Named
import com.squareup.wire.GrpcException
import com.squareup.wire.GrpcStatus
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.runBlocking
Expand All @@ -13,6 +13,7 @@ import misk.grpc.miskserver.RouteChatGrpcAction
import misk.grpc.miskserver.RouteGuideMiskServiceModule
import misk.inject.getInstance
import misk.logging.LogCollectorModule
import misk.metrics.FakeMetrics
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.web.interceptors.RequestLoggingInterceptor
Expand All @@ -25,6 +26,9 @@ import routeguide.Point
import routeguide.RouteGuideClient
import routeguide.RouteNote
import wisp.logging.LogCollector
import javax.inject.Inject
import javax.inject.Named
import kotlin.test.assertFailsWith

@MiskTest(startService = true)
class MiskClientMiskServerTest {
Expand All @@ -37,6 +41,7 @@ class MiskClientMiskServerTest {
@Inject lateinit var logCollector: LogCollector
@Inject lateinit var routeChatGrpcAction: RouteChatGrpcAction
@Inject @field:Named("grpc server") lateinit var serverUrl: HttpUrl
@Inject lateinit var metrics: FakeMetrics

private lateinit var routeGuide: RouteGuideClient
private lateinit var callCounter: RouteGuideCallCounter
Expand Down Expand Up @@ -105,4 +110,55 @@ class MiskClientMiskServerTest {
sendChannel.close()
}
}

@Test
fun serverFailureGeneric() {
val point = Point(
latitude = -1,
longitude = 500
)

runBlocking {
val e = assertFailsWith<GrpcException> {
routeGuide.GetFeature().execute(point)
}
assertThat(e.grpcMessage).isEqualTo("unexpected latitude error!")
assertThat(e.grpcStatus).isEqualTo(GrpcStatus.UNKNOWN)

assertResponseCount(200, 0)
assertResponseCount(500, 1)
}
}

@Test
fun serverFailureNotFound() {
val point = Point(
latitude = -1,
longitude = 404
)

runBlocking {
val e = assertFailsWith<GrpcException> {
routeGuide.GetFeature().execute(point)
}
assertThat(e.grpcMessage).isEqualTo("unexpected latitude error!")
assertThat(e.grpcStatus).isEqualTo(GrpcStatus.UNIMPLEMENTED)
.withFailMessage("wrong gRPC status ${e.grpcStatus.name}")

assertResponseCount(200, 0)
assertResponseCount(404, 1)
}
}

private fun assertResponseCount(code: Int, count: Int) {
val responseCount = metrics.histogramCount(
"http_request_latency_ms",
"action" to "GetFeatureGrpcAction",
"caller" to "unknown",
"code" to code.toString(),
)?.toInt() ?: 0
assertThat(responseCount)
.withFailMessage("Expected metrics to indicate $count responses with HTTP status $code but got $responseCount")
.isEqualTo(count)
}
}
1 change: 1 addition & 0 deletions misk-testing/src/main/kotlin/misk/web/FakeHttpCall.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data class FakeHttpCall(
override val dispatchMechanism: DispatchMechanism = DispatchMechanism.GET,
override val requestHeaders: Headers = headersOf(),
override var statusCode: Int = 200,
override var networkStatusCode: Int = 200,
val headersBuilder: Headers.Builder = Headers.Builder(),
var sendTrailers: Boolean = false,
val trailersBuilder: Headers.Builder = Headers.Builder(),
Expand Down
5 changes: 5 additions & 0 deletions misk/src/main/kotlin/misk/web/HttpCall.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ interface HttpCall {
val requestHeaders: Headers

/** The HTTP response under construction. */
/** Meaningful HTTP status about what actually happened */
var statusCode: Int

/** The HTTP status code actually sent over the network. For gRPC, this is *always* 200, even
* for errors. */
var networkStatusCode: Int
val responseHeaders: Headers

fun setResponseHeader(name: String, value: String)
Expand Down
8 changes: 8 additions & 0 deletions misk/src/main/kotlin/misk/web/ServletHttpCall.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ internal data class ServletHttpCall(
var responseBody: BufferedSink? = null,
var webSocket: WebSocket? = null
) : HttpCall {
private var _actualStatusCode: Int? = null

override var statusCode: Int
get() = _actualStatusCode ?: upstreamResponse.statusCode
set(value) {
_actualStatusCode = value
upstreamResponse.statusCode = value
}

override var networkStatusCode: Int
get() = upstreamResponse.statusCode
set(value) {
upstreamResponse.statusCode = value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
package misk.web.exceptions

import com.google.common.util.concurrent.UncheckedExecutionException
import com.squareup.wire.GrpcStatus
import com.squareup.wire.ProtoAdapter
import misk.Action
import misk.exceptions.StatusCode
import misk.exceptions.UnauthenticatedException
import misk.exceptions.UnauthorizedException
import misk.grpc.GrpcMessageSink
import misk.web.DispatchMechanism
import misk.web.HttpCall
import misk.web.NetworkChain
import misk.web.NetworkInterceptor
import misk.web.Response
import misk.web.ResponseBody
import misk.web.mediatype.MediaTypes
import misk.web.toResponseBody
import okhttp3.Headers.Companion.toHeaders
import okio.Buffer
import okio.BufferedSink
import okio.ByteString
import wisp.logging.getLogger
import wisp.logging.log
import java.lang.reflect.InvocationTargetException
Expand All @@ -37,13 +45,67 @@ class ExceptionHandlingInterceptor(
} catch (th: Throwable) {
val response = toResponse(th)
chain.httpCall.statusCode = response.statusCode
chain.httpCall.takeResponseBody()?.use { sink ->
chain.httpCall.addResponseHeaders(response.headers)
(response.body as ResponseBody).writeTo(sink)
if (chain.httpCall.dispatchMechanism == DispatchMechanism.GRPC) {
sendGrpcFailure(chain.httpCall, response)
} else {
sendHttpFailure(chain.httpCall, response)
}
}
}

private fun sendHttpFailure(httpCall: HttpCall, response: Response<*>) {
httpCall.takeResponseBody()?.use { sink ->
httpCall.addResponseHeaders(response.headers)
(response.body as ResponseBody).writeTo(sink)
}
}

/**
* Borrow behavior from [GrpcFeatureBinding] to send a gRPC error with an HTTP 200 status code.
* This is weird but it's how gRPC clients work.
*
* One thing to note is for our metrics we want to pretend that the HTTP code is what we sent.
* Otherwise gRPC requests that crashed and yielded an HTTP 200 code will confuse operators.
*/
private fun sendGrpcFailure(httpCall: HttpCall, response: Response<*>) {
httpCall.networkStatusCode = 200
httpCall.requireTrailers()
httpCall.setResponseHeader("grpc-encoding", "identity")
httpCall.setResponseHeader("Content-Type", MediaTypes.APPLICATION_GRPC)
httpCall.setResponseTrailer(
"grpc-status",
toGrpcStatus(response.statusCode).code.toString()
)
httpCall.setResponseTrailer("grpc-message", this.grpcMessage(response))
httpCall.takeResponseBody()?.use { responseBody: BufferedSink ->
GrpcMessageSink(responseBody, ProtoAdapter.BYTES, grpcEncoding = "identity")
.use { messageSink ->
messageSink.write(ByteString.EMPTY)
}
}
}

private fun grpcMessage(response: Response<*>): String {
val buffer = Buffer()
(response.body as ResponseBody).writeTo(buffer)
return buffer.readUtf8()
}

/** https://grpc.github.io/grpc/core/md_doc_http-grpc-status-mapping.html */
private fun toGrpcStatus(statusCode: Int): GrpcStatus {
return when (statusCode) {
400 -> GrpcStatus.INTERNAL
401 -> GrpcStatus.UNAUTHENTICATED
403 -> GrpcStatus.PERMISSION_DENIED
404 -> GrpcStatus.UNIMPLEMENTED
429 -> GrpcStatus.UNAVAILABLE
502 -> GrpcStatus.UNAVAILABLE
503 -> GrpcStatus.UNAVAILABLE
504 -> GrpcStatus.UNAVAILABLE
else -> GrpcStatus.UNKNOWN
}
}

private fun toResponse(th: Throwable): Response<*> = when (th) {
is UnauthenticatedException -> UNAUTHENTICATED_RESPONSE
is UnauthorizedException -> UNAUTHORIZED_RESPONSE
Expand Down
13 changes: 6 additions & 7 deletions misk/src/test/kotlin/misk/grpc/GrpcConnectivityTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package misk.grpc
import com.google.inject.Guice
import com.squareup.protos.test.grpc.HelloReply
import com.squareup.protos.test.grpc.HelloRequest
import com.squareup.wire.GrpcStatus
import com.squareup.wire.Service
import com.squareup.wire.WireRpc
import misk.MiskTestingServiceModule
Expand All @@ -12,7 +13,6 @@ import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.web.WebActionModule
import misk.web.WebServerTestingModule
import misk.web.WebTestingModule
import misk.web.actions.WebAction
import misk.web.jetty.JettyService
import misk.web.mediatype.MediaTypes
Expand Down Expand Up @@ -119,12 +119,11 @@ class GrpcConnectivityTest {
val call = client.newCall(request)
val response = call.execute()
response.use {
assertThat(response.code).isEqualTo(400)
assertThat(response.body!!.string()).isEqualTo("bad request!")
assertThat(response.headers["grpc-status"]).isNull()
assertThat(response.headers["grpc-encoding"]).isNull()
assertThat(response.trailers().size).isEqualTo(0)
assertThat(response.body?.contentType()).isEqualTo("text/plain;charset=utf-8".toMediaType())
assertThat(response.code).isEqualTo(200)
assertThat(response.headers["grpc-encoding"]).isEqualTo("identity")
assertThat(response.body!!.contentType()).isEqualTo("application/grpc".toMediaType())
response.body?.close()
assertThat(response.trailers()["grpc-status"]).isEqualTo(GrpcStatus.INTERNAL.code.toString())
}
}

Expand Down

0 comments on commit c7ea3c0

Please sign in to comment.