Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport "Remove wait for tunnel up" to prepare-android/2025.1 #7617

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package net.mullvad.talpid

import android.net.VpnService
import android.os.ParcelFileDescriptor
import arrow.core.right
import io.mockk.MockKAnnotations
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkConstructor
import io.mockk.mockkStatic
import io.mockk.spyk
import java.net.InetAddress
import net.mullvad.mullvadvpn.lib.common.test.assertLists
import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe
import net.mullvad.mullvadvpn.lib.model.Prepared
import net.mullvad.talpid.model.CreateTunResult
import net.mullvad.talpid.model.InetNetwork
import net.mullvad.talpid.model.TunConfig
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertInstanceOf

class TalpidVpnServiceFallbackDnsTest {
lateinit var talpidVpnService: TalpidVpnService
var builderMockk = mockk<VpnService.Builder>()

@BeforeEach
fun setup() {
MockKAnnotations.init(this)
mockkStatic(VPN_SERVICE_EXTENSION)

talpidVpnService = spyk<TalpidVpnService>(recordPrivateCalls = true)
every { talpidVpnService.prepareVpnSafe() } returns Prepared.right()
builderMockk = mockk<VpnService.Builder>()

mockkConstructor(VpnService.Builder::class)
every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk
every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk
every { anyConstructed<VpnService.Builder>().addAddress(any<InetAddress>(), any()) } returns
builderMockk
every { anyConstructed<VpnService.Builder>().addRoute(any<InetAddress>(), any()) } returns
builderMockk
every {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
} returns builderMockk
val parcelFileDescriptor: ParcelFileDescriptor = mockk()
every { anyConstructed<VpnService.Builder>().establish() } returns parcelFileDescriptor
every { parcelFileDescriptor.detachFd() } returns 1
}

@Test
fun `opening tun with no DnsServers should add fallback DNS server`() {
val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf())

val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.Success>(result)

// Fallback DNS server should be added if no DNS servers are provided
coVerify(exactly = 1) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

@Test
fun `opening tun with all bad DnsServers should return InvalidDnsServers and add fallback`() {
val badDns1 = InetAddress.getByName("0.0.0.0")
val badDns2 = InetAddress.getByName("255.255.255.255")
every { anyConstructed<VpnService.Builder>().addDnsServer(badDns1) } throws
IllegalArgumentException()
every { anyConstructed<VpnService.Builder>().addDnsServer(badDns2) } throws
IllegalArgumentException()

val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(badDns1, badDns2))
val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.InvalidDnsServers>(result)
assertLists(tunConfig.dnsServers, result.addresses)
// Fallback DNS server should be added if no valid DNS servers are provided
coVerify(exactly = 1) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

@Test
fun `opening tun with 1 good and 1 bad DnsServers should return InvalidDnsServers`() {
val goodDnsServer = InetAddress.getByName("1.1.1.1")
val badDns = InetAddress.getByName("255.255.255.255")
every { anyConstructed<VpnService.Builder>().addDnsServer(goodDnsServer) } returns
builderMockk
every { anyConstructed<VpnService.Builder>().addDnsServer(badDns) } throws
IllegalArgumentException()

val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer, badDns))
val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.InvalidDnsServers>(result)
assertLists(arrayListOf(badDns), result.addresses)

// Fallback DNS server should not be added since we have 1 good DNS server
coVerify(exactly = 0) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

@Test
fun `providing good dns servers should not add the fallback dns and return success`() {
val goodDnsServer = InetAddress.getByName("1.1.1.1")
every { anyConstructed<VpnService.Builder>().addDnsServer(goodDnsServer) } returns
builderMockk

val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer))
val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.Success>(result)

// Fallback DNS server should not be added since we have good DNS servers.
coVerify(exactly = 0) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

