Skip to content

Commit

Permalink
Merge pull request #70 from IZIVIA/feat/#62-provide-context-with-from…
Browse files Browse the repository at this point in the history
…-to-headers

feat(common) #62: Message routing headers
  • Loading branch information
lilgallon authored Jan 19, 2024
2 parents 4826e59 + 0911532 commit 5ebdf8b
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package com.izivia.ocpi.toolkit.common

import com.fasterxml.jackson.core.type.TypeReference
import com.izivia.ocpi.toolkit.common.Header.OCPI_FROM_COUNTRY_CODE
import com.izivia.ocpi.toolkit.common.Header.OCPI_FROM_PARTY_ID
import com.izivia.ocpi.toolkit.common.Header.OCPI_TO_COUNTRY_CODE
import com.izivia.ocpi.toolkit.common.Header.OCPI_TO_PARTY_ID
import com.izivia.ocpi.toolkit.common.context.*
import com.izivia.ocpi.toolkit.common.validation.validate
import com.izivia.ocpi.toolkit.common.validation.validateLength
import com.izivia.ocpi.toolkit.modules.credentials.repositories.PartnerRepository
import com.izivia.ocpi.toolkit.modules.versions.domain.ModuleID
import com.izivia.ocpi.toolkit.transport.TransportClient
Expand All @@ -18,12 +25,25 @@ object Header {
const val X_LIMIT = "X-Limit"
const val LINK = "Link"
const val CONTENT_TYPE = "Content-Type"
const val OCPI_TO_PARTY_ID = "OCPI-to-party-id"
const val OCPI_TO_COUNTRY_CODE = "OCPI-to-country-code"
const val OCPI_FROM_PARTY_ID = "OCPI-from-party-id"
const val OCPI_FROM_COUNTRY_CODE = "OCPI-from-country-code"
}

object ContentType {
const val APPLICATION_JSON = "application/json"
}

fun Map<String, String>.validateMessageRoutingHeaders() {
validate {
validateLength(OCPI_TO_PARTY_ID, getByNormalizedKey(OCPI_TO_PARTY_ID).orEmpty(), 3)
validateLength(OCPI_TO_COUNTRY_CODE, getByNormalizedKey(OCPI_TO_COUNTRY_CODE).orEmpty(), 2)
validateLength(OCPI_FROM_PARTY_ID, getByNormalizedKey(OCPI_FROM_PARTY_ID).orEmpty(), 3)
validateLength(OCPI_FROM_COUNTRY_CODE, getByNormalizedKey(OCPI_FROM_COUNTRY_CODE).orEmpty(), 2)
}
}

