Skip to content

Commit

Permalink
feat: e2ei: Update checking if e2ei required (#2303)
Browse files Browse the repository at this point in the history
* feat(core-crypto): upgrade to rc 21

* respect e2ei enabled in setting credentialtype

* add teamId to enrollment functions

* feat: e2ei: Update checking if E2EI required

* Code-style fixes

---------

Co-authored-by: Mojtaba Chenani <[email protected]>
  • Loading branch information
borichellow and mchenani authored Dec 12, 2023
1 parent 998f48e commit 1e78df6
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,13 @@ class UserSessionScope internal constructor(

val isMLSEnabled: IsMLSEnabledUseCase get() = IsMLSEnabledUseCaseImpl(featureSupport, userConfigRepository)

val observeE2EIRequired: ObserveE2EIRequiredUseCase get() = ObserveE2EIRequiredUseCaseImpl(userConfigRepository, featureSupport)
val observeE2EIRequired: ObserveE2EIRequiredUseCase
get() = ObserveE2EIRequiredUseCaseImpl(
userConfigRepository,
featureSupport,
users.getE2EICertificate,
clientIdProvider
)
val markE2EIRequiredAsNotified: MarkEnablingE2EIAsNotifiedUseCase
get() = MarkEnablingE2EIAsNotifiedUseCaseImpl(userConfigRepository)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ interface EnrollE2EIUseCase {
suspend fun finalizeEnrollment(
idToken: String,
initializationResult: E2EIEnrollmentResult.Initialized
): Either<CoreFailure, E2EIEnrollmentResult>
): Either<E2EIFailure, E2EIEnrollmentResult>
}

@Suppress("ReturnCount")
Expand Down Expand Up @@ -91,7 +91,7 @@ class EnrollE2EIUseCaseImpl internal constructor(
override suspend fun finalizeEnrollment(
idToken: String,
initializationResult: E2EIEnrollmentResult.Initialized
): Either<CoreFailure, E2EIEnrollmentResult> {
): Either<E2EIFailure, E2EIEnrollmentResult> {

var prevNonce = initializationResult.lastNonce
val authz = initializationResult.authz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ package com.wire.kalium.logic.feature.user

import com.wire.kalium.logic.configuration.E2EISettings
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.getOrNull
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onlyRight
import com.wire.kalium.util.DateTimeUtil
import com.wire.kalium.util.KaliumDispatcherImpl
Expand Down Expand Up @@ -50,6 +56,8 @@ interface ObserveE2EIRequiredUseCase {
internal class ObserveE2EIRequiredUseCaseImpl(
private val userConfigRepository: UserConfigRepository,
private val featureSupport: FeatureSupport,
private val e2eiCertificate: GetE2eiCertificateUseCase,
private val currentClientIdProvider: CurrentClientIdProvider,
private val dispatcher: CoroutineDispatcher = KaliumDispatcherImpl.io
) : ObserveE2EIRequiredUseCase {

Expand All @@ -63,29 +71,38 @@ internal class ObserveE2EIRequiredUseCaseImpl(
.filterNotNull()
.delayUntilNotifyTime()
.flatMapLatest {
observeE2EISettings().flatMapLatest { setting ->
observeE2EISettings().map { setting ->
if (!setting.isRequired)
flowOf(E2EIRequiredResult.NotRequired)
else
observeCurrentE2EICertificate().map { currentCertificate ->
// TODO check here if current certificate needs to be renewed (soon, or now)
E2EIRequiredResult.NotRequired
else {
currentClientIdProvider()
.map { clientId ->
val certificateResult = e2eiCertificate(clientId)
when {
certificateResult.isValid() -> E2EIRequiredResult.NotRequired

if (setting.gracePeriodEnd == null || setting.gracePeriodEnd <= DateTimeUtil.currentInstant())
E2EIRequiredResult.NoGracePeriod.Create
else E2EIRequiredResult.WithGracePeriod.Create(setting.gracePeriodEnd.minus(DateTimeUtil.currentInstant()))
}
setting.isGracePeriodLeft() -> {
val timeLeft = setting.gracePeriodEnd!!.minus(DateTimeUtil.currentInstant())
if (certificateResult !is GetE2EICertificateUseCaseResult.Failure)
E2EIRequiredResult.WithGracePeriod.Renew(timeLeft)
else E2EIRequiredResult.WithGracePeriod.Create(timeLeft)
}

else -> {
if (certificateResult !is GetE2EICertificateUseCaseResult.Failure)
E2EIRequiredResult.NoGracePeriod.Renew
else E2EIRequiredResult.NoGracePeriod.Create
}
}
}.getOrElse { E2EIRequiredResult.NotRequired }
}
}
}
.flowOn(dispatcher)
}

private fun observeE2EISettings() = userConfigRepository.observeE2EISettings().onlyRight().flowOn(dispatcher)

private fun observeCurrentE2EICertificate(): Flow<Unit> {
// TODO get current client E2EI certificate data here
return flowOf(Unit).flowOn(dispatcher)
}

private fun Flow<Instant>.delayUntilNotifyTime(): Flow<Instant> = flatMapLatest { instant ->
val delayMillis = instant
.minus(DateTimeUtil.currentInstant())
Expand All @@ -94,6 +111,11 @@ internal class ObserveE2EIRequiredUseCaseImpl(
flowOf(instant).onStart { delay(delayMillis) }
}

private fun GetE2EICertificateUseCaseResult.isValid(): Boolean =
this is GetE2EICertificateUseCaseResult.Success && certificate.status == CertificateStatus.VALID

private fun E2EISettings.isGracePeriodLeft(): Boolean = gracePeriodEnd != null && gracePeriodEnd > DateTimeUtil.currentInstant()

companion object {
private const val NO_DELAY_MS = 0L
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,23 @@ package com.wire.kalium.logic.feature.client
import app.cash.turbine.test
import com.wire.kalium.logic.configuration.E2EISettings
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.E2eiCertificate
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
import com.wire.kalium.logic.feature.user.E2EIRequiredResult
import com.wire.kalium.logic.feature.user.ObserveE2EIRequiredUseCase
import com.wire.kalium.logic.feature.user.ObserveE2EIRequiredUseCaseImpl
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.framework.TestClient
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.util.DateTimeUtil
import io.mockative.Mock
import io.mockative.any
import io.mockative.classOf
import io.mockative.given
import io.mockative.mock
import io.mockative.verify
Expand All @@ -41,6 +49,7 @@ import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import kotlinx.datetime.Instant
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.minutes
Expand All @@ -54,6 +63,8 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(MLS_E2EI_SETTING)
.withE2EINotificationTime(null)
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
Expand All @@ -71,6 +82,8 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant())
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
Expand All @@ -88,10 +101,12 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant())
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
assertTrue { awaitItem() == E2EIRequiredResult.NoGracePeriod.Create }
assertEquals(E2EIRequiredResult.NoGracePeriod.Create, awaitItem())
awaitComplete()
}
}
Expand All @@ -106,14 +121,16 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant().plus(delayDuration))
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
advanceTimeBy(delayDuration.minus(1.minutes).inWholeMilliseconds)
expectNoEvents()

advanceTimeBy(delayDuration.inWholeMilliseconds)
assertTrue { awaitItem() == E2EIRequiredResult.NoGracePeriod.Create }
assertEquals(E2EIRequiredResult.NoGracePeriod.Create, awaitItem())
awaitComplete()
}
}
Expand All @@ -127,11 +144,13 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant())
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
advanceTimeBy(1000L)
assertTrue { awaitItem() == E2EIRequiredResult.NoGracePeriod.Create }
assertEquals(E2EIRequiredResult.NoGracePeriod.Create, awaitItem())
awaitComplete()
}
}
Expand All @@ -142,6 +161,8 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(MLS_E2EI_SETTING)
.withE2EINotificationTime(null)
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
Expand All @@ -156,6 +177,8 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant())
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated)
.arrange()

