Skip to content

Commit

Permalink
Merge pull request #4101 from element-hq/feature/bma/unifiedPushGatew…
Browse files Browse the repository at this point in the history
…ayResolverImprovement

Unified push gateway resolver improvement
  • Loading branch information
bmarty authored Jan 2, 2025
2 parents a411ac8 + bbe0f10 commit 5880bbb
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,63 @@ import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.core.data.tryOrNull
import io.element.android.libraries.di.AppScope
import kotlinx.coroutines.withContext
import retrofit2.HttpException
import timber.log.Timber
import java.net.HttpURLConnection
import java.net.URL
import javax.inject.Inject

sealed interface UnifiedPushGatewayResolverResult {
data class Success(val gateway: String) : UnifiedPushGatewayResolverResult
data class Error(val gateway: String) : UnifiedPushGatewayResolverResult
data object NoMatrixGateway : UnifiedPushGatewayResolverResult
data object ErrorInvalidUrl : UnifiedPushGatewayResolverResult
}

interface UnifiedPushGatewayResolver {
suspend fun getGateway(endpoint: String): String
suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult
}

@ContributesBinding(AppScope::class)
class DefaultUnifiedPushGatewayResolver @Inject constructor(
private val unifiedPushApiFactory: UnifiedPushApiFactory,
private val coroutineDispatchers: CoroutineDispatchers,
) : UnifiedPushGatewayResolver {
private val logger = Timber.tag("DefaultUnifiedPushGatewayResolver")

override suspend fun getGateway(endpoint: String): String {
override suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult {
val url = tryOrNull(
onError = { logger.d(it, "Cannot parse endpoint as an URL") }
onError = { Timber.tag("DefaultUnifiedPushGatewayResolver").d(it, "Cannot parse endpoint as an URL") }
) {
URL(endpoint)
}
return if (url == null) {
logger.d("Using default gateway")
UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL
Timber.tag("DefaultUnifiedPushGatewayResolver").d("ErrorInvalidUrl")
UnifiedPushGatewayResolverResult.ErrorInvalidUrl
} else {
val port = if (url.port != -1) ":${url.port}" else ""
val customBase = "${url.protocol}://${url.host}$port"
val customUrl = "$customBase/_matrix/push/v1/notify"
logger.i("Testing $customUrl")
Timber.tag("DefaultUnifiedPushGatewayResolver").i("Testing $customUrl")
return withContext(coroutineDispatchers.io) {
val api = unifiedPushApiFactory.create(customBase)
try {
val discoveryResponse = api.discover()
if (discoveryResponse.unifiedpush.gateway == "matrix") {
logger.d("The endpoint seems to be a valid UnifiedPush gateway")
Timber.tag("DefaultUnifiedPushGatewayResolver").d("The endpoint seems to be a valid UnifiedPush gateway")
UnifiedPushGatewayResolverResult.Success(customUrl)
} else {
logger.e("The endpoint does not seem to be a valid UnifiedPush gateway")
// The endpoint returned a 200 OK but didn't promote an actual matrix gateway, which means it doesn't have any
Timber.tag("DefaultUnifiedPushGatewayResolver").w("The endpoint does not seem to be a valid UnifiedPush gateway, using fallback")
UnifiedPushGatewayResolverResult.NoMatrixGateway
}
} catch (throwable: Throwable) {
logger.e(throwable, "Error checking for UnifiedPush endpoint")
if ((throwable as? HttpException)?.code() == HttpURLConnection.HTTP_NOT_FOUND) {
Timber.tag("DefaultUnifiedPushGatewayResolver").i("Checking for UnifiedPush endpoint yielded 404, using fallback")
UnifiedPushGatewayResolverResult.NoMatrixGateway
} else {
Timber.tag("DefaultUnifiedPushGatewayResolver").e(throwable, "Error checking for UnifiedPush endpoint")
UnifiedPushGatewayResolverResult.Error(customUrl)
}
}
// Always return the custom url.
customUrl
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2024 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only
* Please see LICENSE in the repository root for full details.
*/

package io.element.android.libraries.pushproviders.unifiedpush

import com.squareup.anvil.annotations.ContributesBinding
import io.element.android.libraries.di.AppScope
import javax.inject.Inject

interface UnifiedPushGatewayUrlResolver {
fun resolve(
gatewayResult: UnifiedPushGatewayResolverResult,
instance: String,
): String
}

@ContributesBinding(AppScope::class)
class DefaultUnifiedPushGatewayUrlResolver @Inject constructor(
private val unifiedPushStore: UnifiedPushStore,
) : UnifiedPushGatewayUrlResolver {
override fun resolve(
gatewayResult: UnifiedPushGatewayResolverResult,
instance: String,
): String {
return when (gatewayResult) {
is UnifiedPushGatewayResolverResult.Error -> {
// Use previous gateway if any, or the provided one
unifiedPushStore.getPushGateway(instance) ?: gatewayResult.gateway
}
UnifiedPushGatewayResolverResult.ErrorInvalidUrl,
UnifiedPushGatewayResolverResult.NoMatrixGateway -> {
UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL
}
is UnifiedPushGatewayResolverResult.Success -> {
gatewayResult.gateway
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class VectorUnifiedPushMessagingReceiver : MessagingReceiver() {
@Inject lateinit var guardServiceStarter: GuardServiceStarter
@Inject lateinit var unifiedPushStore: UnifiedPushStore
@Inject lateinit var unifiedPushGatewayResolver: UnifiedPushGatewayResolver
@Inject lateinit var unifiedPushGatewayUrlResolver: UnifiedPushGatewayUrlResolver
@Inject lateinit var newGatewayHandler: UnifiedPushNewGatewayHandler
@Inject lateinit var endpointRegistrationHandler: EndpointRegistrationHandler
@Inject lateinit var coroutineScope: CoroutineScope
Expand Down Expand Up @@ -64,6 +65,9 @@ class VectorUnifiedPushMessagingReceiver : MessagingReceiver() {
Timber.tag(loggerTag.value).i("onNewEndpoint: $endpoint")
coroutineScope.launch {
val gateway = unifiedPushGatewayResolver.getGateway(endpoint)
.let { gatewayResult ->
unifiedPushGatewayUrlResolver.resolve(gatewayResult, instance)
}
unifiedPushStore.storePushGateway(instance, gateway)
val result = newGatewayHandler.handle(endpoint, gateway, instance)
.onFailure {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import io.element.android.libraries.pushproviders.unifiedpush.network.DiscoveryU
import io.element.android.tests.testutils.testCoroutineDispatchers
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runTest
import okhttp3.ResponseBody.Companion.toResponseBody
import org.junit.Test
import retrofit2.HttpException
import retrofit2.Response
import java.net.HttpURLConnection

internal val matrixDiscoveryResponse = {
DiscoveryResponse(
Expand Down Expand Up @@ -43,7 +47,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url")
assertThat(result).isEqualTo("https://custom.url/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url/_matrix/push/v1/notify"))
}

@Test
Expand All @@ -56,7 +60,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url:123")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url:123")
assertThat(result).isEqualTo("https://custom.url:123/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url:123/_matrix/push/v1/notify"))
}

@Test
Expand All @@ -69,7 +73,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url:123/some/path")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url:123")
assertThat(result).isEqualTo("https://custom.url:123/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url:123/_matrix/push/v1/notify"))
}

@Test
Expand All @@ -82,7 +86,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("http://custom.url:123/some/path")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url:123")
assertThat(result).isEqualTo("http://custom.url:123/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("http://custom.url:123/_matrix/push/v1/notify"))
}

@Test
Expand All @@ -95,11 +99,41 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("http://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url")
assertThat(result).isEqualTo("http://custom.url/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Error("http://custom.url/_matrix/push/v1/notify"))
}

@Test
fun `when a custom url is invalid, the default url is returned`() = runTest {
fun `when a custom url is not found (404), NoMatrixGateway is returned`() = runTest {
val unifiedPushApiFactory = FakeUnifiedPushApiFactory(
discoveryResponse = {
throw HttpException(Response.error<Unit>(HttpURLConnection.HTTP_NOT_FOUND, "".toResponseBody()))
}
)
val sut = createDefaultUnifiedPushGatewayResolver(
unifiedPushApiFactory = unifiedPushApiFactory
)
val result = sut.getGateway("http://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.NoMatrixGateway)
}

@Test
fun `when a custom url is forbidden (403), Error is returned`() = runTest {
val unifiedPushApiFactory = FakeUnifiedPushApiFactory(
discoveryResponse = {
throw HttpException(Response.error<Unit>(HttpURLConnection.HTTP_FORBIDDEN, "".toResponseBody()))
}
)
val sut = createDefaultUnifiedPushGatewayResolver(
unifiedPushApiFactory = unifiedPushApiFactory
)
val result = sut.getGateway("http://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Error("http://custom.url/_matrix/push/v1/notify"))
}

@Test
fun `when a custom url is invalid, ErrorInvalidUrl is returned`() = runTest {
val unifiedPushApiFactory = FakeUnifiedPushApiFactory(
discoveryResponse = matrixDiscoveryResponse
)
Expand All @@ -108,11 +142,11 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("invalid")
assertThat(unifiedPushApiFactory.baseUrlParameter).isNull()
assertThat(result).isEqualTo(UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL)
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.ErrorInvalidUrl)
}

@Test
fun `when a custom url provides a invalid matrix gateway, the custom url is still returned`() = runTest {
fun `when a custom url provides a invalid matrix gateway, NoMatrixGateway is returned`() = runTest {
val unifiedPushApiFactory = FakeUnifiedPushApiFactory(
discoveryResponse = invalidDiscoveryResponse
)
Expand All @@ -121,7 +155,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url")
assertThat(result).isEqualTo("https://custom.url/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.NoMatrixGateway)
}

private fun TestScope.createDefaultUnifiedPushGatewayResolver(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2024 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only
* Please see LICENSE in the repository root for full details.
*/

package io.element.android.libraries.pushproviders.unifiedpush

import com.google.common.truth.Truth.assertThat
import org.junit.Test

class DefaultUnifiedPushGatewayUrlResolverTest {
@Test
fun `resolve ErrorInvalidUrl returns the default gateway`() {
val sut = createDefaultUnifiedPushGatewayUrlResolver()
val result = sut.resolve(
gatewayResult = UnifiedPushGatewayResolverResult.ErrorInvalidUrl,
instance = "",
)
assertThat(result).isEqualTo(UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL)
}

@Test
fun `resolve NoMatrixGateway returns the default gateway`() {
val sut = createDefaultUnifiedPushGatewayUrlResolver()
val result = sut.resolve(
gatewayResult = UnifiedPushGatewayResolverResult.NoMatrixGateway,
instance = "",
)
assertThat(result).isEqualTo(UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL)
}

@Test
fun `resolve Success returns the url`() {
val sut = createDefaultUnifiedPushGatewayUrlResolver()
val result = sut.resolve(
gatewayResult = UnifiedPushGatewayResolverResult.Success("aUrl"),
instance = "",
)
assertThat(result).isEqualTo("aUrl")
}

@Test
fun `resolve Error returns the current url when available`() {
val sut = createDefaultUnifiedPushGatewayUrlResolver(
unifiedPushStore = FakeUnifiedPushStore(
getPushGatewayResult = { instance ->
assertThat(instance).isEqualTo("instance")
"aCurrentUrl"
},
)
)
val result = sut.resolve(
gatewayResult = UnifiedPushGatewayResolverResult.Error("aUrl"),
instance = "instance",
)
assertThat(result).isEqualTo("aCurrentUrl")
}

@Test
fun `resolve Error returns the url if no current url is available`() {
val sut = createDefaultUnifiedPushGatewayUrlResolver(
unifiedPushStore = FakeUnifiedPushStore(
getPushGatewayResult = { instance ->
assertThat(instance).isEqualTo("instance")
null
},
)
)
val result = sut.resolve(
gatewayResult = UnifiedPushGatewayResolverResult.Error("aUrl"),
instance = "instance",
)
assertThat(result).isEqualTo("aUrl")
}

private fun createDefaultUnifiedPushGatewayUrlResolver(
unifiedPushStore: UnifiedPushStore = FakeUnifiedPushStore(),
) = DefaultUnifiedPushGatewayUrlResolver(
unifiedPushStore = unifiedPushStore,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ package io.element.android.libraries.pushproviders.unifiedpush
import io.element.android.tests.testutils.lambda.lambdaError

class FakeUnifiedPushGatewayResolver(
private val getGatewayResult: (String) -> String = { lambdaError() },
private val getGatewayResult: (String) -> UnifiedPushGatewayResolverResult = { lambdaError() },
) : UnifiedPushGatewayResolver {
override suspend fun getGateway(endpoint: String): String {
override suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult {
return getGatewayResult(endpoint)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright 2024 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only
* Please see LICENSE in the repository root for full details.
*/

package io.element.android.libraries.pushproviders.unifiedpush

import io.element.android.tests.testutils.lambda.lambdaError

class FakeUnifiedPushGatewayUrlResolver(
private val resolveResult: (UnifiedPushGatewayResolverResult, String) -> String = { _, _ -> lambdaError() },
) : UnifiedPushGatewayUrlResolver {
override fun resolve(gatewayResult: UnifiedPushGatewayResolverResult, instance: String): String {
return resolveResult(gatewayResult, instance)
}
}
Loading

0 comments on commit 5880bbb

Please sign in to comment.