/**
* Parse body of a paginated request. The result will be stored in a SearchResult which contains all pagination
* information.
Expand Down Expand Up @@ -130,14 +150,59 @@ fun HttpRequest.authenticate(token: String): AuthenticatedHttpRequest =
* It adds Content-Type header as "application/json" if the body is not null.
*/
private fun HttpRequest.withContentTypeHeaderIfNeeded(): HttpRequest =
withHeaders(
headers = if (body != null) {
headers.plus(Header.CONTENT_TYPE to ContentType.APPLICATION_JSON)
} else {
headers
}
if (body != null) {
withHeaders(headers = headers.plus(Header.CONTENT_TYPE to ContentType.APPLICATION_JSON))
} else {
this
}

/**
* It adds message routing header if they are set in the current coroutine context
*/
private suspend fun HttpRequest.withRequestMessageRoutingHeadersIfPresent(): HttpRequest {
val requestMessageRoutingHeaders = currentRequestMessageRoutingHeadersOrNull()

return if (requestMessageRoutingHeaders != null) {
withHeaders(headers = headers.plus(requestMessageRoutingHeaders.httpHeaders()))
} else {
this
}
}

/**
* It builds MessageRoutingHeaders from the headers of the request.
*/
fun HttpRequest.messageRoutingHeaders(): RequestMessageRoutingHeaders =
RequestMessageRoutingHeaders(
toPartyId = headers.getByNormalizedKey(OCPI_TO_PARTY_ID),
toCountryCode = headers.getByNormalizedKey(OCPI_TO_COUNTRY_CODE),
fromPartyId = headers.getByNormalizedKey(OCPI_FROM_PARTY_ID),
fromCountryCode = headers.getByNormalizedKey(OCPI_FROM_COUNTRY_CODE)
)

/**
* It builds headers from a ResponseMessageRoutingHeaders
*/
private fun RequestMessageRoutingHeaders.httpHeaders(): Map<String, String> =
mapOf(
OCPI_TO_PARTY_ID to toPartyId,
OCPI_TO_COUNTRY_CODE to toCountryCode,
OCPI_FROM_PARTY_ID to fromPartyId,
OCPI_FROM_COUNTRY_CODE to fromCountryCode
)
.filter { it.value != null }
.mapValues { it.value!! }

fun ResponseMessageRoutingHeaders.httpHeaders(): Map<String, String> =
mapOf(
OCPI_TO_PARTY_ID to toPartyId,
OCPI_TO_COUNTRY_CODE to toCountryCode,
OCPI_FROM_PARTY_ID to fromPartyId,
OCPI_FROM_COUNTRY_CODE to fromCountryCode
)
.filter { it.value != null }
.mapValues { it.value!! }

/**
* For debugging issues, OCPI implementations are required to include unique IDs via HTTP headers in every
* request/response.
Expand All @@ -156,15 +221,17 @@ private fun HttpRequest.withContentTypeHeaderIfNeeded(): HttpRequest =
* Dev note: When the server does a request (not a response), it must keep the same X-Correlation-ID but generate a new
* X-Request-ID. So don't call this method in that case.
*/
fun HttpRequest.withRequiredHeaders(
suspend fun HttpRequest.withRequiredHeaders(
requestId: String,
correlationId: String
): HttpRequest =
withHeaders(
headers = headers
.plus(Header.X_REQUEST_ID to requestId)
.plus(Header.X_CORRELATION_ID to correlationId)
).withContentTypeHeaderIfNeeded()
)
.withContentTypeHeaderIfNeeded()
.withRequestMessageRoutingHeadersIfPresent()

/**
* For debugging issues, OCPI implementations are required to include unique IDs via HTTP headers in every
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.izivia.ocpi.toolkit.common

import com.fasterxml.jackson.core.JsonProcessingException
import com.izivia.ocpi.toolkit.common.context.currentResponseMessageRoutingHeadersOrNull
import com.izivia.ocpi.toolkit.common.validation.toReadableString
import com.izivia.ocpi.toolkit.transport.domain.HttpException
import com.izivia.ocpi.toolkit.transport.domain.HttpRequest
Expand Down Expand Up @@ -139,10 +140,11 @@ suspend fun <T> HttpRequest.httpResponse(fn: suspend () -> OcpiResponseBody<T>):
),
headers = getDebugHeaders()
.plus(Header.CONTENT_TYPE to ContentType.APPLICATION_JSON)
.plus(currentResponseMessageRoutingHeadersOrNull()?.httpHeaders().orEmpty())
).let {
if (isPaginated) {
it.copy(
headers = (ocpiResponseBody as OcpiResponseBody<SearchResult<*>>)
headers = it.headers + (ocpiResponseBody as OcpiResponseBody<SearchResult<*>>)
.getPaginatedHeaders(request = this)
)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.izivia.ocpi.toolkit.common.context

import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext

/**
* Contains context about the current MessageRoutingHeaders
*/
data class RequestMessageRoutingHeaders(
val toPartyId: String? = null,
val toCountryCode: String? = null,
val fromPartyId: String? = null,
val fromCountryCode: String? = null
) : AbstractCoroutineContextElement(RequestMessageRoutingHeaders) {
companion object Key : CoroutineContext.Key<RequestMessageRoutingHeaders>
}

/**
* Retrieves MessageRoutingHeaders in the current coroutine if it is found.
*/
suspend fun currentRequestMessageRoutingHeadersOrNull(): RequestMessageRoutingHeaders? =
coroutineContext[RequestMessageRoutingHeaders]

/**
* Retrieves MessageRoutingHeaders in the current coroutine, and throws IllegalStateException
* if it could not be found.
*/
suspend fun currentRequestMessageRoutingHeaders(): RequestMessageRoutingHeaders =
coroutineContext[RequestMessageRoutingHeaders]
?: throw IllegalStateException("No MessageRoutingHeaders object in current coroutine context")
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.izivia.ocpi.toolkit.common.context

import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext

/**
* Contains context about the current MessageRoutingHeaders
*/
data class ResponseMessageRoutingHeaders(
var toPartyId: String? = null,
var toCountryCode: String? = null,
var fromPartyId: String? = null,
var fromCountryCode: String? = null
) : AbstractCoroutineContextElement(ResponseMessageRoutingHeaders) {
companion object Key : CoroutineContext.Key<ResponseMessageRoutingHeaders> {
/**
* Creates [ResponseMessageRoutingHeaders] by inverting "from" and "to" headers of
* the [RequestMessageRoutingHeaders].
*/
fun invertFromRequest(requestMessageRoutingHeaders: RequestMessageRoutingHeaders) =
ResponseMessageRoutingHeaders(
toPartyId = requestMessageRoutingHeaders.fromPartyId,
toCountryCode = requestMessageRoutingHeaders.fromCountryCode,
fromPartyId = requestMessageRoutingHeaders.toPartyId,
fromCountryCode = requestMessageRoutingHeaders.toCountryCode
)
}
}

/**
* Retrieves MessageRoutingHeaders in the current coroutine if it is found.
*/
suspend fun currentResponseMessageRoutingHeadersOrNull(): ResponseMessageRoutingHeaders? =
coroutineContext[ResponseMessageRoutingHeaders]

/**
* Retrieves MessageRoutingHeaders in the current coroutine, and throws IllegalStateException
* if it could not be found.
*/
suspend fun currentResponseMessageRoutingHeaders(): ResponseMessageRoutingHeaders =
coroutineContext[ResponseMessageRoutingHeaders]
?: throw IllegalStateException("No MessageRoutingHeaders object in current coroutine context")
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package com.izivia.ocpi.toolkit.samples.common

import com.izivia.ocpi.toolkit.common.OcpiException
import com.izivia.ocpi.toolkit.common.toHttpResponse
import com.izivia.ocpi.toolkit.common.*
import com.izivia.ocpi.toolkit.common.context.ResponseMessageRoutingHeaders
import com.izivia.ocpi.toolkit.common.validation.toReadableString
import com.izivia.ocpi.toolkit.transport.TransportServer
import com.izivia.ocpi.toolkit.transport.domain.*
import kotlinx.coroutines.runBlocking
import org.http4k.core.*
import org.http4k.filter.DebuggingFilters
import org.http4k.routing.RoutingHttpHandler
import org.http4k.routing.bind
import org.http4k.routing.path
import org.http4k.routing.routes
import org.http4k.routing.*
import org.http4k.server.Http4kServer
import org.http4k.server.Netty
import org.http4k.server.asServer
import org.valiktor.ConstraintViolationException

class Http4kTransportServer(
val baseUrl: String,
Expand All @@ -34,7 +33,7 @@ class Http4kTransportServer(
callback: suspend (request: HttpRequest) -> HttpResponse
) {
val pathParams = path
.filterIsInstance(VariablePathSegment::class.java)
.filterIsInstance<VariablePathSegment>()
.map { it.path }

val route = path.joinToString("/") { segment ->
Expand All @@ -58,12 +57,25 @@ class Http4kTransportServer(
.associate { (key, value) -> key to value!! },
body = req.bodyString()
)
.also { httpRequest ->
try {
httpRequest.headers.validateMessageRoutingHeaders()
} catch (e: ConstraintViolationException) {
throw OcpiClientInvalidParametersException(
message = "invalid message routing headers: " + e.toReadableString()
)
}
}
.also { httpRequest ->
runBlocking { secureFilter(httpRequest) }
}
.also { httpRequest -> filters.forEach { filter -> filter(httpRequest) } }
.let { httpRequest ->
httpRequest to runBlocking {
val requestMessageRoutingHeaders = httpRequest.messageRoutingHeaders()
val responseMessageRoutingHeaders = ResponseMessageRoutingHeaders
.invertFromRequest(requestMessageRoutingHeaders)

httpRequest to runBlocking(requestMessageRoutingHeaders + responseMessageRoutingHeaders) {
callback(httpRequest)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.izivia.ocpi.toolkit.tests.integration

import com.izivia.ocpi.toolkit.common.Header
import com.izivia.ocpi.toolkit.common.OcpiStatus
import com.izivia.ocpi.toolkit.common.context.RequestMessageRoutingHeaders
import com.izivia.ocpi.toolkit.modules.locations.LocationsCpoServer
import com.izivia.ocpi.toolkit.modules.locations.LocationsEmspClient
import com.izivia.ocpi.toolkit.modules.locations.domain.Location
Expand All @@ -10,6 +12,7 @@ import com.izivia.ocpi.toolkit.modules.versions.repositories.InMemoryVersionsRep
import com.izivia.ocpi.toolkit.samples.common.*
import com.izivia.ocpi.toolkit.tests.integration.common.BaseServerIntegrationTest
import com.izivia.ocpi.toolkit.tests.integration.mock.LocationsCpoMongoRepository
import com.izivia.ocpi.toolkit.transport.domain.HttpMethod
import com.mongodb.client.MongoDatabase
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -83,8 +86,15 @@ class LocationsIntegrationTest : BaseServerIntegrationTest() {
var dateFrom: Instant? = null
var dateTo: Instant? = null

val requestMessageRoutingHeaders = RequestMessageRoutingHeaders(
toPartyId = "AAA",
toCountryCode = "AA",
fromPartyId = "BBB",
fromCountryCode = "BB"
)

expectThat(
runBlocking {
runBlocking(requestMessageRoutingHeaders) {
locationsEmspClient.getLocations(
dateFrom = dateFrom,
dateTo = dateTo,
Expand Down Expand Up @@ -125,6 +135,36 @@ class LocationsIntegrationTest : BaseServerIntegrationTest() {
}
}

expectThat(cpoServer.requestHistory)
.hasSize(1)[0]
.and {
get { first }.and {
// request
get { method }.isEqualTo(HttpMethod.GET)
get { path }.isEqualTo("/2.2.1/locations")
get { headers[Header.OCPI_FROM_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromPartyId)
get { headers[Header.OCPI_FROM_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromCountryCode)
get { headers[Header.OCPI_TO_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toPartyId)
get { headers[Header.OCPI_TO_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toCountryCode)
}

get { second }.and {
// response
get { headers[Header.OCPI_FROM_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toPartyId)
get { headers[Header.OCPI_FROM_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toCountryCode)
get { headers[Header.OCPI_TO_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromPartyId)
get { headers[Header.OCPI_TO_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromCountryCode)
}
}

limit = 100
offset = 100
dateFrom = null
Expand Down

0 comments on commit 5ebdf8b

Please sign in to comment.