useCase().test {
Expand All @@ -174,6 +197,7 @@ class ObserveE2EIRequiredUseCaseTest {
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant().plus(delayDuration))
.withIsMLSSupported(false)
.withCurrentClientProviderSuccess()
.arrange()

useCase().test {
Expand All @@ -187,15 +211,63 @@ class ObserveE2EIRequiredUseCaseTest {
.wasNotInvoked()
}

@Test
fun givenSettingWithNotifyDateInPastAndUserHasCertificate_thenEmitNotRequiredResult() = runTest(TestKaliumDispatcher.io) {
val setting = MLS_E2EI_SETTING.copy(
gracePeriodEnd = DateTimeUtil.currentInstant()
)
val (_, useCase) = Arrangement(TestKaliumDispatcher.io)
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant())
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Success(VALID_CERTIFICATE))
.arrange()

useCase().test {
advanceTimeBy(1000L)
assertEquals(E2EIRequiredResult.NotRequired, awaitItem())
awaitComplete()
}
}

@Test
fun givenSettingWithNotifyDateInPastAndUserHasExpiredCertificate_thenEmitRequiredResult() = runTest(TestKaliumDispatcher.io) {
val setting = MLS_E2EI_SETTING.copy(
gracePeriodEnd = DateTimeUtil.currentInstant()
)
val (_, useCase) = Arrangement(TestKaliumDispatcher.io)
.withMLSE2EISetting(setting)
.withE2EINotificationTime(DateTimeUtil.currentInstant())
.withIsMLSSupported(true)
.withCurrentClientProviderSuccess()
.withGetE2EICertificateUseCaseResult(
GetE2EICertificateUseCaseResult.Success(VALID_CERTIFICATE.copy(status = CertificateStatus.EXPIRED))
)
.arrange()

useCase().test {
advanceTimeBy(1000L)
assertEquals(E2EIRequiredResult.NoGracePeriod.Renew, awaitItem())
awaitComplete()
}
}

