From 1e78df60282a78e6cc2b4061f3d14d702c3bd9d5 Mon Sep 17 00:00:00 2001 From: boris Date: Tue, 12 Dec 2023 18:50:03 +0200 Subject: [PATCH] feat: e2ei: Update checking if e2ei required (#2303) * feat(core-crypto): upgrade to rc 21 * respect e2ei enabled in setting credentialtype * add teamId to enrollment functions * feat: e2ei: Update checking if E2EI required * Code-style fixes --------- Co-authored-by: Mojtaba Chenani --- .../kalium/logic/feature/UserSessionScope.kt | 8 +- .../feature/e2ei/usecase/EnrollE2EIUseCase.kt | 4 +- .../user/ObserveE2EIRequiredUseCase.kt | 50 +++++++--- .../client/ObserveE2EIRequiredUseCaseTest.kt | 95 ++++++++++++++++++- 4 files changed, 136 insertions(+), 21 deletions(-) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 432b705c43c..a35930a80b7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1714,7 +1714,13 @@ class UserSessionScope internal constructor( val isMLSEnabled: IsMLSEnabledUseCase get() = IsMLSEnabledUseCaseImpl(featureSupport, userConfigRepository) - val observeE2EIRequired: ObserveE2EIRequiredUseCase get() = ObserveE2EIRequiredUseCaseImpl(userConfigRepository, featureSupport) + val observeE2EIRequired: ObserveE2EIRequiredUseCase + get() = ObserveE2EIRequiredUseCaseImpl( + userConfigRepository, + featureSupport, + users.getE2EICertificate, + clientIdProvider + ) val markE2EIRequiredAsNotified: MarkEnablingE2EIAsNotifiedUseCase get() = MarkEnablingE2EIAsNotifiedUseCaseImpl(userConfigRepository) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/EnrollE2EIUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/EnrollE2EIUseCase.kt index aa21ebd8995..d8bd6598dd2 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/EnrollE2EIUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/e2ei/usecase/EnrollE2EIUseCase.kt @@ -34,7 +34,7 @@ interface EnrollE2EIUseCase { suspend fun finalizeEnrollment( idToken: String, initializationResult: E2EIEnrollmentResult.Initialized - ): Either + ): Either } @Suppress("ReturnCount") @@ -91,7 +91,7 @@ class EnrollE2EIUseCaseImpl internal constructor( override suspend fun finalizeEnrollment( idToken: String, initializationResult: E2EIEnrollmentResult.Initialized - ): Either { + ): Either { var prevNonce = initializationResult.lastNonce val authz = initializationResult.authz diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/ObserveE2EIRequiredUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/ObserveE2EIRequiredUseCase.kt index 0dc2e1e260a..429d2c48946 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/ObserveE2EIRequiredUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/ObserveE2EIRequiredUseCase.kt @@ -19,8 +19,14 @@ package com.wire.kalium.logic.feature.user import com.wire.kalium.logic.configuration.E2EISettings import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.id.CurrentClientIdProvider +import com.wire.kalium.logic.feature.e2ei.CertificateStatus +import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult +import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.functional.getOrElse import com.wire.kalium.logic.functional.getOrNull +import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.onlyRight import com.wire.kalium.util.DateTimeUtil import com.wire.kalium.util.KaliumDispatcherImpl @@ -50,6 +56,8 @@ interface ObserveE2EIRequiredUseCase { internal class ObserveE2EIRequiredUseCaseImpl( private val userConfigRepository: UserConfigRepository, private val featureSupport: FeatureSupport, + private val e2eiCertificate: GetE2eiCertificateUseCase, + private val currentClientIdProvider: CurrentClientIdProvider, private val dispatcher: CoroutineDispatcher = KaliumDispatcherImpl.io ) : ObserveE2EIRequiredUseCase { @@ -63,17 +71,31 @@ internal class ObserveE2EIRequiredUseCaseImpl( .filterNotNull() .delayUntilNotifyTime() .flatMapLatest { - observeE2EISettings().flatMapLatest { setting -> + observeE2EISettings().map { setting -> if (!setting.isRequired) - flowOf(E2EIRequiredResult.NotRequired) - else - observeCurrentE2EICertificate().map { currentCertificate -> - // TODO check here if current certificate needs to be renewed (soon, or now) + E2EIRequiredResult.NotRequired + else { + currentClientIdProvider() + .map { clientId -> + val certificateResult = e2eiCertificate(clientId) + when { + certificateResult.isValid() -> E2EIRequiredResult.NotRequired - if (setting.gracePeriodEnd == null || setting.gracePeriodEnd <= DateTimeUtil.currentInstant()) - E2EIRequiredResult.NoGracePeriod.Create - else E2EIRequiredResult.WithGracePeriod.Create(setting.gracePeriodEnd.minus(DateTimeUtil.currentInstant())) - } + setting.isGracePeriodLeft() -> { + val timeLeft = setting.gracePeriodEnd!!.minus(DateTimeUtil.currentInstant()) + if (certificateResult !is GetE2EICertificateUseCaseResult.Failure) + E2EIRequiredResult.WithGracePeriod.Renew(timeLeft) + else E2EIRequiredResult.WithGracePeriod.Create(timeLeft) + } + + else -> { + if (certificateResult !is GetE2EICertificateUseCaseResult.Failure) + E2EIRequiredResult.NoGracePeriod.Renew + else E2EIRequiredResult.NoGracePeriod.Create + } + } + }.getOrElse { E2EIRequiredResult.NotRequired } + } } } .flowOn(dispatcher) @@ -81,11 +103,6 @@ internal class ObserveE2EIRequiredUseCaseImpl( private fun observeE2EISettings() = userConfigRepository.observeE2EISettings().onlyRight().flowOn(dispatcher) - private fun observeCurrentE2EICertificate(): Flow { - // TODO get current client E2EI certificate data here - return flowOf(Unit).flowOn(dispatcher) - } - private fun Flow.delayUntilNotifyTime(): Flow = flatMapLatest { instant -> val delayMillis = instant .minus(DateTimeUtil.currentInstant()) @@ -94,6 +111,11 @@ internal class ObserveE2EIRequiredUseCaseImpl( flowOf(instant).onStart { delay(delayMillis) } } + private fun GetE2EICertificateUseCaseResult.isValid(): Boolean = + this is GetE2EICertificateUseCaseResult.Success && certificate.status == CertificateStatus.VALID + + private fun E2EISettings.isGracePeriodLeft(): Boolean = gracePeriodEnd != null && gracePeriodEnd > DateTimeUtil.currentInstant() + companion object { private const val NO_DELAY_MS = 0L } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/ObserveE2EIRequiredUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/ObserveE2EIRequiredUseCaseTest.kt index 7feda0ad1a1..3885b0b8b13 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/ObserveE2EIRequiredUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/ObserveE2EIRequiredUseCaseTest.kt @@ -20,15 +20,23 @@ package com.wire.kalium.logic.feature.client import app.cash.turbine.test import com.wire.kalium.logic.configuration.E2EISettings import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.data.id.CurrentClientIdProvider +import com.wire.kalium.logic.feature.e2ei.CertificateStatus +import com.wire.kalium.logic.feature.e2ei.E2eiCertificate +import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult +import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase import com.wire.kalium.logic.feature.user.E2EIRequiredResult import com.wire.kalium.logic.feature.user.ObserveE2EIRequiredUseCase import com.wire.kalium.logic.feature.user.ObserveE2EIRequiredUseCaseImpl import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestKaliumDispatcher import com.wire.kalium.util.DateTimeUtil import io.mockative.Mock import io.mockative.any +import io.mockative.classOf import io.mockative.given import io.mockative.mock import io.mockative.verify @@ -41,6 +49,7 @@ import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertTrue import kotlin.time.Duration.Companion.days import kotlin.time.Duration.Companion.minutes @@ -54,6 +63,8 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(MLS_E2EI_SETTING) .withE2EINotificationTime(null) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { @@ -71,6 +82,8 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(setting) .withE2EINotificationTime(DateTimeUtil.currentInstant()) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { @@ -88,10 +101,12 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(setting) .withE2EINotificationTime(DateTimeUtil.currentInstant()) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { - assertTrue { awaitItem() == E2EIRequiredResult.NoGracePeriod.Create } + assertEquals(E2EIRequiredResult.NoGracePeriod.Create, awaitItem()) awaitComplete() } } @@ -106,6 +121,8 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(setting) .withE2EINotificationTime(DateTimeUtil.currentInstant().plus(delayDuration)) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { @@ -113,7 +130,7 @@ class ObserveE2EIRequiredUseCaseTest { expectNoEvents() advanceTimeBy(delayDuration.inWholeMilliseconds) - assertTrue { awaitItem() == E2EIRequiredResult.NoGracePeriod.Create } + assertEquals(E2EIRequiredResult.NoGracePeriod.Create, awaitItem()) awaitComplete() } } @@ -127,11 +144,13 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(setting) .withE2EINotificationTime(DateTimeUtil.currentInstant()) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { advanceTimeBy(1000L) - assertTrue { awaitItem() == E2EIRequiredResult.NoGracePeriod.Create } + assertEquals(E2EIRequiredResult.NoGracePeriod.Create, awaitItem()) awaitComplete() } } @@ -142,6 +161,8 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(MLS_E2EI_SETTING) .withE2EINotificationTime(null) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { @@ -156,6 +177,8 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(setting) .withE2EINotificationTime(DateTimeUtil.currentInstant()) .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Failure.NotActivated) .arrange() useCase().test { @@ -174,6 +197,7 @@ class ObserveE2EIRequiredUseCaseTest { .withMLSE2EISetting(setting) .withE2EINotificationTime(DateTimeUtil.currentInstant().plus(delayDuration)) .withIsMLSSupported(false) + .withCurrentClientProviderSuccess() .arrange() useCase().test { @@ -187,6 +211,48 @@ class ObserveE2EIRequiredUseCaseTest { .wasNotInvoked() } + @Test + fun givenSettingWithNotifyDateInPastAndUserHasCertificate_thenEmitNotRequiredResult() = runTest(TestKaliumDispatcher.io) { + val setting = MLS_E2EI_SETTING.copy( + gracePeriodEnd = DateTimeUtil.currentInstant() + ) + val (_, useCase) = Arrangement(TestKaliumDispatcher.io) + .withMLSE2EISetting(setting) + .withE2EINotificationTime(DateTimeUtil.currentInstant()) + .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult(GetE2EICertificateUseCaseResult.Success(VALID_CERTIFICATE)) + .arrange() + + useCase().test { + advanceTimeBy(1000L) + assertEquals(E2EIRequiredResult.NotRequired, awaitItem()) + awaitComplete() + } + } + + @Test + fun givenSettingWithNotifyDateInPastAndUserHasExpiredCertificate_thenEmitRequiredResult() = runTest(TestKaliumDispatcher.io) { + val setting = MLS_E2EI_SETTING.copy( + gracePeriodEnd = DateTimeUtil.currentInstant() + ) + val (_, useCase) = Arrangement(TestKaliumDispatcher.io) + .withMLSE2EISetting(setting) + .withE2EINotificationTime(DateTimeUtil.currentInstant()) + .withIsMLSSupported(true) + .withCurrentClientProviderSuccess() + .withGetE2EICertificateUseCaseResult( + GetE2EICertificateUseCaseResult.Success(VALID_CERTIFICATE.copy(status = CertificateStatus.EXPIRED)) + ) + .arrange() + + useCase().test { + advanceTimeBy(1000L) + assertEquals(E2EIRequiredResult.NoGracePeriod.Renew, awaitItem()) + awaitComplete() + } + } + private class Arrangement(testDispatcher: CoroutineDispatcher = UnconfinedTestDispatcher()) { @Mock val userConfigRepository = mock(UserConfigRepository::class) @@ -194,8 +260,14 @@ class ObserveE2EIRequiredUseCaseTest { @Mock val featureSupport = mock(FeatureSupport::class) + @Mock + val e2eiCertificate = mock(GetE2eiCertificateUseCase::class) + + @Mock + val currentClientIdProvider = mock(CurrentClientIdProvider::class) + private var observeMLSEnabledUseCase: ObserveE2EIRequiredUseCase = - ObserveE2EIRequiredUseCaseImpl(userConfigRepository, featureSupport, testDispatcher) + ObserveE2EIRequiredUseCaseImpl(userConfigRepository, featureSupport, e2eiCertificate, currentClientIdProvider, testDispatcher) fun withMLSE2EISetting(setting: E2EISettings) = apply { given(userConfigRepository) @@ -217,10 +289,25 @@ class ObserveE2EIRequiredUseCaseTest { .thenReturn(supported) } + fun withCurrentClientProviderSuccess(clientId: ClientId = TestClient.CLIENT_ID) = apply { + given(currentClientIdProvider) + .suspendFunction(currentClientIdProvider::invoke) + .whenInvoked() + .thenReturn(Either.Right(clientId)) + } + + fun withGetE2EICertificateUseCaseResult(result: GetE2EICertificateUseCaseResult) = apply { + given(e2eiCertificate) + .suspendFunction(e2eiCertificate::invoke) + .whenInvokedWith(any()) + .thenReturn(result) + } + fun arrange() = this to observeMLSEnabledUseCase } companion object { private val MLS_E2EI_SETTING = E2EISettings(true, "some_url", null) + private val VALID_CERTIFICATE = E2eiCertificate(status = CertificateStatus.VALID) } }