companion object {
private const val VPN_SERVICE_EXTENSION =
"net.mullvad.mullvadvpn.lib.common.util.VpnServiceUtilsKt"

val baseTunConfig =
TunConfig(
addresses = arrayListOf(InetAddress.getByName("45.83.223.209")),
dnsServers = arrayListOf(),
routes =
arrayListOf(
InetNetwork(InetAddress.getByName("0.0.0.0"), 0),
InetNetwork(InetAddress.getByName("::"), 0),
),
mtu = 1280,
excludedPackages = arrayListOf(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@ package net.mullvad.mullvadvpn.lib.common.util

import android.content.Context
import android.content.Intent
import android.net.VpnService
import android.net.VpnService.prepare
import android.os.ParcelFileDescriptor
import arrow.core.Either
import arrow.core.flatten
import arrow.core.flatMap
import arrow.core.left
import arrow.core.raise.either
import arrow.core.raise.ensureNotNull
import arrow.core.right
import co.touchlab.kermit.Logger
import net.mullvad.mullvadvpn.lib.common.util.SdkUtils.getInstalledPackagesList
import net.mullvad.mullvadvpn.lib.model.PrepareError
import net.mullvad.mullvadvpn.lib.model.Prepared

/**
* Prepare to establish a VPN connection safely.
*
* Invoking VpnService.prepare() can result in 3 out comes:
* 1. IllegalStateException - There is a legacy VPN profile marked as always on
* 2. Intent
Expand All @@ -34,7 +40,7 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> =
else -> throw it
}
}
.map { intent ->
.flatMap { intent ->
if (intent == null) {
Prepared.right()
} else {
Expand All @@ -46,7 +52,6 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> =
}
}
}
.flatten()

fun Context.getAlwaysOnVpnAppName(): String? {
return resolveAlwaysOnVpnPackageName()
Expand All @@ -59,3 +64,38 @@ fun Context.getAlwaysOnVpnAppName(): String? {
?.loadLabel(packageManager)
?.toString()
}

/**
* Establish a VPN connection safely.
*
* This function wraps the [VpnService.Builder.establish] function and catches any exceptions that
* may be thrown and type them to a more specific error.
*
* @return [ParcelFileDescriptor] if successful, [EstablishError] otherwise
*/
fun VpnService.Builder.establishSafe(): Either<EstablishError, ParcelFileDescriptor> = either {
val vpnInterfaceFd =
Either.catch { establish() }
.mapLeft {
when (it) {
is IllegalStateException -> EstablishError.ParameterNotApplied(it)
is IllegalArgumentException -> EstablishError.ParameterNotAccepted(it)
else -> EstablishError.UnknownError(it)
}
}
.bind()

ensureNotNull(vpnInterfaceFd) { EstablishError.NullVpnInterface }

vpnInterfaceFd
}

sealed interface EstablishError {
data class ParameterNotApplied(val exception: IllegalStateException) : EstablishError

data class ParameterNotAccepted(val exception: IllegalArgumentException) : EstablishError

data object NullVpnInterface : EstablishError

data class UnknownError(val error: Throwable) : EstablishError
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ import net.mullvad.mullvadvpn.lib.model.DnsState
import net.mullvad.mullvadvpn.lib.model.Endpoint
import net.mullvad.mullvadvpn.lib.model.ErrorState
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.AuthFailed
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.OtherAlwaysOnApp
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.TunnelParameterError
import net.mullvad.mullvadvpn.lib.model.FeatureIndicator
import net.mullvad.mullvadvpn.lib.model.GeoIpLocation
import net.mullvad.mullvadvpn.lib.model.GeoLocationId
Expand Down Expand Up @@ -125,7 +122,7 @@ private fun ManagementInterface.TunnelState.Error.toDomain(): TunnelState.Error
val otherAlwaysOnAppError =
errorState.let {
if (it.hasOtherAlwaysOnAppError()) {
OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName)
ErrorStateCause.OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName)
} else {
null
}
Expand Down Expand Up @@ -238,7 +235,7 @@ internal fun ManagementInterface.ErrorState.toDomain(
cause =
when (cause!!) {
ManagementInterface.ErrorState.Cause.AUTH_FAILED ->
AuthFailed(authFailedError.toDomain())
ErrorStateCause.AuthFailed(authFailedError.toDomain())
ManagementInterface.ErrorState.Cause.IPV6_UNAVAILABLE ->
ErrorStateCause.Ipv6Unavailable
ManagementInterface.ErrorState.Cause.SET_FIREWALL_POLICY_ERROR ->
Expand All @@ -247,15 +244,14 @@ internal fun ManagementInterface.ErrorState.toDomain(
ManagementInterface.ErrorState.Cause.START_TUNNEL_ERROR ->
ErrorStateCause.StartTunnelError
ManagementInterface.ErrorState.Cause.TUNNEL_PARAMETER_ERROR ->
TunnelParameterError(parameterError.toDomain())
ErrorStateCause.TunnelParameterError(parameterError.toDomain())
ManagementInterface.ErrorState.Cause.IS_OFFLINE -> ErrorStateCause.IsOffline
ManagementInterface.ErrorState.Cause.SPLIT_TUNNEL_ERROR ->
ErrorStateCause.StartTunnelError
ManagementInterface.ErrorState.Cause.UNRECOGNIZED,
ManagementInterface.ErrorState.Cause.NEED_FULL_DISK_PERMISSIONS,
ManagementInterface.ErrorState.Cause.CREATE_TUNNEL_DEVICE ->
throw IllegalArgumentException("Unrecognized error state cause")

ManagementInterface.ErrorState.Cause.NOT_PREPARED -> ErrorStateCause.NotPrepared
ManagementInterface.ErrorState.Cause.OTHER_ALWAYS_ON_APP -> otherAlwaysOnApp!!
ManagementInterface.ErrorState.Cause.OTHER_LEGACY_ALWAYS_ON_VPN ->
Expand Down
Loading
Loading