private class Arrangement(testDispatcher: CoroutineDispatcher = UnconfinedTestDispatcher()) {
@Mock
val userConfigRepository = mock(UserConfigRepository::class)

@Mock
val featureSupport = mock(FeatureSupport::class)

@Mock
val e2eiCertificate = mock(GetE2eiCertificateUseCase::class)

@Mock
val currentClientIdProvider = mock(CurrentClientIdProvider::class)

private var observeMLSEnabledUseCase: ObserveE2EIRequiredUseCase =
ObserveE2EIRequiredUseCaseImpl(userConfigRepository, featureSupport, testDispatcher)
ObserveE2EIRequiredUseCaseImpl(userConfigRepository, featureSupport, e2eiCertificate, currentClientIdProvider, testDispatcher)

fun withMLSE2EISetting(setting: E2EISettings) = apply {
given(userConfigRepository)
Expand All @@ -217,10 +289,25 @@ class ObserveE2EIRequiredUseCaseTest {
.thenReturn(supported)
}

fun withCurrentClientProviderSuccess(clientId: ClientId = TestClient.CLIENT_ID) = apply {
given(currentClientIdProvider)
.suspendFunction(currentClientIdProvider::invoke)
.whenInvoked()
.thenReturn(Either.Right(clientId))
}

fun withGetE2EICertificateUseCaseResult(result: GetE2EICertificateUseCaseResult) = apply {
given(e2eiCertificate)
.suspendFunction(e2eiCertificate::invoke)
.whenInvokedWith(any())
.thenReturn(result)
}

fun arrange() = this to observeMLSEnabledUseCase
}

companion object {
private val MLS_E2EI_SETTING = E2EISettings(true, "some_url", null)
private val VALID_CERTIFICATE = E2eiCertificate(status = CertificateStatus.VALID)
}
}

0 comments on commit 1e78df6

Please sign in to comment.