diff --git a/cli/src/appleMain/kotlin/main.kt b/cli/src/appleMain/kotlin/main.kt index d2b5eb01498..bbd85916860 100644 --- a/cli/src/appleMain/kotlin/main.kt +++ b/cli/src/appleMain/kotlin/main.kt @@ -27,6 +27,7 @@ import com.wire.kalium.cli.commands.MarkAsReadCommand import com.wire.kalium.cli.commands.RefillKeyPackagesCommand import com.wire.kalium.cli.commands.RemoveMemberFromGroupCommand import com.wire.kalium.cli.commands.InteractiveCommand +import com.wire.kalium.cli.commands.UpdateSupportedProtocolsCommand fun main(args: Array) = CLIApplication().subcommands( LoginCommand().subcommands( @@ -37,6 +38,7 @@ fun main(args: Array) = CLIApplication().subcommands( RemoveMemberFromGroupCommand(), RefillKeyPackagesCommand(), MarkAsReadCommand(), - InteractiveCommand() + InteractiveCommand(), + UpdateSupportedProtocolsCommand() ) ).main(args) diff --git a/cli/src/commonMain/kotlin/com/wire/kalium/cli/CLIApplication.kt b/cli/src/commonMain/kotlin/com/wire/kalium/cli/CLIApplication.kt index af05daf47d2..a187a7c39cd 100644 --- a/cli/src/commonMain/kotlin/com/wire/kalium/cli/CLIApplication.kt +++ b/cli/src/commonMain/kotlin/com/wire/kalium/cli/CLIApplication.kt @@ -30,13 +30,25 @@ import com.wire.kalium.logic.CoreLogger import com.wire.kalium.logic.CoreLogic import com.wire.kalium.logic.featureFlags.KaliumConfigs import kotlinx.coroutines.runBlocking +import kotlin.time.Duration class CLIApplication : CliktCommand(allowMultipleSubcommands = true) { - private val logLevel by option(help = "log level").enum().default(KaliumLogLevel.WARN) - private val logOutputFile by option(help = "output file for logs") - private val developmentApiEnabled by option(help = "use development API if supported by backend").flag(default = false) - private val encryptProteusStorage by option(help = "use encrypted storage for proteus sessions and identity").flag(default = false) + private val logLevel by option( + help = "log level" + ).enum().default(KaliumLogLevel.WARN) + private val logOutputFile by option( + help = "output file for logs" + ) + private val developmentApiEnabled by option( + help = "use development API if supported by backend" + ).flag(default = false) + private val encryptProteusStorage by option( + help = "use encrypted storage for proteus sessions and identity" + ).flag(default = false) + private val mlsMigrationInterval by option( + help = "interval at which mls migration is updated" + ).default("24h") private val fileLogger: LogWriter by lazy { fileLogger(logOutputFile ?: "kalium.log") } override fun run() = runBlocking { @@ -45,7 +57,8 @@ class CLIApplication : CliktCommand(allowMultipleSubcommands = true) { rootPath = "$HOME_DIRECTORY/.kalium/accounts", kaliumConfigs = KaliumConfigs( developmentApiEnabled = developmentApiEnabled, - encryptProteusStorage = encryptProteusStorage + encryptProteusStorage = encryptProteusStorage, + mlsMigrationInterval = Duration.parse(mlsMigrationInterval) ) ) } @@ -63,7 +76,6 @@ class CLIApplication : CliktCommand(allowMultipleSubcommands = true) { companion object { val HOME_DIRECTORY: String = homeDirectory() } - } expect fun fileLogger(filePath: String): LogWriter diff --git a/cli/src/commonMain/kotlin/com/wire/kalium/cli/commands/UpdateSupportedProtocolsCommand.kt b/cli/src/commonMain/kotlin/com/wire/kalium/cli/commands/UpdateSupportedProtocolsCommand.kt new file mode 100644 index 00000000000..ae5a9fcd075 --- /dev/null +++ b/cli/src/commonMain/kotlin/com/wire/kalium/cli/commands/UpdateSupportedProtocolsCommand.kt @@ -0,0 +1,40 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.kalium.cli.commands + +import com.github.ajalt.clikt.core.CliktCommand +import com.github.ajalt.clikt.core.PrintMessage +import com.github.ajalt.clikt.core.requireObject +import com.wire.kalium.logic.feature.UserSessionScope +import com.wire.kalium.logic.functional.fold +import kotlinx.coroutines.runBlocking + +class UpdateSupportedProtocolsCommand : CliktCommand(name = "update-supported-protocols") { + + private val userSession by requireObject() + + override fun run() = runBlocking { + userSession.syncManager.waitUntilLive() + userSession.users.updateSupportedProtocols().fold({ failure -> + throw PrintMessage("updating supported protocols failed: $failure") + }, { + echo("supported protocols were updated") + }) + } +} diff --git a/cli/src/jvmMain/kotlin/com/wire/kalium/cli/main.kt b/cli/src/jvmMain/kotlin/com/wire/kalium/cli/main.kt index 038c271b5f3..115199d62df 100644 --- a/cli/src/jvmMain/kotlin/com/wire/kalium/cli/main.kt +++ b/cli/src/jvmMain/kotlin/com/wire/kalium/cli/main.kt @@ -28,6 +28,7 @@ import com.wire.kalium.cli.commands.MarkAsReadCommand import com.wire.kalium.cli.commands.ConsoleCommand import com.wire.kalium.cli.commands.RefillKeyPackagesCommand import com.wire.kalium.cli.commands.RemoveMemberFromGroupCommand +import com.wire.kalium.cli.commands.UpdateSupportedProtocolsCommand fun main(args: Array) = CLIApplication().subcommands( LoginCommand().subcommands( @@ -38,6 +39,7 @@ fun main(args: Array) = CLIApplication().subcommands( RemoveMemberFromGroupCommand(), ConsoleCommand(), RefillKeyPackagesCommand(), - MarkAsReadCommand() + MarkAsReadCommand(), + UpdateSupportedProtocolsCommand() ) ).main(args) diff --git a/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/E2EIClientTest.kt b/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/E2EIClientTest.kt index 07244098ede..642bd05a865 100644 --- a/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/E2EIClientTest.kt +++ b/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/E2EIClientTest.kt @@ -23,8 +23,8 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue -@IgnoreJS @IgnoreIOS +@IgnoreJS class E2EIClientTest : BaseMLSClientTest() { data class SampleUser( val id: CryptoQualifiedID, val clientId: CryptoClientId, val name: String, val handle: String diff --git a/logic/src/commonJvmAndroid/kotlin/com/wire/kalium/logic/CoreCryptoExceptionMapper.kt b/logic/src/commonJvmAndroid/kotlin/com/wire/kalium/logic/CoreCryptoExceptionMapper.kt index 2e2783af545..935c35efebc 100644 --- a/logic/src/commonJvmAndroid/kotlin/com/wire/kalium/logic/CoreCryptoExceptionMapper.kt +++ b/logic/src/commonJvmAndroid/kotlin/com/wire/kalium/logic/CoreCryptoExceptionMapper.kt @@ -25,8 +25,10 @@ actual fun mapMLSException(exception: Exception): MLSFailure = when (exception.error) { is CryptoError.WrongEpoch -> MLSFailure.WrongEpoch is CryptoError.DuplicateMessage -> MLSFailure.DuplicateMessage + is CryptoError.BufferedFutureMessage -> MLSFailure.BufferedFutureMessage is CryptoError.SelfCommitIgnored -> MLSFailure.SelfCommitIgnored is CryptoError.UnmergedPendingGroup -> MLSFailure.UnmergedPendingGroup + is CryptoError.ConversationAlreadyExists -> MLSFailure.ConversationAlreadyExists else -> MLSFailure.Generic(exception) } } else { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt index 1826fa408cb..29a7b6bc540 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt @@ -21,6 +21,7 @@ package com.wire.kalium.logic import com.wire.kalium.cryptography.exceptions.ProteusException import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.functional.Either +import com.wire.kalium.network.exceptions.APINotSupported import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.exceptions.isFederationDenied import com.wire.kalium.network.utils.NetworkResponse @@ -99,6 +100,11 @@ sealed interface CoreFailure { data object SyncEventOrClientNotFound : FeatureFailure() data object FeatureNotImplemented : FeatureFailure() + /** + * No common Protocol found in order to establish a conversation between parties. + * Could be, for example, that the desired user only supports Proteus, but we only support MLS. + */ + data object NoCommonProtocolFound : FeatureFailure() } sealed class NetworkFailure : CoreFailure { @@ -155,6 +161,10 @@ sealed class NetworkFailure : CoreFailure { } + /** + * Failure due to a feature not supported by the current client/backend. + */ + object FeatureNotSupported : NetworkFailure() } interface MLSFailure : CoreFailure { @@ -163,10 +173,14 @@ interface MLSFailure : CoreFailure { object DuplicateMessage : MLSFailure + object BufferedFutureMessage : MLSFailure + object SelfCommitIgnored : MLSFailure object UnmergedPendingGroup : MLSFailure + object ConversationAlreadyExists : MLSFailure + object ConversationDoesNotSupportMLS : MLSFailure class Generic(internal val exception: Exception) : MLSFailure { @@ -233,6 +247,10 @@ internal inline fun wrapApiRequest(networkCall: () -> NetworkResponse< Either.Left(NetworkFailure.NoNetworkConnection(exception)) } + exception is APINotSupported -> { + Either.Left(NetworkFailure.FeatureNotSupported) + } + else -> { Either.Left(NetworkFailure.ServerMiscommunication(result.kException)) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/configuration/UserConfigRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/configuration/UserConfigRepository.kt index 390f7af9b88..72cc39a7175 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/configuration/UserConfigRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/configuration/UserConfigRepository.kt @@ -20,6 +20,12 @@ package com.wire.kalium.logic.configuration import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.featureConfig.AppLockConfigModel +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.toEntity +import com.wire.kalium.logic.data.featureConfig.toModel +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.toDao +import com.wire.kalium.logic.data.user.toModel import com.wire.kalium.logic.feature.selfDeletingMessages.SelfDeletionMapper.toSelfDeletionTimerEntity import com.wire.kalium.logic.feature.selfDeletingMessages.SelfDeletionMapper.toTeamSelfDeleteTimer import com.wire.kalium.logic.feature.selfDeletingMessages.TeamSettingsSelfDeletionStatus @@ -57,6 +63,10 @@ interface UserConfigRepository { fun observeE2EISettings(): Flow> fun setE2EISettings(setting: E2EISettings): Either fun snoozeE2EINotification(duration: Duration): Either + fun setDefaultProtocol(protocol: SupportedProtocol): Either + fun getDefaultProtocol(): Either + suspend fun setSupportedProtocols(protocols: Set): Either + suspend fun getSupportedProtocols(): Either> fun setConferenceCallingEnabled(enabled: Boolean): Either fun isConferenceCallingEnabled(): Either fun setSecondFactorPasswordChallengeStatus(isRequired: Boolean): Either @@ -80,6 +90,8 @@ interface UserConfigRepository { suspend fun observeTeamSettingsSelfDeletingStatus(): Flow> fun observeE2EINotificationTime(): Flow> fun setE2EINotificationTime(instant: Instant): Either + suspend fun getMigrationConfiguration(): Either + suspend fun setMigrationConfiguration(configuration: MLSMigrationModel): Either } @Suppress("TooManyFunctions") @@ -186,6 +198,17 @@ class UserConfigDataSource( private fun getE2EINotificationTimeOrNull() = wrapStorageRequest { userConfigStorage.getE2EINotificationTime() }.getOrNull() + override fun setDefaultProtocol(protocol: SupportedProtocol): Either = + wrapStorageRequest { userConfigStorage.persistDefaultProtocol(protocol.toDao()) } + + override fun getDefaultProtocol(): Either = + wrapStorageRequest { userConfigStorage.defaultProtocol().toModel() } + + override suspend fun setSupportedProtocols(protocols: Set): Either = + wrapStorageRequest { userConfigDAO.setSupportedProtocols(protocols.toDao()) } + + override suspend fun getSupportedProtocols(): Either> = + wrapStorageRequest { userConfigDAO.getSupportedProtocols()?.toModel() } override fun setConferenceCallingEnabled(enabled: Boolean): Either = wrapStorageRequest { userConfigStorage.persistConferenceCalling(enabled) @@ -299,4 +322,14 @@ class UserConfigDataSource( } } } + + override suspend fun getMigrationConfiguration(): Either = + wrapStorageRequest { + userConfigDAO.getMigrationConfiguration()?.toModel() + } + + override suspend fun setMigrationConfiguration(configuration: MLSMigrationModel): Either = + wrapStorageRequest { + userConfigDAO.setMigrationConfiguration(configuration.toEntity()) + } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/CallRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/CallRepository.kt index 4fbc5e0b962..331239be55f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/CallRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/CallRepository.kt @@ -633,7 +633,8 @@ internal class CallDataSource( ).flattenConcat() } } ?: Either.Left(CoreFailure.NotSupportedByProteus) - is Conversation.ProtocolInfo.Proteus -> Either.Left(CoreFailure.NotSupportedByProteus) + is Conversation.ProtocolInfo.Proteus, + is Conversation.ProtocolInfo.Mixed -> Either.Left(CoreFailure.NotSupportedByProteus) } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/mapper/CallMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/mapper/CallMapper.kt index a7fdb4c6438..17bdecad5b9 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/mapper/CallMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/call/mapper/CallMapper.kt @@ -125,7 +125,8 @@ class CallMapperImpl( Conversation.Type.GROUP -> { when (conversation.protocol) { is Conversation.ProtocolInfo.MLS -> ConversationType.ConferenceMls - is Conversation.ProtocolInfo.Proteus -> ConversationType.Conference + is Conversation.ProtocolInfo.Proteus, + is Conversation.ProtocolInfo.Mixed -> ConversationType.Conference } } Conversation.Type.ONE_ON_ONE -> ConversationType.OneOnOne diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt index b4b0bb166ae..2e10ea28ed0 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientMapper.kt @@ -71,7 +71,8 @@ class ClientMapper( model = client.model, isVerified = false, isValid = true, - mlsPublicKeys = client.mlsPublicKeys + mlsPublicKeys = client.mlsPublicKeys, + isMLSCapable = client.mlsPublicKeys?.isNotEmpty() ?: false ) fun fromClientEntity(clientEntity: ClientEntity): Client = with(clientEntity) { @@ -85,7 +86,8 @@ class ClientMapper( model = model, isVerified = isProteusVerified, isValid = isValid, - mlsPublicKeys = mlsPublicKeys + mlsPublicKeys = mlsPublicKeys, + isMLSCapable = isMLSCapable ) } @@ -100,7 +102,8 @@ class ClientMapper( model = model, isVerified = false, isValid = true, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) } @@ -116,7 +119,8 @@ class ClientMapper( model = null, registrationDate = null, lastActive = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) } } @@ -132,7 +136,8 @@ class ClientMapper( model = model, registrationDate = Instant.parse(registrationTime), lastActive = lastActive?.let { Instant.parse(it).coerceAtMost(Clock.System.now()) }, - mlsPublicKeys = mlsPublicKeys + mlsPublicKeys = mlsPublicKeys, + isMLSCapable = mlsPublicKeys?.isNotEmpty() ?: false ) } @@ -147,7 +152,8 @@ class ClientMapper( model = null, registrationDate = null, lastActive = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) } @@ -161,7 +167,8 @@ class ClientMapper( model = event.client.model, registrationDate = event.client.registrationTime, lastActive = event.client.lastActive, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = event.client.isMLSCapable ) private fun toClientTypeDTO(clientType: ClientType): ClientTypeDTO = when (clientType) { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt index d3c7a695dcc..c58add03455 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/ClientModel.kt @@ -20,7 +20,9 @@ package com.wire.kalium.logic.data.client import com.wire.kalium.cryptography.PreKeyCrypto import com.wire.kalium.logic.data.conversation.ClientId +import kotlinx.datetime.Clock import kotlinx.datetime.Instant +import kotlin.time.Duration.Companion.days data class RegisterClientParam( val password: String?, @@ -59,8 +61,13 @@ data class Client( val deviceType: DeviceType?, val label: String?, val model: String?, - val mlsPublicKeys: Map? -) + val mlsPublicKeys: Map?, + val isMLSCapable: Boolean +) { + companion object { + val INACTIVE_DURATION = 28.days + } +} enum class ClientType { Temporary, @@ -86,3 +93,12 @@ data class OtherUserClient( val isValid: Boolean, val isProteusVerified: Boolean ) + +/** + * True if the client is considered to be in active use. + * + * A client is considered active if it has connected to the backend within + * the `INACTIVE_DURATION`. + */ +val Client.isActive: Boolean + get() = lastActive?.let { (Clock.System.now() - it) < Client.INACTIVE_DURATION } ?: false diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepository.kt index b3ed3785d55..a1f935d992f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepository.kt @@ -39,14 +39,10 @@ import com.wire.kalium.logic.data.user.ConnectionState.NOT_CONNECTED import com.wire.kalium.logic.data.user.ConnectionState.PENDING import com.wire.kalium.logic.data.user.ConnectionState.SENT import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.data.user.UserMapper -import com.wire.kalium.logic.data.user.type.UserEntityTypeMapper import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.failure.InvalidMappingFailure -import com.wire.kalium.logic.feature.SelfTeamIdProvider import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap -import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.functional.isRight import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.onFailure @@ -57,7 +53,6 @@ import com.wire.kalium.logic.wrapStorageRequest import com.wire.kalium.network.api.base.authenticated.connection.ConnectionApi import com.wire.kalium.network.api.base.authenticated.connection.ConnectionDTO import com.wire.kalium.network.api.base.authenticated.connection.ConnectionStateDTO -import com.wire.kalium.network.api.base.authenticated.userDetails.UserDetailsApi import com.wire.kalium.persistence.dao.ConnectionDAO import com.wire.kalium.persistence.dao.UserDAO import com.wire.kalium.persistence.dao.conversation.ConversationDAO @@ -88,15 +83,10 @@ internal class ConnectionDataSource( private val memberDAO: MemberDAO, private val connectionDAO: ConnectionDAO, private val connectionApi: ConnectionApi, - private val userDetailsApi: UserDetailsApi, private val userDAO: UserDAO, - private val selfUserId: UserId, - private val selfTeamIdProvider: SelfTeamIdProvider, private val conversationRepository: ConversationRepository, private val connectionStatusMapper: ConnectionStatusMapper = MapperProvider.connectionStatusMapper(), - private val connectionMapper: ConnectionMapper = MapperProvider.connectionMapper(), - private val userMapper: UserMapper = MapperProvider.userMapper(), - private val userTypeEntityTypeMapper: UserEntityTypeMapper = MapperProvider.userTypeEntityMapper() + private val connectionMapper: ConnectionMapper = MapperProvider.connectionMapper() ) : ConnectionRepository { override suspend fun fetchSelfUserConnections(): Either { @@ -148,7 +138,6 @@ internal class ConnectionDataSource( val connectionStatus = connectionDTO.copy(status = newConnectionStatus) val connectionModel = connectionMapper.fromApiToModel(connectionDTO) handleUserConnectionStatusPersistence(connectionMapper.fromApiToModel(connectionStatus)) - persistConnection(connectionModel) connectionModel } } @@ -193,7 +182,7 @@ internal class ConnectionDataSource( } override suspend fun insertConnectionFromEvent(event: Event.User.NewConnection): Either = - persistConnection(event.connection) + handleUserConnectionStatusPersistence(event.connection) override suspend fun observeConnectionList(): Flow> { return connectionDAO.getConnections().map { connections -> @@ -203,36 +192,16 @@ internal class ConnectionDataSource( } } - // TODO: Vitor : Instead of duplicating, we could pass selfUser.teamId from the UseCases to this function. - // This way, the UseCases can tie the different Repos together, calling these functions. private suspend fun persistConnection(connection: Connection) = - selfTeamIdProvider().flatMap { teamId -> - // This can fail, but the connection will be there and get synced in worst case scenario in next SlowSync - wrapApiRequest { - userDetailsApi.getUserInfo(connection.qualifiedToId.toApi()) - }.fold({ - wrapStorageRequest { - connectionDAO.insertConnection(connectionMapper.modelToDao(connection)) - } - }, { userProfileDTO -> - wrapStorageRequest { - val userEntity = userMapper.fromUserProfileDtoToUserEntity( - userProfile = userProfileDTO, - connectionState = connectionStatusMapper.toDaoModel(state = connection.status), - userTypeEntity = userTypeEntityTypeMapper.fromTeamAndDomain( - otherUserDomain = userProfileDTO.id.domain, - selfUserTeamId = teamId?.value, - otherUserTeamId = userProfileDTO.teamId, - selfUserDomain = selfUserId.domain, - isService = userProfileDTO.service != null - ) - ) - insertConversationFromConnection(connection) - // should we insert first user before creating conversation ? - userDAO.insertUser(userEntity) - connectionDAO.insertConnection(connectionMapper.modelToDao(connection)) - } - }) + wrapStorageRequest { + val connectionStatus = connectionStatusMapper.toDaoModel(state = connection.status) + userDAO.upsertConnectionStatus(connection.qualifiedToId.toDao(), connectionStatus) + + insertConversationFromConnection(connection) + + if (connection.status != ACCEPTED) { + connectionDAO.insertConnection(connectionMapper.modelToDao(connection)) + } } private suspend fun insertConversationFromConnection(connection: Connection) { @@ -265,10 +234,9 @@ internal class ConnectionDataSource( } ACCEPTED -> { - memberDAO.updateOrInsertOneOnOneMemberWithConnectionStatus( + memberDAO.updateOrInsertOneOnOneMember( member = MemberEntity(user = connection.qualifiedToId.toDao(), MemberEntity.Role.Member), - conversationID = connection.qualifiedConversationId.toDao(), - status = connectionStatusMapper.toDaoModel(connection.status) + conversationID = connection.qualifiedConversationId.toDao() ) } @@ -283,26 +251,13 @@ internal class ConnectionDataSource( connectionDAO.deleteConnectionDataAndConversation(conversationId.toDao()) } - private suspend fun updateConversationMemberFromConnection(connection: Connection) = - wrapStorageRequest { - memberDAO.updateOrInsertOneOnOneMemberWithConnectionStatus( - // TODO(IMPORTANT!!!!!!): setting a default value for member role is incorrect and can lead to unexpected behaviour - member = MemberEntity(user = connection.qualifiedToId.toDao(), MemberEntity.Role.Member), - status = connectionStatusMapper.toDaoModel(connection.status), - conversationID = connection.qualifiedConversationId.toDao() - ) - }.onFailure { - kaliumLogger.e("There was an error when trying to persist the connection: $connection") - } - /** * This will update the connection status on user table and will insert members only * if the [ConnectionDTO.status] is other than [ConnectionStateDTO.PENDING] or [ConnectionStateDTO.SENT] */ private suspend fun handleUserConnectionStatusPersistence(connection: Connection): Either = when (connection.status) { - MISSING_LEGALHOLD_CONSENT, NOT_CONNECTED, PENDING, SENT, BLOCKED, IGNORED -> persistConnection(connection) + ACCEPTED, MISSING_LEGALHOLD_CONSENT, NOT_CONNECTED, PENDING, SENT, BLOCKED, IGNORED -> persistConnection(connection) CANCELLED -> deleteConnection(connection.qualifiedConversationId) - ACCEPTED -> updateConversationMemberFromConnection(connection) } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/Conversation.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/Conversation.kt index c0a79b5eab3..a3160d3b713 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/Conversation.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/Conversation.kt @@ -176,6 +176,12 @@ data class Conversation( CODE; } + enum class Protocol { + PROTEUS, + MIXED, + MLS + } + enum class ReceiptMode { DISABLED, ENABLED } enum class TypingIndicatorMode { STARTED, STOPPED } @@ -190,7 +196,8 @@ data class Conversation( MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448(4), MLS_256_DHKEMP521_AES256GCM_SHA512_P521(5), MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448(6), - MLS_256_DHKEMP384_AES256GCM_SHA384_P384(7); + MLS_256_DHKEMP384_AES256GCM_SHA384_P384(7), + MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519(61489); companion object { fun fromTag(tag: Int): CipherSuite = values().first { type -> type.tag == tag } @@ -200,28 +207,44 @@ data class Conversation( val supportsUnreadMessageCount get() = type in setOf(Type.ONE_ON_ONE, Type.GROUP) - sealed class ProtocolInfo { - object Proteus : ProtocolInfo() { + sealed interface ProtocolInfo { + object Proteus : ProtocolInfo { override fun name() = "Proteus" } data class MLS( - val groupId: GroupID, - val groupState: GroupState, - val epoch: ULong, - val keyingMaterialLastUpdate: Instant, + override val groupId: GroupID, + override val groupState: MLSCapable.GroupState, + override val epoch: ULong, + override val keyingMaterialLastUpdate: Instant, + override val cipherSuite: CipherSuite + ) : MLSCapable { + override fun name() = "MLS" + } + + data class Mixed( + override val groupId: GroupID, + override val groupState: MLSCapable.GroupState, + override val epoch: ULong, + override val keyingMaterialLastUpdate: Instant, + override val cipherSuite: CipherSuite + ) : MLSCapable { + override fun name() = "Mixed" + } + + sealed interface MLSCapable : ProtocolInfo { + val groupId: GroupID + val groupState: GroupState + val epoch: ULong + val keyingMaterialLastUpdate: Instant val cipherSuite: CipherSuite - ) : ProtocolInfo() { - enum class GroupState { PENDING_CREATION, PENDING_JOIN, PENDING_WELCOME_MESSAGE, ESTABLISHED } - override fun name() = "MLS" + enum class GroupState { PENDING_CREATION, PENDING_JOIN, PENDING_WELCOME_MESSAGE, ESTABLISHED } } - abstract fun name(): String + fun name(): String } - enum class Protocol { PROTEUS, MLS } - data class Member(val id: UserId, val role: Role) { sealed class Role { object Member : Role() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt index fcdc7a86cee..a1f16094bad 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt @@ -98,8 +98,8 @@ internal class ConversationGroupRepositoryImpl( private val newGroupConversationSystemMessagesCreator: Lazy, private val selfUserId: UserId, private val teamIdProvider: SelfTeamIdProvider, - private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(), - private val eventMapper: EventMapper = MapperProvider.eventMapper(), + private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId), + private val eventMapper: EventMapper = MapperProvider.eventMapper(selfUserId), private val protocolInfoMapper: ProtocolInfoMapper = MapperProvider.protocolInfoMapper(), ) : ConversationGroupRepository { @@ -150,7 +150,7 @@ internal class ConversationGroupRepositoryImpl( ).flatMap { when (protocol) { is Conversation.ProtocolInfo.Proteus -> Either.Right(Unit) - is Conversation.ProtocolInfo.MLS -> mlsConversationRepository.establishMLSGroup( + is Conversation.ProtocolInfo.MLSCapable -> mlsConversationRepository.establishMLSGroup( groupID = protocol.groupId, members = usersList + selfUserId ) @@ -177,6 +177,12 @@ internal class ConversationGroupRepositoryImpl( is ConversationEntity.ProtocolInfo.Proteus -> tryAddMembersToCloudAndStorage(userIdList, conversationId) + is ConversationEntity.ProtocolInfo.Mixed -> + tryAddMembersToCloudAndStorage(userIdList, conversationId) + .flatMap { + mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), userIdList) + } + is ConversationEntity.ProtocolInfo.MLS -> { mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), userIdList) } @@ -187,7 +193,7 @@ internal class ConversationGroupRepositoryImpl( wrapStorageRequest { conversationDAO.getConversationProtocolInfo(conversationId.toDao()) } .flatMap { protocol -> when (protocol) { - is ConversationEntity.ProtocolInfo.Proteus -> { + is ConversationEntity.ProtocolInfo.Proteus, is ConversationEntity.ProtocolInfo.Mixed -> { wrapApiRequest { conversationApi.addService( AddServiceRequest(id = serviceId.id, provider = serviceId.provider), @@ -199,7 +205,8 @@ internal class ConversationGroupRepositoryImpl( eventMapper.conversationMemberJoin( LocalId.generate(), response.event, - true + true, + false ) ) } @@ -238,7 +245,7 @@ internal class ConversationGroupRepositoryImpl( conversationId: ConversationId ) = if (apiResult.value is ConversationMemberAddedResponse.Changed) { memberJoinEventHandler.handle( - eventMapper.conversationMemberJoin(LocalId.generate(), apiResult.value.event, true) + eventMapper.conversationMemberJoin(LocalId.generate(), apiResult.value.event, true, false) ).flatMap { if (failedUsersList.isNotEmpty()) { newGroupConversationSystemMessagesCreator.value.conversationFailedToAddMembers(conversationId, failedUsersList) @@ -292,18 +299,12 @@ internal class ConversationGroupRepositoryImpl( is ConversationEntity.ProtocolInfo.Proteus -> deleteMemberFromCloudAndStorage(userId, conversationId) + is ConversationEntity.ProtocolInfo.Mixed -> + deleteMemberFromCloudAndStorage(userId, conversationId) + .flatMap { deleteMemberFromMlsGroup(userId, conversationId, protocol) } + is ConversationEntity.ProtocolInfo.MLS -> { - if (userId == selfUserId) { - deleteMemberFromCloudAndStorage(userId, conversationId).flatMap { - mlsConversationRepository.leaveGroup(GroupID(protocol.groupId)) - } - } else { - // when removing a member from an MLS group, don't need to call the api - mlsConversationRepository.removeMembersFromMLSGroup( - GroupID(protocol.groupId), - listOf(userId) - ) - } + deleteMemberFromMlsGroup(userId, conversationId, protocol) } } } @@ -319,17 +320,17 @@ internal class ConversationGroupRepositoryImpl( if (response is ConversationMemberAddedResponse.Changed) { val conversationId = response.event.qualifiedConversation.toModel() - memberJoinEventHandler.handle(eventMapper.conversationMemberJoin(LocalId.generate(), response.event, true)) + memberJoinEventHandler.handle(eventMapper.conversationMemberJoin(LocalId.generate(), response.event, true, false)) .flatMap { wrapStorageRequest { conversationDAO.getConversationProtocolInfo(conversationId.toDao()) } - .flatMap { - when (it) { + .flatMap { protocol -> + when (protocol) { is ConversationEntity.ProtocolInfo.Proteus -> Either.Right(Unit) - is ConversationEntity.ProtocolInfo.MLS -> { + is ConversationEntity.ProtocolInfo.MLSCapable -> { joinExistingMLSConversation(conversationId).flatMap { - addMembers(listOf(selfUserId), conversationId) + mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), listOf(selfUserId)) } } } @@ -344,6 +345,20 @@ internal class ConversationGroupRepositoryImpl( ): Either = wrapApiRequest { conversationApi.fetchLimitedInformationViaCode(code, key) } + private suspend fun deleteMemberFromMlsGroup( + userId: UserId, + conversationId: ConversationId, + protocol: ConversationEntity.ProtocolInfo.MLSCapable + ) = + if (userId == selfUserId) { + deleteMemberFromCloudAndStorage(userId, conversationId).flatMap { + mlsConversationRepository.leaveGroup(GroupID(protocol.groupId)) + } + } else { + // when removing a member from an MLS group, don't need to call the api + mlsConversationRepository.removeMembersFromMLSGroup(GroupID(protocol.groupId), listOf(userId)) + } + private suspend fun deleteMemberFromCloudAndStorage(userId: UserId, conversationId: ConversationId) = wrapApiRequest { conversationApi.removeMember(userId.toApi(), conversationId.toApi()) @@ -353,6 +368,7 @@ internal class ConversationGroupRepositoryImpl( eventMapper.conversationMemberLeave( LocalId.generate(), response.event, + false, false ) ) @@ -392,7 +408,8 @@ internal class ConversationGroupRepositoryImpl( eventMapper.conversationMessageTimerUpdate( LocalId.generate(), it, - true + true, + false ) ) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt index 0b750ce4766..e59dde28bd7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt @@ -31,6 +31,7 @@ import com.wire.kalium.logic.data.user.BotService import com.wire.kalium.logic.data.user.Connection import com.wire.kalium.logic.data.user.OtherUser import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.toModel import com.wire.kalium.logic.data.user.type.DomainUserTypeMapper import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol @@ -58,9 +59,6 @@ import kotlin.time.toDuration @Suppress("TooManyFunctions") interface ConversationMapper { fun fromApiModelToDaoModel(apiModel: ConversationResponse, mlsGroupState: GroupState?, selfUserTeamId: TeamId?): ConversationEntity - fun fromApiModelToDaoModel(apiModel: ConvProtocol): Protocol - fun fromDaoModel(daoProtocol: Protocol?): Conversation.Protocol? - fun toDaoModel(protocol: Conversation.Protocol?): Protocol? fun fromDaoModel(daoModel: ConversationViewEntity): Conversation fun fromDaoModel(daoModel: ConversationEntity): Conversation fun fromDaoModelToDetails( @@ -72,7 +70,7 @@ interface ConversationMapper { fun fromDaoModel(daoModel: ProposalTimerEntity): ProposalTimer fun toDAOAccess(accessList: Set): List fun toDAOAccessRole(accessRoleList: Set): List - fun toDAOGroupState(groupState: Conversation.ProtocolInfo.MLS.GroupState): GroupState + fun toDAOGroupState(groupState: Conversation.ProtocolInfo.MLSCapable.GroupState): GroupState fun toDAOProposalTimer(proposalTimer: ProposalTimer): ProposalTimerEntity fun toApiModel(access: Conversation.Access): ConversationAccessDTO fun toApiModel(accessRole: Conversation.AccessRole): ConversationAccessRoleDTO @@ -87,6 +85,7 @@ interface ConversationMapper { @Suppress("TooManyFunctions", "LongParameterList") internal class ConversationMapperImpl( + private val selfUserId: UserId, private val idMapper: IdMapper, private val conversationStatusMapper: ConversationStatusMapper, private val protocolInfoMapper: ProtocolInfoMapper, @@ -110,7 +109,7 @@ internal class ConversationMapperImpl( mutedStatus = conversationStatusMapper.fromMutedStatusApiToDaoModel(apiModel.members.self.otrMutedStatus), mutedTime = apiModel.members.self.otrMutedRef?.let { Instant.parse(it) }?.toEpochMilliseconds() ?: 0, removedBy = null, - creatorId = apiModel.creator, + creatorId = apiModel.creator ?: selfUserId.value, // NOTE mls 1-1 does not have the creator field set. lastReadDate = Instant.UNIX_FIRST_DATE, lastNotificationDate = null, lastModifiedDate = apiModel.lastEventTime.toInstant(), @@ -125,23 +124,6 @@ internal class ConversationMapperImpl( verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED ) - override fun fromApiModelToDaoModel(apiModel: ConvProtocol): Protocol = when (apiModel) { - ConvProtocol.PROTEUS -> Protocol.PROTEUS - ConvProtocol.MLS -> Protocol.MLS - } - - override fun fromDaoModel(daoProtocol: Protocol?): Conversation.Protocol? = when (daoProtocol) { - Protocol.PROTEUS -> Conversation.Protocol.PROTEUS - Protocol.MLS -> Conversation.Protocol.MLS - null -> null - } - - override fun toDaoModel(protocol: Conversation.Protocol?): Protocol? = when (protocol) { - Conversation.Protocol.PROTEUS -> Protocol.PROTEUS - Conversation.Protocol.MLS -> Protocol.MLS - null -> null - } - override fun fromDaoModel(daoModel: ConversationViewEntity): Conversation = with(daoModel) { val lastReadDateEntity = if (type == ConversationEntity.Type.CONNECTION_PENDING) UNIX_FIRST_DATE else lastReadDate.toIsoDateTimeString() @@ -225,7 +207,9 @@ internal class ConversationMapperImpl( connectionStatus = connectionStatusMapper.fromDaoModel(connectionStatus), expiresAt = null, defederated = userDefederated ?: false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = userSupportedProtocols?.map { it.toModel() }?.toSet(), + activeOneOnOneConversationId = userActiveOneOnOneConversationId?.toModel() ), legalHoldStatus = LegalHoldStatus.DISABLED, userType = domainUserTypeMapper.fromUserTypeEntity(userType), @@ -262,7 +246,8 @@ internal class ConversationMapperImpl( teamId = teamId?.let { TeamId(it) }, expiresAt = null, defederated = userDefederated ?: false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = userSupportedProtocols?.map { it.toModel() }?.toSet() ) ConversationDetails.Connection( @@ -312,12 +297,12 @@ internal class ConversationMapperImpl( } } - override fun toDAOGroupState(groupState: Conversation.ProtocolInfo.MLS.GroupState): GroupState = + override fun toDAOGroupState(groupState: Conversation.ProtocolInfo.MLSCapable.GroupState): GroupState = when (groupState) { - Conversation.ProtocolInfo.MLS.GroupState.ESTABLISHED -> GroupState.ESTABLISHED - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN -> GroupState.PENDING_JOIN - Conversation.ProtocolInfo.MLS.GroupState.PENDING_WELCOME_MESSAGE -> GroupState.PENDING_WELCOME_MESSAGE - Conversation.ProtocolInfo.MLS.GroupState.PENDING_CREATION -> GroupState.PENDING_CREATION + Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED -> GroupState.ESTABLISHED + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN -> GroupState.PENDING_JOIN + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_WELCOME_MESSAGE -> GroupState.PENDING_WELCOME_MESSAGE + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_CREATION -> GroupState.PENDING_CREATION } override fun toDAOProposalTimer(proposalTimer: ProposalTimer): ProposalTimerEntity = @@ -428,6 +413,14 @@ internal class ConversationMapperImpl( ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag) ) + ConvProtocol.MIXED -> ProtocolInfo.Mixed( + groupId ?: "", + mlsGroupState ?: GroupState.PENDING_JOIN, + epoch ?: 0UL, + keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), + ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag) + ) + ConvProtocol.PROTEUS -> ProtocolInfo.Proteus } } @@ -499,7 +492,14 @@ private fun ConversationEntity.AccessRole.toDAO(): Conversation.AccessRole = whe ConversationEntity.AccessRole.EXTERNAL -> Conversation.AccessRole.EXTERNAL } -private fun Conversation.Type.toDAO(): ConversationEntity.Type = when (this) { +internal fun Conversation.ProtocolInfo.MLSCapable.GroupState.toDao(): ConversationEntity.GroupState = when (this) { + Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED -> GroupState.ESTABLISHED + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_CREATION -> GroupState.PENDING_CREATION + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN -> GroupState.PENDING_JOIN + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_WELCOME_MESSAGE -> GroupState.PENDING_WELCOME_MESSAGE +} + +internal fun Conversation.Type.toDAO(): ConversationEntity.Type = when (this) { Conversation.Type.SELF -> ConversationEntity.Type.SELF Conversation.Type.ONE_ON_ONE -> ConversationEntity.Type.ONE_ON_ONE Conversation.Type.GROUP -> ConversationEntity.Type.GROUP @@ -521,3 +521,27 @@ private fun Conversation.Access.toDAO(): ConversationEntity.Access = when (this) Conversation.Access.LINK -> ConversationEntity.Access.LINK Conversation.Access.CODE -> ConversationEntity.Access.CODE } + +internal fun Conversation.Protocol.toApi(): ConvProtocol = when (this) { + Conversation.Protocol.PROTEUS -> ConvProtocol.PROTEUS + Conversation.Protocol.MIXED -> ConvProtocol.MIXED + Conversation.Protocol.MLS -> ConvProtocol.MLS +} + +internal fun Conversation.Protocol.toDao(): Protocol = when (this) { + Conversation.Protocol.PROTEUS -> Protocol.PROTEUS + Conversation.Protocol.MIXED -> Protocol.MIXED + Conversation.Protocol.MLS -> Protocol.MLS +} + +internal fun ConvProtocol.toModel(): Conversation.Protocol = when (this) { + ConvProtocol.PROTEUS -> Conversation.Protocol.PROTEUS + ConvProtocol.MIXED -> Conversation.Protocol.MIXED + ConvProtocol.MLS -> Conversation.Protocol.MLS +} + +internal fun Protocol.toModel(): Conversation.Protocol = when (this) { + Protocol.PROTEUS -> Conversation.Protocol.PROTEUS + Protocol.MIXED -> Conversation.Protocol.MIXED + Protocol.MLS -> Conversation.Protocol.MLS +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt index b5550138239..b074efa2189 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt @@ -22,12 +22,14 @@ import com.wire.kalium.logger.KaliumLogger.Companion.ApplicationFlow.CONVERSATIO import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo.MLSCapable.GroupState import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.id.NetworkQualifiedId import com.wire.kalium.logic.data.id.QualifiedID +import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toCrypto import com.wire.kalium.logic.data.id.toDao @@ -55,10 +57,12 @@ import com.wire.kalium.logic.wrapMLSRequest import com.wire.kalium.logic.wrapStorageRequest import com.wire.kalium.network.api.base.authenticated.client.ClientApi import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi +import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMemberDTO import com.wire.kalium.network.api.base.authenticated.conversation.ConversationRenameResponse import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse +import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO @@ -114,6 +118,14 @@ interface ConversationRepository { suspend fun getConversationList(): Either>> suspend fun observeConversationList(): Flow> suspend fun observeConversationListDetails(fromArchive: Boolean): Flow> + suspend fun getConversationIds( + type: Conversation.Type, + protocol: Conversation.Protocol, + teamId: TeamId? = null + ): Either> + + suspend fun fetchMlsOneToOneConversation(userId: UserId): Either + suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: TeamId): Either> suspend fun observeConversationDetailsById(conversationID: ConversationId): Flow> suspend fun fetchConversation(conversationID: ConversationId): Either suspend fun fetchSentConnectionConversation(conversationID: ConversationId): Either @@ -145,6 +157,11 @@ interface ConversationRepository { suspend fun deleteMembersFromEvent(userIDList: List, conversationID: ConversationId): Either suspend fun observeOneToOneConversationWithOtherUser(otherUserId: UserId): Flow> + suspend fun getOneOnOneConversationsWithOtherUser( + otherUserId: UserId, + protocol: Conversation.Protocol + ): Either> + suspend fun updateMutedStatusLocally( conversationId: ConversationId, mutedStatus: MutedConversationStatus, @@ -170,9 +187,10 @@ interface ConversationRepository { ): Either suspend fun getConversationsByGroupState( - groupState: Conversation.ProtocolInfo.MLS.GroupState + groupState: GroupState ): Either> + suspend fun updateConversationGroupState(groupID: GroupID, groupState: GroupState): Either suspend fun updateConversationNotificationDate(qualifiedID: QualifiedID): Either suspend fun updateAllConversationsNotificationDate(): Either suspend fun updateConversationModifiedDate(qualifiedID: QualifiedID, date: Instant): Either @@ -200,7 +218,7 @@ interface ConversationRepository { suspend fun deleteUserFromConversations(userId: UserId): Either - suspend fun getConversationIdsByUserId(userId: UserId): Either> + suspend fun getConversationsByUserId(userId: UserId): Either> suspend fun insertConversations(conversations: List): Either suspend fun changeConversationName( conversationId: ConversationId, @@ -238,13 +256,34 @@ interface ConversationRepository { suspend fun getConversationDetailsByMLSGroupId(mlsGroupId: GroupID): Either suspend fun observeUnreadArchivedConversationsCount(): Flow + suspend fun sendTypingIndicatorStatus( conversationId: ConversationId, typingStatus: Conversation.TypingIndicatorMode ): Either + + /** + * Update a conversation's protocol remotely. + * + * This also fetches the newly assigned `groupID` from the backend, if this operation fails the whole + * operation is cancelled and protocol change is not persisted. + * + * @return **true** if the protocol was changed or **false** if the protocol was unchanged. + */ + suspend fun updateProtocolRemotely(conversationId: ConversationId, protocol: Conversation.Protocol): Either + + /** + * Update a conversation's protocol locally. + * + * This also fetches the newly assigned `groupID` from the backend, if this operation fails the whole + * operation is cancelled and protocol change is not persisted. + * + * @return **true** if the protocol was changed or **false** if the protocol was unchanged. + */ + suspend fun updateProtocolLocally(conversationId: ConversationId, protocol: Conversation.Protocol): Either } -@Suppress("LongParameterList", "TooManyFunctions") +@Suppress("LongParameterList", "TooManyFunctions", "LargeClass") internal class ConversationDataSource internal constructor( private val selfUserId: UserId, private val mlsClientProvider: MLSClientProvider, @@ -257,7 +296,7 @@ internal class ConversationDataSource internal constructor( private val clientApi: ClientApi, private val conversationMetaDataDAO: ConversationMetaDataDAO, private val idMapper: IdMapper = MapperProvider.idMapper(), - private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(), + private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId), private val memberMapper: MemberMapper = MapperProvider.memberMapper(), private val conversationStatusMapper: ConversationStatusMapper = MapperProvider.conversationStatusMapper(), private val conversationRoleMapper: ConversationRoleMapper = MapperProvider.conversationRoleMapper(), @@ -438,6 +477,60 @@ internal class ConversationDataSource internal constructor( } } + override suspend fun fetchMlsOneToOneConversation(userId: UserId): Either = + wrapApiRequest { + conversationApi.fetchMlsOneToOneConversation(userId.toApi()) + }.map { conversationResponse -> + addOtherMemberIfMissing(conversationResponse, userId) + }.flatMap { conversationResponse -> + val selfUserTeamId = selfTeamIdProvider().getOrNull() + persistConversations( + conversations = listOf(conversationResponse), + selfUserTeamId = selfUserTeamId?.value + ).map { conversationResponse } + }.flatMap { response -> + baseInfoById(response.id.toModel()) + } + + private fun addOtherMemberIfMissing( + conversationResponse: ConversationResponse, + otherMemberId: UserId + ): ConversationResponse { + val currentOtherMembers = conversationResponse.members.otherMembers + val hasOtherUser = currentOtherMembers.any { it.id == otherMemberId.toApi() } + val otherMembers = if (hasOtherUser) { + currentOtherMembers + } else { + listOf( + ConversationMemberDTO.Other( + id = otherMemberId.toApi(), + conversationRole = "", + service = null + ) + ) + } + return conversationResponse.copy( + members = conversationResponse.members.copy( + otherMembers = otherMembers + ) + ) + } + + override suspend fun getConversationIds( + type: Conversation.Type, + protocol: Conversation.Protocol, + teamId: TeamId? + ): Either> = + wrapStorageRequest { + conversationDAO.getConversationIds(type.toDAO(), protocol.toDao(), teamId?.value) + .map { it.toModel() } + } + override suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: TeamId): Either> = + wrapStorageRequest { + conversationDAO.getTeamConversationIdsReadyToCompleteMigration(teamId.value) + .map { it.toModel() } + } + /** * Gets a flow that allows observing of */ @@ -548,13 +641,21 @@ internal class ConversationDataSource internal constructor( } override suspend fun getConversationsByGroupState( - groupState: Conversation.ProtocolInfo.MLS.GroupState + groupState: GroupState ): Either> = wrapStorageRequest { conversationDAO.getConversationsByGroupState(conversationMapper.toDAOGroupState(groupState)) .map(conversationMapper::fromDaoModel) } + override suspend fun updateConversationGroupState( + groupID: GroupID, + groupState: GroupState + ): Either = + wrapStorageRequest { + conversationDAO.updateConversationGroupState(groupState.toDao(), groupID.value) + } + override suspend fun updateConversationNotificationDate( qualifiedID: QualifiedID ): Either = @@ -631,12 +732,21 @@ internal class ConversationDataSource internal constructor( wrapApiRequest { clientApi.listClientsOfUsers(it) }.map { memberMapper.fromMapOfClientsResponseToRecipients(it) } } - override suspend fun observeOneToOneConversationWithOtherUser(otherUserId: UserId): Flow> { - return conversationDAO.observeConversationWithOtherUser(otherUserId.toDao()) + override suspend fun observeOneToOneConversationWithOtherUser( + otherUserId: UserId + ): Flow> { + return conversationDAO.observeOneOnOneConversationWithOtherUser(otherUserId.toDao()) .wrapStorageRequest() .mapRight { conversationMapper.fromDaoModel(it) } } + override suspend fun getOneOnOneConversationsWithOtherUser( + otherUserId: UserId, + protocol: Conversation.Protocol + ): Either> = wrapStorageRequest { + conversationDAO.getOneOnOneConversationIdsWithOtherUser(otherUserId.toDao(), protocol.toDao()).map { it.toModel() } + } + override suspend fun updateMutedStatusLocally( conversationId: ConversationId, mutedStatus: MutedConversationStatus, @@ -709,7 +819,7 @@ internal class ConversationDataSource internal constructor( override suspend fun deleteConversation(conversationId: ConversationId) = getConversationProtocolInfo(conversationId).flatMap { when (it) { - is Conversation.ProtocolInfo.MLS -> + is Conversation.ProtocolInfo.MLSCapable -> mlsClientProvider.getMLSClient().flatMap { mlsClient -> wrapMLSRequest { mlsClient.wipeConversation(it.groupId.toCrypto()) @@ -746,9 +856,9 @@ internal class ConversationDataSource internal constructor( conversationDAO.revokeOneOnOneConversationsWithDeletedUser(userId.toDao()) } - override suspend fun getConversationIdsByUserId(userId: UserId): Either> { - return wrapStorageRequest { conversationDAO.getConversationIdsByUserId(userId.toDao()) } - .map { it.map { conversationIdEntity -> conversationIdEntity.toModel() } } + override suspend fun getConversationsByUserId(userId: UserId): Either> { + return wrapStorageRequest { conversationDAO.getConversationsByUserId(userId.toDao()) } + .map { it.map { entity -> conversationMapper.fromDaoModel(entity) } } } override suspend fun insertConversations(conversations: List): Either { @@ -892,6 +1002,49 @@ internal class ConversationDataSource internal constructor( } } + override suspend fun updateProtocolRemotely( + conversationId: ConversationId, + protocol: Conversation.Protocol + ): Either = + wrapApiRequest { + conversationApi.updateProtocol(conversationId.toApi(), protocol.toApi()) + }.flatMap { response -> + when (response) { + UpdateConversationProtocolResponse.ProtocolUnchanged -> { + // no need to update conversation + Either.Right(false) + } + + is UpdateConversationProtocolResponse.ProtocolUpdated -> { + updateProtocolLocally(conversationId, protocol) + } + } + } + + override suspend fun updateProtocolLocally( + conversationId: ConversationId, + protocol: Conversation.Protocol + ): Either = + wrapApiRequest { + conversationApi.fetchConversationDetails(conversationId.toApi()) + }.flatMap { conversationResponse -> + wrapStorageRequest { + conversationDAO.updateConversationProtocol( + conversationId = conversationId.toDao(), + protocol = protocol.toDao() + ) + }.flatMap { updated -> + if (updated) { + val selfUserTeamId = selfTeamIdProvider().getOrNull() + persistConversations(listOf(conversationResponse), selfUserTeamId?.value, invalidateMembers = true) + } else { + Either.Right(Unit) + }.map { + updated + } + } + } + companion object { const val DEFAULT_MEMBER_ROLE = "wire_member" } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index e09c00307fe..56a74c31064 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -23,6 +23,7 @@ import com.wire.kalium.cryptography.CryptoQualifiedClientId import com.wire.kalium.cryptography.CryptoQualifiedID import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.event.Event @@ -115,12 +116,20 @@ private enum class CommitStrategy { ABORT } -private fun CoreFailure.getStrategy(retryOnClientMismatch: Boolean = true): CommitStrategy { - return if (this is NetworkFailure.ServerMiscommunication && this.kaliumException is KaliumException.InvalidRequestError) { +private fun CoreFailure.getStrategy( + remainingAttempts: Int, + retryOnClientMismatch: Boolean = true, + retryOnStaleMessage: Boolean = true +): CommitStrategy { + return if ( + remainingAttempts > 0 && + this is NetworkFailure.ServerMiscommunication && + kaliumException is KaliumException.InvalidRequestError + ) { if (this.kaliumException.isMlsClientMismatch() && retryOnClientMismatch) { CommitStrategy.DISCARD_AND_RETRY } else if ( - this.kaliumException.isMlsStaleMessage() || + this.kaliumException.isMlsStaleMessage() && retryOnStaleMessage || this.kaliumException.isMlsCommitMissingReferences() ) { CommitStrategy.KEEP_AND_RETRY @@ -134,6 +143,7 @@ private fun CoreFailure.getStrategy(retryOnClientMismatch: Boolean = true): Comm @Suppress("TooManyFunctions", "LongParameterList") internal class MLSConversationDataSource( + private val selfUserId: UserId, private val keyPackageRepository: KeyPackageRepository, private val mlsClientProvider: MLSClientProvider, private val mlsMessageApi: MLSMessageApi, @@ -145,7 +155,7 @@ internal class MLSConversationDataSource( private val epochsFlow: MutableSharedFlow, private val proposalTimersFlow: MutableSharedFlow, private val idMapper: IdMapper = MapperProvider.idMapper(), - private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(), + private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId), private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(), private val mlsCommitBundleMapper: MLSCommitBundleMapper = MapperProvider.mlsCommitBundleMapper(), kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl @@ -251,7 +261,6 @@ internal class MLSConversationDataSource( idMapper.toCryptoModel(groupID) ) } - } } @@ -323,7 +332,7 @@ internal class MLSConversationDataSource( private suspend fun processCommitBundleEvents(events: List) { events.forEach { eventContentDTO -> - val event = MapperProvider.eventMapper().fromEventContentDTO("", eventContentDTO, true) + val event = MapperProvider.eventMapper(selfUserId).fromEventContentDTO("", eventContentDTO, true, false) if (event is Event.Conversation) { commitBundleEventReceiver.onEvent(event) } @@ -368,12 +377,16 @@ internal class MLSConversationDataSource( return epochsFlow } - override suspend fun addMemberToMLSGroup( + override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List): Either = + internalAddMemberToMLSGroup(groupID, userIdList, retryOnStaleMessage = true) + + private suspend fun internalAddMemberToMLSGroup( groupID: GroupID, - userIdList: List + userIdList: List, + retryOnStaleMessage: Boolean ): Either = withContext(serialDispatcher) { commitPendingProposals(groupID).flatMap { - retryOnCommitFailure(groupID) { + retryOnCommitFailure(groupID, retryOnStaleMessage = retryOnStaleMessage) { keyPackageRepository.claimKeyPackages(userIdList).flatMap { keyPackages -> mlsClientProvider.getMLSClient().flatMap { mlsClient -> val clientKeyPackageList = keyPackages @@ -472,9 +485,19 @@ internal class MLSConversationDataSource( idMapper.toCryptoModel(groupID), publicKeys.map { mlsPublicKeysMapper.toCrypto(it) } ) + }.flatMapLeft { + if (it is MLSFailure.ConversationAlreadyExists) { + Either.Right(Unit) + } else { + Either.Left(it) + } } }.flatMap { - addMemberToMLSGroup(groupID, members) + internalAddMemberToMLSGroup(groupID, members, retryOnStaleMessage = false).onFailure { + wrapMLSRequest { + mlsClient.wipeConversation(groupID.toCrypto()) + } + } }.flatMap { wrapStorageRequest { conversationDAO.updateConversationGroupState( @@ -497,25 +520,48 @@ internal class MLSConversationDataSource( private suspend fun retryOnCommitFailure( groupID: GroupID, retryOnClientMismatch: Boolean = true, + retryOnStaleMessage: Boolean = true, operation: suspend () -> Either ) = operation() .flatMapLeft { - handleCommitFailure(it, groupID, retryOnClientMismatch, operation) + handleCommitFailure( + failure = it, + groupID = groupID, + remainingAttempts = 2, + retryOnClientMismatch = retryOnClientMismatch, + retryOnStaleMessage = retryOnStaleMessage, + retryOperation = operation + ) } private suspend fun handleCommitFailure( failure: CoreFailure, groupID: GroupID, + remainingAttempts: Int, retryOnClientMismatch: Boolean, + retryOnStaleMessage: Boolean, retryOperation: suspend () -> Either ): Either { - return when (failure.getStrategy(retryOnClientMismatch)) { + return when ( + failure.getStrategy( + remainingAttempts = remainingAttempts, + retryOnClientMismatch = retryOnClientMismatch, + retryOnStaleMessage = retryOnStaleMessage + ) + ) { CommitStrategy.KEEP_AND_RETRY -> keepCommitAndRetry(groupID) CommitStrategy.DISCARD_AND_RETRY -> discardCommitAndRetry(groupID, retryOperation) CommitStrategy.ABORT -> return discardCommit(groupID).flatMap { Either.Left(failure) } }.flatMapLeft { - handleCommitFailure(it, groupID, retryOnClientMismatch, retryOperation) + handleCommitFailure( + failure = it, + groupID = groupID, + remainingAttempts = remainingAttempts - 1, + retryOnClientMismatch = retryOnClientMismatch, + retryOnStaleMessage = retryOnStaleMessage, + retryOperation = retryOperation + ) } } @@ -548,9 +594,13 @@ internal class MLSConversationDataSource( kaliumLogger.w("Discarding the failed commit.") return mlsClientProvider.getMLSClient().flatMap { mlsClient -> - wrapMLSRequest { + @Suppress("TooGenericExceptionCaught") + try { mlsClient.clearPendingCommit(idMapper.toCryptoModel(groupID)) + } catch (error: Throwable) { + kaliumLogger.e("Discarding pending commit failed: $error") } + Either.Right(Unit) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/NewGroupConversationSystemMessagesCreator.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/NewGroupConversationSystemMessagesCreator.kt index dc19e42b954..d20972517d0 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/NewGroupConversationSystemMessagesCreator.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/NewGroupConversationSystemMessagesCreator.kt @@ -116,7 +116,7 @@ internal class NewGroupConversationSystemMessagesCreatorImpl( persistReadReceiptSystemMessage( conversationId = conversation.id.toModel(), - creatorId = qualifiedIdMapper.fromStringToQualifiedID(conversation.creator), + creatorId = conversation.creator?.let { qualifiedIdMapper.fromStringToQualifiedID(it) } ?: selfUserId, receiptMode = conversation.receiptMode == ReceiptMode.ENABLED ) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapper.kt index 2c7259ef7a0..16d19a7cdaa 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapper.kt @@ -25,7 +25,6 @@ import com.wire.kalium.persistence.dao.conversation.ConversationEntity interface ProtocolInfoMapper { fun fromEntity(protocolInfo: ConversationEntity.ProtocolInfo): Conversation.ProtocolInfo fun toEntity(protocolInfo: Conversation.ProtocolInfo): ConversationEntity.ProtocolInfo - fun fromInfoToProtocol(protocolInfo: Conversation.ProtocolInfo): Conversation.Protocol } class ProtocolInfoMapperImpl( @@ -36,7 +35,14 @@ class ProtocolInfoMapperImpl( is ConversationEntity.ProtocolInfo.Proteus -> Conversation.ProtocolInfo.Proteus is ConversationEntity.ProtocolInfo.MLS -> Conversation.ProtocolInfo.MLS( idMapper.fromGroupIDEntity(protocolInfo.groupId), - Conversation.ProtocolInfo.MLS.GroupState.valueOf(protocolInfo.groupState.name), + Conversation.ProtocolInfo.MLSCapable.GroupState.valueOf(protocolInfo.groupState.name), + protocolInfo.epoch, + protocolInfo.keyingMaterialLastUpdate, + Conversation.CipherSuite.fromTag(protocolInfo.cipherSuite.cipherSuiteTag) + ) + is ConversationEntity.ProtocolInfo.Mixed -> Conversation.ProtocolInfo.Mixed( + idMapper.fromGroupIDEntity(protocolInfo.groupId), + Conversation.ProtocolInfo.MLSCapable.GroupState.valueOf(protocolInfo.groupState.name), protocolInfo.epoch, protocolInfo.keyingMaterialLastUpdate, Conversation.CipherSuite.fromTag(protocolInfo.cipherSuite.cipherSuiteTag) @@ -53,11 +59,12 @@ class ProtocolInfoMapperImpl( protocolInfo.keyingMaterialLastUpdate, ConversationEntity.CipherSuite.fromTag(protocolInfo.cipherSuite.tag) ) - } - - override fun fromInfoToProtocol(protocolInfo: Conversation.ProtocolInfo): Conversation.Protocol = - when (protocolInfo) { - is Conversation.ProtocolInfo.Proteus -> Conversation.Protocol.PROTEUS - is Conversation.ProtocolInfo.MLS -> Conversation.Protocol.MLS + is Conversation.ProtocolInfo.Mixed -> ConversationEntity.ProtocolInfo.Mixed( + idMapper.toGroupIDEntity(protocolInfo.groupId), + ConversationEntity.GroupState.valueOf(protocolInfo.groupState.name), + protocolInfo.epoch, + protocolInfo.keyingMaterialLastUpdate, + ConversationEntity.CipherSuite.fromTag(protocolInfo.cipherSuite.tag) + ) } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/Event.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/Event.kt index bb143153e01..17297a83053 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/Event.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/Event.kt @@ -24,7 +24,7 @@ import com.wire.kalium.logger.obfuscateDomain import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.data.client.Client import com.wire.kalium.logic.data.conversation.ClientId -import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.Conversation.Protocol import com.wire.kalium.logic.data.conversation.Conversation.Member import com.wire.kalium.logic.data.conversation.Conversation.ReceiptMode import com.wire.kalium.logic.data.conversation.Conversation.TypingIndicatorMode @@ -33,18 +33,20 @@ import com.wire.kalium.logic.data.featureConfig.ClassifiedDomainsModel import com.wire.kalium.logic.data.featureConfig.ConferenceCallingModel import com.wire.kalium.logic.data.featureConfig.ConfigsStatusModel import com.wire.kalium.logic.data.featureConfig.E2EIModel +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel import com.wire.kalium.logic.data.featureConfig.MLSModel import com.wire.kalium.logic.data.featureConfig.SelfDeletingMessagesModel import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.SubconversationId import com.wire.kalium.logic.data.user.Connection +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.util.DateTimeUtil import com.wire.kalium.util.serialization.toJsonElement import kotlinx.serialization.json.JsonNull -sealed class Event(open val id: String, open val transient: Boolean) { +sealed class Event(open val id: String, open val transient: Boolean, open val live: Boolean) { private companion object { const val typeKey = "type" @@ -69,15 +71,17 @@ sealed class Event(open val id: String, open val transient: Boolean) { sealed class Conversation( id: String, override val transient: Boolean, + override val live: Boolean, open val conversationId: ConversationId - ) : Event(id, transient) { + ) : Event(id, transient, live) { data class AccessUpdate( override val id: String, override val conversationId: ConversationId, val data: ConversationResponse, val qualifiedFrom: UserId, - override val transient: Boolean - ) : Conversation(id, transient, conversationId) { + override val transient: Boolean, + override val live: Boolean + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.AccessUpdate", @@ -91,12 +95,13 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val senderUserId: UserId, val senderClientId: ClientId, val timestampIso: String, val content: String, val encryptedExternalContent: EncryptedData? - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.NewMessage", @@ -112,11 +117,12 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val subconversationId: SubconversationId?, val senderUserId: UserId, val timestampIso: String, val content: String - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.NewMLSMessage", @@ -131,10 +137,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val senderUserId: UserId, val timestampIso: String, val conversation: ConversationResponse - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.NewConversation", @@ -148,10 +155,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val addedBy: UserId, val members: List, val timestampIso: String - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.MemberJoin", @@ -167,10 +175,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val removedBy: UserId, val removedList: List, val timestampIso: String - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.MemberLeave", @@ -185,15 +194,17 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, open val timestampIso: String, - transient: Boolean, - ) : Conversation(id, transient, conversationId) { + override val transient: Boolean, + override val live: Boolean, + ) : Conversation(id, transient, live, conversationId) { class MemberChangedRole( override val id: String, override val conversationId: ConversationId, override val timestampIso: String, override val transient: Boolean, + override val live: Boolean, val member: Member?, - ) : MemberChanged(id, conversationId, timestampIso, transient) { + ) : MemberChanged(id, conversationId, timestampIso, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.MemberChangedRole", @@ -209,9 +220,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val conversationId: ConversationId, override val timestampIso: String, override val transient: Boolean, + override val live: Boolean, val mutedConversationStatus: MutedConversationStatus, val mutedConversationChangedTime: String - ) : MemberChanged(id, conversationId, timestampIso, transient) { + ) : MemberChanged(id, conversationId, timestampIso, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.MemberMutedStatusChanged", @@ -228,9 +240,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val conversationId: ConversationId, override val timestampIso: String, override val transient: Boolean, + override val live: Boolean, val archivedConversationChangedTime: String, val isArchiving: Boolean - ) : MemberChanged(id, conversationId, timestampIso, transient) { + ) : MemberChanged(id, conversationId, timestampIso, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.MemberArchivedStatusChanged", @@ -245,8 +258,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class IgnoredMemberChanged( override val id: String, override val conversationId: ConversationId, - override val transient: Boolean - ) : MemberChanged(id, conversationId, "", transient) { + override val transient: Boolean, + override val live: Boolean, + ) : MemberChanged(id, conversationId, "", transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.IgnoredMemberChanged", @@ -260,10 +274,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val senderUserId: UserId, val message: String, val timestampIso: String = DateTimeUtil.currentIsoDateTimeString() - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.MLSWelcome", idKey to id.obfuscateId(), @@ -277,9 +292,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val senderUserId: UserId, val timestampIso: String, - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.DeletedConversation", @@ -294,10 +310,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val conversationName: String, val senderUserId: UserId, val timestampIso: String, - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.RenamedConversation", idKey to id.obfuscateId(), @@ -312,9 +329,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val receiptMode: ReceiptMode, val senderUserId: UserId - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap() = mapOf( typeKey to "Conversation.ConversationReceiptMode", @@ -329,10 +347,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val messageTimer: Long?, val senderUserId: UserId, val timestampIso: String - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap() = mapOf( typeKey to "Conversation.ConversationMessageTimer", @@ -348,11 +367,12 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val key: String, val code: String, val uri: String, val isPasswordProtected: Boolean, - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf(typeKey to "Conversation.CodeUpdated") } @@ -360,7 +380,8 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, - ) : Conversation(id, transient, conversationId) { + override val live: Boolean, + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf(typeKey to "Conversation.CodeDeleted") } @@ -368,10 +389,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val conversationId: ConversationId, override val transient: Boolean, + override val live: Boolean, val senderUserId: UserId, val timestampIso: String, val typingIndicatorMode: TypingIndicatorMode, - ) : Conversation(id, transient, conversationId) { + ) : Conversation(id, transient, live, conversationId) { override fun toLogMap(): Map = mapOf( typeKey to "Conversation.TypingIndicator", conversationIdKey to conversationId.toLogString(), @@ -380,20 +402,39 @@ sealed class Event(open val id: String, open val transient: Boolean) { timestampIsoKey to timestampIso ) } + + data class ConversationProtocol( + override val id: String, + override val conversationId: ConversationId, + override val transient: Boolean, + override val live: Boolean, + val protocol: Protocol, + val senderUserId: UserId + ) : Conversation(id, transient, live, conversationId) { + override fun toLogMap() = mapOf( + typeKey to "Conversation.ConversationProtocol", + idKey to id.obfuscateId(), + conversationIdKey to conversationId.toLogString(), + "protocol" to protocol.name, + senderUserIdKey to senderUserId.toLogString(), + ) + } } sealed class Team( id: String, open val teamId: String, transient: Boolean, - ) : Event(id, transient) { + live: Boolean, + ) : Event(id, transient, live) { data class Update( override val id: String, override val transient: Boolean, + override val live: Boolean, override val teamId: String, val icon: String, val name: String, - ) : Team(id, teamId, transient) { + ) : Team(id, teamId, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Team.Update", idKey to id.obfuscateId(), @@ -407,8 +448,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val teamId: String, override val transient: Boolean, + override val live: Boolean, val memberId: String, - ) : Team(id, teamId, transient) { + ) : Team(id, teamId, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Team.MemberJoin", idKey to id.obfuscateId(), @@ -420,10 +462,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class MemberLeave( override val id: String, override val transient: Boolean, + override val live: Boolean, override val teamId: String, val memberId: String, val timestampIso: String, - ) : Team(id, teamId, transient) { + ) : Team(id, teamId, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Team.MemberLeave", idKey to id.obfuscateId(), @@ -437,9 +480,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { override val id: String, override val teamId: String, override val transient: Boolean, + override val live: Boolean, val memberId: String, val permissionCode: Int?, - ) : Team(id, teamId, transient) { + ) : Team(id, teamId, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Team.MemberUpdate", idKey to id.obfuscateId(), @@ -454,12 +498,14 @@ sealed class Event(open val id: String, open val transient: Boolean) { sealed class FeatureConfig( id: String, transient: Boolean, - ) : Event(id, transient) { + live: Boolean, + ) : Event(id, transient, live) { data class FileSharingUpdated( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: ConfigsStatusModel - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.FileSharingUpdated", idKey to id.obfuscateId(), @@ -470,8 +516,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class MLSUpdated( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: MLSModel - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.MLSUpdated", idKey to id.obfuscateId(), @@ -480,11 +527,27 @@ sealed class Event(open val id: String, open val transient: Boolean) { ) } + data class MLSMigrationUpdated( + override val id: String, + override val transient: Boolean, + override val live: Boolean, + val model: MLSMigrationModel + ) : FeatureConfig(id, transient, live) { + override fun toLogMap(): Map = mapOf( + typeKey to "FeatureConfig.MLSUpdated", + idKey to id.obfuscateId(), + featureStatusKey to model.status.name, + "startTime" to model.startTime, + "endTime" to model.endTime + ) + } + data class ClassifiedDomainsUpdated( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: ClassifiedDomainsModel, - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.ClassifiedDomainsUpdated", idKey to id.obfuscateId(), @@ -496,8 +559,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class ConferenceCallingUpdated( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: ConferenceCallingModel, - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap() = mapOf( typeKey to "FeatureConfig.ConferenceCallingUpdated", idKey to id.obfuscateId(), @@ -508,8 +572,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class GuestRoomLinkUpdated( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: ConfigsStatusModel, - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.GuestRoomLinkUpdated", idKey to id.obfuscateId(), @@ -520,8 +585,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class SelfDeletingMessagesConfig( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: SelfDeletingMessagesModel, - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.SelfDeletingMessagesConfig", idKey to id.obfuscateId(), @@ -533,8 +599,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class MLSE2EIUpdated( override val id: String, override val transient: Boolean, + override val live: Boolean, val model: E2EIModel - ) : FeatureConfig(id, transient) { + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.MLSE2EIUpdated", idKey to id.obfuscateId(), @@ -546,7 +613,8 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class UnknownFeatureUpdated( override val id: String, override val transient: Boolean, - ) : FeatureConfig(id, transient) { + override val live: Boolean, + ) : FeatureConfig(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "FeatureConfig.UnknownFeatureUpdated", idKey to id.obfuscateId(), @@ -556,13 +624,15 @@ sealed class Event(open val id: String, open val transient: Boolean) { sealed class User( id: String, - transient: Boolean - ) : Event(id, transient) { + transient: Boolean, + live: Boolean, + ) : Event(id, transient, live) { data class Update( override val id: String, override val transient: Boolean, - val userId: String, + override val live: Boolean, + val userId: UserId, val accentId: Int?, val ssoIdDeleted: Boolean?, val name: String?, @@ -570,19 +640,21 @@ sealed class Event(open val id: String, open val transient: Boolean) { val email: String?, val previewAssetId: String?, val completeAssetId: String?, - ) : User(id, transient) { + val supportedProtocols: Set? + ) : User(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.Update", idKey to id.obfuscateId(), - userIdKey to userId.obfuscateId() + userIdKey to userId.toLogString() ) } data class NewConnection( override val transient: Boolean, + override val live: Boolean, override val id: String, val connection: Connection - ) : User(id, transient) { + ) : User(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.NewConnection", idKey to id.obfuscateId(), @@ -592,9 +664,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class ClientRemove( override val transient: Boolean, + override val live: Boolean, override val id: String, val clientId: ClientId - ) : User(id, transient) { + ) : User(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.ClientRemove", idKey to id.obfuscateId(), @@ -604,10 +677,11 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class UserDelete( override val transient: Boolean, + override val live: Boolean, override val id: String, val userId: UserId, val timestampIso: String = DateTimeUtil.currentIsoDateTimeString() // TODO we are not receiving it from API - ) : User(id, transient) { + ) : User(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.UserDelete", idKey to id.obfuscateId(), @@ -618,9 +692,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class NewClient( override val transient: Boolean, + override val live: Boolean, override val id: String, val client: Client, - ) : User(id, transient) { + ) : User(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.NewClient", idKey to id.obfuscateId(), @@ -629,21 +704,24 @@ sealed class Event(open val id: String, open val transient: Boolean) { "model" to (client.model ?: ""), "clientType" to client.type, "deviceType" to client.deviceType, - "label" to (client.label ?: "") + "label" to (client.label ?: ""), + "isMLSCapable" to client.isMLSCapable ) } } sealed class UserProperty( id: String, - transient: Boolean - ) : Event(id, transient) { + transient: Boolean, + live: Boolean, + ) : Event(id, transient, live) { data class ReadReceiptModeSet( override val id: String, override val transient: Boolean, + override val live: Boolean, val value: Boolean, - ) : UserProperty(id, transient) { + ) : UserProperty(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.UserProperty.ReadReceiptModeSet", idKey to id.obfuscateId(), @@ -655,8 +733,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class TypingIndicatorModeSet( override val id: String, override val transient: Boolean, + override val live: Boolean, val value: Boolean, - ) : UserProperty(id, transient) { + ) : UserProperty(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.UserProperty.TypingIndicatorModeSet", idKey to id.obfuscateId(), @@ -669,9 +748,10 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class Unknown( override val id: String, override val transient: Boolean, + override val live: Boolean, val unknownType: String, val cause: String? = null - ) : Event(id, transient) { + ) : Event(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "User.UnknownEvent", idKey to id.obfuscateId(), @@ -683,13 +763,15 @@ sealed class Event(open val id: String, open val transient: Boolean) { sealed class Federation( id: String, override val transient: Boolean, - ) : Event(id, transient) { + override val live: Boolean, + ) : Event(id, transient, live) { data class Delete( override val id: String, override val transient: Boolean, + override val live: Boolean, val domain: String, - ) : Federation(id, transient) { + ) : Federation(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Federation.Delete", idKey to id.obfuscateId(), @@ -701,8 +783,9 @@ sealed class Event(open val id: String, open val transient: Boolean) { data class ConnectionRemoved( override val id: String, override val transient: Boolean, + override val live: Boolean, val domains: List, - ) : Federation(id, transient) { + ) : Federation(id, transient, live) { override fun toLogMap(): Map = mapOf( typeKey to "Federation.ConnectionRemoved", idKey to id.obfuscateId(), diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventMapper.kt index c915ca50684..87b49e61b6b 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventMapper.kt @@ -33,6 +33,8 @@ import com.wire.kalium.logic.data.event.Event.UserProperty.TypingIndicatorModeSe import com.wire.kalium.logic.data.featureConfig.FeatureConfigMapper import com.wire.kalium.logic.data.id.SubconversationId import com.wire.kalium.logic.data.id.toModel +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.toModel import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.util.Base64 import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigData @@ -49,85 +51,88 @@ import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.serializer -@Suppress("TooManyFunctions") +@Suppress("TooManyFunctions", "LongParameterList") class EventMapper( private val memberMapper: MemberMapper, private val connectionMapper: ConnectionMapper, private val featureConfigMapper: FeatureConfigMapper, private val roleMapper: ConversationRoleMapper, + private val selfUserId: UserId, private val receiptModeMapper: ReceiptModeMapper = MapperProvider.receiptModeMapper(), private val clientMapper: ClientMapper = MapperProvider.clientMapper() ) { - fun fromDTO(eventResponse: EventResponse): List { + fun fromDTO(eventResponse: EventResponse, live: Boolean = false): List { // TODO(edge-case): Multiple payloads in the same event have the same ID, is this an issue when marking lastProcessedEventId? val id = eventResponse.id return eventResponse.payload?.map { eventContentDTO -> - fromEventContentDTO(id, eventContentDTO, eventResponse.transient) + fromEventContentDTO(id, eventContentDTO, eventResponse.transient, live) } ?: listOf() } @Suppress("ComplexMethod") - fun fromEventContentDTO(id: String, eventContentDTO: EventContentDTO, transient: Boolean): Event = + fun fromEventContentDTO(id: String, eventContentDTO: EventContentDTO, transient: Boolean, live: Boolean): Event = when (eventContentDTO) { - is EventContentDTO.Conversation.NewMessageDTO -> newMessage(id, eventContentDTO, transient) - is EventContentDTO.Conversation.NewConversationDTO -> newConversation(id, eventContentDTO, transient) - is EventContentDTO.Conversation.MemberJoinDTO -> conversationMemberJoin(id, eventContentDTO, transient) - is EventContentDTO.Conversation.MemberLeaveDTO -> conversationMemberLeave(id, eventContentDTO, transient) - is EventContentDTO.Conversation.MemberUpdateDTO -> memberUpdate(id, eventContentDTO, transient) - is EventContentDTO.Conversation.MLSWelcomeDTO -> welcomeMessage(id, eventContentDTO, transient) - is EventContentDTO.Conversation.NewMLSMessageDTO -> newMLSMessage(id, eventContentDTO, transient) - is EventContentDTO.User.NewConnectionDTO -> connectionUpdate(id, eventContentDTO, transient) - is EventContentDTO.User.ClientRemoveDTO -> clientRemove(id, eventContentDTO, transient) - is EventContentDTO.User.UserDeleteDTO -> userDelete(id, eventContentDTO, transient) - is EventContentDTO.FeatureConfig.FeatureConfigUpdatedDTO -> featureConfig(id, eventContentDTO, transient) - is EventContentDTO.User.NewClientDTO -> newClient(id, eventContentDTO, transient) - is EventContentDTO.Unknown -> unknown(id, transient, eventContentDTO) - is EventContentDTO.Conversation.AccessUpdate -> unknown(id, transient, eventContentDTO) - is EventContentDTO.Conversation.DeletedConversationDTO -> conversationDeleted(id, eventContentDTO, transient) - - is EventContentDTO.Conversation.ConversationRenameDTO -> conversationRenamed(id, eventContentDTO, transient) - is EventContentDTO.Team.MemberJoin -> teamMemberJoined(id, eventContentDTO, transient) - is EventContentDTO.Team.MemberLeave -> teamMemberLeft(id, eventContentDTO, transient) - is EventContentDTO.Team.MemberUpdate -> teamMemberUpdate(id, eventContentDTO, transient) - is EventContentDTO.Team.Update -> teamUpdate(id, eventContentDTO, transient) - is EventContentDTO.User.UpdateDTO -> userUpdate(id, eventContentDTO, transient) - is EventContentDTO.UserProperty.PropertiesSetDTO -> updateUserProperties(id, eventContentDTO, transient) - is EventContentDTO.UserProperty.PropertiesDeleteDTO -> deleteUserProperties(id, eventContentDTO, transient) - is EventContentDTO.Conversation.ReceiptModeUpdate -> conversationReceiptModeUpdate(id, eventContentDTO, transient) - - is EventContentDTO.Conversation.MessageTimerUpdate -> conversationMessageTimerUpdate(id, eventContentDTO, transient) - - is EventContentDTO.Conversation.CodeDeleted -> conversationCodeDeleted(id, eventContentDTO, transient) - is EventContentDTO.Conversation.CodeUpdated -> conversationCodeUpdated(id, eventContentDTO, transient) - is EventContentDTO.Federation -> federationTerminated(id, eventContentDTO, transient) - is EventContentDTO.Conversation.ConversationTypingDTO -> conversationTyping(id, eventContentDTO, transient) + is EventContentDTO.Conversation.NewMessageDTO -> newMessage(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.NewConversationDTO -> newConversation(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.MemberJoinDTO -> conversationMemberJoin(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.MemberLeaveDTO -> conversationMemberLeave(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.MemberUpdateDTO -> memberUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.MLSWelcomeDTO -> welcomeMessage(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.NewMLSMessageDTO -> newMLSMessage(id, eventContentDTO, transient, live) + is EventContentDTO.User.NewConnectionDTO -> connectionUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.User.ClientRemoveDTO -> clientRemove(id, eventContentDTO, transient, live) + is EventContentDTO.User.UserDeleteDTO -> userDelete(id, eventContentDTO, transient, live) + is EventContentDTO.FeatureConfig.FeatureConfigUpdatedDTO -> featureConfig(id, eventContentDTO, transient, live) + is EventContentDTO.User.NewClientDTO -> newClient(id, eventContentDTO, transient, live) + is EventContentDTO.Unknown -> unknown(id, transient, live, eventContentDTO) + is EventContentDTO.Conversation.AccessUpdate -> unknown(id, transient, live, eventContentDTO) + is EventContentDTO.Conversation.DeletedConversationDTO -> conversationDeleted(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.ConversationRenameDTO -> conversationRenamed(id, eventContentDTO, transient, live) + is EventContentDTO.Team.MemberJoin -> teamMemberJoined(id, eventContentDTO, transient, live) + is EventContentDTO.Team.MemberLeave -> teamMemberLeft(id, eventContentDTO, transient, live) + is EventContentDTO.Team.MemberUpdate -> teamMemberUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.Team.Update -> teamUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.User.UpdateDTO -> userUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.UserProperty.PropertiesSetDTO -> updateUserProperties(id, eventContentDTO, transient, live) + is EventContentDTO.UserProperty.PropertiesDeleteDTO -> deleteUserProperties(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.ReceiptModeUpdate -> conversationReceiptModeUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.MessageTimerUpdate -> conversationMessageTimerUpdate(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.CodeDeleted -> conversationCodeDeleted(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.CodeUpdated -> conversationCodeUpdated(id, eventContentDTO, transient, live) + is EventContentDTO.Federation -> federationTerminated(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.ConversationTypingDTO -> conversationTyping(id, eventContentDTO, transient, live) + is EventContentDTO.Conversation.ProtocolUpdate -> conversationProtocolUpdate(id, eventContentDTO, transient, live) } private fun conversationTyping( id: String, eventContentDTO: EventContentDTO.Conversation.ConversationTypingDTO, - transient: Boolean + transient: Boolean, + live: Boolean ): Event = Event.Conversation.TypingIndicator( id, eventContentDTO.qualifiedConversation.toModel(), transient, + live, eventContentDTO.qualifiedFrom.toModel(), eventContentDTO.time, eventContentDTO.status.status.toModel() ) - private fun federationTerminated(id: String, eventContentDTO: EventContentDTO.Federation, transient: Boolean): Event = + private fun federationTerminated(id: String, eventContentDTO: EventContentDTO.Federation, transient: Boolean, live: Boolean): Event = when (eventContentDTO) { is EventContentDTO.Federation.FederationConnectionRemovedDTO -> Event.Federation.ConnectionRemoved( id, transient, + live, eventContentDTO.domains ) is EventContentDTO.Federation.FederationDeleteDTO -> Event.Federation.Delete( id, transient, + live, eventContentDTO.domain ) } @@ -135,17 +140,20 @@ class EventMapper( private fun conversationCodeDeleted( id: String, event: EventContentDTO.Conversation.CodeDeleted, - transient: Boolean + transient: Boolean, + live: Boolean ): Event.Conversation.CodeDeleted = Event.Conversation.CodeDeleted( id = id, transient = transient, + live = live, conversationId = event.qualifiedConversation.toModel() ) private fun conversationCodeUpdated( id: String, event: EventContentDTO.Conversation.CodeUpdated, - transient: Boolean + transient: Boolean, + live: Boolean ): Event.Conversation.CodeUpdated = Event.Conversation.CodeUpdated( id = id, key = event.data.key, @@ -153,18 +161,21 @@ class EventMapper( uri = event.data.uri, isPasswordProtected = event.data.hasPassword, conversationId = event.qualifiedConversation.toModel(), - transient = transient + transient = transient, + live = live, ) @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) fun unknown( id: String, transient: Boolean, + live: Boolean, eventContentDTO: EventContentDTO, cause: String? = null ): Event.Unknown = Event.Unknown( id = id, transient = transient, + live = live, unknownType = when (eventContentDTO) { is EventContentDTO.Unknown -> eventContentDTO.type else -> try { @@ -176,14 +187,30 @@ class EventMapper( cause = cause ) + private fun conversationProtocolUpdate( + id: String, + eventContentDTO: EventContentDTO.Conversation.ProtocolUpdate, + transient: Boolean, + live: Boolean + ): Event = Event.Conversation.ConversationProtocol( + id = id, + conversationId = eventContentDTO.qualifiedConversation.toModel(), + transient = transient, + live = live, + protocol = eventContentDTO.data.protocol.toModel(), + senderUserId = eventContentDTO.qualifiedFrom.toModel() + ) + fun conversationMessageTimerUpdate( id: String, eventContentDTO: EventContentDTO.Conversation.MessageTimerUpdate, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.ConversationMessageTimer( id = id, conversationId = eventContentDTO.qualifiedConversation.toModel(), transient = transient, + live = live, messageTimer = eventContentDTO.data.messageTimer, senderUserId = eventContentDTO.qualifiedFrom.toModel(), timestampIso = eventContentDTO.time @@ -192,11 +219,13 @@ class EventMapper( private fun conversationReceiptModeUpdate( id: String, eventContentDTO: EventContentDTO.Conversation.ReceiptModeUpdate, - transient: Boolean + transient: Boolean, + live: Boolean ): Event = Event.Conversation.ConversationReceiptMode( id = id, conversationId = eventContentDTO.qualifiedConversation.toModel(), transient = transient, + live = live, receiptMode = receiptModeMapper.fromApiToModel(eventContentDTO.data.receiptMode), senderUserId = eventContentDTO.qualifiedFrom.toModel() ) @@ -204,7 +233,8 @@ class EventMapper( private fun updateUserProperties( id: String, eventContentDTO: EventContentDTO.UserProperty.PropertiesSetDTO, - transient: Boolean + transient: Boolean, + live: Boolean ): Event { val fieldKeyValue = eventContentDTO.value val key = eventContentDTO.key @@ -214,18 +244,21 @@ class EventMapper( WIRE_RECEIPT_MODE.key -> ReadReceiptModeSet( id, transient, + live, fieldKeyValue.value == 1 ) WIRE_TYPING_INDICATOR_MODE.key -> TypingIndicatorModeSet( id, transient, + live, fieldKeyValue.value != 0 ) else -> unknown( id = id, transient = transient, + live = live, eventContentDTO = eventContentDTO, cause = "Unknown key: $key " ) @@ -235,6 +268,7 @@ class EventMapper( else -> unknown( id = id, transient = transient, + live = live, eventContentDTO = eventContentDTO, cause = "Unknown value type for key: ${eventContentDTO.key} " ) @@ -244,14 +278,16 @@ class EventMapper( private fun deleteUserProperties( id: String, eventContentDTO: EventContentDTO.UserProperty.PropertiesDeleteDTO, - transient: Boolean + transient: Boolean, + live: Boolean ): Event { return when (eventContentDTO.key) { - WIRE_RECEIPT_MODE.key -> ReadReceiptModeSet(id, transient, false) - WIRE_TYPING_INDICATOR_MODE.key -> TypingIndicatorModeSet(id, transient, true) + WIRE_RECEIPT_MODE.key -> ReadReceiptModeSet(id, transient, live, false) + WIRE_TYPING_INDICATOR_MODE.key -> TypingIndicatorModeSet(id, transient, live, true) else -> unknown( id = id, transient = transient, + live = live, eventContentDTO = eventContentDTO, cause = "Unknown key: ${eventContentDTO.key} " ) @@ -261,11 +297,13 @@ class EventMapper( private fun welcomeMessage( id: String, eventContentDTO: EventContentDTO.Conversation.MLSWelcomeDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.MLSWelcome( id, eventContentDTO.qualifiedConversation.toModel(), transient, + live, eventContentDTO.qualifiedFrom.toModel(), eventContentDTO.message, ) @@ -273,11 +311,13 @@ class EventMapper( private fun newMessage( id: String, eventContentDTO: EventContentDTO.Conversation.NewMessageDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.NewMessage( id, eventContentDTO.qualifiedConversation.toModel(), transient, + live, eventContentDTO.qualifiedFrom.toModel(), ClientId(eventContentDTO.data.sender), eventContentDTO.time, @@ -290,11 +330,13 @@ class EventMapper( private fun newMLSMessage( id: String, eventContentDTO: EventContentDTO.Conversation.NewMLSMessageDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.NewMLSMessage( id, eventContentDTO.qualifiedConversation.toModel(), transient, + live, eventContentDTO.subconversation?.let { SubconversationId(it) }, eventContentDTO.qualifiedFrom.toModel(), eventContentDTO.time, @@ -304,32 +346,42 @@ class EventMapper( private fun connectionUpdate( id: String, eventConnectionDTO: EventContentDTO.User.NewConnectionDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.User.NewConnection( transient, + live, id, connectionMapper.fromApiToModel(eventConnectionDTO.connection) ) - private fun userDelete(id: String, eventUserDelete: EventContentDTO.User.UserDeleteDTO, transient: Boolean): Event.User.UserDelete { - return Event.User.UserDelete(transient, id, eventUserDelete.userId.toModel()) + private fun userDelete( + id: String, + eventUserDelete: EventContentDTO.User.UserDeleteDTO, + transient: Boolean, + live: Boolean + ): Event.User.UserDelete { + return Event.User.UserDelete(transient, live, id, eventUserDelete.userId.toModel()) } private fun clientRemove( id: String, eventClientRemove: EventContentDTO.User.ClientRemoveDTO, - transient: Boolean + transient: Boolean, + live: Boolean ): Event.User.ClientRemove { - return Event.User.ClientRemove(transient, id, ClientId(eventClientRemove.client.clientId)) + return Event.User.ClientRemove(transient, live, id, ClientId(eventClientRemove.client.clientId)) } private fun newClient( id: String, eventNewClient: EventContentDTO.User.NewClientDTO, - transient: Boolean + transient: Boolean, + live: Boolean ): Event.User.NewClient { return Event.User.NewClient( transient = transient, + live = live, id = id, client = clientMapper.fromClientDto(eventNewClient.client) ) @@ -338,11 +390,13 @@ class EventMapper( private fun newConversation( id: String, eventContentDTO: EventContentDTO.Conversation.NewConversationDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.NewConversation( id, eventContentDTO.qualifiedConversation.toModel(), transient, + live, eventContentDTO.qualifiedFrom.toModel(), eventContentDTO.time, eventContentDTO.data @@ -351,33 +405,38 @@ class EventMapper( fun conversationMemberJoin( id: String, eventContentDTO: EventContentDTO.Conversation.MemberJoinDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.MemberJoin( id = id, conversationId = eventContentDTO.qualifiedConversation.toModel(), addedBy = eventContentDTO.qualifiedFrom.toModel(), members = eventContentDTO.members.users.map { memberMapper.fromApiModel(it) }, timestampIso = eventContentDTO.time, - transient = transient + transient = transient, + live = live, ) fun conversationMemberLeave( id: String, eventContentDTO: EventContentDTO.Conversation.MemberLeaveDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.MemberLeave( id = id, conversationId = eventContentDTO.qualifiedConversation.toModel(), removedBy = eventContentDTO.qualifiedFrom.toModel(), removedList = eventContentDTO.members.qualifiedUserIds.map { it.toModel() }, timestampIso = eventContentDTO.time, - transient = transient + transient = transient, + live = live, ) private fun memberUpdate( id: String, eventContentDTO: EventContentDTO.Conversation.MemberUpdateDTO, - transient: Boolean + transient: Boolean, + live: Boolean ): Event.Conversation.MemberChanged { return when { eventContentDTO.roleChange.role?.isNotEmpty() == true -> { @@ -386,6 +445,7 @@ class EventMapper( conversationId = eventContentDTO.qualifiedConversation.toModel(), timestampIso = eventContentDTO.time, transient = transient, + live = live, member = Conversation.Member( id = eventContentDTO.roleChange.qualifiedUserId.toModel(), role = roleMapper.fromApi(eventContentDTO.roleChange.role.orEmpty()) @@ -400,6 +460,7 @@ class EventMapper( timestampIso = eventContentDTO.time, mutedConversationChangedTime = eventContentDTO.roleChange.mutedRef.orEmpty(), transient = transient, + live = live, mutedConversationStatus = mapConversationMutedStatus(eventContentDTO.roleChange.mutedStatus) ) } @@ -410,6 +471,7 @@ class EventMapper( conversationId = eventContentDTO.qualifiedConversation.toModel(), timestampIso = eventContentDTO.time, transient = transient, + live = live, archivedConversationChangedTime = eventContentDTO.roleChange.archivedRef.orEmpty(), isArchiving = eventContentDTO.roleChange.isArchiving ?: false ) @@ -419,7 +481,8 @@ class EventMapper( Event.Conversation.MemberChanged.IgnoredMemberChanged( id, eventContentDTO.qualifiedConversation.toModel(), - transient + transient, + live ) } } @@ -436,132 +499,160 @@ class EventMapper( private fun featureConfig( id: String, featureConfigUpdatedDTO: EventContentDTO.FeatureConfig.FeatureConfigUpdatedDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = when (featureConfigUpdatedDTO.data) { is FeatureConfigData.FileSharing -> Event.FeatureConfig.FileSharingUpdated( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.FileSharing) ) is FeatureConfigData.SelfDeletingMessages -> Event.FeatureConfig.SelfDeletingMessagesConfig( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.SelfDeletingMessages) ) is FeatureConfigData.MLS -> Event.FeatureConfig.MLSUpdated( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.MLS) ) + is FeatureConfigData.MLSMigration -> Event.FeatureConfig.MLSMigrationUpdated( + id, + transient, + live, + featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.MLSMigration) + ) + is FeatureConfigData.ClassifiedDomains -> Event.FeatureConfig.ClassifiedDomainsUpdated( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.ClassifiedDomains) ) is FeatureConfigData.ConferenceCalling -> Event.FeatureConfig.ConferenceCallingUpdated( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.ConferenceCalling) ) is FeatureConfigData.ConversationGuestLinks -> Event.FeatureConfig.GuestRoomLinkUpdated( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.ConversationGuestLinks) ) is FeatureConfigData.E2EI -> Event.FeatureConfig.MLSE2EIUpdated( id, transient, + live, featureConfigMapper.fromDTO(featureConfigUpdatedDTO.data as FeatureConfigData.E2EI) ) - else -> Event.FeatureConfig.UnknownFeatureUpdated(id, transient) + else -> Event.FeatureConfig.UnknownFeatureUpdated(id, transient, live) } private fun conversationDeleted( id: String, deletedConversationDTO: EventContentDTO.Conversation.DeletedConversationDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.DeletedConversation( id = id, conversationId = deletedConversationDTO.qualifiedConversation.toModel(), senderUserId = deletedConversationDTO.qualifiedFrom.toModel(), transient = transient, + live = live, timestampIso = deletedConversationDTO.time ) fun conversationRenamed( id: String, event: EventContentDTO.Conversation.ConversationRenameDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Conversation.RenamedConversation( id = id, conversationId = event.qualifiedConversation.toModel(), senderUserId = event.qualifiedFrom.toModel(), conversationName = event.updateNameData.conversationName, transient = transient, + live = live, timestampIso = event.time, ) private fun teamMemberJoined( id: String, event: EventContentDTO.Team.MemberJoin, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Team.MemberJoin( id = id, teamId = event.teamId, transient = transient, + live = live, memberId = event.teamMember.nonQualifiedUserId ) private fun teamMemberLeft( id: String, event: EventContentDTO.Team.MemberLeave, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Team.MemberLeave( id = id, teamId = event.teamId, memberId = event.teamMember.nonQualifiedUserId, transient = transient, + live = live, timestampIso = event.time ) private fun teamMemberUpdate( id: String, event: EventContentDTO.Team.MemberUpdate, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Team.MemberUpdate( id = id, teamId = event.teamId, memberId = event.permissionsResponse.nonQualifiedUserId, transient = transient, + live = live, permissionCode = event.permissionsResponse.permissions.own ) private fun teamUpdate( id: String, event: EventContentDTO.Team.Update, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.Team.Update( id = id, teamId = event.teamId, icon = event.teamUpdate.icon, transient = transient, + live = live, name = event.teamUpdate.name ) private fun userUpdate( id: String, event: EventContentDTO.User.UpdateDTO, - transient: Boolean + transient: Boolean, + live: Boolean ) = Event.User.Update( id = id, - userId = event.userData.nonQualifiedUserId, + userId = UserId(event.userData.nonQualifiedUserId, selfUserId.domain), accentId = event.userData.accentId, ssoIdDeleted = event.userData.ssoIdDeleted, name = event.userData.name, @@ -569,7 +660,9 @@ class EventMapper( email = event.userData.email, previewAssetId = event.userData.assets?.getPreviewAssetOrNull()?.key, transient = transient, - completeAssetId = event.userData.assets?.getCompleteAssetOrNull()?.key + live = live, + completeAssetId = event.userData.assets?.getCompleteAssetOrNull()?.key, + supportedProtocols = event.userData.supportedProtocols?.toModel() ) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventRepository.kt index 655199e9d80..386b2e222ba 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/event/EventRepository.kt @@ -22,6 +22,7 @@ import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.feature.CurrentClientIdProvider import com.wire.kalium.logic.functional.Either @@ -81,7 +82,8 @@ class EventDataSource( private val notificationApi: NotificationApi, private val metadataDAO: MetadataDAO, private val currentClientId: CurrentClientIdProvider, - private val eventMapper: EventMapper = MapperProvider.eventMapper() + private val selfUserId: UserId, + private val eventMapper: EventMapper = MapperProvider.eventMapper(selfUserId) ) : EventRepository { // TODO(edge-case): handle Missing notification response (notify user that some messages are missing) @@ -108,7 +110,7 @@ class EventDataSource( } is WebSocketEvent.BinaryPayloadReceived -> { - eventMapper.fromDTO(webSocketEvent.payload).asFlow().map { WebSocketEvent.BinaryPayloadReceived(it) } + eventMapper.fromDTO(webSocketEvent.payload, true).asFlow().map { WebSocketEvent.BinaryPayloadReceived(it) } } } }.flattenConcat() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigMapper.kt index 22e98223345..c3c87e02e67 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigMapper.kt @@ -19,14 +19,19 @@ package com.wire.kalium.logic.data.featureConfig import com.wire.kalium.logic.data.id.PlainId +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.toModel import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigData import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigResponse import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureFlagStatusDTO +import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSMigrationConfigDTO +import com.wire.kalium.persistence.config.MLSMigrationEntity interface FeatureConfigMapper { fun fromDTO(featureConfigResponse: FeatureConfigResponse): FeatureConfigModel fun fromDTO(status: FeatureFlagStatusDTO): Status fun fromDTO(data: FeatureConfigData.MLS?): MLSModel + fun fromDTO(data: FeatureConfigData.MLSMigration): MLSMigrationModel fun fromDTO(data: FeatureConfigData.AppLock): AppLockModel fun fromDTO(data: FeatureConfigData.ClassifiedDomains): ClassifiedDomainsModel fun fromDTO(data: FeatureConfigData.SelfDeletingMessages): SelfDeletingMessagesModel @@ -34,6 +39,8 @@ interface FeatureConfigMapper { fun fromDTO(data: FeatureConfigData.ConferenceCalling): ConferenceCallingModel fun fromDTO(data: FeatureConfigData.ConversationGuestLinks): ConfigsStatusModel fun fromDTO(data: FeatureConfigData.E2EI?): E2EIModel + fun fromModel(status: Status): FeatureFlagStatusDTO + fun fromModel(model: MLSMigrationModel): FeatureConfigData.MLSMigration } class FeatureConfigMapperImpl : FeatureConfigMapper { @@ -56,7 +63,8 @@ class FeatureConfigMapperImpl : FeatureConfigMapper { ssoModel = ConfigsStatusModel(fromDTO(sso.status)), validateSAMLEmailsModel = ConfigsStatusModel(fromDTO(validateSAMLEmails.status)), mlsModel = fromDTO(mls), - e2EIModel = fromDTO(mlsE2EI) + e2EIModel = fromDTO(mlsE2EI), + mlsMigrationModel = mlsMigration?.let { fromDTO(it) } ) } @@ -70,13 +78,25 @@ class FeatureConfigMapperImpl : FeatureConfigMapper { data?.let { MLSModel( it.config.protocolToggleUsers.map { userId -> PlainId(userId) }, + it.config.defaultProtocol.toModel(), + it.config.supportedProtocols.map { it.toModel() }.toSet(), fromDTO(it.status) ) } ?: MLSModel( listOf(), + SupportedProtocol.PROTEUS, + setOf(SupportedProtocol.PROTEUS), Status.DISABLED ) + @Suppress("MagicNumber") + override fun fromDTO(data: FeatureConfigData.MLSMigration): MLSMigrationModel = + MLSMigrationModel( + data.config.startTime, + data.config.finaliseRegardlessAfter, + fromDTO(data.status) + ) + override fun fromDTO(data: FeatureConfigData.AppLock): AppLockModel = AppLockModel( AppLockConfigModel(data.config.enforceAppLock, data.config.inactivityTimeoutSecs), @@ -118,4 +138,32 @@ class FeatureConfigMapperImpl : FeatureConfigMapper { ), fromDTO(data?.status ?: FeatureFlagStatusDTO.DISABLED) ) + override fun fromModel(status: Status): FeatureFlagStatusDTO = + when (status) { + Status.ENABLED -> FeatureFlagStatusDTO.ENABLED + Status.DISABLED -> FeatureFlagStatusDTO.DISABLED + } + + override fun fromModel(model: MLSMigrationModel): FeatureConfigData.MLSMigration = + FeatureConfigData.MLSMigration( + MLSMigrationConfigDTO( + model.startTime, + model.endTime + ), + fromModel(model.status) + ) } + +fun MLSMigrationModel.toEntity(): MLSMigrationEntity = + MLSMigrationEntity( + status = status.equals(Status.ENABLED), + startTime = startTime, + endTime = endTime + ) + +fun MLSMigrationEntity.toModel(): MLSMigrationModel = + MLSMigrationModel( + status = if (status) Status.ENABLED else Status.DISABLED, + startTime = startTime, + endTime = endTime + ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigModel.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigModel.kt index ba514ec7d1e..b4c35f77003 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigModel.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigModel.kt @@ -20,6 +20,8 @@ package com.wire.kalium.logic.data.featureConfig import com.wire.kalium.logic.data.id.PlainId import com.wire.kalium.util.time.Second +import com.wire.kalium.logic.data.user.SupportedProtocol +import kotlinx.datetime.Instant data class FeatureConfigModel( val appLockModel: AppLockModel, @@ -36,7 +38,8 @@ data class FeatureConfigModel( val ssoModel: ConfigsStatusModel, val validateSAMLEmailsModel: ConfigsStatusModel, val mlsModel: MLSModel, - val e2EIModel: E2EIModel + val e2EIModel: E2EIModel, + val mlsMigrationModel: MLSMigrationModel? ) enum class Status { @@ -78,6 +81,14 @@ data class SelfDeletingMessagesConfigModel( data class MLSModel( val allowedUsers: List, + val defaultProtocol: SupportedProtocol, + val supportedProtocols: Set, + val status: Status +) + +data class MLSMigrationModel( + val startTime: Instant?, + val endTime: Instant?, val status: Status ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt index d3a8cd5b66e..d8c764d94ef 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt @@ -305,6 +305,10 @@ sealed interface Message { typeKey to "conversationMightLostHistory" ) + MessageContent.HistoryLostProtocolChanged -> mutableMapOf( + typeKey to "conversationMightLostHistoryProtocolChanged" + ) + is MessageContent.ConversationMessageTimerChanged -> mutableMapOf( typeKey to "conversationMessageTimerChanged" ) @@ -336,9 +340,14 @@ sealed interface Message { is MessageContent.FederationStopped.ConnectionRemoved -> mutableMapOf( typeKey to "federationConnectionRemoved" ) + is MessageContent.FederationStopped.Removed -> mutableMapOf( typeKey to "federationRemoved" ) + + is MessageContent.ConversationProtocolChanged -> mutableMapOf( + typeKey to "conversationProtocolChanged" + ) } val standardProperties = mapOf( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt index f73707d7f18..4a5523fb94a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt @@ -20,6 +20,7 @@ package com.wire.kalium.logic.data.message import com.wire.kalium.logger.obfuscateId import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.MessageButtonId import com.wire.kalium.logic.data.id.MessageId @@ -273,6 +274,10 @@ sealed class MessageContent { val messageTimer: Long? ) : System() + data class ConversationProtocolChanged( + val protocol: Conversation.Protocol + ) : System() + // we can add other types to be processed, but signaling ones shouldn't be persisted object Ignored : Signaling() // messages that aren't processed in any way @@ -289,6 +294,8 @@ sealed class MessageContent { object CryptoSessionReset : System() + object HistoryLostProtocolChanged : System() + object HistoryLost : System() object ConversationCreated : System() data object ConversationDegradedMLS : System() @@ -328,6 +335,7 @@ fun MessageContent?.getType() = when (this) { is MessageContent.ConversationRenamed -> "ConversationRenamed" is MessageContent.CryptoSessionReset -> "CryptoSessionReset" is MessageContent.HistoryLost -> "HistoryLost" + is MessageContent.HistoryLostProtocolChanged -> "HistoryLostProtocolChanged" is MessageContent.MemberChange.Added -> "MemberChange.Added" is MessageContent.MemberChange.Removed -> "MemberChange.Removed" is MessageContent.MissedCall -> "MissedCall" @@ -345,6 +353,7 @@ fun MessageContent?.getType() = when (this) { is MessageContent.MemberChange.FederationRemoved -> "MemberChange.FederationRemoved" is MessageContent.FederationStopped.ConnectionRemoved -> "Federation.ConnectionRemoved" is MessageContent.FederationStopped.Removed -> "Federation.Removed" + is MessageContent.ConversationProtocolChanged -> "ConversationProtocolChanged" is MessageContent.Unknown -> "Unknown" MessageContent.ConversationVerifiedMLS -> "ConversationVerification.Verified.MLS" MessageContent.ConversationVerifiedProteus -> "ConversationVerification.Verified.Proteus" diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt index dd01e37ca91..24b4fdad4a0 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt @@ -20,6 +20,8 @@ package com.wire.kalium.logic.data.message import com.wire.kalium.logic.data.asset.AssetMapper import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.data.conversation.toDao +import com.wire.kalium.logic.data.conversation.toModel import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.message.AssetContent.AssetMetadata.Audio @@ -251,6 +253,7 @@ class MessageMapperImpl( MessageEntity.ContentType.NEW_CONVERSATION_RECEIPT_MODE -> null MessageEntity.ContentType.CONVERSATION_RECEIPT_MODE_CHANGED -> null MessageEntity.ContentType.HISTORY_LOST -> null + MessageEntity.ContentType.HISTORY_LOST_PROTOCOL_CHANGED -> null MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED -> null MessageEntity.ContentType.CONVERSATION_CREATED -> null MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING -> null @@ -260,6 +263,7 @@ class MessageMapperImpl( MessageEntity.ContentType.FEDERATION -> null MessageEntity.ContentType.CONVERSATION_VERIFIED_MLS -> null MessageEntity.ContentType.CONVERSATION_VERIFIED_PREOTEUS -> null + MessageEntity.ContentType.CONVERSATION_PROTOCOL_CHANGED -> null } } @@ -361,6 +365,7 @@ class MessageMapperImpl( is MessageEntityContent.NewConversationReceiptMode -> MessageContent.NewConversationReceiptMode(receiptMode) is MessageEntityContent.ConversationReceiptModeChanged -> MessageContent.ConversationReceiptModeChanged(receiptMode) is MessageEntityContent.HistoryLost -> MessageContent.HistoryLost + is MessageEntityContent.HistoryLostProtocolChanged -> MessageContent.HistoryLostProtocolChanged is MessageEntityContent.ConversationMessageTimerChanged -> MessageContent.ConversationMessageTimerChanged(messageTimer) is MessageEntityContent.ConversationCreated -> MessageContent.ConversationCreated is MessageEntityContent.MLSWrongEpochWarning -> MessageContent.MLSWrongEpochWarning @@ -372,6 +377,7 @@ class MessageMapperImpl( MessageEntity.FederationType.DELETE -> MessageContent.FederationStopped.Removed(domainList.first()) MessageEntity.FederationType.CONNECTION_REMOVED -> MessageContent.FederationStopped.ConnectionRemoved(domainList) } + is MessageEntityContent.ConversationProtocolChanged -> MessageContent.ConversationProtocolChanged(protocol.toModel()) } } @@ -584,4 +590,6 @@ fun MessageContent.System.toMessageEntityContent(): MessageEntityContent.System MessageContent.ConversationVerifiedMLS -> MessageEntityContent.ConversationVerifiedMLS MessageContent.ConversationVerifiedProteus -> MessageEntityContent.ConversationVerifiedProteus + is MessageContent.ConversationProtocolChanged -> MessageEntityContent.ConversationProtocolChanged(protocol.toDao()) + MessageContent.HistoryLostProtocolChanged -> MessageEntityContent.HistoryLostProtocolChanged } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageRepository.kt index 9148acfaa28..f0252254f71 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageRepository.kt @@ -150,7 +150,10 @@ interface MessageRepository { messageOption: BroadcastMessageOption ): Either - suspend fun sendMLSMessage(conversationId: ConversationId, message: MLSMessageApi.Message): Either + suspend fun sendMLSMessage( + conversationId: ConversationId, + message: MLSMessageApi.Message + ): Either suspend fun getAllPendingMessagesFromUser(senderUserId: UserId): Either> suspend fun getPendingConfirmationMessagesByConversationAfterDate( @@ -211,6 +214,11 @@ interface MessageRepository { usersWithFailedDeliveryList: List ): Either + suspend fun moveMessagesToAnotherConversation( + originalConversation: ConversationId, + targetConversation: ConversationId + ): Either + val extensions: MessageRepositoryExtensions } @@ -296,12 +304,18 @@ class MessageDataSource( messageDAO.deleteMessage(messageUuid, conversationId.toDao()) } - override suspend fun markMessageAsDeleted(messageUuid: String, conversationId: ConversationId): Either = + override suspend fun markMessageAsDeleted( + messageUuid: String, + conversationId: ConversationId + ): Either = wrapStorageRequest { messageDAO.markMessageAsDeleted(id = messageUuid, conversationsId = conversationId.toDao()) } - override suspend fun getMessageById(conversationId: ConversationId, messageUuid: String): Either = + override suspend fun getMessageById( + conversationId: ConversationId, + messageUuid: String + ): Either = wrapStorageRequest { messageDAO.getMessageById(messageUuid, conversationId.toDao()) }.map(messageMapper::fromEntityToMessage) @@ -316,7 +330,11 @@ class MessageDataSource( visibility.map { it.toEntityVisibility() } ).map { messageList -> messageList.map(messageMapper::fromEntityToMessage) } - override suspend fun updateMessageStatus(messageStatus: MessageEntity.Status, conversationId: ConversationId, messageUuid: String) = + override suspend fun updateMessageStatus( + messageStatus: MessageEntity.Status, + conversationId: ConversationId, + messageUuid: String + ) = wrapStorageRequest { messageDAO.updateMessageStatus( status = messageStatus, @@ -369,11 +387,12 @@ class MessageDataSource( envelope: MessageEnvelope, messageTarget: MessageTarget ): Either { - val recipientMap: Map> = envelope.recipients.associate { recipientEntry -> - recipientEntry.userId.toApi() to recipientEntry.clientPayloads.associate { clientPayload -> - clientPayload.clientId.value to clientPayload.payload.data + val recipientMap: Map> = + envelope.recipients.associate { recipientEntry -> + recipientEntry.userId.toApi() to recipientEntry.clientPayloads.associate { clientPayload -> + clientPayload.clientId.value to clientPayload.payload.data + } } - } return wrapApiRequest { messageApi.qualifiedSendMessage( @@ -419,11 +438,12 @@ class MessageDataSource( envelope: MessageEnvelope, messageOption: BroadcastMessageOption ): Either { - val recipientMap: Map> = envelope.recipients.associate { recipientEntry -> - recipientEntry.userId.toApi() to recipientEntry.clientPayloads.associate { clientPayload -> - clientPayload.clientId.value to clientPayload.payload.data + val recipientMap: Map> = + envelope.recipients.associate { recipientEntry -> + recipientEntry.userId.toApi() to recipientEntry.clientPayloads.associate { clientPayload -> + clientPayload.clientId.value to clientPayload.payload.data + } } - } val option = when (messageOption) { is BroadcastMessageOption.IgnoreSome -> MessageApi.QualifiedMessageOption.IgnoreSome(messageOption.userIDs.map { it.toApi() }) @@ -458,17 +478,21 @@ class MessageDataSource( }) } - override suspend fun sendMLSMessage(conversationId: ConversationId, message: MLSMessageApi.Message): Either = + override suspend fun sendMLSMessage( + conversationId: ConversationId, + message: MLSMessageApi.Message + ): Either = wrapApiRequest { mlsMessageApi.sendMessage(message) }.flatMap { response -> Either.Right(sendMessagePartialFailureMapper.fromMlsDTO(response)) } - override suspend fun getAllPendingMessagesFromUser(senderUserId: UserId): Either> = wrapStorageRequest { - messageDAO.getAllPendingMessagesFromUser(senderUserId.toDao()) - .map(messageMapper::fromEntityToMessage) - } + override suspend fun getAllPendingMessagesFromUser(senderUserId: UserId): Either> = + wrapStorageRequest { + messageDAO.getAllPendingMessagesFromUser(senderUserId.toDao()) + .map(messageMapper::fromEntityToMessage) + } override suspend fun getPendingConfirmationMessagesByConversationAfterDate( conversationId: ConversationId, @@ -540,9 +564,10 @@ class MessageDataSource( ) } - override suspend fun getEphemeralMessagesMarkedForDeletion(): Either> = wrapStorageRequest { - messageDAO.getEphemeralMessagesMarkedForDeletion().map(messageMapper::fromEntityToMessage) - } + override suspend fun getEphemeralMessagesMarkedForDeletion(): Either> = + wrapStorageRequest { + messageDAO.getEphemeralMessagesMarkedForDeletion().map(messageMapper::fromEntityToMessage) + } override suspend fun markSelfDeletionStartDate( conversationId: ConversationId, @@ -602,4 +627,13 @@ class MessageDataSource( ) } + override suspend fun moveMessagesToAnotherConversation( + originalConversation: ConversationId, + targetConversation: ConversationId + ): Either = wrapStorageRequest { + messageDAO.moveMessages( + from = originalConversation.toDao(), + to = targetConversation.toDao() + ) + } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt index c414067bf3e..823f05358e0 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt @@ -92,6 +92,7 @@ internal class PersistMessageUseCaseImpl( is MessageContent.NewConversationReceiptMode -> false is MessageContent.ConversationReceiptModeChanged -> false is MessageContent.HistoryLost -> false + is MessageContent.HistoryLostProtocolChanged -> false is MessageContent.ConversationMessageTimerChanged -> false is MessageContent.MemberChange.CreationAdded -> false is MessageContent.MemberChange.FailedToAdd -> false @@ -107,5 +108,6 @@ internal class PersistMessageUseCaseImpl( is MessageContent.MemberChange.FederationRemoved -> false is MessageContent.FederationStopped.ConnectionRemoved -> false is MessageContent.FederationStopped.Removed -> false + is MessageContent.ConversationProtocolChanged -> false } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt new file mode 100644 index 00000000000..0ea93350798 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/SystemMessageInserter.kt @@ -0,0 +1,95 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.data.message + +import com.benasher44.uuid.uuid4 +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.util.DateTimeUtil + +internal interface SystemMessageInserter { + suspend fun insertProtocolChangedSystemMessage( + conversationId: ConversationId, + senderUserId: UserId, + protocol: Conversation.Protocol + ) + suspend fun insertHistoryLostProtocolChangedSystemMessage( + conversationId: ConversationId + ) + + suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either +} + +internal class SystemMessageInserterImpl( + private val selfUserId: UserId, + private val persistMessage: PersistMessageUseCase +) : SystemMessageInserter { + override suspend fun insertProtocolChangedSystemMessage( + conversationId: ConversationId, + senderUserId: UserId, + protocol: Conversation.Protocol + ) { + val message = Message.System( + uuid4().toString(), + MessageContent.ConversationProtocolChanged( + protocol = protocol + ), + conversationId, + DateTimeUtil.currentIsoDateTimeString(), + senderUserId, + Message.Status.Sent, + Message.Visibility.VISIBLE, + null + ) + + persistMessage(message) + } + + override suspend fun insertHistoryLostProtocolChangedSystemMessage(conversationId: ConversationId) { + val message = Message.System( + uuid4().toString(), + MessageContent.HistoryLostProtocolChanged, + conversationId, + DateTimeUtil.currentIsoDateTimeString(), + selfUserId, + Message.Status.Sent, + Message.Visibility.VISIBLE, + null + ) + + persistMessage(message) + } + + override suspend fun insertLostCommitSystemMessage(conversationId: ConversationId, dateIso: String): Either { + val mlsEpochWarningMessage = Message.System( + id = uuid4().toString(), + content = MessageContent.MLSWrongEpochWarning, + conversationId = conversationId, + date = dateIso, + senderUserId = selfUserId, + status = Message.Status.Read(0), + visibility = Message.Visibility.VISIBLE, + senderUserName = null, + expirationData = null + ) + return persistMessage(mlsEpochWarningMessage) + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/publicuser/PublicUserMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/publicuser/PublicUserMapper.kt index 46a1914da5b..effb61eb294 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/publicuser/PublicUserMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/publicuser/PublicUserMapper.kt @@ -29,8 +29,11 @@ import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.ConnectionStateMapper import com.wire.kalium.logic.data.user.OtherUser import com.wire.kalium.logic.data.user.OtherUserMinimized +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserAvailabilityStatus import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.toDao +import com.wire.kalium.logic.data.user.toModel import com.wire.kalium.logic.data.user.type.DomainUserTypeMapper import com.wire.kalium.logic.data.user.type.UserEntityTypeMapper import com.wire.kalium.logic.data.user.type.UserType @@ -84,7 +87,9 @@ class PublicUserMapperImpl( deleted = userEntity.deleted, expiresAt = userEntity.expiresAt, defederated = userEntity.defederated, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = userEntity.supportedProtocols?.toModel(), + activeOneOnOneConversationId = userEntity.activeOneOnOneConversationId?.toModel() ) override fun fromUserDetailsEntityToOtherUser(userDetailsEntity: UserDetailsEntity) = OtherUser( @@ -104,7 +109,8 @@ class PublicUserMapperImpl( deleted = userDetailsEntity.deleted, expiresAt = userDetailsEntity.expiresAt, defederated = userDetailsEntity.defederated, - isProteusVerified = userDetailsEntity.isProteusVerified + isProteusVerified = userDetailsEntity.isProteusVerified, + supportedProtocols = userDetailsEntity.supportedProtocols?.toModel() ) override fun fromOtherToUserEntity(otherUser: OtherUser): UserEntity = with(otherUser) { @@ -125,7 +131,9 @@ class PublicUserMapperImpl( deleted = deleted, expiresAt = expiresAt, hasIncompleteMetadata = false, - defederated = defederated + defederated = defederated, + supportedProtocols = supportedProtocols?.toDao(), + activeOneOnOneConversationId = activeOneOnOneConversationId?.toDao() ) } @@ -148,7 +156,9 @@ class PublicUserMapperImpl( expiresAt = expiresAt, hasIncompleteMetadata = false, defederated = defederated, - isProteusVerified = otherUser.isProteusVerified + isProteusVerified = otherUser.isProteusVerified, + supportedProtocols = supportedProtocols?.toDao(), + activeOneOnOneConversationId = activeOneOnOneConversationId?.toDao() ) } @@ -180,7 +190,8 @@ class PublicUserMapperImpl( deleted = userDetailResponse.deleted ?: false, expiresAt = userDetailResponse.expiresAt?.toInstant(), defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = userDetailResponse.supportedProtocols?.toModel() ?: setOf(SupportedProtocol.PROTEUS) ) override fun fromEntityToUserSummary(userEntity: UserEntity) = with(userEntity) { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/sync/SlowSyncStatus.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/sync/SlowSyncStatus.kt index 3ad4881a76b..59603aeddf8 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/sync/SlowSyncStatus.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/sync/SlowSyncStatus.kt @@ -34,9 +34,11 @@ sealed interface SlowSyncStatus { enum class SlowSyncStep { SELF_USER, FEATURE_FLAGS, + UPDATE_SUPPORTED_PROTOCOLS, CONVERSATIONS, CONNECTIONS, SELF_TEAM, CONTACTS, - JOINING_MLS_CONVERSATIONS + JOINING_MLS_CONVERSATIONS, + RESOLVE_ONE_ON_ONE_PROTOCOLS } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/team/TeamRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/team/TeamRepository.kt index 3f0289b3c27..a043436dfb5 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/team/TeamRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/team/TeamRepository.kt @@ -91,19 +91,16 @@ internal class TeamDataSource( */ if (teamMemberList.hasMore.not()) { teamMemberList.members.map { teamMember -> - userMapper.fromTeamMemberToDaoModel( - teamId = teamId, - nonQualifiedUserId = teamMember.nonQualifiedUserId, - permissionCode = teamMember.permissions?.own, - userDomain = userDomain, - ) + val userId = QualifiedIDEntity(teamMember.nonQualifiedUserId, userDomain) + val userType = userTypeEntityTypeMapper.teamRoleCodeToUserType(teamMember.permissions?.own) + userId to userType } } else { listOf() } }.flatMap { teamMembers -> wrapStorageRequest { - userDAO.upsertTeamMembersTypes(teamMembers) + userDAO.upsertTeamMemberUserTypes(teamMembers.toMap()) } } @@ -123,13 +120,9 @@ internal class TeamDataSource( override suspend fun updateMemberRole(teamId: String, userId: String, permissionCode: Int?): Either { return wrapStorageRequest { - val user = userMapper.fromTeamMemberToDaoModel( - teamId = TeamId(teamId), - nonQualifiedUserId = userId, - userDomain = selfUserId.domain, - permissionCode = permissionCode - ) - userDAO.upsertTeamMembersTypes(listOf(user)) + userDAO.upsertTeamMemberUserTypes(mapOf( + QualifiedIDEntity(userId, selfUserId.domain) to userTypeEntityTypeMapper.teamRoleCodeToUserType(permissionCode) + )) } } @@ -139,22 +132,16 @@ internal class TeamDataSource( teamId = teamId, userId = userId, ) - }.flatMap { _ -> + }.flatMap { member -> wrapApiRequest { userDetailsApi.getUserInfo(userId = QualifiedID(userId, selfUserId.domain)) } .flatMap { userProfileDTO -> wrapStorageRequest { val userEntity = userMapper.fromUserProfileDtoToUserEntity( userProfile = userProfileDTO, connectionState = ConnectionEntity.State.ACCEPTED, - userTypeEntity = userTypeEntityTypeMapper.fromTeamAndDomain( - otherUserDomain = userProfileDTO.id.domain, - selfUserTeamId = teamId, - otherUserTeamId = userProfileDTO.teamId, - selfUserDomain = selfUserId.domain, - isService = userProfileDTO.service != null - ) + userTypeEntity = userTypeEntityTypeMapper.teamRoleCodeToUserType(member.permissions?.own) ) - userDAO.insertUser(userEntity) + userDAO.upsertUser(userEntity) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserMapper.kt index b9f86bc910d..0fd650a89b4 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserMapper.kt @@ -24,12 +24,11 @@ import com.wire.kalium.logic.data.id.NetworkQualifiedId import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel -import com.wire.kalium.logic.data.user.type.UserEntityTypeMapper import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.network.api.base.authenticated.self.UserUpdateRequest import com.wire.kalium.network.api.base.model.AssetSizeDTO -import com.wire.kalium.network.api.base.model.NonQualifiedUserId import com.wire.kalium.network.api.base.model.SelfUserDTO +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.base.model.UserAssetDTO import com.wire.kalium.network.api.base.model.UserAssetTypeDTO import com.wire.kalium.network.api.base.model.UserProfileDTO @@ -37,7 +36,9 @@ import com.wire.kalium.network.api.base.model.getCompleteAssetOrNull import com.wire.kalium.network.api.base.model.getPreviewAssetOrNull import com.wire.kalium.persistence.dao.BotIdEntity import com.wire.kalium.persistence.dao.ConnectionEntity +import com.wire.kalium.persistence.dao.PartialUserEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.dao.UserAvailabilityStatusEntity import com.wire.kalium.persistence.dao.UserDetailsEntity import com.wire.kalium.persistence.dao.UserEntity @@ -69,15 +70,7 @@ interface UserMapper { updateRequest: UserUpdateRequest ): UserEntity - fun fromTeamMemberToDaoModel( - teamId: TeamId, - nonQualifiedUserId: NonQualifiedUserId, - permissionCode: Int?, - userDomain: String, - ): UserEntity - - fun fromUserUpdateEventToUserEntity(event: Event.User.Update, userEntity: UserEntity): UserEntity - fun fromUserUpdateEventToUserEntity(event: Event.User.Update, userEntity: UserDetailsEntity): UserEntity + fun fromUserUpdateEventToPartialUserEntity(event: Event.User.Update): PartialUserEntity fun fromUserProfileDtoToUserEntity( userProfile: UserProfileDTO, @@ -102,8 +95,7 @@ interface UserMapper { internal class UserMapperImpl( private val idMapper: IdMapper = MapperProvider.idMapper(), private val availabilityStatusMapper: AvailabilityStatusMapper = MapperProvider.availabilityStatusMapper(), - private val connectionStateMapper: ConnectionStateMapper = MapperProvider.connectionStateMapper(), - private val userEntityTypeMapper: UserEntityTypeMapper = MapperProvider.userTypeEntityMapper() + private val connectionStateMapper: ConnectionStateMapper = MapperProvider.connectionStateMapper() ) : UserMapper { override fun fromUserEntityToSelfUser(userEntity: UserEntity) = with(userEntity) { @@ -119,7 +111,8 @@ internal class UserMapperImpl( previewAssetId?.toModel(), completeAssetId?.toModel(), availabilityStatusMapper.fromDaoAvailabilityStatusToModel(availabilityStatus), - expiresAt = expiresAt + expiresAt = expiresAt, + supportedProtocols?.toModel() ) } @@ -136,7 +129,8 @@ internal class UserMapperImpl( previewAssetId?.toModel(), completeAssetId?.toModel(), availabilityStatusMapper.fromDaoAvailabilityStatusToModel(availabilityStatus), - expiresAt = expiresAt + expiresAt = expiresAt, + supportedProtocols?.toModel() ) } @@ -157,7 +151,9 @@ internal class UserMapperImpl( botService = null, deleted = false, expiresAt = expiresAt, - defederated = false + defederated = false, + supportedProtocols = supportedProtocols?.toDao() ?: setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) } @@ -179,7 +175,9 @@ internal class UserMapperImpl( deleted = false, expiresAt = expiresAt, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = supportedProtocols?.toDao() ?: setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) } @@ -199,7 +197,9 @@ internal class UserMapperImpl( botService = null, deleted = userDTO.deleted ?: false, expiresAt = expiresAt?.toInstant(), - defederated = false + defederated = false, + supportedProtocols = supportedProtocols?.toDao() ?: setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) } @@ -220,7 +220,9 @@ internal class UserMapperImpl( deleted = userDTO.deleted ?: false, expiresAt = expiresAt?.toInstant(), defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = supportedProtocols?.toDao() ?: setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) } @@ -253,37 +255,6 @@ internal class UserMapperImpl( ) ) - /** - * Null and default/hardcoded values will be replaced later when fetching known users. - */ - override fun fromTeamMemberToDaoModel( - teamId: TeamId, - nonQualifiedUserId: NonQualifiedUserId, - permissionCode: Int?, - userDomain: String, - ): UserEntity = - UserEntity( - id = QualifiedIDEntity( - value = nonQualifiedUserId, - domain = userDomain - ), - name = null, - handle = null, - email = null, - phone = null, - accentId = 1, - team = teamId.value, - connectionStatus = ConnectionEntity.State.ACCEPTED, - previewAssetId = null, - completeAssetId = null, - availabilityStatus = UserAvailabilityStatusEntity.NONE, - userType = userEntityTypeMapper.teamRoleCodeToUserType(permissionCode), - botService = null, - deleted = false, - expiresAt = null, - defederated = false - ) - override fun fromUserProfileDtoToUserEntity( userProfile: UserProfileDTO, connectionState: ConnectionEntity.State, @@ -306,7 +277,9 @@ internal class UserMapperImpl( botService = userProfile.service?.let { BotIdEntity(it.id, it.provider) }, deleted = userProfile.deleted ?: false, expiresAt = userProfile.expiresAt?.toInstant(), - defederated = false + defederated = false, + supportedProtocols = userProfile.supportedProtocols?.toDao() ?: setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) override fun fromUserProfileDtoToUserDetailsEntity( @@ -332,48 +305,21 @@ internal class UserMapperImpl( deleted = userProfile.deleted ?: false, expiresAt = userProfile.expiresAt?.toInstant(), defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = userProfile.supportedProtocols?.toDao() ?: setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) - override fun fromUserUpdateEventToUserEntity(event: Event.User.Update, userEntity: UserEntity): UserEntity { - return userEntity.let { persistedEntity -> - persistedEntity.copy( - email = event.email ?: persistedEntity.email, - name = event.name ?: persistedEntity.name, - handle = event.handle ?: persistedEntity.handle, - accentId = event.accentId ?: persistedEntity.accentId, - previewAssetId = event.previewAssetId?.let { QualifiedIDEntity(it, persistedEntity.id.domain) } - ?: persistedEntity.previewAssetId, - completeAssetId = event.completeAssetId?.let { QualifiedIDEntity(it, persistedEntity.id.domain) } - ?: persistedEntity.completeAssetId - ) - } - } - - override fun fromUserUpdateEventToUserEntity(event: Event.User.Update, userEntity: UserDetailsEntity): UserEntity = - userEntity.let { persistedEntity -> - UserEntity( - id = persistedEntity.id, - name = event.name ?: persistedEntity.name, - handle = event.handle ?: persistedEntity.handle, - email = event.email ?: persistedEntity.email, - phone = persistedEntity.phone, - accentId = event.accentId ?: persistedEntity.accentId, - team = persistedEntity.team, - connectionStatus = persistedEntity.connectionStatus, - previewAssetId = event.previewAssetId?.let { QualifiedIDEntity(it, persistedEntity.id.domain) } - ?: persistedEntity.previewAssetId, - completeAssetId = event.completeAssetId?.let { QualifiedIDEntity(it, persistedEntity.id.domain) } - ?: persistedEntity.completeAssetId, - availabilityStatus = persistedEntity.availabilityStatus, - userType = persistedEntity.userType, - botService = persistedEntity.botService, - deleted = persistedEntity.deleted, - hasIncompleteMetadata = persistedEntity.hasIncompleteMetadata, - expiresAt = persistedEntity.expiresAt, - defederated = persistedEntity.defederated - ) - } + override fun fromUserUpdateEventToPartialUserEntity(event: Event.User.Update): PartialUserEntity = + PartialUserEntity( + email = event.email, + name = event.name, + handle = event.handle, + accentId = event.accentId, + previewAssetId = event.previewAssetId?.let { QualifiedIDEntity(it, event.userId.domain) }, + completeAssetId = event.completeAssetId?.let { QualifiedIDEntity(it, event.userId.domain) }, + supportedProtocols = event.supportedProtocols?.toDao() + ) /** * Default values and marked as [UserEntity.hasIncompleteMetadata] = true. @@ -397,7 +343,9 @@ internal class UserMapperImpl( deleted = false, hasIncompleteMetadata = true, expiresAt = null, - defederated = false + defederated = false, + supportedProtocols = null, + activeOneOnOneConversationId = null ) } @@ -420,7 +368,39 @@ internal class UserMapperImpl( hasIncompleteMetadata = true, expiresAt = null, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null, + activeOneOnOneConversationId = null ) } } + +fun SupportedProtocol.toApi() = when (this) { + SupportedProtocol.MLS -> SupportedProtocolDTO.MLS + SupportedProtocol.PROTEUS -> SupportedProtocolDTO.PROTEUS +} + +fun SupportedProtocol.toDao() = when (this) { + SupportedProtocol.MLS -> SupportedProtocolEntity.MLS + SupportedProtocol.PROTEUS -> SupportedProtocolEntity.PROTEUS +} + +fun SupportedProtocolDTO.toModel() = when (this) { + SupportedProtocolDTO.MLS -> SupportedProtocol.MLS + SupportedProtocolDTO.PROTEUS -> SupportedProtocol.PROTEUS +} + +fun SupportedProtocolDTO.toDao() = when (this) { + SupportedProtocolDTO.MLS -> SupportedProtocolEntity.MLS + SupportedProtocolDTO.PROTEUS -> SupportedProtocolEntity.PROTEUS +} + +fun SupportedProtocolEntity.toModel() = when (this) { + SupportedProtocolEntity.MLS -> SupportedProtocol.MLS + SupportedProtocolEntity.PROTEUS -> SupportedProtocol.PROTEUS +} + +fun List.toDao() = this.map { it.toDao() }.toSet() +fun List.toModel() = this.map { it.toModel() }.toSet() +fun Set.toDao() = this.map { it.toDao() }.toSet() +fun Set.toModel() = this.map { it.toModel() }.toSet() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserModel.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserModel.kt index ce501d4b2ee..4699562d063 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserModel.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserModel.kt @@ -44,6 +44,7 @@ sealed class User { abstract val completePicture: UserAssetId? abstract val availabilityStatus: UserAvailabilityStatus abstract val expiresAt: Instant? + abstract val supportedProtocols: Set? } // TODO we should extract ConnectionModel and ConnectionState to separate logic AR-1734 @@ -80,6 +81,10 @@ enum class UserAvailabilityStatus { NONE, AVAILABLE, BUSY, AWAY } +enum class SupportedProtocol { + PROTEUS, MLS +} + enum class ConnectionState { /** Default - No connection state */ NOT_CONNECTED, @@ -126,7 +131,8 @@ data class SelfUser( override val previewPicture: UserAssetId?, override val completePicture: UserAssetId?, override val availabilityStatus: UserAvailabilityStatus, - override val expiresAt: Instant? = null + override val expiresAt: Instant? = null, + override val supportedProtocols: Set? ) : User() data class OtherUserMinimized( @@ -149,11 +155,13 @@ data class OtherUser( override val completePicture: UserAssetId?, val userType: UserType, override val availabilityStatus: UserAvailabilityStatus, + override val supportedProtocols: Set?, val botService: BotService?, val deleted: Boolean, val defederated: Boolean, override val expiresAt: Instant? = null, - val isProteusVerified: Boolean + val isProteusVerified: Boolean, + val activeOneOnOneConversationId: ConversationId? = null ) : User() { /** diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt index 7aca95cf31f..7257b68e72d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/user/UserRepository.kt @@ -28,7 +28,6 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.id.NetworkQualifiedId import com.wire.kalium.logic.data.id.QualifiedID -import com.wire.kalium.logic.data.id.QualifiedIdMapper import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel @@ -96,6 +95,7 @@ internal interface UserRepository { suspend fun observeAllKnownUsers(): Flow>> suspend fun getKnownUser(userId: UserId): Flow suspend fun getKnownUserMinimized(userId: UserId): OtherUserMinimized? + suspend fun getUsersWithOneOnOneConversation(): List suspend fun observeUser(userId: UserId): Flow suspend fun userById(userId: UserId): Either suspend fun updateOtherUserAvailabilityStatus(userId: UserId, status: UserAvailabilityStatus) @@ -133,6 +133,10 @@ internal interface UserRepository { * Gets users summary by their ids. */ suspend fun getUsersSummaryByIds(userIds: List): Either> + + suspend fun updateSupportedProtocols(protocols: Set): Either + + suspend fun updateActiveOneOnOneConversation(userId: UserId, conversationId: ConversationId): Either } @Suppress("LongParameterList", "TooManyFunctions") @@ -144,7 +148,6 @@ internal class UserDataSource internal constructor( private val userDetailsApi: UserDetailsApi, private val sessionRepository: SessionRepository, private val selfUserId: UserId, - private val qualifiedIdMapper: QualifiedIdMapper, private val selfTeamIdProvider: SelfTeamIdProvider, private val idMapper: IdMapper = MapperProvider.idMapper(), private val userMapper: UserMapper = MapperProvider.userMapper(), @@ -172,7 +175,7 @@ internal class UserDataSource internal constructor( updateSelfUserProviderAccountInfo(userDTO) .map { userMapper.fromSelfUserDtoToUserEntity(userDTO).copy(connectionStatus = ConnectionEntity.State.ACCEPTED) } .flatMap { userEntity -> - wrapStorageRequest { userDAO.insertUser(userEntity) } + wrapStorageRequest { userDAO.upsertUser(userEntity) } .flatMap { wrapStorageRequest { metadataDAO.insertValue(Json.encodeToString(userEntity.id), SELF_USER_ID_KEY) } } @@ -191,6 +194,11 @@ internal class UserDataSource internal constructor( processFederatedUserRefresh(userId, otherUser) } + override suspend fun getUsersWithOneOnOneConversation(): List { + return userDAO.getUsersWithOneOnOneConversation() + .map(publicUserMapper::fromUserEntityToOtherUser) + } + /** * Only in case of federated users and if it's expired or not cached, we fetch and refresh the user info. */ @@ -234,9 +242,7 @@ internal class UserDataSource internal constructor( } private suspend fun persistIncompleteUsers(usersFailed: List) = wrapStorageRequest { - usersFailed.map { userMapper.fromFailedUserToEntity(it) }.forEach { - userDAO.insertUser(it) - } + userDAO.insertOrIgnoreUsers(usersFailed.map { userMapper.fromFailedUserToEntity(it) }) } private suspend fun persistUsers(listUserProfileDTO: List) = wrapStorageRequest { @@ -246,12 +252,14 @@ internal class UserDataSource internal constructor( .filter { userProfileDTO -> isTeamMember(selfUserTeamId, userProfileDTO, selfUserDomain) } val otherUsers = listUserProfileDTO .filter { userProfileDTO -> !isTeamMember(selfUserTeamId, userProfileDTO, selfUserDomain) } - userDAO.upsertTeamMembers( + + userDAO.upsertUsers( teamMembers.map { userProfileDTO -> userMapper.fromUserProfileDtoToUserEntity( userProfile = userProfileDTO, connectionState = ConnectionEntity.State.ACCEPTED, - userTypeEntity = UserTypeEntity.STANDARD + userTypeEntity = userDAO.observeUserDetailsByQualifiedID(userProfileDTO.id.toDao()) + .firstOrNull()?.userType ?: UserTypeEntity.STANDARD ) } ) @@ -337,7 +345,7 @@ internal class UserDataSource internal constructor( .map { userMapper.fromUpdateRequestToDaoModel(user, updateRequest) } .flatMap { userEntity -> wrapStorageRequest { - userDAO.updateUser(userEntity) + userDAO.upsertUser(userEntity) }.map { userMapper.fromUserEntityToSelfUser(userEntity) } } } @@ -396,6 +404,18 @@ internal class UserDataSource internal constructor( userDAO.updateUserAvailabilityStatus(userId.toDao(), availabilityStatusMapper.fromModelAvailabilityStatusToDao(status)) } + override suspend fun updateSupportedProtocols(protocols: Set): Either { + return wrapApiRequest { selfApi.updateSupportedProtocols(protocols.map { it.toApi() }) } + .flatMap { + wrapStorageRequest { + userDAO.updateUserSupportedProtocols(selfUserId.toDao(), protocols.map { it.toDao() }.toSet()) + } + } + } + + override suspend fun updateActiveOneOnOneConversation(userId: UserId, conversationId: ConversationId): Either = + wrapStorageRequest { userDAO.updateActiveOneOnOneConversation(userId.toDao(), conversationId.toDao()) } + override fun observeAllKnownUsersNotInConversation( conversationId: ConversationId ): Flow>> { @@ -428,10 +448,13 @@ internal class UserDataSource internal constructor( } override suspend fun updateUserFromEvent(event: Event.User.Update): Either = wrapStorageRequest { - val userId = qualifiedIdMapper.fromStringToQualifiedID(event.userId) - val user = - userDAO.observeUserDetailsByQualifiedID(userId.toDao()).firstOrNull() ?: return Either.Left(StorageFailure.DataNotFound) - userDAO.updateUser(userMapper.fromUserUpdateEventToUserEntity(event, user)) + userDAO.updateUser(event.userId.toDao(), userMapper.fromUserUpdateEventToPartialUserEntity(event)) + }.flatMap { updated -> + if (!updated) { + Either.Left(StorageFailure.DataNotFound) + } else { + Either.Right(Unit) + } } override suspend fun removeUser(userId: UserId): Either { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/di/MapperProvider.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/di/MapperProvider.kt index 07edb83f0c1..789995b4b3d 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/di/MapperProvider.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/di/MapperProvider.kt @@ -106,7 +106,7 @@ internal object MapperProvider { fun availabilityStatusMapper(): AvailabilityStatusMapper = AvailabilityStatusMapperImpl() fun connectionStateMapper(): ConnectionStateMapper = ConnectionStateMapperImpl() fun userMapper(): UserMapper = UserMapperImpl( - idMapper(), availabilityStatusMapper(), connectionStateMapper(), userTypeEntityMapper() + idMapper(), availabilityStatusMapper(), connectionStateMapper() ) fun userTypeMapper(): DomainUserTypeMapper = DomainUserTypeMapperImpl() @@ -118,8 +118,9 @@ internal object MapperProvider { ) fun memberMapper(): MemberMapper = MemberMapperImpl(idMapper(), conversationRoleMapper()) - fun conversationMapper(): ConversationMapper = + fun conversationMapper(selfUserId: UserId): ConversationMapper = ConversationMapperImpl( + selfUserId, idMapper(), ConversationStatusMapperImpl(idMapper()), ProtocolInfoMapperImpl(), @@ -134,11 +135,12 @@ internal object MapperProvider { fun sendMessageFailureMapper(): SendMessageFailureMapper = SendMessageFailureMapperImpl() fun assetMapper(): AssetMapper = AssetMapperImpl() fun encryptionAlgorithmMapper(): EncryptionAlgorithmMapper = EncryptionAlgorithmMapper() - fun eventMapper(): EventMapper = EventMapper( + fun eventMapper(selfUserId: UserId): EventMapper = EventMapper( memberMapper(), connectionMapper(), featureConfigMapper(), conversationRoleMapper(), + selfUserId, receiptModeMapper(), ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/TimestampKeyRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/TimestampKeyRepository.kt index d4c5d9fedea..ba161051bce 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/TimestampKeyRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/TimestampKeyRepository.kt @@ -57,5 +57,6 @@ class TimestampKeyRepositoryImpl( enum class TimestampKeys { LAST_KEYING_MATERIAL_UPDATE_CHECK, LAST_KEY_PACKAGE_COUNT_CHECK, - LAST_MISSING_METADATA_SYNC_CHECK + LAST_MISSING_METADATA_SYNC_CHECK, + LAST_MLS_MIGRATION_CHECK } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index c1d9b125ac4..71b37ce9660 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -94,6 +94,7 @@ import com.wire.kalium.logic.data.message.PersistReactionUseCase import com.wire.kalium.logic.data.message.PersistReactionUseCaseImpl import com.wire.kalium.logic.data.message.ProtoContentMapper import com.wire.kalium.logic.data.message.ProtoContentMapperImpl +import com.wire.kalium.logic.data.message.SystemMessageInserterImpl import com.wire.kalium.logic.data.message.reaction.ReactionRepositoryImpl import com.wire.kalium.logic.data.message.receipt.ReceiptRepositoryImpl import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository @@ -179,6 +180,12 @@ import com.wire.kalium.logic.feature.conversation.SyncConversationsUseCaseImpl import com.wire.kalium.logic.feature.conversation.TypingIndicatorSyncManager import com.wire.kalium.logic.feature.conversation.keyingmaterials.KeyingMaterialsManager import com.wire.kalium.logic.feature.conversation.keyingmaterials.KeyingMaterialsManagerImpl +import com.wire.kalium.logic.feature.conversation.mls.MLSOneOnOneConversationResolver +import com.wire.kalium.logic.feature.conversation.mls.MLSOneOnOneConversationResolverImpl +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneMigrator +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneMigratorImpl +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolverImpl import com.wire.kalium.logic.feature.debug.DebugScope import com.wire.kalium.logic.feature.e2ei.EnrollE2EIUseCase import com.wire.kalium.logic.feature.e2ei.EnrollE2EIUseCaseImpl @@ -193,6 +200,7 @@ import com.wire.kalium.logic.feature.featureConfig.handler.GuestRoomConfigHandle import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SecondFactorPasswordChallengeConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SelfDeletingMessagesConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler import com.wire.kalium.logic.feature.keypackage.KeyPackageManager import com.wire.kalium.logic.feature.keypackage.KeyPackageManagerImpl import com.wire.kalium.logic.feature.message.AddSystemMessageToAllConversationsUseCase @@ -210,12 +218,21 @@ import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCase import com.wire.kalium.logic.feature.message.PersistMigratedMessagesUseCaseImpl import com.wire.kalium.logic.feature.message.SessionEstablisher import com.wire.kalium.logic.feature.message.SessionEstablisherImpl +import com.wire.kalium.logic.feature.message.StaleEpochVerifier +import com.wire.kalium.logic.feature.message.StaleEpochVerifierImpl import com.wire.kalium.logic.feature.migration.MigrationScope +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManager +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationManagerImpl +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerImpl +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrator +import com.wire.kalium.logic.feature.mlsmigration.MLSMigratorImpl import com.wire.kalium.logic.feature.notificationToken.PushTokenUpdater import com.wire.kalium.logic.feature.proteus.ProteusPreKeyRefiller import com.wire.kalium.logic.feature.proteus.ProteusPreKeyRefillerImpl import com.wire.kalium.logic.feature.proteus.ProteusSyncWorker import com.wire.kalium.logic.feature.proteus.ProteusSyncWorkerImpl +import com.wire.kalium.logic.feature.protocol.OneOnOneProtocolSelector +import com.wire.kalium.logic.feature.protocol.OneOnOneProtocolSelectorImpl import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveSelfDeletionTimerSettingsForConversationUseCase import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveSelfDeletionTimerSettingsForConversationUseCaseImpl import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveTeamSettingsSelfDeletingStatusUseCase @@ -245,6 +262,10 @@ import com.wire.kalium.logic.feature.user.SyncContactsUseCase import com.wire.kalium.logic.feature.user.SyncContactsUseCaseImpl import com.wire.kalium.logic.feature.user.SyncSelfUserUseCase import com.wire.kalium.logic.feature.user.SyncSelfUserUseCaseImpl +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCaseImpl +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCase +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseImpl import com.wire.kalium.logic.feature.user.UserScope import com.wire.kalium.logic.feature.user.guestroomlink.MarkGuestLinkFeatureFlagAsNotChangedUseCase import com.wire.kalium.logic.feature.user.guestroomlink.MarkGuestLinkFeatureFlagAsNotChangedUseCaseImpl @@ -314,6 +335,8 @@ import com.wire.kalium.logic.sync.receiver.conversation.MemberLeaveEventHandler import com.wire.kalium.logic.sync.receiver.conversation.MemberLeaveEventHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.NewConversationEventHandler import com.wire.kalium.logic.sync.receiver.conversation.NewConversationEventHandlerImpl +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandler +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHandler import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler @@ -322,8 +345,6 @@ import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessa import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpackerImpl -import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandler -import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandler import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.ProteusMessageUnpacker @@ -524,6 +545,7 @@ class UserSessionScope internal constructor( private val mlsConversationRepository: MLSConversationRepository get() = MLSConversationDataSource( + userId, keyPackageRepository, mlsClientProvider, authenticatedNetworkContainer.mlsMessageApi, @@ -617,7 +639,6 @@ class UserSessionScope internal constructor( authenticatedNetworkContainer.userDetailsApi, globalScope.sessionRepository, userId, - qualifiedIdMapper, selfTeamId ) @@ -652,10 +673,7 @@ class UserSessionScope internal constructor( userStorage.database.memberDAO, userStorage.database.connectionDAO, authenticatedNetworkContainer.connectionApi, - authenticatedNetworkContainer.userDetailsApi, userStorage.database.userDAO, - userId, - selfTeamId, conversationRepository ) @@ -774,8 +792,8 @@ class UserSessionScope internal constructor( private val eventGatherer: EventGatherer get() = EventGathererImpl(eventRepository, incrementalSyncRepository) - private val eventProcessor: EventProcessor - get() = EventProcessorImpl( + private val eventProcessor: EventProcessor by lazy { + EventProcessorImpl( eventRepository, conversationEventReceiver, userEventReceiver, @@ -784,6 +802,7 @@ class UserSessionScope internal constructor( userPropertiesEventReceiver, federationEventReceiver ) + } private val slowSyncCriteriaProvider: SlowSyncCriteriaProvider get() = SlowSlowSyncCriteriaProviderImpl(clientRepository, logoutRepository) @@ -802,7 +821,10 @@ class UserSessionScope internal constructor( ) private val syncConversations: SyncConversationsUseCase - get() = SyncConversationsUseCaseImpl(conversationRepository) + get() = SyncConversationsUseCaseImpl( + conversationRepository, + systemMessageInserter + ) private val syncConnections: SyncConnectionsUseCase get() = SyncConnectionsUseCaseImpl( @@ -823,8 +845,7 @@ class UserSessionScope internal constructor( authenticatedNetworkContainer.conversationApi, clientRepository, conversationRepository, - mlsConversationRepository, - mlsUnpacker + mlsConversationRepository ) private val recoverMLSConversationsUseCase: RecoverMLSConversationsUseCase @@ -861,16 +882,54 @@ class UserSessionScope internal constructor( clientIdProvider, ) + private val mlsOneOnOneConversationResolver: MLSOneOnOneConversationResolver + get() = MLSOneOnOneConversationResolverImpl( + conversationRepository, + joinExistingMLSConversationUseCase + ) + + private val oneOnOneMigrator: OneOnOneMigrator + get() = OneOnOneMigratorImpl( + mlsOneOnOneConversationResolver, + conversationGroupRepository, + conversationRepository, + messageRepository, + userRepository + ) + private val oneOnOneResolver: OneOnOneResolver + get() = OneOnOneResolverImpl( + userRepository, + oneOnOneProtocolSelector, + oneOnOneMigrator, + incrementalSyncRepository + ) + + private val updateSupportedProtocols: UpdateSupportedProtocolsUseCase + get() = UpdateSupportedProtocolsUseCaseImpl( + clientRepository, + userRepository, + userConfigRepository, + featureSupport + ) + + private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase + get() = UpdateSupportedProtocolsAndResolveOneOnOnesUseCaseImpl( + updateSupportedProtocols, + oneOnOneResolver + ) + private val slowSyncWorker: SlowSyncWorker by lazy { SlowSyncWorkerImpl( eventRepository, syncSelfUser, syncFeatureConfigsUseCase, + updateSupportedProtocols, syncConversations, syncConnections, syncSelfTeamUseCase, syncContacts, - joinExistingMLSConversations + joinExistingMLSConversations, + oneOnOneResolver ) } @@ -948,7 +1007,17 @@ class UserSessionScope internal constructor( private val eventRepository: EventRepository get() = EventDataSource( - authenticatedNetworkContainer.notificationApi, userStorage.database.metadataDAO, clientIdProvider + authenticatedNetworkContainer.notificationApi, userStorage.database.metadataDAO, clientIdProvider, userId + ) + + private val mlsMigrator: MLSMigrator + get() = MLSMigratorImpl( + userId, + selfTeamId, + userRepository, + conversationRepository, + mlsConversationRepository, + systemMessageInserter ) internal val keyPackageManager: KeyPackageManager = KeyPackageManagerImpl(featureSupport, @@ -964,7 +1033,7 @@ class UserSessionScope internal constructor( lazy { users.timestampKeyRepository }) internal val mlsClientManager: MLSClientManager = MLSClientManagerImpl(clientIdProvider, - isMLSEnabled, + isAllowedToRegisterMLSClient, incrementalSyncRepository, lazy { slowSyncRepository }, lazy { clientRepository }, @@ -974,6 +1043,24 @@ class UserSessionScope internal constructor( ) }) + internal val mlsMigrationWorker get() = + MLSMigrationWorkerImpl( + userConfigRepository, + featureConfigRepository, + mlsConfigHandler, + mlsMigrationConfigHandler, + mlsMigrator, + ) + + internal val mlsMigrationManager: MLSMigrationManager = MLSMigrationManagerImpl( + kaliumConfigs, + featureSupport, + incrementalSyncRepository, + lazy { clientRepository }, + lazy { users.timestampKeyRepository }, + lazy { mlsMigrationWorker } + ) + private val mlsPublicKeysRepository: MLSPublicKeysRepository get() = MLSPublicKeysRepositoryImpl( authenticatedNetworkContainer.mlsPublicKeyApi, @@ -1049,6 +1136,8 @@ class UserSessionScope internal constructor( private val messageEncoder get() = MessageContentEncoder() + private val systemMessageInserter get() = SystemMessageInserterImpl(userId, persistMessage) + private val receiptMessageHandler get() = ReceiptMessageHandlerImpl( selfUserId = this.userId, @@ -1093,11 +1182,11 @@ class UserSessionScope internal constructor( userId ) - private val mlsWrongEpochHandler: MLSWrongEpochHandler - get() = MLSWrongEpochHandlerImpl( - selfUserId = userId, - persistMessage = persistMessage, + private val staleEpochVerifier: StaleEpochVerifier + get() = StaleEpochVerifierImpl( + systemMessageInserter = systemMessageInserter, conversationRepository = conversationRepository, + mlsConversationRepository = mlsConversationRepository, joinExistingMLSConversation = joinExistingMLSConversationUseCase ) @@ -1107,7 +1196,7 @@ class UserSessionScope internal constructor( { conversationId, messageId -> messages.ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId) }, userId, - mlsWrongEpochHandler + staleEpochVerifier ) private val newConversationHandler: NewConversationEventHandler @@ -1135,7 +1224,7 @@ class UserSessionScope internal constructor( ) private val mlsWelcomeHandler: MLSWelcomeEventHandler get() = MLSWelcomeEventHandlerImpl( - mlsClientProvider, userStorage.database.conversationDAO, conversationRepository + mlsClientProvider, conversationRepository, oneOnOneResolver ) private val renamedConversationHandler: RenamedConversationEventHandler get() = RenamedConversationEventHandlerImpl( @@ -1167,6 +1256,12 @@ class UserSessionScope internal constructor( private val typingIndicatorHandler: TypingIndicatorHandler get() = TypingIndicatorHandlerImpl(userId, conversations.typingIndicatorIncomingRepository) + private val protocolUpdateEventHandler: ProtocolUpdateEventHandler + get() = ProtocolUpdateEventHandlerImpl( + conversationRepository = conversationRepository, + systemMessageInserter = systemMessageInserter + ) + private val conversationEventReceiver: ConversationEventReceiver by lazy { ConversationEventReceiverImpl( newMessageHandler, @@ -1181,7 +1276,8 @@ class UserSessionScope internal constructor( conversationMessageTimerEventHandler, conversationCodeUpdateHandler, conversationCodeDeletedHandler, - typingIndicatorHandler + typingIndicatorHandler, + protocolUpdateEventHandler ) } @@ -1192,6 +1288,7 @@ class UserSessionScope internal constructor( conversationRepository, userRepository, logout, + oneOnOneResolver, userId, clientIdProvider ) @@ -1208,17 +1305,48 @@ class UserSessionScope internal constructor( private val teamEventReceiver: TeamEventReceiver get() = TeamEventReceiverImpl(teamRepository, conversationRepository, userRepository, persistMessage, userId) + private val guestRoomConfigHandler + get() = GuestRoomConfigHandler(userConfigRepository, kaliumConfigs) + + private val fileSharingConfigHandler + get() = FileSharingConfigHandler(userConfigRepository) + + private val mlsConfigHandler + get() = MLSConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes, userId) + + private val mlsMigrationConfigHandler + get() = MLSMigrationConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes) + + private val classifiedDomainsConfigHandler + get() = ClassifiedDomainsConfigHandler(userConfigRepository) + + private val conferenceCallingConfigHandler + get() = ConferenceCallingConfigHandler(userConfigRepository) + + private val secondFactorPasswordChallengeConfigHandler + get() = SecondFactorPasswordChallengeConfigHandler(userConfigRepository) + + private val selfDeletingMessagesConfigHandler + get() = SelfDeletingMessagesConfigHandler(userConfigRepository, kaliumConfigs) + + private val e2eiConfigHandler + get() = E2EIConfigHandler(userConfigRepository) + + private val appLockConfigHandler + get() = AppLockConfigHandler(userConfigRepository) + private val featureConfigEventReceiver: FeatureConfigEventReceiver get() = FeatureConfigEventReceiverImpl( - GuestRoomConfigHandler(userConfigRepository, kaliumConfigs), - FileSharingConfigHandler(userConfigRepository), - MLSConfigHandler(userConfigRepository, userId), - ClassifiedDomainsConfigHandler(userConfigRepository), - ConferenceCallingConfigHandler(userConfigRepository), - SecondFactorPasswordChallengeConfigHandler(userConfigRepository), - SelfDeletingMessagesConfigHandler(userConfigRepository, kaliumConfigs), - E2EIConfigHandler(userConfigRepository), - AppLockConfigHandler(userConfigRepository) + guestRoomConfigHandler, + fileSharingConfigHandler, + mlsConfigHandler, + mlsMigrationConfigHandler, + classifiedDomainsConfigHandler, + conferenceCallingConfigHandler, + secondFactorPasswordChallengeConfigHandler, + selfDeletingMessagesConfigHandler, + e2eiConfigHandler, + appLockConfigHandler ) private val preKeyRepository: PreKeyRepository @@ -1275,6 +1403,11 @@ class UserSessionScope internal constructor( protoContentMapper = protoContentMapper ) + private val oneOnOneProtocolSelector: OneOnOneProtocolSelector + get() = OneOnOneProtocolSelectorImpl( + userRepository + ) + @OptIn(DelicateKaliumApi::class) val client: ClientScope get() = ClientScope( @@ -1296,7 +1429,8 @@ class UserSessionScope internal constructor( userRepository, authenticationScope.secondFactorVerificationRepository, slowSyncRepository, - cachedClientIdClearer + cachedClientIdClearer, + updateSupportedProtocolsAndResolveOneOnOnes ) val conversations: ConversationScope by lazy { ConversationScope( @@ -1321,11 +1455,12 @@ class UserSessionScope internal constructor( globalScope.serverConfigRepository, userStorage, userPropertyRepository, + oneOnOneResolver, this ) } - val migration get() = MigrationScope(userStorage.database) + val migration get() = MigrationScope(userId, userStorage.database) val debug: DebugScope get() = DebugScope( messageRepository, @@ -1343,6 +1478,8 @@ class UserSessionScope internal constructor( slowSyncRepository, messageSendingScheduler, selfConversationIdProvider, + staleEpochVerifier, + eventProcessor, this ) val messages: MessageScope @@ -1370,6 +1507,7 @@ class UserSessionScope internal constructor( protoContentMapper, observeSelfDeletingMessages, messageMetadataRepository, + staleEpochVerifier, this ) val users: UserScope @@ -1389,8 +1527,9 @@ class UserSessionScope internal constructor( userPropertyRepository, messages.messageSender, clientIdProvider, + e2eiRepository, team.isSelfATeamMember, - e2eiRepository + updateSupportedProtocols ) private val clearUserData: ClearUserDataUseCase get() = ClearUserDataUseCaseImpl(userStorage) @@ -1460,15 +1599,16 @@ class UserSessionScope internal constructor( private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase get() = SyncFeatureConfigsUseCaseImpl( featureConfigRepository, - GuestRoomConfigHandler(userConfigRepository, kaliumConfigs), - FileSharingConfigHandler(userConfigRepository), - MLSConfigHandler(userConfigRepository, userId), - ClassifiedDomainsConfigHandler(userConfigRepository), - ConferenceCallingConfigHandler(userConfigRepository), - SecondFactorPasswordChallengeConfigHandler(userConfigRepository), - SelfDeletingMessagesConfigHandler(userConfigRepository, kaliumConfigs), - E2EIConfigHandler(userConfigRepository), - AppLockConfigHandler(userConfigRepository) + guestRoomConfigHandler, + fileSharingConfigHandler, + mlsConfigHandler, + mlsMigrationConfigHandler, + classifiedDomainsConfigHandler, + conferenceCallingConfigHandler, + secondFactorPasswordChallengeConfigHandler, + selfDeletingMessagesConfigHandler, + e2eiConfigHandler, + appLockConfigHandler ) val team: TeamScope get() = TeamScope(userRepository, teamRepository, conversationRepository, selfTeamId) @@ -1496,7 +1636,13 @@ class UserSessionScope internal constructor( kaliumConfigs ) - val connection: ConnectionScope get() = ConnectionScope(connectionRepository, conversationRepository) + val connection: ConnectionScope + get() = ConnectionScope( + connectionRepository, + conversationRepository, + userRepository, + oneOnOneResolver + ) val observeSecurityClassificationLabel: ObserveSecurityClassificationLabelUseCase get() = ObserveSecurityClassificationLabelUseCaseImpl( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt index 3e97470fea2..2188c65ab89 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/auth/LoginUseCase.kt @@ -138,6 +138,7 @@ internal class LoginUseCaseImpl internal constructor( is NetworkFailure.ServerMiscommunication -> handleServerMiscommunication(it, isEmail, cleanUserIdentifier) is NetworkFailure.NoNetworkConnection -> AuthenticationResult.Failure.Generic(it) is NetworkFailure.FederatedBackendFailure -> AuthenticationResult.Failure.Generic(it) + is NetworkFailure.FeatureNotSupported -> AuthenticationResult.Failure.Generic(it) } }, { if (isEmail && clean2FACode != null) { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCase.kt index c89e3da87f5..99af987f549 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCase.kt @@ -63,12 +63,12 @@ interface RestoreWebBackupUseCase { @Suppress("TooManyFunctions", "LongParameterList", "NestedBlockDepth") internal class RestoreWebBackupUseCaseImpl( private val kaliumFileSystem: KaliumFileSystem, - private val userId: UserId, + private val selfUserId: UserId, private val persistMigratedMessages: PersistMigratedMessagesUseCase, private val restartSlowSyncProcessForRecovery: RestartSlowSyncProcessForRecoveryUseCase, private val migrationDAO: MigrationDAO, private val dispatchers: KaliumDispatcher = KaliumDispatcherImpl, - private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper() + private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId) ) : RestoreWebBackupUseCase { override suspend operator fun invoke(backupRootPath: Path, metadata: BackupMetadata): RestoreBackupResult = @@ -100,7 +100,7 @@ internal class RestoreWebBackupUseCaseImpl( while (iterator.hasNext()) { try { val webConversation = iterator.next() - val migratedConversation = webConversation.toConversation(userId) + val migratedConversation = webConversation.toConversation(selfUserId) if (migratedConversation != null) { migratedConversations.add(migratedConversation) } @@ -128,7 +128,7 @@ internal class RestoreWebBackupUseCaseImpl( while (iterator.hasNext()) { try { val webContent = iterator.next() - val migratedMessage = webContent.toMigratedMessage(userId.domain) + val migratedMessage = webContent.toMigratedMessage(selfUserId.domain) if (migratedMessage != null) { migratedMessagesBatch.add(migratedMessage) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt index f7af6ae4941..258c3da0dbb 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt @@ -42,6 +42,7 @@ import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesUseCaseImpl import com.wire.kalium.logic.feature.session.DeregisterTokenUseCase import com.wire.kalium.logic.feature.session.DeregisterTokenUseCaseImpl import com.wire.kalium.logic.feature.session.UpgradeCurrentSessionUseCase +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase import com.wire.kalium.logic.sync.slow.RestartSlowSyncProcessForRecoveryUseCase import com.wire.kalium.logic.sync.slow.RestartSlowSyncProcessForRecoveryUseCaseImpl import com.wire.kalium.util.DelicateKaliumApi @@ -66,7 +67,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( private val userRepository: UserRepository, private val secondFactorVerificationRepository: SecondFactorVerificationRepository, private val slowSyncRepository: SlowSyncRepository, - private val cachedClientIdClearer: CachedClientIdClearer + private val cachedClientIdClearer: CachedClientIdClearer, + private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase ) { @OptIn(DelicateKaliumApi::class) val register: RegisterClientUseCase @@ -85,7 +87,11 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( val selfClients: FetchSelfClientsFromRemoteUseCase get() = FetchSelfClientsFromRemoteUseCaseImpl(clientRepository, clientIdProvider) val observeClientDetailsUseCase: ObserveClientDetailsUseCase get() = ObserveClientDetailsUseCaseImpl(clientRepository, clientIdProvider) - val deleteClient: DeleteClientUseCase get() = DeleteClientUseCaseImpl(clientRepository) + val deleteClient: DeleteClientUseCase + get() = DeleteClientUseCaseImpl( + clientRepository, + updateSupportedProtocolsAndResolveOneOnOnes, + ) val needsToRegisterClient: NeedsToRegisterClientUseCase get() = NeedsToRegisterClientUseCaseImpl(clientIdProvider, sessionRepository, proteusClientProvider, selfUserId) val deregisterNativePushToken: DeregisterTokenUseCase diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCase.kt index 53408f5fc4a..fda4033b4e7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCase.kt @@ -22,7 +22,9 @@ import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.client.DeleteClientParam +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.exceptions.isBadRequest import com.wire.kalium.network.exceptions.isInvalidCredentials @@ -36,9 +38,18 @@ interface DeleteClientUseCase { suspend operator fun invoke(param: DeleteClientParam): DeleteClientResult } -class DeleteClientUseCaseImpl(private val clientRepository: ClientRepository) : DeleteClientUseCase { +internal class DeleteClientUseCaseImpl( + private val clientRepository: ClientRepository, + private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase, +) : DeleteClientUseCase { override suspend operator fun invoke(param: DeleteClientParam): DeleteClientResult = - clientRepository.deleteClient(param).fold( + clientRepository.deleteClient(param) + .onSuccess { + updateSupportedProtocolsAndResolveOneOnOnes( + synchroniseUsers = true + ) + } + .fold( { handleError(it) }, { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/MLSClientManager.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/MLSClientManager.kt index 5bf8512dbe7..9f6f50d261e 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/MLSClientManager.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/MLSClientManager.kt @@ -23,7 +23,6 @@ import com.wire.kalium.logic.data.sync.IncrementalSyncRepository import com.wire.kalium.logic.data.sync.IncrementalSyncStatus import com.wire.kalium.logic.data.sync.SlowSyncRepository import com.wire.kalium.logic.feature.CurrentClientIdProvider -import com.wire.kalium.logic.feature.user.IsMLSEnabledUseCase import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.onSuccess @@ -45,7 +44,7 @@ interface MLSClientManager @Suppress("LongParameterList") internal class MLSClientManagerImpl( private val currentClientIdProvider: CurrentClientIdProvider, - private val isMLSEnabled: IsMLSEnabledUseCase, + private val isAllowedToRegisterMLSClient: IsAllowedToRegisterMLSClientUseCase, private val incrementalSyncRepository: IncrementalSyncRepository, private val slowSyncRepository: Lazy, private val clientRepository: Lazy, @@ -68,7 +67,7 @@ internal class MLSClientManagerImpl( incrementalSyncRepository.incrementalSyncState.collect { syncState -> ensureActive() if (syncState is IncrementalSyncStatus.Live && - isMLSEnabled() + isAllowedToRegisterMLSClient() ) { registerMLSClientIfNeeded() } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCase.kt index 6a84d557832..8d60f87aced 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCase.kt @@ -23,8 +23,10 @@ import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.util.DateTimeUtil @@ -44,16 +46,26 @@ fun interface AcceptConnectionRequestUseCase { internal class AcceptConnectionRequestUseCaseImpl( private val connectionRepository: ConnectionRepository, private val conversationRepository: ConversationRepository, + private val oneOnOneResolver: OneOnOneResolver ) : AcceptConnectionRequestUseCase { override suspend fun invoke(userId: UserId): AcceptConnectionRequestUseCaseResult { return connectionRepository.updateConnectionStatus(userId, ConnectionState.ACCEPTED) - .flatMap { - conversationRepository.fetchConversation(it.qualifiedConversationId) - conversationRepository.updateConversationModifiedDate(it.qualifiedConversationId, DateTimeUtil.currentInstant()) + .flatMap { connection -> + conversationRepository.fetchConversation(connection.qualifiedConversationId) + .flatMap { + conversationRepository.updateConversationModifiedDate( + connection.qualifiedConversationId, + DateTimeUtil.currentInstant() + ) + }.flatMap { + oneOnOneResolver.resolveOneOnOneConversationWithUserId( + connection.qualifiedToId + ).map { } + } } .fold({ - kaliumLogger.e("An error occurred when accepting the connection request from $userId") + kaliumLogger.e("An error occurred when accepting the connection request from ${userId.toLogString()}: $it") AcceptConnectionRequestUseCaseResult.Failure(it) }, { AcceptConnectionRequestUseCaseResult.Success diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/ConnectionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/ConnectionScope.kt index 5f8abb6c5f1..70b1908beed 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/ConnectionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/ConnectionScope.kt @@ -20,17 +20,22 @@ package com.wire.kalium.logic.feature.connection import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver -class ConnectionScope( +class ConnectionScope internal constructor( private val connectionRepository: ConnectionRepository, private val conversationRepository: ConversationRepository, + private val userRepository: UserRepository, + private val oneOnOneResolver: OneOnOneResolver ) { - val sendConnectionRequest: SendConnectionRequestUseCase get() = SendConnectionRequestUseCaseImpl(connectionRepository) + val sendConnectionRequest: SendConnectionRequestUseCase get() = SendConnectionRequestUseCaseImpl(connectionRepository, userRepository) val acceptConnectionRequest: AcceptConnectionRequestUseCase get() = AcceptConnectionRequestUseCaseImpl( connectionRepository, - conversationRepository + conversationRepository, + oneOnOneResolver ) val cancelConnectionRequest: CancelConnectionRequestUseCase get() = CancelConnectionRequestUseCaseImpl(connectionRepository) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCase.kt index bde72f515b8..870fc9b1401 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCase.kt @@ -22,6 +22,8 @@ import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.fold import com.wire.kalium.logic.kaliumLogger @@ -39,22 +41,24 @@ interface SendConnectionRequestUseCase { } internal class SendConnectionRequestUseCaseImpl( - private val connectionRepository: ConnectionRepository + private val connectionRepository: ConnectionRepository, + private val userRepository: UserRepository ) : SendConnectionRequestUseCase { override suspend fun invoke(userId: UserId): SendConnectionRequestResult { - return connectionRepository.sendUserConnection(userId) - .fold({ coreFailure -> - kaliumLogger.e("An error occurred when sending a connection request to $userId") - when (coreFailure) { - is NetworkFailure.FederatedBackendFailure.FederationDenied -> - SendConnectionRequestResult.Failure.FederationDenied + return userRepository.fetchUserInfo(userId).flatMap { + connectionRepository.sendUserConnection(userId) + }.fold({ coreFailure -> + kaliumLogger.e("An error occurred when sending a connection request to $userId") + when (coreFailure) { + is NetworkFailure.FederatedBackendFailure.FederationDenied -> + SendConnectionRequestResult.Failure.FederationDenied - else -> SendConnectionRequestResult.Failure.GenericFailure(coreFailure) - } - }, { - SendConnectionRequestResult.Success - }) + else -> SendConnectionRequestResult.Failure.GenericFailure(coreFailure) + } + }, { + SendConnectionRequestResult.Success + }) } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/CheckConversationInviteCodeUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/CheckConversationInviteCodeUseCase.kt index cd3732f2853..2d86ffb9c4f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/CheckConversationInviteCodeUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/CheckConversationInviteCodeUseCase.kt @@ -52,6 +52,7 @@ class CheckConversationInviteCodeUseCase internal constructor( when (failure) { is NetworkFailure.NoNetworkConnection, is NetworkFailure.FederatedBackendFailure, + is NetworkFailure.FeatureNotSupported, is NetworkFailure.ProxyError -> Result.Failure.Generic(failure) is NetworkFailure.ServerMiscommunication -> handleServerMissCommunicationError(failure) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt index 8885c74faf5..c33785dca71 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/ConversationScope.kt @@ -56,6 +56,7 @@ import com.wire.kalium.logic.feature.conversation.keyingmaterials.UpdateKeyingMa import com.wire.kalium.logic.feature.conversation.keyingmaterials.UpdateKeyingMaterialsUseCaseImpl import com.wire.kalium.logic.feature.conversation.messagetimer.UpdateMessageTimerUseCase import com.wire.kalium.logic.feature.conversation.messagetimer.UpdateMessageTimerUseCaseImpl +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver import com.wire.kalium.logic.feature.message.MessageSender import com.wire.kalium.logic.feature.message.SendConfirmationUseCase import com.wire.kalium.logic.feature.team.DeleteTeamConversationUseCase @@ -91,6 +92,7 @@ class ConversationScope internal constructor( private val serverConfigRepository: ServerConfigRepository, private val userStorage: UserStorage, private val userPropertyRepository: UserPropertyRepository, + private val oneOnOneResolver: OneOnOneResolver, private val scope: CoroutineScope ) { @@ -157,7 +159,11 @@ class ConversationScope internal constructor( get() = AddServiceToConversationUseCase(groupRepository = conversationGroupRepository) val getOrCreateOneToOneConversationUseCase: GetOrCreateOneToOneConversationUseCase - get() = GetOrCreateOneToOneConversationUseCase(conversationRepository, conversationGroupRepository) + get() = GetOrCreateOneToOneConversationUseCaseImpl( + conversationRepository, + userRepository, + oneOnOneResolver + ) val updateConversationMutedStatus: UpdateConversationMutedStatusUseCase get() = UpdateConversationMutedStatusUseCaseImpl(conversationRepository) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCase.kt index b01d2040ac1..d8f8eaf5551 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCase.kt @@ -21,9 +21,12 @@ package com.wire.kalium.logic.feature.conversation import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.ConversationGroupRepository import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.fold import kotlinx.coroutines.flow.first @@ -35,18 +38,22 @@ import kotlinx.coroutines.flow.first * @return Result with [Conversation] in case of success, or [CoreFailure] if something went wrong: * can't get data from local DB, or can't create a conversation. */ -class GetOrCreateOneToOneConversationUseCase( - private val conversationRepository: ConversationRepository, - private val conversationGroupRepository: ConversationGroupRepository -) { +interface GetOrCreateOneToOneConversationUseCase { + suspend operator fun invoke(otherUserId: UserId): CreateConversationResult +} - suspend operator fun invoke(otherUserId: UserId): CreateConversationResult { - // TODO: filter out self user from the list (just in case of client bug that leads to self user to be included part of the list) +internal class GetOrCreateOneToOneConversationUseCaseImpl( + private val conversationRepository: ConversationRepository, + private val userRepository: UserRepository, + private val oneOnOneResolver: OneOnOneResolver +) : GetOrCreateOneToOneConversationUseCase { + override suspend operator fun invoke(otherUserId: UserId): CreateConversationResult { + // TODO periodically re-resolve one-on-one return conversationRepository.observeOneToOneConversationWithOtherUser(otherUserId) .first() .fold({ conversationFailure -> if (conversationFailure is StorageFailure.DataNotFound) { - conversationGroupRepository.createGroupConversation(usersList = listOf(otherUserId)) + resolveOneOnOneConversationWithUser(otherUserId) .fold( CreateConversationResult::Failure, CreateConversationResult::Success @@ -59,6 +66,14 @@ class GetOrCreateOneToOneConversationUseCase( }) } + private suspend fun resolveOneOnOneConversationWithUser(otherUserId: UserId): Either = + (userRepository.getKnownUser(otherUserId).first()?.let { otherUser -> + // TODO support lazily establishing mls group for team 1-1 + oneOnOneResolver.resolveOneOnOneConversationWithUser(otherUser).flatMap { + conversationRepository.getConversationById(it)?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound) + } + } ?: Either.Left(StorageFailure.DataNotFound)) + } sealed class CreateConversationResult { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt index 4cda2eef95a..9007ba1cecb 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCase.kt @@ -36,7 +36,6 @@ import com.wire.kalium.logic.functional.getOrElse import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageFailureHandler import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageFailureResolution -import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker import com.wire.kalium.logic.wrapApiRequest import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.exceptions.KaliumException @@ -47,10 +46,10 @@ import com.wire.kalium.util.KaliumDispatcherImpl import kotlinx.coroutines.withContext /** - * Send an external commit to join all MLS conversations for which the user is a member, + * Send an external commit to join an MLS conversation for which the user is a member, * but has not yet joined the corresponding MLS group. */ -interface JoinExistingMLSConversationUseCase { +internal interface JoinExistingMLSConversationUseCase { suspend operator fun invoke(conversationId: ConversationId): Either } @@ -61,7 +60,6 @@ internal class JoinExistingMLSConversationUseCaseImpl( private val clientRepository: ClientRepository, private val conversationRepository: ConversationRepository, private val mlsConversationRepository: MLSConversationRepository, - private val mlsMessageUnpacker: MLSMessageUnpacker, kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl ) : JoinExistingMLSConversationUseCase { private val dispatcher = kaliumDispatcher.io @@ -70,7 +68,7 @@ internal class JoinExistingMLSConversationUseCaseImpl( if (!featureSupport.isMLSSupported || !clientRepository.hasRegisteredMLSClient().getOrElse(false) ) { - kaliumLogger.d("Skip re-join existing MLS conversation(s), since MLS is not supported.") + kaliumLogger.d("Skip re-join existing MLS conversation, since MLS is not supported.") Either.Right(Unit) } else { conversationRepository.baseInfoById(conversationId).fold({ @@ -91,7 +89,13 @@ internal class JoinExistingMLSConversationUseCaseImpl( if (failure.kaliumException.isMlsStaleMessage()) { kaliumLogger.w("Epoch out of date for conversation ${conversation.id}, re-fetching and re-trying") // Re-fetch current epoch and try again - conversationRepository.fetchConversation(conversation.id).flatMap { + if (conversation.type == Conversation.Type.ONE_ON_ONE) { + conversationRepository.getConversationMembers(conversation.id).flatMap { + conversationRepository.fetchMlsOneToOneConversation(it.first()) + } + } else { + conversationRepository.fetchConversation(conversation.id) + }.flatMap { conversationRepository.baseInfoById(conversation.id).flatMap { conversation -> joinOrEstablishMLSGroup(conversation) } @@ -108,23 +112,19 @@ internal class JoinExistingMLSConversationUseCaseImpl( } private suspend fun joinOrEstablishMLSGroup(conversation: Conversation): Either { - return if (conversation.protocol is Conversation.ProtocolInfo.MLS) { - if (conversation.protocol.epoch == 0UL) { - if (conversation.type == Conversation.Type.SELF) { - kaliumLogger.i("Establish group for ${conversation.type}") - mlsConversationRepository.establishMLSGroup( - conversation.protocol.groupId, - emptyList() - ) - } else { - Either.Right(Unit) - } - } else { + val protocol = conversation.protocol + val type = conversation.type + return when { + protocol !is Conversation.ProtocolInfo.MLSCapable -> Either.Right(Unit) + + protocol.epoch != 0UL -> { + // TODO(refactor): don't use conversationAPI directly + // we could use mlsConversationRepository to solve this wrapApiRequest { conversationApi.fetchGroupInfo(conversation.id.toApi()) }.flatMap { groupInfo -> mlsConversationRepository.joinGroupByExternalCommit( - conversation.protocol.groupId, + protocol.groupId, groupInfo ).flatMapLeft { if (MLSMessageFailureHandler.handleFailure(it) is MLSMessageFailureResolution.Ignore) { @@ -135,8 +135,26 @@ internal class JoinExistingMLSConversationUseCaseImpl( } } } - } else { - Either.Right(Unit) + + type == Conversation.Type.SELF -> { + kaliumLogger.i("Establish group for ${conversation.type}") + mlsConversationRepository.establishMLSGroup( + protocol.groupId, + emptyList() + ) + } + + type == Conversation.Type.ONE_ON_ONE -> { + kaliumLogger.i("Establish group for ${conversation.type}") + conversationRepository.getConversationMembers(conversation.id).flatMap { members -> + mlsConversationRepository.establishMLSGroup( + protocol.groupId, + members + ) + } + } + + else -> Either.Right(Unit) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCase.kt index 61212aaff2e..b34a9a42eb7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCase.kt @@ -20,11 +20,12 @@ package com.wire.kalium.logic.feature.conversation import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.client.ClientRepository -import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo.MLS.GroupState +import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo.MLSCapable.GroupState import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.flatMapLeft import com.wire.kalium.logic.functional.foldToEitherWhileRight import com.wire.kalium.logic.functional.getOrElse import com.wire.kalium.logic.kaliumLogger @@ -33,12 +34,12 @@ import com.wire.kalium.logic.kaliumLogger * Send an external commit to join all MLS conversations for which the user is a member, * but has not yet joined the corresponding MLS group. */ -interface JoinExistingMLSConversationsUseCase { +internal interface JoinExistingMLSConversationsUseCase { suspend operator fun invoke(keepRetryingOnFailure: Boolean = true): Either } @Suppress("LongParameterList") -class JoinExistingMLSConversationsUseCaseImpl( +internal class JoinExistingMLSConversationsUseCaseImpl( private val featureSupport: FeatureSupport, private val clientRepository: ClientRepository, private val conversationRepository: ConversationRepository, @@ -57,6 +58,18 @@ class JoinExistingMLSConversationsUseCaseImpl( return pendingConversations.map { conversation -> joinExistingMLSConversationUseCase(conversation.id) + .flatMapLeft { + if (it is CoreFailure.NoKeyPackagesAvailable) { + kaliumLogger.w( + "Failed to establish mls group for ${conversation.id.toLogString()} " + + "since some participants are out of key packages, skipping." + ) + Either.Right(Unit) + } else { + Either.Left(it) + } + + } }.foldToEitherWhileRight(Unit) { value, _ -> value } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/PersistMigratedConversationUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/PersistMigratedConversationUseCase.kt index 1e4aa9fed8c..fe4dc48c5e4 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/PersistMigratedConversationUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/PersistMigratedConversationUseCase.kt @@ -22,6 +22,7 @@ import com.wire.kalium.logger.KaliumLogger.Companion.ApplicationFlow.CONVERSATIO import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationMapper +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.wrapStorageRequest @@ -46,8 +47,9 @@ fun interface PersistMigratedConversationUseCase { } internal class PersistMigratedConversationUseCaseImpl( + private val selfUserId: UserId, private val migrationDAO: MigrationDAO, - private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper() + private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId) ) : PersistMigratedConversationUseCase { val logger by lazy { kaliumLogger.withFeatureId(CONVERSATIONS) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCase.kt index 6cb76e886ed..81c4690fe6c 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCase.kt @@ -21,7 +21,7 @@ package com.wire.kalium.logic.feature.conversation import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo.MLS.GroupState +import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo.MLSCapable.GroupState import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.featureFlags.FeatureSupport @@ -39,15 +39,15 @@ sealed class RecoverMLSConversationsResult { } /** - *Iterate over all MLS Established conversations after 404 sync error and + * Iterate over all MLS Established conversations after 404 sync error and * check for out of sync epochs, if out of sync then it tries to re-join. */ -interface RecoverMLSConversationsUseCase { +internal interface RecoverMLSConversationsUseCase { suspend operator fun invoke(): RecoverMLSConversationsResult } @Suppress("LongParameterList") -class RecoverMLSConversationsUseCaseImpl( +internal class RecoverMLSConversationsUseCaseImpl( private val featureSupport: FeatureSupport, private val clientRepository: ClientRepository, private val conversationRepository: ConversationRepository, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RenameConversationUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RenameConversationUseCase.kt index 4d5f71a3377..862cf2277e1 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RenameConversationUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/RenameConversationUseCase.kt @@ -47,14 +47,14 @@ internal class RenameConversationUseCaseImpl( val persistMessage: PersistMessageUseCase, private val renamedConversationEventHandler: RenamedConversationEventHandler, val selfUserId: UserId, - private val eventMapper: EventMapper = MapperProvider.eventMapper() + private val eventMapper: EventMapper = MapperProvider.eventMapper(selfUserId) ) : RenameConversationUseCase { override suspend fun invoke(conversationId: ConversationId, conversationName: String): RenamingResult { return conversationRepository.changeConversationName(conversationId, conversationName) .onSuccess { response -> if (response is ConversationRenameResponse.Changed) renamedConversationEventHandler.handle( - eventMapper.conversationRenamed(LocalId.generate(), response.event, true) + eventMapper.conversationRenamed(LocalId.generate(), response.event, true, false) ) } .fold({ diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCase.kt index cf3d2536a88..f121d6bdcff 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCase.kt @@ -19,20 +19,43 @@ package com.wire.kalium.logic.feature.conversation import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.SystemMessageInserter import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap internal interface SyncConversationsUseCase { suspend operator fun invoke(): Either } + /** * This use case will sync against the backend the conversations of the current user. */ internal class SyncConversationsUseCaseImpl( - private val conversationRepository: ConversationRepository + private val conversationRepository: ConversationRepository, + private val systemMessageInserter: SystemMessageInserter ) : SyncConversationsUseCase { + override suspend operator fun invoke(): Either = + conversationRepository.getConversationIds(Conversation.Type.GROUP, Conversation.Protocol.PROTEUS) + .flatMap { proteusConversationIds -> + conversationRepository.fetchConversations() + .flatMap { + reportConversationsWithPotentialHistoryLoss(proteusConversationIds) + } + } - override suspend operator fun invoke(): Either { - return conversationRepository.fetchConversations() - } + private suspend fun reportConversationsWithPotentialHistoryLoss( + proteusConversationIds: List + ): Either = + conversationRepository.getConversationIds(Conversation.Type.GROUP, Conversation.Protocol.MLS) + .flatMap { mlsConversationIds -> + val conversationsWithUpgradedProtocol = mlsConversationIds.intersect(proteusConversationIds) + for (conversationId in conversationsWithUpgradedProtocol) { + systemMessageInserter.insertHistoryLostProtocolChangedSystemMessage(conversationId) + } + Either.Right(Unit) + } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/guestroomlink/GenerateGuestRoomLinkUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/guestroomlink/GenerateGuestRoomLinkUseCase.kt index 9e482f86f34..42a6e7a77b4 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/guestroomlink/GenerateGuestRoomLinkUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/guestroomlink/GenerateGuestRoomLinkUseCase.kt @@ -51,6 +51,7 @@ class GenerateGuestRoomLinkUseCaseImpl internal constructor( id = uuid4().toString(), isPasswordProtected = it.data.hasPassword, transient = false, + live = false, key = it.data.key, uri = it.data.uri ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/MLSOneOnOneConversationResolver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/MLSOneOnOneConversationResolver.kt new file mode 100644 index 00000000000..55bfea4e331 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/MLSOneOnOneConversationResolver.kt @@ -0,0 +1,75 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo.MLSCapable.GroupState +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger + +/** + * Attempts to find an existing MLS-capable one-on-one conversation, + * or creates a new one if none is found. + * In case the conversation already exists, but it's not established yet + * (see [GroupState.ESTABLISHED]), it will attempt to join it, returning failure if it fails. + */ +internal interface MLSOneOnOneConversationResolver { + /** + * Attempts to find an existing MLS-capable one-on-one conversation, + * or creates a new one if none is found. + * In case the conversation already exists, but it's not established yet + * (see [GroupState.ESTABLISHED]), it will attempt to join it, returning failure if it fails. + * @param userId The user ID of the other participant. + */ + suspend operator fun invoke(userId: UserId): Either +} + +internal class MLSOneOnOneConversationResolverImpl( + private val conversationRepository: ConversationRepository, + private val joinExistingMLSConversationUseCase: JoinExistingMLSConversationUseCase, +) : MLSOneOnOneConversationResolver { + + override suspend fun invoke(userId: UserId): Either = + conversationRepository.getConversationsByUserId(userId).flatMap { conversations -> + // Look for an existing MLS-capable conversation one-on-one + val initializedMLSOneOnOne = conversations.firstOrNull { + val isOneOnOne = it.type == Conversation.Type.ONE_ON_ONE + val protocol = it.protocol + val isMLSInitialized = protocol is Conversation.ProtocolInfo.MLSCapable && + protocol.groupState == GroupState.ESTABLISHED + isOneOnOne && isMLSInitialized + } + + if (initializedMLSOneOnOne != null) { + kaliumLogger.d("Already established mls group for one-on-one with ${userId.toLogString()}, skipping.") + Either.Right(initializedMLSOneOnOne.id) + } else { + kaliumLogger.d("Establishing mls group for one-on-one with ${userId.toLogString()}") + conversationRepository.fetchMlsOneToOneConversation(userId).flatMap { conversation -> + joinExistingMLSConversationUseCase(conversation.id).map { conversation.id } + } + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneMigrator.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneMigrator.kt new file mode 100644 index 00000000000..5b15974a670 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneMigrator.kt @@ -0,0 +1,110 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationGroupRepository +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.MessageRepository +import com.wire.kalium.logic.data.user.OtherUser +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.data.user.type.isTeammate +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.foldToEitherWhileRight +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger + +interface OneOnOneMigrator { + suspend fun migrateToProteus(user: OtherUser): Either + suspend fun migrateToMLS(user: OtherUser): Either +} + +internal class OneOnOneMigratorImpl( + private val getResolvedMLSOneOnOne: MLSOneOnOneConversationResolver, + private val conversationGroupRepository: ConversationGroupRepository, + private val conversationRepository: ConversationRepository, + private val messageRepository: MessageRepository, + private val userRepository: UserRepository +) : OneOnOneMigrator { + + override suspend fun migrateToProteus(user: OtherUser): Either = + conversationRepository.getOneOnOneConversationsWithOtherUser(user.id, Conversation.Protocol.PROTEUS).flatMap { conversationIds -> + if (conversationIds.isNotEmpty()) { + val conversationId = conversationIds.first() + Either.Right(conversationId) + } else { + Either.Left(StorageFailure.DataNotFound) + } + }.fold({ failure -> + if (failure is StorageFailure.DataNotFound && user.userType.isTeammate()) { + conversationGroupRepository.createGroupConversation(usersList = listOf(user.id)).map { it.id } + } else { + Either.Left(failure) + } + }, { + Either.Right(it) + }).flatMap { conversationId -> + if (user.activeOneOnOneConversationId != conversationId) { + kaliumLogger.d("resolved one-on-one to proteus, user = ${user.id.toLogString()}") + userRepository.updateActiveOneOnOneConversation(user.id, conversationId) + } + Either.Right(conversationId) + } + + override suspend fun migrateToMLS(user: OtherUser): Either { + return getResolvedMLSOneOnOne(user.id) + .flatMap { mlsConversation -> + if (user.activeOneOnOneConversationId == mlsConversation) { + return@flatMap Either.Right(mlsConversation) + } + + kaliumLogger.d("resolved one-on-one to MLS, user = ${user.id.toLogString()}") + + migrateOneOnOneHistory(user, mlsConversation) + .flatMap { + userRepository.updateActiveOneOnOneConversation( + conversationId = mlsConversation, + userId = user.id + ).map { + mlsConversation + } + } + } + } + + private suspend fun migrateOneOnOneHistory(user: OtherUser, targetConversation: ConversationId): Either { + return conversationRepository.getOneOnOneConversationsWithOtherUser( + otherUserId = user.id, + protocol = Conversation.Protocol.PROTEUS + ).flatMap { proteusOneOnOneConversations -> + // We can theoretically have more than one proteus 1-1 conversation with + // team members since there was no backend safeguards against this + proteusOneOnOneConversations.foldToEitherWhileRight(Unit) { proteusOneOnOneConversation, _ -> + messageRepository.moveMessagesToAnotherConversation( + originalConversation = proteusOneOnOneConversation, + targetConversation = targetConversation + ) + } + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneResolver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneResolver.kt new file mode 100644 index 00000000000..b83c41ba3ad --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneResolver.kt @@ -0,0 +1,118 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.sync.IncrementalSyncRepository +import com.wire.kalium.logic.data.sync.IncrementalSyncStatus +import com.wire.kalium.logic.data.user.OtherUser +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.protocol.OneOnOneProtocolSelector +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.flatMapLeft +import com.wire.kalium.logic.functional.foldToEitherWhileRight +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger +import com.wire.kalium.util.KaliumDispatcher +import com.wire.kalium.util.KaliumDispatcherImpl +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.launch +import kotlin.time.Duration + +interface OneOnOneResolver { + suspend fun resolveAllOneOnOneConversations(synchronizeUsers: Boolean = false): Either + suspend fun scheduleResolveOneOnOneConversationWithUserId(userId: UserId, delay: Duration = Duration.ZERO): Job + suspend fun resolveOneOnOneConversationWithUserId(userId: UserId): Either + suspend fun resolveOneOnOneConversationWithUser(user: OtherUser): Either +} + +internal class OneOnOneResolverImpl( + private val userRepository: UserRepository, + private val oneOnOneProtocolSelector: OneOnOneProtocolSelector, + private val oneOnOneMigrator: OneOnOneMigrator, + private val incrementalSyncRepository: IncrementalSyncRepository, + kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl +) : OneOnOneResolver { + + @OptIn(ExperimentalCoroutinesApi::class) + private val dispatcher = kaliumDispatcher.default.limitedParallelism(1) + private val resolveActiveOneOnOneScope = CoroutineScope(dispatcher) + + override suspend fun resolveAllOneOnOneConversations(synchronizeUsers: Boolean): Either = + if (synchronizeUsers) { + userRepository.fetchAllOtherUsers() + } else { + Either.Right(Unit) + }.flatMap { + val usersWithOneOnOne = userRepository.getUsersWithOneOnOneConversation() + kaliumLogger.i("Resolving one-on-one protocol for ${usersWithOneOnOne.size} user(s)") + usersWithOneOnOne.foldToEitherWhileRight(Unit) { item, _ -> + resolveOneOnOneConversationWithUser(item).flatMapLeft { + when (it) { + is CoreFailure.NoKeyPackagesAvailable, + is NetworkFailure.ServerMiscommunication, + is NetworkFailure.FederatedBackendFailure, + is CoreFailure.NoCommonProtocolFound + -> { + kaliumLogger.e("Resolving one-on-one failed $it, skipping") + Either.Right(Unit) + } + + else -> { + kaliumLogger.e("Resolving one-on-one failed $it, retrying") + Either.Left(it) + } + } + }.map { } + } + } + + override suspend fun scheduleResolveOneOnOneConversationWithUserId(userId: UserId, delay: Duration) = + resolveActiveOneOnOneScope.launch { + kaliumLogger.d("Schedule resolving active one-on-one") + incrementalSyncRepository.incrementalSyncState.first { it is IncrementalSyncStatus.Live } + delay(delay) + resolveOneOnOneConversationWithUserId(userId) + } + + override suspend fun resolveOneOnOneConversationWithUserId(userId: UserId): Either = + userRepository.getKnownUser(userId).firstOrNull()?.let { + resolveOneOnOneConversationWithUser(it) + } ?: Either.Left(StorageFailure.DataNotFound) + + override suspend fun resolveOneOnOneConversationWithUser(user: OtherUser): Either { + kaliumLogger.i("Resolving one-on-one protocol for ${user.id.toLogString()}") + return oneOnOneProtocolSelector.getProtocolForUser(user.id).flatMap { supportedProtocol -> + when (supportedProtocol) { + SupportedProtocol.PROTEUS -> oneOnOneMigrator.migrateToProteus(user) + SupportedProtocol.MLS -> oneOnOneMigrator.migrateToMLS(user) + } + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt index 48c25623e41..84891a1c98c 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt @@ -46,10 +46,12 @@ import com.wire.kalium.logic.feature.message.MessageSendingInterceptorImpl import com.wire.kalium.logic.feature.message.MessageSendingScheduler import com.wire.kalium.logic.feature.message.SessionEstablisher import com.wire.kalium.logic.feature.message.SessionEstablisherImpl +import com.wire.kalium.logic.feature.message.StaleEpochVerifier import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandlerImpl import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.sync.incremental.EventProcessor import com.wire.kalium.logic.util.MessageContentEncoder import com.wire.kalium.util.KaliumDispatcher import com.wire.kalium.util.KaliumDispatcherImpl @@ -75,6 +77,8 @@ class DebugScope internal constructor( private val slowSyncRepository: SlowSyncRepository, private val messageSendingScheduler: MessageSendingScheduler, private val selfConversationIdProvider: SelfConversationIdProvider, + private val staleEpochVerifier: StaleEpochVerifier, + private val eventProcessor: EventProcessor, private val scope: CoroutineScope, internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl ) { @@ -99,6 +103,11 @@ class DebugScope internal constructor( messageSender ) + val disableEventProcessing: DisableEventProcessingUseCase + get() = DisableEventProcessingUseCaseImpl( + eventProcessor = eventProcessor + ) + private val messageSendFailureHandler: MessageSendFailureHandler get() = MessageSendFailureHandlerImpl(userRepository, clientRepository, messageRepository, messageSendingScheduler) @@ -138,6 +147,7 @@ class DebugScope internal constructor( mlsMessageCreator, messageSendingInterceptor, userRepository, + staleEpochVerifier, { message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) }, scope ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DisableEventProcessingUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DisableEventProcessingUseCase.kt new file mode 100644 index 00000000000..f6bb30dea45 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DisableEventProcessingUseCase.kt @@ -0,0 +1,39 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.debug + +import com.wire.kalium.logic.sync.incremental.EventProcessor + +/** + * Disables processing of incoming events but still mark them as processed. + * + * This use case useful for testing error scenarios where messages have been lost, + * putting the client in an inconsistent state with the backend. + */ +interface DisableEventProcessingUseCase { + suspend operator fun invoke(disabled: Boolean) +} + +internal class DisableEventProcessingUseCaseImpl( + private val eventProcessor: EventProcessor +) : DisableEventProcessingUseCase { + + override suspend fun invoke(disabled: Boolean) { + eventProcessor.disableEventProcessing = disabled + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCase.kt index 76d1c5f173e..44732d85db6 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCase.kt @@ -28,6 +28,7 @@ import com.wire.kalium.logic.feature.featureConfig.handler.E2EIConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.FileSharingConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.GuestRoomConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SecondFactorPasswordChallengeConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SelfDeletingMessagesConfigHandler import com.wire.kalium.logic.functional.Either @@ -51,6 +52,7 @@ internal class SyncFeatureConfigsUseCaseImpl( private val guestRoomConfigHandler: GuestRoomConfigHandler, private val fileSharingConfigHandler: FileSharingConfigHandler, private val mlsConfigHandler: MLSConfigHandler, + private val mlsMigrationConfigHandler: MLSMigrationConfigHandler, private val classifiedDomainsConfigHandler: ClassifiedDomainsConfigHandler, private val conferenceCallingConfigHandler: ConferenceCallingConfigHandler, private val passwordChallengeConfigHandler: SecondFactorPasswordChallengeConfigHandler, @@ -63,7 +65,8 @@ internal class SyncFeatureConfigsUseCaseImpl( // TODO handle other feature flags and after it bump version in [SlowSyncManager.CURRENT_VERSION] guestRoomConfigHandler.handle(it.guestRoomLinkModel) fileSharingConfigHandler.handle(it.fileSharingModel) - mlsConfigHandler.handle(it.mlsModel) + mlsConfigHandler.handle(it.mlsModel, duringSlowSync = true) + it.mlsMigrationModel?.let { mlsMigrationConfigHandler.handle(it, duringSlowSync = true) } classifiedDomainsConfigHandler.handle(it.classifiedDomainsModel) conferenceCallingConfigHandler.handle(it.conferenceCallingModel) passwordChallengeConfigHandler.handle(it.secondFactorPasswordChallengeModel) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt index bdb276573ef..69f403f4aef 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandler.kt @@ -21,16 +21,37 @@ import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.featureConfig.MLSModel import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.getOrElse class MLSConfigHandler( private val userConfigRepository: UserConfigRepository, + private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase, private val selfUserId: UserId ) { - fun handle(mlsConfig: MLSModel): Either { + suspend fun handle(mlsConfig: MLSModel, duringSlowSync: Boolean): Either { val mlsEnabled = mlsConfig.status == Status.ENABLED val selfUserIsWhitelisted = mlsConfig.allowedUsers.contains(selfUserId.toPlainID()) + val previousSupportedProtocols = userConfigRepository.getSupportedProtocols().getOrElse(setOf(SupportedProtocol.PROTEUS)) + val supportedProtocolsHasChanged = !previousSupportedProtocols.equals(mlsConfig.supportedProtocols) + return userConfigRepository.setMLSEnabled(mlsEnabled && selfUserIsWhitelisted) + .flatMap { + userConfigRepository.setDefaultProtocol(if (mlsEnabled) mlsConfig.defaultProtocol else SupportedProtocol.PROTEUS) + }.flatMap { + userConfigRepository.setSupportedProtocols(mlsConfig.supportedProtocols) + }.flatMap { + if (supportedProtocolsHasChanged) { + updateSupportedProtocolsAndResolveOneOnOnes( + synchroniseUsers = !duringSlowSync + ) + } else { + Either.Right(Unit) + } + } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSMigrationConfigHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSMigrationConfigHandler.kt new file mode 100644 index 00000000000..799b8036cd2 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSMigrationConfigHandler.kt @@ -0,0 +1,41 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.featureConfig.handler + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase +import com.wire.kalium.logic.functional.Either + +class MLSMigrationConfigHandler( + private val userConfigRepository: UserConfigRepository, + private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase +) { + + suspend fun handle(mlsMigrationConfig: MLSMigrationModel, duringSlowSync: Boolean): Either { + if (mlsMigrationConfig.hasMigrationEnded() && !duringSlowSync) { + updateSupportedProtocolsAndResolveOneOnOnes( + synchroniseUsers = true + ) + } + + return userConfigRepository.setMigrationConfiguration(mlsMigrationConfig) + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt index 1cb7d29ee36..3313769ccda 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt @@ -89,6 +89,7 @@ class MessageScope internal constructor( private val protoContentMapper: ProtoContentMapper, private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase, private val messageMetadataRepository: MessageMetadataRepository, + private val staleEpochVerifier: StaleEpochVerifier, private val scope: CoroutineScope, internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl ) { @@ -145,6 +146,7 @@ class MessageScope internal constructor( mlsMessageCreator, messageSendingInterceptor, userRepository, + staleEpochVerifier, { message, expirationData -> ephemeralMessageDeletionHandler.enqueueSelfDeletion(message, expirationData) }, scope ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt index dd8ecc6f797..a600b61a123 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt @@ -136,6 +136,7 @@ internal class MessageSenderImpl internal constructor( private val mlsMessageCreator: MLSMessageCreator, private val messageSendingInterceptor: MessageSendingInterceptor, private val userRepository: UserRepository, + private val staleEpochVerifier: StaleEpochVerifier, private val enqueueSelfDeletion: (Message, Message.ExpirationData) -> Unit, private val scope: CoroutineScope ) : MessageSender { @@ -225,7 +226,7 @@ internal class MessageSenderImpl internal constructor( attemptToSendWithMLS(protocolInfo.groupId, message) } - is Conversation.ProtocolInfo.Proteus -> { + is Conversation.ProtocolInfo.Proteus, is Conversation.ProtocolInfo.Mixed -> { // TODO(messaging): make this thread safe (per user) attemptToSendWithProteus(message, messageTarget) } @@ -317,10 +318,13 @@ internal class MessageSenderImpl internal constructor( messageRepository.sendMLSMessage(message.conversationId, mlsMessage).fold({ if (it is NetworkFailure.ServerMiscommunication && it.kaliumException is KaliumException.InvalidRequestError) { if (it.kaliumException.isMlsStaleMessage()) { - logger.w("Encrypted MLS message for outdated epoch '${message.id}', re-trying..") - return syncManager.waitUntilLiveOrFailure().flatMap { - attemptToSend(message) - } + logger.w("Encrypted MLS message for stale epoch '${message.id}', re-trying..") + return staleEpochVerifier.verifyEpoch(message.conversationId) + .flatMap { + syncManager.waitUntilLiveOrFailure().flatMap { + attemptToSend(message) + } + } } } Either.Left(it) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifier.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifier.kt new file mode 100644 index 00000000000..fe82c98852e --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifier.kt @@ -0,0 +1,83 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.message + +import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.MLSFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger +import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant + +interface StaleEpochVerifier { + suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant? = null): Either +} + +internal class StaleEpochVerifierImpl( + private val systemMessageInserter: SystemMessageInserter, + private val conversationRepository: ConversationRepository, + private val mlsConversationRepository: MLSConversationRepository, + private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase +) : StaleEpochVerifier { + + private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.MESSAGES) } + override suspend fun verifyEpoch(conversationId: ConversationId, timestamp: Instant?): Either { + logger.i("Verifying stale epoch") + return getUpdatedConversationProtocolInfo(conversationId).flatMap { protocol -> + if (protocol is Conversation.ProtocolInfo.MLS) { + Either.Right(protocol) + } else { + Either.Left(MLSFailure.ConversationDoesNotSupportMLS) + } + }.flatMap { protocolInfo -> + mlsConversationRepository.isGroupOutOfSync(protocolInfo.groupId, protocolInfo.epoch) + .map { epochIsStale -> + epochIsStale + } + }.flatMap { hasMissedCommits -> + if (hasMissedCommits) { + logger.w("Epoch stale due to missing commits, re-joining") + joinExistingMLSConversation(conversationId).flatMap { + systemMessageInserter.insertLostCommitSystemMessage( + conversationId, + (timestamp ?: Clock.System.now()).toIsoDateTimeString() + ) + } + } else { + logger.i("Epoch stale due to unprocessed events") + Either.Right(Unit) + } + } + } + + private suspend fun getUpdatedConversationProtocolInfo(conversationId: ConversationId): Either { + return conversationRepository.fetchConversation(conversationId).flatMap { + conversationRepository.getConversationProtocolInfo(conversationId) + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/migration/MigrationScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/migration/MigrationScope.kt index 10ff919d385..de4e91eec69 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/migration/MigrationScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/migration/MigrationScope.kt @@ -18,15 +18,17 @@ package com.wire.kalium.logic.feature.migration +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.feature.conversation.PersistMigratedConversationUseCase import com.wire.kalium.logic.feature.conversation.PersistMigratedConversationUseCaseImpl import com.wire.kalium.persistence.db.UserDatabaseBuilder class MigrationScope( + private val selfUserId: UserId, private val userDatabase: UserDatabaseBuilder ) { val persistMigratedConversation: PersistMigratedConversationUseCase - get() = PersistMigratedConversationUseCaseImpl(userDatabase.migrationDAO) + get() = PersistMigratedConversationUseCaseImpl(selfUserId, userDatabase.migrationDAO) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt new file mode 100644 index 00000000000..2ff18128855 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt @@ -0,0 +1,112 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.sync.IncrementalSyncRepository +import com.wire.kalium.logic.data.sync.IncrementalSyncStatus +import com.wire.kalium.logic.feature.TimestampKeyRepository +import com.wire.kalium.logic.feature.TimestampKeys +import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.featureFlags.KaliumConfigs +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.getOrElse +import com.wire.kalium.logic.functional.onFailure +import com.wire.kalium.logic.functional.onSuccess +import com.wire.kalium.logic.kaliumLogger +import com.wire.kalium.util.KaliumDispatcher +import com.wire.kalium.util.KaliumDispatcherImpl +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Job +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.launch +import kotlinx.datetime.Clock + +/** + * Orchestrates the migration from proteus to MLS. + */ +internal interface MLSMigrationManager + +@Suppress("LongParameterList") +internal class MLSMigrationManagerImpl( + private val kaliumConfigs: KaliumConfigs, + private val featureSupport: FeatureSupport, + private val incrementalSyncRepository: IncrementalSyncRepository, + private val clientRepository: Lazy, + private val timestampKeyRepository: Lazy, + private val mlsMigrationWorker: Lazy, + kaliumDispatcher: KaliumDispatcher = KaliumDispatcherImpl +) : MLSMigrationManager { + /** + * A dispatcher with limited parallelism of 1. + * This means using this dispatcher only a single coroutine will be processed at a time. + */ + @OptIn(ExperimentalCoroutinesApi::class) + private val dispatcher = kaliumDispatcher.default.limitedParallelism(1) + + private val mlsMigrationScope = CoroutineScope(dispatcher) + + private var mlsMigrationJob: Job? = null + + init { + mlsMigrationJob = mlsMigrationScope.launch { + incrementalSyncRepository.incrementalSyncState.collect { syncState -> + ensureActive() + if (syncState is IncrementalSyncStatus.Live && + featureSupport.isMLSSupported && + clientRepository.value.hasRegisteredMLSClient().getOrElse(false) + ) { + updateMigration() + } + } + } + } + + private suspend fun updateMigration(): Either = + timestampKeyRepository.value.hasPassed( + TimestampKeys.LAST_MLS_MIGRATION_CHECK, + kaliumConfigs.mlsMigrationInterval + ).flatMap { lastMlsMigrationCheckHasPassed -> + kaliumLogger.d("Migration needs to be updated: $lastMlsMigrationCheckHasPassed") + if (lastMlsMigrationCheckHasPassed) { + kaliumLogger.d("Running mls migration") + mlsMigrationWorker.value.runMigration() + .onSuccess { + kaliumLogger.d("Successfully advanced the mls migration") + timestampKeyRepository.value.reset(TimestampKeys.LAST_MLS_MIGRATION_CHECK) + } + .onFailure { + kaliumLogger.d("Failure while advancing the mls migration: $it") + } + } + Either.Right(Unit) + } +} + +fun MLSMigrationModel.hasMigrationStarted(): Boolean { + return status == Status.ENABLED && startTime?.let { Clock.System.now() > it } ?: false +} + +fun MLSMigrationModel.hasMigrationEnded(): Boolean { + return status == Status.ENABLED && endTime?.let { Clock.System.now() > it } ?: false +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt new file mode 100644 index 00000000000..c481ba50771 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt @@ -0,0 +1,68 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository +import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.getOrNull +import com.wire.kalium.logic.kaliumLogger + +interface MLSMigrationWorker { + suspend fun runMigration(): Either +} + +internal class MLSMigrationWorkerImpl( + private val userConfigRepository: UserConfigRepository, + private val featureConfigRepository: FeatureConfigRepository, + private val mlsConfigHandler: MLSConfigHandler, + private val mlsMigrationConfigHandler: MLSMigrationConfigHandler, + private val mlsMigrator: MLSMigrator, +) : MLSMigrationWorker { + + override suspend fun runMigration() = + syncMigrationConfigurations().flatMap { + userConfigRepository.getMigrationConfiguration().getOrNull()?.let { configuration -> + if (configuration.hasMigrationStarted()) { + kaliumLogger.i("Running proteus to MLS migration") + mlsMigrator.migrateProteusConversations().flatMap { + if (configuration.hasMigrationEnded()) { + mlsMigrator.finaliseAllProteusConversations() + } else { + mlsMigrator.finaliseProteusConversations() + } + } + } else { + kaliumLogger.i("MLS migration is not enabled") + Either.Right(Unit) + } + } ?: Either.Right(Unit) + } + + private suspend fun syncMigrationConfigurations(): Either = + featureConfigRepository.getFeatureConfigs().flatMap { configurations -> + mlsConfigHandler.handle(configurations.mlsModel, duringSlowSync = false) + .flatMap { configurations.mlsMigrationModel?.let { + mlsMigrationConfigHandler.handle(configurations.mlsMigrationModel, duringSlowSync = false) + } ?: Either.Right(Unit) } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt new file mode 100644 index 00000000000..272cffa6c9c --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt @@ -0,0 +1,138 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.Conversation.Protocol +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.SelfTeamIdProvider +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.flatMapLeft +import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.foldToEitherWhileRight +import com.wire.kalium.logic.kaliumLogger + +interface MLSMigrator { + suspend fun migrateProteusConversations(): Either + suspend fun finaliseProteusConversations(): Either + suspend fun finaliseAllProteusConversations(): Either +} +internal class MLSMigratorImpl( + private val selfUserId: UserId, + private val selfTeamIdProvider: SelfTeamIdProvider, + private val userRepository: UserRepository, + private val conversationRepository: ConversationRepository, + private val mlsConversationRepository: MLSConversationRepository, + private val systemMessageInserter: SystemMessageInserter +) : MLSMigrator { + + override suspend fun migrateProteusConversations(): Either = + selfTeamIdProvider().flatMap { + it?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound) + }.flatMap { teamId -> + conversationRepository.getConversationIds(Conversation.Type.GROUP, Protocol.PROTEUS, teamId) + .flatMap { + it.foldToEitherWhileRight(Unit) { conversationId, _ -> + migrate(conversationId) + } + } + } + + override suspend fun finaliseAllProteusConversations(): Either = + selfTeamIdProvider().flatMap { + it?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound) + }.flatMap { teamId -> + conversationRepository.getConversationIds(Conversation.Type.GROUP, Protocol.MIXED, teamId) + .flatMap { + it.foldToEitherWhileRight(Unit) { conversationId, _ -> + finalise(conversationId) + } + } + } + + override suspend fun finaliseProteusConversations(): Either = + selfTeamIdProvider().flatMap { + it?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound) + }.flatMap { teamId -> + userRepository.fetchAllOtherUsers() + .flatMap { + conversationRepository.getTeamConversationIdsReadyToCompleteMigration(teamId) + .flatMap { + it.foldToEitherWhileRight(Unit) { conversationId, _ -> + finalise(conversationId) + } + } + } + } + + private suspend fun migrate(conversationId: ConversationId): Either { + kaliumLogger.i("migrating ${conversationId.toLogString()} to mixed") + return conversationRepository.updateProtocolRemotely(conversationId, Protocol.MIXED) + .flatMap { updated -> + if (updated) { + systemMessageInserter.insertProtocolChangedSystemMessage( + conversationId, selfUserId, Protocol.MIXED + ) + } + establishConversation(conversationId) + }.flatMapLeft { + kaliumLogger.w("failed to migrate ${conversationId.toLogString()} to mixed: $it") + Either.Right(Unit) + } + } + + private suspend fun finalise(conversationId: ConversationId): Either { + kaliumLogger.i("finalising ${conversationId.toLogString()} to mls") + return conversationRepository.updateProtocolRemotely(conversationId, Protocol.MLS) + .fold({ failure -> + kaliumLogger.w("failed to finalise ${conversationId.toLogString()} to mls: $failure") + Either.Right(Unit) + }, { updated -> + if (updated) { + systemMessageInserter.insertProtocolChangedSystemMessage( + conversationId, selfUserId, Protocol.MLS + ) + } + Either.Right(Unit) + }) + } + + private suspend fun establishConversation(conversationId: ConversationId) = + conversationRepository.getConversationProtocolInfo(conversationId) + .flatMap { protocolInfo -> + when (protocolInfo) { + is Conversation.ProtocolInfo.Mixed -> { + mlsConversationRepository.establishMLSGroup(protocolInfo.groupId, emptyList()) + .flatMap { + conversationRepository.getConversationMembers(conversationId).flatMap { members -> + mlsConversationRepository.addMemberToMLSGroup(protocolInfo.groupId, members) + } + } + } + else -> Either.Right(Unit) + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt new file mode 100644 index 00000000000..422cf17a9d6 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt @@ -0,0 +1,52 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.protocol + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap + +internal interface OneOnOneProtocolSelector { + suspend fun getProtocolForUser(userId: UserId): Either +} + +internal class OneOnOneProtocolSelectorImpl( + private val userRepository: UserRepository +) : OneOnOneProtocolSelector { + override suspend fun getProtocolForUser(userId: UserId): Either = + userRepository.userById(userId).flatMap { otherUser -> + val selfUser = userRepository.getSelfUser() ?: run { + val error = NullPointerException("Self user unobtainable when selecting protocol for user") + return@flatMap Either.Left(CoreFailure.Unknown(error)) + } + + val selfUserProtocols = selfUser.supportedProtocols.orEmpty() + val otherUserProtocols = otherUser.supportedProtocols.orEmpty() + + val commonProtocols = selfUserProtocols.intersect(otherUserProtocols) + + return when { + commonProtocols.contains(SupportedProtocol.MLS) -> Either.Right(SupportedProtocol.MLS) + commonProtocols.contains(SupportedProtocol.PROTEUS) -> Either.Right(SupportedProtocol.PROTEUS) + else -> Either.Left(CoreFailure.NoCommonProtocolFound) + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsAndResolveOneOnOnesUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsAndResolveOneOnOnesUseCase.kt new file mode 100644 index 00000000000..b42ff81ecc3 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsAndResolveOneOnOnesUseCase.kt @@ -0,0 +1,52 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.user + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap + +/** + * Update self supported protocols, and if the supported protocols + * did change we also resolve the active protocol for all one-on-one + * conversations. + */ +interface UpdateSupportedProtocolsAndResolveOneOnOnesUseCase { + + /** + * @param synchroniseUsers if true we synchronize all known users from backend + * in order to have to up-to-date information about which protocols are supported. + */ + suspend operator fun invoke(synchroniseUsers: Boolean): Either +} + +class UpdateSupportedProtocolsAndResolveOneOnOnesUseCaseImpl( + private val updateSupportedProtocols: UpdateSupportedProtocolsUseCase, + private val oneOnOneResolver: OneOnOneResolver +) : UpdateSupportedProtocolsAndResolveOneOnOnesUseCase { + + override suspend operator fun invoke(synchroniseUsers: Boolean) = + updateSupportedProtocols().flatMap { updated -> + if (updated) { + oneOnOneResolver.resolveAllOneOnOneConversations(synchronizeUsers = synchroniseUsers) + } else { + Either.Right(Unit) + } + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt new file mode 100644 index 00000000000..d25be6ee326 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCase.kt @@ -0,0 +1,138 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.user + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.client.Client +import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.client.isActive +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded +import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.flatMapLeft +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger +import kotlinx.datetime.Instant + +/** + * Updates the supported protocols of the current user. + */ +interface UpdateSupportedProtocolsUseCase { + suspend operator fun invoke(): Either +} + +internal class UpdateSupportedProtocolsUseCaseImpl( + private val clientsRepository: ClientRepository, + private val userRepository: UserRepository, + private val userConfigRepository: UserConfigRepository, + private val featureSupport: FeatureSupport +) : UpdateSupportedProtocolsUseCase { + + override suspend operator fun invoke(): Either { + return if (!featureSupport.isMLSSupported) { + kaliumLogger.d("Skip updating supported protocols, since MLS is not supported.") + Either.Right(false) + } else { + (userRepository.getSelfUser()?.let { selfUser -> + selfSupportedProtocols().flatMap { newSupportedProtocols -> + kaliumLogger.i( + "Updating supported protocols = $newSupportedProtocols previously = ${selfUser.supportedProtocols}" + ) + if (newSupportedProtocols != selfUser.supportedProtocols) { + userRepository.updateSupportedProtocols(newSupportedProtocols).map { true } + } else { + Either.Right(false) + } + }.flatMapLeft { + when (it) { + is StorageFailure.DataNotFound -> { + kaliumLogger.w( + "Skip updating supported protocols since additional protocols are not configured" + ) + Either.Right(false) + } + else -> Either.Left(it) + } + } + } ?: Either.Left(StorageFailure.DataNotFound)) + } + } + + private suspend fun selfSupportedProtocols(): Either> = + clientsRepository.selfListOfClients().flatMap { selfClients -> + userConfigRepository.getMigrationConfiguration() + .flatMapLeft { if (it is StorageFailure.DataNotFound) Either.Right(MIGRATION_CONFIGURATION_DISABLED) else Either.Left(it) } + .flatMap { migrationConfiguration -> + userConfigRepository.getSupportedProtocols().map { supportedProtocols -> + val selfSupportedProtocols = mutableSetOf() + if (proteusIsSupported(supportedProtocols, migrationConfiguration)) { + selfSupportedProtocols.add(SupportedProtocol.PROTEUS) + } + + if (mlsIsSupported(supportedProtocols, migrationConfiguration, selfClients)) { + selfSupportedProtocols.add(SupportedProtocol.MLS) + } + selfSupportedProtocols + } + } + } + + private fun mlsIsSupported( + supportedProtocols: Set, + migrationConfiguration: MLSMigrationModel, + selfClients: List + ): Boolean { + val mlsIsSupported = supportedProtocols.contains(SupportedProtocol.MLS) + val mlsMigrationHasEnded = migrationConfiguration.hasMigrationEnded() + val allSelfClientsAreMLSCapable = selfClients.filter { it.isActive }.all { it.isMLSCapable } + kaliumLogger.d( + "mls is supported = $mlsIsSupported, " + + "all active self clients are mls capable = $allSelfClientsAreMLSCapable " + + "migration has ended = $mlsMigrationHasEnded" + ) + return mlsIsSupported && (mlsMigrationHasEnded || allSelfClientsAreMLSCapable) + } + + private fun proteusIsSupported( + supportedProtocols: Set, + migrationConfiguration: MLSMigrationModel + ): Boolean { + val proteusIsSupported = supportedProtocols.contains(SupportedProtocol.PROTEUS) + val mlsMigrationHasEnded = migrationConfiguration.hasMigrationEnded() + kaliumLogger.d( + "proteus is supported = $proteusIsSupported, " + + "migration has ended = $mlsMigrationHasEnded" + ) + return proteusIsSupported || !mlsMigrationHasEnded + } + + companion object { + val MIGRATION_CONFIGURATION_DISABLED = MLSMigrationModel( + startTime = Instant.DISTANT_FUTURE, + endTime = Instant.DISTANT_FUTURE, + status = Status.DISABLED + ) + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt index 467b5a50b85..a0cad86fe92 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UserScope.kt @@ -82,8 +82,9 @@ class UserScope internal constructor( private val userPropertyRepository: UserPropertyRepository, private val messageSender: MessageSender, private val clientIdProvider: CurrentClientIdProvider, + private val e2EIRepository: E2EIRepository, private val isSelfATeamMember: IsSelfATeamMemberUseCase, - private val e2EIRepository: E2EIRepository + private val updateSupportedProtocolsUseCase: UpdateSupportedProtocolsUseCase, ) { private val validateUserHandleUseCase: ValidateUserHandleUseCase get() = ValidateUserHandleUseCaseImpl() val getSelfUser: GetSelfUserUseCase get() = GetSelfUserUseCaseImpl(userRepository) @@ -155,4 +156,6 @@ class UserScope internal constructor( val getAssetSizeLimit: GetAssetSizeLimitUseCase get() = GetAssetSizeLimitUseCaseImpl(isSelfATeamMember) val deleteAccount: DeleteAccountUseCase get() = DeleteAccountUseCase(accountRepository) + + val updateSupportedProtocols: UpdateSupportedProtocolsUseCase get() = updateSupportedProtocolsUseCase } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/featureFlags/KaliumConfigs.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/featureFlags/KaliumConfigs.kt index e14a6fd9d44..269a5e25572 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/featureFlags/KaliumConfigs.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/featureFlags/KaliumConfigs.kt @@ -20,6 +20,8 @@ package com.wire.kalium.logic.featureFlags import com.wire.kalium.logic.util.KaliumMockEngine import com.wire.kalium.network.NetworkStateObserver +import kotlin.time.Duration +import kotlin.time.Duration.Companion.hours data class KaliumConfigs( val forceConstantBitrateCalls: Boolean = false, @@ -40,7 +42,9 @@ data class KaliumConfigs( val isWebSocketEnabledByDefault: Boolean = false, val certPinningConfig: Map> = emptyMap(), val kaliumMockEngine: KaliumMockEngine? = null, - val mockNetworkStateObserver: NetworkStateObserver? = null + val mockNetworkStateObserver: NetworkStateObserver? = null, + // Interval between attempts to advance the proteus to MLS migration + val mlsMigrationInterval: Duration = 24.hours ) sealed interface BuildFileRestrictionState { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt index 6f735922bf8..06e4c367df5 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/network/SessionManagerImpl.kt @@ -115,6 +115,7 @@ class SessionManagerImpl internal constructor( is NetworkFailure.NoNetworkConnection -> null is NetworkFailure.ProxyError -> null is NetworkFailure.FederatedBackendFailure -> null + is NetworkFailure.FeatureNotSupported -> null is NetworkFailure.ServerMiscommunication -> { onServerMissCommunication(it) null diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/incremental/EventProcessor.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/incremental/EventProcessor.kt index 892fb75d5f0..ebc747a0c04 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/incremental/EventProcessor.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/incremental/EventProcessor.kt @@ -40,6 +40,12 @@ import com.wire.kalium.util.serialization.toJsonElement * @see [Event] */ internal interface EventProcessor { + + /** + * When enabled events will be consumed but no event processing will occur. + */ + var disableEventProcessing: Boolean + /** * Process the [event], and persist the last processed event ID if the event * is not transient. @@ -66,23 +72,31 @@ internal class EventProcessorImpl( kaliumLogger.withFeatureId(EVENT_RECEIVER) } - override suspend fun processEvent(event: Event): Either = when (event) { - is Event.Conversation -> conversationEventReceiver.onEvent(event) - is Event.User -> userEventReceiver.onEvent(event) - is Event.FeatureConfig -> featureConfigEventReceiver.onEvent(event) - is Event.Unknown -> { - kaliumLogger - .logEventProcessing( - EventLoggingStatus.SKIPPED, - event - ) - // Skipping event = success + override var disableEventProcessing: Boolean = false + + override suspend fun processEvent(event: Event): Either = + if (disableEventProcessing) { + logger.w("Skipping processing of $event due to debug option") Either.Right(Unit) - } + } else { + when (event) { + is Event.Conversation -> conversationEventReceiver.onEvent(event) + is Event.User -> userEventReceiver.onEvent(event) + is Event.FeatureConfig -> featureConfigEventReceiver.onEvent(event) + is Event.Unknown -> { + kaliumLogger + .logEventProcessing( + EventLoggingStatus.SKIPPED, + event + ) + // Skipping event = success + Either.Right(Unit) + } - is Event.Team -> teamEventReceiver.onEvent(event) - is Event.UserProperty -> userPropertiesEventReceiver.onEvent(event) - is Event.Federation -> federationEventReceiver.onEvent(event) + is Event.Team -> teamEventReceiver.onEvent(event) + is Event.UserProperty -> userPropertiesEventReceiver.onEvent(event) + is Event.Federation -> federationEventReceiver.onEvent(event) + } }.onSuccess { val logMap = mapOf( "event" to event.toLogMap() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt index 10069b19a1e..3a8e935ee16 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiver.kt @@ -28,6 +28,7 @@ import com.wire.kalium.logic.sync.receiver.conversation.MemberChangeEventHandler import com.wire.kalium.logic.sync.receiver.conversation.MemberJoinEventHandler import com.wire.kalium.logic.sync.receiver.conversation.MemberLeaveEventHandler import com.wire.kalium.logic.sync.receiver.conversation.NewConversationEventHandler +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandler import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHandler import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandler @@ -53,7 +54,8 @@ internal class ConversationEventReceiverImpl( private val conversationMessageTimerEventHandler: ConversationMessageTimerEventHandler, private val codeUpdatedHandler: CodeUpdatedHandler, private val codeDeletedHandler: CodeDeletedHandler, - private val typingIndicatorHandler: TypingIndicatorHandler + private val typingIndicatorHandler: TypingIndicatorHandler, + private val protocolUpdateEventHandler: ProtocolUpdateEventHandler ) : ConversationEventReceiver { override suspend fun onEvent(event: Event.Conversation): Either { // TODO: Make sure errors are accounted for by each handler. @@ -114,6 +116,9 @@ internal class ConversationEventReceiverImpl( is Event.Conversation.CodeDeleted -> codeDeletedHandler.handle(event) is Event.Conversation.CodeUpdated -> codeUpdatedHandler.handle(event) is Event.Conversation.TypingIndicator -> typingIndicatorHandler.handle(event) + is Event.Conversation.ConversationProtocol -> { + protocolUpdateEventHandler.handle(event) + } } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiver.kt index e4690dc3b44..ac5c525e3ee 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiver.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiver.kt @@ -29,6 +29,7 @@ import com.wire.kalium.logic.feature.featureConfig.handler.E2EIConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.FileSharingConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.GuestRoomConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SecondFactorPasswordChallengeConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SelfDeletingMessagesConfigHandler import com.wire.kalium.logic.functional.Either @@ -43,6 +44,7 @@ internal class FeatureConfigEventReceiverImpl internal constructor( private val guestRoomConfigHandler: GuestRoomConfigHandler, private val fileSharingConfigHandler: FileSharingConfigHandler, private val mlsConfigHandler: MLSConfigHandler, + private val mlsMigrationConfigHandler: MLSMigrationConfigHandler, private val classifiedDomainsConfigHandler: ClassifiedDomainsConfigHandler, private val conferenceCallingConfigHandler: ConferenceCallingConfigHandler, private val passwordChallengeConfigHandler: SecondFactorPasswordChallengeConfigHandler, @@ -84,7 +86,8 @@ internal class FeatureConfigEventReceiverImpl internal constructor( private suspend fun handleFeatureConfigEvent(event: Event.FeatureConfig): Either = when (event) { is Event.FeatureConfig.FileSharingUpdated -> fileSharingConfigHandler.handle(event.model) - is Event.FeatureConfig.MLSUpdated -> mlsConfigHandler.handle(event.model) + is Event.FeatureConfig.MLSUpdated -> mlsConfigHandler.handle(event.model, duringSlowSync = false) + is Event.FeatureConfig.MLSMigrationUpdated -> mlsMigrationConfigHandler.handle(event.model, duringSlowSync = false) is Event.FeatureConfig.ClassifiedDomainsUpdated -> classifiedDomainsConfigHandler.handle(event.model) is Event.FeatureConfig.ConferenceCallingUpdated -> conferenceCallingConfigHandler.handle(event.model) is Event.FeatureConfig.GuestRoomLinkUpdated -> guestRoomConfigHandler.handle(event.model) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiver.kt index 35477bfd65d..0dad1ff5726 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiver.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiver.kt @@ -92,13 +92,13 @@ internal class TeamEventReceiverImpl( .onSuccess { val knownUser = userRepository.getKnownUser(userId).first() if (knownUser?.name != null) { - conversationRepository.getConversationIdsByUserId(userId) + conversationRepository.getConversationsByUserId(userId) .onSuccess { - it.forEach { conversationId -> + it.forEach { conversation -> val message = Message.System( id = uuid4().toString(), // We generate a random uuid for this new system message content = MessageContent.TeamMemberRemoved(knownUser.name), - conversationId = conversationId, + conversationId = conversation.id, date = event.timestampIso, senderUserId = userId, status = Message.Status.Sent, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiver.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiver.kt index fd3e2b96e1f..ae2b70d341b 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiver.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiver.kt @@ -27,16 +27,21 @@ import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.EventLoggingStatus import com.wire.kalium.logic.data.event.logEventProcessing import com.wire.kalium.logic.data.logout.LogoutReason +import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.CurrentClientIdProvider import com.wire.kalium.logic.feature.auth.LogoutUseCase +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMapLeft +import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger +import kotlin.time.Duration.Companion.ZERO +import kotlin.time.Duration.Companion.seconds internal interface UserEventReceiver : EventReceiver @@ -47,8 +52,9 @@ internal class UserEventReceiverImpl internal constructor( private val conversationRepository: ConversationRepository, private val userRepository: UserRepository, private val logout: LogoutUseCase, + private val oneOnOneResolver: OneOnOneResolver, private val selfUserId: UserId, - private val currentClientIdProvider: CurrentClientIdProvider, + private val currentClientIdProvider: CurrentClientIdProvider ) : UserEventReceiver { override suspend fun onEvent(event: Event.User): Either { @@ -89,7 +95,21 @@ internal class UserEventReceiverImpl internal constructor( } private suspend fun handleNewConnection(event: Event.User.NewConnection): Either = - connectionRepository.insertConnectionFromEvent(event) + userRepository.fetchUserInfo(event.connection.qualifiedToId) + .flatMap { + connectionRepository.insertConnectionFromEvent(event) + .flatMap { + if (event.connection.status != ConnectionState.ACCEPTED) { + return@flatMap Either.Right(Unit) + } + + oneOnOneResolver.scheduleResolveOneOnOneConversationWithUserId( + event.connection.qualifiedToId, + delay = if (event.live) 3.seconds else ZERO + ) + Either.Right(Unit) + } + } .onSuccess { kaliumLogger .logEventProcessing( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandler.kt index d364a86a847..3779bfd3bf6 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandler.kt @@ -18,33 +18,37 @@ package com.wire.kalium.logic.sync.receiver.conversation -import com.wire.kalium.logger.obfuscateId +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.data.client.MLSClientProvider +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationDetails import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.event.EventLoggingStatus import com.wire.kalium.logic.data.event.logEventProcessing +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.id.GroupID +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver +import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.wrapMLSRequest -import com.wire.kalium.logic.wrapStorageRequest -import com.wire.kalium.persistence.dao.conversation.ConversationDAO -import com.wire.kalium.persistence.dao.conversation.ConversationEntity import io.ktor.util.decodeBase64Bytes +import kotlinx.coroutines.flow.first interface MLSWelcomeEventHandler { - suspend fun handle(event: Event.Conversation.MLSWelcome) + suspend fun handle(event: Event.Conversation.MLSWelcome): Either } internal class MLSWelcomeEventHandlerImpl( val mlsClientProvider: MLSClientProvider, - val conversationDAO: ConversationDAO, - val conversationRepository: ConversationRepository + val conversationRepository: ConversationRepository, + val oneOnOneResolver: OneOnOneResolver ) : MLSWelcomeEventHandler { - override suspend fun handle(event: Event.Conversation.MLSWelcome) { + override suspend fun handle(event: Event.Conversation.MLSWelcome): Either = mlsClientProvider .getMLSClient() .flatMap { client -> @@ -52,28 +56,40 @@ internal class MLSWelcomeEventHandlerImpl( client.processWelcomeMessage(event.message.decodeBase64Bytes()) } }.flatMap { groupID -> - val groupIdLogPair = Pair("groupId", groupID.obfuscateId()) - - wrapStorageRequest { - conversationRepository.fetchConversationIfUnknown(event.conversationId).map { - conversationDAO.updateConversationGroupState(ConversationEntity.GroupState.ESTABLISHED, groupID) + conversationRepository.fetchConversationIfUnknown(event.conversationId) + .flatMap { + markConversationAsEstablished(GroupID(groupID)) + }.flatMap { + resolveConversationIfOneOnOne(event.conversationId) } - }.onSuccess { - kaliumLogger - .logEventProcessing( - EventLoggingStatus.SUCCESS, - event, - Pair("info", "Established mls conversation from welcome message"), - groupIdLogPair - ) - }.onFailure { - kaliumLogger - .logEventProcessing( - EventLoggingStatus.FAILURE, - event, - groupIdLogPair - ) + }.onSuccess { + kaliumLogger + .logEventProcessing( + EventLoggingStatus.SUCCESS, + event, + Pair("info", "Established mls conversation from welcome message") + ) + }.onFailure { + kaliumLogger + .logEventProcessing( + EventLoggingStatus.FAILURE, + event, + Pair("failure", it) + ) + } + + private suspend fun markConversationAsEstablished(groupID: GroupID): Either = + conversationRepository.updateConversationGroupState(groupID, Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED) + + private suspend fun resolveConversationIfOneOnOne(conversationId: ConversationId): Either = + conversationRepository.observeConversationDetailsById(conversationId) + .first() + .flatMap { + if (it is ConversationDetails.OneOne) { + oneOnOneResolver.resolveOneOnOneConversationWithUser(it.otherUser).map { Unit } + } else { + Either.Right(Unit) } } - } + } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt new file mode 100644 index 00000000000..c5d8b69188e --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/ProtocolUpdateEventHandler.kt @@ -0,0 +1,69 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.kalium.logic.sync.receiver.conversation + +import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.event.Event +import com.wire.kalium.logic.data.event.EventLoggingStatus +import com.wire.kalium.logic.data.event.logEventProcessing +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.functional.onFailure +import com.wire.kalium.logic.functional.onSuccess +import com.wire.kalium.logic.kaliumLogger + +interface ProtocolUpdateEventHandler { + suspend fun handle(event: Event.Conversation.ConversationProtocol): Either +} + +internal class ProtocolUpdateEventHandlerImpl( + private val conversationRepository: ConversationRepository, + private val systemMessageInserter: SystemMessageInserter +) : ProtocolUpdateEventHandler { + + private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } + + override suspend fun handle(event: Event.Conversation.ConversationProtocol): Either = + conversationRepository.updateProtocolLocally(event.conversationId, event.protocol) + .onSuccess { updated -> + if (updated) { + systemMessageInserter.insertProtocolChangedSystemMessage( + event.conversationId, + event.senderUserId, + event.protocol + ) + } + logger + .logEventProcessing( + EventLoggingStatus.SUCCESS, + event + ) + } + .onFailure { coreFailure -> + logger + .logEventProcessing( + EventLoggingStatus.FAILURE, + event, + Pair("errorInfo", "$coreFailure") + ) + }.map { } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt index 9ba841c7bc4..85bcddeaa35 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt @@ -29,11 +29,13 @@ sealed class MLSMessageFailureResolution { internal object MLSMessageFailureHandler { fun handleFailure(failure: CoreFailure): MLSMessageFailureResolution { return when (failure) { - // Received messages targeting a future epoch, we might have lost messages. + // Received messages targeting a future epoch (outside epoch bounds), we might have lost messages. is MLSFailure.WrongEpoch -> MLSMessageFailureResolution.OutOfSync // Received already sent or received message, can safely be ignored. is MLSFailure.DuplicateMessage -> MLSMessageFailureResolution.Ignore - // Received self commit, any unmerged group has know when merged by CoreCrypto. + // Received message was targeting a future epoch and been buffered, can safely be ignored. + is MLSFailure.BufferedFutureMessage -> MLSMessageFailureResolution.Ignore + // Received self commit, any unmerged group has know been when merged by CoreCrypto. is MLSFailure.SelfCommitIgnored -> MLSMessageFailureResolution.Ignore // Message arrive in an unmerged group, it has been buffered and will be consumed later. is MLSFailure.UnmergedPendingGroup -> MLSMessageFailureResolution.Ignore diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt index c9c4e01f919..37b381b6182 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpacker.kt @@ -122,7 +122,7 @@ internal class MLSMessageUnpackerImpl( mlsConversationRepository.decryptMessage(messageEvent.content.decodeBase64Bytes(), groupID) } } ?: conversationRepository.getConversationProtocolInfo(messageEvent.conversationId).flatMap { protocolInfo -> - if (protocolInfo is Conversation.ProtocolInfo.MLS) { + if (protocolInfo is Conversation.ProtocolInfo.MLSCapable) { logger.d( "Decrypting MLS for " + "converationId = ${messageEvent.conversationId.value.obfuscateId()} " + @@ -130,7 +130,7 @@ internal class MLSMessageUnpackerImpl( ) mlsConversationRepository.decryptMessage(messageEvent.content.decodeBase64Bytes(), protocolInfo.groupId) } else { - Either.Right(emptyList()) + Either.Left(CoreFailure.NotSupportedByProteus) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt deleted file mode 100644 index a49d8d850d7..00000000000 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Wire - * Copyright (C) 2023 Wire Swiss GmbH - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see http://www.gnu.org/licenses/. - */ -package com.wire.kalium.logic.sync.receiver.conversation.message - -import com.benasher44.uuid.uuid4 -import com.wire.kalium.logger.KaliumLogger -import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.MLSFailure -import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.id.ConversationId -import com.wire.kalium.logic.data.message.Message -import com.wire.kalium.logic.data.message.MessageContent -import com.wire.kalium.logic.data.message.PersistMessageUseCase -import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase -import com.wire.kalium.logic.functional.Either -import com.wire.kalium.logic.functional.flatMap -import com.wire.kalium.logic.functional.map -import com.wire.kalium.logic.kaliumLogger - -interface MLSWrongEpochHandler { - suspend fun onMLSWrongEpoch( - conversationId: ConversationId, - dateIso: String, - ) -} - -internal class MLSWrongEpochHandlerImpl( - private val selfUserId: UserId, - private val persistMessage: PersistMessageUseCase, - private val conversationRepository: ConversationRepository, - private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase -) : MLSWrongEpochHandler { - - private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } - - override suspend fun onMLSWrongEpoch( - conversationId: ConversationId, - dateIso: String, - ) { - logger.i("Handling MLS WrongEpoch result") - conversationRepository.getConversationProtocolInfo(conversationId).flatMap { protocol -> - if (protocol is Conversation.ProtocolInfo.MLS) { - Either.Right(protocol) - } else { - Either.Left(MLSFailure.ConversationDoesNotSupportMLS) - } - }.flatMap { currentProtocol -> - getUpdatedConversationEpoch(conversationId).map { updatedEpoch -> - updatedEpoch != null && updatedEpoch != currentProtocol.epoch - } - }.flatMap { isRejoinNeeded -> - if (isRejoinNeeded) { - joinExistingMLSConversation(conversationId) - } else Either.Right(Unit) - }.flatMap { - insertInfoMessage(conversationId, dateIso) - } - } - - private suspend fun getUpdatedConversationEpoch(conversationId: ConversationId): Either { - return conversationRepository.fetchConversation(conversationId).flatMap { - conversationRepository.getConversationProtocolInfo(conversationId) - }.map { updatedProtocol -> - (updatedProtocol as? Conversation.ProtocolInfo.MLS)?.epoch - } - } - - private suspend fun insertInfoMessage(conversationId: ConversationId, dateIso: String): Either { - val mlsEpochWarningMessage = Message.System( - id = uuid4().toString(), - content = MessageContent.MLSWrongEpochWarning, - conversationId = conversationId, - date = dateIso, - senderUserId = selfUserId, - status = Message.Status.Read(0), - visibility = Message.Visibility.VISIBLE, - senderUserName = null, - expirationData = null - ) - return persistMessage(mlsEpochWarningMessage) - } -} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt index 21d378b54a3..a67df1e6ac7 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt @@ -28,10 +28,12 @@ import com.wire.kalium.logic.data.event.logEventProcessing import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.MessageContent import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.message.StaleEpochVerifier import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.util.serialization.toJsonElement +import kotlinx.datetime.toInstant internal interface NewMessageEventHandler { suspend fun handleNewProteusMessage(event: Event.Conversation.NewMessage) @@ -44,7 +46,7 @@ internal class NewMessageEventHandlerImpl( private val applicationMessageHandler: ApplicationMessageHandler, private val enqueueSelfDeletion: (conversationId: ConversationId, messageId: String) -> Unit, private val selfUserId: UserId, - private val mlsWrongEpochHandler: MLSWrongEpochHandler + private val staleEpochVerifier: StaleEpochVerifier ) : NewMessageEventHandler { private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } @@ -120,7 +122,7 @@ internal class NewMessageEventHandlerImpl( } is MLSMessageFailureResolution.OutOfSync -> { logger.i("Epoch out of sync error: ${logMap.toJsonElement()}") - mlsWrongEpochHandler.onMLSWrongEpoch(event.conversationId, event.timestampIso) + staleEpochVerifier.verifyEpoch(event.conversationId, event.timestampIso.toInstant()) } } }.onSuccess { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorker.kt index b56cde27e3c..279f3a2a56f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorker.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorker.kt @@ -25,13 +25,16 @@ import com.wire.kalium.logic.data.sync.SlowSyncStep import com.wire.kalium.logic.feature.connection.SyncConnectionsUseCase import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationsUseCase import com.wire.kalium.logic.feature.conversation.SyncConversationsUseCase +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver import com.wire.kalium.logic.feature.featureConfig.SyncFeatureConfigsUseCase import com.wire.kalium.logic.feature.team.SyncSelfTeamUseCase import com.wire.kalium.logic.feature.user.SyncContactsUseCase import com.wire.kalium.logic.feature.user.SyncSelfUserUseCase +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCase import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.isRight +import com.wire.kalium.logic.functional.map import com.wire.kalium.logic.functional.nullableFold import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.kaliumLogger @@ -56,11 +59,13 @@ internal class SlowSyncWorkerImpl( private val eventRepository: EventRepository, private val syncSelfUser: SyncSelfUserUseCase, private val syncFeatureConfigs: SyncFeatureConfigsUseCase, + private val updateSupportedProtocols: UpdateSupportedProtocolsUseCase, private val syncConversations: SyncConversationsUseCase, private val syncConnections: SyncConnectionsUseCase, private val syncSelfTeam: SyncSelfTeamUseCase, private val syncContacts: SyncContactsUseCase, - private val joinMLSConversations: JoinExistingMLSConversationsUseCase + private val joinMLSConversations: JoinExistingMLSConversationsUseCase, + private val oneOnOneResolver: OneOnOneResolver, ) : SlowSyncWorker { private val logger = kaliumLogger.withFeatureId(SYNC) @@ -78,11 +83,13 @@ internal class SlowSyncWorkerImpl( performStep(SlowSyncStep.SELF_USER, syncSelfUser::invoke) .continueWithStep(SlowSyncStep.FEATURE_FLAGS, syncFeatureConfigs::invoke) + .continueWithStep(SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS) { updateSupportedProtocols.invoke().map { } } .continueWithStep(SlowSyncStep.CONVERSATIONS, syncConversations::invoke) .continueWithStep(SlowSyncStep.CONNECTIONS, syncConnections::invoke) .continueWithStep(SlowSyncStep.SELF_TEAM, syncSelfTeam::invoke) .continueWithStep(SlowSyncStep.CONTACTS, syncContacts::invoke) .continueWithStep(SlowSyncStep.JOINING_MLS_CONVERSATIONS, joinMLSConversations::invoke) + .continueWithStep(SlowSyncStep.RESOLVE_ONE_ON_ONE_PROTOCOLS, oneOnOneResolver::resolveAllOneOnOneConversations) .flatMap { saveLastProcessedEventIdIfNeeded(lastProcessedEventIdToSaveOnSuccess) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt index bc82e028ae9..615d1279689 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallRepositoryTest.kt @@ -83,6 +83,7 @@ import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import kotlinx.coroutines.yield import kotlinx.datetime.Clock +import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -1752,7 +1753,7 @@ class CallRepositoryTest { val mlsProtocolInfo = Conversation.ProtocolInfo.MLS( groupId, - Conversation.ProtocolInfo.MLS.GroupState.ESTABLISHED, + Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED, 1UL, Clock.System.now(), Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallingParticipantsOrderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallingParticipantsOrderTest.kt index b237ee8a797..5770464e60c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallingParticipantsOrderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/call/CallingParticipantsOrderTest.kt @@ -177,7 +177,8 @@ class CallingParticipantsOrderTest { previewPicture = null, completePicture = null, availabilityStatus = UserAvailabilityStatus.AVAILABLE, - expiresAt = null + expiresAt = null, + supportedProtocols = null ) const val selfClientId = "client1" diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientRepositoryTest.kt index a1eb24cc5aa..5e7127bb0c2 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientRepositoryTest.kt @@ -287,7 +287,8 @@ class ClientRepositoryTest { model = "Mac ox", isVerified = false, isValid = true, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ), Client( id = PlainId(value = "client_id_2"), @@ -299,7 +300,8 @@ class ClientRepositoryTest { model = "iphone 15", isVerified = false, isValid = true, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ), ) @@ -362,7 +364,8 @@ class ClientRepositoryTest { isProteusVerified = false, isValid = true, userId = userId, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) ) @@ -377,7 +380,8 @@ class ClientRepositoryTest { model = null, isVerified = false, isValid = true, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) ) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientTest.kt new file mode 100644 index 00000000000..eeb5e0b3c36 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/ClientTest.kt @@ -0,0 +1,53 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.data.client + +import com.wire.kalium.logic.framework.TestClient +import kotlinx.datetime.Clock +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.days + +class ClientTest { + + @Test + fun givenLastActiveIsNull_thenIsActiveIsFalse() { + val client = TestClient.CLIENT.copy( + lastActive = null + ) + assertFalse(client.isActive) + } + + @Test + fun givenLastActiveIsOlderThanInactivityDuration_thenIsActiveIsFalse() { + val client = TestClient.CLIENT.copy( + lastActive = Clock.System.now() - (Client.INACTIVE_DURATION + 1.days) + ) + assertFalse(client.isActive) + } + + @Test + fun givenLastActiveIsNewerThanInactivityDuration_thenIsActiveIsTrue() { + val client = TestClient.CLIENT.copy( + lastActive = Clock.System.now() - (Client.INACTIVE_DURATION - 1.days) + ) + assertTrue(client.isActive) + } + +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepositoryTest.kt index 247311a6c02..38436932738 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/connection/ConnectionRepositoryTest.kt @@ -84,8 +84,8 @@ class ConnectionRepositoryTest { // then verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) .wasInvoked(exactly = twice) // Verifies that when fetching connections, it succeeded @@ -106,8 +106,8 @@ class ConnectionRepositoryTest { // then verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) .wasInvoked(exactly = twice) // Verifies that when fetching connections, it succeeded @@ -135,20 +135,6 @@ class ConnectionRepositoryTest { .suspendFunction(arrangement.connectionApi::createConnection) .with(eq(userId)) .wasInvoked(once) - - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::insertUser) - .with(any()) - .wasInvoked(once) - verify(arrangement.userDetailsApi) - .suspendFunction(arrangement.userDetailsApi::getUserInfo) - .with(any()) - .wasInvoked(once) - - verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::fetchConversations) - .wasNotInvoked() - } @Test @@ -171,8 +157,8 @@ class ConnectionRepositoryTest { .with(eq(userId)) .wasInvoked(once) verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) .wasNotInvoked() verify(arrangement.conversationRepository) .suspendFunction(arrangement.conversationRepository::fetchConversations) @@ -205,13 +191,6 @@ class ConnectionRepositoryTest { .suspendFunction(arrangement.connectionDAO::insertConnection) .with(any()) .wasInvoked(once) - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::insertUser) - .with(any()) - .wasInvoked(once) - verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::fetchConversations) - .wasNotInvoked() } @Test @@ -235,9 +214,9 @@ class ConnectionRepositoryTest { .with(eq(userId), eq(ConnectionStateDTO.ACCEPTED)) .wasInvoked(once) verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) - .wasInvoked(exactly = twice) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) + .wasInvoked(exactly = once) } @Test @@ -257,8 +236,8 @@ class ConnectionRepositoryTest { .with(eq(userId), eq(ConnectionStateDTO.ACCEPTED)) .wasNotInvoked() verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) .wasNotInvoked() } @@ -279,8 +258,8 @@ class ConnectionRepositoryTest { .with(eq(userId), eq(ConnectionStateDTO.ACCEPTED)) .wasInvoked(once) verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) .wasNotInvoked() } @@ -301,8 +280,8 @@ class ConnectionRepositoryTest { .with(eq(userId), eq(ConnectionStateDTO.PENDING)) .wasNotInvoked() verify(arrangement.memberDAO) - .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .with(any(), any(), any()) + .suspendFunction(arrangement.memberDAO::updateOrInsertOneOnOneMember) + .with(any(), any()) .wasNotInvoked() } @@ -351,10 +330,7 @@ class ConnectionRepositoryTest { conversationDAO = conversationDAO, connectionApi = connectionApi, connectionDAO = connectionDAO, - userDetailsApi = userDetailsApi, userDAO = userDAO, - selfUserId = TestUser.SELF.id, - selfTeamIdProvider = selfTeamIdProvider, memberDAO = memberDAO, conversationRepository = conversationRepository ) @@ -394,11 +370,11 @@ class ConnectionRepositoryTest { email = null, expiresAt = null, nonQualifiedId = "value", - service = null + service = null, + supportedProtocols = null ) val stubUserEntity = TestUser.DETAILS_ENTITY - val stubConversationID1 = QualifiedIDEntity("conversationId1", "domain") val stubConversationID2 = QualifiedIDEntity("conversationId2", "domain") @@ -437,11 +413,10 @@ class ConnectionRepositoryTest { } fun withNotFoundGetConversationError(): Arrangement = apply { - // TODO: user withUpdateOrInsertOneOnOneMemberWithConnectionStatusFailure directly in the test once it is fully refactored - withUpdateOrInsertOneOnOneMemberWithConnectionStatusFailure( + // TODO: use withUpdateOrInsertOneOnOneMemberFailure directly in the test once it is fully refactored + withUpdateOrInsertOneOnOneMemberFailure( error = Exception("error"), member = any(), - status = any(), conversationId = any() ) } @@ -456,11 +431,10 @@ class ConnectionRepositoryTest { } fun withErrorOnPersistingConnectionResponse(userId: NetworkUserId): Arrangement = apply { - // TODO: user withUpdateOrInsertOneOnOneMemberWithConnectionStatusFailure directly in the test once it is fully refactored - withUpdateOrInsertOneOnOneMemberWithConnectionStatusFailure( + // TODO: use withUpdateOrInsertOneOnOneMemberFailure directly in the test once it is fully refactored + withUpdateOrInsertOneOnOneMemberFailure( error = RuntimeException("An error occurred persisting the data"), member = eq(UserIDEntity(userId.value, userId.domain)), - status = any(), conversationId = any() ) } @@ -471,9 +445,8 @@ class ConnectionRepositoryTest { .whenInvokedWith(eq(userId), eq(ConnectionStateDTO.ACCEPTED)) .then { _, _ -> NetworkResponse.Success(stubConnectionOne, mapOf(), 200) } - withUpdateOrInsertOneOnOneMemberWithConnectionStatusSuccess( + withUpdateOrInsertOneOnOneMemberSuccess( member = eq(UserIDEntity(userId.value, userId.domain)), - status = any(), conversationId = any() ) } @@ -508,14 +481,14 @@ class ConnectionRepositoryTest { .whenInvokedWith(any()) .then { NetworkResponse.Success(stubUserProfileDTO, mapOf(), 200) } - withUpdateOrInsertOneOnOneMemberWithConnectionStatusSuccess() + withUpdateOrInsertOneOnOneMemberSuccess() given(userDAO).suspendFunction(userDAO::observeUserDetailsByQualifiedID) .whenInvokedWith(any()) .then { flowOf(stubUserEntity) } given(userDAO) - .suspendFunction(userDAO::insertUser) + .suspendFunction(userDAO::upsertUser) .whenInvokedWith(any()) .then { } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt index 22f5af21ef6..17c9a081617 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt @@ -284,7 +284,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAConversationAndAPISucceedsWithChange_whenAddingMembersToConversation_thenShouldSucceed() = runTest { + fun givenProteusConversation_whenAddingMembersToConversation_thenShouldSucceed() = runTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.CONVERSATION) .withProtocolInfoById(PROTEUS_PROTOCOL_INFO) @@ -296,6 +296,11 @@ class ConversationGroupRepositoryTest { conversationGroupRepository.addMembers(listOf(TestConversation.USER_1), TestConversation.ID) .shouldSucceed() + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::addMember) + .with(anything(), eq(TestConversation.ID.toApi())) + .wasInvoked(exactly = once) + verify(arrangement.memberJoinEventHandler) .suspendFunction(arrangement.memberJoinEventHandler::handle) .with(anything()) @@ -303,7 +308,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAConversationAndAPISucceedsWithChange_whenAddingServiceToConversation_thenShouldSucceed() = runTest { + fun givenProteusConversation_whenAddingServiceToConversation_thenShouldSucceed() = runTest { val serviceID = ServiceId("service-id", "service-provider") val addServiceRequest = AddServiceRequest(id = serviceID.id, provider = serviceID.provider) @@ -330,7 +335,43 @@ class ConversationGroupRepositoryTest { } @Test - fun givenMLSConversation_whenAddingServiceToConversation_theReturnError() = runTest { + fun givenProteusConversationAndUserIsAlreadyAMember_whenAddingMembersToConversation_thenShouldSucceed() = runTest { + val (arrangement, conversationGroupRepository) = Arrangement() + .withConversationDetailsById(TestConversation.CONVERSATION) + .withProtocolInfoById(PROTEUS_PROTOCOL_INFO) + .withFetchUsersIfUnknownByIdsSuccessful() + .withAddMemberAPISucceedUnchanged() + .arrange() + + conversationGroupRepository.addMembers(listOf(TestConversation.USER_1), TestConversation.ID) + .shouldSucceed() + + verify(arrangement.memberJoinEventHandler) + .suspendFunction(arrangement.memberJoinEventHandler::handle) + .with(anything()) + .wasNotInvoked() + } + + @Test + fun givenProteusConversationAndAPICallFails_whenAddingMembersToConversation_thenShouldFail() = runTest { + val (arrangement, conversationGroupRepository) = Arrangement() + .withConversationDetailsById(TestConversation.CONVERSATION) + .withProtocolInfoById(PROTEUS_PROTOCOL_INFO) + .withAddMemberAPIFailed() + .withInsertFailedToAddSystemMessageSuccess() + .arrange() + + conversationGroupRepository.addMembers(listOf(TestConversation.USER_1), TestConversation.ID) + .shouldFail() + + verify(arrangement.memberJoinEventHandler) + .suspendFunction(arrangement.memberJoinEventHandler::handle) + .with(anything()) + .wasNotInvoked() + } + + @Test + fun givenMLSConversation_whenAddingServiceToConversation_thenReturnError() = runTest { val serviceID = ServiceId("service-id", "service-provider") val (arrangement, conversationGroupRepository) = Arrangement() @@ -424,7 +465,36 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAConversationAndAPISucceedsWithChange_whenRemovingMemberFromConversation_thenShouldSucceed() = runTest { + fun givenMixedConversation_whenAddMemberFromConversation_thenShouldSucceed() = runTest { + val (arrangement, conversationGroupRepository) = Arrangement() + .withConversationDetailsById(TestConversation.MIXED_CONVERSATION) + .withProtocolInfoById(MIXED_PROTOCOL_INFO) + .withAddMemberAPISucceedChanged() + .withSuccessfulAddMemberToMLSGroup() + .withSuccessfulHandleMemberJoinEvent() + .arrange() + + conversationGroupRepository.addMembers(listOf(TestConversation.USER_1), TestConversation.ID) + .shouldSucceed() + + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::addMember) + .with(anything(), eq(TestConversation.ID.toApi())) + .wasInvoked(exactly = once) + + verify(arrangement.memberJoinEventHandler) + .suspendFunction(arrangement.memberJoinEventHandler::handle) + .with(anything()) + .wasInvoked(exactly = once) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::addMemberToMLSGroup) + .with(eq(GROUP_ID), eq(listOf(TestConversation.USER_1))) + .wasInvoked(exactly = once) + } + + @Test + fun givenProteusConversation_whenRemovingMemberFromConversation_thenShouldSucceed() = runTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.CONVERSATION) .withProtocolInfoById(PROTEUS_PROTOCOL_INFO) @@ -435,6 +505,11 @@ class ConversationGroupRepositoryTest { conversationGroupRepository.deleteMember(TestConversation.USER_1, TestConversation.ID) .shouldSucceed() + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::removeMember) + .with(eq(TestConversation.USER_1.toApi()), eq(TestConversation.ID.toApi())) + .wasInvoked(exactly = once) + verify(arrangement.memberLeaveEventHandler) .suspendFunction(arrangement.memberLeaveEventHandler::handle) .with(anything()) @@ -442,7 +517,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAConversationAndAPISucceedsWithoutChange_whenRemovingMemberFromConversation_thenShouldSucceed() = runTest { + fun givenProteusConversationAndUserIsNotAMember_whenRemovingMemberFromConversation_thenShouldSucceed() = runTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.CONVERSATION) .withProtocolInfoById(PROTEUS_PROTOCOL_INFO) @@ -459,7 +534,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAConversationAndAPIFailed_whenRemovingMemberFromConversation_thenShouldFail() = runTest { + fun givenProteusConversationAndAPICallFails_whenRemovingMemberFromConversation_thenShouldFail() = runTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.CONVERSATION) .withProtocolInfoById(PROTEUS_PROTOCOL_INFO) @@ -476,7 +551,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAnMLSConversationAndAPISucceeds_whenRemovingLeavingConversation_thenShouldSucceed() = runTest { + fun givenMLSConversation_whenRemovingLeavingConversation_thenShouldSucceed() = runTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.MLS_CONVERSATION) .withProtocolInfoById(MLS_PROTOCOL_INFO) @@ -503,7 +578,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenAnMLSConversationAndAPISucceeds_whenRemoveMemberFromConversation_thenShouldSucceed() = runTest { + fun givenMLSConversation_whenRemoveMemberFromConversation_thenShouldSucceed() = runTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.MLS_CONVERSATION) .withProtocolInfoById(MLS_PROTOCOL_INFO) @@ -524,6 +599,35 @@ class ConversationGroupRepositoryTest { .wasNotInvoked() } + @Test + fun givenMixedConversation_whenRemoveMemberFromConversation_thenShouldSucceed() = runTest { + val (arrangement, conversationGroupRepository) = Arrangement() + .withConversationDetailsById(TestConversation.MIXED_CONVERSATION) + .withProtocolInfoById(MIXED_PROTOCOL_INFO) + .withDeleteMemberAPISucceedChanged() + .withSuccessfulRemoveMemberFromMLSGroup() + .withSuccessfulHandleMemberLeaveEvent() + .arrange() + + conversationGroupRepository.deleteMember(TestConversation.USER_1, TestConversation.ID) + .shouldSucceed() + + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::removeMember) + .with(eq(TestConversation.USER_1.toApi()), eq(TestConversation.ID.toApi())) + .wasInvoked(exactly = once) + + verify(arrangement.memberLeaveEventHandler) + .suspendFunction(arrangement.memberLeaveEventHandler::handle) + .with(anything()) + .wasInvoked(exactly = once) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::removeMembersFromMLSGroup) + .with(eq(GROUP_ID), eq(listOf(TestConversation.USER_1))) + .wasInvoked(exactly = once) + } + @Test fun givenProteusConversation_whenJoiningConversationSuccessWithChanged_thenResponseIsHandled() = runTest { val code = "code" @@ -558,7 +662,40 @@ class ConversationGroupRepositoryTest { } @Test - fun givenMlsConversation_whenJoiningConversationSuccessWithChanged_thenAddSelfClientsToMlsGroup() = runTest { + fun givenProteusConversation_whenJoiningConversationSuccessWithUnchanged_thenMemberJoinEventHandlerIsNotInvoked() = runTest { + val code = "code" + val key = "key" + val uri = null + val password = null + + val (arrangement, conversationGroupRepository) = Arrangement() + .withConversationDetailsById(TestConversation.CONVERSATION) + .withConversationDetailsById(TestConversation.GROUP_VIEW_ENTITY(PROTEUS_PROTOCOL_INFO)) + .withJoinConversationAPIResponse( + code, + key, + uri, + NetworkResponse.Success(ConversationMemberAddedResponse.Unchanged, emptyMap(), 204) + ) + .withSuccessfulHandleMemberJoinEvent() + .arrange() + + conversationGroupRepository.joinViaInviteCode(code, key, uri, password) + .shouldSucceed() + + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::joinConversation) + .with(eq(code), eq(key), eq(uri)) + .wasInvoked(exactly = once) + + verify(arrangement.memberJoinEventHandler) + .suspendFunction(arrangement.memberJoinEventHandler::handle) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenMLSConversation_whenJoiningConversationSuccessWithChanged_thenAddSelfClientsToMlsGroup() = runTest { val code = "code" val key = "key" val uri = null @@ -603,7 +740,7 @@ class ConversationGroupRepositoryTest { } @Test - fun givenProteusConversation_whenJoiningConversationSuccessWithUnchanged_thenMemberJoinEventHandlerIsNotInvoked() = runTest { + fun givenMixedConversation_whenJoiningConversationSuccessWithChanged_thenAddSelfClientsToMlsGroup() = runTest { val code = "code" val key = "key" val uri = null @@ -611,14 +748,16 @@ class ConversationGroupRepositoryTest { val (arrangement, conversationGroupRepository) = Arrangement() .withConversationDetailsById(TestConversation.CONVERSATION) - .withConversationDetailsById(TestConversation.GROUP_VIEW_ENTITY(PROTEUS_PROTOCOL_INFO)) + .withProtocolInfoById(MIXED_PROTOCOL_INFO) .withJoinConversationAPIResponse( code, key, uri, - NetworkResponse.Success(ConversationMemberAddedResponse.Unchanged, emptyMap(), 204) + NetworkResponse.Success(ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE, emptyMap(), 200) ) .withSuccessfulHandleMemberJoinEvent() + .withJoinExistingMlsConversationSucceeds() + .withSuccessfulAddMemberToMLSGroup() .arrange() conversationGroupRepository.joinViaInviteCode(code, key, uri, password) @@ -632,7 +771,17 @@ class ConversationGroupRepositoryTest { verify(arrangement.memberJoinEventHandler) .suspendFunction(arrangement.memberJoinEventHandler::handle) .with(any()) - .wasNotInvoked() + .wasInvoked(exactly = once) + + verify(arrangement.joinExistingMLSConversation) + .suspendFunction(arrangement.joinExistingMLSConversation::invoke) + .with(eq(ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE.event.qualifiedConversation.toModel())) + .wasInvoked(exactly = once) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::addMemberToMLSGroup) + .with(eq(GroupID(MIXED_PROTOCOL_INFO.groupId)), eq(listOf(TestUser.SELF.id))) + .wasInvoked(exactly = once) } @Test @@ -1406,6 +1555,14 @@ class ConversationGroupRepositoryTest { Instant.parse("2021-03-30T15:36:00.000Z"), cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ) + val MIXED_PROTOCOL_INFO = ConversationEntity.ProtocolInfo + .Mixed( + RAW_GROUP_ID, + groupState = ConversationEntity.GroupState.ESTABLISHED, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) const val GROUP_NAME = "Group Name" diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapperTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapperTest.kt index 93e6d116758..72d0c5a8cba 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapperTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapperTest.kt @@ -23,6 +23,7 @@ import com.wire.kalium.logic.data.id.IdMapper import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.user.AvailabilityStatusMapper import com.wire.kalium.logic.data.user.type.DomainUserTypeMapper +import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMemberDTO import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMembersResponse @@ -74,6 +75,7 @@ class ConversationMapperTest { @BeforeTest fun setup() { conversationMapper = ConversationMapperImpl( + TestUser.SELF.id, idMapper, conversationStatusMapper, protocolInfoMapper, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt index bcd64a50dc1..bc77d380f8e 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt @@ -45,6 +45,7 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangement import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangementImpl +import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.client.ClientApi import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol @@ -60,13 +61,16 @@ import com.wire.kalium.network.api.base.authenticated.conversation.ConversationR import com.wire.kalium.network.api.base.authenticated.conversation.ReceiptMode import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse +import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationAccessInfoDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO +import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationProtocolDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO import com.wire.kalium.network.api.base.model.ConversationAccessDTO import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO +import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.NetworkResponse import com.wire.kalium.persistence.dao.ConversationIDEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity @@ -119,7 +123,15 @@ class ConversationRepositoryTest { @Test fun givenNewConversationEvent_whenCallingPersistConversation_thenConversationShouldBePersisted() = runTest { - val event = Event.Conversation.NewConversation("id", TestConversation.ID, false, TestUser.SELF.id, "time", CONVERSATION_RESPONSE) + val event = Event.Conversation.NewConversation( + "id", + TestConversation.ID, + false, + false, + TestUser.SELF.id, + "time", + CONVERSATION_RESPONSE + ) val selfUserFlow = flowOf(TestUser.SELF) val (arrangement, conversationRepository) = Arrangement() .withSelfUserFlow(selfUserFlow) @@ -140,104 +152,129 @@ class ConversationRepositoryTest { } @Test - fun givenNewConversationEvent_whenCallingPersistConversationFromEvent_thenConversationShouldBePersisted() = runTest { - val event = Event.Conversation.NewConversation("id", TestConversation.ID, false, TestUser.SELF.id, "time", CONVERSATION_RESPONSE) - val selfUserFlow = flowOf(TestUser.SELF) - val (arrangement, conversationRepository) = Arrangement() - .withSelfUserFlow(selfUserFlow) - .withExpectedConversationBase(null) - .arrange() + fun givenNewConversationEvent_whenCallingPersistConversationFromEvent_thenConversationShouldBePersisted() = + runTest { + val event = Event.Conversation.NewConversation( + "id", + TestConversation.ID, + false, + false, + TestUser.SELF.id, + "time", + CONVERSATION_RESPONSE + ) + val selfUserFlow = flowOf(TestUser.SELF) + val (arrangement, conversationRepository) = Arrangement() + .withSelfUserFlow(selfUserFlow) + .withExpectedConversationBase(null) + .arrange() - conversationRepository.persistConversation(event.conversation, "teamId") + conversationRepository.persistConversation(event.conversation, "teamId") - with(arrangement) { - verify(conversationDAO) - .suspendFunction(conversationDAO::insertConversation) - .with( - matching { conversation -> - conversation.id.value == CONVERSATION_RESPONSE.id.value - } - ) - .wasInvoked(exactly = once) + with(arrangement) { + verify(conversationDAO) + .suspendFunction(conversationDAO::insertConversation) + .with( + matching { conversation -> + conversation.id.value == CONVERSATION_RESPONSE.id.value + } + ) + .wasInvoked(exactly = once) + } } - } @Test - fun givenNewConversationEvent_whenCallingPersistConversationFromEventAndExists_thenConversationPersistenceShouldBeSkipped() = runTest { - val event = Event.Conversation.NewConversation("id", TestConversation.ID, false, TestUser.SELF.id, "time", CONVERSATION_RESPONSE) - val selfUserFlow = flowOf(TestUser.SELF) - val (arrangement, conversationRepository) = Arrangement() - .withSelfUserFlow(selfUserFlow) - .withExpectedConversationBase(TestConversation.ENTITY) - .arrange() + fun givenNewConversationEvent_whenCallingPersistConversationFromEventAndExists_thenConversationPersistenceShouldBeSkipped() = + runTest { + val event = Event.Conversation.NewConversation( + "id", + TestConversation.ID, + false, + false, + TestUser.SELF.id, + "time", + CONVERSATION_RESPONSE + ) + val selfUserFlow = flowOf(TestUser.SELF) + val (arrangement, conversationRepository) = Arrangement() + .withSelfUserFlow(selfUserFlow) + .withExpectedConversationBase(TestConversation.ENTITY) + .arrange() - conversationRepository.persistConversation(event.conversation, "teamId") + conversationRepository.persistConversation(event.conversation, "teamId") - with(arrangement) { - verify(conversationDAO) - .suspendFunction(conversationDAO::insertConversation) - .with( - matching { conversation -> - conversation.id.value == CONVERSATION_RESPONSE.id.value - } - ) - .wasNotInvoked() + with(arrangement) { + verify(conversationDAO) + .suspendFunction(conversationDAO::insertConversation) + .with( + matching { conversation -> + conversation.id.value == CONVERSATION_RESPONSE.id.value + } + ) + .wasNotInvoked() + } } - } @Test - fun givenNewConversationEventWithMlsConversation_whenCallingInsertConversation_thenMlsGroupExistenceShouldBeQueried() = runTest { - val event = Event.Conversation.NewConversation( - "id", - TestConversation.ID, - false, - TestUser.SELF.id, - "time", - CONVERSATION_RESPONSE.copy( - groupId = RAW_GROUP_ID, - protocol = MLS, - mlsCipherSuiteTag = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.cipherSuiteTag + fun givenNewConversationEventWithMlsConversation_whenCallingInsertConversation_thenMlsGroupExistenceShouldBeQueried() = + runTest { + val event = Event.Conversation.NewConversation( + "id", + TestConversation.ID, + false, + false, + TestUser.SELF.id, + "time", + CONVERSATION_RESPONSE.copy( + groupId = RAW_GROUP_ID, + protocol = MLS, + mlsCipherSuiteTag = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.cipherSuiteTag + ) + ) + val protocolInfo = ConversationEntity.ProtocolInfo.MLS( + RAW_GROUP_ID, + ConversationEntity.GroupState.ESTABLISHED, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ) - ) - val protocolInfo = ConversationEntity.ProtocolInfo.MLS( - RAW_GROUP_ID, - ConversationEntity.GroupState.ESTABLISHED, - 0UL, - Instant.parse("2021-03-30T15:36:00.000Z"), - cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 - ) - val (arrangement, conversationRepository) = Arrangement() - .withSelfUserFlow(flowOf(TestUser.SELF)) - .withHasEstablishedMLSGroup(true) - .arrange() + val (arrangement, conversationRepository) = Arrangement() + .withSelfUserFlow(flowOf(TestUser.SELF)) + .withHasEstablishedMLSGroup(true) + .arrange() - conversationRepository.persistConversations(listOf(event.conversation), "teamId", originatedFromEvent = true) + conversationRepository.persistConversations( + listOf(event.conversation), + "teamId", + originatedFromEvent = true + ) - verify(arrangement.mlsClient) - .suspendFunction(arrangement.mlsClient::conversationExists) - .with(eq(RAW_GROUP_ID)) - .wasInvoked(once) + verify(arrangement.mlsClient) + .suspendFunction(arrangement.mlsClient::conversationExists) + .with(eq(RAW_GROUP_ID)) + .wasInvoked(once) - verify(arrangement.conversationDAO) - .suspendFunction(arrangement.conversationDAO::insertConversations) - .with( - matching { conversations -> - conversations.any { entity -> - entity.id.value == CONVERSATION_RESPONSE.id.value && entity.protocolInfo == protocolInfo.copy( - keyingMaterialLastUpdate = (entity.protocolInfo as ConversationEntity.ProtocolInfo.MLS).keyingMaterialLastUpdate - ) + verify(arrangement.conversationDAO) + .suspendFunction(arrangement.conversationDAO::insertConversations) + .with( + matching { conversations -> + conversations.any { entity -> + entity.id.value == CONVERSATION_RESPONSE.id.value && entity.protocolInfo == protocolInfo.copy( + keyingMaterialLastUpdate = (entity.protocolInfo as ConversationEntity.ProtocolInfo.MLS).keyingMaterialLastUpdate + ) + } } - } - ) - .wasInvoked(once) - } + ) + .wasInvoked(once) + } @Test fun givenTwoPagesOfConversation_whenFetchingConversationsAndItsDetails_thenThePagesShouldBeAddedAndPersistOnlyFounds() = runTest { // given - val response = ConversationPagingResponse(listOf(CONVERSATION_IDS_DTO_ONE, CONVERSATION_IDS_DTO_TWO), false, "") + val response = + ConversationPagingResponse(listOf(CONVERSATION_IDS_DTO_ONE, CONVERSATION_IDS_DTO_TWO), false, "") val (arrangement, conversationRepository) = Arrangement() .withFetchConversationsIds(NetworkResponse.Success(response, emptyMap(), HttpStatusCode.OK.value)) @@ -306,51 +343,54 @@ class ConversationRepositoryTest { } @Test - fun givenConversationDaoReturnsAGroupConversation_whenGettingConversationDetailsById_thenReturnAGroupConversationDetails() = runTest { - val conversationEntity = TestConversation.VIEW_ENTITY.copy(type = ConversationEntity.Type.GROUP) + fun givenConversationDaoReturnsAGroupConversation_whenGettingConversationDetailsById_thenReturnAGroupConversationDetails() = + runTest { + val conversationEntity = TestConversation.VIEW_ENTITY.copy(type = ConversationEntity.Type.GROUP) - val (_, conversationRepository) = Arrangement() - .withExpectedObservableConversation(conversationEntity) - .arrange() + val (_, conversationRepository) = Arrangement() + .withExpectedObservableConversation(conversationEntity) + .arrange() - conversationRepository.observeConversationDetailsById(TestConversation.ID).test { - assertIs>(awaitItem()) - awaitComplete() + conversationRepository.observeConversationDetailsById(TestConversation.ID).test { + assertIs>(awaitItem()) + awaitComplete() + } } - } @Test - fun givenConversationDaoReturnsASelfConversation_whenGettingConversationDetailsById_thenReturnASelfConversationDetails() = runTest { - val conversationEntity = TestConversation.VIEW_ENTITY.copy(type = ConversationEntity.Type.SELF) + fun givenConversationDaoReturnsASelfConversation_whenGettingConversationDetailsById_thenReturnASelfConversationDetails() = + runTest { + val conversationEntity = TestConversation.VIEW_ENTITY.copy(type = ConversationEntity.Type.SELF) - val (_, conversationRepository) = Arrangement() - .withExpectedObservableConversation(conversationEntity) - .arrange() + val (_, conversationRepository) = Arrangement() + .withExpectedObservableConversation(conversationEntity) + .arrange() - conversationRepository.observeConversationDetailsById(TestConversation.ID).test { - assertIs>(awaitItem()) - awaitComplete() + conversationRepository.observeConversationDetailsById(TestConversation.ID).test { + assertIs>(awaitItem()) + awaitComplete() + } } - } @Test - fun givenConversationDaoReturnsAOneOneConversation_whenGettingConversationDetailsById_thenReturnAOneOneConversationDetails() = runTest { - val conversationId = TestConversation.ENTITY_ID - val conversationEntity = TestConversation.VIEW_ENTITY.copy( - id = conversationId, - type = ConversationEntity.Type.ONE_ON_ONE, - otherUserId = QualifiedIDEntity("otherUser", "domain") - ) + fun givenConversationDaoReturnsAOneOneConversation_whenGettingConversationDetailsById_thenReturnAOneOneConversationDetails() = + runTest { + val conversationId = TestConversation.ENTITY_ID + val conversationEntity = TestConversation.VIEW_ENTITY.copy( + id = conversationId, + type = ConversationEntity.Type.ONE_ON_ONE, + otherUserId = QualifiedIDEntity("otherUser", "domain") + ) - val (_, conversationRepository) = Arrangement() - .withExpectedObservableConversation(conversationEntity) - .arrange() + val (_, conversationRepository) = Arrangement() + .withExpectedObservableConversation(conversationEntity) + .arrange() - conversationRepository.observeConversationDetailsById(TestConversation.ID).test { - assertIs>(awaitItem()) - awaitComplete() + conversationRepository.observeConversationDetailsById(TestConversation.ID).test { + assertIs>(awaitItem()) + awaitComplete() + } } - } @Test fun givenUserHasKnownContactAndConversation_WhenGettingConversationDetailsByExistingConversation_ReturnTheCorrectConversation() = @@ -509,121 +549,123 @@ class ConversationRepositoryTest { @Suppress("LongMethod") @Test - fun givenUpdateAccessRoleSuccess_whenUpdatingConversationAccessInfo_thenTheNewAccessSettingsAreUpdatedLocally() = runTest { + fun givenUpdateAccessRoleSuccess_whenUpdatingConversationAccessInfo_thenTheNewAccessSettingsAreUpdatedLocally() = + runTest { - val conversationIdDTO = ConversationIdDTO("conv_id", "conv_domain") - val newAccessInfoDTO = ConversationAccessInfoDTO( - accessRole = setOf( - ConversationAccessRoleDTO.TEAM_MEMBER, - ConversationAccessRoleDTO.NON_TEAM_MEMBER, - ConversationAccessRoleDTO.SERVICE, - ConversationAccessRoleDTO.GUEST, - ), - access = setOf( - ConversationAccessDTO.INVITE, - ConversationAccessDTO.CODE, - ConversationAccessDTO.PRIVATE, - ConversationAccessDTO.LINK + val conversationIdDTO = ConversationIdDTO("conv_id", "conv_domain") + val newAccessInfoDTO = ConversationAccessInfoDTO( + accessRole = setOf( + ConversationAccessRoleDTO.TEAM_MEMBER, + ConversationAccessRoleDTO.NON_TEAM_MEMBER, + ConversationAccessRoleDTO.SERVICE, + ConversationAccessRoleDTO.GUEST, + ), + access = setOf( + ConversationAccessDTO.INVITE, + ConversationAccessDTO.CODE, + ConversationAccessDTO.PRIVATE, + ConversationAccessDTO.LINK + ) ) - ) - val newAccess = UpdateConversationAccessResponse.AccessUpdated( - EventContentDTO.Conversation.AccessUpdate( - conversationIdDTO, - data = newAccessInfoDTO, - qualifiedFrom = com.wire.kalium.network.api.base.model.UserId("from_id", "from_domain") + val newAccess = UpdateConversationAccessResponse.AccessUpdated( + EventContentDTO.Conversation.AccessUpdate( + conversationIdDTO, + data = newAccessInfoDTO, + qualifiedFrom = com.wire.kalium.network.api.base.model.UserId("from_id", "from_domain") + ) ) - ) - val (arrange, conversationRepository) = Arrangement() - .withApiUpdateAccessRoleReturns(NetworkResponse.Success(newAccess, mapOf(), 200)) - .withDaoUpdateAccessSuccess() - .arrange() - - conversationRepository.updateAccessInfo( - conversationID = ConversationId(conversationIdDTO.value, conversationIdDTO.domain), - access = setOf( - Conversation.Access.INVITE, - Conversation.Access.CODE, - Conversation.Access.PRIVATE, - Conversation.Access.LINK - ), - accessRole = setOf( - Conversation.AccessRole.TEAM_MEMBER, - Conversation.AccessRole.NON_TEAM_MEMBER, - Conversation.AccessRole.SERVICE, - Conversation.AccessRole.GUEST - ) - ).shouldSucceed() + val (arrange, conversationRepository) = Arrangement() + .withApiUpdateAccessRoleReturns(NetworkResponse.Success(newAccess, mapOf(), 200)) + .withDaoUpdateAccessSuccess() + .arrange() - with(arrange) { - verify(conversationApi) - .coroutine { - conversationApi.updateAccess( - conversationIdDTO, - UpdateConversationAccessRequest( - newAccessInfoDTO.access, - newAccessInfoDTO.accessRole + conversationRepository.updateAccessInfo( + conversationID = ConversationId(conversationIdDTO.value, conversationIdDTO.domain), + access = setOf( + Conversation.Access.INVITE, + Conversation.Access.CODE, + Conversation.Access.PRIVATE, + Conversation.Access.LINK + ), + accessRole = setOf( + Conversation.AccessRole.TEAM_MEMBER, + Conversation.AccessRole.NON_TEAM_MEMBER, + Conversation.AccessRole.SERVICE, + Conversation.AccessRole.GUEST + ) + ).shouldSucceed() + + with(arrange) { + verify(conversationApi) + .coroutine { + conversationApi.updateAccess( + conversationIdDTO, + UpdateConversationAccessRequest( + newAccessInfoDTO.access, + newAccessInfoDTO.accessRole + ) ) - ) - } - .wasInvoked(exactly = once) - - verify(conversationDAO) - .coroutine { - conversationDAO.updateAccess( - ConversationIDEntity(conversationIdDTO.value, conversationIdDTO.domain), - accessList = listOf( - ConversationEntity.Access.INVITE, - ConversationEntity.Access.CODE, - ConversationEntity.Access.PRIVATE, - ConversationEntity.Access.LINK - ), - accessRoleList = listOf( - ConversationEntity.AccessRole.TEAM_MEMBER, - ConversationEntity.AccessRole.NON_TEAM_MEMBER, - ConversationEntity.AccessRole.SERVICE, - ConversationEntity.AccessRole.GUEST + } + .wasInvoked(exactly = once) + + verify(conversationDAO) + .coroutine { + conversationDAO.updateAccess( + ConversationIDEntity(conversationIdDTO.value, conversationIdDTO.domain), + accessList = listOf( + ConversationEntity.Access.INVITE, + ConversationEntity.Access.CODE, + ConversationEntity.Access.PRIVATE, + ConversationEntity.Access.LINK + ), + accessRoleList = listOf( + ConversationEntity.AccessRole.TEAM_MEMBER, + ConversationEntity.AccessRole.NON_TEAM_MEMBER, + ConversationEntity.AccessRole.SERVICE, + ConversationEntity.AccessRole.GUEST + ) ) - ) - } - .wasInvoked(exactly = once) + } + .wasInvoked(exactly = once) + } } - } @Test - fun givenUpdateConversationMemberRoleSuccess_whenUpdatingConversationMemberRole_thenTheNewRoleIsUpdatedLocally() = runTest { - val (arrange, conversationRepository) = Arrangement() - .withApiUpdateConversationMemberRoleReturns(NetworkResponse.Success(Unit, mapOf(), 200)) - .withDaoUpdateConversationMemberRoleSuccess() - .arrange() - val conversationId = ConversationId("conv_id", "conv_domain") - val userId: UserId = UserId("user_id", "user_domain") - val newRole = Conversation.Member.Role.Admin - - conversationRepository.updateConversationMemberRole(conversationId, userId, newRole).shouldSucceed() - - with(arrange) { - verify(conversationApi) - .coroutine { - conversationApi.updateConversationMemberRole( - conversationId.toApi(), - userId.toApi(), - ConversationMemberRoleDTO(MapperProvider.conversationRoleMapper().toApi(newRole)) - ) - } - .wasInvoked(exactly = once) - - verify(memberDAO) - .coroutine { - memberDAO.updateConversationMemberRole( - conversationId.toDao(), - userId.toDao(), - MapperProvider.conversationRoleMapper().toDAO(newRole) - ) - } - .wasInvoked(exactly = once) + fun givenUpdateConversationMemberRoleSuccess_whenUpdatingConversationMemberRole_thenTheNewRoleIsUpdatedLocally() = + runTest { + val (arrange, conversationRepository) = Arrangement() + .withApiUpdateConversationMemberRoleReturns(NetworkResponse.Success(Unit, mapOf(), 200)) + .withDaoUpdateConversationMemberRoleSuccess() + .arrange() + val conversationId = ConversationId("conv_id", "conv_domain") + val userId: UserId = UserId("user_id", "user_domain") + val newRole = Conversation.Member.Role.Admin + + conversationRepository.updateConversationMemberRole(conversationId, userId, newRole).shouldSucceed() + + with(arrange) { + verify(conversationApi) + .coroutine { + conversationApi.updateConversationMemberRole( + conversationId.toApi(), + userId.toApi(), + ConversationMemberRoleDTO(MapperProvider.conversationRoleMapper().toApi(newRole)) + ) + } + .wasInvoked(exactly = once) + + verify(memberDAO) + .coroutine { + memberDAO.updateConversationMemberRole( + conversationId.toDao(), + userId.toDao(), + MapperProvider.conversationRoleMapper().toDAO(newRole) + ) + } + .wasInvoked(exactly = once) + } } - } @Test fun givenProteusConversation_WhenDeletingTheConversation_ThenShouldBeDeletedLocally() = runTest { @@ -673,16 +715,16 @@ class ConversationRepositoryTest { val shouldFetchFromArchivedConversations = false val messagePreviewEntity = MESSAGE_PREVIEW_ENTITY.copy(conversationId = conversationIdEntity) - val conversationEntity = TestConversation.VIEW_ENTITY.copy( - id = conversationIdEntity, - type = ConversationEntity.Type.GROUP, - ) + val conversationEntity = TestConversation.VIEW_ENTITY.copy( + id = conversationIdEntity, + type = ConversationEntity.Type.GROUP, + ) - val unreadMessagesCount = 5 - val conversationUnreadEventEntity = ConversationUnreadEventEntity( - conversationIdEntity, - mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) - ) + val unreadMessagesCount = 5 + val conversationUnreadEventEntity = ConversationUnreadEventEntity( + conversationIdEntity, + mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) + ) val (_, conversationRepository) = Arrangement() .withConversations(listOf(conversationEntity)) @@ -694,8 +736,8 @@ class ConversationRepositoryTest { conversationRepository.observeConversationListDetails(shouldFetchFromArchivedConversations).test { val result = awaitItem() - assertContains(result.map { it.conversation.id }, conversationId) - val conversation = result.first { it.conversation.id == conversationId } + assertContains(result.map { it.conversation.id }, conversationId) + val conversation = result.first { it.conversation.id == conversationId } assertIs(conversation) assertEquals(conversation.unreadEventCount[UnreadEventType.MESSAGE], unreadMessagesCount) @@ -704,9 +746,9 @@ class ConversationRepositoryTest { conversation.lastMessage ) - awaitComplete() + awaitComplete() + } } - } @Test fun givenArchivedConversationHasNewMessages_whenGettingConversationDetails_ThenCorrectlyGetUnreadMessageCountAndNullLastMessage() = @@ -754,46 +796,46 @@ class ConversationRepositoryTest { val conversationEntity = TestConversation.VIEW_ENTITY.copy( type = ConversationEntity.Type.GROUP, ) + val (_, conversationRepository) = Arrangement() + .withExpectedObservableConversation(conversationEntity) + .arrange() - val (_, conversationRepository) = Arrangement() - .withExpectedObservableConversation(conversationEntity) - .arrange() - - // when - conversationRepository.observeConversationDetailsById(TestConversation.ID).test { - // then - val conversationDetail = awaitItem() + // when + conversationRepository.observeConversationDetailsById(TestConversation.ID).test { + // then + val conversationDetail = awaitItem() - assertIs>(conversationDetail) - assertTrue { conversationDetail.value.lastMessage == null } + assertIs>(conversationDetail) + assertTrue { conversationDetail.value.lastMessage == null } - awaitComplete() + awaitComplete() + } } - } @Test - fun givenAOneToOneConversationHasNotNewMessages_whenGettingConversationDetails_ThenReturnZeroUnreadMessageCount() = runTest { - // given - val conversationEntity = TestConversation.VIEW_ENTITY.copy( - type = ConversationEntity.Type.ONE_ON_ONE, - otherUserId = QualifiedIDEntity("otherUser", "domain") - ) + fun givenAOneToOneConversationHasNotNewMessages_whenGettingConversationDetails_ThenReturnZeroUnreadMessageCount() = + runTest { + // given + val conversationEntity = TestConversation.VIEW_ENTITY.copy( + type = ConversationEntity.Type.ONE_ON_ONE, + otherUserId = QualifiedIDEntity("otherUser", "domain") + ) - val (_, conversationRepository) = Arrangement() - .withExpectedObservableConversation(conversationEntity) - .arrange() + val (_, conversationRepository) = Arrangement() + .withExpectedObservableConversation(conversationEntity) + .arrange() - // when - conversationRepository.observeConversationDetailsById(TestConversation.ID).test { - // then - val conversationDetail = awaitItem() + // when + conversationRepository.observeConversationDetailsById(TestConversation.ID).test { + // then + val conversationDetail = awaitItem() - assertIs>(conversationDetail) - assertTrue { conversationDetail.value.lastMessage == null } + assertIs>(conversationDetail) + assertTrue { conversationDetail.value.lastMessage == null } - awaitComplete() + awaitComplete() + } } - } @Test fun givenAGroupConversationHasNewMessages_whenObservingConversationListDetails_ThenCorrectlyGetUnreadMessageCount() = runTest { @@ -802,36 +844,36 @@ class ConversationRepositoryTest { val conversationId = QualifiedID("some_value", "some_domain") val shouldFetchFromArchivedConversations = false - val conversationEntity = TestConversation.VIEW_ENTITY.copy( - id = conversationIdEntity, type = ConversationEntity.Type.ONE_ON_ONE, - otherUserId = QualifiedIDEntity("otherUser", "domain") - ) - - val unreadMessagesCount = 5 - val conversationUnreadEventEntity = ConversationUnreadEventEntity( - conversationIdEntity, - mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) - ) + val conversationEntity = TestConversation.VIEW_ENTITY.copy( + id = conversationIdEntity, type = ConversationEntity.Type.ONE_ON_ONE, + otherUserId = QualifiedIDEntity("otherUser", "domain") + ) - val (_, conversationRepository) = Arrangement() - .withConversations(listOf(conversationEntity)) - .withLastMessages(listOf()) - .withConversationUnreadEvents(listOf(conversationUnreadEventEntity)) - .arrange() + val unreadMessagesCount = 5 + val conversationUnreadEventEntity = ConversationUnreadEventEntity( + conversationIdEntity, + mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) + ) + val (_, conversationRepository) = Arrangement() + .withConversations(listOf(conversationEntity)) + .withLastMessages(listOf()) + .withConversationUnreadEvents(listOf(conversationUnreadEventEntity)) + .arrange() + // when conversationRepository.observeConversationListDetails(shouldFetchFromArchivedConversations).test { val result = awaitItem() - assertContains(result.map { it.conversation.id }, conversationId) - val conversation = result.first { it.conversation.id == conversationId } + assertContains(result.map { it.conversation.id }, conversationId) + val conversation = result.first { it.conversation.id == conversationId } - assertIs(conversation) - assertEquals(conversation.unreadEventCount[UnreadEventType.MESSAGE], unreadMessagesCount) + assertIs(conversation) + assertEquals(conversation.unreadEventCount[UnreadEventType.MESSAGE], unreadMessagesCount) - awaitComplete() + awaitComplete() + } } - } @Test fun givenAConversationDaoFailed_whenUpdatingTheConversationReadDate_thenShouldNotSucceed() = runTest { @@ -910,7 +952,8 @@ class ConversationRepositoryTest { val whoDeletedMe = UserId("deletion-author", "deletion-author-domain") val conversationId = ConversationId("conv_id", "conv_domain") val selfUserFlow = flowOf(TestUser.SELF) - val (arrange, conversationRepository) = Arrangement().withSelfUserFlow(selfUserFlow).withWhoDeletedMe(whoDeletedMe).arrange() + val (arrange, conversationRepository) = Arrangement().withSelfUserFlow(selfUserFlow) + .withWhoDeletedMe(whoDeletedMe).arrange() val result = conversationRepository.whoDeletedMe(conversationId) @@ -935,7 +978,8 @@ class ConversationRepositoryTest { @Test fun givenAConversationId_WhenTheConversationExists_ShouldReturnAConversationInstance() = runTest { val conversationId = ConversationId("conv_id", "conv_domain") - val (_, conversationRepository) = Arrangement().withExpectedObservableConversation(TestConversation.VIEW_ENTITY).arrange() + val (_, conversationRepository) = Arrangement().withExpectedObservableConversation(TestConversation.VIEW_ENTITY) + .arrange() val result = conversationRepository.getConversationById(conversationId) assertNotNull(result) @@ -944,13 +988,15 @@ class ConversationRepositoryTest { @Test fun givenAnUserId_WhenGettingConversationIds_ShouldReturnSuccess() = runTest { val userId = UserId("user_id", "user_domain") - val (arrange, conversationRepository) = Arrangement().withConversationIdsByUserId(listOf(TestConversation.ID)).arrange() + val (arrange, conversationRepository) = Arrangement() + .withConversationsByUserId(listOf(TestConversation.ENTITY)) + .arrange() - val result = conversationRepository.getConversationIdsByUserId(userId) + val result = conversationRepository.getConversationsByUserId(userId) with(result) { shouldSucceed() verify(arrange.conversationDAO) - .suspendFunction(arrange.conversationDAO::getConversationIdsByUserId) + .suspendFunction(arrange.conversationDAO::getConversationsByUserId) .with(any()) .wasInvoked(exactly = once) } @@ -1023,7 +1069,8 @@ class ConversationRepositoryTest { null, null, null, - null + null, + false ) ) @@ -1142,8 +1189,139 @@ class ConversationRepositoryTest { assertEquals(unreadCount, result) } + @Test + fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest { + // given + val protocol = Conversation.Protocol.MLS + + val (arrange, conversationRepository) = Arrangement() + .withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED) + .arrange() + + // when + val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldSucceed() + verify(arrange.conversationDAO) + .suspendFunction(arrange.conversationDAO::updateConversationProtocol) + .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .wasNotInvoked() + } + } + + @Test + fun givenChange_whenUpdatingProtocol_thenShouldFetchConversationDetails() = runTest { + // given + val protocol = Conversation.Protocol.MIXED + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) + + val (arrangement, conversationRepository) = Arrangement() + .withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS) + .withFetchConversationsDetails(conversationResponse) + .withDaoUpdateProtocolSuccess() + .arrange() + + // when + val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldSucceed() + verify(arrangement.conversationApi) + .suspendFunction(arrangement.conversationApi::fetchConversationDetails) + .with(eq(CONVERSATION_ID.toApi())) + .wasInvoked(exactly = once) + } + } + + @Test + fun givenChange_whenUpdatingProtocol_thenShouldUpdateLocally() = runTest { + // given + val protocol = Conversation.Protocol.MLS + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) + + val (arrange, conversationRepository) = Arrangement() + .withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS) + .withFetchConversationsDetails(conversationResponse) + .withDaoUpdateProtocolSuccess() + .arrange() + + // when + val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldSucceed() + verify(arrange.conversationDAO) + .suspendFunction(arrange.conversationDAO::updateConversationProtocol) + .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .wasInvoked(exactly = once) + } + } + + @Test + fun givenSuccessFetchingConversationDetails_whenUpdatingProtocolLocally_thenShouldUpdateLocally() = runTest { + // given + val protocol = Conversation.Protocol.MLS + val conversationResponse = NetworkResponse.Success( + TestConversation.CONVERSATION_RESPONSE, + emptyMap(), + HttpStatusCode.OK.value + ) + + val (arrange, conversationRepository) = Arrangement() + .withFetchConversationsDetails(conversationResponse) + .withDaoUpdateProtocolSuccess() + .arrange() + + // when + val result = conversationRepository.updateProtocolLocally(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldSucceed() + verify(arrange.conversationDAO) + .suspendFunction(arrange.conversationDAO::updateConversationProtocol) + .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .wasInvoked(exactly = once) + } + } + + @Test + fun givenFailureFetchingConversationDetails_whenUpdatingProtocolLocally_thenShouldNotUpdateLocally() = runTest { + // given + val protocol = Conversation.Protocol.MLS + val (arrange, conversationRepository) = Arrangement() + .withFetchConversationsDetails(NetworkResponse.Error(KaliumException.NoNetwork())) + .withDaoUpdateProtocolSuccess() + .arrange() + + // when + val result = conversationRepository.updateProtocolLocally(CONVERSATION_ID, protocol) + + // then + with(result) { + shouldFail() + verify(arrange.conversationDAO) + .suspendFunction(arrange.conversationDAO::updateConversationProtocol) + .with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao())) + .wasNotInvoked() + } + } + private class Arrangement : MemberDAOArrangement by MemberDAOArrangementImpl() { + @Mock val userRepository: UserRepository = mock(UserRepository::class) @@ -1175,7 +1353,8 @@ class ConversationRepositoryTest { val conversationMetaDataDAO: ConversationMetaDataDAO = mock(ConversationMetaDataDAO::class) @Mock - val renamedConversationEventHandler = configure(mock(RenamedConversationEventHandler::class)) { stubsUnitByDefault = true } + val renamedConversationEventHandler = + configure(mock(RenamedConversationEventHandler::class)) { stubsUnitByDefault = true } val conversationRepository = ConversationDataSource( @@ -1243,6 +1422,13 @@ class ConversationRepositoryTest { .thenReturn(selfUser) } + fun withFetchConversationsDetails(response: NetworkResponse) = apply { + given(conversationApi) + .suspendFunction(conversationApi::fetchConversationDetails) + .whenInvokedWith(any()) + .thenReturn(response) + } + fun withFetchConversationsIds(response: NetworkResponse) = apply { given(conversationApi) .suspendFunction(conversationApi::fetchConversationsIds) @@ -1262,7 +1448,7 @@ class ConversationRepositoryTest { fun withExpectedConversationWithOtherUser(conversation: ConversationViewEntity?) = apply { given(conversationDAO) - .suspendFunction(conversationDAO::observeConversationWithOtherUser) + .suspendFunction(conversationDAO::observeOneOnOneConversationWithOtherUser) .whenInvokedWith(anything()) .then { flowOf(conversation) } } @@ -1330,6 +1516,13 @@ class ConversationRepositoryTest { .thenReturn(Unit) } + fun withDaoUpdateProtocolSuccess() = apply { + given(conversationDAO) + .suspendFunction(conversationDAO::updateConversationProtocol) + .whenInvokedWith(any(), any()) + .thenReturn(true) + } + fun withGetConversationProtocolInfoReturns(result: ConversationEntity.ProtocolInfo) = apply { given(conversationDAO) .suspendFunction(conversationDAO::getConversationProtocolInfo) @@ -1405,13 +1598,11 @@ class ConversationRepositoryTest { .thenReturn(author) } - fun withConversationIdsByUserId(conversationIds: List) = apply { - val conversationIdEntities = conversationIds.map { it.toDao() } - + fun withConversationsByUserId(conversations: List) = apply { given(conversationDAO) - .suspendFunction(conversationDAO::getConversationIdsByUserId) + .suspendFunction(conversationDAO::getConversationsByUserId) .whenInvokedWith(any()) - .thenReturn(conversationIdEntities) + .thenReturn(conversations) } fun withConversationRenameCall(newName: String = "newName") = apply { @@ -1485,6 +1676,13 @@ class ConversationRepositoryTest { .thenReturn(result) } + fun withUpdateProtocolResponse(response: NetworkResponse) = apply { + given(conversationApi) + .suspendFunction(conversationApi::updateProtocol) + .whenInvokedWith(any(), any()) + .thenReturn(response) + } + fun arrange() = this to conversationRepository } @@ -1567,5 +1765,19 @@ class ConversationRepositoryTest { ) ) + val UPDATE_PROTOCOL_SUCCESS = NetworkResponse.Success( + UpdateConversationProtocolResponse.ProtocolUpdated( + EventContentDTO.Conversation.ProtocolUpdate( + TestConversation.NETWORK_ID, + ConversationProtocolDTO(ConvProtocol.MIXED), + TestUser.NETWORK_ID + ) + ), emptyMap(), 200 + ) + val UPDATE_PROTOCOL_UNCHANGED = NetworkResponse.Success( + UpdateConversationProtocolResponse.ProtocolUnchanged, + emptyMap(), 204 + ) + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index ff5e2f974e4..936ca5f44f1 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -165,6 +165,36 @@ class MLSConversationRepositoryTest { .wasInvoked(twice) } + @Test + fun givenMlsStaleMessageError_whenCallingEstablishMLSGroup_thenAbortCommitAndWipeData() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withCommitPendingProposalsReturningNothing() + .withClaimKeyPackagesSuccessful() + .withGetMLSClientSuccessful() + .withGetPublicKeysSuccessful() + .withAddMLSMemberSuccessful() + .withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR) + .arrange() + + val result = mlsConversationRepository.establishMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_1)) + result.shouldFail() + + verify(arrangement.mlsMessageApi) + .suspendFunction(arrangement.mlsMessageApi::sendCommitBundle) + .with(anyInstanceOf(MLSMessageApi.CommitBundle::class)) + .wasInvoked(once) + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::clearPendingCommit) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::wipeConversation) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + } + @Test fun givenSuccessfulResponses_whenCallingEstablishMLSGroup_thenKeyPackagesAreClaimedForMembers() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() @@ -389,6 +419,27 @@ class MLSConversationRepositoryTest { .wasInvoked(once) } + @Test + fun givenRetryLimitIsReached_whenCallingAddMemberToMLSGroup_thenClearCommitAndFail() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withClaimKeyPackagesSuccessful() + .withGetMLSClientSuccessful() + .withAddMLSMemberSuccessful() + .withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR, times = Int.MAX_VALUE) + .withCommitPendingProposalsSuccessful() + .withClearProposalTimerSuccessful() + .withWaitUntilLiveSuccessful() + .arrange() + + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + result.shouldFail() + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::clearPendingCommit) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + } + @Test fun givenSuccessfulResponses_whenCallingRequestToJoinGroup_ThenGroupStateIsUpdated() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() @@ -536,6 +587,24 @@ class MLSConversationRepositoryTest { .wasInvoked(once) } + @Test + fun givenRetryLimitIsReached_whenCallingCommitPendingProposals_thenClearCommitAndFail() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withCommitPendingProposalsSuccessful() + .withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR, times = Int.MAX_VALUE) + .withWaitUntilLiveSuccessful() + .arrange() + + val result = mlsConversationRepository.commitPendingProposals(Arrangement.GROUP_ID) + result.shouldFail() + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::clearPendingCommit) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + } + @Test fun givenSuccessfulResponses_whenCallingRemoveMemberFromGroup_thenCommitBundleIsSentAndAccepted() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() @@ -621,6 +690,27 @@ class MLSConversationRepositoryTest { .wasInvoked(once) } + @Test + fun givenRetryLimitIsReached_whenCallingRemoveMemberFromGroup_thenClearCommitAndFail() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withCommitPendingProposalsSuccessful() + .withGetMLSClientSuccessful() + .withFetchClientsOfUsersSuccessful() + .withRemoveMemberSuccessful() + .withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR, times = Int.MAX_VALUE) + .withWaitUntilLiveSuccessful() + .arrange() + + val users = listOf(TestUser.USER_ID) + val result = mlsConversationRepository.removeMembersFromMLSGroup(Arrangement.GROUP_ID, users) + result.shouldFail() + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::clearPendingCommit) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + } + @Test fun givenClientMismatchError_whenCallingRemoveMemberFromGroup_thenClearCommitAndRetry() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() @@ -858,6 +948,25 @@ class MLSConversationRepositoryTest { .wasInvoked(once) } + @Test + fun givenRetryLimitIsReached_whenCallingUpdateKeyMaterial_clearCommitAndFail() = runTest { + val (arrangement, mlsConversationRepository) = Arrangement() + .withGetMLSClientSuccessful() + .withUpdateKeyingMaterialSuccessful() + .withCommitPendingProposalsSuccessful() + .withSendCommitBundleFailing(Arrangement.MLS_STALE_MESSAGE_ERROR, times = Int.MAX_VALUE) + .withWaitUntilLiveSuccessful() + .arrange() + + val result = mlsConversationRepository.updateKeyingMaterial(Arrangement.GROUP_ID) + result.shouldFail() + + verify(arrangement.mlsClient) + .function(arrangement.mlsClient::clearPendingCommit) + .with(eq(Arrangement.RAW_GROUP_ID)) + .wasInvoked(once) + } + @Test fun givenConversationWithOutdatedEpoch_whenCallingIsGroupOutOfSync_returnsTrue() = runTest { val returnEpoch = 10UL @@ -1189,6 +1298,7 @@ class MLSConversationRepositoryTest { } fun arrange() = this to MLSConversationDataSource( + TestUser.SELF.id, keyPackageRepository, mlsClientProvider, mlsMessageApi, @@ -1255,6 +1365,7 @@ class MLSConversationRepositoryTest { "eventId", TestConversation.ID, false, + false, TestUser.USER_ID, WELCOME.encodeBase64(), timestampIso = "2022-03-30T15:36:00.000Z" diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapperTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapperTest.kt index 2ec5ac169eb..844ed2d61fd 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapperTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ProtocolInfoMapperTest.kt @@ -29,6 +29,13 @@ import kotlin.test.assertIs class ProtocolInfoMapperTest { private val protocolInfoMapper = ProtocolInfoMapperImpl() + @Test + fun givenConversationMixedProtocolInfo_WhenMapToConversationProtocolInfo_ResultShouldBeEqual() = runTest { + val mappedValue = protocolInfoMapper.toEntity(CONVERSATION_MIXED_PROTOCOL_INFO) + assertIs(mappedValue) + assertEquals(mappedValue, CONV_ENTITY_MIXED_PROTOCOL_INFO) + } + @Test fun givenConversationMLSProtocolInfo_WhenMapToConversationProtocolInfo_ResultShouldBeEqual() = runTest { val mappedValue = protocolInfoMapper.toEntity(CONVERSATION_MLS_PROTOCOL_INFO) @@ -43,6 +50,13 @@ class ProtocolInfoMapperTest { assertEquals(mappedValue, CONV_ENTITY_PROTEUS_PROTOCOL_INFO) } + @Test + fun givenEntityMixedProtocolInfo_WhenMapToConversationProtocolInfo_ResultShouldBeEqual() = runTest { + val mappedValue = protocolInfoMapper.fromEntity(CONV_ENTITY_MIXED_PROTOCOL_INFO) + assertIs(mappedValue) + assertEquals(mappedValue, CONVERSATION_MIXED_PROTOCOL_INFO) + } + @Test fun givenEntityMLSProtocolInfo_WhenMapToConversationProtocolInfo_ResultShouldBeEqual() = runTest { val mappedValue = protocolInfoMapper.fromEntity(CONV_ENTITY_MLS_PROTOCOL_INFO) @@ -58,15 +72,30 @@ class ProtocolInfoMapperTest { } companion object { + val CONVERSATION_MIXED_PROTOCOL_INFO = Conversation.ProtocolInfo.Mixed( + GroupID("GROUP_ID"), + Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED, + 5UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) val CONVERSATION_MLS_PROTOCOL_INFO = Conversation.ProtocolInfo.MLS( GroupID("GROUP_ID"), - Conversation.ProtocolInfo.MLS.GroupState.ESTABLISHED, + Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED, 5UL, Instant.parse("2021-03-30T15:36:00.000Z"), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ) val CONVERSATION_PROTEUS_PROTOCOL_INFO = Conversation.ProtocolInfo.Proteus + val CONV_ENTITY_MIXED_PROTOCOL_INFO = + ConversationEntity.ProtocolInfo.Mixed( + "GROUP_ID", + groupState = ConversationEntity.GroupState.ESTABLISHED, + 5UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) val CONV_ENTITY_MLS_PROTOCOL_INFO = ConversationEntity.ProtocolInfo.MLS( "GROUP_ID", @@ -76,6 +105,5 @@ class ProtocolInfoMapperTest { cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ) val CONV_ENTITY_PROTEUS_PROTOCOL_INFO = ConversationEntity.ProtocolInfo.Proteus - } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/EventRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/EventRepositoryTest.kt index 912a1d161e5..8a29d9c3974 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/EventRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/EventRepositoryTest.kt @@ -24,6 +24,7 @@ import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.feature.CurrentClientIdProvider import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed @@ -149,7 +150,7 @@ class EventRepositoryTest { @Mock val clientIdProvider = mock(CurrentClientIdProvider::class) - private val eventRepository: EventRepository = EventDataSource(notificationApi, metaDAO, clientIdProvider) + private val eventRepository: EventRepository = EventDataSource(notificationApi, metaDAO, clientIdProvider, TestUser.SELF.id) init { withCurrentClientIdReturning(TestClient.CLIENT_ID) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/FeatureConfigMapperTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/FeatureConfigMapperTest.kt index cfe1aa96e1b..d8efc3718b7 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/FeatureConfigMapperTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/event/FeatureConfigMapperTest.kt @@ -30,7 +30,10 @@ import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConf import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureFlagStatusDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSConfigDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.E2EIConfigDTO +import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSMigrationConfigDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.SelfDeletingMessagesConfigDTO +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO +import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertEquals @@ -64,6 +67,17 @@ class FeatureConfigMapperTest { assertEquals(listOf(PlainId("someId")), model.allowedUsers) } + @Test + fun givenApiModelResponse_whenMappingMLSMigrationStatusToModel_thenShouldBeMappedCorrectly() { + val (arrangement, mapper) = Arrangement().arrange() + + val model = arrangement.featureConfigResponse.mlsMigration?.let { mapper.fromDTO(it) } + + assertEquals(Status.ENABLED, model?.status) + assertEquals(Instant.DISTANT_FUTURE, model?.startTime) + assertEquals(Instant.DISTANT_FUTURE, model?.endTime) + } + @Test fun givenApiModelResponse_whenMappingClassifiedDomainsToModel_thenShouldBeMappedCorrectly() { val (arrangement, mapper) = Arrangement().arrange() @@ -124,9 +138,13 @@ class FeatureConfigMapperTest { private class Arrangement { val featureConfigResponse = FeatureConfigResponse( FeatureConfigData.AppLock( - AppLockConfigDTO(true, 0), FeatureFlagStatusDTO.ENABLED + AppLockConfigDTO(true, 0), + FeatureFlagStatusDTO.ENABLED + ), + FeatureConfigData.ClassifiedDomains( + ClassifiedDomainsConfigDTO(listOf("wire.com")), + FeatureFlagStatusDTO.ENABLED ), - FeatureConfigData.ClassifiedDomains(ClassifiedDomainsConfigDTO(listOf("wire.com")), FeatureFlagStatusDTO.ENABLED), FeatureConfigData.ConferenceCalling(FeatureFlagStatusDTO.ENABLED), FeatureConfigData.ConversationGuestLinks(FeatureFlagStatusDTO.ENABLED), FeatureConfigData.DigitalSignatures(FeatureFlagStatusDTO.ENABLED), @@ -140,7 +158,8 @@ class FeatureConfigMapperTest { FeatureConfigData.MLS( MLSConfigDTO( listOf("someId"), - ConvProtocol.MLS, + SupportedProtocolDTO.MLS, + listOf(SupportedProtocolDTO.MLS), emptyList(), 1 ), FeatureFlagStatusDTO.ENABLED @@ -148,6 +167,13 @@ class FeatureConfigMapperTest { FeatureConfigData.E2EI( E2EIConfigDTO("url", 1_000_000L), FeatureFlagStatusDTO.ENABLED + ), + FeatureConfigData.MLSMigration( + MLSMigrationConfigDTO( + Instant.DISTANT_FUTURE, + Instant.DISTANT_FUTURE + ), + FeatureFlagStatusDTO.ENABLED ) ) @@ -155,5 +181,4 @@ class FeatureConfigMapperTest { fun arrange() = this to mapper } - } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigRepositoryTest.kt index fc560b49d8b..d39bb2b8e69 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigRepositoryTest.kt @@ -18,6 +18,7 @@ package com.wire.kalium.logic.data.featureConfig +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestNetworkException import com.wire.kalium.logic.util.shouldFail @@ -31,7 +32,9 @@ import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConf import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureFlagStatusDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSConfigDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.E2EIConfigDTO +import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSMigrationConfigDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.SelfDeletingMessagesConfigDTO +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.NetworkResponse import io.mockative.Mock @@ -42,6 +45,8 @@ import io.mockative.once import io.mockative.verify import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant import kotlin.test.Test @OptIn(ExperimentalCoroutinesApi::class) @@ -75,10 +80,17 @@ class FeatureConfigRepositoryTest { ConfigsStatusModel(Status.ENABLED), MLSModel( emptyList(), + SupportedProtocol.PROTEUS, + setOf(SupportedProtocol.PROTEUS), Status.ENABLED ), E2EIModel( E2EIConfigModel("url", 1000000L), + com.wire.kalium.logic.data.featureConfig.Status.ENABLED + ), + MLSMigrationModel( + Instant.DISTANT_FUTURE, + Instant.DISTANT_FUTURE, Status.ENABLED ) ) @@ -151,7 +163,8 @@ class FeatureConfigRepositoryTest { FeatureConfigData.MLS( MLSConfigDTO( emptyList(), - ConvProtocol.MLS, + SupportedProtocolDTO.PROTEUS, + listOf(SupportedProtocolDTO.PROTEUS), emptyList(), 1 ), FeatureFlagStatusDTO.ENABLED @@ -159,6 +172,10 @@ class FeatureConfigRepositoryTest { FeatureConfigData.E2EI( E2EIConfigDTO("url", 1000000L), FeatureFlagStatusDTO.ENABLED + ), + FeatureConfigData.MLSMigration( + MLSMigrationConfigDTO(Instant.DISTANT_FUTURE, Instant.DISTANT_FUTURE), + FeatureFlagStatusDTO.ENABLED ) ) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigTest.kt index 3e213094566..b41e37e166b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/featureConfig/FeatureConfigTest.kt @@ -17,6 +17,9 @@ */ package com.wire.kalium.logic.data.featureConfig +import com.wire.kalium.logic.data.user.SupportedProtocol +import kotlinx.datetime.Instant + object FeatureConfigTest { @Suppress("LongParameterList") @@ -40,8 +43,13 @@ object FeatureConfigTest { secondFactorPasswordChallengeModel: ConfigsStatusModel = ConfigsStatusModel(Status.ENABLED), ssoModel: ConfigsStatusModel = ConfigsStatusModel(Status.ENABLED), validateSAMLEmailsModel: ConfigsStatusModel = ConfigsStatusModel(Status.ENABLED), - mlsModel: MLSModel = MLSModel(listOf(), Status.ENABLED), - e2EIModel: E2EIModel = E2EIModel(E2EIConfigModel("url", 10000L), Status.ENABLED) + mlsModel: MLSModel = MLSModel(listOf(), SupportedProtocol.PROTEUS, setOf(SupportedProtocol.PROTEUS), Status.ENABLED), + e2EIModel: E2EIModel = E2EIModel(E2EIConfigModel("url", 10000L), Status.ENABLED), + mlsMigrationModel: MLSMigrationModel? = MLSMigrationModel( + Instant.DISTANT_FUTURE, + Instant.DISTANT_FUTURE, + Status.ENABLED + ) ): FeatureConfigModel = FeatureConfigModel( appLockModel, classifiedDomainsModel, @@ -57,6 +65,7 @@ object FeatureConfigTest { ssoModel, validateSAMLEmailsModel, mlsModel, - e2EIModel + e2EIModel, + mlsMigrationModel ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/message/MessageRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/message/MessageRepositoryTest.kt index 6c6169f796e..f358ae99fb7 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/message/MessageRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/message/MessageRepositoryTest.kt @@ -18,6 +18,7 @@ package com.wire.kalium.logic.data.message +import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.asset.AssetMapper import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.Recipient @@ -32,6 +33,7 @@ import com.wire.kalium.logic.feature.message.MessageTarget import com.wire.kalium.logic.framework.TestMessage.TEST_MESSAGE_ID import com.wire.kalium.logic.framework.TestUser.OTHER_USER_ID_2 import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi import com.wire.kalium.network.api.base.authenticated.message.MessageApi @@ -63,6 +65,7 @@ import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertIs import kotlin.test.assertSame import kotlin.test.assertTrue @@ -428,6 +431,47 @@ class MessageRepositoryTest { .wasInvoked(exactly = once) } + @Test + fun givenConversationIds_whenMovingMessages_thenShouldCallDAOWithCorrectParameters() = runTest { + val sourceConversationId = TEST_CONVERSATION_ID.copy(value = "source") + val targetConversationId = TEST_CONVERSATION_ID.copy(value = "target") + + val (arrangement, messageRepository) = Arrangement() + .withMovingToAnotherConversationSucceeding() + .arrange() + + messageRepository.moveMessagesToAnotherConversation( + sourceConversationId, + targetConversationId + ).shouldSucceed() + + verify(arrangement.messageDAO) + .suspendFunction(arrangement.messageDAO::moveMessages) + .with( + eq(sourceConversationId.toDao()), + eq(targetConversationId.toDao()) + ) + .wasInvoked(exactly = once) + } + + @Test + fun givenDAOFails_whenMovingMessages_thenShouldPropagateFailure() = runTest { + val exception = IllegalArgumentException("Oopsie doopsie!") + val (_, messageRepository) = Arrangement() + .withMovingToAnotherConversationFailingWith(exception) + .arrange() + val sourceConversationId = TEST_CONVERSATION_ID.copy(value = "source") + val targetConversationId = TEST_CONVERSATION_ID.copy(value = "target") + + messageRepository.moveMessagesToAnotherConversation( + sourceConversationId, + targetConversationId + ).shouldFail { + assertIs(it) + assertEquals(exception, it.rootCause) + } + } + private class Arrangement { @Mock @@ -547,6 +591,20 @@ class MessageRepositoryTest { .then { _, _, _, _ -> Unit } } + fun withMovingToAnotherConversationSucceeding() = apply { + given(messageDAO) + .suspendFunction(messageDAO::moveMessages) + .whenInvokedWith(any()) + .thenReturn(Unit) + } + + fun withMovingToAnotherConversationFailingWith(throwable: Throwable) = apply { + given(messageDAO) + .suspendFunction(messageDAO::moveMessages) + .whenInvokedWith(any()) + .thenThrow(throwable) + } + fun arrange() = this to MessageDataSource( messageApi = messageApi, mlsMessageApi = mlsMessageApi, @@ -610,3 +668,4 @@ class MessageRepositoryTest { ) } } + diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/SearchUserRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/SearchUserRepositoryTest.kt index be726620552..693c491f5dc 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/SearchUserRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/SearchUserRepositoryTest.kt @@ -522,13 +522,13 @@ class SearchUserRepositoryTest { email = null, expiresAt = null, nonQualifiedId = "value", - service = null + service = null, + supportedProtocols = null ) ) ) const val JSON_QUALIFIED_ID = """{"value":"test" , "domain":"test" }""" - } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/UserSearchApiWrapperTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/UserSearchApiWrapperTest.kt index 27c43e06f3c..f3015479935 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/UserSearchApiWrapperTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/publicuser/UserSearchApiWrapperTest.kt @@ -449,7 +449,8 @@ class UserSearchApiWrapperTest { previewPicture = null, completePicture = null, availabilityStatus = UserAvailabilityStatus.AVAILABLE, - expiresAt = null + expiresAt = null, + supportedProtocols = null, ) } @@ -465,7 +466,8 @@ class UserSearchApiWrapperTest { previewPicture = null, completePicture = null, availabilityStatus = UserAvailabilityStatus.AVAILABLE, - expiresAt = null + expiresAt = null, + supportedProtocols = null ) const val JSON_QUALIFIED_ID = """{"value":"test" , "domain":"test" }""" @@ -486,7 +488,9 @@ class UserSearchApiWrapperTest { botService = null, deleted = false, expiresAt = null, - defederated = false + defederated = false, + supportedProtocols = null, + activeOneOnOneConversationId = null ) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/reaction/ReactionRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/reaction/ReactionRepositoryTest.kt index 882386a818c..d0677566a5a 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/reaction/ReactionRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/reaction/ReactionRepositoryTest.kt @@ -82,7 +82,7 @@ class ReactionRepositoryTest { } suspend fun insertInitialData() { - userDao.insertUser(TEST_SELF_USER_ENTITY) + userDao.upsertUser(TEST_SELF_USER_ENTITY) conversationDao.insertConversation(TEST_CONVERSATION_ENTITY) messageDao.insertOrIgnoreMessage(TEST_MESSAGE_ENTITY) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt index e126fa567d1..68f4e71c1c8 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/register/RegisterAccountRepositoryTest.kt @@ -339,7 +339,8 @@ class RegisterAccountRepositoryTest { locale = "", managedByDTO = null, phone = null, - ssoID = null + ssoID = null, + supportedProtocols = null ) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/team/TeamRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/team/TeamRepositoryTest.kt index e2dc66188b8..c5b4142d858 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/team/TeamRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/team/TeamRepositoryTest.kt @@ -122,7 +122,7 @@ class TeamRepositoryTest { // Verifies that userDAO insertUsers was called with the correct mapped values verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::upsertTeamMembersTypes) + .suspendFunction(arrangement.userDAO::upsertTeamMemberUserTypes) .with(any()) .wasInvoked(exactly = once) @@ -236,7 +236,7 @@ class TeamRepositoryTest { result.shouldSucceed() verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::insertUser) + .suspendFunction(arrangement.userDAO::upsertUser) .with(any()) .wasInvoked(once) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserMapperTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserMapperTest.kt index 9e6267f440f..4e6ccc95c09 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserMapperTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserMapperTest.kt @@ -18,22 +18,13 @@ package com.wire.kalium.logic.data.user -import com.wire.kalium.logic.data.id.TeamId -import com.wire.kalium.logic.data.team.TeamRole -import com.wire.kalium.logic.framework.TestTeam import com.wire.kalium.logic.framework.TestUser -import com.wire.kalium.network.api.base.authenticated.TeamsApi import com.wire.kalium.persistence.dao.ConnectionEntity -import com.wire.kalium.persistence.dao.QualifiedIDEntity -import com.wire.kalium.persistence.dao.UserAvailabilityStatusEntity -import com.wire.kalium.persistence.dao.UserEntity import com.wire.kalium.persistence.dao.UserTypeEntity -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals -@OptIn(ExperimentalCoroutinesApi::class) class UserMapperTest { @Test @@ -56,46 +47,6 @@ class UserMapperTest { assertEquals(expectedResult, result) } - @Test - fun givenTeamMemberApiModel_whenMappingFromApiResponse_thenDaoModelIsReturned() = runTest { - val apiModel = TestTeam.memberDTO( - nonQualifiedUserId = "teamMember1", - permissions = TeamsApi.Permissions(TeamRole.Member.value, TeamRole.Member.value) - ) - - val expectedResult = UserEntity( - id = QualifiedIDEntity( - value = "teamMember1", - domain = "userDomain" - ), - name = null, - handle = null, - email = null, - phone = null, - accentId = 1, - team = "teamId", - connectionStatus = ConnectionEntity.State.ACCEPTED, - previewAssetId = null, - completeAssetId = null, - availabilityStatus = UserAvailabilityStatusEntity.NONE, - userType = UserTypeEntity.STANDARD, - botService = null, - deleted = false, - expiresAt = null, - defederated = false - ) - val (_, userMapper) = Arrangement().arrange() - - val result = userMapper.fromTeamMemberToDaoModel( - teamId = TeamId("teamId"), - userDomain = "userDomain", - nonQualifiedUserId = "teamMember1", - permissionCode = apiModel.permissions?.own - ) - - assertEquals(expectedResult, result) - } - private class Arrangement { private val userMapper = UserMapperImpl() diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserRepositoryTest.kt index 01955dd8977..7019af4aafe 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/user/UserRepositoryTest.kt @@ -19,6 +19,7 @@ package com.wire.kalium.logic.data.user import app.cash.turbine.test +import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.id.QualifiedID import com.wire.kalium.logic.data.id.QualifiedIdMapper import com.wire.kalium.logic.data.id.toApi @@ -34,7 +35,7 @@ import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.framework.TestUser.LIST_USERS_DTO import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.getOrNull -import com.wire.kalium.logic.sync.receiver.UserEventReceiverTest +import com.wire.kalium.logic.test_util.TestNetworkResponseError import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.self.SelfApi @@ -62,6 +63,7 @@ import io.mockative.given import io.mockative.matching import io.mockative.mock import io.mockative.once +import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow @@ -71,6 +73,7 @@ import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertIs import kotlin.test.assertTrue class UserRepositoryTest { @@ -126,25 +129,16 @@ class UserRepositoryTest { @Test fun givenAUserEvent_whenPersistingTheUser_thenShouldSucceed() = runTest { val (arrangement, userRepository) = Arrangement() - .withMapperQualifiedUserId() + .withUpdateUserReturning(true) .arrange() - val result = userRepository.updateUserFromEvent(TestEvent.updateUser(userId = UserEventReceiverTest.SELF_USER_ID)) + val result = userRepository.updateUserFromEvent(TestEvent.updateUser(userId = SELF_USER.id)) with(result) { shouldSucceed() - - verify(arrangement.qualifiedIdMapper) - .function(arrangement.qualifiedIdMapper::fromStringToQualifiedID) - .with(any()) - .wasInvoked(exactly = once) - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::observeUserDetailsByQualifiedID) - .with(any()) - .wasInvoked(exactly = once) verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::updateUser) - .with(any()) + .with(any(), any()) .wasInvoked(exactly = once) } } @@ -152,27 +146,17 @@ class UserRepositoryTest { @Test fun givenAUserEvent_whenPersistingTheUserAndNotExists_thenShouldFail() = runTest { val (arrangement, userRepository) = Arrangement() - .withMapperQualifiedUserId() - .withUserDaoReturning(null) + .withUpdateUserReturning(false) .arrange() - val result = userRepository.updateUserFromEvent(TestEvent.updateUser(userId = UserEventReceiverTest.SELF_USER_ID)) + val result = userRepository.updateUserFromEvent(TestEvent.updateUser(userId = SELF_USER.id)) with(result) { shouldFail() - - verify(arrangement.qualifiedIdMapper) - .function(arrangement.qualifiedIdMapper::fromStringToQualifiedID) - .with(any()) - .wasInvoked(exactly = once) - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::observeUserDetailsByQualifiedID) - .with(any()) - .wasInvoked(exactly = once) verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::updateUser) - .with(any()) - .wasNotInvoked() + .with(any(), any()) + .wasInvoked(exactly = once) } } @@ -294,14 +278,10 @@ class UserRepositoryTest { .suspendFunction(arrangement.userDetailsApi::getUserInfo) .with(any()) .wasInvoked(exactly = once) - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::upsertTeamMembers) - .with(any()) - .wasInvoked(exactly = once) verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::upsertUsers) .with(any()) - .wasInvoked(exactly = once) + .wasInvoked() } } @@ -319,10 +299,6 @@ class UserRepositoryTest { .suspendFunction(arrangement.userDetailsApi::getUserInfo) .with(any()) .wasNotInvoked() - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::upsertTeamMembers) - .with(any()) - .wasNotInvoked() verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::upsertUsers) .with(any()) @@ -344,14 +320,10 @@ class UserRepositoryTest { .suspendFunction(arrangement.userDetailsApi::getUserInfo) .with(any()) .wasInvoked(exactly = once) - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::upsertTeamMembers) - .with(any()) - .wasInvoked(exactly = once) verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::upsertUsers) .with(any()) - .wasInvoked(exactly = once) + .wasInvoked(exactly = twice) } val resultSecondTime = userRepository.getKnownUser(TestUser.USER_ID) @@ -360,10 +332,6 @@ class UserRepositoryTest { .suspendFunction(arrangement.userDetailsApi::getUserInfo) .with(any()) .wasNotInvoked() - verify(arrangement.userDAO) - .suspendFunction(arrangement.userDAO::upsertTeamMembers) - .with(any()) - .wasNotInvoked() verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::upsertUsers) .with(any()) @@ -391,7 +359,7 @@ class UserRepositoryTest { verify(arrangement.userDAO) .suspendFunction(arrangement.userDAO::upsertUsers) .with(matching { - it.first().name != null + it.firstOrNull()?.name != null }) .wasInvoked(exactly = once) } @@ -418,7 +386,6 @@ class UserRepositoryTest { .wasNotInvoked() } - @Test fun whenRemovingUserBrokenAsset_thenShouldCallDaoAndSucceed() = runTest { // Given @@ -539,6 +506,89 @@ class UserRepositoryTest { .wasInvoked(once) } + @Test + fun givenANewSupportedProtocols_whenUpdatingOk_thenShouldSucceedAndPersistTheSupportedProtocolsLocally() = runTest { + val successResponse = NetworkResponse.Success(Unit, mapOf(), HttpStatusCode.OK.value) + val (arrangement, userRepository) = Arrangement() + .withGetSelfUserId() + .withUpdateSupportedProtocolsApiRequestResponse(successResponse) + .arrange() + + val result = userRepository.updateSupportedProtocols(setOf(SupportedProtocol.MLS)) + + with(result) { + shouldSucceed() + verify(arrangement.selfApi) + .suspendFunction(arrangement.selfApi::updateSupportedProtocols) + .with(any()) + .wasInvoked(exactly = once) + verify(arrangement.userDAO) + .suspendFunction(arrangement.userDAO::updateUserSupportedProtocols) + .with(any(), any()) + .wasInvoked(exactly = once) + } + } + + @Test + fun givenANewSupportedProtocols_whenUpdatingFails_thenShouldNotPersistSupportedProtocolsLocally() = runTest { + val (arrangement, userRepository) = Arrangement() + .withGetSelfUserId() + .withUpdateSupportedProtocolsApiRequestResponse(TestNetworkResponseError.genericResponseError()) + .arrange() + + val result = userRepository.updateSupportedProtocols(setOf(SupportedProtocol.MLS)) + + with(result) { + shouldFail() + verify(arrangement.selfApi) + .suspendFunction(arrangement.selfApi::updateSupportedProtocols) + .with(any()) + .wasInvoked(exactly = once) + verify(arrangement.userDAO) + .suspendFunction(arrangement.userDAO::updateUserSupportedProtocols) + .with(any(), any()) + .wasNotInvoked() + } + } + + @Test + fun givenUserIdAndConversationId_whenUpdatingOneOnOneConversation_thenShouldCallDAOWithCorrectArguments() = runTest { + val userId = TestUser.USER_ID + val conversationId = TestConversation.CONVERSATION.id + + val (arrangement, userRepository) = Arrangement() + .withUpdateOneOnOneConversationSuccess() + .arrange() + + userRepository.updateActiveOneOnOneConversation( + userId, + conversationId + ).shouldSucceed() + + verify(arrangement.userDAO) + .suspendFunction(arrangement.userDAO::updateActiveOneOnOneConversation) + .with(eq(userId.toDao()), eq(conversationId.toDao())) + .wasInvoked(exactly = once) + } + + @Test + fun givenDAOFails_whenUpdatingOneOnOneConversation_thenShouldPropagateException() = runTest { + val exception = IllegalStateException("Oopsie Doopsie!") + val (_, connectionRepository) = Arrangement() + .withUpdateOneOnOneConversationFailing(exception) + .arrange() + val userId = TestUser.USER_ID + val conversationId = TestConversation.CONVERSATION.id + + connectionRepository.updateActiveOneOnOneConversation( + userId, + conversationId + ).shouldFail { + assertIs(it) + assertEquals(exception, it.rootCause) + } + } + private class Arrangement { @Mock val userDAO = configure(mock(classOf())) { stubsUnitByDefault = true } @@ -575,7 +625,6 @@ class UserRepositoryTest { userDetailsApi, sessionRepository, selfUserId, - qualifiedIdMapper, selfTeamIdProvider ) } @@ -613,6 +662,13 @@ class UserRepositoryTest { .thenReturn(flowOf(userEntities)) } + fun withUpdateUserReturning(updated: Boolean) = apply { + given(userDAO) + .suspendFunction(userDAO::updateUser) + .whenInvokedWith(any(), any()) + .thenReturn(updated) + } + fun withSuccessfulGetUsersInfo() = apply { given(userDetailsApi) .suspendFunction(userDetailsApi::getUserInfo) @@ -681,6 +737,13 @@ class UserRepositoryTest { .thenReturn(response) } + fun withUpdateSupportedProtocolsApiRequestResponse(response: NetworkResponse) = apply { + given(selfApi) + .suspendFunction(selfApi::updateSupportedProtocols) + .whenInvokedWith(any()) + .thenReturn(response) + } + fun withRemoteUpdateEmail(result: NetworkResponse) = apply { given(selfApi) .suspendFunction(selfApi::updateEmailAddress) @@ -725,6 +788,20 @@ class UserRepositoryTest { .thenReturn(Unit) } + fun withUpdateOneOnOneConversationSuccess() = apply { + given(userDAO) + .suspendFunction(userDAO::updateActiveOneOnOneConversation) + .whenInvokedWith(any(), any()) + .thenReturn(Unit) + } + + fun withUpdateOneOnOneConversationFailing(exception: Throwable) = apply { + given(userDAO) + .suspendFunction(userDAO::updateActiveOneOnOneConversation) + .whenInvokedWith(any(), any()) + .thenThrow(exception) + } + fun arrange(block: (Arrangement.() -> Unit) = { }): Pair { apply(block) return this to userRepository @@ -732,6 +809,6 @@ class UserRepositoryTest { } private companion object { - val SELF_USER = TestUser.SELF_USER_DTO + val SELF_USER = TestUser.SELF } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCaseTest.kt index e658b20cc6c..026f03aec9f 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/backup/RestoreWebBackupUseCaseTest.kt @@ -192,7 +192,7 @@ class RestoreWebBackupUseCaseTest { fun arrange() = this to RestoreWebBackupUseCaseImpl( kaliumFileSystem = fakeFileSystem, - userId = selfUserId, + selfUserId = selfUserId, migrationDAO = migrationDAO, persistMigratedMessages = persistMigratedMessagesUseCase, restartSlowSyncProcessForRecovery = restartSlowSyncProcessForRecoveryUseCase diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/call/usecase/EndCallOnConversationChangeUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/call/usecase/EndCallOnConversationChangeUseCaseTest.kt index 7b0c5ad8183..5ae1d896529 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/call/usecase/EndCallOnConversationChangeUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/call/usecase/EndCallOnConversationChangeUseCaseTest.kt @@ -11,6 +11,7 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.OtherUser +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserAssetId import com.wire.kalium.logic.data.user.UserAvailabilityStatus import com.wire.kalium.logic.data.user.UserId @@ -171,7 +172,8 @@ class EndCallOnConversationChangeUseCaseTest { botService = null, deleted = true, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = setOf(SupportedProtocol.PROTEUS) ) private val groupConversationDetail = ConversationDetails.Group( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCaseTest.kt index 85ff85e434d..c90409ba8a9 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/DeleteClientUseCaseTest.kt @@ -18,15 +18,22 @@ package com.wire.kalium.logic.feature.client +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.client.DeleteClientParam +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestNetworkException +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangement +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangementImpl import com.wire.kalium.network.exceptions.KaliumException import io.ktor.utils.io.errors.IOException import io.mockative.Mock +import io.mockative.any import io.mockative.anything import io.mockative.classOf import io.mockative.eq @@ -35,35 +42,24 @@ import io.mockative.mock import io.mockative.once import io.mockative.verify import kotlinx.coroutines.test.runTest -import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertIs import kotlin.test.assertSame class DeleteClientUseCaseTest { - @Mock - private val clientRepository = mock(classOf()) - - private lateinit var deleteClient: DeleteClientUseCase - - @BeforeTest - fun setup() { - deleteClient = DeleteClientUseCaseImpl(clientRepository) - } - @Test fun givenDeleteClientParams_whenDeleting_thenTheRepositoryShouldBeCalledWithCorrectParameters() = runTest { val params = DELETE_CLIENT_PARAMETERS - given(clientRepository) - .suspendFunction(clientRepository::deleteClient) - .whenInvokedWith(anything()) - .then { Either.Left(TEST_FAILURE) } + + val (arrangement, deleteClient) = arrange { + withDeleteClient(Either.Left(TEST_FAILURE)) + } deleteClient(params) - verify(clientRepository) - .suspendFunction(clientRepository::deleteClient) + verify(arrangement.clientRepository) + .suspendFunction(arrangement.clientRepository::deleteClient) .with(eq(params)) .wasInvoked(once) } @@ -71,10 +67,9 @@ class DeleteClientUseCaseTest { @Test fun givenRepositoryDeleteClientFailsDueToGenericError_whenDeleting_thenGenericErrorShouldBeReturned() = runTest { val genericFailure = TEST_FAILURE - given(clientRepository) - .suspendFunction(clientRepository::deleteClient) - .whenInvokedWith(anything()) - .then { Either.Left(genericFailure) } + val (_, deleteClient) = arrange { + withDeleteClient(Either.Left(genericFailure)) + } val result = deleteClient(DELETE_CLIENT_PARAMETERS) @@ -85,10 +80,9 @@ class DeleteClientUseCaseTest { @Test fun givenRepositoryDeleteClientFailsDueToWrongPassword_whenDeleting_thenInvalidCredentialsErrorShouldBeReturned() = runTest { val wrongPasswordFailure = NetworkFailure.ServerMiscommunication(TestNetworkException.invalidCredentials) - given(clientRepository) - .suspendFunction(clientRepository::deleteClient) - .whenInvokedWith(anything()) - .then { Either.Left(wrongPasswordFailure) } + val (_, deleteClient) = arrange { + withDeleteClient(Either.Left(wrongPasswordFailure)) + } val result = deleteClient(DELETE_CLIENT_PARAMETERS) @@ -98,10 +92,9 @@ class DeleteClientUseCaseTest { @Test fun givenRepositoryDeleteClientFailsDueToMissingPassword_whenDeleting_thenPasswordAuthRequiredErrorShouldBeReturned() = runTest { val missingPasswordFailure = NetworkFailure.ServerMiscommunication(TestNetworkException.missingAuth) - given(clientRepository) - .suspendFunction(clientRepository::deleteClient) - .whenInvokedWith(anything()) - .then { Either.Left(missingPasswordFailure) } + val (_, deleteClient) = arrange { + withDeleteClient(Either.Left(missingPasswordFailure)) + } val result = deleteClient(DELETE_CLIENT_PARAMETERS) @@ -111,20 +104,69 @@ class DeleteClientUseCaseTest { @Test fun givenRepositoryDeleteClientFailsDueToBadRequest_whenDeleting_thenInvalidCredentialsErrorShouldBeReturned() = runTest { val badRequest = NetworkFailure.ServerMiscommunication(TestNetworkException.badRequest) - given(clientRepository) - .suspendFunction(clientRepository::deleteClient) - .whenInvokedWith(anything()) - .then { Either.Left(badRequest) } + val (_, deleteClient) = arrange { + withDeleteClient(Either.Left(badRequest)) + } val result = deleteClient(DELETE_CLIENT_PARAMETERS) assertIs(result) } + @Test + fun givenRepositoryDeleteClientSucceeds_whenDeleting_thenUpdateSupportedProtocols() = runTest { + val (arrangement, deleteClient) = arrange { + withDeleteClient(Either.Right(Unit)) + withUpdateSupportedProtocolsAndResolveOneOnOnes(Either.Right(Unit)) + } + + val result = deleteClient(DELETE_CLIENT_PARAMETERS) + + assertIs(result) + verify(arrangement.updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(arrangement.updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .with(eq(true)) + .wasInvoked(exactly = once) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + UserRepositoryArrangement by UserRepositoryArrangementImpl(), + OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl() + { + @Mock + val clientRepository = mock(classOf()) + + @Mock + val updateSupportedProtocolsAndResolveOneOnOnes = mock(classOf()) + + fun withDeleteClient(result: Either) { + given(clientRepository) + .suspendFunction(clientRepository::deleteClient) + .whenInvokedWith(anything()) + .then { result } + } + + fun withUpdateSupportedProtocolsAndResolveOneOnOnes(result: Either) { + given(updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .whenInvokedWith(any()) + .thenReturn(result) + } + + fun arrange() = run { + block() + this@Arrangement to DeleteClientUseCaseImpl( + clientRepository = clientRepository, + updateSupportedProtocolsAndResolveOneOnOnes = updateSupportedProtocolsAndResolveOneOnOnes, + ) + } + } + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + val CLIENT = TestClient.CLIENT val DELETE_CLIENT_PARAMETERS = DeleteClientParam("pass", CLIENT.id) val TEST_FAILURE = NetworkFailure.ServerMiscommunication(KaliumException.GenericError(IOException("no internet"))) - } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/MLSClientManagerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/MLSClientManagerTest.kt index 1d067f3b8a5..39638999c98 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/MLSClientManagerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/client/MLSClientManagerTest.kt @@ -48,7 +48,7 @@ class MLSClientManagerTest { fun givenMLSSupportIsDisabled_whenObservingSyncFinishes_thenMLSClientIsNotRegistered() = runTest(TestKaliumDispatcher.default) { val (arrangement, _) = Arrangement() - .withIsMLSEnabled(false) + .withIsAllowedToRegisterMLSClient(false) .arrange() arrangement.incrementalSyncRepository.updateIncrementalSyncState(IncrementalSyncStatus.Live) @@ -64,7 +64,7 @@ class MLSClientManagerTest { fun givenMLSClientIsNotRegistered_whenObservingSyncFinishes_thenMLSClientIsRegistered() = runTest(TestKaliumDispatcher.default) { val (arrangement, _) = Arrangement() - .withIsMLSEnabled(true) + .withIsAllowedToRegisterMLSClient(true) .withHasRegisteredMLSClient(Either.Right(false)) .withCurrentClientId(Either.Right(TestClient.CLIENT_ID)) .withRegisterMLSClientSuccessful() @@ -87,7 +87,7 @@ class MLSClientManagerTest { fun givenMLSClientIsRegistered_whenObservingSyncFinishes_thenMLSClientIsNotRegistered() = runTest(TestKaliumDispatcher.default) { val (arrangement, _) = Arrangement() - .withIsMLSEnabled(true) + .withIsAllowedToRegisterMLSClient(true) .withHasRegisteredMLSClient(Either.Right(true)) .arrange() @@ -114,7 +114,7 @@ class MLSClientManagerTest { val clientRepository = mock(classOf()) @Mock - val isMLSEnabled = mock(classOf()) + val isAllowedToRegisterMLSClient = mock(classOf()) @Mock val registerMLSClient = mock(classOf()) @@ -140,16 +140,16 @@ class MLSClientManagerTest { .thenReturn(Either.Right(Unit)) } - fun withIsMLSEnabled(enabled: Boolean) = apply { - given(isMLSEnabled) - .function(isMLSEnabled::invoke) + fun withIsAllowedToRegisterMLSClient(enabled: Boolean) = apply { + given(isAllowedToRegisterMLSClient) + .suspendFunction(isAllowedToRegisterMLSClient::invoke) .whenInvoked() .thenReturn(enabled) } fun arrange() = this to MLSClientManagerImpl( clientIdProvider, - isMLSEnabled, + isAllowedToRegisterMLSClient, incrementalSyncRepository, lazy { slowSyncRepository }, lazy { clientRepository }, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCaseTest.kt index 8b824d002c0..de9d672683b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/AcceptConnectionRequestUseCaseTest.kt @@ -19,100 +19,140 @@ package com.wire.kalium.logic.feature.connection import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.data.connection.ConnectionRepository -import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.user.Connection import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.functional.Either -import io.mockative.Mock +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangement +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConnectionRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConnectionRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl import io.mockative.any import io.mockative.eq -import io.mockative.given -import io.mockative.mock import io.mockative.once import io.mockative.verify import kotlinx.coroutines.test.runTest -import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals class AcceptConnectionRequestUseCaseTest { - @Mock - private val connectionRepository: ConnectionRepository = mock(ConnectionRepository::class) - - @Mock - private val conversationRepository: ConversationRepository = mock(ConversationRepository::class) + @Test + fun givenSuccess_whenInvokingUseCase_thenShouldUpdateConnectionStatusToAccepted() = runTest { + // given + val (arrangement, acceptConnectionRequestUseCase) = arrange { + withUpdateConnectionStatus(Either.Right(CONNECTION)) + withFetchConversation(Either.Right(Unit)) + withUpdateConversationModifiedDate(Either.Right(Unit)) + withResolveOneOnOneConversationWithUserIdReturning(Either.Right(TestConversation.ID)) + } - lateinit var acceptConnectionRequestUseCase: AcceptConnectionRequestUseCase + // when + val result = acceptConnectionRequestUseCase(USER_ID) - @BeforeTest - fun setUp() { - acceptConnectionRequestUseCase = AcceptConnectionRequestUseCaseImpl(connectionRepository, conversationRepository) + // then + assertEquals(AcceptConnectionRequestUseCaseResult.Success, result) + verify(arrangement.connectionRepository) + .suspendFunction(arrangement.connectionRepository::updateConnectionStatus) + .with(eq(USER_ID), eq(ConnectionState.ACCEPTED)) + .wasInvoked(once) } @Test - fun givenAConnectionRequest_whenInvokingAcceptConnectionRequestAndOk_thenShouldReturnsASuccessResult() = runTest { + fun givenSuccess_whenInvokingUseCase_thenShouldUpdateConversationModifiedDate() = runTest { // given - given(connectionRepository) - .suspendFunction(connectionRepository::updateConnectionStatus) - .whenInvokedWith(eq(userId), eq(ConnectionState.ACCEPTED)) - .thenReturn(Either.Right(connection)) + val (arrangement, acceptConnectionRequestUseCase) = arrange { + withUpdateConnectionStatus(Either.Right(CONNECTION)) + withFetchConversation(Either.Right(Unit)) + withUpdateConversationModifiedDate(Either.Right(Unit)) + withResolveOneOnOneConversationWithUserIdReturning(Either.Right(TestConversation.ID)) + } - given(conversationRepository) - .suspendFunction(conversationRepository::fetchConversation) - .whenInvokedWith(eq(conversationId)) - .thenReturn(Either.Right(Unit)) + // when + val result = acceptConnectionRequestUseCase(USER_ID) - given(conversationRepository) - .suspendFunction(conversationRepository::updateConversationModifiedDate) - .whenInvokedWith(eq(conversationId), any()) - .thenReturn(Either.Right(Unit)) + // then + assertEquals(AcceptConnectionRequestUseCaseResult.Success, result) + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateConversationModifiedDate) + .with(eq(CONNECTION.qualifiedConversationId), any()) + .wasInvoked(once) + } + + @Test + fun givenSuccess_whenInvokingUseCase_thenShouldResolveActiveOneOnOneConversation() = runTest { + // given + val (arrangement, acceptConnectionRequestUseCase) = arrange { + withUpdateConnectionStatus(Either.Right(CONNECTION)) + withFetchConversation(Either.Right(Unit)) + withUpdateConversationModifiedDate(Either.Right(Unit)) + withResolveOneOnOneConversationWithUserIdReturning(Either.Right(TestConversation.ID)) + } // when - val resultOk = acceptConnectionRequestUseCase(userId) + val result = acceptConnectionRequestUseCase(USER_ID) // then - assertEquals(AcceptConnectionRequestUseCaseResult.Success, resultOk) - verify(connectionRepository) - .suspendFunction(connectionRepository::updateConnectionStatus) - .with(eq(userId), eq(ConnectionState.ACCEPTED)) + assertEquals(AcceptConnectionRequestUseCaseResult.Success, result) + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::resolveOneOnOneConversationWithUserId) + .with(eq(CONNECTION.qualifiedToId)) .wasInvoked(once) } @Test - fun givenAConnectionRequest_whenInvokingAcceptConnectionRequestAndFails_thenShouldReturnsAFailureResult() = runTest { + fun givenFailure_whenInvokingUseCase_thenShouldReturnsAFailureResult() = runTest { // given - given(connectionRepository) - .suspendFunction(connectionRepository::updateConnectionStatus) - .whenInvokedWith(eq(userId), eq(ConnectionState.ACCEPTED)) - .thenReturn(Either.Left(CoreFailure.Unknown(RuntimeException("Some error")))) + val failure = CoreFailure.Unknown(RuntimeException("Some error")) + val (arrangement, acceptConnectionRequestUseCase) = arrange { + withUpdateConnectionStatus(Either.Left(failure)) + } // when - val resultFailure = acceptConnectionRequestUseCase(userId) + val resultFailure = acceptConnectionRequestUseCase(USER_ID) // then assertEquals(AcceptConnectionRequestUseCaseResult.Failure::class, resultFailure::class) - verify(connectionRepository) - .suspendFunction(connectionRepository::updateConnectionStatus) - .with(eq(userId), eq(ConnectionState.ACCEPTED)) + verify(arrangement.connectionRepository) + .suspendFunction(arrangement.connectionRepository::updateConnectionStatus) + .with(eq(USER_ID), eq(ConnectionState.ACCEPTED)) .wasInvoked(once) } + private class Arrangement(private val block: Arrangement.() -> Unit) : + ConnectionRepositoryArrangement by ConnectionRepositoryArrangementImpl(), + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to AcceptConnectionRequestUseCaseImpl( + connectionRepository = connectionRepository, + conversationRepository = conversationRepository, + oneOnOneResolver = oneOnOneResolver + ) + } + } + private companion object { - val userId = UserId("some_user", "some_domain") - val conversationId = ConversationId("someId", "someDomain") - val connection = Connection( + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + val USER_ID = UserId("some_user", "some_domain") + val CONVERSATION_ID = ConversationId("someId", "someDomain") + val CONNECTION = Connection( "someId", "from", "lastUpdate", - conversationId, - conversationId, + CONVERSATION_ID, + CONVERSATION_ID, ConnectionState.ACCEPTED, "toId", null ) } + } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCaseTest.kt index 802f37f47ef..47303fc917c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/connection/SendConnectionRequestUseCaseTest.kt @@ -22,8 +22,10 @@ import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.functional.Either import io.mockative.Mock +import io.mockative.any import io.mockative.classOf import io.mockative.eq import io.mockative.given @@ -41,6 +43,7 @@ class SendConnectionRequestUseCaseTest { // given val (arrangement, sendConnectionRequestUseCase) = Arrangement() .withCreateConnectionResult(Either.Right(Unit)) + .withFetchUserInfoResult(Either.Right(Unit)) .arrange() // when @@ -48,16 +51,44 @@ class SendConnectionRequestUseCaseTest { // then assertEquals(SendConnectionRequestResult.Success, resultOk) + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::fetchUserInfo) + .with(eq(userId)) + .wasInvoked(once) verify(arrangement.connectionRepository) .suspendFunction(arrangement.connectionRepository::sendUserConnection) .with(eq(userId)) .wasInvoked(once) } + @Test + fun givenAConnectionRequest_whenInvokingFetchUserInfoRequestFails_thenShouldReturnsAFailureResult() = runTest { + // given + val (arrangement, sendConnectionRequestUseCase) = Arrangement() + .withFetchUserInfoResult(Either.Left(CoreFailure.Unknown(RuntimeException("Some error")))) + .withCreateConnectionResult(Either.Right(Unit)) + .arrange() + + // when + val resultFailure = sendConnectionRequestUseCase(userId) + + // then + assertEquals(SendConnectionRequestResult.Failure.GenericFailure::class, resultFailure::class) + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::fetchUserInfo) + .with(eq(userId)) + .wasInvoked(once) + verify(arrangement.connectionRepository) + .suspendFunction(arrangement.connectionRepository::sendUserConnection) + .with(eq(userId)) + .wasNotInvoked() + } + @Test fun givenAConnectionRequest_whenInvokingASendAConnectionRequestFails_thenShouldReturnsAFailureResult() = runTest { // given val (arrangement, sendConnectionRequestUseCase) = Arrangement() + .withFetchUserInfoResult(Either.Right(Unit)) .withCreateConnectionResult(Either.Left(CoreFailure.Unknown(RuntimeException("Some error")))) .arrange() @@ -76,6 +107,7 @@ class SendConnectionRequestUseCaseTest { fun givenAConnectionRequest_whenInvokingAndFailsByFederationDenied_thenShouldReturnsAFederationDenied() = runTest { // given val (arrangement, sendConnectionRequestUseCase) = Arrangement() + .withFetchUserInfoResult(Either.Right(Unit)) .withCreateConnectionResult( Either.Left(NetworkFailure.FederatedBackendFailure.FederationDenied("federation-denied")) ) @@ -95,6 +127,9 @@ class SendConnectionRequestUseCaseTest { @Mock val connectionRepository = mock(classOf()) + @Mock + val userRepository = mock(classOf()) + fun withCreateConnectionResult(result: Either) = apply { given(connectionRepository) .suspendFunction(connectionRepository::sendUserConnection) @@ -102,7 +137,14 @@ class SendConnectionRequestUseCaseTest { .thenReturn(result) } - fun arrange() = this to SendConnectionRequestUseCaseImpl(connectionRepository) + fun withFetchUserInfoResult(result: Either) = apply { + given(userRepository) + .suspendFunction(userRepository::fetchUserInfo) + .whenInvokedWith(any()) + .thenReturn(result) + } + + fun arrange() = this to SendConnectionRequestUseCaseImpl(connectionRepository, userRepository) } private companion object { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetAllContactsNotInTheConversationUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetAllContactsNotInTheConversationUseCaseTest.kt index af3dae1fdc2..3766553f92c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetAllContactsNotInTheConversationUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetAllContactsNotInTheConversationUseCaseTest.kt @@ -93,7 +93,8 @@ class GetAllContactsNotInTheConversationUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ), OtherUser( id = QualifiedID("someAllContactsValue1", "someAllContactsDomain1"), @@ -111,7 +112,8 @@ class GetAllContactsNotInTheConversationUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ) ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCaseTest.kt index 0ec4cffb143..4015859c8b3 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/GetOrCreateOneToOneConversationUseCaseTest.kt @@ -18,120 +18,111 @@ package com.wire.kalium.logic.feature.conversation +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure -import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.Conversation.ProtocolInfo -import com.wire.kalium.logic.data.conversation.ConversationGroupRepository -import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.conversation.MutedConversationStatus -import com.wire.kalium.logic.data.id.ConversationId -import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either -import com.wire.kalium.persistence.dao.conversation.ConversationEntity -import io.mockative.Mock +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangement +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl import io.mockative.anything -import io.mockative.classOf -import io.mockative.given -import io.mockative.mock +import io.mockative.eq import io.mockative.once import io.mockative.verify import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.test.runTest -import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertIs +@Suppress("MaxLineLength") class GetOrCreateOneToOneConversationUseCaseTest { - @Mock - private val conversationRepository = mock(classOf()) - - @Mock - private val conversationGroupRepository = mock(classOf()) - - private lateinit var getOrCreateOneToOneConversationUseCase: GetOrCreateOneToOneConversationUseCase - - @BeforeTest - fun setUp() { - getOrCreateOneToOneConversationUseCase = GetOrCreateOneToOneConversationUseCase( - conversationRepository = conversationRepository, - conversationGroupRepository = conversationGroupRepository - ) - } - @Test - fun givenConversationDoesNotExist_whenCallingTheUseCase_ThenDoNotCreateAConversationButReturnExisting() = runTest { + fun givenConversationExist_whenCallingTheUseCase_ThenReturnExistingConversation() = runTest { // given - given(conversationRepository) - .suspendFunction(conversationRepository::observeOneToOneConversationWithOtherUser) - .whenInvokedWith(anything()) - .thenReturn(flowOf(Either.Right(CONVERSATION))) - - given(conversationRepository) - .suspendFunction(conversationGroupRepository::createGroupConversation) - .whenInvokedWith(anything(), anything(), anything()) - .thenReturn(Either.Right(CONVERSATION)) + val (arrangement, useCase) = arrange { + withObserveOneToOneConversationWithOtherUserReturning(Either.Right(CONVERSATION)) + } + // when - val result = getOrCreateOneToOneConversationUseCase.invoke(USER_ID) + val result = useCase.invoke(OTHER_USER_ID) + // then assertIs(result) - verify(conversationGroupRepository) - .suspendFunction(conversationGroupRepository::createGroupConversation) - .with(anything(), anything(), anything()) + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::resolveOneOnOneConversationWithUser) + .with(anything()) .wasNotInvoked() - verify(conversationRepository) - .suspendFunction(conversationRepository::observeOneToOneConversationWithOtherUser) + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::observeOneToOneConversationWithOtherUser) .with(anything()) .wasInvoked(exactly = once) } @Test - fun givenConversationExist_whenCallingTheUseCase_ThenCreateAConversationAndReturn() = runTest { + fun givenFailure_whenCallingTheUseCase_ThenErrorIsPropagated() = runTest { + // given + val (_, useCase) = arrange { + withObserveOneToOneConversationWithOtherUserReturning(Either.Left(StorageFailure.DataNotFound)) + withGetKnownUserReturning(flowOf(OTHER_USER)) + withResolveOneOnOneConversationWithUserReturning(Either.Left(CoreFailure.NoCommonProtocolFound)) + } + + // when + val result = useCase.invoke(OTHER_USER_ID) + + // then + assertIs(result) + } + + @Test + fun givenConversationDoesNotExist_whenCallingTheUseCase_ThenResolveOneOnOneConversation() = runTest { // given - given(conversationRepository) - .coroutine { observeOneToOneConversationWithOtherUser(USER_ID) } - .then { flowOf(Either.Left(StorageFailure.DataNotFound)) } - - given(conversationGroupRepository) - .suspendFunction(conversationGroupRepository::createGroupConversation) - .whenInvokedWith(anything(), anything(), anything()) - .thenReturn(Either.Right(CONVERSATION)) + val (arrangement, useCase) = arrange { + withObserveOneToOneConversationWithOtherUserReturning(Either.Left(StorageFailure.DataNotFound)) + withGetKnownUserReturning(flowOf(OTHER_USER)) + withResolveOneOnOneConversationWithUserReturning(Either.Right(CONVERSATION.id)) + withGetConversationByIdReturning(CONVERSATION) + } + // when - val result = getOrCreateOneToOneConversationUseCase.invoke(USER_ID) + val result = useCase.invoke(OTHER_USER_ID) + // then assertIs(result) - verify(conversationGroupRepository) - .coroutine { createGroupConversation(usersList = MEMBER) } + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::resolveOneOnOneConversationWithUser) + .with(eq(OTHER_USER)) .wasInvoked(exactly = once) } + private fun arrange(block: Arrangement.() -> Unit) = Arrangement(block).arrange() + + internal class Arrangement( + private val block: Arrangement.() -> Unit + ) : ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + UserRepositoryArrangement by UserRepositoryArrangementImpl(), + OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl() { + + fun arrange() = block().run { + this@Arrangement to GetOrCreateOneToOneConversationUseCaseImpl( + conversationRepository = conversationRepository, + userRepository = userRepository, + oneOnOneResolver = oneOnOneResolver + ) + } + } + private companion object { - val USER_ID = UserId(value = "userId", domain = "domainId") - val MEMBER = listOf(USER_ID) - val CONVERSATION_ID = ConversationId(value = "userId", domain = "domainId") - val CONVERSATION = Conversation( - id = CONVERSATION_ID, - name = null, - type = Conversation.Type.ONE_ON_ONE, - teamId = null, - ProtocolInfo.Proteus, - MutedConversationStatus.AllAllowed, - null, - null, - null, - lastReadDate = "2022-03-30T15:36:00.000Z", - access = listOf(Conversation.Access.CODE, Conversation.Access.INVITE), - accessRole = listOf(Conversation.AccessRole.NON_TEAM_MEMBER, Conversation.AccessRole.GUEST), - creatorId = null, - receiptMode = Conversation.ReceiptMode.DISABLED, - messageTimer = null, - userMessageTimer = null, - archived = false, - archivedDateTime = null, - verificationStatus = Conversation.VerificationStatus.NOT_VERIFIED - ) + val OTHER_USER = TestUser.OTHER + val OTHER_USER_ID = OTHER_USER.id + val CONVERSATION = TestConversation.ONE_ON_ONE() } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt index 3d49a3d1e0d..423d54578d1 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationUseCaseTest.kt @@ -27,11 +27,11 @@ import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID +import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either -import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker -import com.wire.kalium.logic.sync.receiver.conversation.message.MessageUnpackResult import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi @@ -50,6 +50,7 @@ import io.mockative.once import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant import kotlin.test.Test class JoinExistingMLSConversationUseCaseTest { @@ -109,7 +110,7 @@ class JoinExistingMLSConversationUseCaseTest { } @Test - fun givenGroupConversationWithZeroEpoch_whenInvokingUseCase_ThenDoNotEstablishGroup() = + fun givenGroupConversationWithZeroEpoch_whenInvokingUseCase_ThenDoNotEstablishMlsGroup() = runTest { val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement() .withIsMLSSupported(true) @@ -127,7 +128,7 @@ class JoinExistingMLSConversationUseCaseTest { } @Test - fun givenSelfConversationWithZeroEpoch_whenInvokingUseCase_ThenEstablishGroup() = + fun givenSelfConversationWithZeroEpoch_whenInvokingUseCase_ThenEstablishMlsGroup() = runTest { val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement() .withIsMLSSupported(true) @@ -144,6 +145,26 @@ class JoinExistingMLSConversationUseCaseTest { .wasInvoked(once) } + @Test + fun givenOneOnOneConversationWithZeroEpoch_whenInvokingUseCase_ThenEstablishMlsGroup() = + runTest { + val members = listOf(TestUser.USER_ID, TestUser.OTHER_USER_ID) + val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement() + .withIsMLSSupported(true) + .withHasRegisteredMLSClient(true) + .withGetConversationsByIdSuccessful(Arrangement.MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION) + .withGetConversationMembersSuccessful(members) + .withEstablishMLSGroupSuccessful() + .arrange() + + joinExistingMLSConversationsUseCase(Arrangement.MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION.id).shouldSucceed() + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::establishMLSGroup) + .with(eq(Arrangement.GROUP_ID_ONE_ON_ONE), eq(members)) + .wasInvoked(once) + } + @Test fun givenOutOfDateEpochFailure_whenInvokingUseCase_ThenRetryWithNewEpoch() = runTest { val (arrangement, joinExistingMLSConversationsUseCase) = Arrangement() @@ -200,16 +221,12 @@ class JoinExistingMLSConversationUseCaseTest { @Mock val mlsConversationRepository = mock(classOf()) - @Mock - val mlsMessageUnpacker = mock(classOf()) - fun arrange() = this to JoinExistingMLSConversationUseCaseImpl( featureSupport, conversationApi, clientRepository, conversationRepository, - mlsConversationRepository, - mlsMessageUnpacker + mlsConversationRepository ) @Suppress("MaxLineLength") @@ -228,6 +245,13 @@ class JoinExistingMLSConversationUseCaseTest { .then { Either.Right(Unit) } } + fun withGetConversationMembersSuccessful(members: List) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationMembers) + .whenInvokedWith(anything()) + .then { Either.Right(members) } + } + fun withEstablishMLSGroupSuccessful() = apply { given(mlsConversationRepository) .suspendFunction(mlsConversationRepository::establishMLSGroup) @@ -270,13 +294,6 @@ class JoinExistingMLSConversationUseCaseTest { .thenReturn(Either.Right(result)) } - fun withUnpackMlsBundleSuccessful() = apply { - given(mlsMessageUnpacker) - .suspendFunction(mlsMessageUnpacker::unpackMlsBundle) - .whenInvokedWith(anything()) - .thenReturn(MessageUnpackResult.HandshakeMessage) - } - companion object { val PUBLIC_GROUP_STATE = "public_group_state".encodeToByteArray() @@ -303,12 +320,13 @@ class JoinExistingMLSConversationUseCaseTest { val GROUP_ID1 = GroupID("group1") val GROUP_ID2 = GroupID("group2") val GROUP_ID3 = GroupID("group3") + val GROUP_ID_ONE_ON_ONE = GroupID("group-one-on-ne") val GROUP_ID_SELF = GroupID("group-self") val MLS_CONVERSATION1 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID1, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 1UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 @@ -318,7 +336,7 @@ class JoinExistingMLSConversationUseCaseTest { val MLS_CONVERSATION2 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID2, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 1UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 @@ -328,7 +346,7 @@ class JoinExistingMLSConversationUseCaseTest { val MLS_UNESTABLISHED_GROUP_CONVERSATION = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID3, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 0UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 @@ -338,12 +356,22 @@ class JoinExistingMLSConversationUseCaseTest { val MLS_UNESTABLISHED_SELF_CONVERSATION = TestConversation.SELF( Conversation.ProtocolInfo.MLS( GROUP_ID_SELF, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 0UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ) ).copy(id = ConversationId("self", "domain")) + + val MLS_UNESTABLISHED_ONE_ONE_ONE_CONVERSATION = TestConversation.ONE_ON_ONE( + Conversation.ProtocolInfo.MLS( + GROUP_ID_ONE_ON_ONE, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, + epoch = 0UL, + keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), + cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) + ).copy(id = ConversationId("one-on-one", "domain")) } } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt index b83a188f6cb..616404cbe61 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/JoinExistingMLSConversationsUseCaseTest.kt @@ -39,6 +39,7 @@ import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant import kotlin.test.Test @OptIn(ExperimentalCoroutinesApi::class) @@ -167,7 +168,7 @@ class JoinExistingMLSConversationsUseCaseTest { val MLS_CONVERSATION1 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID1, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 1UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 @@ -177,7 +178,7 @@ class JoinExistingMLSConversationsUseCaseTest { val MLS_CONVERSATION2 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID2, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 1UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/MembersToMentionUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/MembersToMentionUseCaseTest.kt index b1d1c6e0ea6..4755fbe5a05 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/MembersToMentionUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/MembersToMentionUseCaseTest.kt @@ -122,7 +122,8 @@ class MembersToMentionUseCaseTest { ConnectionState.ACCEPTED, UserAssetId("value1", DOMAIN), UserAssetId("value2", DOMAIN), - UserAvailabilityStatus.NONE + UserAvailabilityStatus.NONE, + supportedProtocols = null ) private val OTHER_USER = OtherUser( UserId(value = "other-id", DOMAIN), @@ -140,7 +141,8 @@ class MembersToMentionUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ) val members = listOf( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationListDetailsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationListDetailsUseCaseTest.kt index 05fd44e0560..84b839fcc9c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationListDetailsUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/ObserveConversationListDetailsUseCaseTest.kt @@ -179,7 +179,7 @@ class ObserveConversationListDetailsUseCaseTest { @Test fun givenSomeConversationsDetailsAreUpdated_whenObservingDetailsList_thenTheUpdateIsPropagatedThroughTheFlow() = runTest { // Given - val oneOnOneConversation = TestConversation.ONE_ON_ONE + val oneOnOneConversation = TestConversation.ONE_ON_ONE() val groupConversation = TestConversation.GROUP() val conversations = listOf(groupConversation, oneOnOneConversation) val fetchArchivedConversations = false @@ -342,9 +342,9 @@ class ObserveConversationListDetailsUseCaseTest { @Test fun givenConversationDetailsFailure_whenObservingDetailsList_thenIgnoreConversationWithFailure() = runTest { // Given - val successConversation = TestConversation.ONE_ON_ONE.copy(id = ConversationId("successId", "domain")) + val successConversation = TestConversation.ONE_ON_ONE().copy(id = ConversationId("successId", "domain")) val successConversationDetails = TestConversationDetails.CONVERSATION_ONE_ONE.copy(conversation = successConversation) - val failureConversation = TestConversation.ONE_ON_ONE.copy(id = ConversationId("failedId", "domain")) + val failureConversation = TestConversation.ONE_ON_ONE().copy(id = ConversationId("failedId", "domain")) val fetchArchivedConversations = false val (_, observeConversationsUseCase) = Arrangement() diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt index 80c9a4b11f8..51abef4ebf7 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/RecoverMLSConversationsUseCaseTests.kt @@ -43,6 +43,7 @@ import io.mockative.twice import io.mockative.verify import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertIs @@ -260,7 +261,7 @@ class RecoverMLSConversationsUseCaseTests { val MLS_CONVERSATION1 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID1, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 1UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 @@ -270,7 +271,7 @@ class RecoverMLSConversationsUseCaseTests { val MLS_CONVERSATION2 = TestConversation.GROUP( Conversation.ProtocolInfo.MLS( GROUP_ID2, - Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, epoch = 1UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCaseTest.kt new file mode 100644 index 00000000000..3909ce9a9e7 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/SyncConversationsUseCaseTest.kt @@ -0,0 +1,122 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation + +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.functional.Either +import io.mockative.any +import io.mockative.eq +import io.mockative.given +import io.mockative.mock +import io.mockative.once +import io.mockative.thenDoNothing +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test + +class SyncConversationsUseCaseTest { + @Test + fun givenUseCase_whenInvoked_thenFetchConversations() = runTest { + + val (arrangement, useCase) = Arrangement() + .withGetConversationsIdsReturning(emptyList()) + .withFetchConversationsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchConversations) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolChanges_whenInvoked_thenInsertHistoryLostSystemMessage() = runTest { + val conversationId = TestConversation.ID + val (arrangement, useCase) = Arrangement() + .withGetConversationsIdsReturning(listOf(conversationId), protocol = Conversation.Protocol.PROTEUS) + .withFetchConversationsSuccessful() + .withGetConversationsIdsReturning(listOf(conversationId), protocol = Conversation.Protocol.MLS) + .withInsertHistoryLostProtocolChangedSystemMessageSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertHistoryLostProtocolChangedSystemMessage) + .with(eq(conversationId)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolIsUnchanged_whenInvoked_thenDoNotInsertHistoryLostSystemMessage() = runTest { + val conversationId = TestConversation.ID + val (arrangement, useCase) = Arrangement() + .withGetConversationsIdsReturning(listOf(conversationId), protocol = Conversation.Protocol.PROTEUS) + .withFetchConversationsSuccessful() + .withGetConversationsIdsReturning(emptyList(), protocol = Conversation.Protocol.MLS) + .withInsertHistoryLostProtocolChangedSystemMessageSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertHistoryLostProtocolChangedSystemMessage) + .with(eq(conversationId)) + .wasNotInvoked() + } + + private class Arrangement { + + val conversationRepository = mock(ConversationRepository::class) + val systemMessageInserter = mock(SystemMessageInserter::class) + + fun withFetchConversationsSuccessful() = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchConversations) + .whenInvoked() + .thenReturn(Either.Right(Unit)) + } + + fun withGetConversationsIdsReturning( + conversationIds: List, + protocol: Conversation.Protocol? = null + ) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationIds) + .whenInvokedWith(eq(Conversation.Type.GROUP), protocol?.let { eq(it) } ?: any(), eq(null)) + .thenReturn(Either.Right(conversationIds)) + } + + fun withInsertHistoryLostProtocolChangedSystemMessageSuccessful() = apply { + given(systemMessageInserter) + .suspendFunction(systemMessageInserter::insertHistoryLostProtocolChangedSystemMessage) + .whenInvokedWith() + .thenDoNothing() + } + + fun arrange() = this to SyncConversationsUseCaseImpl( + conversationRepository, + systemMessageInserter + ) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/MLSOneOnOneConversationResolverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/MLSOneOnOneConversationResolverTest.kt new file mode 100644 index 00000000000..4eca8bb46e2 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/MLSOneOnOneConversationResolverTest.kt @@ -0,0 +1,188 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangement +import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.any +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class MLSOneOnOneConversationResolverTest { + + @Test + fun givenAUserId_whenInvokingUseCase_shouldPassCorrectUserIdWhenGettingConversationsForUser() = runTest { + val (arrangement, getOrEstablishMlsOneToOneUseCase) = arrange { + withConversationsForUserIdReturning(Either.Right(ALL_CONVERSATIONS)) + } + + getOrEstablishMlsOneToOneUseCase(userId) + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::getConversationsByUserId) + .with(eq(userId)) + .wasInvoked(exactly = once) + } + + @Test + fun givenFailureWhenGettingConversations_thenShouldPropagateFailureAndAvoidUnnecessaryCalls() = runTest { + val cause = CoreFailure.Unknown(null) + val (arrangement, getOrEstablishMlsOneToOneUseCase) = arrange { + withConversationsForUserIdReturning(Either.Left(cause)) + withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit)) + } + + val result = getOrEstablishMlsOneToOneUseCase(userId) + + result.shouldFail { + assertEquals(cause, it) + } + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchMlsOneToOneConversation) + .with(any()) + .wasNotInvoked() + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenOneOnOneMLSConversationAlreadyExists_thenShouldReturnIt() = runTest { + val (_, getOrEstablishMlsOneToOneUseCase) = arrange { + withConversationsForUserIdReturning(Either.Right(ALL_CONVERSATIONS)) + } + + val result = getOrEstablishMlsOneToOneUseCase(userId) + + result.shouldSucceed { + assertEquals(CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED.id, it) + } + } + + @Test + fun givenNoInitializedMLSAndFetchingFails_thenShouldPropagateFailure() = runTest { + val cause = CoreFailure.Unknown(null) + val (_, getOrEstablishMlsOneToOneUseCase) = arrange { + withConversationsForUserIdReturning( + Either.Right( + ALL_CONVERSATIONS - CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED + ) + ) + withFetchMlsOneToOneConversation(Either.Left(cause)) + } + + val result = getOrEstablishMlsOneToOneUseCase(userId) + + result.shouldFail { + assertEquals(cause, it) + } + } + + @Test + fun givenNoInitializedMLSAndFetchingSucceeds_thenShouldJoinAndAndReturnIt() = runTest { + val (arrangement, getOrEstablishMlsOneToOneUseCase) = arrange { + withConversationsForUserIdReturning( + Either.Right( + ALL_CONVERSATIONS - CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED + ) + ) + withFetchMlsOneToOneConversation(Either.Right(CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED)) + withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit)) + } + + val result = getOrEstablishMlsOneToOneUseCase(userId) + + result.shouldSucceed { + assertEquals(CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED.id, it) + } + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(any()) + .wasInvoked(exactly = once) + } + + private fun arrange(block: Arrangement.() -> Unit) = Arrangement(block).arrange() + + private class Arrangement( + private val block: Arrangement.() -> Unit + ) : ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + JoinExistingMLSConversationUseCaseArrangement by JoinExistingMLSConversationUseCaseArrangementImpl() { + + fun arrange() = block().let { + this to MLSOneOnOneConversationResolverImpl( + conversationRepository = conversationRepository, + joinExistingMLSConversationUseCase = joinExistingMLSConversationUseCase, + ) + } + } + + private companion object { + private val userId = TestUser.USER_ID + + private val CONVERSATION_ONE_ON_ONE_PROTEUS = TestConversation.ONE_ON_ONE().copy( + id = ConversationId("one-on-one-proteus", "test"), + protocol = Conversation.ProtocolInfo.Proteus, + ) + + private val CONVERSATION_ONE_ON_ONE_MLS_NOT_ESTABLISHED = CONVERSATION_ONE_ON_ONE_PROTEUS.copy( + id = ConversationId("one-on-one-mls-NOT-initialized", "test"), + protocol = TestConversation.MLS_PROTOCOL_INFO.copy( + groupState = Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_CREATION, + epoch = 0U + ), + ) + + private val CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED = CONVERSATION_ONE_ON_ONE_MLS_NOT_ESTABLISHED.copy( + id = ConversationId("one-on-one-mls-initialized", "test"), + protocol = TestConversation.MLS_PROTOCOL_INFO.copy( + groupState = Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED, + epoch = 0U + ), + ) + + private val CONVERSATION_GROUP_MLS_INITIALIZED = CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED.copy( + id = ConversationId("group-mls-initialized", "test"), + type = Conversation.Type.GROUP + ) + + private val ALL_CONVERSATIONS = listOf( + CONVERSATION_ONE_ON_ONE_PROTEUS, + CONVERSATION_ONE_ON_ONE_MLS_NOT_ESTABLISHED, + CONVERSATION_ONE_ON_ONE_MLS_ESTABLISHED, + CONVERSATION_GROUP_MLS_INITIALIZED, + ) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneMigratorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneMigratorTest.kt new file mode 100644 index 00000000000..3585bdaeff9 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneMigratorTest.kt @@ -0,0 +1,265 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.ConversationOptions +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.mls.MLSOneOnOneConversationResolverArrangement +import com.wire.kalium.logic.util.arrangement.mls.MLSOneOnOneConversationResolverArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationGroupRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationGroupRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.MessageRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.MessageRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.any +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class OneOnOneMigratorTest { + + @Test + fun givenOneOnOneIsAlreadyProteus_whenMigratingToProteus_thenShouldNotDoAnythingElseAndSucceed() = runTest { + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = TestConversation.ID + ) + + val (arrangement, oneOneMigrator) = arrange { + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(listOf(TestConversation.ID))) + } + + oneOneMigrator.migrateToProteus(user) + .shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation) + .with(any(), any()) + .wasNotInvoked() + } + + @Test + fun givenUnassignedOneOnOne_whenMigratingToProteus_thenShouldAssignOneOnOneConversation() = runTest { + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = null + ) + + val (arrangement, oneOneMigrator) = arrange { + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(listOf(TestConversation.ID))) + withUpdateOneOnOneConversationReturning(Either.Right(Unit)) + } + + oneOneMigrator.migrateToProteus(user) + .shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation) + .with(eq(user.id), eq(TestConversation.ID)) + .wasInvoked() + } + + @Test + fun givenNoExistingTeamOneOnOne_whenMigratingToProteus_thenShouldCreateGroupConversation() = runTest { + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = null + ) + + val (arrangement, oneOneMigrator) = arrange { + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(emptyList())) + withCreateGroupConversationReturning(Either.Right(TestConversation.ONE_ON_ONE())) + withUpdateOneOnOneConversationReturning(Either.Right(Unit)) + } + + oneOneMigrator.migrateToProteus(user) + .shouldSucceed() + + verify(arrangement.conversationGroupRepository) + .suspendFunction(arrangement.conversationGroupRepository::createGroupConversation) + .with(eq(null), eq(listOf(TestUser.OTHER.id)), eq(ConversationOptions())) + .wasInvoked() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation) + .with(eq(TestUser.OTHER.id), eq(TestConversation.ONE_ON_ONE().id)) + .wasInvoked() + } + + @Test + fun givenOneOnOneIsAlreadyMLS_whenMigratingToMLS_thenShouldNotDoAnythingElseAndSucceed() = runTest { + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = TestConversation.ID + ) + + val (arrangement, oneOneMigrator) = arrange { + withResolveConversationReturning(Either.Right(TestConversation.ID)) + } + + oneOneMigrator.migrateToMLS(user) + .shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation) + .with(any(), any()) + .wasNotInvoked() + + verify(arrangement.messageRepository) + .suspendFunction(arrangement.messageRepository::moveMessagesToAnotherConversation) + .with(any(), any()) + .wasNotInvoked() + } + + @Test + fun givenResolvingMLSConversationFails_whenMigratingToMLS_thenShouldPropagateFailure() = runTest { + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = null + ) + val failure = CoreFailure.MissingClientRegistration + + val (_, oneOnOneMigrator) = arrange { + withResolveConversationReturning(Either.Left(failure)) + } + + oneOnOneMigrator.migrateToMLS(user) + .shouldFail { + assertEquals(failure, it) + } + } + + @Test + fun givenMigratingMessagesFails_whenMigratingToMLS_thenShouldPropagateFailureAndNotUpdateConversation() = runTest { + val failure = StorageFailure.DataNotFound + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = null + ) + val (arrangement, oneOnOneMigrator) = arrange { + withResolveConversationReturning(Either.Right(TestConversation.ID)) + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(listOf(TestConversation.ID))) + withMoveMessagesToAnotherConversation(Either.Left(failure)) + } + + oneOnOneMigrator.migrateToMLS(user) + .shouldFail { + assertEquals(failure, it) + } + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation) + .with(any(), any()) + .wasNotInvoked() + } + + @Test + fun givenUpdatingOneOnOneConversationFails_whenMigratingToMLS_thenShouldPropagateFailure() = runTest { + val failure = StorageFailure.DataNotFound + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = null + ) + val (_, oneOnOneMigrator) = arrange { + withResolveConversationReturning(Either.Right(TestConversation.ID)) + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(listOf(TestConversation.ID))) + withMoveMessagesToAnotherConversation(Either.Right(Unit)) + withUpdateOneOnOneConversationReturning(Either.Left(failure)) + } + + oneOnOneMigrator.migrateToMLS(user) + .shouldFail { + assertEquals(failure, it) + } + } + + @Test + fun givenResolvedMLSConversation_whenMigratingToMLS_thenShouldMoveMessagesCorrectly() = runTest { + val originalConversationId = ConversationId("someRandomConversationId", "testDomain") + val resolvedConversationId = ConversationId("resolvedMLSConversationId", "anotherDomain") + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = null + ) + val (arrangement, oneOnOneMigrator) = arrange { + withResolveConversationReturning(Either.Right(resolvedConversationId)) + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(listOf(originalConversationId))) + withMoveMessagesToAnotherConversation(Either.Right(Unit)) + withUpdateOneOnOneConversationReturning(Either.Right(Unit)) + } + + oneOnOneMigrator.migrateToMLS(user) + .shouldSucceed() + + verify(arrangement.messageRepository) + .suspendFunction(arrangement.messageRepository::moveMessagesToAnotherConversation) + .with(eq(originalConversationId), eq(resolvedConversationId)) + .wasInvoked(exactly = once) + } + + @Test + fun givenResolvedMLSConversation_whenMigratingToMLS_thenCallRepositoryWithCorrectArguments() = runTest { + val originalConversationId = ConversationId("someRandomConversationId", "testDomain") + val resolvedConversationId = ConversationId("resolvedMLSConversationId", "anotherDomain") + val user = TestUser.OTHER.copy( + activeOneOnOneConversationId = originalConversationId + ) + val (arrangement, oneOnOneMigrator) = arrange { + withResolveConversationReturning(Either.Right(resolvedConversationId)) + withGetOneOnOneConversationsWithOtherUserReturning(Either.Right(listOf(originalConversationId))) + withMoveMessagesToAnotherConversation(Either.Right(Unit)) + withUpdateOneOnOneConversationReturning(Either.Right(Unit)) + } + + oneOnOneMigrator.migrateToMLS(user) + .shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateActiveOneOnOneConversation) + .with(eq(user.id), eq(resolvedConversationId)) + .wasInvoked(exactly = once) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + MLSOneOnOneConversationResolverArrangement by MLSOneOnOneConversationResolverArrangementImpl(), + MessageRepositoryArrangement by MessageRepositoryArrangementImpl(), + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + ConversationGroupRepositoryArrangement by ConversationGroupRepositoryArrangementImpl(), + UserRepositoryArrangement by UserRepositoryArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to OneOnOneMigratorImpl( + getResolvedMLSOneOnOne = mlsOneOnOneConversationResolver, + conversationGroupRepository = conversationGroupRepository, + conversationRepository = conversationRepository, + messageRepository = messageRepository, + userRepository = userRepository + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneResolverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneResolverTest.kt new file mode 100644 index 00000000000..7b0ddf225ab --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/conversation/mls/OneOnOneResolverTest.kt @@ -0,0 +1,143 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.conversation.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.IncrementalSyncRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.IncrementalSyncRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneMigratorArrangement +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneMigratorArrangementImpl +import com.wire.kalium.logic.util.arrangement.protocol.OneOnOneProtocolSelectorArrangement +import com.wire.kalium.logic.util.arrangement.protocol.OneOnOneProtocolSelectorArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.eq +import io.mockative.given +import io.mockative.matchers.OneOfMatcher +import io.mockative.once +import io.mockative.twice +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test + +class OneOnOneResolverTest { + + @Test + fun givenListOneOnOneUsers_whenResolveAllOneOnOneConversations_thenResolveOneOnOneForEachUser() = runTest { + // given + val oneOnOneUsers = listOf(TestUser.OTHER.copy(id = TestUser.OTHER_USER_ID), TestUser.OTHER.copy(id = TestUser.OTHER_USER_ID_2)) + val (arrangement, resolver) = arrange { + withGetUsersWithOneOnOneConversationReturning(oneOnOneUsers) + withGetProtocolForUser(Either.Right(SupportedProtocol.MLS)) + withMigrateToMLSReturns(Either.Right(TestConversation.ID)) + } + + // when + resolver.resolveAllOneOnOneConversations().shouldSucceed() + + // then + verify(arrangement.oneOnOneProtocolSelector) + .suspendFunction(arrangement.oneOnOneProtocolSelector::getProtocolForUser) + .with(OneOfMatcher(oneOnOneUsers.map { it.id })) + .wasInvoked(exactly = twice) + } + + @Test + fun givenResolvingOneConversationFails_whenResolveAllOneOnOneConversations_thenTheWholeOperationFails() = runTest { + // given + val oneOnOneUsers = listOf(TestUser.OTHER.copy(id = TestUser.OTHER_USER_ID), TestUser.OTHER.copy(id = TestUser.OTHER_USER_ID_2)) + val (arrangement, resolver) = arrange { + withGetUsersWithOneOnOneConversationReturning(oneOnOneUsers) + withGetProtocolForUser(Either.Right(SupportedProtocol.MLS)) + withMigrateToMLSReturns(Either.Right(TestConversation.ID)) + } + + given(arrangement.oneOnOneMigrator) + .suspendFunction(arrangement.oneOnOneMigrator::migrateToMLS) + .whenInvokedWith(eq(oneOnOneUsers.last())) + .thenReturn(Either.Left(CoreFailure.Unknown(null))) + + // when then + resolver.resolveAllOneOnOneConversations().shouldFail() + } + + @Test + fun givenProtocolResolvesToMLS_whenResolveOneOnOneConversationWithUser_thenMigrateToMLS() = runTest { + // given + val (arrangement, resolver) = arrange { + withGetProtocolForUser(Either.Right(SupportedProtocol.MLS)) + withMigrateToMLSReturns(Either.Right(TestConversation.ID)) + } + + // when + resolver.resolveOneOnOneConversationWithUser(OTHER_USER).shouldSucceed() + + // then + verify(arrangement.oneOnOneMigrator) + .suspendFunction(arrangement.oneOnOneMigrator::migrateToMLS) + .with(eq(OTHER_USER)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolResolvesToProteus_whenResolveOneOnOneConversationWithUser_thenMigrateToProteus() = runTest { + // given + val (arrangement, resolver) = arrange { + withGetProtocolForUser(Either.Right(SupportedProtocol.PROTEUS)) + withMigrateToProteusReturns(Either.Right(TestConversation.ID)) + } + + // when + resolver.resolveOneOnOneConversationWithUser(OTHER_USER).shouldSucceed() + + // then + verify(arrangement.oneOnOneMigrator) + .suspendFunction(arrangement.oneOnOneMigrator::migrateToProteus) + .with(eq(OTHER_USER)) + .wasInvoked(exactly = once) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + UserRepositoryArrangement by UserRepositoryArrangementImpl(), + OneOnOneProtocolSelectorArrangement by OneOnOneProtocolSelectorArrangementImpl(), + OneOnOneMigratorArrangement by OneOnOneMigratorArrangementImpl(), + IncrementalSyncRepositoryArrangement by IncrementalSyncRepositoryArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to OneOnOneResolverImpl( + userRepository = userRepository, + oneOnOneProtocolSelector = oneOnOneProtocolSelector, + oneOnOneMigrator = oneOnOneMigrator, + incrementalSyncRepository = incrementalSyncRepository + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + val OTHER_USER = TestUser.OTHER + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCaseTest.kt index da8c8ff5f2d..9e4e26b3e4f 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/SyncFeatureConfigsUseCaseTest.kt @@ -29,7 +29,6 @@ import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest import com.wire.kalium.logic.data.featureConfig.E2EIConfigModel import com.wire.kalium.logic.data.featureConfig.E2EIModel -import com.wire.kalium.logic.data.featureConfig.MLSModel import com.wire.kalium.logic.data.featureConfig.SelfDeletingMessagesConfigModel import com.wire.kalium.logic.data.featureConfig.SelfDeletingMessagesModel import com.wire.kalium.logic.data.featureConfig.Status @@ -40,10 +39,12 @@ import com.wire.kalium.logic.feature.featureConfig.handler.E2EIConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.FileSharingConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.GuestRoomConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SecondFactorPasswordChallengeConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SelfDeletingMessagesConfigHandler import com.wire.kalium.logic.feature.selfDeletingMessages.SelfDeletionMapper.toTeamSelfDeleteTimer import com.wire.kalium.logic.feature.selfDeletingMessages.TeamSelfDeleteTimer +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase import com.wire.kalium.logic.featureFlags.BuildFileRestrictionState import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.framework.TestUser @@ -51,6 +52,7 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestNetworkException import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.persistence.config.inMemoryUserConfigStorage +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.dao.unread.UserConfigDAO import io.mockative.Mock import io.mockative.any @@ -82,6 +84,7 @@ class SyncFeatureConfigsUseCaseTest { ) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -101,6 +104,7 @@ class SyncFeatureConfigsUseCaseTest { ) ) ) + .withGetSupportedProtocolsReturning(null) .withGetTeamSettingsSelfDeletionStatusSuccessful() .arrange() @@ -111,54 +115,6 @@ class SyncFeatureConfigsUseCaseTest { } } - @Test - fun givenMlsIsEnabledAndSelfUserIsWhitelisted_whenSyncing_thenItShouldBeStoredAsEnabled() = runTest { - val (arrangement, syncFeatureConfigsUseCase) = Arrangement() - .withRemoteFeatureConfigsSucceeding( - FeatureConfigTest.newModel(mlsModel = MLSModel(listOf(SELF_USER_ID.toPlainID()), Status.ENABLED)) - ) - .withGetTeamSettingsSelfDeletionStatusSuccessful() - .arrange() - - syncFeatureConfigsUseCase() - - arrangement.userConfigRepository.isMLSEnabled().shouldSucceed { - assertTrue(it) - } - } - - @Test - fun givenMlsIsEnabledAndSelfUserIsNotWhitelisted_whenSyncing_thenItShouldBeStoredAsDisabled() = runTest { - val (arrangement, syncFeatureConfigsUseCase) = Arrangement() - .withRemoteFeatureConfigsSucceeding( - FeatureConfigTest.newModel(mlsModel = MLSModel(listOf(), Status.ENABLED)) - ) - .withGetTeamSettingsSelfDeletionStatusSuccessful() - .arrange() - - syncFeatureConfigsUseCase() - - arrangement.userConfigRepository.isMLSEnabled().shouldSucceed { - assertFalse(it) - } - } - - @Test - fun givenMlsIsDisasbled_whenSyncing_thenItShouldBeStoredAsDisabled() = runTest { - val (arrangement, syncFeatureConfigsUseCase) = Arrangement() - .withRemoteFeatureConfigsSucceeding( - FeatureConfigTest.newModel(mlsModel = MLSModel(listOf(), Status.DISABLED)) - ) - .withGetTeamSettingsSelfDeletionStatusSuccessful() - .arrange() - - syncFeatureConfigsUseCase() - - arrangement.userConfigRepository.isMLSEnabled().shouldSucceed { - assertFalse(it) - } - } - @Test fun givenConferenceCallingIsEnabled_whenSyncing_thenItShouldBeStoredAsEnabled() = runTest { val (arrangement, syncFeatureConfigsUseCase) = Arrangement() @@ -166,6 +122,7 @@ class SyncFeatureConfigsUseCaseTest { FeatureConfigTest.newModel(conferenceCallingModel = ConferenceCallingModel(Status.ENABLED)) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -182,6 +139,7 @@ class SyncFeatureConfigsUseCaseTest { FeatureConfigTest.newModel(conferenceCallingModel = ConferenceCallingModel(Status.DISABLED)) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() @@ -199,6 +157,7 @@ class SyncFeatureConfigsUseCaseTest { FeatureConfigTest.newModel(fileSharingModel = ConfigsStatusModel(Status.ENABLED)) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -219,6 +178,7 @@ class SyncFeatureConfigsUseCaseTest { isStatusChanged = false ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -235,6 +195,7 @@ class SyncFeatureConfigsUseCaseTest { FeatureConfigTest.newModel(fileSharingModel = ConfigsStatusModel(Status.DISABLED)) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() @@ -256,6 +217,7 @@ class SyncFeatureConfigsUseCaseTest { isStatusChanged = false ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -276,6 +238,7 @@ class SyncFeatureConfigsUseCaseTest { isStatusChanged = false ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -300,6 +263,7 @@ class SyncFeatureConfigsUseCaseTest { isStatusChanged = false ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -323,6 +287,7 @@ class SyncFeatureConfigsUseCaseTest { ) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -345,6 +310,7 @@ class SyncFeatureConfigsUseCaseTest { ) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -367,6 +333,7 @@ class SyncFeatureConfigsUseCaseTest { ) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -390,6 +357,7 @@ class SyncFeatureConfigsUseCaseTest { ) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -446,6 +414,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withBuildConfigFileSharing(BuildFileRestrictionState.NoRestriction) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -467,6 +436,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withBuildConfigFileSharing(BuildFileRestrictionState.AllowSome(listOf("png", "jpg"))) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -484,6 +454,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withBuildConfigFileSharing(BuildFileRestrictionState.AllowSome(listOf("png", "jpg"))) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -501,6 +472,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withBuildConfigFileSharing(BuildFileRestrictionState.NoRestriction) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -516,6 +488,7 @@ class SyncFeatureConfigsUseCaseTest { val (arrangement, getTeamSettingsSelfDeletionStatusUseCase) = Arrangement() .withKaliumConfigs { it.copy(selfDeletingMessages = false) } .withSetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() // When @@ -542,6 +515,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withSetTeamSettingsSelfDeletionStatusSuccessful() .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() // When @@ -568,6 +542,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withSetTeamSettingsSelfDeletionStatusSuccessful() .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() // When @@ -595,6 +570,7 @@ class SyncFeatureConfigsUseCaseTest { ) .withSetTeamSettingsSelfDeletionStatusSuccessful() .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() // When @@ -620,6 +596,7 @@ class SyncFeatureConfigsUseCaseTest { FeatureConfigTest.newModel(e2EIModel = e2EIModel) ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -638,6 +615,7 @@ class SyncFeatureConfigsUseCaseTest { FeatureConfigTest.newModel() ) .withGetTeamSettingsSelfDeletionStatusSuccessful() + .withGetSupportedProtocolsReturning(null) .arrange() syncFeatureConfigsUseCase() @@ -666,6 +644,8 @@ class SyncFeatureConfigsUseCaseTest { @Mock val featureConfigRepository = mock(classOf()) + @Mock + val updateSupportedProtocolsAndResolveOneOnOnes = mock(classOf()) private lateinit var syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase @@ -734,6 +714,13 @@ class SyncFeatureConfigsUseCaseTest { .then { } } + fun withGetSupportedProtocolsReturning(result: Set?) = apply { + given(userConfigDAO) + .suspendFunction(userConfigDAO::getSupportedProtocols) + .whenInvoked() + .thenReturn(result) + } + fun withKaliumConfigs(changeConfigs: (KaliumConfigs) -> KaliumConfigs) = apply { this.kaliumConfigs = changeConfigs(this.kaliumConfigs) } @@ -743,7 +730,8 @@ class SyncFeatureConfigsUseCaseTest { featureConfigRepository, GuestRoomConfigHandler(userConfigRepository, kaliumConfigs), FileSharingConfigHandler(userConfigRepository), - MLSConfigHandler(userConfigRepository, TestUser.SELF.id), + MLSConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes, TestUser.SELF.id), + MLSMigrationConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes), ClassifiedDomainsConfigHandler(userConfigRepository), ConferenceCallingConfigHandler(userConfigRepository), SecondFactorPasswordChallengeConfigHandler(userConfigRepository), @@ -755,8 +743,4 @@ class SyncFeatureConfigsUseCaseTest { } } - - private companion object { - val SELF_USER_ID = TestUser.USER_ID - } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandlerTest.kt new file mode 100644 index 00000000000..51eebdebd76 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSConfigHandlerTest.kt @@ -0,0 +1,223 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.featureConfig.handler + +import com.wire.kalium.logic.data.featureConfig.MLSModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.usecase.UpdateSupportedProtocolsAndResolveOneOnOnesArrangement +import com.wire.kalium.logic.util.arrangement.usecase.UpdateSupportedProtocolsAndResolveOneOnOnesArrangementImpl +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test + +class MLSConfigHandlerTest { + @Test + fun givenMlsIsEnabledAndMlsIsDefaultProtocol_whenSyncing_thenSetMlsAsDefault() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.ENABLED, + defaultProtocol = SupportedProtocol.MLS + ), duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setDefaultProtocol) + .with(eq(SupportedProtocol.MLS)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsEnabledAndProteusIsDefaultProtocol_whenSyncing_thenSetProteusAsDefault() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.ENABLED, + defaultProtocol = SupportedProtocol.PROTEUS + ), duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setDefaultProtocol) + .with(eq(SupportedProtocol.PROTEUS)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsDisabledAndMlsIsDefaultProtocol_whenSyncing_thenSetProteusAsDefault() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.DISABLED, + defaultProtocol = SupportedProtocol.MLS + ), duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setDefaultProtocol) + .with(eq(SupportedProtocol.PROTEUS)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsEnabledAndSelfUserIsWhitelisted_whenSyncing_thenSetMlsEnabled() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.ENABLED, + allowedUsers = listOf(SELF_USER_ID.toPlainID()) + ), duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setMLSEnabled) + .with(eq(true)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsEnabledAndSelfUserIsNotWhitelisted_whenSyncing_thenSetMlsDisabled() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.ENABLED, + allowedUsers = listOf(TestUser.OTHER_USER_ID.toPlainID()) + ), duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setMLSEnabled) + .with(eq(false)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsDisabled_whenSyncing_thenSetMlsDisabled() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.DISABLED + ), duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setMLSEnabled) + .with(eq(false)) + .wasInvoked(exactly = once) + } + + @Test + fun givenSupportedProtocolsHasChangedInEvent_whenSyncing_thenUpdateSelfSupportedProtocols() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withUpdateSupportedProtocolsAndResolveOneOnOnesSuccessful() + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.ENABLED, + supportedProtocols = setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS) + ), duringSlowSync = false) + + verify(arrangement.updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(arrangement.updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .with(eq(true)) + .wasInvoked(exactly = once) + } + + @Test + fun givenSupportedProtocolsHasChangedDuringSlowSync_whenSyncing_thenUpdateSelfSupportedProtocols() = runTest { + val (arrangement, handler) = arrange { + withGetSupportedProtocolsReturning(Either.Right(setOf(SupportedProtocol.PROTEUS))) + withUpdateSupportedProtocolsAndResolveOneOnOnesSuccessful() + withSetSupportedProtocolsSuccessful() + withSetDefaultProtocolSuccessful() + withSetMLSEnabledSuccessful() + } + + handler.handle(MLS_CONFIG.copy( + status = Status.ENABLED, + supportedProtocols = setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS) + ), duringSlowSync = true) + + verify(arrangement.updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(arrangement.updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .with(eq(false)) + .wasInvoked(exactly = once) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(), + UpdateSupportedProtocolsAndResolveOneOnOnesArrangement by UpdateSupportedProtocolsAndResolveOneOnOnesArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to MLSConfigHandler( + userConfigRepository = userConfigRepository, + updateSupportedProtocolsAndResolveOneOnOnes = updateSupportedProtocolsAndResolveOneOnOnes, + SELF_USER_ID + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + val SELF_USER_ID = TestUser.USER_ID + val MLS_CONFIG = MLSModel( + allowedUsers = emptyList(), + defaultProtocol = SupportedProtocol.MLS, + supportedProtocols = setOf(SupportedProtocol.PROTEUS), + status = Status.ENABLED + ) + } + +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSMigrationConfigHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSMigrationConfigHandlerTest.kt new file mode 100644 index 00000000000..e177a705e97 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/featureConfig/handler/MLSMigrationConfigHandlerTest.kt @@ -0,0 +1,109 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.featureConfig.handler + +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.usecase.UpdateSupportedProtocolsAndResolveOneOnOnesArrangement +import com.wire.kalium.logic.util.arrangement.usecase.UpdateSupportedProtocolsAndResolveOneOnOnesArrangementImpl +import io.mockative.any +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlin.test.Test + +class MLSMigrationConfigHandlerTest { + @Test + fun givenMlsConfiguration_whenHandling_thenSetMlsConfiguration() = runTest { + val (arrangement, handler) = arrange { + withSetMigrationConfigurationSuccessful() + } + + handler.handle(MIGRATION_CONFIG, duringSlowSync = false) + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::setMigrationConfiguration) + .with(eq(MIGRATION_CONFIG)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMigrationHasEnded_whenHandling_thenUpdateSelfSupportedProtocols() = runTest { + val (arrangement, handler) = arrange { + withUpdateSupportedProtocolsAndResolveOneOnOnesSuccessful() + withSetMigrationConfigurationSuccessful() + } + + handler.handle(MIGRATION_CONFIG.copy( + startTime = Instant.DISTANT_PAST, + endTime = Instant.DISTANT_PAST + ), duringSlowSync = false) + + verify(arrangement.updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(arrangement.updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .with(eq(true)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMigrationHasEndedDuringSlowSync_whenHandling_thenDontUpdateSelfSupportedProtocols() = runTest { + val (arrangement, handler) = arrange { + withUpdateSupportedProtocolsAndResolveOneOnOnesSuccessful() + withSetMigrationConfigurationSuccessful() + } + + handler.handle(MIGRATION_CONFIG.copy( + startTime = Instant.DISTANT_PAST, + endTime = Instant.DISTANT_PAST + ), duringSlowSync = true) + + verify(arrangement.updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(arrangement.updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .with(any()) + .wasNotInvoked() + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(), + UpdateSupportedProtocolsAndResolveOneOnOnesArrangement by UpdateSupportedProtocolsAndResolveOneOnOnesArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to MLSMigrationConfigHandler( + userConfigRepository = userConfigRepository, + updateSupportedProtocolsAndResolveOneOnOnes = updateSupportedProtocolsAndResolveOneOnOnes + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + val MIGRATION_CONFIG = MLSMigrationModel( + startTime = null, + endTime = null, + status = Status.ENABLED + ) + } + +} + diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt index 037bfbf8a2d..9cd0fcf3e3c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt @@ -41,11 +41,14 @@ import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Compa import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.MESSAGE_SENT_TIME import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.TEST_MEMBER_2 import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.TEST_PROTOCOL_INFO_FAILURE +import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.arrange import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandler import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestMessage import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.util.arrangement.mls.StaleEpochVerifierArrangement +import com.wire.kalium.logic.util.arrangement.mls.StaleEpochVerifierArrangementImpl import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi @@ -64,7 +67,6 @@ import io.mockative.mock import io.mockative.once import io.mockative.twice import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.TestScope import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant @@ -72,15 +74,14 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.time.Duration -@OptIn(ExperimentalCoroutinesApi::class) class MessageSenderTest { @Test fun givenAllStepsSucceed_WhenSendingOutgoingMessage_ThenReturnSuccess() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -94,9 +95,9 @@ class MessageSenderTest { @Test fun givenGettingConversationProtocolFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(getConversationProtocolFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(getConversationProtocolFailing = true) + } arrangement.testScope.runTest { // when @@ -114,9 +115,9 @@ class MessageSenderTest { @Test fun givenGettingConversationRecipientsFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(getConversationsRecipientFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(getConversationsRecipientFailing = true) + } arrangement.testScope.runTest { // when @@ -134,9 +135,9 @@ class MessageSenderTest { @Test fun givenPreparingRecipientsForNewOutgoingMessageFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(prepareRecipientsForNewOutGoingMessageFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(prepareRecipientsForNewOutGoingMessageFailing = true) + } arrangement.testScope.runTest { // when @@ -154,9 +155,9 @@ class MessageSenderTest { @Test fun givenCreatingOutgoingEnvelopeFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(createOutgoingEnvelopeFailing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(createOutgoingEnvelopeFailing = true) + } arrangement.testScope.runTest { // when @@ -175,9 +176,9 @@ class MessageSenderTest { fun givenSendingEnvelopeFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given val failure = CoreFailure.Unknown(Throwable("some exception")) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -197,10 +198,10 @@ class MessageSenderTest { // given val failure = CoreFailure.Unknown(Throwable("some exception")) - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -219,10 +220,10 @@ class MessageSenderTest { @Test fun givenUpdatingMessageStatusToSuccessFails_WhenSendingOutgoingMessage_ThenReturnSuccess() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(updateMessageStatusFailing = true) - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(updateMessageStatusFailing = true) + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -241,9 +242,9 @@ class MessageSenderTest { fun givenSendingOfEnvelopeFailsDueToLackOfConnection_whenSendingOutgoingMessage_thenFailureShouldBeHandledProperly() { // given val failure = NetworkFailure.NoNetworkConnection(null) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -261,9 +262,9 @@ class MessageSenderTest { fun givenSendingOfEnvelopeFailsDueToLackOfConnection_whenSendingOutgoingMessage_thenFailureShouldBePropagated() { // given val failure = Either.Left(NetworkFailure.NoNetworkConnection(null)) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = failure) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = failure) + } arrangement.testScope.runTest { // when @@ -274,16 +275,42 @@ class MessageSenderTest { } } + @Test + fun givenReceivingStaleMessageError_whenSendingMlsMessage_thenVerifyStaleEpoch() { + // given + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage() + withSendOutgoingMlsMessage(Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE), times = 1) + withWaitUntilLiveOrFailure() + withPromoteMessageToSentUpdatingServerTime() + withVerifyEpoch(Either.Right(Unit)) + } + + arrangement.testScope.runTest { + // when + val result = messageSender.sendPendingMessage(Arrangement.TEST_CONVERSATION_ID, Arrangement.TEST_MESSAGE_UUID) + + // then + result.shouldSucceed() + verify(arrangement.staleEpochVerifier) + .suspendFunction(arrangement.staleEpochVerifier::verifyEpoch) + .with(eq(Arrangement.TEST_CONVERSATION_ID)) + .wasInvoked(once) + } + } + @Test fun givenReceivingStaleMessageError_whenSendingMlsMessage_thenRetryAfterSyncIsLive() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage() - .withSendOutgoingMlsMessage(Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE), times = 1) - .withWaitUntilLiveOrFailure() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage() + withSendOutgoingMlsMessage(Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE), times = 1) + withWaitUntilLiveOrFailure() + withPromoteMessageToSentUpdatingServerTime() + withVerifyEpoch(Either.Right(Unit)) + } arrangement.testScope.runTest { // when @@ -301,12 +328,12 @@ class MessageSenderTest { @Test fun givenPendingProposals_whenSendingMlsMessage_thenProposalsAreCommitted() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage() - .withSendOutgoingMlsMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage() + withSendOutgoingMlsMessage() + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -324,11 +351,12 @@ class MessageSenderTest { @Test fun givenReceivingStaleMessageError_whenSendingMlsMessage_thenGiveUpIfSyncIsPending() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage(sendMlsMessageWithResult = Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE)) - .withWaitUntilLiveOrFailure(failing = true) - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage(sendMlsMessageWithResult = Either.Left(Arrangement.MLS_STALE_MESSAGE_FAILURE)) + withWaitUntilLiveOrFailure(failing = true) + withVerifyEpoch(Either.Right(Unit)) + } arrangement.testScope.runTest { // when @@ -346,10 +374,10 @@ class MessageSenderTest { @Test fun givenClientTargets_WhenSendingOutgoingMessage_ThenCallSendEnvelopeWithCorrectTargets() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } val message = Message.Signaling( id = Arrangement.TEST_MESSAGE_UUID, @@ -399,10 +427,10 @@ class MessageSenderTest { @Test fun givenConversationTarget_WhenSendingOutgoingMessage_ThenCallSendEnvelopeWithCorrectTargets() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } val message = Message.Signaling( id = Arrangement.TEST_MESSAGE_UUID, @@ -448,9 +476,9 @@ class MessageSenderTest { fun givenARemoteProteusConversationFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given val failure = FEDERATION_MESSAGE_FAILURE - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(sendEnvelopeWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -469,11 +497,11 @@ class MessageSenderTest { fun givenARemoteMLSConversationFails_WhenSendingOutgoingMessage_ThenReturnFailureAndHandleFailureProperly() { // given val failure = FEDERATION_MESSAGE_FAILURE - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withWaitUntilLiveOrFailure() - .withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) - .arrange() + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withWaitUntilLiveOrFailure() + withSendMlsMessage(sendMlsMessageWithResult = Either.Left(failure)) + } arrangement.testScope.runTest { // when @@ -491,8 +519,8 @@ class MessageSenderTest { @Test fun givenARemoteProteusConversationPartiallyFails_WhenSendingOutgoingMessage_ThenReturnSuccessAndPersistFailedRecipients() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage( + val (arrangement, messageSender) = arrange { + withSendProteusMessage( sendEnvelopeWithResult = Either.Right( MessageSent( time = MESSAGE_SENT_TIME, @@ -500,10 +528,10 @@ class MessageSenderTest { ) ) ) - .withFailedClientsPartialSuccess() - .withPromoteMessageToSentUpdatingServerTime() - .withSendMessagePartialSuccess() - .arrange() + withFailedClientsPartialSuccess() + withPromoteMessageToSentUpdatingServerTime() + withSendMessagePartialSuccess() + } arrangement.testScope.runTest { // when @@ -522,8 +550,8 @@ class MessageSenderTest { fun givenARemoteProteusConversationPartiallyFails_WithNoClientsWhenSendingAMessage_ThenReturnSuccessAndPersistFailedClientsAndFailedToSend() { // given val failedRecipient = UsersWithoutSessions(listOf(TEST_MEMBER_2)) - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage( + val (arrangement, messageSender) = arrange { + withSendProteusMessage( sendEnvelopeWithResult = Either.Right( MessageSent( time = MESSAGE_SENT_TIME, @@ -531,11 +559,11 @@ class MessageSenderTest { ) ) ) - .withFailedClientsPartialSuccess() - .withPrepareRecipientsForNewOutgoingMessage(false, failedRecipient) - .withPromoteMessageToSentUpdatingServerTime() - .withSendMessagePartialSuccess() - .arrange() + withFailedClientsPartialSuccess() + withPrepareRecipientsForNewOutgoingMessage(false, failedRecipient) + withPromoteMessageToSentUpdatingServerTime() + withSendMessagePartialSuccess() + } arrangement.testScope.runTest { // when @@ -561,13 +589,13 @@ class MessageSenderTest { Arrangement.TEST_RECIPIENT_1, Arrangement.TEST_RECIPIENT_2 ) - val (arrangement, messageSender) = Arrangement() - .withPrepareRecipientsForNewOutgoingMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withCreateOutgoingBroadcastEnvelope() - .withAllRecipients(recipients to listOf()) - .withBroadcastEnvelope() - .arrange() + val (arrangement, messageSender) = arrange { + withPrepareRecipientsForNewOutgoingMessage() + withPromoteMessageToSentUpdatingServerTime() + withCreateOutgoingBroadcastEnvelope() + withAllRecipients(recipients to listOf()) + withBroadcastEnvelope() + } val message = BroadcastMessage( id = Arrangement.TEST_MESSAGE_UUID, @@ -614,13 +642,13 @@ class MessageSenderTest { Arrangement.TEST_RECIPIENT_1, Recipient(senderUserId, listOf(senderClientId, ClientId("mySecondClientId"))) ) - val (arrangement, messageSender) = Arrangement() - .withPrepareRecipientsForNewOutgoingMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withCreateOutgoingBroadcastEnvelope() - .withAllRecipients(recipients to listOf()) - .withBroadcastEnvelope() - .arrange() + val (arrangement, messageSender) = arrange { + withPrepareRecipientsForNewOutgoingMessage() + withPromoteMessageToSentUpdatingServerTime() + withCreateOutgoingBroadcastEnvelope() + withAllRecipients(recipients to listOf()) + withBroadcastEnvelope() + } val message = BroadcastMessage( id = Arrangement.TEST_MESSAGE_UUID, @@ -667,13 +695,13 @@ class MessageSenderTest { Arrangement.TEST_RECIPIENT_1, Arrangement.TEST_RECIPIENT_3, ) - val (arrangement, messageSender) = Arrangement() - .withPrepareRecipientsForNewOutgoingMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withCreateOutgoingBroadcastEnvelope() - .withAllRecipients(teamRecipients to otherRecipients) - .withBroadcastEnvelope() - .arrange() + val (arrangement, messageSender) = arrange { + withPrepareRecipientsForNewOutgoingMessage() + withPromoteMessageToSentUpdatingServerTime() + withCreateOutgoingBroadcastEnvelope() + withAllRecipients(teamRecipients to otherRecipients) + withBroadcastEnvelope() + } val message = BroadcastMessage( id = Arrangement.TEST_MESSAGE_UUID, @@ -710,11 +738,11 @@ class MessageSenderTest { @Test fun givenASuccess_WhenSendingEditMessage_ThenUpdateMessageIdButDoNotUpdateCreationDate() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .withUpdateTextMessage() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + withUpdateTextMessage() + } val originalMessageId = "original_id" val editedMessageId = "edited_id" @@ -751,10 +779,10 @@ class MessageSenderTest { @Test fun givenASuccess_WhenSendingRegularMessage_ThenDoNotUpdateMessageIdButUpdateCreationDateToServerDate() { // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } val message = TestMessage.TEXT_MESSAGE arrangement.testScope.runTest { @@ -782,10 +810,10 @@ class MessageSenderTest { ) // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage() - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -808,10 +836,10 @@ class MessageSenderTest { ) // given - val (arrangement, messageSender) = Arrangement() - .withSendProteusMessage(true, true) - .withPromoteMessageToSentUpdatingServerTime() - .arrange() + val (arrangement, messageSender) = arrange { + withSendProteusMessage(true, true) + withPromoteMessageToSentUpdatingServerTime() + } arrangement.testScope.runTest { // when @@ -829,15 +857,15 @@ class MessageSenderTest { @Test fun givenARemoteMlsConversationPartiallyFails_whenSendingAMessage_ThenReturnSuccessAndPersistFailedToSendUsers() { // given - val (arrangement, messageSender) = Arrangement() - .withCommitPendingProposals() - .withSendMlsMessage( + val (arrangement, messageSender) = arrange { + withCommitPendingProposals() + withSendMlsMessage( sendMlsMessageWithResult = Either.Right(MessageSent(MESSAGE_SENT_TIME, listOf(TEST_MEMBER_2))), ) - .withWaitUntilLiveOrFailure() - .withPromoteMessageToSentUpdatingServerTime() - .withSendMessagePartialSuccess() - .arrange() + withWaitUntilLiveOrFailure() + withPromoteMessageToSentUpdatingServerTime() + withSendMessagePartialSuccess() + } arrangement.testScope.runTest { // when @@ -852,7 +880,9 @@ class MessageSenderTest { } } - private class Arrangement { + private class Arrangement(private val block: Arrangement.() -> Unit): + StaleEpochVerifierArrangement by StaleEpochVerifierArrangementImpl() + { @Mock val messageRepository: MessageRepository = mock(MessageRepository::class) @@ -891,25 +921,29 @@ class MessageSenderTest { } } - fun arrange() = this to MessageSenderImpl( - messageRepository = messageRepository, - conversationRepository = conversationRepository, - mlsConversationRepository = mlsConversationRepository, - syncManager = syncManager, - messageSendFailureHandler = messageSendFailureHandler, - sessionEstablisher = sessionEstablisher, - messageEnvelopeCreator = messageEnvelopeCreator, - mlsMessageCreator = mlsMessageCreator, - messageSendingInterceptor = messageSendingInterceptor, - userRepository = userRepository, - enqueueSelfDeletion = { message, expirationData -> - selfDeleteMessageSenderHandler.enqueueSelfDeletion( - message, - expirationData - ) - }, - scope = testScope - ) + fun arrange() = run { + block() + this@Arrangement to MessageSenderImpl( + messageRepository = messageRepository, + conversationRepository = conversationRepository, + mlsConversationRepository = mlsConversationRepository, + syncManager = syncManager, + messageSendFailureHandler = messageSendFailureHandler, + sessionEstablisher = sessionEstablisher, + messageEnvelopeCreator = messageEnvelopeCreator, + mlsMessageCreator = mlsMessageCreator, + messageSendingInterceptor = messageSendingInterceptor, + userRepository = userRepository, + enqueueSelfDeletion = { message, expirationData -> + selfDeleteMessageSenderHandler.enqueueSelfDeletion( + message, + expirationData + ) + }, + staleEpochVerifier = staleEpochVerifier, + scope = testScope + ) + } fun withGetMessageById(failing: Boolean = false) = apply { given(messageRepository) @@ -1089,6 +1123,8 @@ class MessageSenderTest { } companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + val TEST_CONVERSATION_ID = TestConversation.ID const val TEST_MESSAGE_UUID = "messageUuid" val MESSAGE_SENT_TIME = DateTimeUtil.currentIsoDateTimeString() @@ -1098,7 +1134,7 @@ class MessageSenderTest { val GROUP_ID = GroupID("groupId") val MLS_PROTOCOL_INFO = Conversation.ProtocolInfo.MLS( GROUP_ID, - Conversation.ProtocolInfo.MLS.GroupState.ESTABLISHED, + Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED, 0UL, Instant.DISTANT_PAST, Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifierTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifierTest.kt new file mode 100644 index 00000000000..ec06129eaf4 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/StaleEpochVerifierTest.kt @@ -0,0 +1,167 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.message + +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangement +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.mls.MLSConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangement +import com.wire.kalium.logic.util.arrangement.usecase.JoinExistingMLSConversationUseCaseArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.any +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlin.test.Test +import kotlin.time.Duration.Companion.minutes + +class StaleEpochVerifierTest { + + @Test + fun givenConversationIsNotMLS_whenHandlingStaleEpoch_thenShouldNotInsertWarning() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.PROTEUS_PROTOCOL_INFO)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldFail() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertLostCommitSystemMessage) + .with(any(), any()) + .wasNotInvoked() + } + + @Test + fun givenMLSConversation_whenHandlingStaleEpoch_thenShouldFetchConversationAgain() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(false)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchConversation) + .with(eq(CONVERSATION_ID)) + .wasInvoked(once) + } + + @Test + fun givenEpochIsLatest_whenHandlingStaleEpoch_thenShouldNotRejoinTheConversation() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(false)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(eq(CONVERSATION_ID)) + .wasNotInvoked() + } + + @Test + fun givenStaleEpoch_whenHandlingStaleEpoch_thenShouldRejoinTheConversation() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(true)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit)) + withInsertLostCommitSystemMessage(Either.Right(Unit)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(eq(CONVERSATION_ID)) + .wasInvoked(once) + } + + @Test + fun givenRejoiningFails_whenHandlingStaleEpoch_thenShouldNotInsertLostCommitSystemMessage() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(true)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + withJoinExistingMLSConversationUseCaseReturning(Either.Left(NetworkFailure.NoNetworkConnection(null))) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldFail() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertLostCommitSystemMessage) + .with(eq(CONVERSATION_ID), any()) + .wasNotInvoked() + } + + @Test + fun givenConversationIsRejoined_whenHandlingStaleEpoch_thenShouldInsertLostCommitSystemMessage() = runTest { + val (arrangement, staleEpochHandler) = arrange { + withIsGroupOutOfSync(Either.Right(true)) + withFetchConversation(Either.Right(Unit)) + withGetConversationProtocolInfo(Either.Right(TestConversation.MLS_PROTOCOL_INFO)) + withJoinExistingMLSConversationUseCaseReturning(Either.Right(Unit)) + withInsertLostCommitSystemMessage(Either.Right(Unit)) + } + + staleEpochHandler.verifyEpoch(CONVERSATION_ID).shouldSucceed() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertLostCommitSystemMessage) + .with(eq(CONVERSATION_ID), any()) + .wasInvoked(once) + } + + + private class Arrangement(private val block: Arrangement.() -> Unit) : + SystemMessageInserterArrangement by SystemMessageInserterArrangementImpl(), + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + MLSConversationRepositoryArrangement by MLSConversationRepositoryArrangementImpl(), + JoinExistingMLSConversationUseCaseArrangement by JoinExistingMLSConversationUseCaseArrangementImpl() + { + fun arrange() = run { + block() + this@Arrangement to StaleEpochVerifierImpl( + systemMessageInserter = systemMessageInserter, + conversationRepository = conversationRepository, + mlsConversationRepository = mlsConversationRepository, + joinExistingMLSConversation = joinExistingMLSConversationUseCase + ) + } + } + + private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + + val CONVERSATION_ID = TestConversation.ID + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt new file mode 100644 index 00000000000..7ccb066128e --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt @@ -0,0 +1,178 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.sync.InMemoryIncrementalSyncRepository +import com.wire.kalium.logic.data.sync.IncrementalSyncRepository +import com.wire.kalium.logic.data.sync.IncrementalSyncStatus +import com.wire.kalium.logic.feature.TimestampKeyRepository +import com.wire.kalium.logic.feature.TimestampKeys +import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.featureFlags.KaliumConfigs +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.test_util.TestKaliumDispatcher +import io.mockative.Mock +import io.mockative.anything +import io.mockative.classOf +import io.mockative.eq +import io.mockative.given +import io.mockative.mock +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlin.test.Test + +@OptIn(ExperimentalCoroutinesApi::class) +class MLSMigrationManagerTest { + + @Test + fun givenMigrationUpdateTimerHasElapsed_whenObservingAndSyncFinishes_migrationIsUpdated() = + runTest(TestKaliumDispatcher.default) { + val (arrangement, _) = Arrangement() + .withIsMLSSupported(true) + .withHasRegisteredMLSClient(true) + .withLastMLSMigrationCheck(true) + .withRunMigrationSucceeds() + .withLastMLSMigrationCheckResetSucceeds() + .arrange() + + arrangement.incrementalSyncRepository.updateIncrementalSyncState(IncrementalSyncStatus.Live) + yield() + + verify(arrangement.mlsMigrationWorker) + .suspendFunction(arrangement.mlsMigrationWorker::runMigration) + .wasInvoked(once) + } + + @Test + fun givenMigrationUpdateTimerHasNotElapsed_whenObservingSyncFinishes_migrationIsNotUpdated() = + runTest(TestKaliumDispatcher.default) { + val (arrangement, _) = Arrangement() + .withIsMLSSupported(true) + .withHasRegisteredMLSClient(true) + .withLastMLSMigrationCheck(false) + .withRunMigrationSucceeds() + .arrange() + + arrangement.incrementalSyncRepository.updateIncrementalSyncState(IncrementalSyncStatus.Live) + yield() + + verify(arrangement.mlsMigrationWorker) + .suspendFunction(arrangement.mlsMigrationWorker::runMigration) + .wasNotInvoked() + } + + @Test + fun givenMLSSupportIsDisabled_whenObservingSyncFinishes_migrationIsNotUpdated() = + runTest(TestKaliumDispatcher.default) { + val (arrangement, _) = Arrangement() + .withIsMLSSupported(false) + .withRunMigrationSucceeds() + .arrange() + + arrangement.incrementalSyncRepository.updateIncrementalSyncState(IncrementalSyncStatus.Live) + yield() + + verify(arrangement.mlsMigrationWorker) + .suspendFunction(arrangement.mlsMigrationWorker::runMigration) + .wasNotInvoked() + } + + @Test + fun givenNoMLSClientIsRegistered_whenObservingSyncFinishes_migrationIsNotUpdated() = + runTest(TestKaliumDispatcher.default) { + val (arrangement, _) = Arrangement() + .withIsMLSSupported(true) + .withHasRegisteredMLSClient(false) + .withRunMigrationSucceeds() + .arrange() + + arrangement.incrementalSyncRepository.updateIncrementalSyncState(IncrementalSyncStatus.Live) + yield() + + verify(arrangement.mlsMigrationWorker) + .suspendFunction(arrangement.mlsMigrationWorker::runMigration) + .wasNotInvoked() + } + + private class Arrangement { + + val incrementalSyncRepository: IncrementalSyncRepository = InMemoryIncrementalSyncRepository() + + val kaliumConfigs = KaliumConfigs() + + @Mock + val clientRepository = mock(classOf()) + + @Mock + val featureSupport = mock(classOf()) + + @Mock + val timestampKeyRepository = mock(classOf()) + + @Mock + val mlsMigrationWorker = mock(classOf()) + + fun withRunMigrationSucceeds() = apply { + given(mlsMigrationWorker) + .suspendFunction(mlsMigrationWorker::runMigration) + .whenInvoked() + .thenReturn(Either.Right(Unit)) + } + + fun withLastMLSMigrationCheck(hasPassed: Boolean) = apply { + given(timestampKeyRepository) + .suspendFunction(timestampKeyRepository::hasPassed) + .whenInvokedWith(eq(TimestampKeys.LAST_MLS_MIGRATION_CHECK), anything()) + .thenReturn(Either.Right(hasPassed)) + } + + fun withLastMLSMigrationCheckResetSucceeds() = apply { + given(timestampKeyRepository) + .suspendFunction(timestampKeyRepository::reset) + .whenInvokedWith(eq(TimestampKeys.LAST_MLS_MIGRATION_CHECK)) + .thenReturn(Either.Right(Unit)) + } + + fun withIsMLSSupported(supported: Boolean) = apply { + given(featureSupport) + .invocation { featureSupport.isMLSSupported } + .thenReturn(supported) + } + + fun withHasRegisteredMLSClient(result: Boolean) = apply { + given(clientRepository) + .suspendFunction(clientRepository::hasRegisteredMLSClient) + .whenInvoked() + .thenReturn(Either.Right(result)) + } + + fun arrange() = this to MLSMigrationManagerImpl( + kaliumConfigs, + featureSupport, + incrementalSyncRepository, + lazy { clientRepository }, + lazy { timestampKeyRepository }, + lazy { mlsMigrationWorker }, + TestKaliumDispatcher + ) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt new file mode 100644 index 00000000000..290ad13898e --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt @@ -0,0 +1,274 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.SelfTeamIdProvider +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestTeam +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.test_util.TestNetworkResponseError +import com.wire.kalium.logic.util.shouldSucceed +import com.wire.kalium.network.api.base.model.ErrorResponse +import com.wire.kalium.network.exceptions.KaliumException +import io.mockative.Mock +import io.mockative.any +import io.mockative.anything +import io.mockative.classOf +import io.mockative.eq +import io.mockative.given +import io.mockative.mock +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlin.test.Test + +@OptIn(ExperimentalCoroutinesApi::class) +class MLSMigratorTest { + + @Test + fun givenTeamConversation_whenMigrating_thenProtocolIsUpdatedToMixedAndGroupIsEstablished() = runTest { + val conversation = TestConversation.CONVERSATION.copy( + type = Conversation.Type.GROUP, + teamId = TestTeam.TEAM_ID + ) + + val (arrangement, migrator) = Arrangement() + .withGetProteusTeamConversationsReturning(listOf(conversation.id)) + .withUpdateProtocolReturns() + .withFetchConversationSucceeding() + .withGetConversationProtocolInfoReturning(Arrangement.MIXED_PROTOCOL_INFO) + .withEstablishGroupSucceeds() + .withGetConversationMembersReturning(Arrangement.MEMBERS) + .withAddMembersSucceeds() + .arrange() + + migrator.migrateProteusConversations() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolRemotely) + .with(eq(conversation.id), eq(Conversation.Protocol.MIXED)) + .wasInvoked(once) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::establishMLSGroup) + .with(eq(Arrangement.MIXED_PROTOCOL_INFO.groupId), eq(emptyList())) + + verify(arrangement.mlsConversationRepository) + .suspendFunction(arrangement.mlsConversationRepository::addMemberToMLSGroup) + .with(eq(Arrangement.MIXED_PROTOCOL_INFO.groupId), eq(Arrangement.MEMBERS)) + } + + @Test + fun givenAnError_whenMigrating_thenStillConsiderItASuccess() = runTest { + val conversation = TestConversation.CONVERSATION.copy( + type = Conversation.Type.GROUP, + teamId = TestTeam.TEAM_ID + ) + + val (_, migrator) = Arrangement() + .withGetProteusTeamConversationsReturning(listOf(conversation.id)) + .withUpdateProtocolReturns() + .withFetchConversationSucceeding() + .withGetConversationProtocolInfoReturning(Arrangement.MIXED_PROTOCOL_INFO) + .withEstablishGroupFails() + .arrange() + + val result = migrator.migrateProteusConversations() + result.shouldSucceed() + } + + @Test + fun givenTeamConversation_whenFinalising_thenKnownUsersAreFetchedAndProtocolIsUpdatedToMls() = runTest { + val conversation = TestConversation.CONVERSATION.copy( + type = Conversation.Type.GROUP, + teamId = TestTeam.TEAM_ID + ) + + val (arrangement, migrator) = Arrangement() + .withFetchAllOtherUsersSucceeding() + .withGetProteusTeamConversationsReadyForFinalisationReturning(listOf(conversation.id)) + .withUpdateProtocolReturns() + .withFetchConversationSucceeding() + .withGetConversationProtocolInfoReturning(Arrangement.MLS_PROTOCOL_INFO) + .arrange() + + migrator.finaliseProteusConversations() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::fetchAllOtherUsers) + .wasInvoked(once) + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolRemotely) + .with(eq(conversation.id), eq(Conversation.Protocol.MLS)) + .wasInvoked(once) + } + + @Test + fun givenAnError_whenFinalising_thenStillConsiderItASuccess() = runTest { + val conversation = TestConversation.CONVERSATION.copy( + type = Conversation.Type.GROUP, + teamId = TestTeam.TEAM_ID + ) + + val (_, migrator) = Arrangement() + .withFetchAllOtherUsersSucceeding() + .withGetProteusTeamConversationsReadyForFinalisationReturning(listOf(conversation.id)) + .withUpdateProtocolReturns(Either.Left(TestNetworkResponseError.noNetworkConnection())) + .arrange() + + val result = migrator.finaliseProteusConversations() + result.shouldSucceed() + } + + private class Arrangement { + + @Mock + val userRepository = mock(classOf()) + + @Mock + val conversationRepository = mock(classOf()) + + @Mock + val mlsConversationRepository = mock(classOf()) + + @Mock + val selfTeamIdProvider = mock(classOf()) + + @Mock + val systemMessageInserter = mock(classOf()) + + fun withFetchAllOtherUsersSucceeding() = apply { + given(userRepository) + .suspendFunction(userRepository::fetchAllOtherUsers) + .whenInvoked() + .thenReturn(Either.Right(Unit)) + } + + fun withGetProteusTeamConversationsReturning(conversationsIds: List) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationIds) + .whenInvokedWith(eq(Conversation.Type.GROUP), eq(Conversation.Protocol.PROTEUS), anything()) + .thenReturn(Either.Right(conversationsIds)) + } + + fun withGetProteusTeamConversationsReadyForFinalisationReturning(conversationsIds: List) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getTeamConversationIdsReadyToCompleteMigration) + .whenInvokedWith(anything()) + .thenReturn(Either.Right(conversationsIds)) + } + + fun withGetConversationProtocolInfoReturning(protocolInfo: Conversation.ProtocolInfo) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationProtocolInfo) + .whenInvokedWith(anything()) + .thenReturn(Either.Right(protocolInfo)) + } + + fun withGetConversationMembersReturning(members: List) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationMembers) + .whenInvokedWith(anything()) + .thenReturn(Either.Right(members)) + } + + fun withFetchConversationSucceeding() = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchConversation) + .whenInvokedWith(anything()) + .thenReturn(Either.Right(Unit)) + } + fun withUpdateProtocolReturns(result: Either = Either.Right(true)) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::updateProtocolRemotely) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } + + fun withEstablishGroupSucceeds() = apply { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::establishMLSGroup) + .whenInvokedWith(anything(), anything()) + .thenReturn(Either.Right(Unit)) + } + + fun withEstablishGroupFails() = apply { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::establishMLSGroup) + .whenInvokedWith(anything(), anything()) + .thenReturn(Either.Left(NetworkFailure.ServerMiscommunication(MLS_STALE_MESSAGE_ERROR))) + } + + fun withAddMembersSucceeds() = apply { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::addMemberToMLSGroup) + .whenInvokedWith(anything(), anything()) + .thenReturn(Either.Right(Unit)) + } + + fun arrange() = this to MLSMigratorImpl( + TestUser.SELF.id, + selfTeamIdProvider, + userRepository, + conversationRepository, + mlsConversationRepository, + systemMessageInserter + ) + + init { + given(selfTeamIdProvider) + .suspendFunction(selfTeamIdProvider::invoke) + .whenInvoked() + .thenReturn(Either.Right(TestTeam.TEAM_ID)) + } + + companion object { + val MLS_STALE_MESSAGE_ERROR = KaliumException.InvalidRequestError( + ErrorResponse(409, "", "mls-stale-message") + ) + val MEMBERS = listOf(TestUser.USER_ID) + val MIXED_PROTOCOL_INFO = Conversation.ProtocolInfo.Mixed( + TestConversation.GROUP_ID, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) + val MLS_PROTOCOL_INFO = Conversation.ProtocolInfo.MLS( + TestConversation.GROUP_ID, + Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) + } + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt new file mode 100644 index 00000000000..b468208d3e7 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt @@ -0,0 +1,150 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.protocol + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs + +class OneOnOneProtocolSelectorTest { + + @Test + fun givenSelfUserIsNull_thenShouldReturnFailure() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withUserByIdReturning(Either.Right(TestUser.OTHER)) + withSelfUserReturning(null) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + assertIs(it.rootCause) + } + } + + @Test + fun givenFailureToFindOtherUser_thenShouldPropagateFailure() = runTest { + val failure = StorageFailure.DataNotFound + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF) + withUserByIdReturning(Either.Left(failure)) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertEquals(failure, it) + } + } + + @Test + fun givenOtherUserId_thenShouldCallRepoWithCorrectUserId() = runTest { + val failure = StorageFailure.DataNotFound + val (arrangement, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF) + withUserByIdReturning(Either.Left(failure)) + } + val otherUserId = TestUser.USER_ID + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::userById) + .with(eq(otherUserId)) + .wasInvoked(exactly = once) + } + + @Test + fun givenBothUsersSupportProteusAndMLS_thenShouldPreferMLS() = runTest { + val supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS) + val (arrangement, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = supportedProtocols)) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = supportedProtocols))) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldSucceed { + assertEquals(SupportedProtocol.MLS, it) + } + } + + @Test + fun givenBothUsersSupportProteusAndOnlyOneSupportsMLS_thenShouldPreferProteus() = runTest { + val bothProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS) + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = bothProtocols)) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.PROTEUS)))) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldSucceed { + assertEquals(SupportedProtocol.PROTEUS, it) + } + } + + @Test + fun givenBothUsersSupportMLS_thenShouldPreferMLS() = runTest { + val mlsSet = setOf(SupportedProtocol.MLS) + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = mlsSet)) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = mlsSet))) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldSucceed { + assertEquals(SupportedProtocol.MLS, it) + } + } + + @Test + fun givenUsersHaveNoProtocolInCommon_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.PROTEUS)))) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + private class Arrangement(private val configure: Arrangement.() -> Unit) : + UserRepositoryArrangement by UserRepositoryArrangementImpl() { + fun arrange(): Pair = run { + configure() + this@Arrangement to OneOnOneProtocolSelectorImpl(userRepository) + } + } + + private companion object { + fun arrange(configure: Arrangement.() -> Unit) = Arrangement(configure).arrange() + } + +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt index 492dd26ad43..284d378ccea 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/register/RegisterAccountUseCaseTest.kt @@ -214,7 +214,8 @@ class RegisterAccountUseCaseTest { connectionStatus = ConnectionState.ACCEPTED, previewPicture = null, completePicture = null, - availabilityStatus = UserAvailabilityStatus.NONE + availabilityStatus = UserAvailabilityStatus.NONE, + supportedProtocols = null ) val TEST_AUTH_TOKENS = AuthTokens( accessToken = "access_token", diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchKnownUserUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchKnownUserUseCaseTest.kt index ee3128fc633..6900e9baf7b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchKnownUserUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchKnownUserUseCaseTest.kt @@ -154,7 +154,8 @@ class SearchKnownUserUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ) val (_, searchKnownUsersUseCase) = Arrangement() @@ -345,7 +346,8 @@ class SearchKnownUserUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ) ) ) @@ -381,7 +383,8 @@ class SearchKnownUserUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ) ) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchUserUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchUserUseCaseTest.kt index dffb5733459..bff73a2cd21 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchUserUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/SearchUserUseCaseTest.kt @@ -302,7 +302,8 @@ class SearchUserUseCaseTest { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = null ) } ) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt new file mode 100644 index 00000000000..9bd57bf0003 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt @@ -0,0 +1,375 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.user + +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.client.Client +import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.COMPLETED_MIGRATION_CONFIGURATION +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.DISABLED_MIGRATION_CONFIGURATION +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.ONGOING_MIGRATION_CONFIGURATION +import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.framework.TestClient +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.Mock +import io.mockative.any +import io.mockative.anything +import io.mockative.given +import io.mockative.matching +import io.mockative.mock +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlin.test.Test + +class UpdateSupportedProtocolsUseCaseTest { + + @Test + fun givenMLSIsNotSupported_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(false) + .arrange() + + useCase.invoke().shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(anything()) + .wasNotInvoked() + } + + @Test + fun givenSupportedProtocolsHasNotChanged_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS)) + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = emptyList()) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(anything()) + .wasNotInvoked() + } + + @Test + fun givenProteusAsSupportedProtocol_whenInvokingUseCase_thenProteusIsIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = emptyList()) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { it.contains(SupportedProtocol.PROTEUS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenProteusIsNotSupportedButMigrationHasNotEnded_whenInvokingUseCase_thenProteusIsIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = emptyList()) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { it.contains(SupportedProtocol.PROTEUS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenProteusIsNotSupported_whenInvokingUseCase_thenProteusIsNotIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) + .withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = emptyList()) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { !it.contains(SupportedProtocol.PROTEUS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsSupportedAndAllActiveClientsAreCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = listOf( + TestClient.CLIENT.copy(isMLSCapable = true, lastActive = Clock.System.now()) + )) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { it.contains(SupportedProtocol.MLS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsSupportedAndAnInactiveClientIsNotMlsCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = listOf( + TestClient.CLIENT.copy(isMLSCapable = true, lastActive = Clock.System.now()), + TestClient.CLIENT.copy(isMLSCapable = false, lastActive = Instant.DISTANT_PAST) + )) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { it.contains(SupportedProtocol.MLS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsSupportedAndAllActiveClientsAreNotCapable_whenInvokingUseCase_thenMlsIsNotIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = listOf( + TestClient.CLIENT.copy(isMLSCapable = true, lastActive = Clock.System.now()), + TestClient.CLIENT.copy(isMLSCapable = false, lastActive = Clock.System.now()) + )) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { !it.contains(SupportedProtocol.MLS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsSupportedAndMigrationHasEnded_whenInvokingUseCase_thenMlsIsIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) + .withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = listOf( + TestClient.CLIENT.copy(isMLSCapable = true), + TestClient.CLIENT.copy(isMLSCapable = false) + )) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { it.contains(SupportedProtocol.MLS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenMigrationIsMissingAndAllClientsAreCapable_whenInvokingUseCase_thenMlsIsIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS)) + .withGetMigrationConfigurationFailing(StorageFailure.DataNotFound) + .withGetSelfClientsSuccessful(clients = listOf( + TestClient.CLIENT.copy(isMLSCapable = true) + )) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { it.contains(SupportedProtocol.MLS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenMlsIsNotSupportedAndAllClientsAreCapable_whenInvokingUseCase_thenMlsIsNotIncluded() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful() + .withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) + .withGetMigrationConfigurationSuccessful(DISABLED_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = listOf( + TestClient.CLIENT.copy(isMLSCapable = true) + )) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(matching { !it.contains(SupportedProtocol.MLS) }) + .wasInvoked(exactly = once) + } + + @Test + fun givenSupportedProtocolsAreNotConfigured_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + val (arrangement, useCase) = Arrangement() + .withIsMLSSupported(true) + .withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS)) + .withGetSupportedProtocolsFailing(StorageFailure.DataNotFound) + .withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) + .withGetSelfClientsSuccessful(clients = emptyList()) + .withUpdateSupportedProtocolsSuccessful() + .arrange() + + useCase.invoke().shouldSucceed() + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::updateSupportedProtocols) + .with(anything()) + .wasNotInvoked() + } + + private class Arrangement { + @Mock + val clientRepository = mock(ClientRepository::class) + @Mock + val userRepository = mock(UserRepository::class) + @Mock + val userConfigRepository = mock(UserConfigRepository::class) + @Mock + val featureSupport = mock(FeatureSupport::class) + + fun withIsMLSSupported(supported: Boolean) = apply { + given(featureSupport) + .invocation { featureSupport.isMLSSupported } + .thenReturn(supported) + } + + fun withGetSelfUserSuccessful(supportedProtocols: Set? = null) = apply { + given(userRepository) + .suspendFunction(userRepository::getSelfUser) + .whenInvoked() + .thenReturn(TestUser.SELF.copy( + supportedProtocols = supportedProtocols + )) + } + + fun withUpdateSupportedProtocolsSuccessful() = apply { + given(userRepository) + .suspendFunction(userRepository::updateSupportedProtocols) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + fun withGetMigrationConfigurationSuccessful(migrationConfiguration: MLSMigrationModel) = apply { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getMigrationConfiguration) + .whenInvoked() + .thenReturn(Either.Right(migrationConfiguration)) + } + + fun withGetMigrationConfigurationFailing(failure: StorageFailure) = apply { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getMigrationConfiguration) + .whenInvoked() + .thenReturn(Either.Left(failure)) + } + + fun withGetSupportedProtocolsSuccessful(supportedProtocols: Set) = apply { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getSupportedProtocols) + .whenInvoked() + .thenReturn(Either.Right(supportedProtocols)) + } + + fun withGetSupportedProtocolsFailing(failure: StorageFailure) = apply { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getSupportedProtocols) + .whenInvoked() + .thenReturn(Either.Left(failure)) + } + + fun withGetSelfClientsSuccessful(clients: List) = apply { + given(clientRepository) + .suspendFunction(clientRepository::selfListOfClients) + .whenInvoked() + .thenReturn(Either.Right(clients)) + } + + fun arrange() = this to UpdateSupportedProtocolsUseCaseImpl( + clientRepository, + userRepository, + userConfigRepository, + featureSupport + ) + + companion object { + val ONGOING_MIGRATION_CONFIGURATION = MLSMigrationModel( + Instant.DISTANT_PAST, + Instant.DISTANT_FUTURE, + Status.ENABLED + ) + val COMPLETED_MIGRATION_CONFIGURATION = ONGOING_MIGRATION_CONFIGURATION + .copy(endTime = Instant.DISTANT_PAST) + val DISABLED_MIGRATION_CONFIGURATION = ONGOING_MIGRATION_CONFIGURATION + .copy(status = Status.DISABLED) + } + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UploadUserAvatarUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UploadUserAvatarUseCaseTest.kt index 0be4c2ff7d7..ddbb9b65fe6 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UploadUserAvatarUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UploadUserAvatarUseCaseTest.kt @@ -24,6 +24,7 @@ import com.wire.kalium.logic.data.asset.KaliumFileSystem import com.wire.kalium.logic.data.asset.UploadedAssetId import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.SelfUser +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserAssetId import com.wire.kalium.logic.data.user.UserAvailabilityStatus import com.wire.kalium.logic.data.user.UserId @@ -129,7 +130,9 @@ class UploadUserAvatarUseCaseTest { ConnectionState.ACCEPTED, UserAssetId("value1", "domain"), UserAssetId("value2", "domain"), - UserAvailabilityStatus.NONE + UserAvailabilityStatus.NONE, + null, + setOf(SupportedProtocol.PROTEUS) ) fun withStoredData(data: ByteArray, dataNamePath: Path): Arrangement { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestClient.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestClient.kt index aae0b69a4db..2754bc7e86d 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestClient.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestClient.kt @@ -38,7 +38,8 @@ object TestClient { label = "label", isVerified = false, isValid = true, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) val SELF_USER_ID = UserId("self-user-id", "domain") diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversation.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversation.kt index faca03c87fa..76694a38ca8 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversation.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversation.kt @@ -57,12 +57,12 @@ object TestConversation { val ID = ConversationId(conversationValue, conversationDomain) fun id(suffix: Int = 0) = ConversationId("${conversationValue}_$suffix", conversationDomain) - val ONE_ON_ONE = Conversation( + fun ONE_ON_ONE(protocolInfo: ProtocolInfo = ProtocolInfo.Proteus) = Conversation( ID.copy(value = "1O1 ID"), "ONE_ON_ONE Name", Conversation.Type.ONE_ON_ONE, TestTeam.TEAM_ID, - ProtocolInfo.Proteus, + protocolInfo, MutedConversationStatus.AllAllowed, null, null, @@ -163,7 +163,9 @@ object TestConversation { userDefederated = null, archived = false, archivedDateTime = null, - verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED + verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED, + userSupportedProtocols = null, + userActiveOneOnOneConversationId = null, ) fun one_on_one(convId: ConversationId) = Conversation( @@ -316,7 +318,27 @@ object TestConversation { userDefederated = null, archived = false, archivedDateTime = null, - verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED + verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED, + userSupportedProtocols = null, + userActiveOneOnOneConversationId = null, + ) + + val MLS_PROTOCOL_INFO = ProtocolInfo.MLS( + GROUP_ID, + ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ) + + val PROTEUS_PROTOCOL_INFO = ProtocolInfo.Proteus + + val MIXED_PROTOCOL_INFO = ProtocolInfo.Mixed( + GROUP_ID, + ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 ) val CONVERSATION = Conversation( @@ -341,13 +363,6 @@ object TestConversation { verificationStatus = Conversation.VerificationStatus.NOT_VERIFIED ) - val MLS_PROTOCOL_INFO = ProtocolInfo.MLS( - GROUP_ID, - ProtocolInfo.MLS.GroupState.PENDING_JOIN, - 0UL, - Instant.parse("2021-03-30T15:36:00.000Z"), - cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 - ) val MLS_CONVERSATION = Conversation( ConversationId("conv_id", "domain"), "MLS Name", @@ -371,4 +386,7 @@ object TestConversation { ) val CONVERSATION_CODE_INFO: ConversationCodeInfo = ConversationCodeInfo("conv_id_value", "name") + val MIXED_CONVERSATION = MLS_CONVERSATION.copy( + protocol = MIXED_PROTOCOL_INFO + ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversationDetails.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversationDetails.kt index 8791d0b3256..02ac4d426f2 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversationDetails.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestConversationDetails.kt @@ -38,7 +38,7 @@ object TestConversationDetails { ) val CONVERSATION_ONE_ONE = ConversationDetails.OneOne( - TestConversation.ONE_ON_ONE, + TestConversation.ONE_ON_ONE(), TestUser.OTHER, LegalHoldStatus.DISABLED, UserType.EXTERNAL, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt index c39b7e60cf7..1c85bdef3fc 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestEvent.kt @@ -37,6 +37,7 @@ object TestEvent { eventId, TestConversation.ID, false, + false, TestUser.USER_ID, members, "2022-03-30T15:36:00.000Z" @@ -46,6 +47,7 @@ object TestEvent { eventId, TestConversation.ID, false, + false, TestUser.USER_ID, listOf(), "2022-03-30T15:36:00.000Z" @@ -56,6 +58,7 @@ object TestEvent { TestConversation.ID, "2022-03-30T15:36:00.000Z", false, + false, member ) @@ -64,6 +67,7 @@ object TestEvent { TestConversation.ID, "2022-03-30T15:36:00.000Z", false, + false, MutedConversationStatus.AllAllowed, "2022-03-30T15:36:00.000Zp" ) @@ -74,6 +78,7 @@ object TestEvent { TestConversation.ID, "2022-03-30T15:36:00.000Z", false, + false, "2022-03-31T16:36:00.000Zp", isArchiving, ) @@ -82,20 +87,22 @@ object TestEvent { eventId, TestConversation.ID, false, + false ) - fun clientRemove(eventId: String = "eventId", clientId: ClientId) = Event.User.ClientRemove(false, eventId, clientId) - fun userDelete(eventId: String = "eventId", userId: UserId) = Event.User.UserDelete(false, eventId, userId) + fun clientRemove(eventId: String = "eventId", clientId: ClientId) = Event.User.ClientRemove(false, false, eventId, clientId) + fun userDelete(eventId: String = "eventId", userId: UserId) = Event.User.UserDelete(false, false, eventId, userId) fun updateUser(eventId: String = "eventId", userId: UserId) = Event.User.Update( eventId, - false, userId.toString(), null, false, "newName", null, null, null, null + false, false, userId, null, false, "newName", null, null, null, null, null ) fun newClient(eventId: String = "eventId", clientId: ClientId = ClientId("client")) = Event.User.NewClient( - false, eventId, TestClient.CLIENT + false, false, eventId, TestClient.CLIENT ) - fun newConnection(eventId: String = "eventId") = Event.User.NewConnection( + fun newConnection(eventId: String = "eventId", status: ConnectionState = ConnectionState.PENDING) = Event.User.NewConnection( + false, false, eventId, Connection( @@ -104,7 +111,7 @@ object TestEvent { lastUpdate = "lastUpdate", qualifiedConversationId = TestConversation.ID, qualifiedToId = TestUser.USER_ID, - status = ConnectionState.PENDING, + status = status, toId = "told?" ) ) @@ -113,6 +120,7 @@ object TestEvent { eventId, TestConversation.ID, false, + false, TestUser.USER_ID, "2022-03-30T15:36:00.000Z" ) @@ -121,6 +129,7 @@ object TestEvent { eventId, TestConversation.ID, false, + false, "newName", TestUser.USER_ID, "2022-03-30T15:36:00.000Z" @@ -130,6 +139,7 @@ object TestEvent { eventId, TestConversation.ID, false, + false, receiptMode = Conversation.ReceiptMode.ENABLED, senderUserId = TestUser.USER_ID ) @@ -139,13 +149,15 @@ object TestEvent { teamId = "teamId", name = "teamName", transient = false, + live = false, icon = "icon", ) fun teamMemberJoin(eventId: String = "eventId") = Event.Team.MemberJoin( eventId, teamId = "teamId", - false, + transient = false, + live = false, memberId = "memberId" ) @@ -154,7 +166,8 @@ object TestEvent { teamId = "teamId", memberId = "memberId", timestampIso = "2022-03-30T15:36:00.000Z", - transient = false + transient = false, + live = false ) fun teamMemberUpdate(eventId: String = "eventId", permissionCode: Int) = Event.Team.MemberUpdate( @@ -162,13 +175,15 @@ object TestEvent { teamId = "teamId", memberId = "memberId", permissionCode = permissionCode, - transient = false + transient = false, + live = false ) fun timerChanged(eventId: String = "eventId") = Event.Conversation.ConversationMessageTimer( id = eventId, conversationId = TestConversation.ID, transient = false, + live = false, messageTimer = 3000, senderUserId = TestUser.USER_ID, timestampIso = "2022-03-30T15:36:00.000Z" @@ -177,6 +192,7 @@ object TestEvent { fun userPropertyReadReceiptMode(eventId: String = "eventId") = Event.UserProperty.ReadReceiptModeSet( id = eventId, transient = false, + live = false, value = true ) @@ -188,6 +204,7 @@ object TestEvent { "eventId", TestConversation.ID, false, + false, senderUserId, TestClient.CLIENT_ID, "time", @@ -201,6 +218,7 @@ object TestEvent { "eventId", TestConversation.ID, false, + false, null, TestUser.USER_ID, timestamp.toIsoDateTimeString(), @@ -211,6 +229,7 @@ object TestEvent { id = "eventId", conversationId = TestConversation.ID, transient = false, + live = false, timestampIso = "timestamp", conversation = TestConversation.CONVERSATION_RESPONSE, senderUserId = TestUser.SELF.id @@ -220,6 +239,7 @@ object TestEvent { "eventId", TestConversation.ID, false, + false, TestUser.USER_ID, "dummy-message", timestampIso = "2022-03-30T15:36:00.000Z" @@ -230,13 +250,15 @@ object TestEvent { conversationId = TestConversation.ID, data = TestConversation.CONVERSATION_RESPONSE, qualifiedFrom = TestUser.USER_ID, - transient = false + transient = false, + live = false ) fun codeUpdated() = Event.Conversation.CodeUpdated( id = "eventId", conversationId = TestConversation.ID, transient = false, + live = false, code = "code", key = "key", uri = "uri", @@ -247,14 +269,25 @@ object TestEvent { id = "eventId", conversationId = TestConversation.ID, transient = false, + live = false ) fun typingIndicator(typingIndicatorMode: Conversation.TypingIndicatorMode) = Event.Conversation.TypingIndicator( id = "eventId", conversationId = TestConversation.ID, transient = true, + live = false, senderUserId = TestUser.OTHER_USER_ID, timestampIso = "2022-03-30T15:36:00.000Z", typingIndicatorMode = typingIndicatorMode ) + + fun newConversationProtocolEvent() = Event.Conversation.ConversationProtocol( + id = "eventId", + conversationId = TestConversation.ID, + transient = false, + live = false, + protocol = Conversation.Protocol.MIXED, + senderUserId = TestUser.OTHER_USER_ID + ) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestUser.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestUser.kt index 62ae24be2b9..31ad4feb87d 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestUser.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/framework/TestUser.kt @@ -22,6 +22,7 @@ import com.wire.kalium.logic.data.id.TeamId import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.OtherUser import com.wire.kalium.logic.data.user.SelfUser +import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserAssetId import com.wire.kalium.logic.data.user.UserAvailabilityStatus import com.wire.kalium.logic.data.user.UserId @@ -30,11 +31,13 @@ import com.wire.kalium.network.api.base.authenticated.userDetails.ListUsersDTO import com.wire.kalium.network.api.base.model.AssetSizeDTO import com.wire.kalium.network.api.base.model.LegalHoldStatusResponse import com.wire.kalium.network.api.base.model.SelfUserDTO +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.base.model.UserAssetDTO import com.wire.kalium.network.api.base.model.UserAssetTypeDTO import com.wire.kalium.network.api.base.model.UserProfileDTO import com.wire.kalium.persistence.dao.ConnectionEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.dao.UserAvailabilityStatusEntity import com.wire.kalium.persistence.dao.UserDetailsEntity import com.wire.kalium.persistence.dao.UserEntity @@ -67,7 +70,8 @@ object TestUser { connectionStatus = ConnectionState.ACCEPTED, previewPicture = UserAssetId("value1", "domain"), completePicture = UserAssetId("value2", "domain"), - availabilityStatus = UserAvailabilityStatus.NONE + availabilityStatus = UserAvailabilityStatus.NONE, + supportedProtocols = setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS) ) val OTHER = OtherUser( @@ -86,7 +90,8 @@ object TestUser { botService = null, deleted = false, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = setOf(SupportedProtocol.PROTEUS) ) val ENTITY = UserEntity( @@ -104,8 +109,11 @@ object TestUser { userType = UserTypeEntity.EXTERNAL, botService = null, deleted = false, + hasIncompleteMetadata = false, expiresAt = null, - defederated = false + defederated = false, + supportedProtocols = setOf(SupportedProtocolEntity.MLS), + activeOneOnOneConversationId = null ) val DETAILS_ENTITY = UserDetailsEntity( @@ -125,7 +133,9 @@ object TestUser { deleted = false, expiresAt = null, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = setOf(SupportedProtocolEntity.MLS), + activeOneOnOneConversationId = null ) val USER_PROFILE_DTO = UserProfileDTO( @@ -143,7 +153,8 @@ object TestUser { deleted = false, expiresAt = null, nonQualifiedId = NETWORK_ID.value, - service = null + service = null, + supportedProtocols = listOf(SupportedProtocolDTO.MLS) ) val SELF_USER_DTO = SelfUserDTO( @@ -161,7 +172,8 @@ object TestUser { locale = "", managedByDTO = null, phone = null, - ssoID = null + ssoID = null, + supportedProtocols = null ) val LIST_USERS_DTO = ListUsersDTO( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiverTest.kt index 5a40636e6d9..eb950571dc0 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiverTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ConversationEventReceiverTest.kt @@ -30,6 +30,7 @@ import com.wire.kalium.logic.sync.receiver.conversation.MemberChangeEventHandler import com.wire.kalium.logic.sync.receiver.conversation.MemberJoinEventHandler import com.wire.kalium.logic.sync.receiver.conversation.MemberLeaveEventHandler import com.wire.kalium.logic.sync.receiver.conversation.NewConversationEventHandler +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandler import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHandler import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler import com.wire.kalium.logic.sync.receiver.conversation.message.NewMessageEventHandler @@ -173,12 +174,14 @@ class ConversationEventReceiverTest { fun givenMLSWelcomeEvent_whenOnEventInvoked_thenMlsWelcomeHandlerShouldBeCalled() = runTest { val mlsWelcomeEvent = TestEvent.newMLSWelcomeEvent() - val (arrangement, featureConfigEventReceiver) = Arrangement().arrange() + val (arrangement, featureConfigEventReceiver) = Arrangement() + .withMLSWelcomeEventSucceeded() + .arrange() val result = featureConfigEventReceiver.onEvent(mlsWelcomeEvent) - verify(arrangement.mLSWelcomeEventHandler) - .suspendFunction(arrangement.mLSWelcomeEventHandler::handle) + verify(arrangement.mlsWelcomeEventHandler) + .suspendFunction(arrangement.mlsWelcomeEventHandler::handle) .with(eq(mlsWelcomeEvent)) .wasInvoked(once) result.shouldSucceed() @@ -371,7 +374,7 @@ class ConversationEventReceiverTest { val renamedConversationEventHandler = mock(classOf()) @Mock - val mLSWelcomeEventHandler = mock(classOf()) + val mlsWelcomeEventHandler = mock(classOf()) @Mock val memberChangeEventHandler = mock(classOf()) @@ -394,6 +397,9 @@ class ConversationEventReceiverTest { @Mock val typingIndicatorHandler = mock(classOf()) + @Mock + val protocolUpdateEventHandler = mock(classOf()) + private val conversationEventReceiver: ConversationEventReceiver = ConversationEventReceiverImpl( newMessageHandler = newMessageEventHandler, newConversationHandler = newConversationEventHandler, @@ -401,13 +407,14 @@ class ConversationEventReceiverTest { memberJoinHandler = memberJoinEventHandler, memberLeaveHandler = memberLeaveEventHandler, memberChangeHandler = memberChangeEventHandler, - mlsWelcomeHandler = mLSWelcomeEventHandler, + mlsWelcomeHandler = mlsWelcomeEventHandler, renamedConversationHandler = renamedConversationEventHandler, receiptModeUpdateEventHandler = receiptModeUpdateEventHandler, conversationMessageTimerEventHandler = conversationMessageTimerEventHandler, codeUpdatedHandler = codeUpdatedHandler, codeDeletedHandler = codeDeletedHandler, - typingIndicatorHandler = typingIndicatorHandler + typingIndicatorHandler = typingIndicatorHandler, + protocolUpdateEventHandler = protocolUpdateEventHandler ) fun arrange(block: Arrangement.() -> Unit = {}) = apply(block).run { @@ -441,6 +448,13 @@ class ConversationEventReceiverTest { .whenInvokedWith(any()) .thenReturn(result) } + + fun withMLSWelcomeEventSucceeded() = apply { + given(mlsWelcomeEventHandler) + .suspendFunction(mlsWelcomeEventHandler::handle) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } } companion object { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiverTest.kt index 72dac9d63fa..e7836d14f16 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiverTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FeatureConfigEventReceiverTest.kt @@ -24,7 +24,6 @@ import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.event.Event import com.wire.kalium.logic.data.featureConfig.ConferenceCallingModel import com.wire.kalium.logic.data.featureConfig.ConfigsStatusModel -import com.wire.kalium.logic.data.featureConfig.MLSModel import com.wire.kalium.logic.data.featureConfig.SelfDeletingMessagesConfigModel import com.wire.kalium.logic.data.featureConfig.SelfDeletingMessagesModel import com.wire.kalium.logic.data.featureConfig.Status @@ -35,10 +34,12 @@ import com.wire.kalium.logic.feature.featureConfig.handler.E2EIConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.FileSharingConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.GuestRoomConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SecondFactorPasswordChallengeConfigHandler import com.wire.kalium.logic.feature.featureConfig.handler.SelfDeletingMessagesConfigHandler import com.wire.kalium.logic.feature.selfDeletingMessages.TeamSelfDeleteTimer import com.wire.kalium.logic.feature.selfDeletingMessages.TeamSettingsSelfDeletionStatus +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either @@ -58,53 +59,6 @@ import kotlin.time.toDuration class FeatureConfigEventReceiverTest { - @Test - fun givenMLSUpdatedEventGrantingAccessForSelfUser_whenProcessingEvent_ThenSetMLSEnabledToTrue() = runTest { - val (arrangement, featureConfigEventReceiver) = Arrangement() - .withSettingMLSEnabledSuccessful() - .arrange() - - featureConfigEventReceiver.onEvent( - arrangement.newMLSUpdatedEvent(MLSModel(listOf(TestUser.SELF.id.toPlainID()), Status.ENABLED)) - ) - - verify(arrangement.userConfigRepository) - .function(arrangement.userConfigRepository::setMLSEnabled) - .with(eq(true)) - .wasInvoked(once) - } - - @Test - fun givenMLSUpdatedEventRemovingAccessForSelfUser_whenProcessingEvent_ThenSetMLSEnabledToFalse() = runTest { - val (arrangement, featureConfigEventReceiver) = Arrangement() - .withSettingMLSEnabledSuccessful() - .arrange() - - featureConfigEventReceiver.onEvent(arrangement.newMLSUpdatedEvent(MLSModel(emptyList(), Status.ENABLED))) - - verify(arrangement.userConfigRepository) - .function(arrangement.userConfigRepository::setMLSEnabled) - .with(eq(false)) - .wasInvoked(once) - } - - @Suppress("MaxLineLength") - @Test - fun givenMLSUpdatedEventGrantingAccessForSelfUserButStatusIsDisabled_whenProcessingEvent_ThenSetMLSEnabledToFalse() = runTest { - val (arrangement, featureConfigEventReceiver) = Arrangement() - .withSettingMLSEnabledSuccessful() - .arrange() - - featureConfigEventReceiver.onEvent( - arrangement.newMLSUpdatedEvent(MLSModel(listOf(TestUser.SELF.id.toPlainID()), Status.DISABLED)) - ) - - verify(arrangement.userConfigRepository) - .function(arrangement.userConfigRepository::setMLSEnabled) - .with(eq(false)) - .wasInvoked(once) - } - @Test fun givenFileSharingUpdatedEventWithStatusEnabled_whenProcessingEvent_ThenSetFileSharingStatusToTrue() = runTest { val (arrangement, featureConfigEventReceiver) = Arrangement() @@ -332,11 +286,15 @@ class FeatureConfigEventReceiverTest { @Mock val userConfigRepository = mock(classOf()) + @Mock + val updateSupportedProtocolsAndResolveOneOnOnes = mock(classOf()) + private val featureConfigEventReceiver: FeatureConfigEventReceiver by lazy { FeatureConfigEventReceiverImpl( GuestRoomConfigHandler(userConfigRepository, kaliumConfigs), FileSharingConfigHandler(userConfigRepository), - MLSConfigHandler(userConfigRepository, TestUser.SELF.id), + MLSConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes, TestUser.SELF.id), + MLSMigrationConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes), ClassifiedDomainsConfigHandler(userConfigRepository), ConferenceCallingConfigHandler(userConfigRepository), SecondFactorPasswordChallengeConfigHandler(userConfigRepository), @@ -346,13 +304,6 @@ class FeatureConfigEventReceiverTest { ) } - fun withSettingMLSEnabledSuccessful() = apply { - given(userConfigRepository) - .function(userConfigRepository::setMLSEnabled) - .whenInvokedWith(any()) - .thenReturn(Either.Right(Unit)) - } - fun withSettingFileSharingEnabledSuccessful() = apply { given(userConfigRepository) .function(userConfigRepository::setFileSharingStatus) @@ -404,21 +355,17 @@ class FeatureConfigEventReceiverTest { .thenReturn(Either.Right(Unit)) } - fun newMLSUpdatedEvent( - model: MLSModel - ) = Event.FeatureConfig.MLSUpdated("eventId", false, model) - fun newFileSharingUpdatedEvent( model: ConfigsStatusModel - ) = Event.FeatureConfig.FileSharingUpdated("eventId", false, model) + ) = Event.FeatureConfig.FileSharingUpdated("eventId", false, false, model) fun newConferenceCallingUpdatedEvent( model: ConferenceCallingModel - ) = Event.FeatureConfig.ConferenceCallingUpdated("eventId", false, model) + ) = Event.FeatureConfig.ConferenceCallingUpdated("eventId", false, false, model) fun newSelfDeletingMessagesUpdatedEvent( model: SelfDeletingMessagesModel - ) = Event.FeatureConfig.SelfDeletingMessagesConfig("eventId", false, model) + ) = Event.FeatureConfig.SelfDeletingMessagesConfig("eventId", false, false, model) fun arrange() = this to featureConfigEventReceiver } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FederationEventReceiverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FederationEventReceiverTest.kt index 0c2c53e712c..b375e0d890a 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FederationEventReceiverTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/FederationEventReceiverTest.kt @@ -100,6 +100,7 @@ class FederationEventReceiverTest { val event = Event.Federation.Delete( "id", true, + false, defederatedDomain ) @@ -170,6 +171,7 @@ class FederationEventReceiverTest { val event = Event.Federation.ConnectionRemoved( "id", true, + false, listOf(defederatedDomain, defederatedDomainTwo) ) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt new file mode 100644 index 00000000000..7fc4f6ea68c --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/ProtocolUpdateEventHandlerTest.kt @@ -0,0 +1,129 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.sync.receiver + +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.framework.TestEvent +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandler +import com.wire.kalium.logic.sync.receiver.conversation.ProtocolUpdateEventHandlerImpl +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangement +import com.wire.kalium.logic.util.arrangement.SystemMessageInserterArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.eq +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class ProtocolUpdateEventHandlerTest { + + @Test + fun givenEventIsSuccessfullyConsumed_whenHandlerInvoked_thenProtocolIsUpdatedLocally() = runTest { + val event = TestEvent.newConversationProtocolEvent() + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Right(true)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolLocally) + .with(eq(event.conversationId), eq(event.protocol)) + .wasInvoked(exactly = once) + } + + @Test + fun givenEventFailsToBeConsumed_whenHandlerInvoked_thenErrorIsPropagated() = runTest { + val event = TestEvent.newConversationProtocolEvent() + val failure = NetworkFailure.NoNetworkConnection(null) + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Left(failure)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldFail { + assertEquals(failure, it) + } + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolLocally) + .with(eq(event.conversationId), eq(event.protocol)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolWasNotAlreadyUpdated_whenHandlerInvoked_thenSystemMessageIsInserted() = runTest { + val event = TestEvent.newConversationProtocolEvent() + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Right(true)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateProtocolLocally) + .with(eq(event.conversationId), eq(event.protocol)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProtocolWasAlreadyUpdated_whenHandlerInvoked_thenSystemMessageIsNotInserted() = runTest { + val event = TestEvent.newConversationProtocolEvent() + + val (arrangement, useCase) = arrange { + withUpdateProtocolLocally(Either.Right(false)) + withInsertProtocolChangedSystemMessage() + } + + useCase.handle(event).shouldSucceed() + + verify(arrangement.systemMessageInserter) + .suspendFunction(arrangement.systemMessageInserter::insertProtocolChangedSystemMessage) + .with(eq(event.conversationId), eq(event.senderUserId), eq(event.protocol)) + .wasNotInvoked() + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + SystemMessageInserterArrangement by SystemMessageInserterArrangementImpl() + { + private val protocolUpdateEventHandler: ProtocolUpdateEventHandler = ProtocolUpdateEventHandlerImpl( + conversationRepository, + systemMessageInserter + ) + + fun arrange() = run { + block() + this@Arrangement to protocolUpdateEventHandler + } + } + + companion object { + private fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiverTest.kt index fb581038b76..77bce7c129d 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiverTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/TeamEventReceiverTest.kt @@ -18,8 +18,8 @@ package com.wire.kalium.logic.sync.receiver +import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.PersistMessageUseCase import com.wire.kalium.logic.data.team.TeamRepository import com.wire.kalium.logic.data.team.TeamRole @@ -76,7 +76,7 @@ class TeamEventReceiverTest { val event = TestEvent.teamMemberLeave() val (arrangement, eventReceiver) = Arrangement() .withMemberLeaveSuccess() - .withConversationIdsByUserId(listOf(TestConversation.ID)) + .withConversationsByUserId(listOf(TestConversation.CONVERSATION)) .withPersistMessageSuccess() .arrange() @@ -141,7 +141,8 @@ class TeamEventReceiverTest { } fun withUpdateTeamSuccess() = apply { - given(teamRepository).suspendFunction(teamRepository::updateTeam).whenInvokedWith(any()).thenReturn(Either.Right(Unit)) + given(teamRepository).suspendFunction(teamRepository::updateTeam).whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) } fun withMemberJoinSuccess() = apply { @@ -161,8 +162,8 @@ class TeamEventReceiverTest { .whenInvokedWith(any(), any(), any()).thenReturn(Either.Right(Unit)) } - fun withConversationIdsByUserId(conversationIds: List) = apply { - given(conversationRepository).suspendFunction(conversationRepository::getConversationIdsByUserId) + fun withConversationsByUserId(conversationIds: List) = apply { + given(conversationRepository).suspendFunction(conversationRepository::getConversationsByUserId) .whenInvokedWith(any()).thenReturn(Either.Right(conversationIds)) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiverTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiverTest.kt index 4d5747720d1..6a0ee41f62c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiverTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/UserEventReceiverTest.kt @@ -18,21 +18,25 @@ package com.wire.kalium.logic.sync.receiver -import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.conversation.ClientId +import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.logout.LogoutReason +import com.wire.kalium.logic.data.user.ConnectionState import com.wire.kalium.logic.data.user.UserId -import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.CurrentClientIdProvider import com.wire.kalium.logic.feature.auth.LogoutUseCase import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.test_util.TestKaliumDispatcher +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.UserRepositoryArrangementImpl +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangement +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangementImpl import io.mockative.Mock import io.mockative.any import io.mockative.classOf @@ -41,19 +45,23 @@ import io.mockative.given import io.mockative.mock import io.mockative.once import io.mockative.verify +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertIs +import kotlin.time.Duration.Companion.ZERO +import kotlin.time.Duration.Companion.seconds class UserEventReceiverTest { @Test fun givenRemoveClientEvent_whenTheClientIdIsEqualCurrentClient_SoftLogoutInvoked() = runTest { val event = TestEvent.clientRemove(EVENT_ID, CLIENT_ID1) - val (arrangement, eventReceiver) = Arrangement() - .withCurrentClientIdIs(CLIENT_ID1) - .withLogoutUseCaseSucceed() - .arrange() + val (arrangement, eventReceiver) = arrange { + withCurrentClientIdIs(CLIENT_ID1) + withLogoutUseCaseSucceed() + } eventReceiver.onEvent(event) @@ -66,10 +74,10 @@ class UserEventReceiverTest { @Test fun givenRemoveClientEvent_whenTheClientIdIsNotEqualCurrentClient_SoftLogoutNotInvoked() = runTest { val event = TestEvent.clientRemove(EVENT_ID, CLIENT_ID1) - val (arrangement, eventReceiver) = Arrangement() - .withCurrentClientIdIs(CLIENT_ID2) - .withLogoutUseCaseSucceed() - .arrange() + val (arrangement, eventReceiver) = arrange { + withCurrentClientIdIs(CLIENT_ID2) + withLogoutUseCaseSucceed() + } eventReceiver.onEvent(event) @@ -82,9 +90,9 @@ class UserEventReceiverTest { @Test fun givenDeleteAccountEvent_SoftLogoutInvoked() = runTest { val event = TestEvent.userDelete(userId = SELF_USER_ID) - val (arrangement, eventReceiver) = Arrangement() - .withLogoutUseCaseSucceed() - .arrange() + val (arrangement, eventReceiver) = arrange { + withLogoutUseCaseSucceed() + } eventReceiver.onEvent(event) @@ -97,10 +105,11 @@ class UserEventReceiverTest { @Test fun givenUserDeleteEvent_RepoAndPersisMessageAreInvoked() = runTest { val event = TestEvent.userDelete(userId = OTHER_USER_ID) - val (arrangement, eventReceiver) = Arrangement() - .withUserDeleteSuccess() - .withConversationIdsByUserId(listOf(TestConversation.ID)) - .arrange() + val (arrangement, eventReceiver) = arrange { + withRemoveUserSuccess() + withDeleteUserFromConversationsSuccess() + withConversationsByUserId(listOf(TestConversation.CONVERSATION)) + } eventReceiver.onEvent(event) @@ -118,9 +127,9 @@ class UserEventReceiverTest { @Test fun givenUserUpdateEvent_RepoIsInvoked() = runTest { val event = TestEvent.updateUser(userId = SELF_USER_ID) - val (arrangement, eventReceiver) = Arrangement() - .withUpdateUserSuccess() - .arrange() + val (arrangement, eventReceiver) = arrange { + withUpdateUserSuccess() + } val result = eventReceiver.onEvent(event) @@ -134,9 +143,9 @@ class UserEventReceiverTest { @Test fun givenUserUpdateEvent_whenUserIsNotFoundInLocalDB_thenShouldIgnoreThisEventFailure() = runTest { val event = TestEvent.updateUser(userId = OTHER_USER_ID) - val (_, eventReceiver) = Arrangement() - .withUpdateUserFailure(StorageFailure.DataNotFound) - .arrange() + val (_, eventReceiver) = arrange { + withUpdateUserFailure(StorageFailure.DataNotFound) + } val result = eventReceiver.onEvent(event) @@ -146,9 +155,9 @@ class UserEventReceiverTest { @Test fun givenUserUpdateEvent_whenFailsWitOtherError_thenShouldFail() = runTest { val event = TestEvent.updateUser(userId = OTHER_USER_ID) - val (_, eventReceiver) = Arrangement() - .withUpdateUserFailure(StorageFailure.Generic(Throwable("error"))) - .arrange() + val (_, eventReceiver) = arrange { + withUpdateUserFailure(StorageFailure.Generic(Throwable("error"))) + } val result = eventReceiver.onEvent(event) @@ -158,8 +167,7 @@ class UserEventReceiverTest { @Test fun givenNewClientEvent_NewClientManagerInvoked() = runTest { val event = TestEvent.newClient() - val (arrangement, eventReceiver) = Arrangement() - .arrange() + val (arrangement, eventReceiver) = arrange { } eventReceiver.onEvent(event) @@ -169,16 +177,84 @@ class UserEventReceiverTest { .wasInvoked(exactly = once) } - private class Arrangement { + @Test + fun givenNewConnectionEvent_thenConnectionIsPersisted() = runTest { + val event = TestEvent.newConnection(status = ConnectionState.PENDING) + val (arrangement, eventReceiver) = arrange { + withFetchUserInfoReturning(Either.Right(Unit)) + withInsertConnectionFromEventSucceeding() + } + + eventReceiver.onEvent(event) + + verify(arrangement.connectionRepository) + .suspendFunction(arrangement.connectionRepository::insertConnectionFromEvent) + .with(any()) + .wasInvoked(exactly = once) + } + + @Test + fun givenNewConnectionEventWithStatusPending_thenActiveOneOnOneConversationIsNotResolved() = runTest { + val event = TestEvent.newConnection(status = ConnectionState.PENDING).copy() + val (arrangement, eventReceiver) = arrange { + withFetchUserInfoReturning(Either.Right(Unit)) + withInsertConnectionFromEventSucceeding() + } + + eventReceiver.onEvent(event) + + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::resolveOneOnOneConversationWithUser) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenNewConnectionEventWithStatusAccepted_thenResolveActiveOneOnOneConversationIsScheduled() = runTest { + val event = TestEvent.newConnection(status = ConnectionState.ACCEPTED).copy() + val (arrangement, eventReceiver) = arrange { + withFetchUserInfoReturning(Either.Right(Unit)) + withInsertConnectionFromEventSucceeding() + withScheduleResolveOneOnOneConversationWithUserId() + } + + eventReceiver.onEvent(event) + + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::scheduleResolveOneOnOneConversationWithUserId) + .with(eq(event.connection.qualifiedToId), eq(ZERO)) + .wasInvoked(exactly = once) + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun givenLiveNewConnectionEventWithStatusAccepted_thenResolveActiveOneOnOneConversationIsScheduledWithDelay() = runTest(TestKaliumDispatcher.default) { + val event = TestEvent.newConnection(status = ConnectionState.ACCEPTED).copy(live = true) + val (arrangement, eventReceiver) = arrange { + withFetchUserInfoReturning(Either.Right(Unit)) + withInsertConnectionFromEventSucceeding() + withScheduleResolveOneOnOneConversationWithUserId() + } + + eventReceiver.onEvent(event) + advanceUntilIdle() + + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::scheduleResolveOneOnOneConversationWithUserId) + .with(eq(event.connection.qualifiedToId), eq(3.seconds)) + .wasInvoked(exactly = once) + } + + private class Arrangement(private val block: Arrangement.() -> Unit) : + UserRepositoryArrangement by UserRepositoryArrangementImpl(), + OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl() + { @Mock val connectionRepository = mock(classOf()) @Mock val logoutUseCase = mock(classOf()) - @Mock - val userRepository = mock(classOf()) - @Mock val conversationRepository = mock(classOf()) @@ -194,6 +270,7 @@ class UserEventReceiverTest { conversationRepository, userRepository, logoutUseCase, + oneOnOneResolver, SELF_USER_ID, currentClientIdProvider ) @@ -202,6 +279,13 @@ class UserEventReceiverTest { withSaveNewClientSucceeding() } + fun withInsertConnectionFromEventSucceeding() = apply { + given(connectionRepository) + .suspendFunction(connectionRepository::insertConnectionFromEvent) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + fun withSaveNewClientSucceeding() = apply { given(clientRepository) .suspendFunction(clientRepository::saveNewClientEvent) @@ -220,31 +304,25 @@ class UserEventReceiverTest { given(logoutUseCase).suspendFunction(logoutUseCase::invoke).whenInvokedWith(any()).thenReturn(Unit) } - fun withUpdateUserSuccess() = apply { - given(userRepository).suspendFunction(userRepository::updateUserFromEvent).whenInvokedWith(any()).thenReturn(Either.Right(Unit)) - } - - fun withUpdateUserFailure(coreFailure: CoreFailure) = apply { - given(userRepository).suspendFunction(userRepository::updateUserFromEvent) - .whenInvokedWith(any()).thenReturn(Either.Left(coreFailure)) - } - - fun withUserDeleteSuccess() = apply { - given(userRepository).suspendFunction(userRepository::removeUser) - .whenInvokedWith(any()).thenReturn(Either.Right(Unit)) + fun withDeleteUserFromConversationsSuccess() = apply { given(conversationRepository).suspendFunction(conversationRepository::deleteUserFromConversations) .whenInvokedWith(any()).thenReturn(Either.Right(Unit)) } - fun withConversationIdsByUserId(conversationIds: List) = apply { - given(conversationRepository).suspendFunction(conversationRepository::getConversationIdsByUserId) + fun withConversationsByUserId(conversationIds: List) = apply { + given(conversationRepository).suspendFunction(conversationRepository::getConversationsByUserId) .whenInvokedWith(any()).thenReturn(Either.Right(conversationIds)) } - fun arrange() = this to userEventReceiver + fun arrange() = run { + block() + this@Arrangement to userEventReceiver + } } companion object { + private fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + const val EVENT_ID = "1234" val SELF_USER_ID = UserId("alice", "wonderland") val OTHER_USER_ID = UserId("john", "public") diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeDeletedHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeDeletedHandlerTest.kt index 9ae28b1835e..400e3034a0a 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeDeletedHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeDeletedHandlerTest.kt @@ -41,7 +41,8 @@ class CodeDeletedHandlerTest { val event = Event.Conversation.CodeDeleted( conversationId = ConversationId("conversationId", "domain"), id = "event-id", - transient = false + transient = false, + live = false ) handler.handle(event) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeUpdateHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeUpdateHandlerTest.kt index c40c6e37a7f..9c213574329 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeUpdateHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/CodeUpdateHandlerTest.kt @@ -45,7 +45,8 @@ class CodeUpdateHandlerTest { code = "code", key = "key", id = "event-id", - transient = false + transient = false, + live = false ) handler.handle(event) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandlerTest.kt index 03461148445..27095c69945 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MLSWelcomeEventHandlerTest.kt @@ -20,14 +20,22 @@ package com.wire.kalium.logic.sync.receiver.conversation import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.cryptography.MLSGroupId import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.MLSClientProvider -import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.event.Event +import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestConversationDetails import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either -import com.wire.kalium.persistence.dao.conversation.ConversationDAO -import com.wire.kalium.persistence.dao.conversation.ConversationEntity +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangement +import com.wire.kalium.logic.util.arrangement.mls.OneOnOneResolverArrangementImpl +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed import io.ktor.util.encodeBase64 import io.mockative.Mock import io.mockative.any @@ -38,28 +46,24 @@ import io.mockative.given import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest -import kotlin.test.Ignore import kotlin.test.Test +import kotlin.test.assertEquals -@OptIn(ExperimentalCoroutinesApi::class) class MLSWelcomeEventHandlerTest { @Test fun givenMLSClientFailsProcessingOfWelcomeMessageFails_thenShouldNotMarkConversationAsEstablished() = runTest { val exception = RuntimeException() - val (arrangement, mlsWelcomeEventHandler) = Arrangement() - .withMLSClientProcessingOfWelcomeMessageFailsWith(exception) - .arrange() + val (arrangement, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageFailsWith(exception) + } - // TODO: make sure failure is propagated - // needs refactoring of EventReceiver - mlsWelcomeEventHandler.handle(WELCOME_EVENT) + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldFail() - verify(arrangement.conversationDAO) - .suspendFunction(arrangement.conversationDAO::updateConversationGroupState) + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateConversationGroupState) .with(any(), any()) .wasNotInvoked() } @@ -67,78 +71,130 @@ class MLSWelcomeEventHandlerTest { @Test fun givenConversationFetchFails_thenShouldNotMarkConversationAsEstablished() = runTest { val failure = CoreFailure.Unknown(null) - val (arrangement, mlsWelcomeEventHandler) = Arrangement() - .withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) - .withFetchConversationIfUnknownFailingWith(failure) - .arrange() + val (arrangement, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownFailingWith(failure) + } - // TODO: make sure failure is propagated - // needs refactoring of EventReceiver - mlsWelcomeEventHandler.handle(WELCOME_EVENT) + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldFail() - verify(arrangement.conversationDAO) - .suspendFunction(arrangement.conversationDAO::updateConversationGroupState) + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateConversationGroupState) .with(any(), any()) .wasNotInvoked() } @Test - fun givenProcessingOfWelcomeAndConversationFetchSucceed_thenShouldMarkConversationAsEstablished() = runTest { - val (arrangement, mlsWelcomeEventHandler) = Arrangement() - .withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) - .withFetchConversationIfUnknownSucceeding() - .withUpdateGroupStateSucceeding() - .arrange() - - // TODO: make sure failure is propagated - // needs refactoring of EventReceiver - mlsWelcomeEventHandler.handle(WELCOME_EVENT) - - verify(arrangement.conversationDAO) - .suspendFunction(arrangement.conversationDAO::updateConversationGroupState) - .with(eq(ConversationEntity.GroupState.ESTABLISHED), eq(MLS_GROUP_ID)) + fun givenProcessingOfWelcomeSucceeds_thenShouldFetchConversationIfUnknown() = runTest { + val (arrangement, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownSucceeding() + withUpdateGroupStateReturning(Either.Right(Unit)) + withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_GROUP)) + } + + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchConversationIfUnknown) + .with(eq(CONVERSATION_ID)) .wasInvoked(exactly = once) } - // TODO: Implement this test once event handler is refactored - @Ignore + @Test + fun givenProcessingOfWelcomeSucceeds_thenShouldMarkConversationAsEstablished() = runTest { + val (arrangement, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownSucceeding() + withUpdateGroupStateReturning(Either.Right(Unit)) + withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_GROUP)) + } + + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldSucceed() + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::updateConversationGroupState) + .with(eq(GroupID(MLS_GROUP_ID)), eq(Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProcessingOfWelcomeForOneOnOneSucceeds_thenShouldResolveConversation() = runTest { + val (arrangement, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownSucceeding() + withUpdateGroupStateReturning(Either.Right(Unit)) + withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_ONE_ONE)) + withResolveOneOnOneConversationWithUserReturning(Either.Right(CONVERSATION_ID)) + } + + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldSucceed() + + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::resolveOneOnOneConversationWithUser) + .with(eq(CONVERSATION_ONE_ONE.otherUser)) + .wasInvoked(exactly = once) + } + + @Test + fun givenProcessingOfWelcomeForGroupSucceeds_thenShouldNotResolveConversation() = runTest { + val (arrangement, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownSucceeding() + withUpdateGroupStateReturning(Either.Right(Unit)) + withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_GROUP)) + } + + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldSucceed() + + verify(arrangement.oneOnOneResolver) + .suspendFunction(arrangement.oneOnOneResolver::resolveOneOnOneConversationWithUser) + .with(any()) + .wasNotInvoked() + } + @Test fun givenUpdateGroupStateFails_thenShouldPropagateError() = runTest { - val (_, mlsWelcomeEventHandler) = Arrangement() - .withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) - .withFetchConversationIfUnknownSucceeding() - .withUpdateGroupStateFailingWith(RuntimeException()) - .arrange() - mlsWelcomeEventHandler.handle(WELCOME_EVENT) + val failure = Either.Left(StorageFailure.DataNotFound) + val (_, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownSucceeding() + withUpdateGroupStateReturning(failure) + } + + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldFail { + assertEquals(failure.value, it) + } } - // TODO: Implement this test once event handler is refactored - @Ignore @Test - fun givenEverythingSucceeds_thenShouldPropagateSuccess() = runTest { - val (_, mlsWelcomeEventHandler) = Arrangement() - .withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) - .withFetchConversationIfUnknownSucceeding() - .withUpdateGroupStateSucceeding() - .arrange() - - mlsWelcomeEventHandler.handle(WELCOME_EVENT) + fun givenResolveOneOnOneConversationFails_thenShouldPropagateError() = runTest { + + val failure = Either.Left(NetworkFailure.NoNetworkConnection(null)) + val (_, mlsWelcomeEventHandler) = arrange { + withMLSClientProcessingOfWelcomeMessageReturnsSuccessfully(MLS_GROUP_ID) + withFetchConversationIfUnknownSucceeding() + withUpdateGroupStateReturning(Either.Right(Unit)) + withObserveConversationDetailsByIdReturning(Either.Right(CONVERSATION_ONE_ONE)) + withResolveOneOnOneConversationWithUserReturning(failure) + } + + mlsWelcomeEventHandler.handle(WELCOME_EVENT).shouldFail { + assertEquals(failure.value, it) + } } - private class Arrangement { + private class Arrangement(private val block: Arrangement.() -> Unit) : + ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(), + OneOnOneResolverArrangement by OneOnOneResolverArrangementImpl() + { @Mock val mlsClient: MLSClient = mock(classOf()) @Mock val mlsClientProvider: MLSClientProvider = mock(classOf()) - @Mock - val conversationDAO: ConversationDAO = mock(classOf()) - - @Mock - val conversationRepository: ConversationRepository = mock(classOf()) - init { withMLSClientProviderReturningMLSClient() } @@ -164,49 +220,29 @@ class MLSWelcomeEventHandlerTest { .thenReturn(mlsGroupId) } - fun withFetchConversationIfUnknownFailingWith(coreFailure: CoreFailure) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::fetchConversationIfUnknown) - .whenInvokedWith(any()) - .thenReturn(Either.Left(coreFailure)) - } - - fun withFetchConversationIfUnknownSucceeding() = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::fetchConversationIfUnknown) - .whenInvokedWith(any()) - .thenReturn(Either.Right(Unit)) - } - - fun withUpdateGroupStateFailingWith(exception: Exception) = apply { - given(conversationDAO) - .suspendFunction(conversationDAO::updateConversationGroupState) - .whenInvokedWith(any(), any()) - .thenThrow(exception) - } - - fun withUpdateGroupStateSucceeding() = apply { - given(conversationDAO) - .suspendFunction(conversationDAO::updateConversationGroupState) - .whenInvokedWith(any(), any()) - .thenReturn(Unit) + fun arrange() = run { + block() + this@Arrangement to MLSWelcomeEventHandlerImpl( + mlsClientProvider = mlsClientProvider, + conversationRepository = conversationRepository, + oneOnOneResolver = oneOnOneResolver + ) } - - fun arrange() = this to MLSWelcomeEventHandlerImpl( - mlsClientProvider = mlsClientProvider, - conversationDAO = conversationDAO, - conversationRepository = conversationRepository - ) } private companion object { + fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() + const val MLS_GROUP_ID: MLSGroupId = "test-mlsGroupId" + val CONVERSATION_ONE_ONE = TestConversationDetails.CONVERSATION_ONE_ONE + val CONVERSATION_GROUP = TestConversationDetails.CONVERSATION_GROUP val CONVERSATION_ID = TestConversation.ID val WELCOME = "welcome".encodeToByteArray() val WELCOME_EVENT = Event.Conversation.MLSWelcome( "eventId", CONVERSATION_ID, false, + false, TestUser.USER_ID, WELCOME.encodeBase64(), timestampIso = "2022-03-30T15:36:00.000Z" diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt index d5be45c2828..58682933b92 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/MemberLeaveEventHandlerTest.kt @@ -129,6 +129,7 @@ class MemberLeaveEventHandlerTest { id = "id", conversationId = conversationId, transient = false, + live = false, removedBy = userId, removedList = listOf(userId), timestampIso = "timestampIso" diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/NewConversationEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/NewConversationEventHandlerTest.kt index 5e4b0ab9cdc..6360cfd7bc7 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/NewConversationEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/NewConversationEventHandlerTest.kt @@ -62,6 +62,7 @@ class NewConversationEventHandlerTest { id = "eventId", conversationId = TestConversation.ID, transient = false, + live = false, timestampIso = "timestamp", conversation = TestConversation.CONVERSATION_RESPONSE, senderUserId = TestUser.SELF.id @@ -103,7 +104,8 @@ class NewConversationEventHandlerTest { val event = Event.Conversation.NewConversation( id = "eventId", conversationId = TestConversation.ID, - false, + transient =false, + live = false, timestampIso = "timestamp", conversation = TestConversation.CONVERSATION_RESPONSE, senderUserId = TestUser.SELF.id @@ -142,6 +144,7 @@ class NewConversationEventHandlerTest { id = "eventId", conversationId = TestConversation.ID, transient = false, + live = false, timestampIso = "timestamp", conversation = TestConversation.CONVERSATION_RESPONSE.copy( creator = "creatorId@creatorDomain", @@ -202,6 +205,7 @@ class NewConversationEventHandlerTest { id = "eventId", conversationId = TestConversation.ID, transient = false, + live = false, timestampIso = "timestamp", conversation = TestConversation.CONVERSATION_RESPONSE.copy( creator = "creatorId@creatorDomain", diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt index dc892b96872..c283470f413 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageUnpackerTest.kt @@ -31,6 +31,9 @@ import com.wire.kalium.logic.feature.message.PendingProposalScheduler import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.getOrNull +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.util.DateTimeUtil import io.ktor.util.decodeBase64Bytes import io.mockative.Mock @@ -45,10 +48,63 @@ import io.mockative.once import io.mockative.verify import kotlinx.coroutines.test.runTest import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.time.Duration.Companion.seconds class MLSMessageUnpackerTest { + @Test + fun givenConversationWithProteusProtocol_whenUnpacking_thenFailWithNotSupportedByProteus() = runTest { + val eventTimestamp = DateTimeUtil.currentInstant() + + val (_, mlsUnpacker) = Arrangement() + .withMLSClientProviderReturningClient() + .withGetConversationProtocolInfoSuccessful(TestConversation.PROTEUS_PROTOCOL_INFO) + .arrange() + + val messageEvent = TestEvent.newMLSMessageEvent(eventTimestamp) + val result = mlsUnpacker.unpackMlsMessage(messageEvent) + result.shouldFail { failure -> + assertEquals(CoreFailure.NotSupportedByProteus, failure) + } + } + + @Test + fun givenConversationWithMixedProtocol_whenUnpacking_thenSucceed() = runTest { + val eventTimestamp = DateTimeUtil.currentInstant() + val commitDelay: Long = 10 + + val (_, mlsUnpacker) = Arrangement() + .withMLSClientProviderReturningClient() + .withGetConversationProtocolInfoSuccessful(TestConversation.MIXED_PROTOCOL_INFO) + .withDecryptMessageReturning(Either.Right(listOf(DECRYPTED_MESSAGE_BUNDLE.copy(commitDelay = commitDelay)))) + .arrange() + + val messageEvent = TestEvent.newMLSMessageEvent(eventTimestamp) + val result = mlsUnpacker.unpackMlsMessage(messageEvent) + result.shouldSucceed() + + assertEquals(listOf(MessageUnpackResult.HandshakeMessage), result.getOrNull()) + } + + @Test + fun givenConversationWithMLSProtocol_whenUnpacking_thenSucceed() = runTest { + val eventTimestamp = DateTimeUtil.currentInstant() + val commitDelay: Long = 10 + + val (_, mlsUnpacker) = Arrangement() + .withMLSClientProviderReturningClient() + .withGetConversationProtocolInfoSuccessful(TestConversation.MLS_PROTOCOL_INFO) + .withDecryptMessageReturning(Either.Right(listOf(DECRYPTED_MESSAGE_BUNDLE.copy(commitDelay = commitDelay)))) + .arrange() + + val messageEvent = TestEvent.newMLSMessageEvent(eventTimestamp) + val result = mlsUnpacker.unpackMlsMessage(messageEvent) + result.shouldSucceed() + + assertEquals(listOf(MessageUnpackResult.HandshakeMessage), result.getOrNull()) + } + @Test fun givenNewMLSMessageEventWithProposal_whenUnpacking_thenScheduleProposalTimer() = runTest { val eventTimestamp = DateTimeUtil.currentInstant() @@ -145,7 +201,6 @@ class MLSMessageUnpackerTest { } fun arrange() = this to mlsMessageUnpacker - } companion object { val SELF_USER_ID = UserId("user-id", "domain") diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt deleted file mode 100644 index 055f37d99b7..00000000000 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Wire - * Copyright (C) 2023 Wire Swiss GmbH - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see http://www.gnu.org/licenses/. - */ -package com.wire.kalium.logic.sync.receiver.conversation.message - -import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.StorageFailure -import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.message.MessageContent -import com.wire.kalium.logic.data.message.PersistMessageUseCase -import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase -import com.wire.kalium.logic.framework.TestConversation -import com.wire.kalium.logic.framework.TestUser -import com.wire.kalium.logic.functional.Either -import com.wire.kalium.logic.util.thenReturnSequentially -import io.mockative.Mock -import io.mockative.any -import io.mockative.classOf -import io.mockative.eq -import io.mockative.given -import io.mockative.matching -import io.mockative.mock -import io.mockative.once -import io.mockative.verify -import kotlinx.coroutines.test.runTest -import kotlin.test.Test - -class MLSWrongEpochHandlerTest { - - @Test - fun givenConversationIsNotMLS_whenHandlingEpochFailure_thenShouldNotInsertWarning() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence(Either.Right(proteusProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.persistMessageUseCase) - .suspendFunction(arrangement.persistMessageUseCase::invoke) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenConversationIsNotMLS_whenHandlingEpochFailure_thenShouldNotFetchConversationAgain() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence(Either.Right(proteusProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::fetchConversation) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenMLSConversation_whenHandlingEpochFailure_thenShouldFetchConversationAgain() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturning(Either.Right(mlsProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.conversationRepository) - .suspendFunction(arrangement.conversationRepository::fetchConversation) - .with(eq(conversationId)) - .wasInvoked(exactly = once) - } - - @Test - fun givenUpdatedMLSConversationHasDifferentEpoch_whenHandlingEpochFailure_thenShouldRejoinTheConversation() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence( - Either.Right(mlsProtocol), - Either.Right(mlsProtocolWithUpdatedEpoch) - ) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.joinExistingMLSConversationUseCase) - .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) - .with(eq(conversationId)) - .wasInvoked(exactly = once) - } - - @Test - fun givenUpdatedMLSConversationHasSameEpoch_whenHandlingEpochFailure_thenShouldNotRejoinTheConversation() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturning(Either.Right(mlsProtocol)) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.joinExistingMLSConversationUseCase) - .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenRejoiningFails_whenHandlingEpochFailure_thenShouldNotPersistAnyMessage() = runTest { - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence( - Either.Right(mlsProtocol), - Either.Right(mlsProtocolWithUpdatedEpoch) - ) - .withJoinExistingConversationReturning(Either.Left(CoreFailure.Unknown(null))) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") - - verify(arrangement.persistMessageUseCase) - .suspendFunction(arrangement.persistMessageUseCase::invoke) - .with(any()) - .wasNotInvoked() - } - - @Test - fun givenConversationIsRejoined_whenHandlingEpochFailure_thenShouldInsertMLSWarningWithCorrectDateAndConversation() = runTest { - val date = "date" - val (arrangement, mlsWrongEpochHandler) = Arrangement() - .withProtocolByIdReturningSequence( - Either.Right(mlsProtocol), - Either.Right(mlsProtocolWithUpdatedEpoch) - ) - .arrange() - - mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, date) - - verify(arrangement.persistMessageUseCase) - .suspendFunction(arrangement.persistMessageUseCase::invoke) - .with( - matching { - it.conversationId == conversationId && - it.content == MessageContent.MLSWrongEpochWarning && - it.date == date - } - ) - .wasInvoked(exactly = once) - } - - private class Arrangement { - - @Mock - val persistMessageUseCase = mock(classOf()) - - @Mock - val conversationRepository = mock(classOf()) - - @Mock - val joinExistingMLSConversationUseCase = mock(classOf()) - - init { - withFetchByIdSucceeding() - withPersistMessageSucceeding() - withJoinExistingConversationSucceeding() - } - - fun withFetchByIdReturning(result: Either) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::fetchConversation) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withFetchByIdSucceeding() = withFetchByIdReturning(Either.Right(Unit)) - - fun withProtocolByIdReturning(result: Either) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::getConversationProtocolInfo) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withProtocolByIdReturningSequence(vararg results: Either) = apply { - given(conversationRepository) - .suspendFunction(conversationRepository::getConversationProtocolInfo) - .whenInvokedWith(any()) - .thenReturnSequentially(*results) - } - - fun withPersistMessageReturning(result: Either) = apply { - given(persistMessageUseCase) - .suspendFunction(persistMessageUseCase::invoke) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withPersistMessageSucceeding() = withPersistMessageReturning(Either.Right(Unit)) - - fun withJoinExistingConversationReturning(result: Either) = apply { - given(joinExistingMLSConversationUseCase) - .suspendFunction(joinExistingMLSConversationUseCase::invoke) - .whenInvokedWith(any()) - .thenReturn(result) - } - - fun withJoinExistingConversationSucceeding() = withJoinExistingConversationReturning(Either.Right(Unit)) - - fun arrange() = this to MLSWrongEpochHandlerImpl( - TestUser.SELF.id, - persistMessageUseCase, - conversationRepository, - joinExistingMLSConversationUseCase - ) - } - - private companion object { - val conversationId = TestConversation.CONVERSATION.id - val proteusProtocol = Conversation.ProtocolInfo.Proteus - - val mlsProtocol = TestConversation.MLS_CONVERSATION.protocol as Conversation.ProtocolInfo.MLS - val mlsProtocolWithUpdatedEpoch = mlsProtocol.copy(epoch = mlsProtocol.epoch + 1U) - } -} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt index 2676447ea60..12ff3697e30 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt @@ -27,6 +27,7 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.MessageContent import com.wire.kalium.logic.data.message.ProtoContent import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.message.StaleEpochVerifier import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandler import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.functional.Either @@ -41,12 +42,11 @@ import io.mockative.given import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant +import kotlinx.datetime.toInstant import kotlin.test.Test -@OptIn(ExperimentalCoroutinesApi::class) class NewMessageEventHandlerTest { @Test @@ -284,15 +284,16 @@ class NewMessageEventHandlerTest { fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldCallWrongEpochHandler() = runTest { val (arrangement, newMessageEventHandler) = Arrangement() .withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch)) + .withVerifyEpoch(Either.Right(Unit)) .arrange() val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) newMessageEventHandler.handleNewMLSMessage(newMessageEvent) - verify(arrangement.mlsWrongEpochHandler) - .suspendFunction(arrangement.mlsWrongEpochHandler::onMLSWrongEpoch) - .with(eq(newMessageEvent.conversationId),eq(newMessageEvent.timestampIso)) + verify(arrangement.staleEpochVerifier) + .suspendFunction(arrangement.staleEpochVerifier::verifyEpoch) + .with(eq(newMessageEvent.conversationId),eq(newMessageEvent.timestampIso.toInstant())) .wasInvoked(exactly = once) } @@ -300,6 +301,7 @@ class NewMessageEventHandlerTest { fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldNotPersistDecryptionErrorMessage() = runTest { val (arrangement, newMessageEventHandler) = Arrangement() .withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch)) + .withVerifyEpoch(Either.Right(Unit)) .arrange() val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) @@ -326,7 +328,7 @@ class NewMessageEventHandlerTest { } @Mock - val mlsWrongEpochHandler = mock(classOf()) + val staleEpochVerifier = mock(classOf()) @Mock val ephemeralMessageDeletionHandler = mock(EphemeralMessageDeletionHandler::class) @@ -337,7 +339,7 @@ class NewMessageEventHandlerTest { applicationMessageHandler, { conversationId, messageId -> ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId) }, SELF_USER_ID, - mlsWrongEpochHandler + staleEpochVerifier ) fun withProteusUnpackerReturning(result: Either) = apply { @@ -354,6 +356,13 @@ class NewMessageEventHandlerTest { .thenReturn(result) } + fun withVerifyEpoch(result: Either) = apply { + given(staleEpochVerifier) + .suspendFunction(staleEpochVerifier::verifyEpoch) + .whenInvokedWith(any()) + .thenReturn(result) + } + fun arrange() = this to newMessageEventHandler } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/ReceiptMessageHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/ReceiptMessageHandlerTest.kt index ee32eb25f12..4d28cba895b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/ReceiptMessageHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/ReceiptMessageHandlerTest.kt @@ -64,8 +64,8 @@ class ReceiptMessageHandlerTest { private val receiptMessageHandler = ReceiptMessageHandlerImpl(SELF_USER_ID, receiptRepository, messageRepository) private suspend fun insertTestData() { - userDatabase.builder.userDAO.insertUser(TestUser.ENTITY.copy(id = SELF_USER_ID_ENTITY)) - userDatabase.builder.userDAO.insertUser(TestUser.ENTITY.copy(id = OTHER_USER_ID_ENTITY)) + userDatabase.builder.userDAO.upsertUser(TestUser.ENTITY.copy(id = SELF_USER_ID_ENTITY)) + userDatabase.builder.userDAO.upsertUser(TestUser.ENTITY.copy(id = OTHER_USER_ID_ENTITY)) userDatabase.builder.conversationDAO.insertConversation(CONVERSATION_ENTITY) userDatabase.builder.messageDAO.insertOrIgnoreMessage(MESSAGE_ENTITY) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorkerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorkerTest.kt index cf6a77cf05f..cde7df34b0c 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorkerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/slow/SlowSyncWorkerTest.kt @@ -23,10 +23,12 @@ import com.wire.kalium.logic.data.sync.SlowSyncStep import com.wire.kalium.logic.feature.connection.SyncConnectionsUseCase import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationsUseCase import com.wire.kalium.logic.feature.conversation.SyncConversationsUseCase +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver import com.wire.kalium.logic.feature.featureConfig.SyncFeatureConfigsUseCase import com.wire.kalium.logic.feature.team.SyncSelfTeamUseCase import com.wire.kalium.logic.feature.user.SyncContactsUseCase import com.wire.kalium.logic.feature.user.SyncSelfUserUseCase +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCase import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.KaliumSyncException import com.wire.kalium.logic.test_util.TestKaliumDispatcher @@ -52,12 +54,14 @@ class SlowSyncWorkerTest { fun givenSuccess_whenPerformingSlowSync_thenRunAllUseCases() = runTest(TestKaliumDispatcher.default) { val (arrangement, worker) = Arrangement() .withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsSuccess() .withSyncConnectionsSuccess() .withSyncSelfTeamSuccess() .withSyncContactsSuccess() .withJoinMLSConversationsSuccess() + .withResolveOneOnOneConversationsSuccess() .arrange() worker.slowSyncStepsFlow().collect() @@ -118,11 +122,47 @@ class SlowSyncWorkerTest { .wasNotInvoked() } + @Test + fun givenUpdateSupportedProtocolsFails_whenPerformingSlowSync_thenThrowSyncException() = runTest(TestKaliumDispatcher.default) { + val steps = hashSetOf(SlowSyncStep.SELF_USER, SlowSyncStep.FEATURE_FLAGS, SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS) + val (arrangement, worker) = Arrangement() + .withSyncSelfUserSuccess() + .withSyncFeatureConfigsSuccess() + .withUpdateSupportedProtocolsFailure() + .arrange() + + assertFailsWith { + worker.slowSyncStepsFlow().collect { + assertTrue { + it in steps + } + } + } + + verify(arrangement.syncSelfUser) + .suspendFunction(arrangement.syncSelfUser::invoke) + .wasInvoked(exactly = once) + + verify(arrangement.syncFeatureConfigs) + .suspendFunction(arrangement.syncFeatureConfigs::invoke) + .wasInvoked(exactly = once) + + verify(arrangement.syncConversations) + .suspendFunction(arrangement.syncConversations::invoke) + .wasNotInvoked() + } + @Test fun givenSyncConversationsFails_whenPerformingSlowSync_thenThrowSyncException() = runTest(TestKaliumDispatcher.default) { - val steps = hashSetOf(SlowSyncStep.SELF_USER, SlowSyncStep.FEATURE_FLAGS, SlowSyncStep.CONVERSATIONS) + val steps = hashSetOf( + SlowSyncStep.SELF_USER, + SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS, + SlowSyncStep.FEATURE_FLAGS, + SlowSyncStep.CONVERSATIONS + ) val (arrangement, worker) = Arrangement() .withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsFailure() .arrange() @@ -143,6 +183,10 @@ class SlowSyncWorkerTest { .suspendFunction(arrangement.syncFeatureConfigs::invoke) .wasInvoked(exactly = once) + verify(arrangement.updateSupportedProtocols) + .suspendFunction(arrangement.updateSupportedProtocols::invoke) + .wasInvoked(exactly = once) + verify(arrangement.syncConversations) .suspendFunction(arrangement.syncConversations::invoke) .wasInvoked(exactly = once) @@ -156,12 +200,14 @@ class SlowSyncWorkerTest { fun givenSyncConnectionsFails_whenPerformingSlowSync_thenThrowSyncException() = runTest(TestKaliumDispatcher.default) { val steps = hashSetOf( SlowSyncStep.SELF_USER, + SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS, SlowSyncStep.FEATURE_FLAGS, SlowSyncStep.CONVERSATIONS, SlowSyncStep.CONNECTIONS, ) val (arrangement, worker) = Arrangement() .withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsSuccess() .withSyncConnectionsFailure() @@ -183,6 +229,10 @@ class SlowSyncWorkerTest { .suspendFunction(arrangement.syncFeatureConfigs::invoke) .wasInvoked(exactly = once) + verify(arrangement.updateSupportedProtocols) + .suspendFunction(arrangement.updateSupportedProtocols::invoke) + .wasInvoked(exactly = once) + verify(arrangement.syncConversations) .suspendFunction(arrangement.syncConversations::invoke) .wasInvoked(exactly = once) @@ -200,6 +250,7 @@ class SlowSyncWorkerTest { fun givenSyncSelfTeamFails_whenPerformingSlowSync_thenThrowSyncException() = runTest(TestKaliumDispatcher.default) { val steps = hashSetOf( SlowSyncStep.SELF_USER, + SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS, SlowSyncStep.FEATURE_FLAGS, SlowSyncStep.CONVERSATIONS, SlowSyncStep.CONNECTIONS, @@ -207,6 +258,7 @@ class SlowSyncWorkerTest { ) val (arrangement, worker) = Arrangement() .withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsSuccess() .withSyncConnectionsSuccess() @@ -229,6 +281,10 @@ class SlowSyncWorkerTest { .suspendFunction(arrangement.syncFeatureConfigs::invoke) .wasInvoked(exactly = once) + verify(arrangement.updateSupportedProtocols) + .suspendFunction(arrangement.updateSupportedProtocols::invoke) + .wasInvoked(exactly = once) + verify(arrangement.syncConversations) .suspendFunction(arrangement.syncConversations::invoke) .wasInvoked(exactly = once) @@ -250,6 +306,7 @@ class SlowSyncWorkerTest { fun givenSyncContactsFails_whenPerformingSlowSync_thenThrowSyncException() = runTest(TestKaliumDispatcher.default) { val steps = hashSetOf( SlowSyncStep.SELF_USER, + SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS, SlowSyncStep.FEATURE_FLAGS, SlowSyncStep.CONVERSATIONS, SlowSyncStep.CONNECTIONS, @@ -258,6 +315,7 @@ class SlowSyncWorkerTest { ) val (arrangement, worker) = Arrangement() .withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsSuccess() .withSyncConnectionsSuccess() @@ -281,6 +339,10 @@ class SlowSyncWorkerTest { .suspendFunction(arrangement.syncFeatureConfigs::invoke) .wasInvoked(exactly = once) + verify(arrangement.updateSupportedProtocols) + .suspendFunction(arrangement.updateSupportedProtocols::invoke) + .wasInvoked(exactly = once) + verify(arrangement.syncConversations) .suspendFunction(arrangement.syncConversations::invoke) .wasInvoked(exactly = once) @@ -307,6 +369,7 @@ class SlowSyncWorkerTest { fun givenJoinMLSConversationsFails_whenPerformingSlowSync_thenThrowSyncException() = runTest(TestKaliumDispatcher.default) { val steps = hashSetOf( SlowSyncStep.SELF_USER, + SlowSyncStep.UPDATE_SUPPORTED_PROTOCOLS, SlowSyncStep.FEATURE_FLAGS, SlowSyncStep.CONVERSATIONS, SlowSyncStep.CONNECTIONS, @@ -316,6 +379,7 @@ class SlowSyncWorkerTest { ) val (arrangement, worker) = Arrangement() .withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsSuccess() .withSyncConnectionsSuccess() @@ -399,12 +463,14 @@ class SlowSyncWorkerTest { withFetchMostRecentEventReturning(Either.Right(fetchedEventId)) withUpdateLastProcessedEventIdReturning(Either.Right(Unit)) }.withSyncSelfUserSuccess() + .withUpdateSupportedProtocolsSuccess() .withSyncFeatureConfigsSuccess() .withSyncConversationsSuccess() .withSyncConnectionsSuccess() .withSyncSelfTeamSuccess() .withSyncContactsSuccess() .withJoinMLSConversationsSuccess() + .withResolveOneOnOneConversationsSuccess() .arrange() slowSyncWorker.slowSyncStepsFlow().collect() @@ -424,6 +490,10 @@ class SlowSyncWorkerTest { .suspendFunction(arrangement.syncFeatureConfigs::invoke) .wasInvoked(exactly = once) + verify(arrangement.updateSupportedProtocols) + .suspendFunction(arrangement.updateSupportedProtocols::invoke) + .wasInvoked(exactly = once) + verify(arrangement.syncConversations) .suspendFunction(arrangement.syncConversations::invoke) .wasInvoked(exactly = once) @@ -469,6 +539,12 @@ class SlowSyncWorkerTest { @Mock val joinMLSConversations: JoinExistingMLSConversationsUseCase = mock(JoinExistingMLSConversationsUseCase::class) + @Mock + val updateSupportedProtocols: UpdateSupportedProtocolsUseCase = mock(UpdateSupportedProtocolsUseCase::class) + + @Mock + val oneOnOneResolver: OneOnOneResolver = mock(OneOnOneResolver::class) + init { withLastProcessedEventIdReturning(Either.Right("lastProcessedEventId")) } @@ -481,7 +557,9 @@ class SlowSyncWorkerTest { syncConnections = syncConnections, syncSelfTeam = syncSelfTeam, syncContacts = syncContacts, - joinMLSConversations = joinMLSConversations + joinMLSConversations = joinMLSConversations, + updateSupportedProtocols = updateSupportedProtocols, + oneOnOneResolver = oneOnOneResolver, ) fun withSyncSelfUserFailure() = apply { @@ -512,6 +590,20 @@ class SlowSyncWorkerTest { .thenReturn(success) } + fun withUpdateSupportedProtocolsSuccess() = apply { + given(updateSupportedProtocols) + .suspendFunction(updateSupportedProtocols::invoke) + .whenInvoked() + .thenReturn(Either.Right(true)) + } + + fun withUpdateSupportedProtocolsFailure() = apply { + given(updateSupportedProtocols) + .suspendFunction(updateSupportedProtocols::invoke) + .whenInvoked() + .thenReturn(failure) + } + fun withSyncConversationsFailure() = apply { given(syncConversations) .suspendFunction(syncConversations::invoke) @@ -581,6 +673,13 @@ class SlowSyncWorkerTest { .whenInvokedWith(eq(keepRetryingOnFailure)) .thenReturn(success) } + + fun withResolveOneOnOneConversationsSuccess() = apply { + given(oneOnOneResolver) + .suspendFunction(oneOnOneResolver::resolveAllOneOnOneConversations) + .whenInvokedWith(any()) + .thenReturn(success) + } } companion object { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt new file mode 100644 index 00000000000..fbb131b69e7 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/SystemMessageInserterArrangement.kt @@ -0,0 +1,54 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.message.SystemMessageInserter +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface SystemMessageInserterArrangement { + val systemMessageInserter: SystemMessageInserter + + fun withInsertProtocolChangedSystemMessage() + + fun withInsertLostCommitSystemMessage(result: Either) +} + +internal class SystemMessageInserterArrangementImpl: SystemMessageInserterArrangement { + + @Mock + override val systemMessageInserter = mock(SystemMessageInserter::class) + + override fun withInsertProtocolChangedSystemMessage() { + given(systemMessageInserter) + .suspendFunction(systemMessageInserter::insertProtocolChangedSystemMessage) + .whenInvokedWith(any(), any(), any()) + .thenReturn(Unit) + } + + override fun withInsertLostCommitSystemMessage(result: Either) { + given(systemMessageInserter) + .suspendFunction(systemMessageInserter::insertLostCommitSystemMessage) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/UserRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/UserRepositoryArrangement.kt new file mode 100644 index 00000000000..0c113d09cf8 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/UserRepositoryArrangement.kt @@ -0,0 +1,124 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.user.OtherUser +import com.wire.kalium.logic.data.user.SelfUser +import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock +import kotlinx.coroutines.flow.Flow + +internal interface UserRepositoryArrangement { + val userRepository: UserRepository + + fun withUpdateUserSuccess() + + fun withUpdateUserFailure(coreFailure: CoreFailure) + + fun withRemoveUserSuccess() + + fun withSelfUserReturning(selfUser: SelfUser?) + + fun withUserByIdReturning(result: Either) + + fun withUpdateOneOnOneConversationReturning(result: Either) + + fun withGetKnownUserReturning(result: Flow) + + fun withGetUsersWithOneOnOneConversationReturning(result: List) + + fun withFetchAllOtherUsersReturning(result: Either) + + fun withFetchUserInfoReturning(result: Either) +} + +internal class UserRepositoryArrangementImpl: UserRepositoryArrangement { + + @Mock + override val userRepository: UserRepository = mock(UserRepository::class) + + override fun withUpdateUserSuccess() { + given(userRepository).suspendFunction(userRepository::updateUserFromEvent).whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + override fun withUpdateUserFailure(coreFailure: CoreFailure) { + given(userRepository).suspendFunction(userRepository::updateUserFromEvent) + .whenInvokedWith(any()).thenReturn(Either.Left(coreFailure)) + } + + override fun withRemoveUserSuccess() { + given(userRepository).suspendFunction(userRepository::removeUser) + .whenInvokedWith(any()).thenReturn(Either.Right(Unit)) + } + + override fun withSelfUserReturning(selfUser: SelfUser?) { + given(userRepository) + .suspendFunction(userRepository::getSelfUser) + .whenInvoked() + .thenReturn(selfUser) + } + + override fun withUserByIdReturning(result: Either) { + given(userRepository) + .suspendFunction(userRepository::userById) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withUpdateOneOnOneConversationReturning(result: Either) { + given(userRepository) + .suspendFunction(userRepository::updateActiveOneOnOneConversation) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withGetKnownUserReturning(result: Flow) { + given(userRepository) + .suspendFunction(userRepository::getKnownUser) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withGetUsersWithOneOnOneConversationReturning(result: List) { + given(userRepository) + .suspendFunction(userRepository::getUsersWithOneOnOneConversation) + .whenInvoked() + .thenReturn(result) + } + + override fun withFetchAllOtherUsersReturning(result: Either) { + given(userRepository) + .suspendFunction(userRepository::fetchAllOtherUsers) + .whenInvoked() + .thenReturn(result) + } + + override fun withFetchUserInfoReturning(result: Either) { + given(userRepository) + .suspendFunction(userRepository::fetchUserInfo) + .whenInvokedWith(any()) + .thenReturn(result) + } + +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/dao/MemberDAOArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/dao/MemberDAOArrangement.kt index 55c56ec5a59..bf67a3a1eb8 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/dao/MemberDAOArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/dao/MemberDAOArrangement.kt @@ -17,7 +17,6 @@ */ package com.wire.kalium.logic.util.arrangement.dao -import com.wire.kalium.persistence.dao.ConnectionEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity import com.wire.kalium.persistence.dao.UserIDEntity import com.wire.kalium.persistence.dao.member.MemberDAO @@ -34,16 +33,14 @@ interface MemberDAOArrangement { @Mock val memberDAO: MemberDAO - fun withUpdateOrInsertOneOnOneMemberWithConnectionStatusSuccess( + fun withUpdateOrInsertOneOnOneMemberSuccess( member: Matcher = any(), - status: Matcher = any(), conversationId: Matcher = any() ) - fun withUpdateOrInsertOneOnOneMemberWithConnectionStatusFailure( + fun withUpdateOrInsertOneOnOneMemberFailure( error: Throwable, member: Matcher = any(), - status: Matcher = any(), conversationId: Matcher = any() ) @@ -79,25 +76,23 @@ class MemberDAOArrangementImpl : MemberDAOArrangement { @Mock override val memberDAO: MemberDAO = mock(MemberDAO::class) - override fun withUpdateOrInsertOneOnOneMemberWithConnectionStatusSuccess( + override fun withUpdateOrInsertOneOnOneMemberSuccess( member: Matcher, - status: Matcher, conversationId: Matcher ) { given(memberDAO) - .suspendFunction(memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .whenInvokedWith(member, status, conversationId) + .suspendFunction(memberDAO::updateOrInsertOneOnOneMember) + .whenInvokedWith(member, conversationId) } - override fun withUpdateOrInsertOneOnOneMemberWithConnectionStatusFailure( + override fun withUpdateOrInsertOneOnOneMemberFailure( error: Throwable, member: Matcher, - status: Matcher, conversationId: Matcher ) { given(memberDAO) - .suspendFunction(memberDAO::updateOrInsertOneOnOneMemberWithConnectionStatus) - .whenInvokedWith(member, status, conversationId) + .suspendFunction(memberDAO::updateOrInsertOneOnOneMember) + .whenInvokedWith(member, conversationId) .thenThrow(error) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt new file mode 100644 index 00000000000..7e35a0dc8f9 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSConversationRepositoryArrangement.kt @@ -0,0 +1,42 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.conversation.MLSConversationRepository +import com.wire.kalium.logic.functional.Either +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +interface MLSConversationRepositoryArrangement { + val mlsConversationRepository: MLSConversationRepository + + fun withIsGroupOutOfSync(result: Either) +} + +class MLSConversationRepositoryArrangementImpl : MLSConversationRepositoryArrangement { + override val mlsConversationRepository = mock(MLSConversationRepository::class) + + override fun withIsGroupOutOfSync(result: Either) { + given(mlsConversationRepository) + .suspendFunction(mlsConversationRepository::isGroupOutOfSync) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSOneOnOneConversationResolverArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSOneOnOneConversationResolverArrangement.kt new file mode 100644 index 00000000000..911f8e53995 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/MLSOneOnOneConversationResolverArrangement.kt @@ -0,0 +1,45 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.feature.conversation.mls.MLSOneOnOneConversationResolver +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface MLSOneOnOneConversationResolverArrangement { + val mlsOneOnOneConversationResolver: MLSOneOnOneConversationResolver + + fun withResolveConversationReturning(result: Either) +} + +internal class MLSOneOnOneConversationResolverArrangementImpl : MLSOneOnOneConversationResolverArrangement { + @Mock + override val mlsOneOnOneConversationResolver = mock(MLSOneOnOneConversationResolver::class) + + override fun withResolveConversationReturning(result: Either) { + given(mlsOneOnOneConversationResolver) + .suspendFunction(mlsOneOnOneConversationResolver::invoke) + .whenInvokedWith(any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt new file mode 100644 index 00000000000..0f64b1fc117 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt @@ -0,0 +1,56 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneMigrator +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +interface OneOnOneMigratorArrangement { + + val oneOnOneMigrator: OneOnOneMigrator + + fun withMigrateToMLSReturns(result: Either) + + fun withMigrateToProteusReturns(result: Either) +} + +class OneOnOneMigratorArrangementImpl : OneOnOneMigratorArrangement { + + @Mock + override val oneOnOneMigrator = mock(OneOnOneMigrator::class) + + override fun withMigrateToMLSReturns(result: Either) { + given(oneOnOneMigrator) + .suspendFunction(oneOnOneMigrator::migrateToMLS) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withMigrateToProteusReturns(result: Either) { + given(oneOnOneMigrator) + .suspendFunction(oneOnOneMigrator::migrateToProteus) + .whenInvokedWith(any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneResolverArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneResolverArrangement.kt new file mode 100644 index 00000000000..0775258e7f6 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneResolverArrangement.kt @@ -0,0 +1,76 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock +import kotlinx.coroutines.Job + +interface OneOnOneResolverArrangement { + + val oneOnOneResolver: OneOnOneResolver + + fun withScheduleResolveOneOnOneConversationWithUserId() + fun withResolveOneOnOneConversationWithUserIdReturning(result: Either) + fun withResolveOneOnOneConversationWithUserReturning(result: Either) + fun withResolveAllOneOnOneConversationsReturning(result: Either) + +} + +class OneOnOneResolverArrangementImpl : OneOnOneResolverArrangement { + + @Mock + override val oneOnOneResolver = mock(OneOnOneResolver::class) + override fun withScheduleResolveOneOnOneConversationWithUserId() { + given(oneOnOneResolver) + .suspendFunction(oneOnOneResolver::scheduleResolveOneOnOneConversationWithUserId) + .whenInvokedWith(any(), any()) + .thenReturn(Job()) + } + + override fun withResolveOneOnOneConversationWithUserIdReturning(result: Either) { + given(oneOnOneResolver) + .suspendFunction(oneOnOneResolver::resolveOneOnOneConversationWithUserId) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withResolveOneOnOneConversationWithUserReturning(result: Either) { + given(oneOnOneResolver) + .suspendFunction(oneOnOneResolver::resolveOneOnOneConversationWithUser) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withResolveAllOneOnOneConversationsReturning(result: Either) { + given(oneOnOneResolver) + .suspendFunction(oneOnOneResolver::resolveAllOneOnOneConversations) + .whenInvokedWith(any()) + .thenReturn(result) + } + +} + + + diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/StaleEpochVerifierArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/StaleEpochVerifierArrangement.kt new file mode 100644 index 00000000000..66414035e8a --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/StaleEpochVerifierArrangement.kt @@ -0,0 +1,47 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.mls + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.feature.message.StaleEpochVerifier +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +interface StaleEpochVerifierArrangement { + + val staleEpochVerifier: StaleEpochVerifier + + fun withVerifyEpoch(result: Either) + +} + +class StaleEpochVerifierArrangementImpl : StaleEpochVerifierArrangement { + + @Mock + override val staleEpochVerifier = mock(StaleEpochVerifier::class) + + override fun withVerifyEpoch(result: Either) { + given(staleEpochVerifier) + .suspendFunction(staleEpochVerifier::verifyEpoch) + .whenInvokedWith(any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/protocol/OneOnOneProtocolSelectorArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/protocol/OneOnOneProtocolSelectorArrangement.kt new file mode 100644 index 00000000000..19d79550efa --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/protocol/OneOnOneProtocolSelectorArrangement.kt @@ -0,0 +1,44 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.protocol + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.feature.protocol.OneOnOneProtocolSelector +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface OneOnOneProtocolSelectorArrangement { + val oneOnOneProtocolSelector: OneOnOneProtocolSelector + fun withGetProtocolForUser(result: Either): OneOnOneProtocolSelectorArrangementImpl +} + +internal open class OneOnOneProtocolSelectorArrangementImpl : OneOnOneProtocolSelectorArrangement { + @Mock + override val oneOnOneProtocolSelector: OneOnOneProtocolSelector = mock(OneOnOneProtocolSelector::class) + + override fun withGetProtocolForUser(result: Either) = apply { + given(oneOnOneProtocolSelector) + .suspendFunction(oneOnOneProtocolSelector::getProtocolForUser) + .whenInvokedWith(any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConnectionRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConnectionRepositoryArrangement.kt index 7c0f840bdd1..f72a14a6174 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConnectionRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConnectionRepositoryArrangement.kt @@ -17,18 +17,22 @@ */ package com.wire.kalium.logic.util.arrangement.repository +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.conversation.ConversationDetails import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.user.Connection +import com.wire.kalium.logic.data.user.ConnectionState +import com.wire.kalium.logic.feature.connection.AcceptConnectionRequestUseCaseTest import com.wire.kalium.logic.functional.Either import io.mockative.Mock import io.mockative.any +import io.mockative.eq import io.mockative.given import io.mockative.matchers.Matcher import io.mockative.mock import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flowOf internal interface ConnectionRepositoryArrangement { val connectionRepository: ConnectionRepository @@ -36,6 +40,7 @@ internal interface ConnectionRepositoryArrangement { fun withGetConnections(result: Either>>) fun withDeleteConnection(result: Either, conversationId: Matcher = any()) fun withConnectionList(connectionsFlow: Flow>) + fun withUpdateConnectionStatus(result: Either) } internal open class ConnectionRepositoryArrangementImpl : ConnectionRepositoryArrangement { @@ -67,4 +72,11 @@ internal open class ConnectionRepositoryArrangementImpl : ConnectionRepositoryAr .whenInvoked() .thenReturn(connectionsFlow) } + + override fun withUpdateConnectionStatus(result: Either) { + given(connectionRepository) + .suspendFunction(connectionRepository::updateConnectionStatus) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationGroupRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationGroupRepositoryArrangement.kt index 177000efae0..d9b4b9a82ae 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationGroupRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationGroupRepositoryArrangement.kt @@ -17,13 +17,16 @@ */ package com.wire.kalium.logic.util.arrangement.repository +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationGroupRepository import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.functional.Either import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO import io.mockative.Mock import io.mockative.any +import io.mockative.anything import io.mockative.given import io.mockative.matchers.Matcher import io.mockative.mock @@ -40,9 +43,17 @@ interface ConversationGroupRepositoryArrangement { .whenInvokedWith(conversationId) .thenReturn(result) } + + fun withCreateGroupConversationReturning(result: Either) { + given(conversationGroupRepository) + .suspendFunction(conversationGroupRepository::createGroupConversation) + .whenInvokedWith(anything(), anything(), anything()) + .thenReturn(result) + } } class ConversationGroupRepositoryArrangementImpl : ConversationGroupRepositoryArrangement { @Mock override val conversationGroupRepository = mock(ConversationGroupRepository::class) } + diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt index 8830f761c9f..23f8c6cf828 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ConversationRepositoryArrangement.kt @@ -18,13 +18,15 @@ package com.wire.kalium.logic.util.arrangement.repository import com.wire.kalium.logic.CoreFailure -import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.conversation.Conversation import com.wire.kalium.logic.data.conversation.ConversationDetails +import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.conversation.ConversationRepository -import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.id.QualifiedID +import com.wire.kalium.logic.feature.connection.AcceptConnectionRequestUseCaseTest import com.wire.kalium.logic.functional.Either import io.mockative.Mock import io.mockative.any @@ -32,6 +34,7 @@ import io.mockative.eq import io.mockative.given import io.mockative.matchers.Matcher import io.mockative.mock +import kotlinx.coroutines.flow.flowOf internal interface ConversationRepositoryArrangement { val conversationRepository: ConversationRepository @@ -54,9 +57,53 @@ internal interface ConversationRepositoryArrangement { fun withConversationProtocolInfo(result: Either): ConversationRepositoryArrangementImpl fun withUpdateVerificationStatus(result: Either): ConversationRepositoryArrangementImpl fun withConversationDetailsByMLSGroupId(result: Either): ConversationRepositoryArrangementImpl + fun withUpdateProtocolLocally(result: Either) + fun withConversationsForUserIdReturning(result: Either>) + fun withFetchMlsOneToOneConversation(result: Either) + fun withFetchConversation(result: Either) + fun withObserveOneToOneConversationWithOtherUserReturning(result: Either) + + fun withObserveConversationDetailsByIdReturning(result: Either) + + fun withGetConversationIdsReturning(result: Either>) + + fun withGetOneOnOneConversationsWithOtherUserReturning(result: Either>) + + fun withGetConversationProtocolInfo(result: Either) + + fun withGetConversationByIdReturning(result: Conversation?) + + fun withFetchConversationIfUnknownFailingWith(coreFailure: CoreFailure) { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchConversationIfUnknown) + .whenInvokedWith(any()) + .thenReturn(Either.Left(coreFailure)) + } + + fun withFetchConversationIfUnknownSucceeding() { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchConversationIfUnknown) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + fun withUpdateGroupStateReturning(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::updateConversationGroupState) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } + + fun withUpdateConversationModifiedDate(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::updateConversationModifiedDate) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } } internal open class ConversationRepositoryArrangementImpl : ConversationRepositoryArrangement { + @Mock override val conversationRepository: ConversationRepository = mock(ConversationRepository::class) @@ -136,4 +183,74 @@ internal open class ConversationRepositoryArrangementImpl : ConversationReposito .whenInvokedWith(any()) .thenReturn(result) } + + override fun withUpdateProtocolLocally(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::updateProtocolLocally) + .whenInvokedWith(any(), any()) + .thenReturn(result) + } + + override fun withConversationsForUserIdReturning(result: Either>) { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationsByUserId) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withFetchMlsOneToOneConversation(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchMlsOneToOneConversation) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withFetchConversation(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchConversation) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withObserveOneToOneConversationWithOtherUserReturning(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::observeOneToOneConversationWithOtherUser) + .whenInvokedWith(any()) + .thenReturn(flowOf(result)) + } + + override fun withObserveConversationDetailsByIdReturning(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::observeConversationDetailsById) + .whenInvokedWith(any()) + .thenReturn(flowOf(result)) + } + + override fun withGetConversationIdsReturning(result: Either>) { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationIds) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withGetOneOnOneConversationsWithOtherUserReturning(result: Either>) { + given(conversationRepository) + .suspendFunction(conversationRepository::getOneOnOneConversationsWithOtherUser) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withGetConversationProtocolInfo(result: Either) { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationProtocolInfo) + .whenInvokedWith(any()) + .thenReturn(result) + } + + override fun withGetConversationByIdReturning(result: Conversation?) { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationById) + .whenInvokedWith(any()) + .thenReturn(result) + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/MessageRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/MessageRepositoryArrangement.kt index 6c38997b67f..f7f6ae54a0b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/MessageRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/MessageRepositoryArrangement.kt @@ -23,7 +23,6 @@ import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.Message import com.wire.kalium.logic.data.message.MessageRepository import com.wire.kalium.logic.data.notification.LocalNotification -import com.wire.kalium.logic.feature.message.GetNotificationsUseCaseTest import com.wire.kalium.logic.functional.Either import io.mockative.Mock import io.mockative.any @@ -31,7 +30,6 @@ import io.mockative.given import io.mockative.matchers.Matcher import io.mockative.mock import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.emptyFlow internal interface MessageRepositoryArrangement { @Mock @@ -56,6 +54,12 @@ internal interface MessageRepositoryArrangement { ) fun withLocalNotifications(list: Either>>) + + fun withMoveMessagesToAnotherConversation( + result: Either, + originalConversation: Matcher = any(), + targetConversation: Matcher = any() + ) } internal open class MessageRepositoryArrangementImpl : MessageRepositoryArrangement { @@ -88,7 +92,7 @@ internal open class MessageRepositoryArrangementImpl : MessageRepositoryArrangem result: Either, messageID: Matcher, conversationId: Matcher - ) { + ) { given(messageRepository) .suspendFunction(messageRepository::markMessageAsDeleted) .whenInvokedWith(messageID, conversationId) @@ -101,4 +105,15 @@ internal open class MessageRepositoryArrangementImpl : MessageRepositoryArrangem .whenInvokedWith(any()) .thenReturn(list) } + + override fun withMoveMessagesToAnotherConversation( + result: Either, + originalConversation: Matcher, + targetConversation: Matcher + ) { + given(messageRepository) + .suspendFunction(messageRepository::moveMessagesToAnotherConversation) + .whenInvokedWith(originalConversation, targetConversation) + .thenReturn(result) + } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt new file mode 100644 index 00000000000..1ea16666fe4 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt @@ -0,0 +1,87 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.repository + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface UserConfigRepositoryArrangement { + val userConfigRepository: UserConfigRepository + + fun withGetSupportedProtocolsReturning(result: Either>) + fun withSetSupportedProtocolsSuccessful() + fun withSetDefaultProtocolSuccessful() + fun withSetMLSEnabledSuccessful() + fun withSetMigrationConfigurationSuccessful() + fun withGetMigrationConfigurationReturning(result: Either) +} + +internal class UserConfigRepositoryArrangementImpl : UserConfigRepositoryArrangement { + @Mock + override val userConfigRepository: UserConfigRepository = mock(UserConfigRepository::class) + + override fun withGetSupportedProtocolsReturning(result: Either>) { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getSupportedProtocols) + .whenInvoked() + .thenReturn(result) + } + + override fun withSetSupportedProtocolsSuccessful() { + given(userConfigRepository) + .suspendFunction(userConfigRepository::setSupportedProtocols) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + override fun withSetDefaultProtocolSuccessful() { + given(userConfigRepository) + .function(userConfigRepository::setDefaultProtocol) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + override fun withSetMLSEnabledSuccessful() { + given(userConfigRepository) + .function(userConfigRepository::setMLSEnabled) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + override fun withSetMigrationConfigurationSuccessful() { + given(userConfigRepository) + .suspendFunction(userConfigRepository::setMigrationConfiguration) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + + override fun withGetMigrationConfigurationReturning(result: Either) { + given(userConfigRepository) + .suspendFunction(userConfigRepository::getMigrationConfiguration) + .whenInvoked() + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/JoinExistingMLSConversationUseCaseArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/JoinExistingMLSConversationUseCaseArrangement.kt new file mode 100644 index 00000000000..2ea3472d579 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/JoinExistingMLSConversationUseCaseArrangement.kt @@ -0,0 +1,46 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.usecase + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface JoinExistingMLSConversationUseCaseArrangement { + val joinExistingMLSConversationUseCase: JoinExistingMLSConversationUseCase + + fun withJoinExistingMLSConversationUseCaseReturning(result: Either) +} + +internal class JoinExistingMLSConversationUseCaseArrangementImpl : JoinExistingMLSConversationUseCaseArrangement { + + @Mock + override val joinExistingMLSConversationUseCase: JoinExistingMLSConversationUseCase = + mock(JoinExistingMLSConversationUseCase::class) + + override fun withJoinExistingMLSConversationUseCaseReturning(result: Either) { + given(joinExistingMLSConversationUseCase) + .suspendFunction(joinExistingMLSConversationUseCase::invoke) + .whenInvokedWith(any()) + .thenReturn(result) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/UpdateSupportedProtocolsAndResolveOneOnOnesArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/UpdateSupportedProtocolsAndResolveOneOnOnesArrangement.kt new file mode 100644 index 00000000000..4de524462de --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/usecase/UpdateSupportedProtocolsAndResolveOneOnOnesArrangement.kt @@ -0,0 +1,46 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util.arrangement.usecase + +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase +import com.wire.kalium.logic.functional.Either +import io.mockative.Mock +import io.mockative.any +import io.mockative.given +import io.mockative.mock + +internal interface UpdateSupportedProtocolsAndResolveOneOnOnesArrangement { + val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase + + fun withUpdateSupportedProtocolsAndResolveOneOnOnesSuccessful() +} + +internal class UpdateSupportedProtocolsAndResolveOneOnOnesArrangementImpl + : UpdateSupportedProtocolsAndResolveOneOnOnesArrangement { + @Mock + override val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase = + mock(UpdateSupportedProtocolsAndResolveOneOnOnesUseCase::class) + + override fun withUpdateSupportedProtocolsAndResolveOneOnOnesSuccessful() { + given(updateSupportedProtocolsAndResolveOneOnOnes) + .suspendFunction(updateSupportedProtocolsAndResolveOneOnOnes::invoke) + .whenInvokedWith(any()) + .thenReturn(Either.Right(Unit)) + } + +} diff --git a/logic/src/jvmTest/kotlin/com/wire/kalium/logic/feature/scenario/OnCloseCallTest.kt b/logic/src/jvmTest/kotlin/com/wire/kalium/logic/feature/scenario/OnCloseCallTest.kt index fc6c104199d..b4db5d64602 100644 --- a/logic/src/jvmTest/kotlin/com/wire/kalium/logic/feature/scenario/OnCloseCallTest.kt +++ b/logic/src/jvmTest/kotlin/com/wire/kalium/logic/feature/scenario/OnCloseCallTest.kt @@ -252,7 +252,7 @@ class OnCloseCallTest { val mlsCall = callMetadata.copy( protocol = Conversation.ProtocolInfo.MLS( groupId = GroupID(""), - groupState = Conversation.ProtocolInfo.MLS.GroupState.ESTABLISHED, + groupState = Conversation.ProtocolInfo.MLSCapable.GroupState.ESTABLISHED, epoch = ULong.MAX_VALUE, keyingMaterialLastUpdate = Instant.DISTANT_FUTURE, cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/BaseApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/BaseApi.kt new file mode 100644 index 00000000000..382afd9ec95 --- /dev/null +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/BaseApi.kt @@ -0,0 +1,27 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.network.api.base.authenticated + +import com.wire.kalium.network.exceptions.APINotSupported +import com.wire.kalium.network.utils.NetworkResponse + +interface BaseApi { + fun getApiNotSupportedError(apiName: String, apiVersion: Int) = NetworkResponse.Error( + APINotSupported("${this::class.simpleName}: $apiName api is only available on API V$apiVersion") + ) +} diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationApi.kt index 6058888dd8a..bcf3664e505 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationApi.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationApi.kt @@ -18,6 +18,7 @@ package com.wire.kalium.network.api.base.authenticated.conversation +import com.wire.kalium.network.api.base.authenticated.BaseApi import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationCodeInfo import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationMemberRoleDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO @@ -27,10 +28,11 @@ import com.wire.kalium.network.api.base.model.QualifiedID import com.wire.kalium.network.api.base.model.ServiceAddedResponse import com.wire.kalium.network.api.base.model.SubconversationId import com.wire.kalium.network.api.base.model.UserId +import com.wire.kalium.network.exceptions.APINotSupported import com.wire.kalium.network.utils.NetworkResponse @Suppress("TooManyFunctions") -interface ConversationApi { +interface ConversationApi : BaseApi { /** * Fetch conversations id's in a paginated fashion, including federated conversations @@ -110,6 +112,10 @@ interface ConversationApi { key: String ): NetworkResponse + suspend fun fetchMlsOneToOneConversation( + userId: UserId + ): NetworkResponse + suspend fun fetchSubconversationDetails( conversationId: ConversationId, subconversationId: SubconversationId @@ -152,4 +158,14 @@ interface ConversationApi { conversationId: ConversationId, typingIndicatorMode: TypingIndicatorStatusDTO ): NetworkResponse + suspend fun updateProtocol( + conversationId: ConversationId, + protocol: ConvProtocol + ): NetworkResponse + + companion object { + fun getApiNotSupportError(apiName: String, apiVersion: String = "4") = NetworkResponse.Error( + APINotSupported("${this::class.simpleName}: $apiName api is only available on API V$apiVersion") + ) + } } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt index 6d79062851b..6b759524718 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/ConversationResponse.kt @@ -37,7 +37,7 @@ import kotlinx.serialization.encoding.Encoder @Serializable data class ConversationResponse( @SerialName("creator") - val creator: String, + val creator: String?, @SerialName("members") val members: ConversationMembersResponse, diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/CreateConversationRequest.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/CreateConversationRequest.kt index b920b96dd4c..c2ad7f71994 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/CreateConversationRequest.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/CreateConversationRequest.kt @@ -89,7 +89,10 @@ enum class ConvProtocol { PROTEUS, @SerialName("mls") - MLS; + MLS, + + @SerialName("mixed") + MIXED; override fun toString(): String { return this.name.lowercase() diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/UpdateConversationProtocolRequest.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/UpdateConversationProtocolRequest.kt new file mode 100644 index 00000000000..7512dd53f32 --- /dev/null +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/UpdateConversationProtocolRequest.kt @@ -0,0 +1,27 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.network.api.base.authenticated.conversation + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class UpdateConversationProtocolRequest( + @SerialName("protocol") + val protocol: ConvProtocol +) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/UpdateConversationProtocolResponse.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/UpdateConversationProtocolResponse.kt new file mode 100644 index 00000000000..2770ea32204 --- /dev/null +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/UpdateConversationProtocolResponse.kt @@ -0,0 +1,25 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.network.api.base.authenticated.conversation + +import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO + +sealed class UpdateConversationProtocolResponse { + object ProtocolUnchanged : UpdateConversationProtocolResponse() + data class ProtocolUpdated(val event: EventContentDTO.Conversation.ProtocolUpdate) : UpdateConversationProtocolResponse() +} diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/model/ConversationProtocolDTO.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/model/ConversationProtocolDTO.kt new file mode 100644 index 00000000000..152be4a6953 --- /dev/null +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/conversation/model/ConversationProtocolDTO.kt @@ -0,0 +1,28 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.network.api.base.authenticated.conversation.model + +import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class ConversationProtocolDTO constructor( + @SerialName("protocol") + val protocol: ConvProtocol +) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/featureConfigs/FeatureConfigResponse.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/featureConfigs/FeatureConfigResponse.kt index 33df406a72a..4f20ea8fa29 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/featureConfigs/FeatureConfigResponse.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/featureConfigs/FeatureConfigResponse.kt @@ -18,7 +18,8 @@ package com.wire.kalium.network.api.base.authenticated.featureConfigs -import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO +import kotlinx.datetime.Instant import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -53,13 +54,16 @@ data class FeatureConfigResponse( @SerialName("mls") val mls: FeatureConfigData.MLS?, @SerialName("mlsE2EId") - val mlsE2EI: FeatureConfigData.E2EI? + val mlsE2EI: FeatureConfigData.E2EI?, + @SerialName("mlsMigration") + val mlsMigration: FeatureConfigData.MLSMigration? ) @Serializable enum class FeatureFlagStatusDTO { @SerialName("enabled") ENABLED, + @SerialName("disabled") DISABLED; } @@ -77,18 +81,31 @@ data class ClassifiedDomainsConfigDTO( @SerialName("domains") val domains: List ) + @Serializable data class MLSConfigDTO( @SerialName("protocolToggleUsers") val protocolToggleUsers: List, @SerialName("defaultProtocol") - val defaultProtocol: ConvProtocol, + val defaultProtocol: SupportedProtocolDTO, + @SerialName("supportedProtocols") + val supportedProtocols: List = listOf(SupportedProtocolDTO.PROTEUS), @SerialName("allowedCipherSuites") val allowedCipherSuites: List, @SerialName("defaultCipherSuite") val defaultCipherSuite: Int ) +@Serializable +data class MLSMigrationConfigDTO( + // migration start timestamp + @SerialName("startTime") + val startTime: Instant?, + // timestamp of the date until the migration has to finalise + @SerialName("finaliseRegardlessAfter") + val finaliseRegardlessAfter: Instant? +) + @Serializable data class SelfDeletingMessagesConfigDTO( @SerialName("enforcedTimeoutSeconds") @@ -221,4 +238,13 @@ sealed class FeatureConfigData { @SerialName("status") val status: FeatureFlagStatusDTO ) : FeatureConfigData() + + @SerialName("mlsMigration") + @Serializable + data class MLSMigration( + @SerialName("config") + val config: MLSMigrationConfigDTO, + @SerialName("status") + val status: FeatureFlagStatusDTO + ) : FeatureConfigData() } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/EventContentDTO.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/EventContentDTO.kt index 41b02bd58ab..1eaab3731c1 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/EventContentDTO.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/EventContentDTO.kt @@ -29,6 +29,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.TypingIndicat import com.wire.kalium.network.api.base.authenticated.conversation.guestroomlink.ConversationInviteLinkResponse import com.wire.kalium.network.api.base.authenticated.conversation.messagetimer.ConversationMessageTimerDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationAccessInfoDTO +import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationProtocolDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigData import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureFlagStatusDTO @@ -268,6 +269,14 @@ sealed class EventContentDTO { @SerialName("from") val from: String ) : Conversation() + @Serializable + @SerialName("conversation.protocol-update") + data class ProtocolUpdate( + @SerialName("qualified_conversation") val qualifiedConversation: ConversationId, + @SerialName("data") val data: ConversationProtocolDTO, + @SerialName("qualified_from") val qualifiedFrom: UserId, + ) : Conversation() + } @Serializable diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/user/ClientEventsData.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/user/ClientEventsData.kt index 2b07b726e19..c80a2168209 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/user/ClientEventsData.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/notification/user/ClientEventsData.kt @@ -19,6 +19,7 @@ package com.wire.kalium.network.api.base.authenticated.notification.user import com.wire.kalium.network.api.base.model.NonQualifiedUserId +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.base.model.UserAssetDTO import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -36,5 +37,6 @@ data class UserUpdateEventData( @SerialName("handle") val handle: String?, @SerialName("email") val email: String?, @SerialName("sso_id_deleted") val ssoIdDeleted: Boolean?, - @SerialName("assets") val assets: List? + @SerialName("assets") val assets: List?, + @SerialName("supported_protocols")val supportedProtocols: List? ) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/SelfApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/SelfApi.kt index caa67b7573c..d31296a31be 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/SelfApi.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/SelfApi.kt @@ -19,6 +19,8 @@ package com.wire.kalium.network.api.base.authenticated.self import com.wire.kalium.network.api.base.model.SelfUserDTO +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO +import com.wire.kalium.network.api.base.authenticated.BaseApi import com.wire.kalium.network.utils.NetworkResponse import kotlinx.serialization.SerialName @@ -27,7 +29,7 @@ data class ChangeHandleRequest( @SerialName("handle") val handle: String ) -interface SelfApi { +interface SelfApi : BaseApi { suspend fun getSelfInfo(): NetworkResponse suspend fun updateSelf(userUpdateRequest: UserUpdateRequest): NetworkResponse suspend fun changeHandle(request: ChangeHandleRequest): NetworkResponse @@ -40,4 +42,12 @@ interface SelfApi { */ suspend fun updateEmailAddress(email: String): NetworkResponse suspend fun deleteAccount(password: String?): NetworkResponse + + /** + * Update the supported protocols of the current user. + * @param protocols The updated list of supported protocols. + * @return A [NetworkResponse] with the result of the operation. + * true if the protocols were updated. + */ + suspend fun updateSupportedProtocols(protocols: List): NetworkResponse } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/UpdateSupportedProtocolsRequest.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/UpdateSupportedProtocolsRequest.kt new file mode 100644 index 00000000000..902284f7df3 --- /dev/null +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/self/UpdateSupportedProtocolsRequest.kt @@ -0,0 +1,28 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.network.api.base.authenticated.self + +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class UpdateSupportedProtocolsRequest( + @SerialName("supported_protocols") + val protocols: List +) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/UserDTO.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/UserDTO.kt index 8375a6381af..cb8023fb6d0 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/UserDTO.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/model/UserDTO.kt @@ -35,6 +35,7 @@ sealed class UserDTO { abstract val expiresAt: String? abstract val nonQualifiedId: NonQualifiedUserId abstract val service: ServiceDTO? + abstract val supportedProtocols: List? } @Serializable @@ -51,6 +52,7 @@ data class UserProfileDTO( @Deprecated("use id instead", replaceWith = ReplaceWith("this.id")) @SerialName("id") override val nonQualifiedId: NonQualifiedUserId, @SerialName("service") override val service: ServiceDTO?, + @SerialName("supported_protocols") override val supportedProtocols: List?, @SerialName("legalhold_status") val legalHoldStatus: LegalHoldStatusResponse, ) : UserDTO() @@ -68,10 +70,11 @@ data class SelfUserDTO( @Deprecated("use id instead", replaceWith = ReplaceWith("this.id")) @SerialName("id") override val nonQualifiedId: NonQualifiedUserId, @SerialName("service") override val service: ServiceDTO?, + @SerialName("supported_protocols") override val supportedProtocols: List?, @SerialName("locale") val locale: String, @SerialName("managed_by") val managedByDTO: ManagedByDTO?, @SerialName("phone") val phone: String?, - @SerialName("sso_id") val ssoID: UserSsoIdDTO?, + @SerialName("sso_id") val ssoID: UserSsoIdDTO? ) : UserDTO() @Serializable @@ -109,3 +112,16 @@ enum class ManagedByDTO { return this.name.lowercase() } } + +@Serializable +enum class SupportedProtocolDTO { + @SerialName("proteus") + PROTEUS, + + @SerialName("mls") + MLS; + + override fun toString(): String { + return this.name.lowercase() + } +} diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt index 49bba38c383..1975db37aca 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/ConversationApiV0.kt @@ -21,6 +21,7 @@ package com.wire.kalium.network.api.v0.authenticated import com.wire.kalium.network.AuthenticatedNetworkClient import com.wire.kalium.network.api.base.authenticated.conversation.AddConversationMembersRequest import com.wire.kalium.network.api.base.authenticated.conversation.AddServiceRequest +import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol import com.wire.kalium.network.api.base.authenticated.conversation.ConversationApi import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMemberAddedResponse import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMemberRemovedResponse @@ -37,6 +38,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.Subconversati import com.wire.kalium.network.api.base.authenticated.conversation.TypingIndicatorStatusDTO import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessRequest import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationAccessResponse +import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationReceiptModeResponse import com.wire.kalium.network.api.base.authenticated.conversation.messagetimer.ConversationMessageTimerDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationCodeInfo @@ -265,38 +267,33 @@ internal open class ConversationApiV0 internal constructor( } } + override suspend fun fetchMlsOneToOneConversation(userId: UserId): NetworkResponse = + getApiNotSupportedError(::fetchMlsOneToOneConversation.name, MIN_API_VERSION_MLS) + override suspend fun fetchSubconversationDetails( conversationId: ConversationId, subconversationId: SubconversationId ): NetworkResponse = - NetworkResponse.Error( - APINotSupported("MLS: fetchSubconversationDetails api is only available on API V5") - ) + getApiNotSupportedError(::fetchSubconversationDetails.name, MIN_API_VERSION_MLS) override suspend fun fetchSubconversationGroupInfo( conversationId: ConversationId, subconversationId: SubconversationId ): NetworkResponse = - NetworkResponse.Error( - APINotSupported("MLS: fetchSubconversationGroupInfo api is only available on API V5") - ) + getApiNotSupportedError(::fetchSubconversationGroupInfo.name, MIN_API_VERSION_MLS) override suspend fun deleteSubconversation( conversationId: ConversationId, subconversationId: SubconversationId, deleteRequest: SubconversationDeleteRequest ): NetworkResponse = - NetworkResponse.Error( - APINotSupported("MLS: deleteSubconversation api is only available on API V5") - ) + getApiNotSupportedError(::deleteSubconversation.name, MIN_API_VERSION_MLS) override suspend fun leaveSubconversation( conversationId: ConversationId, subconversationId: SubconversationId ): NetworkResponse = - NetworkResponse.Error( - APINotSupported("MLS: leaveSubconversation api is only available on API V5") - ) + getApiNotSupportedError(::leaveSubconversation.name, MIN_API_VERSION_MLS) protected suspend fun handleConversationMemberAddedResponse( httpResponse: HttpResponse @@ -394,6 +391,11 @@ internal open class ConversationApiV0 internal constructor( setBody(typingIndicatorMode) } } + override suspend fun updateProtocol( + conversationId: ConversationId, + protocol: ConvProtocol + ): NetworkResponse = + ConversationApi.getApiNotSupportError("updateProtocol") protected companion object { const val PATH_CONVERSATIONS = "conversations" @@ -412,11 +414,8 @@ internal open class ConversationApiV0 internal constructor( const val PATH_BOTS = "bots" const val QUERY_KEY_CODE = "code" const val QUERY_KEY_KEY = "key" - const val QUERY_KEY_START = "start" - const val QUERY_KEY_SIZE = "size" - const val QUERY_KEY_IDS = "qualified_ids" const val PATH_TYPING_NOTIFICATION = "typing" - const val MAX_CONVERSATION_DETAILS_COUNT = 1000 + const val MIN_API_VERSION_MLS = 5 } } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/SelfApiV0.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/SelfApiV0.kt index dcb59570845..7053948c396 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/SelfApiV0.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/SelfApiV0.kt @@ -25,6 +25,7 @@ import com.wire.kalium.network.api.base.authenticated.self.UserUpdateRequest import com.wire.kalium.network.api.base.model.DeleteAccountRequest import com.wire.kalium.network.api.base.model.RefreshTokenProperties import com.wire.kalium.network.api.base.model.SelfUserDTO +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.base.model.UpdateEmailRequest import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.session.SessionManager @@ -44,7 +45,7 @@ internal open class SelfApiV0 internal constructor( private val sessionManager: SessionManager ) : SelfApi { - private val httpClient get() = authenticatedNetworkClient.httpClient + internal val httpClient get() = authenticatedNetworkClient.httpClient override suspend fun getSelfInfo(): NetworkResponse = wrapKaliumResponse { httpClient.get(PATH_SELF) @@ -85,10 +86,15 @@ internal open class SelfApiV0 internal constructor( } } - private companion object { + override suspend fun updateSupportedProtocols(protocols: List): NetworkResponse = + getApiNotSupportedError(::updateSupportedProtocols.name, MIN_API_VERSION_SUPPORTED_PROTOCOLS) + + companion object { const val PATH_SELF = "self" const val PATH_HANDLE = "handle" const val PATH_ACCESS = "access" const val PATH_EMAIL = "email" + + const val MIN_API_VERSION_SUPPORTED_PROTOCOLS = 4 } } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v4/authenticated/ConversationApiV4.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v4/authenticated/ConversationApiV4.kt index 7154b65db19..60d2f76a846 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v4/authenticated/ConversationApiV4.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v4/authenticated/ConversationApiV4.kt @@ -117,9 +117,4 @@ internal open class ConversationApiV4 internal constructor( setBody(typingIndicatorMode) } } - - companion object { - const val PATH_GROUP_INFO = "groupinfo" - const val PATH_SUBCONVERSATIONS = "subconversations" - } } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/ConversationApiV5.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/ConversationApiV5.kt index deca8f058e0..718ea320155 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/ConversationApiV5.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/ConversationApiV5.kt @@ -19,17 +19,28 @@ package com.wire.kalium.network.api.v5.authenticated import com.wire.kalium.network.AuthenticatedNetworkClient +import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol +import com.wire.kalium.network.api.base.authenticated.conversation.ConversationResponse import com.wire.kalium.network.api.base.authenticated.conversation.SubconversationDeleteRequest import com.wire.kalium.network.api.base.authenticated.conversation.SubconversationResponse +import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolRequest +import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse +import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO import com.wire.kalium.network.api.base.model.ConversationId import com.wire.kalium.network.api.base.model.QualifiedID import com.wire.kalium.network.api.base.model.SubconversationId +import com.wire.kalium.network.api.base.model.UserId import com.wire.kalium.network.api.v4.authenticated.ConversationApiV4 +import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.NetworkResponse +import com.wire.kalium.network.utils.mapSuccess import com.wire.kalium.network.utils.wrapKaliumResponse import io.ktor.client.request.delete import io.ktor.client.request.get +import io.ktor.client.request.put import io.ktor.client.request.setBody +import io.ktor.http.HttpStatusCode +import io.ktor.utils.io.errors.IOException internal open class ConversationApiV5 internal constructor( authenticatedNetworkClient: AuthenticatedNetworkClient, @@ -87,4 +98,39 @@ internal open class ConversationApiV5 internal constructor( ) } + override suspend fun updateProtocol( + conversationId: ConversationId, + protocol: ConvProtocol + ): NetworkResponse = try { + httpClient.put("$PATH_CONVERSATIONS/${conversationId.domain}/${conversationId.value}/$PATH_PROTOCOL") { + setBody(UpdateConversationProtocolRequest(protocol)) + }.let { httpResponse -> + when (httpResponse.status) { + HttpStatusCode.NoContent -> NetworkResponse.Success( + UpdateConversationProtocolResponse.ProtocolUnchanged, httpResponse + ) + else -> { + wrapKaliumResponse { httpResponse } + .mapSuccess { + UpdateConversationProtocolResponse.ProtocolUpdated(it) + } + } + } + } + } catch (e: IOException) { + NetworkResponse.Error(KaliumException.GenericError(e)) + } + + override suspend fun fetchMlsOneToOneConversation(userId: UserId): NetworkResponse = + wrapKaliumResponse { + httpClient.get("$PATH_CONVERSATIONS/$PATH_ONE_TO_ONE/${userId.domain}/${userId.value}") + } + + companion object { + const val PATH_PROTOCOL = "protocol" + const val PATH_GROUP_INFO = "groupinfo" + const val PATH_SUBCONVERSATIONS = "subconversations" + const val PATH_ONE_TO_ONE = "one2one" + } + } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/SelfApiV5.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/SelfApiV5.kt index a37cf428262..939f4c98c43 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/SelfApiV5.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/SelfApiV5.kt @@ -19,10 +19,28 @@ package com.wire.kalium.network.api.v5.authenticated import com.wire.kalium.network.AuthenticatedNetworkClient +import com.wire.kalium.network.api.base.authenticated.self.UpdateSupportedProtocolsRequest +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.v4.authenticated.SelfApiV4 import com.wire.kalium.network.session.SessionManager +import com.wire.kalium.network.utils.NetworkResponse +import com.wire.kalium.network.utils.wrapKaliumResponse +import io.ktor.client.request.put +import io.ktor.client.request.setBody internal open class SelfApiV5 internal constructor( authenticatedNetworkClient: AuthenticatedNetworkClient, sessionManager: SessionManager -) : SelfApiV4(authenticatedNetworkClient, sessionManager) +) : SelfApiV4(authenticatedNetworkClient, sessionManager) { + override suspend fun updateSupportedProtocols( + protocols: List + ): NetworkResponse = wrapKaliumResponse { + httpClient.put("$PATH_SELF/$PATH_SUPPORTED_PROTOCOLS") { + setBody(UpdateSupportedProtocolsRequest(protocols)) + } + } + + companion object { + const val PATH_SUPPORTED_PROTOCOLS = "supported-protocols" + } +} diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/conversation/ConversationApiV0Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/conversation/ConversationApiV0Test.kt index b1c0d474712..3eb8ce94ca9 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/conversation/ConversationApiV0Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/conversation/ConversationApiV0Test.kt @@ -45,14 +45,17 @@ import com.wire.kalium.network.api.base.model.ConversationAccessDTO import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO import com.wire.kalium.network.api.base.model.ConversationId import com.wire.kalium.network.api.base.model.JoinConversationRequestV0 +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.base.model.UserId import com.wire.kalium.network.api.v0.authenticated.ConversationApiV0 +import com.wire.kalium.network.api.v0.authenticated.SelfApiV0 import com.wire.kalium.network.utils.NetworkResponse import com.wire.kalium.network.utils.isSuccessful import io.ktor.http.HttpStatusCode import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFalse import kotlin.test.assertIs import kotlin.test.assertTrue @@ -442,6 +445,15 @@ internal class ConversationApiV0Test : ApiTest() { assertIs>(response) } + @Test + fun givenRequest_whenFetchingMlsOneToOneConversation_thenRequestShouldFail() = runTest { + val networkClient = mockAuthenticatedNetworkClient(responseBody = "", statusCode = HttpStatusCode.OK) + val conversationApi = ConversationApiV0(networkClient) + val response = conversationApi.fetchMlsOneToOneConversation(UserId("domain", "id")) + + assertFalse(response.isSuccessful()) + } + private companion object { const val PATH_CONVERSATIONS = "/conversations" const val PATH_CONVERSATIONS_LIST_V2 = "/conversations/list/v2" diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/featureConfig/FeatureConfigApiV0Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/featureConfig/FeatureConfigApiV0Test.kt index 57cb78dc566..ea034541a34 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/featureConfig/FeatureConfigApiV0Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/featureConfig/FeatureConfigApiV0Test.kt @@ -32,7 +32,7 @@ import kotlin.test.assertTrue internal class FeatureConfigApiV0Test : ApiTest() { @Test - fun givenValidRequest_WhenCallingTheFileSharingApi_SuccessResponseExpected() = runTest { + fun givenValidRequest_WhenCallingTheFeatureConfigApi_SuccessResponseExpected() = runTest { // Given val apiPath = FEATURE_CONFIG val networkClient = mockAuthenticatedNetworkClient( @@ -55,7 +55,7 @@ internal class FeatureConfigApiV0Test : ApiTest() { } @Test - fun givenInValidRequestWithInsufficientPermission_WhenCallingTheFileSharingApi_ErrorResponseExpected() = runTest { + fun givenInValidRequestWithInsufficientPermission_WhenCallingTheFeatureConfigApi_ErrorResponseExpected() = runTest { // Given val apiPath = FEATURE_CONFIG val networkClient = mockAuthenticatedNetworkClient( @@ -79,7 +79,7 @@ internal class FeatureConfigApiV0Test : ApiTest() { } @Test - fun givenInValidRequestWithNoTeam_WhenCallingTheFileSharingApi_ErrorResponseExpected() = runTest { + fun givenInValidRequestWithNoTeam_WhenCallingFeatureConfigApi_ErrorResponseExpected() = runTest { // Given val apiPath = FEATURE_CONFIG val networkClient = mockAuthenticatedNetworkClient( diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/login/LoginApiV0Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/login/LoginApiV0Test.kt index 4eb1250d5c0..ed872719df4 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/login/LoginApiV0Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/login/LoginApiV0Test.kt @@ -154,7 +154,8 @@ internal class LoginApiV0Test : ApiTest() { locale = "", managedByDTO = null, phone = null, - ssoID = null + ssoID = null, + supportedProtocols = null ) val VALID_ACCESS_TOKEN_RESPONSE = AccessTokenDTOJson.createValid(accessTokenDto) val VALID_SELF_RESPONSE = UserDTOJson.createValid(userDTO) diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/self/SelfApiV0Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/self/SelfApiV0Test.kt index c7db79cbe44..c952f111a76 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/self/SelfApiV0Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v0/user/self/SelfApiV0Test.kt @@ -21,6 +21,7 @@ package com.wire.kalium.api.v0.user.self import com.wire.kalium.api.ApiTest import com.wire.kalium.api.json.model.ErrorResponseJson import com.wire.kalium.model.UserDTOJson +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.v0.authenticated.SelfApiV0 import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.utils.isSuccessful @@ -104,8 +105,6 @@ internal class SelfApiV0Test : ApiTest() { } } - - @Test fun givenUpdateEmailFailure_whenChangingSelfEmail_thenFailureIsReturned() = runTest { val networkClient = mockAuthenticatedNetworkClient( @@ -125,8 +124,14 @@ internal class SelfApiV0Test : ApiTest() { } } + @Test + fun givenRequest_whenUpdatingSupportedProtocols_thenRequestShouldFail() = runTest { + val networkClient = mockAuthenticatedNetworkClient(responseBody = "", statusCode = HttpStatusCode.OK) + val selfApi = SelfApiV0(networkClient, TEST_SESSION_MANAGER) + val response = selfApi.updateSupportedProtocols(listOf(SupportedProtocolDTO.PROTEUS)) - + assertFalse(response.isSuccessful()) + } private companion object { const val PATH_SELF = "/self" diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v4/ConversationApiV4Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v4/ConversationApiV4Test.kt index e4176395c38..69180b147a6 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v4/ConversationApiV4Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v4/ConversationApiV4Test.kt @@ -19,8 +19,8 @@ package com.wire.kalium.api.v4 import com.wire.kalium.api.ApiTest -import com.wire.kalium.api.json.model.ErrorResponseJson import com.wire.kalium.model.EventContentDTOJson +import com.wire.kalium.api.json.model.ErrorResponseJson import com.wire.kalium.model.conversation.CreateConversationRequestJson import com.wire.kalium.model.conversation.SendTypingStatusNotificationRequestJson import com.wire.kalium.network.api.base.authenticated.conversation.AddConversationMembersRequest diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/ConversationApiV5Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/ConversationApiV5Test.kt index 8df6da3b328..0f987c36a13 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/ConversationApiV5Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/ConversationApiV5Test.kt @@ -19,22 +19,30 @@ package com.wire.kalium.api.v5 import com.wire.kalium.api.ApiTest -import com.wire.kalium.model.conversation.CreateConversationRequestJson +import com.wire.kalium.api.v4.ConversationApiV4Test +import com.wire.kalium.model.EventContentDTOJson +import com.wire.kalium.model.conversation.ConversationResponseJson import com.wire.kalium.model.conversation.SubconversationDeleteRequestJson import com.wire.kalium.model.conversation.SubconversationDetailsResponseJson +import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol import com.wire.kalium.network.api.base.authenticated.conversation.SubconversationDeleteRequest import com.wire.kalium.network.api.base.authenticated.conversation.SubconversationResponse +import com.wire.kalium.network.api.base.authenticated.conversation.UpdateConversationProtocolResponse import com.wire.kalium.network.api.base.model.ConversationId +import com.wire.kalium.network.api.base.model.UserId +import com.wire.kalium.network.api.v4.authenticated.ConversationApiV4 import com.wire.kalium.network.api.v5.authenticated.ConversationApiV5 import com.wire.kalium.network.utils.NetworkResponse +import com.wire.kalium.network.utils.isSuccessful import io.ktor.http.HttpStatusCode import kotlinx.coroutines.test.runTest import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertIs +import kotlin.test.assertTrue internal class ConversationApiV5Test : ApiTest() { - @Test fun givenRequest_whenFetchingSubconversationDetails_thenRequestIsConfiguredCorrectly() = runTest { val networkClient = mockAuthenticatedNetworkClient( @@ -136,4 +144,75 @@ internal class ConversationApiV5Test : ApiTest() { "sub", ) } + + @Test + fun given200Response_whenUpdatingConversationProtocol_thenEventIsParsedCorrectly() = runTest { + val conversationId = ConversationId("conversationId", "conversationDomain") + + val networkClient = mockAuthenticatedNetworkClient( + EventContentDTOJson.validUpdateProtocol.rawJson, + statusCode = HttpStatusCode.OK, + assertion = { + assertPut() + assertPathEqual("${PATH_CONVERSATIONS}/${conversationId.domain}/${conversationId.value}/${PATH_PROTOCOL}") + } + ) + val conversationApi = ConversationApiV5(networkClient) + val response = conversationApi.updateProtocol(conversationId, ConvProtocol.MIXED) + + assertIs>(response) + assertIs(response.value) + assertEquals( + EventContentDTOJson.validUpdateProtocol.serializableData, + (response.value as UpdateConversationProtocolResponse.ProtocolUpdated).event + ) + } + + @Test + fun given204Response_whenUpdatingConversationProtocol_thenEventIsParsedCorrectly() = runTest { + val conversationId = ConversationId("conversationId", "conversationDomain") + + val networkClient = mockAuthenticatedNetworkClient( + "", + statusCode = HttpStatusCode.NoContent, + assertion = { + assertPut() + assertPathEqual("${PATH_CONVERSATIONS}/${conversationId.domain}/${conversationId.value}/${PATH_PROTOCOL}") + } + ) + val conversationApi = ConversationApiV5(networkClient) + val response = conversationApi.updateProtocol(conversationId, ConvProtocol.MIXED) + + assertIs>(response) + assertIs(response.value) + } + + @Test + fun whenCallingFetchMlsOneToOneConversation_thenTheRequestShouldBeConfiguredOK() = runTest { + val networkClient = mockAuthenticatedNetworkClient( + FETCH_CONVERSATION_RESPONSE, + statusCode = HttpStatusCode.OK, + assertion = { + assertGet() + assertPathEqual("${PATH_CONVERSATIONS}/one2one/${USER_ID.domain}/${USER_ID.value}") + } + ) + val conversationApi = ConversationApiV5(networkClient) + conversationApi.fetchMlsOneToOneConversation(USER_ID) + } + + @Test + fun given200Response_whenCallingFetchMlsOneToOneConversation_thenResponseIsParsedCorrectly() = runTest { + val networkClient = mockAuthenticatedNetworkClient(FETCH_CONVERSATION_RESPONSE, statusCode = HttpStatusCode.OK) + val conversationApi = ConversationApiV5(networkClient) + + assertTrue(conversationApi.fetchMlsOneToOneConversation(USER_ID).isSuccessful()) + } + + companion object { + const val PATH_CONVERSATIONS = "/conversations" + const val PATH_PROTOCOL = "protocol" + val USER_ID = UserId("id", "domain") + val FETCH_CONVERSATION_RESPONSE = ConversationResponseJson.v0.rawJson + } } diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/SelfApiV5Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/SelfApiV5Test.kt new file mode 100644 index 00000000000..f6d35e248d7 --- /dev/null +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/SelfApiV5Test.kt @@ -0,0 +1,54 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.kalium.api.v5 + +import com.wire.kalium.api.ApiTest +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO +import com.wire.kalium.network.api.v5.authenticated.SelfApiV5 +import com.wire.kalium.network.utils.isSuccessful +import io.ktor.http.HttpStatusCode +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertTrue + +@ExperimentalCoroutinesApi +internal class SelfApiV5Test : ApiTest() { + @Test + fun givenValidRequest_whenUpdatingSupportedProtocols_theRequestShouldBeConfiguredCorrectly() = + runTest { + val networkClient = mockAuthenticatedNetworkClient( + "", + statusCode = HttpStatusCode.OK, + assertion = { + assertPut() + assertNoQueryParams() + assertPathEqual("$PATH_SELF/$PATH_SUPPORTED_PROTOCOLS") + } + ) + val selfApi = SelfApiV5(networkClient, TEST_SESSION_MANAGER) + val response = selfApi.updateSupportedProtocols(listOf(SupportedProtocolDTO.MLS)) + assertTrue(response.isSuccessful()) + } + + private companion object { + const val PATH_SELF = "/self" + const val PATH_SUPPORTED_PROTOCOLS = "supported-protocols" + } +} diff --git a/network/src/commonTest/kotlin/com/wire/kalium/model/EventContentDTOJson.kt b/network/src/commonTest/kotlin/com/wire/kalium/model/EventContentDTOJson.kt index 53a398c1979..83c7512d705 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/model/EventContentDTOJson.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/model/EventContentDTOJson.kt @@ -19,10 +19,12 @@ package com.wire.kalium.model import com.wire.kalium.api.json.ValidJsonProvider +import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol import com.wire.kalium.network.api.base.authenticated.conversation.ConversationMembers import com.wire.kalium.network.api.base.authenticated.conversation.ConversationUsers import com.wire.kalium.network.api.base.authenticated.conversation.ReceiptMode import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationAccessInfoDTO +import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationProtocolDTO import com.wire.kalium.network.api.base.authenticated.conversation.model.ConversationReceiptModeDTO import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO import com.wire.kalium.network.api.base.model.ConversationAccessDTO @@ -138,6 +140,28 @@ object EventContentDTOJson { """.trimMargin() } + private val jsonProviderUpdateConversationProtocol = { serializable: EventContentDTO.Conversation.ProtocolUpdate -> + """ + |{ + | "conversation":"${serializable.qualifiedConversation.value}", + | "data":{ + | "protocol": "mixed" + | }, + | "from":"${serializable.qualifiedFrom.value}", + | "qualified_conversation": { + | "id": "${serializable.qualifiedConversation.value}", + | "domain": "${serializable.qualifiedConversation.domain}" + | }, + | "qualified_from" : { + | "id" : "${serializable.qualifiedFrom.value}", + | "domain" : "${serializable.qualifiedFrom.domain}" + | }, + | "time":"2023-01-27T10:35:10.146Z", + | "type":"conversation.protocol-update" + |} + """.trimMargin() + } + val validAccessUpdate = ValidJsonProvider( EventContentDTO.Conversation.AccessUpdate( qualifiedConversation = ConversationId("ebafd3d4-1548-49f2-ac4e-b2757e6ca44b", "anta.wire.link"), @@ -215,6 +239,21 @@ object EventContentDTOJson { jsonProviderUpdateConversationReceiptMode ) + val validUpdateProtocol = ValidJsonProvider( + EventContentDTO.Conversation.ProtocolUpdate( + qualifiedConversation = QualifiedID( + value = "conversationId", + domain = "conversationDomain" + ), + qualifiedFrom = QualifiedID( + value = "qualifiedFromId", + domain = "qualifiedFromDomain" + ), + data = ConversationProtocolDTO(ConvProtocol.MIXED) + ), + jsonProviderUpdateConversationProtocol + ) + val validGenerateGuestRoomLink = """ |{ | "qualified_conversation" : { diff --git a/network/src/commonTest/kotlin/com/wire/kalium/model/FeatureConfigJson.kt b/network/src/commonTest/kotlin/com/wire/kalium/model/FeatureConfigJson.kt index 13258823e6f..b7a88cc8976 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/model/FeatureConfigJson.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/model/FeatureConfigJson.kt @@ -40,8 +40,11 @@ import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConf import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureFlagStatusDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSConfigDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.E2EIConfigDTO +import com.wire.kalium.network.api.base.authenticated.featureConfigs.MLSMigrationConfigDTO import com.wire.kalium.network.api.base.authenticated.featureConfigs.SelfDeletingMessagesConfigDTO import com.wire.kalium.network.api.base.model.ErrorResponse +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO +import kotlinx.datetime.Instant object FeatureConfigJson { private val featureConfigResponseSerializer = { _: FeatureConfigResponse -> @@ -100,6 +103,7 @@ object FeatureConfigJson { | "config": { | "protocolToggleUsers": ["60368759-d23f-4502-ba6f-68b10e926f7a"], | "defaultProtocol": "proteus", + | "supportedProtocols": ["proteus", "mls"], | "allowedCipherSuites": [1], | "defaultCipherSuite": 1 | } @@ -125,10 +129,14 @@ object FeatureConfigJson { SSO(FeatureFlagStatusDTO.ENABLED), ValidateSAMLEmails(FeatureFlagStatusDTO.ENABLED), MLS( - MLSConfigDTO(emptyList(), ConvProtocol.PROTEUS, listOf(1), 1), + MLSConfigDTO(emptyList(), SupportedProtocolDTO.PROTEUS, listOf(SupportedProtocolDTO.PROTEUS), listOf(1), 1), FeatureFlagStatusDTO.ENABLED ), - FeatureConfigData.E2EI(E2EIConfigDTO("url", 0L), FeatureFlagStatusDTO.ENABLED) + FeatureConfigData.E2EI(E2EIConfigDTO("url", 0L), FeatureFlagStatusDTO.ENABLED), + FeatureConfigData.MLSMigration( + MLSMigrationConfigDTO(Instant.DISTANT_FUTURE, Instant.DISTANT_FUTURE), + FeatureFlagStatusDTO.ENABLED + ) ), featureConfigResponseSerializer ) diff --git a/network/src/commonTest/kotlin/com/wire/kalium/model/ListUsersResponseJson.kt b/network/src/commonTest/kotlin/com/wire/kalium/model/ListUsersResponseJson.kt index 98ec2b08da5..8a3cca3dcdb 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/model/ListUsersResponseJson.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/model/ListUsersResponseJson.kt @@ -21,8 +21,11 @@ package com.wire.kalium.model import com.wire.kalium.api.json.ValidJsonProvider import com.wire.kalium.network.api.base.authenticated.userDetails.ListUsersDTO import com.wire.kalium.network.api.base.model.LegalHoldStatusResponse +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import com.wire.kalium.network.api.base.model.UserId import com.wire.kalium.network.api.base.model.UserProfileDTO +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json object ListUsersResponseJson { @@ -43,7 +46,8 @@ object ListUsersResponseJson { email = null, expiresAt = null, nonQualifiedId = USER_1.value, - service = null + service = null, + supportedProtocols = listOf(SupportedProtocolDTO.PROTEUS, SupportedProtocolDTO.MLS) ), UserProfileDTO( id = USER_2, @@ -57,47 +61,64 @@ object ListUsersResponseJson { email = null, expiresAt = null, nonQualifiedId = USER_2.value, - service = null + service = null, + supportedProtocols = listOf(SupportedProtocolDTO.PROTEUS) ), ) - private val validUserInfoProvider = { userInfo: List -> + private val validUserInfoProviderV0 = { userInfo: UserProfileDTO -> + """ + |{ + | "accent_id": ${userInfo.accentId}, + | "handle": "${userInfo.handle}", + | "legalhold_status": "enabled", + | "name": "${userInfo.name}", + | "assets": ${userInfo.assets}, + | "id": "${userInfo.id.value}", + | "deleted": "false", + | "qualified_id": { + | "domain": "${userInfo.id.domain}", + | "id": "${userInfo.id.value}" + | } + |} + """.trimMargin() + } + + private val validUserInfoProviderV4 = { userInfo: UserProfileDTO -> + """ + |{ + | "accent_id": ${userInfo.accentId}, + | "handle": "${userInfo.handle}", + | "legalhold_status": "enabled", + | "name": "${userInfo.name}", + | "assets": ${userInfo.assets}, + | "id": "${userInfo.id.value}", + | "deleted": "false", + | "supported_protocols": ${Json.encodeToString(userInfo.supportedProtocols)}, + | "qualified_id": { + | "domain": "${userInfo.id.domain}", + | "id": "${userInfo.id.value}" + | } + |} + """.trimMargin() + } + + private val listProvider = { list: List -> """ |[ - | { - | "accent_id": ${userInfo[0].accentId}, - | "handle": "${userInfo[0].handle}", - | "legalhold_status": "enabled", - | "name": "${userInfo[0].name}", - | "assets": ${userInfo[0].assets}, - | "id": "${userInfo[0].id.value}", - | "deleted": "false", - | "qualified_id": { - | "domain": "${userInfo[0].id.domain}", - | "id": "${userInfo[0].id.value}" - | } - | }, - | { - | "accent_id": ${userInfo[1].accentId}, - | "handle": "${userInfo[1].handle}", - | "legalhold_status": "enabled", - | "name": "${userInfo[1].name}", - | "assets": ${userInfo[1].assets}, - | "id": "${userInfo[1].id.value}", - | "deleted": "false", - | "qualified_id": { - | "domain": "${userInfo[1].id.domain}", - | "id": "${userInfo[1].id.value}" - | } - | } + | ${list[0]}, + | ${list[1]} |] """.trimMargin() + } val v0 = ValidJsonProvider( - expectedUsersResponse + expectedUsersResponse.map { + it.copy(supportedProtocols = null) // we don't expect supported_protocols in v0 + } ) { - validUserInfoProvider(it) + listProvider(it.map(validUserInfoProviderV0)) } val v4_withFailedUsers = ValidJsonProvider( @@ -111,7 +132,7 @@ object ListUsersResponseJson { | "id": "${it.usersFailed[0].value}" | } | ], - | "found": ${validUserInfoProvider(it.usersFound)} + | "found": ${ listProvider(it.usersFound.map(validUserInfoProviderV4)) } |} """.trimMargin() } @@ -121,7 +142,7 @@ object ListUsersResponseJson { ) { """ |{ - | "found": ${validUserInfoProvider(it.usersFound)} + | "found": ${ listProvider(it.usersFound.map(validUserInfoProviderV4)) } |} """.trimMargin() } diff --git a/network/src/commonTest/kotlin/com/wire/kalium/model/NotificationEventsResponseJson.kt b/network/src/commonTest/kotlin/com/wire/kalium/model/NotificationEventsResponseJson.kt index 37bd3a68477..5dad94b0bc9 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/model/NotificationEventsResponseJson.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/model/NotificationEventsResponseJson.kt @@ -40,6 +40,7 @@ import com.wire.kalium.network.api.base.authenticated.notification.NotificationR import com.wire.kalium.network.api.base.model.ConversationId import com.wire.kalium.network.api.base.model.LocationResponse import com.wire.kalium.network.api.base.model.QualifiedID +import com.wire.kalium.network.api.base.model.SupportedProtocolDTO import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json @@ -198,7 +199,7 @@ object NotificationEventsResponseJson { private val newMlsFeatureConfigUpdate = ValidJsonProvider( EventContentDTO.FeatureConfig.FeatureConfigUpdatedDTO( MLS( - MLSConfigDTO(emptyList(), ConvProtocol.MLS, listOf(1), 1), + MLSConfigDTO(emptyList(), SupportedProtocolDTO.MLS, listOf(SupportedProtocolDTO.PROTEUS), listOf(1), 1), FeatureFlagStatusDTO.ENABLED, ) ), diff --git a/network/src/commonTest/kotlin/com/wire/kalium/model/UserDTOJson.kt b/network/src/commonTest/kotlin/com/wire/kalium/model/UserDTOJson.kt index 01f54570f7c..4aec3f2e0e2 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/model/UserDTOJson.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/model/UserDTOJson.kt @@ -91,7 +91,8 @@ object UserDTOJson { locale = "", managedByDTO = null, phone = null, - ssoID = null + ssoID = null, + supportedProtocols = null ), jsonProvider ) } diff --git a/persistence/src/androidInstrumentedTest/kotlin/com/wire/kalium/persistence/dao/message/MessageExtensionsTest.kt b/persistence/src/androidInstrumentedTest/kotlin/com/wire/kalium/persistence/dao/message/MessageExtensionsTest.kt index 067162fc28d..11b5bdc00f2 100644 --- a/persistence/src/androidInstrumentedTest/kotlin/com/wire/kalium/persistence/dao/message/MessageExtensionsTest.kt +++ b/persistence/src/androidInstrumentedTest/kotlin/com/wire/kalium/persistence/dao/message/MessageExtensionsTest.kt @@ -136,7 +136,7 @@ class MessageExtensionsTest : BaseDatabaseTest() { private suspend fun populateMessageData() { val userId = UserIDEntity("user", "domain") - userDAO.insertUser(newUserEntity(qualifiedID = userId)) + userDAO.upsertUser(newUserEntity(qualifiedID = userId)) conversationDAO.insertConversation(newConversationEntity(id = CONVERSATION_ID)) val messages = buildList { repeat(MESSAGE_COUNT) { diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq index 4138cb864a2..6d3f0298cf4 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Clients.sq @@ -19,6 +19,7 @@ CREATE TABLE Client ( model TEXT DEFAULT NULL, last_active INTEGER AS Instant DEFAULT NULL, mls_public_keys TEXT AS Map DEFAULT NULL, + is_mls_capable INTEGER AS Boolean NOT NULL DEFAULT 0, FOREIGN KEY (user_id) REFERENCES User(qualified_id) ON DELETE CASCADE, PRIMARY KEY (user_id, id) ); @@ -33,8 +34,8 @@ deleteClientsOfUser: DELETE FROM Client WHERE user_id = ?; insertClient: -INSERT INTO Client(user_id, id, device_type, client_type, is_valid, registration_date, label, model, last_active, mls_public_keys) -VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +INSERT INTO Client(user_id, id, device_type, client_type, is_valid, registration_date, label, model, last_active, mls_public_keys, is_mls_capable) +VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id, user_id) DO UPDATE SET device_type = coalesce(excluded.device_type, device_type), registration_date = coalesce(excluded.registration_date, registration_date), @@ -42,7 +43,8 @@ label = coalesce(excluded.label, label), model = coalesce(excluded.model, model), is_valid = is_valid, last_active = coalesce(excluded.last_active, last_active), -mls_public_keys = excluded.mls_public_keys; +mls_public_keys = excluded.mls_public_keys, +is_mls_capable = excluded.is_mls_capable OR is_mls_capable; -- it's not possible to remove mls capability once added selectAllClients: SELECT * FROM Client; diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Connections.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Connections.sq index 7df9f1773a1..715c508427e 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Connections.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Connections.sq @@ -38,6 +38,9 @@ UPDATE Connection SET last_update_date = ? WHERE to_id = ?; updateNotificationFlag: UPDATE Connection SET should_notify = ? WHERE qualified_to = ?; +updateConnectionConversation: +UPDATE Connection SET conversation_id = ?, qualified_conversation = ? WHERE qualified_to = ?; + setAllConnectionsAsNotified: UPDATE Connection SET should_notify = 0 WHERE status = 'PENDING' AND should_notify = 1; diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq index 98503639ce7..679b54ceb26 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq @@ -183,6 +183,10 @@ CASE (Conversation.type) WHEN 'ONE_ON_ONE' THEN User.defederated WHEN 'CONNECTION_PENDING' THEN connection_user.defederated END AS userDefederated, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.supported_protocols + WHEN 'CONNECTION_PENDING' THEN connection_user.supported_protocols +END AS userSupportedProtocols, CASE (Conversation.type) WHEN 'ONE_ON_ONE' THEN User.connection_status WHEN 'CONNECTION_PENDING' THEN connection_user.connection_status @@ -191,10 +195,18 @@ CASE (Conversation.type) WHEN 'ONE_ON_ONE' THEN User.qualified_id WHEN 'CONNECTION_PENDING' THEN connection_user.qualified_id END AS otherUserId, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.active_one_on_one_conversation_id + WHEN 'CONNECTION_PENDING' THEN connection_user.active_one_on_one_conversation_id +END AS otherUserActiveConversationId, CASE WHEN ((SELECT id FROM SelfUser LIMIT 1) LIKE (Conversation.creator_id || '@%')) THEN 1 ELSE 0 END AS isCreator, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN coalesce(User.active_one_on_one_conversation_id = Conversation.qualified_id, 0) + ELSE 1 +END AS isActive, Conversation.last_notified_date AS lastNotifiedMessageDate, memberRole. role AS selfRole, Conversation.protocol, @@ -241,13 +253,22 @@ WHERE type IS 'GROUP' OR (type IS NOT 'GROUP' AND (name IS NOT NULL AND otherUserId IS NOT NULL)) -- show 1:1 convos and connection requests if they have user metadata OR (type IS 'ONE_ON_ONE' AND userDeleted = 1) -- show deleted 1:1 convos, to maintain prev, logic ) - AND (protocol IS 'PROTEUS' OR (protocol IS 'MLS' AND mls_group_state IS 'ESTABLISHED')) + AND (protocol IS 'PROTEUS' OR protocol IS 'MIXED' OR (protocol IS 'MLS' AND mls_group_state IS 'ESTABLISHED')) AND archived = :fromArchive + AND isActive ORDER BY lastModifiedDate DESC, name IS NULL, name COLLATE NOCASE ASC; selectAllConversations: SELECT * FROM ConversationDetails WHERE type IS NOT 'CONNECTION_PENDING' ORDER BY last_modified_date DESC, name ASC; +selectAllTeamProteusConversationsReadyForMigration: +SELECT +qualified_id, +(SELECT count(*) FROM Member WHERE conversation = qualified_id) AS memberCount, +(SELECT count(*) FROM Member LEFT JOIN User ON User.qualified_id = Member.user WHERE Member.conversation = Conversation.qualified_id AND (User.supported_protocols = 'MLS' OR User.supported_protocols = 'MLS,PROTEUS' OR User.supported_protocols = 'PROTEUS,MLS')) AS mlsCapableMemberCount +FROM Conversation +WHERE type IS 'GROUP' AND protocol IS 'MIXED' AND team_id = ? AND memberCount = mlsCapableMemberCount; + selectByQualifiedId: SELECT * FROM ConversationDetails WHERE qualifiedId = ?; @@ -265,11 +286,23 @@ selectByGroupId: SELECT * FROM ConversationDetails WHERE mls_group_id = ?; selectByGroupState: -SELECT * FROM ConversationDetails WHERE mls_group_state = ? AND protocol = ?; +SELECT * FROM ConversationDetails WHERE mls_group_state = ? AND (protocol IS "MLS" OR protocol IS "MIXED"); + +selectActiveOneOnOneConversation: +SELECT * FROM ConversationDetails +WHERE qualifiedId = (SELECT active_one_on_one_conversation_id FROM User WHERE qualified_id = :user_id); + +selectOneOnOneConversationIdsByProtocol: +SELECT Member.conversation FROM Member +JOIN Conversation ON Conversation.qualified_id = Member.conversation +WHERE Conversation.type = 'ONE_ON_ONE' AND Conversation.protocol = :protocol AND Member.user = :user; getConversationIdByGroupId: SELECT qualified_id FROM Conversation WHERE mls_group_id = ?; +selectConversationIds: +SELECT qualified_id FROM Conversation WHERE protocol = :protocol AND type = :type AND (:teamId IS NULL OR team_id = :teamId); + updateConversationMutingStatus: UPDATE Conversation SET muted_status = ?, muted_time = ? @@ -287,7 +320,7 @@ updateKeyingMaterialDate: UPDATE Conversation SET mls_last_keying_material_update_date= ? WHERE mls_group_id = ?; selectByKeyingMaterialUpdate: -SELECT mls_group_id FROM Conversation WHERE mls_group_state = ? AND protocol = ? AND mls_last_keying_material_update_date - ? <0 AND mls_group_id IS NOT NULL; +SELECT mls_group_id FROM Conversation WHERE mls_group_state = ? AND (protocol IS "MLS" OR protocol IS "MIXED") AND mls_last_keying_material_update_date - ? <0 AND mls_group_id IS NOT NULL; updateProposalTimer: UPDATE Conversation SET mls_proposal_timer = COALESCE(mls_proposal_timer, ?) WHERE mls_group_id = ?; @@ -296,7 +329,7 @@ clearProposalTimer: UPDATE Conversation SET mls_proposal_timer = NULL WHERE mls_group_id = ?; selectProposalTimers: -SELECT mls_group_id, mls_proposal_timer FROM Conversation WHERE protocol = ? AND mls_group_id IS NOT NULL AND mls_proposal_timer IS NOT NULL; +SELECT mls_group_id, mls_proposal_timer FROM Conversation WHERE (protocol IS "MLS" OR protocol IS "MIXED") AND mls_group_id IS NOT NULL AND mls_proposal_timer IS NOT NULL; whoDeletedMeInConversation: SELECT sender_user_id FROM Message WHERE id IN (SELECT message_id FROM MessageMemberChangeContent WHERE conversation_id = :conversation_id AND member_change_type = 'REMOVED' AND member_change_list LIKE ('%' || :self_user_id || '%')) ORDER BY creation_date DESC LIMIT 1; @@ -311,6 +344,13 @@ UPDATE Conversation SET type = ? WHERE qualified_id = ?; +updateConversationProtocol { +UPDATE Conversation +SET protocol = :protocol +WHERE qualified_id = :qualified_id AND protocol != :protocol; +SELECT changes(); +} + selfConversationId: SELECT qualified_id FROM Conversation WHERE type = 'SELF' AND protocol = ? LIMIT 1; diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Members.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Members.sq index e2ac5c8c21d..9b0f8179086 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Members.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Members.sq @@ -25,14 +25,34 @@ DELETE FROM Member WHERE conversation = ?; selectAllMembersByConversation: SELECT * FROM Member WHERE conversation = :conversation; -selectConversationByMember: -SELECT * FROM Member -JOIN ConversationDetails ON ConversationDetails.qualifiedId = Member.conversation -WHERE ConversationDetails.type = 'ONE_ON_ONE' AND Member.user = ? -LIMIT 1; - selectConversationsByMember: -SELECT * FROM Member +SELECT + qualified_id, + name, + type, + team_id, + mls_group_id, + mls_group_state, + mls_epoch, + mls_proposal_timer, + protocol, + muted_status, + muted_time, + creator_id, + last_modified_date, + last_notified_date, + last_read_date, + access_list, + access_role_list, + mls_last_keying_material_update_date, + mls_cipher_suite, + receipt_mode, + message_timer, + user_message_timer, + archived, + archived_date_time, + verification_status +FROM Member JOIN Conversation ON Conversation.qualified_id = Member.conversation WHERE Member.user = ?; diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/MessageDetailsView.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/MessageDetailsView.sq index 5e887719764..167a50796d3 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/MessageDetailsView.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/MessageDetailsView.sq @@ -116,6 +116,7 @@ ConversationReceiptModeChanged.receipt_mode AS conversationReceiptModeChanged, ConversationTimerChangedContent.message_timer AS messageTimerChanged, FailedRecipientsWithNoClients.recipient_failure_list AS recipientsFailedWithNoClientsList, FailedRecipientsDeliveryFailed.recipient_failure_list AS recipientsFailedDeliveryList, + IFNULL( (SELECT '[' || GROUP_CONCAT('{"text":"' || text || '", "id":"' || id || '""is_selected":' || is_selected || '}') @@ -127,7 +128,8 @@ IFNULL( '[]' ) AS buttonsJson, FederationTerminatedContent.domain_list AS federationDomainList, -FederationTerminatedContent.federation_type AS federationType +FederationTerminatedContent.federation_type AS federationType, +ConversationProtocolChangedContent.protocol AS conversationProtocolChanged FROM Message JOIN User ON Message.sender_user_id = User.qualified_id @@ -152,5 +154,6 @@ LEFT JOIN MessageNewConversationReceiptModeContent AS NewConversationReceiptMode LEFT JOIN MessageConversationReceiptModeChangedContent AS ConversationReceiptModeChanged ON Message.id = ConversationReceiptModeChanged.message_id AND Message.conversation_id = ConversationReceiptModeChanged.conversation_id LEFT JOIN MessageConversationTimerChangedContent AS ConversationTimerChangedContent ON Message.id = ConversationTimerChangedContent.message_id AND Message.conversation_id = ConversationTimerChangedContent.conversation_id LEFT JOIN MessageFederationTerminatedContent AS FederationTerminatedContent ON Message.id = FederationTerminatedContent.message_id AND Message.conversation_id = FederationTerminatedContent.conversation_id +LEFT JOIN MessageConversationProtocolChangedContent AS ConversationProtocolChangedContent ON Message.id = ConversationProtocolChangedContent.message_id AND Message.conversation_id = ConversationProtocolChangedContent.conversation_id LEFT JOIN SelfUser; -- TODO: Remove IFNULL functions above if we can force SQLDelight to not unpack as notnull diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Messages.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Messages.sq index 938c54c71d4..cb110bb1f79 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Messages.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Messages.sq @@ -1,3 +1,4 @@ +import com.wire.kalium.persistence.dao.conversation.ConversationEntity; import com.wire.kalium.persistence.dao.QualifiedIDEntity; import com.wire.kalium.persistence.dao.message.MessageEntity.ContentType; import com.wire.kalium.persistence.dao.message.MessageEntity.FederationType; @@ -211,6 +212,15 @@ CREATE TABLE MessageRecipientFailure ( PRIMARY KEY (message_id, conversation_id, recipient_failure_type) ); +CREATE TABLE MessageConversationProtocolChangedContent ( + message_id TEXT NOT NULL, + conversation_id TEXT AS QualifiedIDEntity NOT NULL, + protocol TEXT AS ConversationEntity.Protocol NOT NULL, + + FOREIGN KEY (message_id, conversation_id) REFERENCES Message(id, conversation_id) ON DELETE CASCADE ON UPDATE CASCADE, + PRIMARY KEY (message_id, conversation_id) +); + needsToBeNotified: WITH targetMessage(isSelfMessage, isMentioningSelfUser, isQuotingSelfUser, mutedStatus) AS ( SELECT isSelfMessage, @@ -369,6 +379,10 @@ insertConversationMessageTimerChanged: INSERT OR IGNORE INTO MessageConversationTimerChangedContent(message_id, conversation_id, message_timer) VALUES(?, ?, ?); +insertConversationProtocolChanged: +INSERT OR IGNORE INTO MessageConversationProtocolChangedContent(message_id, conversation_id, protocol) +VALUES(?, ?, ?); + updateMessageStatus: UPDATE Message SET status = ? @@ -492,3 +506,8 @@ WHERE conversation_id = ? AND id = ?; insertMessageRecipientsFailure: INSERT OR IGNORE INTO MessageRecipientFailure(message_id, conversation_id, recipient_failure_list, recipient_failure_type) VALUES(?, ?, ?, ?); + +moveMessages: +UPDATE OR REPLACE Message +SET conversation_id = :to +WHERE conversation_id = :from; diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Users.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Users.sq index 5aae6b1f07e..b54af091006 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Users.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Users.sq @@ -3,9 +3,11 @@ import com.wire.kalium.persistence.dao.ConnectionEntity; import com.wire.kalium.persistence.dao.QualifiedIDEntity; import com.wire.kalium.persistence.dao.UserAvailabilityStatusEntity; import com.wire.kalium.persistence.dao.UserTypeEntity; +import com.wire.kalium.persistence.dao.SupportedProtocolEntity; import kotlin.Int; import kotlinx.datetime.Instant; import kotlin.Boolean; +import kotlin.collections.Set; CREATE TABLE User ( qualified_id TEXT AS QualifiedIDEntity NOT NULL PRIMARY KEY, @@ -24,7 +26,9 @@ CREATE TABLE User ( deleted INTEGER AS Boolean NOT NULL DEFAULT 0, incomplete_metadata INTEGER AS Boolean NOT NULL DEFAULT 0, expires_at INTEGER AS Instant, - defederated INTEGER AS Boolean NOT NULL DEFAULT 0 + defederated INTEGER AS Boolean NOT NULL DEFAULT 0, + supported_protocols TEXT AS Set DEFAULT 'PROTEUS', + active_one_on_one_conversation_id TEXT AS QualifiedIDEntity ); CREATE INDEX user_team_index ON User(team); CREATE INDEX user_service_id ON User(bot_service); @@ -33,8 +37,8 @@ deleteUser: DELETE FROM User WHERE qualified_id = ?; insertUser: -INSERT INTO User(qualified_id, name, handle, email, phone, accent_id, team, connection_status, preview_asset_id, complete_asset_id, user_type, bot_service, deleted, incomplete_metadata, expires_at) -VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +INSERT INTO User(qualified_id, name, handle, email, phone, accent_id, team, connection_status, preview_asset_id, complete_asset_id, user_type, bot_service, deleted, incomplete_metadata, expires_at, supported_protocols, active_one_on_one_conversation_id) +VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(qualified_id) DO UPDATE SET name = excluded.name, handle = excluded.handle, @@ -42,7 +46,6 @@ email = excluded.email, phone = excluded.phone, accent_id = excluded.accent_id, team = excluded.team, -connection_status = excluded.connection_status, preview_asset_id = excluded.preview_asset_id, complete_asset_id = excluded.complete_asset_id, user_type = excluded.user_type, @@ -50,26 +53,33 @@ bot_service = excluded.bot_service, deleted = excluded.deleted, incomplete_metadata = excluded.incomplete_metadata, expires_at = excluded.expires_at, -defederated = 0; +defederated = 0, +supported_protocols = excluded.supported_protocols; insertOrIgnoreUser: -INSERT OR IGNORE INTO User(qualified_id, name, handle, email, phone, accent_id, team, connection_status, preview_asset_id, complete_asset_id, user_type, bot_service, deleted, incomplete_metadata, expires_at) -VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); - -updateUser: -UPDATE User -SET name = ?, handle = ?, email = ?, phone = ?, accent_id = ?, team = ?, preview_asset_id = ?, complete_asset_id = ?, user_type = ?, bot_service = ?, incomplete_metadata = ?, expires_at = ? -WHERE qualified_id = ?; +INSERT OR IGNORE INTO User(qualified_id, name, handle, email, phone, accent_id, team, connection_status, preview_asset_id, complete_asset_id, user_type, bot_service, deleted, incomplete_metadata, expires_at, supported_protocols) +VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); -updateTeamMemberUser: +updateUser { UPDATE User -SET name = ?, handle = ?, email = ?, phone = ?, accent_id = ?, team = ?, preview_asset_id = ?, complete_asset_id = ?, bot_service = ?, incomplete_metadata = 0 +SET +name = coalesce(:name, name), +handle = coalesce(:handle, handle), +email = coalesce(:email, email), +accent_id = coalesce(:accent_id, accent_id), +preview_asset_id = :preview_asset_id, preview_asset_id = coalesce(:preview_asset_id, preview_asset_id), +complete_asset_id = :complete_asset_id, complete_asset_id = coalesce(:complete_asset_id, complete_asset_id), +supported_protocols = :supported_protocols, supported_protocols = coalesce(:supported_protocols, supported_protocols) WHERE qualified_id = ?; +SELECT changes(); +} -updateTeamMemberType: -UPDATE User -SET team = ?, connection_status = ?, user_type = ? -WHERE qualified_id = ?; +upsertTeamMemberUserType: +INSERT INTO User(qualified_id, connection_status, user_type) +VALUES(?, ?, ?) +ON CONFLICT(qualified_id) DO UPDATE SET +connection_status = excluded.connection_status, +user_type = excluded.user_type; markUserAsDeleted: UPDATE User @@ -85,11 +95,6 @@ insertOrIgnoreUserId: INSERT OR IGNORE INTO User(qualified_id, incomplete_metadata) VALUES(?, 1); -updateSelfUser: -UPDATE User -SET name = ?, handle = ?, email = ?, accent_id = ?, preview_asset_id = ?, complete_asset_id = ? -WHERE qualified_id = ?; - insertOrIgnoreUserIdWithConnectionStatus: INSERT OR IGNORE INTO User(qualified_id, connection_status) VALUES(?, ?); @@ -119,6 +124,8 @@ User.deleted, User.incomplete_metadata, User.expires_at, User.defederated, +User.supported_protocols, +User.active_one_on_one_conversation_id, CASE WHEN SUM(Client.is_verified) = COUNT(*) THEN 1 ELSE 0 @@ -203,5 +210,21 @@ selectUsersWithoutMetadata: SELECT * FROM UserDetails AS user WHERE deleted = 0 AND incomplete_metadata = 1; +selectUsersWithOneOnOne: +SELECT * FROM User +WHERE deleted = 0 AND qualified_id != (SELECT id FROM SelfUser) AND qualified_id IN +( + SELECT user FROM Member + JOIN Conversation ON Conversation.qualified_id = Member.conversation + WHERE Conversation.type = 'ONE_ON_ONE' AND Member.user = User.qualified_id + LIMIT 1 +); + userIdsWithoutSelf: SELECT qualified_id FROM User WHERE qualified_id != (SELECT id FROM SelfUser); + +updateUserSupportedProtocols: +UPDATE User SET supported_protocols = ? WHERE qualified_id = ?; + +updateOneOnOnConversationId: +UPDATE User SET active_one_on_one_conversation_id = ? WHERE qualified_id = ?; diff --git a/persistence/src/commonMain/db_user/migrations/60.sqm b/persistence/src/commonMain/db_user/migrations/60.sqm new file mode 100644 index 00000000000..0d269fadf96 --- /dev/null +++ b/persistence/src/commonMain/db_user/migrations/60.sqm @@ -0,0 +1,156 @@ +ALTER TABLE User ADD COLUMN supported_protocols TEXT AS Set DEFAULT 'PROTEUS'; +ALTER TABLE User ADD COLUMN active_one_on_one_conversation_id TEXT AS QualifiedIDEntity; + +-- Re-create ConversationDetails view + +DROP VIEW IF EXISTS ConversationDetails; + +CREATE VIEW IF NOT EXISTS ConversationDetails AS +SELECT +Conversation.qualified_id AS qualifiedId, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.name + WHEN 'CONNECTION_PENDING' THEN connection_user.name + ELSE Conversation.name +END AS name, +Conversation.type, +Call.status AS callStatus, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.preview_asset_id + WHEN 'CONNECTION_PENDING' THEN connection_user.preview_asset_id +END AS previewAssetId, +Conversation.muted_status AS mutedStatus, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.team + ELSE Conversation.team_id +END AS teamId, +CASE (Conversation.type) + WHEN 'CONNECTION_PENDING' THEN Connection.last_update_date + ELSE Conversation.last_modified_date +END AS lastModifiedDate, +Conversation.last_read_date AS lastReadDate, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.user_availability_status + WHEN 'CONNECTION_PENDING' THEN connection_user.user_availability_status +END AS userAvailabilityStatus, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.user_type + WHEN 'CONNECTION_PENDING' THEN connection_user.user_type +END AS userType, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.bot_service + WHEN 'CONNECTION_PENDING' THEN connection_user.bot_service +END AS botService, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.deleted + WHEN 'CONNECTION_PENDING' THEN connection_user.deleted +END AS userDeleted, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.defederated + WHEN 'CONNECTION_PENDING' THEN connection_user.defederated +END AS userDefederated, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.supported_protocols + WHEN 'CONNECTION_PENDING' THEN connection_user.supported_protocols +END AS userSupportedProtocols, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.connection_status + WHEN 'CONNECTION_PENDING' THEN connection_user.connection_status +END AS connectionStatus, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.qualified_id + WHEN 'CONNECTION_PENDING' THEN connection_user.qualified_id +END AS otherUserId, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN User.active_one_on_one_conversation_id + WHEN 'CONNECTION_PENDING' THEN connection_user.active_one_on_one_conversation_id +END AS otherUserActiveConversationId, +CASE + WHEN ((SELECT id FROM SelfUser LIMIT 1) LIKE (Conversation.creator_id || '@%')) THEN 1 + ELSE 0 +END AS isCreator, +CASE (Conversation.type) + WHEN 'ONE_ON_ONE' THEN coalesce(User.active_one_on_one_conversation_id = Conversation.qualified_id, 0) + ELSE 1 +END AS isActive, +Conversation.last_notified_date AS lastNotifiedMessageDate, +memberRole. role AS selfRole, +Conversation.protocol, +Conversation.mls_cipher_suite, +Conversation.mls_epoch, +Conversation.mls_group_id, +Conversation.mls_last_keying_material_update_date, +Conversation.mls_group_state, +Conversation.access_list, +Conversation.access_role_list, +Conversation.team_id, +Conversation.mls_proposal_timer, +Conversation.muted_time, +Conversation.creator_id, +Conversation.last_modified_date, +Conversation.receipt_mode, +Conversation.message_timer, +Conversation.user_message_timer, +Conversation.incomplete_metadata, +Conversation.archived, +Conversation.archived_date_time, +Conversation.verification_status +FROM Conversation +LEFT JOIN Member ON Conversation.qualified_id = Member.conversation + AND Conversation.type IS 'ONE_ON_ONE' + AND Member.user IS NOT (SELECT SelfUser.id FROM SelfUser LIMIT 1) +LEFT JOIN Member AS memberRole ON Conversation.qualified_id = memberRole.conversation + AND memberRole.user IS (SELECT SelfUser.id FROM SelfUser LIMIT 1) +LEFT JOIN User ON User.qualified_id = Member.user +LEFT JOIN Connection ON Connection.qualified_conversation = Conversation.qualified_id + AND (Connection.status = 'SENT' + OR Connection.status = 'PENDING' + OR Connection.status = 'NOT_CONNECTED' + AND Conversation.type IS 'CONNECTION_PENDING') +LEFT JOIN User AS connection_user ON Connection.qualified_to = connection_user.qualified_id +LEFT JOIN Call ON Call.id IS (SELECT id FROM Call WHERE Call.conversation_id = Conversation.qualified_id AND Call.status IS 'STILL_ONGOING' ORDER BY created_at DESC LIMIT 1); + +-- Re-create UserDetails view + +DROP VIEW IF EXISTS UserDetails; + +CREATE VIEW IF NOT EXISTS UserDetails AS +SELECT +User.qualified_id, +User.name, +User.handle, +User.email, +User.phone, +User.accent_id, +User.team, +User.connection_status, +User.preview_asset_id, +User.complete_asset_id, +User.user_availability_status, +User.user_type, +User.bot_service, +User.deleted, +User.incomplete_metadata, +User.expires_at, +User.defederated, +User.supported_protocols, +User.active_one_on_one_conversation_id, +CASE + WHEN SUM(Client.is_verified) = COUNT(*) THEN 1 + ELSE 0 +END AS is_proteus_verified +FROM User +LEFT JOIN Client ON User.qualified_id = Client.user_id +GROUP BY User.qualified_id; + +-- Populate active_one_on_one_conversation_id for users with existing one-on-one conversations + +UPDATE User +SET active_one_on_one_conversation_id = ( + SELECT Member.conversation FROM Member + JOIN Conversation ON Conversation.qualified_id = Member.conversation + WHERE Conversation.type = 'ONE_ON_ONE' AND Conversation.protocol = 'PROTEUS' AND Member.user = User.qualified_id + LIMIT 1 +) +WHERE qualified_id != (SELECT id FROM SelfUser); + diff --git a/persistence/src/commonMain/db_user/migrations/61.sqm b/persistence/src/commonMain/db_user/migrations/61.sqm new file mode 100644 index 00000000000..3f5d4fca87d --- /dev/null +++ b/persistence/src/commonMain/db_user/migrations/61.sqm @@ -0,0 +1,3 @@ +import kotlin.Boolean; + +ALTER TABLE Client ADD COLUMN is_mls_capable INTEGER AS Boolean NOT NULL DEFAULT 0; diff --git a/persistence/src/commonMain/db_user/migrations/62.sqm b/persistence/src/commonMain/db_user/migrations/62.sqm new file mode 100644 index 00000000000..577c3035d09 --- /dev/null +++ b/persistence/src/commonMain/db_user/migrations/62.sqm @@ -0,0 +1,162 @@ +import com.wire.kalium.persistence.dao.QualifiedIDEntity; +import com.wire.kalium.persistence.dao.conversation.ConversationEntity; + +CREATE TABLE MessageConversationProtocolChangedContent ( + message_id TEXT NOT NULL, + conversation_id TEXT AS QualifiedIDEntity NOT NULL, + protocol TEXT AS ConversationEntity.Protocol NOT NULL, + + FOREIGN KEY (message_id, conversation_id) REFERENCES Message(id, conversation_id) ON DELETE CASCADE ON UPDATE CASCADE, + PRIMARY KEY (message_id, conversation_id) +); + +DROP VIEW IF EXISTS MessageDetailsView; + +CREATE VIEW IF NOT EXISTS MessageDetailsView +AS SELECT +Message.id AS id, +Message.conversation_id AS conversationId, +Message.content_type AS contentType, +Message.creation_date AS date, +Message.sender_user_id AS senderUserId, +Message.sender_client_id AS senderClientId, +Message.status AS status, +Message.last_edit_date AS lastEditTimestamp, +Message.visibility AS visibility, +Message.expects_read_confirmation AS expectsReadConfirmation, +Message.expire_after_millis AS expireAfterMillis, +Message.self_deletion_start_date AS selfDeletionStartDate, +IFNULL ((SELECT COUNT (*) FROM Receipt WHERE message_id = Message.id AND type = "READ"), 0) AS readCount, +User.name AS senderName, +User.handle AS senderHandle, +User.email AS senderEmail, +User.phone AS senderPhone, +User.accent_id AS senderAccentId, +User.team AS senderTeamId, +User.connection_status AS senderConnectionStatus, +User.preview_asset_id AS senderPreviewAssetId, +User.complete_asset_id AS senderCompleteAssetId, +User.user_availability_status AS senderAvailabilityStatus, +User.user_type AS senderUserType, +User.bot_service AS senderBotService, +User.deleted AS senderIsDeleted, +(Message.sender_user_id == SelfUser.id) AS isSelfMessage, +TextContent.text_body AS text, +TextContent.is_quoting_self AS isQuotingSelfUser, +AssetContent.asset_size AS assetSize, +AssetContent.asset_name AS assetName, +AssetContent.asset_mime_type AS assetMimeType, +AssetContent.asset_upload_status AS assetUploadStatus, +AssetContent.asset_download_status AS assetDownloadStatus, +AssetContent.asset_otr_key AS assetOtrKey, +AssetContent.asset_sha256 AS assetSha256, +AssetContent.asset_id AS assetId, +AssetContent.asset_token AS assetToken, +AssetContent.asset_domain AS assetDomain, +AssetContent.asset_encryption_algorithm AS assetEncryptionAlgorithm, +AssetContent.asset_width AS assetWidth, +AssetContent.asset_height AS assetHeight, +AssetContent.asset_duration_ms AS assetDuration, +AssetContent.asset_normalized_loudness AS assetNormalizedLoudness, +MissedCallContent.caller_id AS callerId, +MemberChangeContent.member_change_list AS memberChangeList, +MemberChangeContent.member_change_type AS memberChangeType, +UnknownContent.unknown_type_name AS unknownContentTypeName, +UnknownContent.unknown_encoded_data AS unknownContentData, +RestrictedAssetContent.asset_mime_type AS restrictedAssetMimeType, +RestrictedAssetContent.asset_size AS restrictedAssetSize, +RestrictedAssetContent.asset_name AS restrictedAssetName, +FailedToDecryptContent.unknown_encoded_data AS failedToDecryptData, +FailedToDecryptContent.is_decryption_resolved AS isDecryptionResolved, +ConversationNameChangedContent.conversation_name AS conversationName, +'{' || IFNULL( + (SELECT GROUP_CONCAT('"' || emoji || '":' || count) + FROM ( + SELECT COUNT(*) count, Reaction.emoji emoji + FROM Reaction + WHERE Reaction.message_id = Message.id + AND Reaction.conversation_id = Message.conversation_id + GROUP BY Reaction.emoji + )), + '') +|| '}' AS allReactionsJson, +IFNULL( + (SELECT '[' || GROUP_CONCAT('"' || Reaction.emoji || '"') || ']' + FROM Reaction + WHERE Reaction.message_id = Message.id + AND Reaction.conversation_id = Message.conversation_id + AND Reaction.sender_id = SelfUser.id + ), + '[]' +) AS selfReactionsJson, +IFNULL( + (SELECT '[' || GROUP_CONCAT( + '{"start":' || start || ', "length":' || length || + ', "userId":{"value":"' || replace(substr(user_id, 0, instr(user_id, '@')), '@', '') || '"' || + ',"domain":"' || replace(substr(user_id, instr(user_id, '@')+1, length(user_id)), '@', '') || '"' || + '}' || '}') || ']' + FROM MessageMention + WHERE MessageMention.message_id = Message.id + AND MessageMention.conversation_id = Message.conversation_id + ), + '[]' +) AS mentions, +TextContent.quoted_message_id AS quotedMessageId, +QuotedMessage.sender_user_id AS quotedSenderId, +TextContent.is_quote_verified AS isQuoteVerified, +QuotedSender.name AS quotedSenderName, +QuotedMessage.creation_date AS quotedMessageDateTime, +QuotedMessage.last_edit_date AS quotedMessageEditTimestamp, +QuotedMessage.visibility AS quotedMessageVisibility, +QuotedMessage.content_type AS quotedMessageContentType, +QuotedTextContent.text_body AS quotedTextBody, +QuotedAssetContent.asset_mime_type AS quotedAssetMimeType, +QuotedAssetContent.asset_name AS quotedAssetName, + +NewConversationReceiptMode.receipt_mode AS newConversationReceiptMode, + +ConversationReceiptModeChanged.receipt_mode AS conversationReceiptModeChanged, +ConversationTimerChangedContent.message_timer AS messageTimerChanged, +FailedRecipientsWithNoClients.recipient_failure_list AS recipientsFailedWithNoClientsList, +FailedRecipientsDeliveryFailed.recipient_failure_list AS recipientsFailedDeliveryList, + +IFNULL( + (SELECT '[' || + GROUP_CONCAT('{"text":"' || text || '", "id":"' || id || '""is_selected":' || is_selected || '}') + || ']' + FROM ButtonContent + WHERE ButtonContent.message_id = Message.id + AND ButtonContent.conversation_id = Message.conversation_id + ), + '[]' +) AS buttonsJson, +FederationTerminatedContent.domain_list AS federationDomainList, +FederationTerminatedContent.federation_type AS federationType, +ConversationProtocolChangedContent.protocol AS conversationProtocolChanged + +FROM Message +JOIN User ON Message.sender_user_id = User.qualified_id +LEFT JOIN MessageTextContent AS TextContent ON Message.id = TextContent.message_id AND Message.conversation_id = TextContent.conversation_id +LEFT JOIN MessageAssetContent AS AssetContent ON Message.id = AssetContent.message_id AND Message.conversation_id = AssetContent.conversation_id +LEFT JOIN MessageMissedCallContent AS MissedCallContent ON Message.id = MissedCallContent.message_id AND Message.conversation_id = MissedCallContent.conversation_id +LEFT JOIN MessageMemberChangeContent AS MemberChangeContent ON Message.id = MemberChangeContent.message_id AND Message.conversation_id = MemberChangeContent.conversation_id +LEFT JOIN MessageUnknownContent AS UnknownContent ON Message.id = UnknownContent.message_id AND Message.conversation_id = UnknownContent.conversation_id +LEFT JOIN MessageRestrictedAssetContent AS RestrictedAssetContent ON Message.id = RestrictedAssetContent.message_id AND RestrictedAssetContent.conversation_id = RestrictedAssetContent.conversation_id +LEFT JOIN MessageFailedToDecryptContent AS FailedToDecryptContent ON Message.id = FailedToDecryptContent.message_id AND Message.conversation_id = FailedToDecryptContent.conversation_id +LEFT JOIN MessageConversationChangedContent AS ConversationNameChangedContent ON Message.id = ConversationNameChangedContent.message_id AND Message.conversation_id = ConversationNameChangedContent.conversation_id +LEFT JOIN MessageRecipientFailure AS FailedRecipientsWithNoClients ON Message.id = FailedRecipientsWithNoClients.message_id AND Message.conversation_id = FailedRecipientsWithNoClients.conversation_id AND FailedRecipientsWithNoClients.recipient_failure_type = 'NO_CLIENTS_TO_DELIVER' +LEFT JOIN MessageRecipientFailure AS FailedRecipientsDeliveryFailed ON Message.id = FailedRecipientsDeliveryFailed.message_id AND Message.conversation_id = FailedRecipientsDeliveryFailed.conversation_id AND FailedRecipientsDeliveryFailed.recipient_failure_type = 'MESSAGE_DELIVERY_FAILED' + +-- joins for quoted messages +LEFT JOIN Message AS QuotedMessage ON QuotedMessage.id = TextContent.quoted_message_id AND QuotedMessage.conversation_id = TextContent.conversation_id +LEFT JOIN User AS QuotedSender ON QuotedMessage.sender_user_id = QuotedSender.qualified_id +LEFT JOIN MessageTextContent AS QuotedTextContent ON QuotedTextContent.message_id = QuotedMessage.id AND QuotedMessage.conversation_id = TextContent.conversation_id +LEFT JOIN MessageAssetContent AS QuotedAssetContent ON QuotedAssetContent.message_id = QuotedMessage.id AND QuotedMessage.conversation_id = TextContent.conversation_id +-- end joins for quoted messages +LEFT JOIN MessageNewConversationReceiptModeContent AS NewConversationReceiptMode ON Message.id = NewConversationReceiptMode.message_id AND Message.conversation_id = NewConversationReceiptMode.conversation_id +LEFT JOIN MessageConversationReceiptModeChangedContent AS ConversationReceiptModeChanged ON Message.id = ConversationReceiptModeChanged.message_id AND Message.conversation_id = ConversationReceiptModeChanged.conversation_id +LEFT JOIN MessageConversationTimerChangedContent AS ConversationTimerChangedContent ON Message.id = ConversationTimerChangedContent.message_id AND Message.conversation_id = ConversationTimerChangedContent.conversation_id +LEFT JOIN MessageFederationTerminatedContent AS FederationTerminatedContent ON Message.id = FederationTerminatedContent.message_id AND Message.conversation_id = FederationTerminatedContent.conversation_id +LEFT JOIN MessageConversationProtocolChangedContent AS ConversationProtocolChangedContent ON Message.id = ConversationProtocolChangedContent.message_id AND Message.conversation_id = ConversationProtocolChangedContent.conversation_id +LEFT JOIN SelfUser; +-- TODO: Remove IFNULL functions above if we can force SQLDelight to not unpack as notnull diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/adapter/SupportedProtocolSetAdapter.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/adapter/SupportedProtocolSetAdapter.kt new file mode 100644 index 00000000000..afeaf818de2 --- /dev/null +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/adapter/SupportedProtocolSetAdapter.kt @@ -0,0 +1,33 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.persistence.adapter + +import app.cash.sqldelight.ColumnAdapter +import com.wire.kalium.persistence.dao.SupportedProtocolEntity + +internal object SupportedProtocolSetAdapter : ColumnAdapter, String> { + override fun decode(databaseValue: String): Set { + return databaseValue.split(SEPARATOR).map { SupportedProtocolEntity.valueOf(it) }.toSet() + } + + override fun encode(value: Set): String { + return value.joinToString(SEPARATOR) { it.name } + } + + private const val SEPARATOR = "," +} diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/config/UserConfigStorage.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/config/UserConfigStorage.kt index cdc691e89c3..9d80f1d1b1e 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/config/UserConfigStorage.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/config/UserConfigStorage.kt @@ -18,6 +18,7 @@ package com.wire.kalium.persistence.config +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.kmmSettings.KaliumPreferences import com.wire.kalium.util.time.Second import kotlinx.coroutines.channels.BufferOverflow @@ -26,6 +27,7 @@ import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.onStart +import kotlinx.datetime.Instant import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlin.time.Duration @@ -95,6 +97,16 @@ interface UserConfigStorage { */ fun isSecondFactorPasswordChallengeRequired(): Boolean + /** + * Save default protocol to use + */ + fun persistDefaultProtocol(protocol: SupportedProtocolEntity) + + /** + * Gets default protocol to use. Defaults to PROTEUS if not default protocol has been saved. + */ + fun defaultProtocol(): SupportedProtocolEntity + /** * Save flag from the user settings to enable and disable MLS */ @@ -213,6 +225,13 @@ sealed class SelfDeletionTimerEntity { data class Enforced(val enforcedDuration: Duration) : SelfDeletionTimerEntity() } +@Serializable +data class MLSMigrationEntity( + @Serializable val status: Boolean, + @Serializable val startTime: Instant?, + @Serializable val endTime: Instant?, +) + @Suppress("TooManyFunctions") class UserConfigStorageImpl( private val kaliumPreferences: KaliumPreferences @@ -332,6 +351,14 @@ class UserConfigStorageImpl( override fun isSecondFactorPasswordChallengeRequired(): Boolean = kaliumPreferences.getBoolean(REQUIRE_SECOND_FACTOR_PASSWORD_CHALLENGE, false) + override fun persistDefaultProtocol(protocol: SupportedProtocolEntity) { + kaliumPreferences.putString(DEFAULT_PROTOCOL, protocol.name) + } + + override fun defaultProtocol(): SupportedProtocolEntity = + kaliumPreferences.getString(DEFAULT_PROTOCOL)?.let { SupportedProtocolEntity.valueOf(it) } + ?: SupportedProtocolEntity.PROTEUS + override fun enableMLS(enabled: Boolean) { kaliumPreferences.putBoolean(ENABLE_MLS, enabled) } @@ -451,5 +478,6 @@ class UserConfigStorageImpl( const val ENABLE_SCREENSHOT_CENSORING = "enable_screenshot_censoring" const val ENABLE_TYPING_INDICATOR = "enable_typing_indicator" const val APP_LOCK = "app_lock" + const val DEFAULT_PROTOCOL = "default_protocol" } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/ConnectionDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/ConnectionDAOImpl.kt index b3dc2963c3e..fb6867f8ef3 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/ConnectionDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/ConnectionDAOImpl.kt @@ -70,7 +70,9 @@ private class ConnectionMapper { deleted: Boolean?, incomplete_metadata: Boolean?, expires_at: Instant?, - defederated: Boolean? + defederated: Boolean?, + supportedProtocols: Set?, + oneToOneConversationId: QualifiedIDEntity? ): ConnectionEntity = ConnectionEntity( conversationId = conversation_id, from = from_id, @@ -98,7 +100,9 @@ private class ConnectionMapper { hasIncompleteMetadata = incomplete_metadata.requireField("incomplete_metadata"), expiresAt = expires_at, defederated = defederated.requireField("defederated"), - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = supportedProtocols, + activeOneOnOneConversationId = oneToOneConversationId ) else null ) diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAO.kt index c5c6542df3c..69007ef7f20 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAO.kt @@ -48,6 +48,12 @@ data class QualifiedIDEntity( typealias UserIDEntity = QualifiedIDEntity typealias ConversationIDEntity = QualifiedIDEntity +@Serializable +enum class SupportedProtocolEntity { + @SerialName("PROTEUS") PROTEUS, + @SerialName("MLS") MLS +} + enum class UserAvailabilityStatusEntity { NONE, AVAILABLE, BUSY, AWAY } @@ -71,7 +77,9 @@ data class UserEntity( val deleted: Boolean, val hasIncompleteMetadata: Boolean = false, val expiresAt: Instant?, - val defederated: Boolean + val defederated: Boolean, + val supportedProtocols: Set?, + val activeOneOnOneConversationId: QualifiedIDEntity? ) data class UserDetailsEntity( @@ -94,7 +102,9 @@ data class UserDetailsEntity( val hasIncompleteMetadata: Boolean = false, val expiresAt: Instant?, val defederated: Boolean, - val isProteusVerified: Boolean + val isProteusVerified: Boolean, + val supportedProtocols: Set?, + val activeOneOnOneConversationId: QualifiedIDEntity? ) { fun toSimpleEntity() = UserEntity( id = id, @@ -113,7 +123,9 @@ data class UserDetailsEntity( deleted = deleted, hasIncompleteMetadata = hasIncompleteMetadata, expiresAt = expiresAt, - defederated = defederated + defederated = defederated, + supportedProtocols = supportedProtocols, + activeOneOnOneConversationId = activeOneOnOneConversationId ) } @@ -129,6 +141,16 @@ data class BotIdEntity( val provider: String ) +data class PartialUserEntity( + val name: String?, + val handle: String?, + val email: String?, + val accentId: Int?, + val previewAssetId: UserAssetIdEntity?, + val completeAssetId: UserAssetIdEntity?, + val supportedProtocols: Set? +) + enum class UserTypeEntity { /**Team member with owner permissions */ @@ -180,50 +202,49 @@ internal typealias UserAssetIdEntity = QualifiedIDEntity @Suppress("TooManyFunctions") interface UserDAO { - /** - * Inserts a new user into the local storage - */ - suspend fun insertUser(user: UserEntity) - /** * Inserts each user into the local storage or ignores if already exists */ suspend fun insertOrIgnoreUsers(users: List) /** - * This will update all columns, except [ConnectionEntity.State] or insert a new record with default value - * [ConnectionEntity.State.NOT_CONNECTED] - * An upsert operation is a one that tries to update a record and if fails (not rows affected by change) inserts instead. - * In this case as the transaction can be executed many times, we need to take care for not deleting old data. + * Perform a partial update of an existing user. Only non-null values will be updated otherwise + * the existing value is kept. + * + * @return true if the user was updated */ - suspend fun upsertUsers(users: List) + suspend fun updateUser(id: UserIDEntity, update: PartialUserEntity): Boolean /** - * This will update [UserEntity.team], [UserEntity.userType], [UserEntity.connectionStatus] to [ConnectionEntity.State.ACCEPTED] - * or insert a new record with default values for other columns. + * This will update all columns (or insert a new record), except: + * - [ConnectionEntity.State] + * - [UserEntity.userType] + * - [UserEntity.activeOneOnOneConversationId] + * * An upsert operation is a one that tries to update a record and if fails (not rows affected by change) inserts instead. - * In this case when trying to insert a member, we could already have the record, so we need to pass only the data needed. + * In this case as the transaction can be executed many times, we need to take care for not deleting old data. */ - suspend fun upsertTeamMembersTypes(users: List) + suspend fun upsertUser(user: UserEntity) /** - * This will update all columns, except [UserEntity.userType] or insert a new record with default values + * This will update all columns (or insert a new record), except: + * - [ConnectionEntity.State] + * - [UserEntity.userType] + * - [UserEntity.activeOneOnOneConversationId] + * * An upsert operation is a one that tries to update a record and if fails (not rows affected by change) inserts instead. * In this case as the transaction can be executed many times, we need to take care for not deleting old data. */ - suspend fun upsertTeamMembers(users: List) + suspend fun upsertUsers(users: List) /** - * This will update a user record corresponding to the User, - * The Fields to update are: - * [UserEntity.name] - * [UserEntity.handle] - * [UserEntity.email] - * [UserEntity.accentId] - * [UserEntity.previewAssetId] - * [UserEntity.completeAssetId] + * This will update [UserEntity.team], [UserEntity.userType], [UserEntity.connectionStatus] to [ConnectionEntity.State.ACCEPTED] + * or insert a new record. + * + * An upsert operation is a one that tries to update a record and if fails (not rows affected by change) inserts instead. + * In this case when trying to insert a member, we could already have the record, so we need to pass only the data needed. */ - suspend fun updateUser(user: UserEntity) + suspend fun upsertTeamMemberUserTypes(users: Map) suspend fun getAllUsersDetails(): Flow> suspend fun observeAllUsersDetailsByConnectionStatus(connectionState: ConnectionEntity.State): Flow> suspend fun observeUserDetailsByQualifiedID(qualifiedID: QualifiedIDEntity): Flow @@ -240,6 +261,8 @@ interface UserDAO { connectionStates: List ): Flow> + suspend fun getUsersWithOneOnOneConversation(): List + suspend fun deleteUserByQualifiedID(qualifiedID: QualifiedIDEntity) suspend fun markUserAsDeleted(qualifiedID: QualifiedIDEntity) suspend fun markUserAsDefederated(qualifiedID: QualifiedIDEntity) @@ -265,4 +288,14 @@ interface UserDAO { * the list does not contain self user ID */ suspend fun allOtherUsersId(): List + + suspend fun updateUserSupportedProtocols(selfUserId: QualifiedIDEntity, supportedProtocols: Set) + + /** + * Update which 1-1 conversation is the currently active one. If multiple encryption protocols are enabled + * there can be multiple co-existing 1-1 conversations. + */ + suspend fun updateActiveOneOnOneConversation(userId: QualifiedIDEntity, conversationId: QualifiedIDEntity) + + suspend fun upsertConnectionStatus(userId: QualifiedIDEntity, status: ConnectionEntity.State) } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAOImpl.kt index d9471b8721f..b1bbadc6421 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/UserDAOImpl.kt @@ -55,7 +55,9 @@ class UserMapper { hasIncompleteMetadata = user.incomplete_metadata, expiresAt = user.expires_at, defederated = user.defederated, - isProteusVerified = user.is_proteus_verified == 1L + supportedProtocols = user.supported_protocols, + isProteusVerified = user.is_proteus_verified == 1L, + activeOneOnOneConversationId = user.active_one_on_one_conversation_id ) } @@ -77,7 +79,9 @@ class UserMapper { deleted = user.deleted, hasIncompleteMetadata = user.incomplete_metadata, expiresAt = user.expires_at, - defederated = user.defederated + defederated = user.defederated, + supportedProtocols = user.supported_protocols, + activeOneOnOneConversationId = user.active_one_on_one_conversation_id ) } @@ -100,6 +104,8 @@ class UserMapper { hasIncompleteMetadata: Boolean, expiresAt: Instant?, defederated: Boolean, + supportedProtocols: Set?, + oneOnOneConversationId: QualifiedIDEntity?, isVerifiedProteus: Long, id: String?, teamName: String?, @@ -123,7 +129,9 @@ class UserMapper { hasIncompleteMetadata = hasIncompleteMetadata, expiresAt = expiresAt, defederated = defederated, - isProteusVerified = isVerifiedProteus == 1L + isProteusVerified = isVerifiedProteus == 1L, + supportedProtocols = supportedProtocols, + activeOneOnOneConversationId = oneOnOneConversationId ) val teamEntity = if (team != null && teamName != null && teamIcon != null) { @@ -155,25 +163,8 @@ class UserDAOImpl internal constructor( ) : UserDAO { val mapper = UserMapper() - - override suspend fun insertUser(user: UserEntity) = withContext(queriesContext) { - userQueries.insertUser( - qualified_id = user.id, - name = user.name, - handle = user.handle, - email = user.email, - phone = user.phone, - accent_id = user.accentId, - team = user.team, - preview_asset_id = user.previewAssetId, - complete_asset_id = user.completeAssetId, - user_type = user.userType, - bot_service = user.botService, - incomplete_metadata = user.hasIncompleteMetadata, - expires_at = user.expiresAt, - connection_status = user.connectionStatus, - deleted = user.deleted - ) + override suspend fun upsertUser(user: UserEntity) { + upsertUsers(listOf(user)) } override suspend fun insertOrIgnoreUsers(users: List) = withContext(queriesContext) { @@ -194,55 +185,30 @@ class UserDAOImpl internal constructor( incomplete_metadata = false, expires_at = user.expiresAt, connection_status = user.connectionStatus, - deleted = user.deleted + deleted = user.deleted, + supported_protocols = user.supportedProtocols ) } } } - override suspend fun upsertTeamMembers(users: List) = withContext(queriesContext) { - userQueries.transaction { - for (user: UserEntity in users) { - userQueries.updateTeamMemberUser( - qualified_id = user.id, - name = user.name, - handle = user.handle, - email = user.email, - phone = user.phone, - accent_id = user.accentId, - team = user.team, - preview_asset_id = user.previewAssetId, - complete_asset_id = user.completeAssetId, - bot_service = user.botService, - ) - val recordDidNotExist = userQueries.selectChanges().executeAsOne() == 0L - if (recordDidNotExist) { - userQueries.insertUser( - qualified_id = user.id, - name = user.name, - handle = user.handle, - email = user.email, - phone = user.phone, - accent_id = user.accentId, - team = user.team, - preview_asset_id = user.previewAssetId, - complete_asset_id = user.completeAssetId, - user_type = user.userType, - bot_service = user.botService, - incomplete_metadata = user.hasIncompleteMetadata, - expires_at = user.expiresAt, - connection_status = user.connectionStatus, - deleted = user.deleted - ) - } - } - } + override suspend fun updateUser(id: UserIDEntity, update: PartialUserEntity) = withContext(queriesContext) { + userQueries.updateUser( + name = update.name, + handle = update.handle, + email = update.email, + accent_id = update.accentId?.toLong(), + preview_asset_id = update.previewAssetId, + complete_asset_id = update.completeAssetId, + supported_protocols = update.supportedProtocols, + id + ).executeAsOne() > 0 } override suspend fun upsertUsers(users: List) = withContext(queriesContext) { userQueries.transaction { for (user: UserEntity in users) { - userQueries.updateUser( + userQueries.insertUser( qualified_id = user.id, name = user.name, handle = user.handle, @@ -254,73 +220,25 @@ class UserDAOImpl internal constructor( complete_asset_id = user.completeAssetId, user_type = user.userType, bot_service = user.botService, - incomplete_metadata = false, - expires_at = user.expiresAt + incomplete_metadata = user.hasIncompleteMetadata, + expires_at = user.expiresAt, + connection_status = user.connectionStatus, + deleted = user.deleted, + supported_protocols = user.supportedProtocols, + active_one_on_one_conversation_id = user.activeOneOnOneConversationId ) - val recordDidNotExist = userQueries.selectChanges().executeAsOne() == 0L - if (recordDidNotExist) { - userQueries.insertUser( - qualified_id = user.id, - name = user.name, - handle = user.handle, - email = user.email, - phone = user.phone, - accent_id = user.accentId, - team = user.team, - connection_status = user.connectionStatus, - preview_asset_id = user.previewAssetId, - complete_asset_id = user.completeAssetId, - user_type = user.userType, - bot_service = user.botService, - deleted = user.deleted, - incomplete_metadata = user.hasIncompleteMetadata, - expires_at = user.expiresAt - ) - } } } } - override suspend fun upsertTeamMembersTypes(users: List) { + override suspend fun upsertTeamMemberUserTypes(users: Map) { userQueries.transaction { - for (user: UserEntity in users) { - userQueries.updateTeamMemberType(user.team, user.connectionStatus, user.userType, user.id) - val recordDidNotExist = userQueries.selectChanges().executeAsOne() == 0L - if (recordDidNotExist) { - userQueries.insertUser( - qualified_id = user.id, - name = user.name, - handle = user.handle, - email = user.email, - phone = user.phone, - accent_id = user.accentId, - team = user.team, - connection_status = user.connectionStatus, - preview_asset_id = user.previewAssetId, - complete_asset_id = user.completeAssetId, - user_type = user.userType, - bot_service = user.botService, - deleted = user.deleted, - incomplete_metadata = user.hasIncompleteMetadata, - expires_at = user.expiresAt - ) - } + for (user: Map.Entry in users) { + userQueries.upsertTeamMemberUserType(user.key, ConnectionEntity.State.ACCEPTED, user.value) } } } - override suspend fun updateUser(user: UserEntity) = withContext(queriesContext) { - userQueries.updateSelfUser( - qualified_id = user.id, - name = user.name, - handle = user.handle, - email = user.email, - accent_id = user.accentId, - preview_asset_id = user.previewAssetId, - complete_asset_id = user.completeAssetId, - ) - } - override suspend fun getAllUsersDetails(): Flow> = userQueries.selectAllUsers() .asFlow() .flowOn(queriesContext) @@ -373,6 +291,10 @@ class UserDAOImpl internal constructor( .mapToList() .map { it.map(mapper::toDetailsModel) } + override suspend fun getUsersWithOneOnOneConversation(): List = withContext(queriesContext) { + userQueries.selectUsersWithOneOnOne().executeAsList().map(mapper::toModel) + } + override suspend fun deleteUserByQualifiedID(qualifiedID: QualifiedIDEntity) = withContext(queriesContext) { userQueries.deleteUser(qualifiedID) } @@ -459,4 +381,19 @@ class UserDAOImpl internal constructor( userQueries.userIdsWithoutSelf().executeAsList() } + override suspend fun updateUserSupportedProtocols(selfUserId: QualifiedIDEntity, supportedProtocols: Set) = + withContext(queriesContext) { + userQueries.updateUserSupportedProtocols(supportedProtocols, selfUserId) + } + + override suspend fun updateActiveOneOnOneConversation(userId: QualifiedIDEntity, conversationId: QualifiedIDEntity) = + withContext(queriesContext) { + userQueries.updateOneOnOnConversationId(conversationId, userId) + } + + override suspend fun upsertConnectionStatus(userId: QualifiedIDEntity, status: ConnectionEntity.State) { + withContext(queriesContext) { + userQueries.upsertUserConnectionStatus(userId, status) + } + } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt index 697f5fda77d..22f512ab492 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAO.kt @@ -34,7 +34,8 @@ data class Client( val lastActive: Instant?, val label: String?, val model: String?, - val mlsPublicKeys: Map? + val mlsPublicKeys: Map?, + val isMLSCapable: Boolean ) data class InsertClientParam( @@ -46,7 +47,8 @@ data class InsertClientParam( val registrationDate: Instant?, val lastActive: Instant?, val model: String?, - val mlsPublicKeys: Map? + val mlsPublicKeys: Map?, + val isMLSCapable: Boolean ) enum class DeviceTypeEntity { diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt index 7b58087e7dc..74d9777d3e1 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOImpl.kt @@ -43,7 +43,8 @@ internal object ClientMapper { label: String?, model: String?, lastActive: Instant?, - mls_public_keys: Map? + mls_public_keys: Map?, + is_mls_capable: Boolean ): Client = Client( userId = user_id, id = id, @@ -51,11 +52,12 @@ internal object ClientMapper { clientType = client_type, isValid = is_valid, isProteusVerified = is_verified, + isMLSCapable = is_mls_capable, registrationDate = registration_date, label = label, model = model, lastActive = lastActive, - mlsPublicKeys = mls_public_keys + mlsPublicKeys = mls_public_keys, ) } @@ -82,6 +84,7 @@ internal class ClientDAOImpl internal constructor( device_type = deviceType, client_type = clientType, is_valid = true, + is_mls_capable = isMLSCapable, registration_date = registrationDate, last_active = lastActive, model = model, diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt index 77ae54ccd79..d789edc6712 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt @@ -41,11 +41,21 @@ interface ConversationDAO { suspend fun updateAllConversationsNotificationDate() suspend fun getAllConversations(): Flow> suspend fun getAllConversationDetails(fromArchive: Boolean): Flow> + suspend fun getConversationIds( + type: ConversationEntity.Type, + protocol: ConversationEntity.Protocol, + teamId: String? = null + ): List + suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow suspend fun observeGetConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): Flow suspend fun getConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): ConversationEntity? suspend fun getConversationByQualifiedID(qualifiedID: QualifiedIDEntity): ConversationViewEntity? - suspend fun observeConversationWithOtherUser(userId: UserIDEntity): Flow + suspend fun getOneOnOneConversationIdsWithOtherUser( + userId: UserIDEntity, + protocol: ConversationEntity.Protocol + ): List + suspend fun observeOneOnOneConversationWithOtherUser(userId: UserIDEntity): Flow suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo? suspend fun observeConversationByGroupID(groupID: String): Flow suspend fun getConversationIdByGroupID(groupID: String): QualifiedIDEntity? @@ -78,8 +88,9 @@ interface ConversationDAO { suspend fun whoDeletedMeInConversation(conversationId: QualifiedIDEntity, selfUserIdString: String): UserIDEntity? suspend fun updateConversationName(conversationId: QualifiedIDEntity, conversationName: String, timestamp: String) suspend fun updateConversationType(conversationID: QualifiedIDEntity, type: ConversationEntity.Type) + suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean suspend fun revokeOneOnOneConversationsWithDeletedUser(userId: UserIDEntity) - suspend fun getConversationIdsByUserId(userId: UserIDEntity): List + suspend fun getConversationsByUserId(userId: UserIDEntity): List suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode) suspend fun updateGuestRoomLink( conversationId: QualifiedIDEntity, diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt index 1325670ed2d..cd5858e79f7 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt @@ -77,14 +77,17 @@ internal class ConversationDAOImpl internal constructor( name, type, teamId, - if (protocolInfo is ConversationEntity.ProtocolInfo.MLS) protocolInfo.groupId + if (protocolInfo is ConversationEntity.ProtocolInfo.MLSCapable) protocolInfo.groupId else null, - if (protocolInfo is ConversationEntity.ProtocolInfo.MLS) protocolInfo.groupState + if (protocolInfo is ConversationEntity.ProtocolInfo.MLSCapable) protocolInfo.groupState else ConversationEntity.GroupState.ESTABLISHED, - if (protocolInfo is ConversationEntity.ProtocolInfo.MLS) protocolInfo.epoch.toLong() + if (protocolInfo is ConversationEntity.ProtocolInfo.MLSCapable) protocolInfo.epoch.toLong() else MLS_DEFAULT_EPOCH, - if (protocolInfo is ConversationEntity.ProtocolInfo.MLS) ConversationEntity.Protocol.MLS - else ConversationEntity.Protocol.PROTEUS, + when (protocolInfo) { + is ConversationEntity.ProtocolInfo.MLS -> ConversationEntity.Protocol.MLS + is ConversationEntity.ProtocolInfo.Mixed -> ConversationEntity.Protocol.MIXED + is ConversationEntity.ProtocolInfo.Proteus -> ConversationEntity.Protocol.PROTEUS + }, mutedStatus, mutedTime, creatorId, @@ -93,9 +96,9 @@ internal class ConversationDAOImpl internal constructor( access, accessRole, lastReadDate, - if (protocolInfo is ConversationEntity.ProtocolInfo.MLS) protocolInfo.keyingMaterialLastUpdate + if (protocolInfo is ConversationEntity.ProtocolInfo.MLSCapable) protocolInfo.keyingMaterialLastUpdate else Instant.fromEpochMilliseconds(MLS_DEFAULT_LAST_KEY_MATERIAL_UPDATE_MILLI), - if (protocolInfo is ConversationEntity.ProtocolInfo.MLS) protocolInfo.cipherSuite + if (protocolInfo is ConversationEntity.ProtocolInfo.MLSCapable) protocolInfo.cipherSuite else MLS_DEFAULT_CIPHER_SUITE, receiptMode, messageTimer, @@ -149,6 +152,24 @@ internal class ConversationDAOImpl internal constructor( .map { list -> list.map { it.let { conversationMapper.toModel(it) } } } } + override suspend fun getConversationIds( + type: ConversationEntity.Type, + protocol: ConversationEntity.Protocol, + teamId: String? + ): List { + return withContext(coroutineContext) { + conversationQueries.selectConversationIds(protocol, type, teamId).executeAsList() + } + } + + override suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List { + return withContext(coroutineContext) { + conversationQueries.selectAllTeamProteusConversationsReadyForMigration(teamId) + .executeAsList() + .map { it.qualified_id } + } + } + override suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow { return conversationQueries.selectByQualifiedId(qualifiedID) .asFlow() @@ -177,12 +198,20 @@ internal class ConversationDAOImpl internal constructor( } } - override suspend fun observeConversationWithOtherUser(userId: UserIDEntity): Flow { - return memberQueries.selectConversationByMember(userId) + override suspend fun getOneOnOneConversationIdsWithOtherUser( + userId: UserIDEntity, + protocol: ConversationEntity.Protocol + ): List = + withContext(coroutineContext) { + conversationQueries.selectOneOnOneConversationIdsByProtocol(protocol, userId).executeAsList() + } + + override suspend fun observeOneOnOneConversationWithOtherUser(userId: UserIDEntity): Flow { + return conversationQueries.selectActiveOneOnOneConversation(userId) .asFlow() .mapToOneOrNull() .flowOn(coroutineContext) - .map { it?.let { conversationMapper.fromOneToOneToModel(it) } } + .map { it?.let { conversationMapper.toModel(it) } } } override suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo? = @@ -210,7 +239,7 @@ internal class ConversationDAOImpl internal constructor( override suspend fun getConversationsByGroupState(groupState: ConversationEntity.GroupState): List = withContext(coroutineContext) { - conversationQueries.selectByGroupState(groupState, ConversationEntity.Protocol.MLS) + conversationQueries.selectByGroupState(groupState) .executeAsList() .map(conversationMapper::toModel) } @@ -259,7 +288,6 @@ internal class ConversationDAOImpl internal constructor( override suspend fun getConversationsByKeyingMaterialUpdate(threshold: Duration): List = withContext(coroutineContext) { conversationQueries.selectByKeyingMaterialUpdate( ConversationEntity.GroupState.ESTABLISHED, - ConversationEntity.Protocol.MLS, DateTimeUtil.currentInstant().minus(threshold) ).executeAsList() } @@ -273,7 +301,7 @@ internal class ConversationDAOImpl internal constructor( } override suspend fun getProposalTimers(): Flow> { - return conversationQueries.selectProposalTimers(ConversationEntity.Protocol.MLS) + return conversationQueries.selectProposalTimers() .asFlow() .flowOn(coroutineContext) .mapToList() @@ -295,12 +323,18 @@ internal class ConversationDAOImpl internal constructor( conversationQueries.updateConversationType(type, conversationID) } + override suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean { + return withContext(coroutineContext) { + conversationQueries.updateConversationProtocol(protocol, conversationId).executeAsOne() > 0 + } + } + override suspend fun revokeOneOnOneConversationsWithDeletedUser(userId: UserIDEntity) = withContext(coroutineContext) { memberQueries.deleteUserFromGroupConversations(userId, userId) } - override suspend fun getConversationIdsByUserId(userId: UserIDEntity): List = withContext(coroutineContext) { - memberQueries.selectConversationsByMember(userId).executeAsList().map { it.conversation } + override suspend fun getConversationsByUserId(userId: UserIDEntity): List = withContext(coroutineContext) { + memberQueries.selectConversationsByMember(userId, conversationMapper::toModel).executeAsList() } override suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode) = diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt index 2fa67582168..cf445bc12ea 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt @@ -54,7 +54,7 @@ data class ConversationEntity( enum class GroupState { PENDING_CREATION, PENDING_JOIN, PENDING_WELCOME_MESSAGE, ESTABLISHED } - enum class Protocol { PROTEUS, MLS } + enum class Protocol { PROTEUS, MLS, MIXED } enum class ReceiptMode { DISABLED, ENABLED } enum class VerificationStatus { VERIFIED, NOT_VERIFIED, DEGRADED } @@ -67,7 +67,8 @@ data class ConversationEntity( MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448(4), MLS_256_DHKEMP521_AES256GCM_SHA512_P521(5), MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448(6), - MLS_256_DHKEMP384_AES256GCM_SHA384_P384(7); + MLS_256_DHKEMP384_AES256GCM_SHA384_P384(7), + MLS_128_X25519KYBER768DRAFT00_AES128GCM_SHA256_ED25519(61489); companion object { fun fromTag(tag: Int?): CipherSuite = @@ -77,14 +78,29 @@ data class ConversationEntity( enum class MutedStatus { ALL_ALLOWED, ONLY_MENTIONS_AND_REPLIES_ALLOWED, MENTIONS_MUTED, ALL_MUTED } - sealed class ProtocolInfo { - object Proteus : ProtocolInfo() + sealed interface ProtocolInfo { + object Proteus : ProtocolInfo data class MLS( - val groupId: String, - val groupState: GroupState, - val epoch: ULong, - val keyingMaterialLastUpdate: Instant, - val cipherSuite: CipherSuite - ) : ProtocolInfo() + override val groupId: String, + override val groupState: ConversationEntity.GroupState, + override val epoch: ULong, + override val keyingMaterialLastUpdate: Instant, + override val cipherSuite: ConversationEntity.CipherSuite + ) : MLSCapable + data class Mixed( + override val groupId: String, + override val groupState: ConversationEntity.GroupState, + override val epoch: ULong, + override val keyingMaterialLastUpdate: Instant, + override val cipherSuite: ConversationEntity.CipherSuite + ) : MLSCapable + + sealed interface MLSCapable : ProtocolInfo { + val groupId: String + val groupState: ConversationEntity.GroupState + val epoch: ULong + val keyingMaterialLastUpdate: Instant + val cipherSuite: ConversationEntity.CipherSuite + } } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt index 66a277f8eac..5db58eec605 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationMapper.kt @@ -17,10 +17,10 @@ */ package com.wire.kalium.persistence.dao.conversation -import com.wire.kalium.persistence.SelectConversationByMember import com.wire.kalium.persistence.dao.QualifiedIDEntity import kotlinx.datetime.Instant import com.wire.kalium.persistence.ConversationDetails as SQLDelightConversationView + internal class ConversationMapper { fun toModel(conversation: SQLDelightConversationView): ConversationViewEntity = with(conversation) { ConversationViewEntity( @@ -67,7 +67,9 @@ internal class ConversationMapper { userDefederated = userDefederated, archived = archived, archivedDateTime = archived_date_time, - verificationStatus = verification_status + verificationStatus = verification_status, + userSupportedProtocols = userSupportedProtocols, + userActiveOneOnOneConversationId = otherUserActiveConversationId ) } @@ -127,57 +129,6 @@ internal class ConversationMapper { verificationStatus = verificationStatus ) - fun fromOneToOneToModel(conversation: SelectConversationByMember?): ConversationViewEntity? { - return conversation?.run { - ConversationViewEntity( - id = qualifiedId, - name = name, - type = type, - teamId = teamId, - protocolInfo = mapProtocolInfo( - protocol, - mls_group_id, - mls_group_state, - mls_epoch, - mls_last_keying_material_update_date, - mls_cipher_suite - ), - isCreator = isCreator, - mutedStatus = mutedStatus, - mutedTime = muted_time, - creatorId = creator_id, - lastNotificationDate = lastNotifiedMessageDate, - lastModifiedDate = last_modified_date, - lastReadDate = lastReadDate, - accessList = access_list, - accessRoleList = access_role_list, - protocol = protocol, - mlsCipherSuite = mls_cipher_suite, - mlsEpoch = mls_epoch, - mlsGroupId = mls_group_id, - mlsLastKeyingMaterialUpdateDate = mls_last_keying_material_update_date, - mlsGroupState = mls_group_state, - mlsProposalTimer = mls_proposal_timer, - callStatus = callStatus, - previewAssetId = previewAssetId, - userAvailabilityStatus = userAvailabilityStatus, - userType = userType, - botService = botService, - userDeleted = userDeleted, - connectionStatus = connectionStatus, - otherUserId = otherUserId, - selfRole = selfRole, - receiptMode = receipt_mode, - messageTimer = message_timer, - userMessageTimer = user_message_timer, - userDefederated = userDefederated, - archived = archived, - archivedDateTime = archived_date_time, - verificationStatus = verification_status - ) - } - } - @Suppress("LongParameterList") fun mapProtocolInfo( protocol: ConversationEntity.Protocol, @@ -196,6 +147,14 @@ internal class ConversationMapper { mlsCipherSuite ) + ConversationEntity.Protocol.MIXED -> ConversationEntity.ProtocolInfo.Mixed( + mlsGroupId ?: "", + mlsGroupState, + mlsEpoch.toULong(), + mlsLastKeyingMaterialUpdate, + mlsCipherSuite + ) + ConversationEntity.Protocol.PROTEUS -> ConversationEntity.ProtocolInfo.Proteus } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationViewEntity.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationViewEntity.kt index 15afb228259..c6632975d0f 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationViewEntity.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationViewEntity.kt @@ -19,7 +19,9 @@ package com.wire.kalium.persistence.dao.conversation import com.wire.kalium.persistence.dao.BotIdEntity import com.wire.kalium.persistence.dao.ConnectionEntity +import com.wire.kalium.persistence.dao.ConversationIDEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.dao.UserAvailabilityStatusEntity import com.wire.kalium.persistence.dao.UserIDEntity import com.wire.kalium.persistence.dao.UserTypeEntity @@ -66,7 +68,9 @@ data class ConversationViewEntity( val userMessageTimer: Long?, val archived: Boolean, val archivedDateTime: Instant?, - val verificationStatus: ConversationEntity.VerificationStatus + val verificationStatus: ConversationEntity.VerificationStatus, + val userSupportedProtocols: Set?, + val userActiveOneOnOneConversationId: ConversationIDEntity? ) { val isMember: Boolean get() = selfRole != null diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/member/MemberDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/member/MemberDAO.kt index a6a305fddaf..debd2018ab4 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/member/MemberDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/member/MemberDAO.kt @@ -21,7 +21,6 @@ import app.cash.sqldelight.coroutines.asFlow import com.wire.kalium.persistence.ConversationsQueries import com.wire.kalium.persistence.MembersQueries import com.wire.kalium.persistence.UsersQueries -import com.wire.kalium.persistence.dao.ConnectionEntity import com.wire.kalium.persistence.dao.ConversationIDEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity import com.wire.kalium.persistence.dao.UserIDEntity @@ -45,9 +44,8 @@ interface MemberDAO { suspend fun deleteMembersByQualifiedID(userIDList: List, conversationID: QualifiedIDEntity) suspend fun observeConversationMembers(qualifiedID: QualifiedIDEntity): Flow> suspend fun updateConversationMemberRole(conversationId: QualifiedIDEntity, userId: UserIDEntity, role: MemberEntity.Role) - suspend fun updateOrInsertOneOnOneMemberWithConnectionStatus( + suspend fun updateOrInsertOneOnOneMember( member: MemberEntity, - status: ConnectionEntity.State, conversationID: QualifiedIDEntity ) @@ -150,13 +148,11 @@ internal class MemberDAOImpl internal constructor( memberQueries.updateMemberRole(role, userId, conversationId) } - override suspend fun updateOrInsertOneOnOneMemberWithConnectionStatus( + override suspend fun updateOrInsertOneOnOneMember( member: MemberEntity, - status: ConnectionEntity.State, conversationID: QualifiedIDEntity ) = withContext(coroutineContext) { memberQueries.transaction { - userQueries.upsertUserConnectionStatus(member.user, status) conversationsQueries.updateConversationType(ConversationEntity.Type.ONE_ON_ONE, conversationID) val conversationRecordExist = conversationsQueries.selectChanges().executeAsOne() != 0L if (conversationRecordExist) { diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAO.kt index 06c0aecf59e..54db47261c2 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAO.kt @@ -133,5 +133,7 @@ interface MessageDAO { recipientFailureTypeEntity: RecipientFailureTypeEntity ) + suspend fun moveMessages(from: ConversationIDEntity, to: ConversationIDEntity) + val platformExtensions: MessageExtensions } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOImpl.kt index b827848b52a..60f119f0dd9 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOImpl.kt @@ -394,6 +394,11 @@ internal class MessageDAOImpl internal constructor( queries.insertMessageRecipientsFailure(id, conversationsId, recipientsFailed, recipientFailureTypeEntity) } + override suspend fun moveMessages(from: ConversationIDEntity, to: ConversationIDEntity) = + withContext(coroutineContext) { + queries.moveMessages(to, from) + } + override suspend fun getConversationUnreadEventsCount(conversationId: QualifiedIDEntity): Long = withContext(coroutineContext) { unreadEventsQueries.getConversationUnreadEventsCount(conversationId).executeAsOne() } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt index 69519e305d9..df12d11c9f7 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt @@ -190,9 +190,10 @@ sealed interface MessageEntity { enum class ContentType { TEXT, ASSET, KNOCK, MEMBER_CHANGE, MISSED_CALL, RESTRICTED_ASSET, CONVERSATION_RENAMED, UNKNOWN, FAILED_DECRYPTION, REMOVED_FROM_TEAM, CRYPTO_SESSION_RESET, - NEW_CONVERSATION_RECEIPT_MODE, CONVERSATION_RECEIPT_MODE_CHANGED, HISTORY_LOST, CONVERSATION_MESSAGE_TIMER_CHANGED, - CONVERSATION_CREATED, MLS_WRONG_EPOCH_WARNING, CONVERSATION_DEGRADED_MLS, CONVERSATION_DEGRADED_PREOTEUS, CONVERSATION_VERIFIED_MLS, - CONVERSATION_VERIFIED_PREOTEUS, COMPOSITE, FEDERATION + NEW_CONVERSATION_RECEIPT_MODE, CONVERSATION_RECEIPT_MODE_CHANGED, HISTORY_LOST, HISTORY_LOST_PROTOCOL_CHANGED, + CONVERSATION_MESSAGE_TIMER_CHANGED, CONVERSATION_CREATED, MLS_WRONG_EPOCH_WARNING, CONVERSATION_DEGRADED_MLS, + CONVERSATION_DEGRADED_PREOTEUS, CONVERSATION_VERIFIED_MLS, CONVERSATION_VERIFIED_PREOTEUS, COMPOSITE, FEDERATION, + CONVERSATION_PROTOCOL_CHANGED } enum class MemberChangeType { @@ -325,6 +326,8 @@ sealed class MessageEntityContent { data class NewConversationReceiptMode(val receiptMode: Boolean) : System() data class ConversationReceiptModeChanged(val receiptMode: Boolean) : System() data class ConversationMessageTimerChanged(val messageTimer: Long?) : System() + data class ConversationProtocolChanged(val protocol: ConversationEntity.Protocol) : System() + object HistoryLostProtocolChanged : System() object HistoryLost : System() object ConversationCreated : System() data object ConversationDegradedMLS : System() diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt index c8890069539..5e9e897f538 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt @@ -203,6 +203,10 @@ internal class MessageInsertExtensionImpl( /* no-op */ } + is MessageEntityContent.HistoryLostProtocolChanged -> { + /* no-op */ + } + is MessageEntityContent.ConversationReceiptModeChanged -> messagesQueries.insertConversationReceiptModeChanged( message_id = message.id, conversation_id = message.conversationId, @@ -251,6 +255,12 @@ internal class MessageInsertExtensionImpl( federation_type = content.type ) } + + is MessageEntityContent.ConversationProtocolChanged -> messagesQueries.insertConversationProtocolChanged( + message_id = message.id, + conversation_id = message.conversationId, + protocol = content.protocol + ) } } @@ -294,8 +304,10 @@ internal class MessageInsertExtensionImpl( is MessageEntityContent.ConversationMessageTimerChanged, is MessageEntityContent.ConversationReceiptModeChanged, is MessageEntityContent.ConversationRenamed, + is MessageEntityContent.ConversationProtocolChanged, MessageEntityContent.CryptoSessionReset, MessageEntityContent.HistoryLost, + MessageEntityContent.HistoryLostProtocolChanged, MessageEntityContent.MLSWrongEpochWarning, is MessageEntityContent.MemberChange, is MessageEntityContent.NewConversationReceiptMode, @@ -385,6 +397,7 @@ internal class MessageInsertExtensionImpl( is MessageEntityContent.NewConversationReceiptMode -> MessageEntity.ContentType.NEW_CONVERSATION_RECEIPT_MODE is MessageEntityContent.ConversationReceiptModeChanged -> MessageEntity.ContentType.CONVERSATION_RECEIPT_MODE_CHANGED is MessageEntityContent.HistoryLost -> MessageEntity.ContentType.HISTORY_LOST + is MessageEntityContent.HistoryLostProtocolChanged -> MessageEntity.ContentType.HISTORY_LOST_PROTOCOL_CHANGED is MessageEntityContent.ConversationMessageTimerChanged -> MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED is MessageEntityContent.ConversationCreated -> MessageEntity.ContentType.CONVERSATION_CREATED is MessageEntityContent.MLSWrongEpochWarning -> MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING @@ -394,5 +407,6 @@ internal class MessageInsertExtensionImpl( is MessageEntityContent.Federation -> MessageEntity.ContentType.FEDERATION MessageEntityContent.ConversationVerifiedMLS -> MessageEntity.ContentType.CONVERSATION_VERIFIED_MLS MessageEntityContent.ConversationVerifiedProteus -> MessageEntity.ContentType.CONVERSATION_VERIFIED_PREOTEUS + is MessageEntityContent.ConversationProtocolChanged -> MessageEntity.ContentType.CONVERSATION_PROTOCOL_CHANGED } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt index 09d876a112b..0a89eb13b5a 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt @@ -194,6 +194,7 @@ object MessageMapper { MessageEntity.ContentType.NEW_CONVERSATION_RECEIPT_MODE -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.CONVERSATION_RECEIPT_MODE_CHANGED -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.HISTORY_LOST -> MessagePreviewEntityContent.Unknown + MessageEntity.ContentType.HISTORY_LOST_PROTOCOL_CHANGED -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.CONVERSATION_CREATED -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING -> MessagePreviewEntityContent.Unknown @@ -204,6 +205,7 @@ object MessageMapper { MessageEntity.ContentType.CRYPTO_SESSION_RESET -> MessagePreviewEntityContent.CryptoSessionReset MessageEntity.ContentType.CONVERSATION_VERIFIED_MLS -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.CONVERSATION_VERIFIED_PREOTEUS -> MessagePreviewEntityContent.Unknown + MessageEntity.ContentType.CONVERSATION_PROTOCOL_CHANGED -> MessagePreviewEntityContent.Unknown } } @@ -441,7 +443,8 @@ object MessageMapper { recipientsFailedDeliveryList: List?, buttonsJson: String, federationDomainList: List?, - federationType: MessageEntity.FederationType? + federationType: MessageEntity.FederationType?, + conversationProtocolChanged: ConversationEntity.Protocol? ): MessageEntity { // If message hsa been deleted, we don't care about the content. Also most of their internal content is null anyways val content = if (visibility == MessageEntity.Visibility.DELETED) { @@ -555,6 +558,7 @@ object MessageMapper { ) MessageEntity.ContentType.HISTORY_LOST -> MessageEntityContent.HistoryLost + MessageEntity.ContentType.HISTORY_LOST_PROTOCOL_CHANGED -> MessageEntityContent.HistoryLostProtocolChanged MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED -> MessageEntityContent.ConversationMessageTimerChanged( messageTimer = messageTimerChanged ) @@ -569,6 +573,10 @@ object MessageMapper { domainList = federationDomainList.requireField("federationDomainList"), type = federationType.requireField("federationType") ) + + MessageEntity.ContentType.CONVERSATION_PROTOCOL_CHANGED -> MessageEntityContent.ConversationProtocolChanged( + protocol = conversationProtocolChanged ?: ConversationEntity.Protocol.PROTEUS + ) } return createMessageEntity( diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/unread/UserConfigDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/unread/UserConfigDAO.kt index f6255dc67d2..e39e7149c83 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/unread/UserConfigDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/unread/UserConfigDAO.kt @@ -17,9 +17,12 @@ */ package com.wire.kalium.persistence.dao.unread +import com.wire.kalium.persistence.config.MLSMigrationEntity import com.wire.kalium.persistence.config.TeamSettingsSelfDeletionStatusEntity import com.wire.kalium.persistence.dao.MetadataDAO +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.builtins.SetSerializer interface UserConfigDAO { @@ -30,6 +33,12 @@ interface UserConfigDAO { suspend fun markTeamSettingsSelfDeletingMessagesStatusAsNotified() suspend fun observeTeamSettingsSelfDeletingStatus(): Flow + + suspend fun getMigrationConfiguration(): MLSMigrationEntity? + suspend fun setMigrationConfiguration(configuration: MLSMigrationEntity) + + suspend fun getSupportedProtocols(): Set? + suspend fun setSupportedProtocols(protocols: Set) } internal class UserConfigDAOImpl internal constructor( @@ -37,23 +46,23 @@ internal class UserConfigDAOImpl internal constructor( ) : UserConfigDAO { override suspend fun getTeamSettingsSelfDeletionStatus(): TeamSettingsSelfDeletionStatusEntity? = - metadataDAO.getSerializable(SELF_DELETING_MESSAGES, TeamSettingsSelfDeletionStatusEntity.serializer()) + metadataDAO.getSerializable(SELF_DELETING_MESSAGES_KEY, TeamSettingsSelfDeletionStatusEntity.serializer()) override suspend fun setTeamSettingsSelfDeletionStatus( teamSettingsSelfDeletionStatusEntity: TeamSettingsSelfDeletionStatusEntity ) { metadataDAO.putSerializable( - key = SELF_DELETING_MESSAGES, + key = SELF_DELETING_MESSAGES_KEY, value = teamSettingsSelfDeletionStatusEntity, TeamSettingsSelfDeletionStatusEntity.serializer() ) } override suspend fun markTeamSettingsSelfDeletingMessagesStatusAsNotified() { - metadataDAO.getSerializable(SELF_DELETING_MESSAGES, TeamSettingsSelfDeletionStatusEntity.serializer()) + metadataDAO.getSerializable(SELF_DELETING_MESSAGES_KEY, TeamSettingsSelfDeletionStatusEntity.serializer()) ?.copy(isStatusChanged = false)?.let { newValue -> metadataDAO.putSerializable( - SELF_DELETING_MESSAGES, + SELF_DELETING_MESSAGES_KEY, newValue, TeamSettingsSelfDeletionStatusEntity.serializer() ) @@ -61,9 +70,23 @@ internal class UserConfigDAOImpl internal constructor( } override suspend fun observeTeamSettingsSelfDeletingStatus(): Flow = - metadataDAO.observeSerializable(SELF_DELETING_MESSAGES, TeamSettingsSelfDeletionStatusEntity.serializer()) + metadataDAO.observeSerializable(SELF_DELETING_MESSAGES_KEY, TeamSettingsSelfDeletionStatusEntity.serializer()) + + override suspend fun getMigrationConfiguration(): MLSMigrationEntity? = + metadataDAO.getSerializable(MLS_MIGRATION_KEY, MLSMigrationEntity.serializer()) + + override suspend fun setMigrationConfiguration(configuration: MLSMigrationEntity) = + metadataDAO.putSerializable(MLS_MIGRATION_KEY, configuration, MLSMigrationEntity.serializer()) + + override suspend fun getSupportedProtocols(): Set? = + metadataDAO.getSerializable(SUPPORTED_PROTOCOLS_KEY, SetSerializer(SupportedProtocolEntity.serializer())) + + override suspend fun setSupportedProtocols(protocols: Set) = + metadataDAO.putSerializable(SUPPORTED_PROTOCOLS_KEY, protocols, SetSerializer(SupportedProtocolEntity.serializer())) private companion object { - private const val SELF_DELETING_MESSAGES = "SELF_DELETING_MESSAGES" + private const val SELF_DELETING_MESSAGES_KEY = "SELF_DELETING_MESSAGES" + private const val MLS_MIGRATION_KEY = "MLS_MIGRATION" + private const val SUPPORTED_PROTOCOLS_KEY = "SUPPORTED_PROTOCOLS" } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/TableMapper.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/TableMapper.kt index 04fa1100167..812e1bc19c8 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/TableMapper.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/TableMapper.kt @@ -28,6 +28,7 @@ import com.wire.kalium.persistence.Member import com.wire.kalium.persistence.Message import com.wire.kalium.persistence.MessageAssetContent import com.wire.kalium.persistence.MessageConversationChangedContent +import com.wire.kalium.persistence.MessageConversationProtocolChangedContent import com.wire.kalium.persistence.MessageConversationReceiptModeChangedContent import com.wire.kalium.persistence.MessageConversationTimerChangedContent import com.wire.kalium.persistence.MessageFailedToDecryptContent @@ -59,6 +60,7 @@ import com.wire.kalium.persistence.adapter.QualifiedIDListAdapter import com.wire.kalium.persistence.adapter.ServiceTagListAdapter import com.wire.kalium.persistence.content.ButtonContent import com.wire.kalium.persistence.adapter.StringListAdapter +import com.wire.kalium.persistence.adapter.SupportedProtocolSetAdapter internal object TableMapper { val callAdapter = Call.Adapter( @@ -176,7 +178,9 @@ internal object TableMapper { complete_asset_idAdapter = QualifiedIDAdapter, user_typeAdapter = EnumColumnAdapter(), bot_serviceAdapter = BotServiceAdapter(), - expires_atAdapter = InstantTypeAdapter + expires_atAdapter = InstantTypeAdapter, + supported_protocolsAdapter = SupportedProtocolSetAdapter, + active_one_on_one_conversation_idAdapter = QualifiedIDAdapter ) val messageNewConversationReceiptModeContentAdapter = MessageNewConversationReceiptModeContent.Adapter( conversation_idAdapter = QualifiedIDAdapter @@ -187,7 +191,10 @@ internal object TableMapper { val messageConversationTimerChangedContentAdapter = MessageConversationTimerChangedContent.Adapter( conversation_idAdapter = QualifiedIDAdapter ) - + val messageConversationProtocolChangedContentAdapter = MessageConversationProtocolChangedContent.Adapter( + conversation_idAdapter = QualifiedIDAdapter, + protocolAdapter = EnumColumnAdapter() + ) val unreadEventAdapter = UnreadEvent.Adapter( conversation_idAdapter = QualifiedIDAdapter, typeAdapter = EnumColumnAdapter(), diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/UserDatabaseBuilder.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/UserDatabaseBuilder.kt index 282dae7e286..4bbff76933e 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/UserDatabaseBuilder.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/db/UserDatabaseBuilder.kt @@ -152,7 +152,8 @@ class UserDatabaseBuilder internal constructor( NewClientAdapter = TableMapper.newClientAdapter, MessageRecipientFailureAdapter = TableMapper.messageRecipientFailureAdapter, ButtonContentAdapter = TableMapper.buttonContentAdapter, - MessageFederationTerminatedContentAdapter = TableMapper.messageFederationTerminatedContentAdapter + MessageFederationTerminatedContentAdapter = TableMapper.messageFederationTerminatedContentAdapter, + MessageConversationProtocolChangedContentAdapter = TableMapper.messageConversationProtocolChangedContentAdapter ) init { diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseExporterTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseExporterTest.kt index ab9143dc53d..41aed5a54d7 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseExporterTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseExporterTest.kt @@ -50,9 +50,9 @@ class DatabaseExporterTest : BaseDatabaseTest() { runTest { with(localDB.userDAO) { - insertUser(SELF_USER) - insertUser(OTHER_USER) - insertUser(OTHER_USER_2) + upsertUser(SELF_USER) + upsertUser(OTHER_USER) + upsertUser(OTHER_USER_2) } with(localDB.conversationDAO) { diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseImporterTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseImporterTest.kt index 994b54aacd7..b29643b70d0 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseImporterTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/DatabaseImporterTest.kt @@ -504,7 +504,7 @@ class DatabaseImporterTest : BaseDatabaseTest() { val uniqueBackupUsers = backupDatabaseDataGenerator.generateAndInsertUsers(uniqueBackupUsersAmount) uniqueBackupUsers.forEach { userEntity -> - backupDatabaseBuilder.userDAO.insertUser(userEntity.toSimpleEntity()) + backupDatabaseBuilder.userDAO.upsertUser(userEntity.toSimpleEntity()) } // when diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/UserDatabaseDataGenerator.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/UserDatabaseDataGenerator.kt index be098ed8dc3..f6a7986d1bd 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/UserDatabaseDataGenerator.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/backup/UserDatabaseDataGenerator.kt @@ -67,7 +67,7 @@ class UserDatabaseDataGenerator( for (index in generatedMessagesCount + 1..amount) { val senderUser = generateUser() - userDatabaseBuilder.userDAO.insertUser(senderUser) + userDatabaseBuilder.userDAO.upsertUser(senderUser) val visibility = MessageEntity.Visibility.values()[index % MessageEntity.Visibility.values().size] @@ -112,7 +112,7 @@ class UserDatabaseDataGenerator( for (index in generatedAssetsCount + 1..amount) { val senderUser = generateUser() - userDatabaseBuilder.userDAO.insertUser(senderUser) + userDatabaseBuilder.userDAO.upsertUser(senderUser) val visibility = MessageEntity.Visibility.values()[index % MessageEntity.Visibility.values().size] @@ -194,7 +194,9 @@ class UserDatabaseDataGenerator( botService = null, hasIncompleteMetadata = false, expiresAt = null, - defederated = false + defederated = false, + supportedProtocols = null, + activeOneOnOneConversationId = null ) } @@ -208,7 +210,7 @@ class UserDatabaseDataGenerator( for (index in generatedMessagesCount + 1..amount) { val senderUser = generateUser() - userDatabaseBuilder.userDAO.insertUser(senderUser) + userDatabaseBuilder.userDAO.upsertUser(senderUser) val visibility = MessageEntity.Visibility.values()[index % MessageEntity.Visibility.values().size] @@ -337,7 +339,7 @@ class UserDatabaseDataGenerator( val callPrefix = "${databasePrefix}Call${generatedCallsCount}" val userEntity = generateUser() - userDatabaseBuilder.userDAO.insertUser(userEntity) + userDatabaseBuilder.userDAO.upsertUser(userEntity) val conversationType = ConversationEntity.Type.values()[generatedCallsCount % ConversationEntity.Type.values().size] val type = CallEntity.Type.values()[generatedCallsCount % CallEntity.Type.values().size] @@ -548,7 +550,7 @@ class UserDatabaseDataGenerator( for (index in generatedUsersCount + 1..amount) { val user = generateUser() - userDatabaseBuilder.userDAO.insertUser(user) + userDatabaseBuilder.userDAO.upsertUser(user) } return userDatabaseBuilder.userDAO.getAllUsersDetails().first() diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/config/UserConfigStorageTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/config/UserConfigStorageTest.kt index d0cb9106f6a..760fff3f88b 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/config/UserConfigStorageTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/config/UserConfigStorageTest.kt @@ -21,6 +21,7 @@ package com.wire.kalium.persistence.config import app.cash.turbine.test import com.russhwolf.settings.MapSettings import com.russhwolf.settings.Settings +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.kmmSettings.KaliumPreferences import com.wire.kalium.persistence.kmmSettings.KaliumPreferencesSettings import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -178,4 +179,19 @@ class UserConfigStorageTest { } } } + + @Test + fun givenDefaultProtocolIsNotSet_whenGettingItsValue_thenItShouldBeProteus() { + userConfigStorage.defaultProtocol().let { + assertEquals(SupportedProtocolEntity.PROTEUS, it) + } + } + + @Test + fun givenDefaultProtocolIsSetToMls_whenGettingItsValue_thenItShouldBeMls() { + userConfigStorage.persistDefaultProtocol(SupportedProtocolEntity.MLS) + userConfigStorage.defaultProtocol().let { + assertEquals(SupportedProtocolEntity.MLS, it) + } + } } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConnectionDaoTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConnectionDaoTest.kt index d0c5cf473d4..80b7f9df910 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConnectionDaoTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConnectionDaoTest.kt @@ -84,14 +84,16 @@ class ConnectionDaoTest : BaseDatabaseTest() { } companion object { + val OTHER_USER_ID = QualifiedIDEntity("me", "wire.com") + private fun connectionEntity(id: String = "0") = ConnectionEntity( - conversationId = "$id@wire.com", + conversationId = id, from = "from_string", lastUpdateDate = "2022-03-30T15:36:00.000Z".toInstant(), qualifiedConversationId = QualifiedIDEntity(id, "wire.com"), - qualifiedToId = QualifiedIDEntity("me", "wire.com"), + qualifiedToId = OTHER_USER_ID, status = ConnectionEntity.State.PENDING, - toId = "me@wire.com", + toId = OTHER_USER_ID.value, shouldNotify = true ) } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt index 58efea75e7b..96aa4f7a409 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt @@ -49,6 +49,7 @@ import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertContentEquals import kotlin.test.assertEquals +import kotlin.test.assertFalse import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue @@ -81,7 +82,7 @@ class ConversationDAOTest : BaseDatabaseTest() { } @Test - fun givenConversation_ThenConversationCanBeInserted() = runTest { + fun givenConversationIsInserted_whenFetchingById_thenConversationIsReturned() = runTest { conversationDAO.insertConversation(conversationEntity1) insertTeamUserAndMember(team, user1, conversationEntity1.id) val result = conversationDAO.getConversationByQualifiedID(conversationEntity1.id) @@ -131,6 +132,14 @@ class ConversationDAOTest : BaseDatabaseTest() { assertEquals(conversationEntity2.toViewEntity(user2), result) } + @Test + fun givenExistingMixedConversation_ThenConversationIdCanBeRetrievedByGroupID() = runTest { + conversationDAO.insertConversation(conversationEntity6) + insertTeamUserAndMember(team, user2, conversationEntity6.id) + val result = + conversationDAO.getConversationIdByGroupID((conversationEntity6.protocolInfo as ConversationEntity.ProtocolInfo.Mixed).groupId) + assertEquals(conversationEntity6.id, result) + } @Test fun givenExistingMLSConversation_ThenConversationIdCanBeRetrievedByGroupID() = runTest { conversationDAO.insertConversation(conversationEntity2) @@ -140,6 +149,61 @@ class ConversationDAOTest : BaseDatabaseTest() { assertEquals(conversationEntity2.id, result) } + @Test + fun givenExistingMixedConversation_ThenConversationCanBeRetrievedByGroupState() = runTest { + conversationDAO.insertConversation(conversationEntity6) + conversationDAO.insertConversation(conversationEntity3) + insertTeamUserAndMember(team, user2, conversationEntity6.id) + val result = + conversationDAO.getConversationsByGroupState(ConversationEntity.GroupState.ESTABLISHED) + assertEquals(listOf(conversationEntity6.toViewEntity(user2)), result) + } + + @Test + fun givenExistingConversations_WhenGetConversationIds_ThenConversationsWithGivenProtocolIsReturned() = runTest { + conversationDAO.insertConversation(conversationEntity4) + conversationDAO.insertConversation(conversationEntity5) + insertTeamUserAndMember(team, user2, conversationEntity5.id) + val result = + conversationDAO.getConversationIds(ConversationEntity.Type.GROUP, ConversationEntity.Protocol.PROTEUS) + assertEquals(listOf(conversationEntity5.id), result) + } + + @Test + fun givenExistingConversations_WhenGetConversationIds_ThenConversationsWithGivenTeamIdIsReturned() = runTest { + conversationDAO.insertConversation(conversationEntity1) + conversationDAO.insertConversation(conversationEntity4) + conversationDAO.insertConversation(conversationEntity5) + insertTeamUserAndMember(team, user2, conversationEntity5.id) + + val result = + conversationDAO.getConversationIds(ConversationEntity.Type.GROUP, ConversationEntity.Protocol.PROTEUS, teamId) + + assertEquals(listOf(conversationEntity5.id), result) + } + + @Test + fun givenExistingConversations_WhenGetConversationIdsWithoutTeamId_ThenConversationsWithAllTeamIdsAreReturned() = runTest { + conversationDAO.insertConversation(conversationEntity4.copy( protocolInfo = ConversationEntity.ProtocolInfo.Proteus)) + conversationDAO.insertConversation(conversationEntity5.copy( teamId = null)) + insertTeamUserAndMember(team, user2, conversationEntity5.id) + + val result = + conversationDAO.getConversationIds(ConversationEntity.Type.GROUP, ConversationEntity.Protocol.PROTEUS) + + assertEquals(setOf(conversationEntity4.id, conversationEntity5.id), result.toSet()) + } + + @Test + fun givenExistingConversations_WhenGetConversationIds_ThenConversationsWithGivenTypeIsReturned() = runTest { + conversationDAO.insertConversation(conversationEntity1.copy(type = ConversationEntity.Type.SELF)) + conversationDAO.insertConversation(conversationEntity5.copy(type = ConversationEntity.Type.ONE_ON_ONE)) + insertTeamUserAndMember(team, user2, conversationEntity5.id) + val result = + conversationDAO.getConversationIds(ConversationEntity.Type.SELF, ConversationEntity.Protocol.PROTEUS) + assertEquals(listOf(conversationEntity1.id), result) + } + @Test fun givenExistingMLSConversation_ThenConversationCanBeRetrievedByGroupState() = runTest { conversationDAO.insertConversation(conversationEntity2) @@ -150,6 +214,36 @@ class ConversationDAOTest : BaseDatabaseTest() { assertEquals(listOf(conversationEntity2.toViewEntity(user2)), result) } + @Test + fun givenAllMembersAreMlsCapable_WhenGetTeamConversationIdsReadyToBeFinalised_ThenConversationIsReturned() = runTest { + val allProtocols = setOf(SupportedProtocolEntity.PROTEUS, SupportedProtocolEntity.MLS) + val selfUser = user1.copy(id = selfUserId, supportedProtocols = allProtocols) + userDAO.upsertUser(selfUser) + + conversationDAO.insertConversation(conversationEntity6) + insertTeamUserAndMember(team, user2.copy(supportedProtocols = allProtocols), conversationEntity6.id) + insertTeamUserAndMember(team, user3.copy(supportedProtocols = allProtocols), conversationEntity6.id) + + val result = conversationDAO.getTeamConversationIdsReadyToCompleteMigration(teamId) + + assertEquals(listOf(conversationEntity6.id), result) + } + + @Test + fun givenOnlySomeMembersAreMlsCapable_WhenGetTeamConversationIdsReadyToBeFinalised_ThenConversationIsNotReturned() = runTest { + val allProtocols = setOf(SupportedProtocolEntity.PROTEUS, SupportedProtocolEntity.MLS) + val selfUser = user1.copy(id = selfUserId, supportedProtocols = allProtocols) + userDAO.upsertUser(selfUser) + + conversationDAO.insertConversation(conversationEntity5) + insertTeamUserAndMember(team, user2.copy(supportedProtocols = allProtocols), conversationEntity5.id) + insertTeamUserAndMember(team, user3.copy(supportedProtocols = setOf(SupportedProtocolEntity.PROTEUS)), conversationEntity5.id) + + val result = conversationDAO.getTeamConversationIdsReadyToCompleteMigration(teamId) + + assertTrue(result.isEmpty()) + } + @Test fun givenExistingConversation_ThenConversationGroupStateCanBeUpdated() = runTest { conversationDAO.insertConversation(conversationEntity2) @@ -173,7 +267,6 @@ class ConversationDAOTest : BaseDatabaseTest() { assertEquals(updatedConversation1Entity.toViewEntity(user1), result) } - @Test fun givenAnExistingConversation_WhenUpdatingTheMutingStatus_ThenConversationShouldBeUpdated() = runTest { conversationDAO.insertConversation(conversationEntity2) @@ -319,6 +412,29 @@ class ConversationDAOTest : BaseDatabaseTest() { } + @Test + fun givenNewValue_whenUpdatingProtocol_thenItsUpdatedAndReportedAsChanged() = runTest { + val conversation = conversationEntity5 + val updatedProtocol = ConversationEntity.Protocol.MLS + + conversationDAO.insertConversation(conversation) + val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol) + + assertTrue(changed) + assertEquals(conversationDAO.getConversationByQualifiedID(conversation.id)?.protocol, updatedProtocol) + } + + @Test + fun givenSameValue_whenUpdatingProtocol_thenItsReportedAsUnchanged() = runTest { + val conversation = conversationEntity5 + val updatedProtocol = ConversationEntity.Protocol.PROTEUS + + conversationDAO.insertConversation(conversation) + val changed = conversationDAO.updateConversationProtocol(conversation.id, updatedProtocol) + + assertFalse(changed) + } + @Test fun givenMLSConversation_whenUpdatingKeyingMaterialLastUpdate_thenItsUpdated() = runTest { // given @@ -372,7 +488,7 @@ class ConversationDAOTest : BaseDatabaseTest() { teamDAO.insertTeam(team) conversationDAO.insertConversation(conversation) - userDAO.insertUser(user1) + userDAO.upsertUser(user1) val messages = buildList { repeat(10) { @@ -587,9 +703,9 @@ class ConversationDAOTest : BaseDatabaseTest() { date = secondRemovalDate, conversationId = conversationEntity1.id ) - userDAO.insertUser(user1) - userDAO.insertUser(user2) - userDAO.insertUser(user3) + userDAO.upsertUser(user1) + userDAO.upsertUser(user2) + userDAO.upsertUser(user3) messageDAO.insertOrIgnoreMessage(message1) messageDAO.insertOrIgnoreMessage(message2) @@ -611,9 +727,9 @@ class ConversationDAOTest : BaseDatabaseTest() { memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member3, conversationEntity1.id) memberDAO.insertMember(mySelfMember, conversationEntity1.id) - userDAO.insertUser(user1) - userDAO.insertUser(user2) - userDAO.insertUser(user3) + userDAO.upsertUser(user1) + userDAO.upsertUser(user2) + userDAO.upsertUser(user3) memberDAO.deleteMemberByQualifiedID(member3.user, conversationEntity1.id) val removalMessage = newSystemMessageEntity( @@ -659,10 +775,10 @@ class ConversationDAOTest : BaseDatabaseTest() { memberDAO.insertMember(member2, conversationEntity2.id) // when - val conversationIds = conversationDAO.getConversationIdsByUserId(member1.user) + val conversationIds = conversationDAO.getConversationsByUserId(member1.user) // then - assertContentEquals(listOf(conversationEntity1.id), conversationIds) + assertContentEquals(listOf(conversationEntity1), conversationIds) } @Test @@ -683,7 +799,7 @@ class ConversationDAOTest : BaseDatabaseTest() { // given conversationDAO.insertConversation(conversationEntity3) teamDAO.insertTeam(team) - userDAO.insertUser(user2) + userDAO.upsertUser(user2) memberDAO.insertMember(MemberEntity(user2.id, MemberEntity.Role.Member), conversationEntity3.id) // when @@ -712,7 +828,7 @@ class ConversationDAOTest : BaseDatabaseTest() { // given conversationDAO.insertConversation(conversationEntity3.copy(creatorId = selfUserId.value)) teamDAO.insertTeam(team) - userDAO.insertUser(user2) + userDAO.upsertUser(user2) insertTeamUserAndMember(team, user2, conversationEntity3.id) // when @@ -723,7 +839,19 @@ class ConversationDAOTest : BaseDatabaseTest() { } @Test - fun givenAnMLSConversation_whenGettingConversationProtocolInfo_itReturnsCorrectInfo() = runTest { + fun givenMixedConversation_whenGettingConversationProtocolInfo_itReturnsCorrectInfo() = runTest { + // given + conversationDAO.insertConversation(conversationEntity6) + + // when + val result = conversationDAO.getConversationProtocolInfo(conversationEntity6.id) + + // then + assertEquals(conversationEntity6.protocolInfo, result) + } + + @Test + fun givenMLSConversation_whenGettingConversationProtocolInfo_itReturnsCorrectInfo() = runTest { // given conversationDAO.insertConversation(conversationEntity2) @@ -735,7 +863,7 @@ class ConversationDAOTest : BaseDatabaseTest() { } @Test - fun givenAProteusConversation_whenGettingConversationProtocolInfo_itReturnsCorrectInfo() = runTest { + fun givenProteusConversation_whenGettingConversationProtocolInfo_itReturnsCorrectInfo() = runTest { // given conversationDAO.insertConversation(conversationEntity1) @@ -764,7 +892,7 @@ class ConversationDAOTest : BaseDatabaseTest() { val instant = Clock.System.now() - userDAO.insertUser(user1) + userDAO.upsertUser(user1) newRegularMessageEntity( id = Random.nextBytes(10).decodeToString(), @@ -774,7 +902,7 @@ class ConversationDAOTest : BaseDatabaseTest() { ).also { messageDAO.insertOrIgnoreMessage(it) } // TODO: insert another message from self user to check if it is not ignored - userDAO.insertUser(user1) + userDAO.upsertUser(user1) newRegularMessageEntity( id = Random.nextBytes(10).decodeToString(), @@ -806,7 +934,7 @@ class ConversationDAOTest : BaseDatabaseTest() { toId = user1.id.value, ) - userDAO.insertUser(user1) + userDAO.upsertUser(user1) conversationDAO.insertConversation(conversation) connectionDAO.insertConnection(connectionEntity) @@ -834,7 +962,7 @@ class ConversationDAOTest : BaseDatabaseTest() { toId = user1.id.value, ) - userDAO.insertUser(user1.copy(name = null)) + userDAO.upsertUser(user1.copy(name = null)) conversationDAO.insertConversation(conversation) connectionDAO.insertConnection(connectionEntity) @@ -849,8 +977,8 @@ class ConversationDAOTest : BaseDatabaseTest() { conversationDAO.insertConversation(conversationEntity1) conversationDAO.insertConversation(conversationEntity2) - userDAO.insertUser(user1) // user with metadata - userDAO.insertUser(user2.copy(name = null)) // user without metadata + userDAO.upsertUser(user1.copy(activeOneOnOneConversationId = conversationEntity1.id)) // user with metadata + userDAO.upsertUser(user2.copy(activeOneOnOneConversationId = conversationEntity2.id, name = null)) // user without metadata memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member2, conversationEntity1.id) @@ -868,8 +996,8 @@ class ConversationDAOTest : BaseDatabaseTest() { conversationDAO.insertConversation(conversationEntity2.copy(archived = true)) conversationDAO.insertConversation(conversationEntity3.copy(archived = false)) - userDAO.insertUser(user1) - userDAO.insertUser(user2) + userDAO.upsertUser(user1.copy(activeOneOnOneConversationId = conversationEntity1.id)) + userDAO.upsertUser(user2.copy(activeOneOnOneConversationId = conversationEntity2.id)) memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member2, conversationEntity2.id) @@ -886,8 +1014,8 @@ class ConversationDAOTest : BaseDatabaseTest() { conversationDAO.insertConversation(conversationEntity1.copy(archived = false)) conversationDAO.insertConversation(conversationEntity2.copy(archived = false)) - userDAO.insertUser(user1) - userDAO.insertUser(user2) + userDAO.upsertUser(user1.copy(activeOneOnOneConversationId = conversationEntity1.id)) + userDAO.upsertUser(user2.copy(activeOneOnOneConversationId = conversationEntity2.id)) memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member2, conversationEntity2.id) @@ -1047,9 +1175,49 @@ class ConversationDAOTest : BaseDatabaseTest() { assertTrue(result == 1L) } + @Test + fun givenOneOnOneConversations_whenGettingAllConversations_thenShouldReturnsOnlyActiveConversations() = runTest { + conversationDAO.insertConversation(conversationEntity1) + conversationDAO.insertConversation(conversationEntity2) + + userDAO.upsertUser(user1.copy(activeOneOnOneConversationId = conversationEntity1.id)) // user active one-on-one + userDAO.upsertUser(user2.copy(activeOneOnOneConversationId = null)) // user without active one-on-one + + memberDAO.insertMembersWithQualifiedId(listOf(member1, member2), conversationEntity1.id) + memberDAO.insertMembersWithQualifiedId(listOf(member1, member2), conversationEntity2.id) + + conversationDAO.getAllConversationDetails(fromArchive = false).first().let { + assertEquals(1, it.size) + assertEquals(conversationEntity1.id, it.first().id) + } + } + + @Test + fun givenOneOnOneConversationNotExisting_whenGettingOneOnOneConversationId_thenShouldReturnEmptyList() = runTest { + // given + userDAO.upsertUser(user1.copy(activeOneOnOneConversationId = conversationEntity1.id)) + + // then + assertTrue(conversationDAO.getOneOnOneConversationIdsWithOtherUser(user1.id, protocol = ConversationEntity.Protocol.PROTEUS).isEmpty()) + } + + @Test + fun givenOneOnOneConversationExisting_whenGettingOneOnOneConversationId_thenShouldRespectProtocol() = runTest { + // given + userDAO.upsertUser(user1) + conversationDAO.insertConversation(conversationEntity1) + conversationDAO.insertConversation(conversationEntity2) + memberDAO.insertMember(member1, conversationEntity1.id) + memberDAO.insertMember(member1, conversationEntity2.id) + + // then + assertEquals(listOf(conversationEntity1.id), conversationDAO.getOneOnOneConversationIdsWithOtherUser(user1.id, protocol = ConversationEntity.Protocol.PROTEUS)) + assertEquals(listOf(conversationEntity2.id), conversationDAO.getOneOnOneConversationIdsWithOtherUser(user1.id, protocol = ConversationEntity.Protocol.MLS)) + } + private suspend fun insertTeamUserAndMember(team: TeamEntity, user: UserEntity, conversationId: QualifiedIDEntity) { teamDAO.insertTeam(team) - userDAO.insertUser(user) + userDAO.upsertUser(user) // should be inserted AFTER inserting the conversation!!! memberDAO.insertMembersWithQualifiedId( listOf( @@ -1067,8 +1235,12 @@ class ConversationDAOTest : BaseDatabaseTest() { val mlsGroupState: ConversationEntity.GroupState val protocolInfoTmp = protocolInfo - if (protocolInfoTmp is ConversationEntity.ProtocolInfo.MLS) { - protocol = ConversationEntity.Protocol.MLS + if (protocolInfoTmp is ConversationEntity.ProtocolInfo.MLSCapable) { + protocol = if (protocolInfoTmp is ConversationEntity.ProtocolInfo.MLS) { + ConversationEntity.Protocol.MLS + } else { + ConversationEntity.Protocol.MIXED + } mlsGroupId = protocolInfoTmp.groupId mlsLastKeyingMaterialUpdate = protocolInfoTmp.keyingMaterialLastUpdate mlsGroupState = protocolInfoTmp.groupState @@ -1115,7 +1287,9 @@ class ConversationDAOTest : BaseDatabaseTest() { userDefederated = if (type == ConversationEntity.Type.ONE_ON_ONE) userEntity?.defederated else null, archived = false, archivedDateTime = null, - verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED + verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED, + userSupportedProtocols = if (type == ConversationEntity.Type.ONE_ON_ONE) userEntity?.supportedProtocols else null, + userActiveOneOnOneConversationId = null, ) } @@ -1209,7 +1383,7 @@ class ConversationDAOTest : BaseDatabaseTest() { QualifiedIDEntity("4", "wire.com"), "conversation4", ConversationEntity.Type.GROUP, - null, + teamId, ConversationEntity.ProtocolInfo.MLS( "group4", ConversationEntity.GroupState.ESTABLISHED, @@ -1233,6 +1407,54 @@ class ConversationDAOTest : BaseDatabaseTest() { archivedInstant = null, verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED ) + val conversationEntity5 = ConversationEntity( + QualifiedIDEntity("5", "wire.com"), + "conversation5", + ConversationEntity.Type.GROUP, + teamId, + ConversationEntity.ProtocolInfo.Proteus, + creatorId = "someValue", + lastNotificationDate = null, + lastModifiedDate = "2022-03-30T15:36:00.000Z".toInstant(), + lastReadDate = "2000-01-01T12:00:00.000Z".toInstant(), + mutedStatus = ConversationEntity.MutedStatus.ALL_ALLOWED, + access = listOf(ConversationEntity.Access.LINK, ConversationEntity.Access.INVITE), + accessRole = listOf(ConversationEntity.AccessRole.NON_TEAM_MEMBER, ConversationEntity.AccessRole.TEAM_MEMBER), + receiptMode = ConversationEntity.ReceiptMode.DISABLED, + messageTimer = null, + userMessageTimer = null, + archived = false, + archivedInstant = null, + verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED + ) + val conversationEntity6 = ConversationEntity( + QualifiedIDEntity("6", "wire.com"), + "conversation6", + ConversationEntity.Type.GROUP, + teamId, + ConversationEntity.ProtocolInfo.Mixed( + "group6", + ConversationEntity.GroupState.ESTABLISHED, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ), + creatorId = "someValue", + // This conversation was modified after the last time the user was notified about it + lastNotificationDate = "2021-03-30T15:30:00.000Z".toInstant(), + lastModifiedDate = "2021-03-30T15:36:00.000Z".toInstant(), + lastReadDate = "2000-01-01T12:00:00.000Z".toInstant(), + // and it's status is set to be only notified if there is a mention for the user + mutedStatus = ConversationEntity.MutedStatus.ONLY_MENTIONS_AND_REPLIES_ALLOWED, + access = listOf(ConversationEntity.Access.LINK, ConversationEntity.Access.INVITE), + accessRole = listOf(ConversationEntity.AccessRole.NON_TEAM_MEMBER, ConversationEntity.AccessRole.TEAM_MEMBER), + receiptMode = ConversationEntity.ReceiptMode.DISABLED, + messageTimer = null, + userMessageTimer = null, + archived = false, + archivedInstant = null, + verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED + ) val member1 = MemberEntity(user1.id, MemberEntity.Role.Admin) val member2 = MemberEntity(user2.id, MemberEntity.Role.Member) diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/LastMessageListTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/LastMessageListTest.kt index 9cf0159b374..a534678d7b8 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/LastMessageListTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/LastMessageListTest.kt @@ -69,7 +69,7 @@ class LastMessageListTest: BaseDatabaseTest() { ) conversationDAO.insertConversation(conversion) - userDAO.insertUser(user) + userDAO.upsertUser(user) messageDAO.insertOrIgnoreMessage(message) messageDAO.observeLastMessages().first().also { diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/MemberDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/MemberDAOTest.kt index 91d6f0c97e9..7bd49022df5 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/MemberDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/MemberDAOTest.kt @@ -25,7 +25,6 @@ import com.wire.kalium.persistence.dao.member.MemberEntity import com.wire.kalium.persistence.dao.message.MessageDAO import com.wire.kalium.persistence.utils.stubs.TestStubs import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.test.runTest import kotlin.test.BeforeTest import kotlin.test.Test @@ -109,62 +108,50 @@ class MemberDAOTest : BaseDatabaseTest() { } @Test - fun givenExistingMLSConversation_whenAddingMembersByGroupId_ThenAllMembersCanBeRetrieved() = runTest { - val conversationEntity2 = TestStubs.conversationEntity2 + fun givenExistingMixedConversation_whenAddingMembersByGroupId_ThenAllMembersCanBeRetrieved() = runTest { + val conversationEntity5 = TestStubs.conversationEntity5 val member1 = TestStubs.member1 val member2 = TestStubs.member2 - conversationDAO.insertConversation(conversationEntity2) + conversationDAO.insertConversation(conversationEntity5) memberDAO.insertMembers( listOf(member1, member2), - (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId + (conversationEntity5.protocolInfo as ConversationEntity.ProtocolInfo.MLSCapable).groupId ) - assertEquals(listOf(member1, member2), memberDAO.observeConversationMembers(conversationEntity2.id).first()) + assertEquals(listOf(member1, member2), memberDAO.observeConversationMembers(conversationEntity5.id).first()) } @Test - fun givenExistingConversation_ThenInsertedOrUpdatedMembersAreRetrieved() = runTest(dispatcher) { - val conversationEntity1 = TestStubs.conversationEntity1 + fun givenExistingMLSConversation_whenAddingMembersByGroupId_ThenAllMembersCanBeRetrieved() = runTest { + val conversationEntity2 = TestStubs.conversationEntity2 val member1 = TestStubs.member1 + val member2 = TestStubs.member2 - conversationDAO.insertConversation(conversationEntity1) - memberDAO.updateOrInsertOneOnOneMemberWithConnectionStatus( - member = member1, - status = ConnectionEntity.State.ACCEPTED, - conversationID = conversationEntity1.id - ) + conversationDAO.insertConversation(conversationEntity2) - assertEquals( - listOf(member1), memberDAO.observeConversationMembers(conversationEntity1.id).first() + memberDAO.insertMembers( + listOf(member1, member2), + (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId ) - assertNotNull(userDAO.observeUserDetailsByQualifiedID(member1.user).firstOrNull()) + + assertEquals(listOf(member1, member2), memberDAO.observeConversationMembers(conversationEntity2.id).first()) } @Test - fun givenExistingUser_WhenInsertingToOneOnOneConversationThenConnectionStatusShouldBeAccepted() = runTest(dispatcher) { + fun givenExistingConversation_ThenInsertedOrUpdatedMembersAreRetrieved() = runTest(dispatcher) { val conversationEntity1 = TestStubs.conversationEntity1 val member1 = TestStubs.member1 - val user = TestStubs.user1.copy(connectionStatus = ConnectionEntity.State.NOT_CONNECTED) - val userDetails = TestStubs.userDetails1.copy(connectionStatus = ConnectionEntity.State.NOT_CONNECTED) - userDAO.insertUser(user) + userDAO.upsertUser(TestStubs.user1) conversationDAO.insertConversation(conversationEntity1) - memberDAO.updateOrInsertOneOnOneMemberWithConnectionStatus( + memberDAO.updateOrInsertOneOnOneMember( member = member1, - status = ConnectionEntity.State.ACCEPTED, conversationID = conversationEntity1.id ) - assertEquals( - listOf(member1), - memberDAO.observeConversationMembers(conversationEntity1.id).first() - ) - assertEquals( - userDetails.copy(connectionStatus = ConnectionEntity.State.ACCEPTED), - userDAO.observeUserDetailsByQualifiedID(member1.user).firstOrNull() - ) + assertEquals(listOf(member1), memberDAO.observeConversationMembers(conversationEntity1.id).first()) } @Test @@ -172,9 +159,8 @@ class MemberDAOTest : BaseDatabaseTest() { val conversationEntity1 = TestStubs.conversationEntity1 val member1 = TestStubs.member1 - memberDAO.updateOrInsertOneOnOneMemberWithConnectionStatus( + memberDAO.updateOrInsertOneOnOneMember( member = member1, - status = ConnectionEntity.State.ACCEPTED, conversationID = conversationEntity1.id ) @@ -183,26 +169,6 @@ class MemberDAOTest : BaseDatabaseTest() { ) } - @Test - fun givenExistingConversation_ThenUserTableShouldBeUpdatedOnlyAndNotReplaced() = runTest(dispatcher) { - val conversationEntity1 = TestStubs.conversationEntity1 - val user1 = TestStubs.user1 - val member1 = TestStubs.member1 - - conversationDAO.insertConversation(conversationEntity1) - userDAO.insertUser(user1.copy(connectionStatus = ConnectionEntity.State.NOT_CONNECTED)) - - memberDAO.updateOrInsertOneOnOneMemberWithConnectionStatus( - member = member1, - status = ConnectionEntity.State.SENT, - conversationID = conversationEntity1.id - ) - - assertEquals(listOf(member1), memberDAO.observeConversationMembers(conversationEntity1.id).first()) - assertEquals(ConnectionEntity.State.SENT, userDAO.observeUserDetailsByQualifiedID(user1.id).first()?.connectionStatus) - assertEquals(user1.name, userDAO.observeUserDetailsByQualifiedID(user1.id).first()?.name) - } - @Test fun givenConversation_whenInsertingMembers_thenMembersShouldNotBeDuplicated() = runTest { val member1 = TestStubs.member1 @@ -284,9 +250,9 @@ class MemberDAOTest : BaseDatabaseTest() { val member3 = TestStubs.member3 // given conversationDAO.insertConversation(conversationEntity1) - userDAO.insertUser(user1) - userDAO.insertUser(user2) - userDAO.insertUser(user3) + userDAO.upsertUser(user1) + userDAO.upsertUser(user2) + userDAO.upsertUser(user3) memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member2, conversationEntity1.id) memberDAO.insertMember(member3, conversationEntity1.id) @@ -313,9 +279,9 @@ class MemberDAOTest : BaseDatabaseTest() { // given conversationDAO.insertConversation(conversationEntity1) - userDAO.insertUser(user1) - userDAO.insertUser(user2) - userDAO.insertUser(user3) + userDAO.upsertUser(user1) + userDAO.upsertUser(user2) + userDAO.upsertUser(user3) memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member2, conversationEntity1.id) memberDAO.insertMember(member3, conversationEntity1.id) @@ -341,8 +307,8 @@ class MemberDAOTest : BaseDatabaseTest() { conversationDAO.insertConversation(conversationEntity1) conversationDAO.insertConversation(conversationEntity2.copy(hasIncompleteMetadata = true)) - userDAO.insertUser(user1) - userDAO.insertUser(user2) + userDAO.upsertUser(user1) + userDAO.upsertUser(user2) memberDAO.insertMember(member1, conversationEntity1.id) memberDAO.insertMember(member2, conversationEntity1.id) @@ -383,13 +349,13 @@ class MemberDAOTest : BaseDatabaseTest() { // Insert a conversation, user, and a member into the conversation to test the deletion operation val oldMember = MemberEntity(TestStubs.user3.id, MemberEntity.Role.Member) - userDAO.insertUser(TestStubs.user3) + userDAO.upsertUser(TestStubs.user3) conversationDAO.insertConversation(TestStubs.conversationEntity1) memberDAO.insertMember(oldMember, conversationID) // Ensure all new users are inserted before calling updateFullMemberList memberList.forEach { member -> - userDAO.insertUser(TestStubs.user1.copy(id = member.user)) + userDAO.upsertUser(TestStubs.user1.copy(id = member.user)) } // When @@ -420,9 +386,9 @@ class MemberDAOTest : BaseDatabaseTest() { val user2 = TestStubs.user2 val user3 = TestStubs.user3.copy(id = QualifiedIDEntity("3", secondDomain)) - userDAO.insertUser(user1) - userDAO.insertUser(user2) - userDAO.insertUser(user3) + userDAO.upsertUser(user1) + userDAO.upsertUser(user2) + userDAO.upsertUser(user3) conversationDAO.insertConversation(groupConversationEntity) @@ -455,8 +421,8 @@ class MemberDAOTest : BaseDatabaseTest() { val federatedUser = TestStubs.user1.copy(id = QualifiedIDEntity("fedid", federatedDomain)) val otherUser = TestStubs.user1.copy(id = QualifiedIDEntity("other", "other.com")) - userDAO.insertUser(federatedUser) - userDAO.insertUser(otherUser) + userDAO.upsertUser(federatedUser) + userDAO.upsertUser(otherUser) conversationDAO.insertConversation(oneOnOneConversationEntity) conversationDAO.insertConversation(otherOneOnOneConversationEntity) diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserClientDAOIntegrationTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserClientDAOIntegrationTest.kt index ee8452d8e76..f72e3cc7a9b 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserClientDAOIntegrationTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserClientDAOIntegrationTest.kt @@ -46,7 +46,7 @@ class UserClientDAOIntegrationTest : BaseDatabaseTest() { @Test fun givenClientsAreInserted_whenDeletingTheUser_thenTheClientsAreDeleted() = runTest { - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertClientParam) userDAO.deleteUserByQualifiedID(user.id) @@ -77,7 +77,8 @@ class UserClientDAOIntegrationTest : BaseDatabaseTest() { label = null, clientType = null, model = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) val insertClientParam = InsertClientParam( client.userId, @@ -88,7 +89,8 @@ class UserClientDAOIntegrationTest : BaseDatabaseTest() { client.registrationDate, client.lastActive, client.model, - null + client.mlsPublicKeys, + client.isMLSCapable ) } } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserConversationDAOIntegrationTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserConversationDAOIntegrationTest.kt index d9866f5a0b7..3b6fbe1e423 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserConversationDAOIntegrationTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserConversationDAOIntegrationTest.kt @@ -25,15 +25,15 @@ import com.wire.kalium.persistence.dao.member.MemberDAO import com.wire.kalium.persistence.dao.member.MemberEntity import com.wire.kalium.persistence.utils.stubs.newConversationEntity import com.wire.kalium.persistence.utils.stubs.newUserEntity -import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.first import kotlinx.coroutines.test.runTest import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull import kotlin.test.assertTrue -@OptIn(ExperimentalCoroutinesApi::class) class UserConversationDAOIntegrationTest : BaseDatabaseTest() { private val user1 = newUserEntity(id = "1") @@ -59,7 +59,7 @@ class UserConversationDAOIntegrationTest : BaseDatabaseTest() { @Test fun givenUserExists_whenInsertingMember_thenOriginalUserDetailsAreKept() = runTest(dispatcher) { - userDAO.insertUser(user1) + userDAO.upsertUser(user1) conversationDAO.insertConversation(conversationEntity1) memberDAO.insertMember(member1, conversationEntity1.id) @@ -251,6 +251,29 @@ class UserConversationDAOIntegrationTest : BaseDatabaseTest() { } } + @Test + fun givenActiveOneOnOneWasSetForConversation_whenFetchingConversationView_thenActiveOneOnOneShouldMatch() = runTest { + userDAO.upsertUser(user1) + conversationDAO.insertConversation(conversationEntity1) + memberDAO.insertMember(member1, conversationEntity1.id) + + userDAO.updateActiveOneOnOneConversation(user1.id, conversationEntity1.id) + + val result = conversationDAO.getConversationByQualifiedID(conversationEntity1.id) + assertEquals(conversationEntity1.id, result?.userActiveOneOnOneConversationId) + } + + @Test + fun givenActiveOneOnOneWasNotSetForConversation_whenFetchingConversationView_thenActiveOneOnOneShouldBeNull() = runTest { + userDAO.upsertUser(user1) + conversationDAO.insertConversation(conversationEntity1) + memberDAO.insertMember(member1, conversationEntity1.id) + + val result = conversationDAO.getConversationByQualifiedID(conversationEntity1.id) + assertNotNull(result) + assertNull(result.userActiveOneOnOneConversationId) + } + private suspend fun createTestConversation(conversationIDEntity: QualifiedIDEntity, members: List) { conversationDAO.insertConversation( newConversationEntity(conversationIDEntity) diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserDAOTest.kt index b486991391e..bafa6f5898d 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/UserDAOTest.kt @@ -22,6 +22,7 @@ import app.cash.turbine.test import com.wire.kalium.persistence.BaseDatabaseTest import com.wire.kalium.persistence.dao.member.MemberEntity import com.wire.kalium.persistence.db.UserDatabaseBuilder +import com.wire.kalium.persistence.utils.stubs.TestStubs import com.wire.kalium.persistence.utils.stubs.newConversationEntity import com.wire.kalium.persistence.utils.stubs.newUserEntity import kotlinx.coroutines.flow.first @@ -54,7 +55,7 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenUser_ThenUserCanBeInserted() = runTest(dispatcher) { - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) val result = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() assertEquals(result?.toSimpleEntity(), user1) } @@ -72,7 +73,7 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenExistingUser_ThenUserCanBeDeleted() = runTest(dispatcher) { - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) db.userDAO.deleteUserByQualifiedID(user1.id) val result = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() assertNull(result) @@ -80,27 +81,15 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenExistingUser_ThenUserCanBeUpdated() = runTest(dispatcher) { - db.userDAO.insertUser(user1) - val updatedUser1 = UserEntity( - user1.id, - "John Doe", - "johndoe", - "email1", - "phone1", - 1, - "team", - ConnectionEntity.State.ACCEPTED, - UserAssetIdEntity("asset1", "domain"), - UserAssetIdEntity("asset1", "domain"), - UserAvailabilityStatusEntity.NONE, - UserTypeEntity.STANDARD, - botService = null, - deleted = false, - hasIncompleteMetadata = false, - expiresAt = null, - defederated = false + db.userDAO.upsertUser(user1) + val updatedUser1 = newUserEntity(user1.id).copy( + name = "John Doe", + handle = "johndoe", + email = "email1", + phone = "phone1", + accentId = 1 ) - db.userDAO.updateUser(updatedUser1) + db.userDAO.upsertUser(updatedUser1) val result = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() assertEquals(result?.toSimpleEntity(), updatedUser1) } @@ -109,32 +98,20 @@ class UserDAOTest : BaseDatabaseTest() { fun givenRetrievedUser_ThenUpdatesArePropagatedThroughFlow() = runTest(dispatcher) { val collectedValues = mutableListOf() - db.userDAO.insertUser(user1) - - val updatedUser1 = UserEntity( - user1.id, - "John Doe", - "johndoe", - "email1", - "phone1", - 1, - "team", - ConnectionEntity.State.ACCEPTED, - null, - null, - UserAvailabilityStatusEntity.NONE, - UserTypeEntity.STANDARD, - botService = null, - false, - hasIncompleteMetadata = false, - expiresAt = null, - defederated = false + db.userDAO.upsertUser(user1) + + val updatedUser1 = newUserEntity(user1.id).copy( + name = "John Doe", + handle = "johndoe", + email = "email1", + phone = "phone1", + accentId = 1 ) db.userDAO.observeUserDetailsByQualifiedID(user1.id).take(2).collect { collectedValues.add(it?.toSimpleEntity()) if (collectedValues.size == 1) { - db.userDAO.updateUser(updatedUser1) + db.userDAO.upsertUser(updatedUser1) } } assertEquals(user1, collectedValues[0]) @@ -144,7 +121,7 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenExistingUser_WhenUpdateUserHandle_ThenUserHandleIsUpdated() = runTest(dispatcher) { // given - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) val updatedHandle = "new-handle" // when @@ -241,44 +218,14 @@ class UserDAOTest : BaseDatabaseTest() { USER_ENTITY_3.copy(email = commonEmailPrefix + "u3@example.org") ) val notCommonEmailUsers = listOf( - UserEntity( - id = QualifiedIDEntity("4", "wire.com"), - name = "testName4", - handle = "testHandle4", - email = "someDifferentEmail1@wire.com", - phone = "testPhone4", - accentId = 4, - team = "testTeam4", - ConnectionEntity.State.ACCEPTED, - null, - null, - UserAvailabilityStatusEntity.NONE, - UserTypeEntity.STANDARD, - botService = null, - false, - hasIncompleteMetadata = false, - expiresAt = null, - defederated = false - ), - UserEntity( - id = QualifiedIDEntity("5", "wire.com"), - name = "testName5", - handle = "testHandle5", - email = "someDifferentEmail2@wire.com", - phone = "testPhone5", - accentId = 5, - team = "testTeam5", - ConnectionEntity.State.ACCEPTED, - null, - null, - UserAvailabilityStatusEntity.NONE, - UserTypeEntity.STANDARD, - botService = null, - deleted = false, - hasIncompleteMetadata = false, - expiresAt = null, - defederated = false - ) + newUserEntity(QualifiedIDEntity("4", "wire.com")) + .copy( + email = "someDifferentEmail1@wire.com", + ), + newUserEntity(QualifiedIDEntity("5", "wire.com")) + .copy( + email = "someDifferentEmail2@wire.com" + ) ) val mockUsers = commonEmailUsers + notCommonEmailUsers @@ -536,7 +483,7 @@ class UserDAOTest : BaseDatabaseTest() { fun givenAExistingUsers_whenUpdatingTheirValuesAndRecordNotExists_ThenResultsOneUpdatedAnotherInserted() = runTest(dispatcher) { // given val newNameA = "new user naming a" - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) // when val updatedUser1 = user1.copy(name = newNameA) db.userDAO.upsertUsers(listOf(updatedUser1, user2)) @@ -548,30 +495,28 @@ class UserDAOTest : BaseDatabaseTest() { } @Test - fun givenAExistingUsers_whenUpsertingTeamMembers_ThenResultsOneUpdatedAnotherInserted() = runTest(dispatcher) { + fun givenAExistingUsers_whenUpsertingTeamMembersUserTypes_ThenUserTypeIsUpdated() = runTest(dispatcher) { // given - val newTeamId = "new user team id" - db.userDAO.insertUser(user1) + val newUserType = UserTypeEntity.ADMIN + db.userDAO.upsertUser(user1) // when - val updatedUser1 = user1.copy(team = newTeamId) - db.userDAO.upsertTeamMembersTypes(listOf(updatedUser1, user2)) + db.userDAO.upsertTeamMemberUserTypes(mapOf(user1.id to newUserType)) // then - val updated1 = db.userDAO.observeUserDetailsByQualifiedID(updatedUser1.id) - val inserted2 = db.userDAO.observeUserDetailsByQualifiedID(user2.id) - assertEquals(newTeamId, updated1.first()?.team) - assertNotNull(inserted2) + val updated = db.userDAO.observeUserDetailsByQualifiedID(user1.id) + assertEquals(newUserType, updated.first()?.userType) + assertEquals(ConnectionEntity.State.ACCEPTED, updated.first()?.connectionStatus) } @Test - fun givenATeamMember_whenUpsertingTeamMember_ThenUserTypeShouldStayTheSame() = runTest(dispatcher) { + fun givenNotExistingUsers_whenUpsertingTeamMembersUserTypes_ThenUserIsInsertedWithCorrectUserType() = runTest(dispatcher) { // given - val externalMember = user1.copy(userType = UserTypeEntity.EXTERNAL) - db.userDAO.upsertTeamMembersTypes(listOf(externalMember)) + val newUserType = UserTypeEntity.ADMIN // when - db.userDAO.upsertTeamMembers(listOf(user1)) + db.userDAO.upsertTeamMemberUserTypes(mapOf(user1.id to newUserType)) // then - val updated1 = db.userDAO.observeUserDetailsByQualifiedID(user1.id) - assertEquals(UserTypeEntity.EXTERNAL, updated1.first()?.userType) + val inserted = db.userDAO.observeUserDetailsByQualifiedID(user1.id) + assertEquals(newUserType, inserted.first()?.userType) + assertEquals(ConnectionEntity.State.ACCEPTED, inserted.first()?.connectionStatus) } @Test @@ -593,7 +538,7 @@ class UserDAOTest : BaseDatabaseTest() { } // when - db.userDAO.upsertTeamMembers(listOf(teamMember)) + db.userDAO.upsertUsers(listOf(teamMember)) // then db.userDAO.getAllUsersDetails().first().also { @@ -609,7 +554,7 @@ class UserDAOTest : BaseDatabaseTest() { fun givenAExistingUsers_whenUpsertingUsers_ThenResultsOneUpdatedAnotherInsertedWithNoConnectionStatusOverride() = runTest(dispatcher) { // given val newTeamId = "new team id" - db.userDAO.insertUser(user1.copy(connectionStatus = ConnectionEntity.State.ACCEPTED)) + db.userDAO.upsertUser(user1.copy(connectionStatus = ConnectionEntity.State.ACCEPTED)) // when val updatedUser1 = user1.copy(team = newTeamId) db.userDAO.upsertUsers(listOf(updatedUser1, user2)) @@ -634,7 +579,7 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenUser_WhenMarkingAsDeleted_ThenProperValueShouldBeUpdated() = runTest(dispatcher) { val user = user1 - db.userDAO.insertUser(user) + db.userDAO.upsertUser(user) val deletedUser = user1.copy(deleted = true, team = null, userType = UserTypeEntity.NONE) db.userDAO.markUserAsDeleted(user1.id) val result = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() @@ -660,7 +605,7 @@ class UserDAOTest : BaseDatabaseTest() { val existingUser = user1 val usersToInsert = listOf(user1.copy(name = "other name to make sure this one wasn't inserted nor edited"), user2) val expected = listOf(user1, user2) - db.userDAO.insertUser(existingUser) + db.userDAO.upsertUser(existingUser) // when db.userDAO.insertOrIgnoreUsers(usersToInsert) // then @@ -672,7 +617,7 @@ class UserDAOTest : BaseDatabaseTest() { fun givenAnExistingUser_whenUpdatingTheDisplayName_thenTheValueShouldBeUpdated() = runTest(dispatcher) { // given val expectedNewDisplayName = "new user display name" - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) // when db.userDAO.updateUserDisplayName(user1.id, expectedNewDisplayName) @@ -685,7 +630,7 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenExistingUserWithoutMetadata_whenQueryingThem_thenShouldReturnUsersWithoutMetadata() = runTest(dispatcher) { // given - db.userDAO.insertUser(user1.copy(name = null, handle = null, hasIncompleteMetadata = true)) + db.userDAO.upsertUser(user1.copy(name = null, handle = null, hasIncompleteMetadata = true)) // when val usersWithoutMetadata = db.userDAO.getUsersDetailsWithoutMetadata() @@ -698,7 +643,7 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenExistingUser_WhenRemoveUserAsset_ThenUserAssetIsRemoved() = runTest(dispatcher) { // given - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) val assetId = UserAssetIdEntity("asset1", "domain") val updatedUser1 = user1.copy(previewAssetId = assetId) @@ -730,10 +675,10 @@ class UserDAOTest : BaseDatabaseTest() { val user2 = newUserEntity().copy(id = UserIDEntity("user-2", "domain-2")) val user3 = newUserEntity().copy(id = UserIDEntity("user-3", "domain-1")) - db.userDAO.insertUser(selfUser) - db.userDAO.insertUser(user1) - db.userDAO.insertUser(user2) - db.userDAO.insertUser(user3) + db.userDAO.upsertUser(selfUser) + db.userDAO.upsertUser(user1) + db.userDAO.upsertUser(user2) + db.userDAO.upsertUser(user3) db.userDAO.allOtherUsersId().also { result -> assertFalse { @@ -745,23 +690,79 @@ class UserDAOTest : BaseDatabaseTest() { @Test fun givenExistingUser_ThenUserCanBeDefederated() = runTest(dispatcher) { - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) db.userDAO.markUserAsDefederated(user1.id) val result = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() assertNotNull(result) assertEquals(true, result.defederated) } + @Test + fun givenAnExistingUser_whenUpdatingTheSupportedProtocols_thenTheValueShouldBeUpdated() = runTest(dispatcher) { + // given + val expectedNewSupportedProtocols = setOf(SupportedProtocolEntity.PROTEUS, SupportedProtocolEntity.MLS) + db.userDAO.upsertUser(user1) + + // when + db.userDAO.updateUserSupportedProtocols(user1.id, expectedNewSupportedProtocols) + + // then + val persistedUser = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() + assertEquals(expectedNewSupportedProtocols, persistedUser?.supportedProtocols) + } @Test fun givenExistingUserIsDefederated_ThenUserCanBeRefederatedAfterUpdate() = runTest(dispatcher) { - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) db.userDAO.markUserAsDefederated(user1.id) - db.userDAO.insertUser(user1) + db.userDAO.upsertUser(user1) val result = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() assertNotNull(result) assertEquals(false, result.defederated) } + @Test + fun givenAnExistingUser_WhenUpdatingOneOnOneConversationId_ThenItIsUpdated() = runTest(dispatcher) { + // given + val expectedNewOneOnOneConversationId = TestStubs.conversationEntity1.id + db.userDAO.upsertUser(user1) + + // when + db.userDAO.updateActiveOneOnOneConversation(user1.id, expectedNewOneOnOneConversationId) + + // then + val persistedUser = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() + assertEquals(expectedNewOneOnOneConversationId, persistedUser?.activeOneOnOneConversationId) + } + + @Test + fun givenAnExistingUser_whenPerformingPartialUpdate_thenChangedFieldIsUpdatedOthersAreUnchanged() = runTest(dispatcher) { + // given + val expectedName = "new name" + val update = PartialUserEntity( + name = expectedName, + handle = null, + email = null, + accentId = null, + previewAssetId = null, + completeAssetId = null, + supportedProtocols = null + ) + db.userDAO.upsertUser(user1) + + // when + db.userDAO.updateUser(user1.id, update) + + // then + val persistedUser = db.userDAO.observeUserDetailsByQualifiedID(user1.id).first() + assertEquals(expectedName, persistedUser?.name) + assertEquals(user1.handle, persistedUser?.handle) + assertEquals(user1.email, persistedUser?.email) + assertEquals(user1.accentId, persistedUser?.accentId) + assertEquals(user1.previewAssetId, persistedUser?.previewAssetId) + assertEquals(user1.completeAssetId, persistedUser?.completeAssetId) + assertEquals(user1.supportedProtocols, persistedUser?.supportedProtocols) + } + private companion object { val USER_ENTITY_1 = newUserEntity(QualifiedIDEntity("1", "wire.com")) val USER_ENTITY_2 = newUserEntity(QualifiedIDEntity("2", "wire.com")) diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt index 90f70094237..2f3c9e2be37 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/client/ClientDAOTest.kt @@ -66,11 +66,10 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClientIsInserted_whenFetchingClientsByUserId_thenTheRelevantClientIsReturned() = runTest { + val insertedClient = insertedClient1.copy(user.id, "id1", deviceType = null, isMLSCapable = true) + val expected = client1.copy(user.id, "id1", deviceType = null, isValid = true, isProteusVerified = false, isMLSCapable = true) - val insertedClient = insertedClient1.copy(user.id, "id1", deviceType = null) - val expected = client1.copy(user.id, "id1", deviceType = null, isValid = true, isProteusVerified = false) - - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) val result = clientDAO.getClientsOfUserByQualifiedIDFlow(userId).first() @@ -88,7 +87,7 @@ class ClientDAOTest : BaseDatabaseTest() { val insertedClient2 = insertedClient2.copy(user.id, "id2", deviceType = null) val client2 = insertedClient2.toClient() - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClients(listOf(insertedClient, insertedClient2)) val result = clientDAO.getClientsOfUserByQualifiedIDFlow(userId).first() @@ -101,14 +100,14 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClientsAreInsertedForMultipleUsers_whenFetchingClientsByUserId_thenOnlyTheRelevantClientsAreReturned() = runTest { - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClients(listOf(insertedClient1, insertedClient2)) val unrelatedUserId = QualifiedIDEntity("unrelated", "user") val unrelatedUser = newUserEntity(unrelatedUserId) val unrelatedInsertedClient = insertedClient1.copy(unrelatedUserId, "id1", deviceType = null) - userDAO.insertUser(unrelatedUser) + userDAO.upsertUser(unrelatedUser) clientDAO.insertClient(unrelatedInsertedClient) val result = clientDAO.getClientsOfUserByQualifiedIDFlow(userId).first() @@ -119,7 +118,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClientIsInserted_whenDeletingItSpecifically_thenItShouldNotBeReturnedAnymoreOnNextFetch() = runTest { - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient1) clientDAO.deleteClient(insertedClient1.userId, insertedClient1.id) @@ -130,7 +129,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClientsAreInserted_whenDeletingClientsOfUser_thenTheyShouldNotBeReturnedAnymoreOnNextFetch() = runTest { - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClients(listOf(insertedClient1, insertedClient2)) clientDAO.deleteClientsOfUserByQualifiedID(insertedClient1.userId) @@ -146,7 +145,7 @@ class ClientDAOTest : BaseDatabaseTest() { val insertClientWithNullType = insertClientWithType.copy(deviceType = null) - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClients(listOf(insertClientWithType)) clientDAO.getClientsOfUserByQualifiedIDFlow(userId).first().also { resultList -> assertEquals(listOf(clientWithType), resultList) @@ -162,7 +161,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClientIsInsertedAndRemoveRedundant_whenFetchingClientsByUserId_thenTheRelevantClientsAreReturned() = runTest { - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClientsAndRemoveRedundant(listOf(insertedClient1, insertedClient2)) val result = clientDAO.getClientsOfUserByQualifiedID(userId) @@ -174,7 +173,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClientIsInsertedAndRemoveRedundant_whenFetchingClientsByUserId_thenTheRedundantClientsAreNotReturned() = runTest { - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClientsAndRemoveRedundant(listOf(insertedClient, insertedClient1)) // this supposes to remove insertedClient1 clientDAO.insertClientsAndRemoveRedundant(listOf(insertedClient, insertedClient2)) @@ -185,10 +184,36 @@ class ClientDAOTest : BaseDatabaseTest() { assertEquals(listOf(client, client2), result) } + @Test + fun givenIsMLSCapableIsFalse_whenUpdatingAClient_thenItShouldUpdatedToTrue() = runTest { + val user = user + userDAO.upsertUser(user) + clientDAO.insertClient(insertedClient.copy( + isMLSCapable = false + )) + clientDAO.insertClient(insertedClient.copy( + isMLSCapable = true + )) + assertTrue { clientDAO.getClientsOfUserByQualifiedID(userId).first().isMLSCapable } + } + + @Test + fun givenIsMLSCapableIsTrue_whenUpdatingAClient_thenItShouldRemainTrue() = runTest { + val user = user + userDAO.upsertUser(user) + clientDAO.insertClient(insertedClient.copy( + isMLSCapable = true + )) + clientDAO.insertClient(insertedClient.copy( + isMLSCapable = false + )) + assertTrue { clientDAO.getClientsOfUserByQualifiedID(userId).first().isMLSCapable } + } + @Test fun whenInsertingANewClient_thenIsMustBeMarkedAsValid() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) assertTrue { clientDAO.getClientsOfUserByQualifiedID(userId).first().isValid } } @@ -196,7 +221,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenValidClient_whenMarkingAsInvalid_thenClientInfoIsUpdated() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) clientDAO.tryMarkInvalid(listOf(insertedClient.userId to listOf(insertedClient.id))) assertFalse { clientDAO.getClientsOfUserByQualifiedID(userId).first().isValid } @@ -205,7 +230,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun whenClientIsInsertedTwice_thenIvValidMustNotBeChanged() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) clientDAO.tryMarkInvalid(listOf(insertedClient.userId to listOf(insertedClient.id))) clientDAO.insertClient(insertedClient) @@ -216,7 +241,7 @@ class ClientDAOTest : BaseDatabaseTest() { fun givenInvalidUserClient_whenSelectingConversationRecipients_thenOnlyValidClientAreReturned() = runTest { val user = user val expected: Map> = mapOf(user.id to listOf(client1, client2)) - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) clientDAO.insertClient(insertedClient1) clientDAO.insertClient(insertedClient2) @@ -230,7 +255,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenNewClientAdded_thenItIsMarkedAsNotVerified() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) assertFalse { clientDAO.getClientsOfUserByQualifiedID(userId).first().isProteusVerified } } @@ -238,7 +263,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenClient_whenUpdatingVerificationStatus_thenItIsUpdated() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) clientDAO.updateClientProteusVerificationStatus(user.id, insertedClient.id, true) assertTrue { clientDAO.getClientsOfUserByQualifiedID(userId).first().isProteusVerified } @@ -250,7 +275,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenUserId_whenAClientIsAdded_thenNewListIsEmitted() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.observeClientsByUserId(user.id).test { awaitItem().also { result -> assertEquals(emptyList(), result) } @@ -267,7 +292,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenVerifiedClient_whenInsertingTheSameIdAgain_thenVerificationStatusIsNotChanges() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) assertFalse { clientDAO.getClientsOfUserByQualifiedID(userId).first().isProteusVerified } @@ -282,7 +307,7 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenUserIsPartOfConversation_whenGettingRecipient_thenOnlyValidUserClientsAreReturned() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) conversationDAO.insertConversation(conversationEntity1) memberDAO.insertMember(MemberEntity(user.id, MemberEntity.Role.Admin), conversationEntity1.id) @@ -300,13 +325,13 @@ class ClientDAOTest : BaseDatabaseTest() { @Test fun givenUserIsNotPartOfConversation_whenGettingRecipient_thenTheyAreNotIncludedInTheResult() = runTest { val user = user - userDAO.insertUser(user) + userDAO.upsertUser(user) clientDAO.insertClient(insertedClient) conversationDAO.insertConversation(conversationEntity1) memberDAO.insertMember(MemberEntity(user.id, MemberEntity.Role.Admin), conversationEntity1.id) val user2 = newUserEntity(QualifiedIDEntity("test2", "domain")) - userDAO.insertUser(user2) + userDAO.upsertUser(user2) val insertedClient2 = InsertClientParam( userId = user2.id, id = "id01", @@ -316,7 +341,8 @@ class ClientDAOTest : BaseDatabaseTest() { model = null, registrationDate = null, lastActive = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) clientDAO.insertClient(insertedClient2) @@ -340,7 +366,8 @@ class ClientDAOTest : BaseDatabaseTest() { model = null, registrationDate = null, lastActive = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) val client = insertedClient.toClient() @@ -382,6 +409,7 @@ private fun InsertClientParam.toClient(): Client = clientType = clientType, isValid = true, isProteusVerified = false, + isMLSCapable = false, label = label, model = model, registrationDate = registrationDate, diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/CompositeMessageTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/CompositeMessageTest.kt index e11bddaad46..5c7bb4ef7e6 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/CompositeMessageTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/CompositeMessageTest.kt @@ -58,7 +58,7 @@ class CompositeMessageTest : BaseDatabaseTest() { fun givenSuccess_whenInsertingCompositeMessage_thenMessageCanBeRetrieved() = runTest { val conversation = conversationEntity1 conversationDAO.insertConversation(conversation) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val compositeMessage = newRegularMessageEntity().copy( senderUserId = userEntity1.id, @@ -87,7 +87,7 @@ class CompositeMessageTest : BaseDatabaseTest() { fun givenCompositeMessage_whenMarkingButtonAsSelected_thenOnlyOneItIsMarked() = runTest { val conversation = conversationEntity1 conversationDAO.insertConversation(conversation) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val compositeMessage = newRegularMessageEntity().copy( senderUserId = userEntity1.id, @@ -124,7 +124,7 @@ class CompositeMessageTest : BaseDatabaseTest() { fun givenCompositeMessageWithSelection_whenResetSelection_thenSelectionIsFalse() = runTest { val conversation = conversationEntity1 conversationDAO.insertConversation(conversation) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val compositeMessage = newRegularMessageEntity().copy( senderUserId = userEntity1.id, diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOBenchmarkTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOBenchmarkTest.kt index a8b39d9fb85..add186045ef 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOBenchmarkTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOBenchmarkTest.kt @@ -176,8 +176,8 @@ class MessageDAOBenchmarkTest : BaseDatabaseTest() { private suspend fun setupData() { conversationDAO.insertConversations(listOf(conversationEntity1)) - userDAO.insertUser(userEntity1) - userDAO.insertUser(userEntity2) + userDAO.upsertUser(userEntity1) + userDAO.upsertUser(userEntity2) } private companion object { diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOTest.kt index d55460c66b1..03e554265b8 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageDAOTest.kt @@ -401,7 +401,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages( listOf( @@ -450,7 +450,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages( listOf( @@ -484,7 +484,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val message = buildList { // add 9 Message before the lastReadDate @@ -520,7 +520,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val readMessagesCount = 3 val unreadMessagesCount = 2 @@ -570,7 +570,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val readMessagesCount = 3 val unreadMessagesCount = 2 @@ -622,7 +622,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages( listOf( @@ -655,7 +655,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val unreadMessagesCount = 2 val message = buildList { @@ -796,7 +796,7 @@ class MessageDAOTest : BaseDatabaseTest() { newConversationEntity(id = conversationId2) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages( listOf( newRegularMessageEntity( @@ -876,7 +876,7 @@ class MessageDAOTest : BaseDatabaseTest() { lastReadDate = "2000-01-01T12:00:00.000Z".toInstant() ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages(listOf(previewAssetMessage)) // when @@ -937,7 +937,7 @@ class MessageDAOTest : BaseDatabaseTest() { lastReadDate = "2000-01-01T12:00:00.000Z".toInstant() ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages(listOf(previewAssetMessage)) // when @@ -997,7 +997,7 @@ class MessageDAOTest : BaseDatabaseTest() { lastReadDate = "2000-01-01T12:00:00.000Z".toInstant() ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessages(listOf(previewAssetMessage)) // when @@ -1087,7 +1087,7 @@ class MessageDAOTest : BaseDatabaseTest() { lastReadDate = "2000-01-01T12:00:00.000Z".toInstant() ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) messageDAO.insertOrIgnoreMessage(initialAssetMessage) // when @@ -1126,7 +1126,7 @@ class MessageDAOTest : BaseDatabaseTest() { lastReadDate = "2000-01-01T12:00:00.000Z".toInstant(), ) ) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val message1 = newRegularMessageEntity( id = messageId, @@ -1160,8 +1160,8 @@ class MessageDAOTest : BaseDatabaseTest() { lastReadDate = "2000-01-01T12:00:00.000Z".toInstant(), ) ) - userDAO.insertUser(userEntity1) - userDAO.insertUser(userEntity2) + userDAO.upsertUser(userEntity1) + userDAO.upsertUser(userEntity2) val messageFromUser1 = newRegularMessageEntity( id = messageId, @@ -1318,7 +1318,7 @@ class MessageDAOTest : BaseDatabaseTest() { val conversationId = QualifiedIDEntity("1", "someDomain") val messageId = "ConversationReceiptModeChanged Message" conversationDAO.insertConversation(newConversationEntity(id = conversationId)) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) // when messageDAO.insertOrIgnoreMessages( @@ -1354,7 +1354,7 @@ class MessageDAOTest : BaseDatabaseTest() { conversationDAO.insertConversation(newConversationEntity(id = conversationId2)) val messageId = "systemMessage" - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) // when messageDAO.persistSystemMessageToAllConversations( @@ -1410,7 +1410,7 @@ class MessageDAOTest : BaseDatabaseTest() { ) ) val messageId = "systemMessage" - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) // when messageDAO.persistSystemMessageToAllConversations( @@ -1516,8 +1516,8 @@ class MessageDAOTest : BaseDatabaseTest() { val conversationId = QualifiedIDEntity("1", "someDomain") val messageId = "Conversation MessageSent With Partial Success" conversationDAO.insertConversation(newConversationEntity(id = conversationId)) - userDAO.insertUser(userEntity1) - userDAO.insertUser(userEntity2) + userDAO.upsertUser(userEntity1) + userDAO.upsertUser(userEntity2) messageDAO.insertOrIgnoreMessages( listOf( @@ -1638,6 +1638,129 @@ class MessageDAOTest : BaseDatabaseTest() { assertTrue(result.readCount == 0L) } + @Test + fun givenExistingMessagesAtSource_whenMovingMessages_thenMessagesAreAccessibleAtDestination() = runTest { + // given + val source = conversationEntity1 + val destination = conversationEntity2 + userDAO.upsertUsers(listOf(userEntity1, userEntity2)) + conversationDAO.insertConversation(source) + conversationDAO.insertConversation(destination) + + val allMessages = listOf( + newRegularMessageEntity( + id = "1", + senderUserId = userEntity1.id, + conversationId = source.id, + content = MessageEntityContent.Text(messageBody = "Message 1") + ), + newRegularMessageEntity( + id = "2", + senderUserId = userEntity1.id, + conversationId = source.id, + content = MessageEntityContent.Text(messageBody = "Message 2") + ) + ) + messageDAO.insertOrIgnoreMessages(allMessages) + + // when + messageDAO.moveMessages(source.id, destination.id) + + // then + val retrievedMessages = messageDAO.getMessagesByConversationAndVisibility( + destination.id, + 10, + 0, + listOf(MessageEntity.Visibility.VISIBLE) + ).first() + + assertEquals( + allMessages.map { it.content }.toSet(), + retrievedMessages.map { it.content }.toSet()) + } + + @Test + fun givenExistingMessagesAtSourceAndDestination_whenMovingMessages_thenMessagesAreAccessibleAtDestination() = runTest { + // given + val source = conversationEntity1 + val destination = conversationEntity2 + userDAO.upsertUsers(listOf(userEntity1, userEntity2)) + conversationDAO.insertConversation(source) + conversationDAO.insertConversation(destination) + + val allMessages = listOf( + newRegularMessageEntity( + id = "1", + senderUserId = userEntity1.id, + conversationId = source.id, + content = MessageEntityContent.Text(messageBody = "Message 1") + ), + newRegularMessageEntity( + id = "2", + senderUserId = userEntity1.id, + conversationId = destination.id, + content = MessageEntityContent.Text(messageBody = "Message 2") + ) + ) + messageDAO.insertOrIgnoreMessages(allMessages) + + // when + messageDAO.moveMessages(source.id, destination.id) + + // then + val retrievedMessages = messageDAO.getMessagesByConversationAndVisibility( + destination.id, + 10, + 0, + listOf(MessageEntity.Visibility.VISIBLE) + ).first() + + assertEquals( + allMessages.map { it.content }.toSet(), + retrievedMessages.map { it.content }.toSet()) + } + + @Test + fun givenNoExistingMessagesAtSource_whenMovingMessages_thenExistingMessagesAreAccessibleAtDestination() = runTest { + // given + val source = conversationEntity1 + val destination = conversationEntity2 + userDAO.upsertUsers(listOf(userEntity1, userEntity2)) + conversationDAO.insertConversation(source) + conversationDAO.insertConversation(destination) + + val allMessages = listOf( + newRegularMessageEntity( + id = "1", + senderUserId = userEntity1.id, + conversationId = destination.id, + content = MessageEntityContent.Text(messageBody = "Message 1") + ), + newRegularMessageEntity( + id = "2", + senderUserId = userEntity1.id, + conversationId = destination.id, + content = MessageEntityContent.Text(messageBody = "Message 2") + ) + ) + messageDAO.insertOrIgnoreMessages(allMessages) + + // when + messageDAO.moveMessages(source.id, destination.id) + + // then + val retrievedMessages = messageDAO.getMessagesByConversationAndVisibility( + destination.id, + 10, + 0, + listOf(MessageEntity.Visibility.VISIBLE) + ).first() + + assertEquals( + allMessages.map { it.content }.toSet(), + retrievedMessages.map { it.content }.toSet()) + } + private suspend fun insertInitialData() { userDAO.upsertUsers(listOf(userEntity1, userEntity2)) conversationDAO.insertConversation( diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtensionTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtensionTest.kt index a585abe557f..ba21ea62245 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtensionTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtensionTest.kt @@ -63,7 +63,7 @@ class MessageInsertExtensionTest : BaseDatabaseTest() { @Test fun givenDeletedAssetMessage_whenUpdateUploadStatus_thenFail() = runTest { conversationDAO.insertConversation(conversationEntity1) - userDAO.insertUser(userEntity1) + userDAO.upsertUser(userEntity1) val assetMessage = newRegularMessageEntity( id = "messageId", date = Instant.DISTANT_PAST, diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMapperTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMapperTest.kt index 7dc56ebafe0..982682e8f77 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMapperTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMapperTest.kt @@ -159,7 +159,8 @@ class MessageMapperTest { recipientsFailedDeliveryList: List? = null, buttonsJson: String = "[]", federationDomainList: List? = null, - federationType: MessageEntity.FederationType? = null + federationType: MessageEntity.FederationType? = null, + conversationProtocolChanged: ConversationEntity.Protocol? = null ): MessageEntity { return MessageMapper.toEntityMessageFromView( id, @@ -238,7 +239,8 @@ class MessageMapperTest { recipientsFailedDeliveryList, buttonsJson, federationDomainList, - federationType + federationType, + conversationProtocolChanged ) } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMetadataDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMetadataDAOTest.kt index b28f1a38055..71efc8f570b 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMetadataDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/message/MessageMetadataDAOTest.kt @@ -60,7 +60,7 @@ class MessageMetadataDAOTest : BaseDatabaseTest() { val originalUser = userEntity1 conversationDAO.insertConversation(conversationEntity1) - userDAO.insertUser(originalUser) + userDAO.upsertUser(originalUser) val originalMessage = newRegularMessageEntity( id = messageId, diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/newclient/NewClientDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/newclient/NewClientDAOTest.kt index fb11bb83758..dbff72fbad3 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/newclient/NewClientDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/newclient/NewClientDAOTest.kt @@ -44,7 +44,7 @@ class NewClientDAOTest: BaseDatabaseTest() { fun whenANewClientsIsAdded_thenNewClientListIsEmitted() = runTest { newClientDAO.observeNewClients().test { awaitItem().also { result -> assertEquals(emptyList(), result) } - newClientDAO.insertNewClient(insertedClient) + newClientDAO.insertNewClient(insertedClient1) awaitItem().also { result -> assertEquals(listOf(client), result) } } @@ -52,8 +52,8 @@ class NewClientDAOTest: BaseDatabaseTest() { @Test fun givenNewClients_whenClearNewClients_thenNewClientEmptyListIsEmitted() = runTest { - newClientDAO.insertNewClient(insertedClient) newClientDAO.insertNewClient(insertedClient1) + newClientDAO.insertNewClient(insertedClient2) newClientDAO.observeNewClients().test { awaitItem() @@ -66,20 +66,21 @@ class NewClientDAOTest: BaseDatabaseTest() { private companion object { val userId = QualifiedIDEntity("test", "domain") val user = newUserEntity(userId) - val insertedClient = InsertClientParam( + val insertedClient1 = InsertClientParam( userId = user.id, - id = "id0", + id = "id1", deviceType = null, clientType = null, label = null, model = null, registrationDate = null, lastActive = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) - val insertedClient1 = insertedClient.copy(user.id, "id1", deviceType = null) + val insertedClient2 = insertedClient1.copy(user.id, "id2", deviceType = null) - val client = insertedClient.toClientEntity() + val client = insertedClient1.toClientEntity() } } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/reaction/ReactionDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/reaction/ReactionDAOTest.kt index 86e09b782cf..f86b8e31ff6 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/reaction/ReactionDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/reaction/ReactionDAOTest.kt @@ -193,8 +193,8 @@ class ReactionDAOTest : BaseDatabaseTest() { } private suspend fun insertTestUsers() { - userDAO.insertUser(SELF_USER) - userDAO.insertUser(OTHER_USER) + userDAO.upsertUser(SELF_USER) + userDAO.upsertUser(OTHER_USER) } private companion object { diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/receipt/ReceiptDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/receipt/ReceiptDAOTest.kt index 5b6107f0993..67bc6893d91 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/receipt/ReceiptDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/receipt/ReceiptDAOTest.kt @@ -169,8 +169,8 @@ class ReceiptDAOTest : BaseDatabaseTest() { } private suspend fun insertTestData() { - userDAO.insertUser(SELF_USER) - userDAO.insertUser(OTHER_USER) + userDAO.upsertUser(SELF_USER) + userDAO.upsertUser(OTHER_USER) conversationDAO.insertConversation(TEST_CONVERSATION) messageDAO.insertOrIgnoreMessage(TEST_MESSAGE) } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/ClientStubs.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/ClientStubs.kt index 2625404f071..02fc21cb632 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/ClientStubs.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/ClientStubs.kt @@ -29,5 +29,6 @@ fun insertedClient(userId: QualifiedIDEntity = QualifiedIDEntity("test", "wire.c model = null, registrationDate = null, lastActive = null, - mlsPublicKeys = null + mlsPublicKeys = null, + isMLSCapable = false ) diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/TestStubs.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/TestStubs.kt index 7d42c742480..ef52abb3152 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/TestStubs.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/TestStubs.kt @@ -141,6 +141,35 @@ internal object TestStubs { verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED ) + val conversationEntity5 = ConversationEntity( + QualifiedIDEntity("5", "wire.com"), + "conversation4", + ConversationEntity.Type.GROUP, + null, + ConversationEntity.ProtocolInfo.Mixed( + "group4", + ConversationEntity.GroupState.ESTABLISHED, + 0UL, + Instant.parse("2021-03-30T15:36:00.000Z"), + cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + ), + creatorId = "someValue", + // This conversation was modified after the last time the user was notified about it + lastNotificationDate = "2021-03-30T15:30:00.000Z".toInstant(), + lastModifiedDate = "2021-03-30T15:36:00.000Z".toInstant(), + lastReadDate = "2000-01-01T12:00:00.000Z".toInstant(), + // and it's status is set to be only notified if there is a mention for the user + mutedStatus = ConversationEntity.MutedStatus.ONLY_MENTIONS_AND_REPLIES_ALLOWED, + access = listOf(ConversationEntity.Access.LINK, ConversationEntity.Access.INVITE), + accessRole = listOf(ConversationEntity.AccessRole.NON_TEAM_MEMBER, ConversationEntity.AccessRole.TEAM_MEMBER), + receiptMode = ConversationEntity.ReceiptMode.DISABLED, + messageTimer = messageTimer, + userMessageTimer = null, + archived = false, + archivedInstant = null, + verificationStatus = ConversationEntity.VerificationStatus.NOT_VERIFIED + ) + val member1 = MemberEntity(user1.id, MemberEntity.Role.Admin) val member2 = MemberEntity(user2.id, MemberEntity.Role.Member) val member3 = MemberEntity(user3.id, MemberEntity.Role.Admin) diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/UserStubs.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/UserStubs.kt index 2e0a150d2f7..183d207b410 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/UserStubs.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/utils/stubs/UserStubs.kt @@ -20,31 +20,13 @@ package com.wire.kalium.persistence.utils.stubs import com.wire.kalium.persistence.dao.ConnectionEntity import com.wire.kalium.persistence.dao.QualifiedIDEntity +import com.wire.kalium.persistence.dao.SupportedProtocolEntity import com.wire.kalium.persistence.dao.UserAvailabilityStatusEntity import com.wire.kalium.persistence.dao.UserDetailsEntity import com.wire.kalium.persistence.dao.UserEntity import com.wire.kalium.persistence.dao.UserTypeEntity -fun newUserEntity(id: String = "test") = - UserEntity( - id = QualifiedIDEntity(id, "wire.com"), - name = "user$id", - handle = "handle$id", - email = "email$id", - phone = "phone$id", - accentId = 1, - team = "team", - ConnectionEntity.State.ACCEPTED, - null, - null, - UserAvailabilityStatusEntity.NONE, - UserTypeEntity.STANDARD, - botService = null, - deleted = false, - hasIncompleteMetadata = false, - expiresAt = null, - defederated = false - ) +fun newUserEntity(id: String = "test") = newUserEntity(QualifiedIDEntity(id, "wire.com"), id) fun newUserEntity(qualifiedID: QualifiedIDEntity, id: String = "test") = UserEntity( @@ -64,7 +46,9 @@ fun newUserEntity(qualifiedID: QualifiedIDEntity, id: String = "test") = deleted = false, hasIncompleteMetadata = false, expiresAt = null, - defederated = false + defederated = false, + supportedProtocols = setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) fun newUserDetailsEntity(id: String = "test") = @@ -86,5 +70,7 @@ fun newUserDetailsEntity(id: String = "test") = hasIncompleteMetadata = false, expiresAt = null, defederated = false, - isProteusVerified = false + isProteusVerified = false, + supportedProtocols = setOf(SupportedProtocolEntity.PROTEUS), + activeOneOnOneConversationId = null ) diff --git a/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt b/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt index 179654c4846..f7073f8443e 100644 --- a/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt +++ b/tango-tests/src/integrationTest/kotlin/action/LoginActions.kt @@ -86,7 +86,8 @@ object LoginActions { locale = "", managedByDTO = null, phone = null, - ssoID = null + ssoID = null, + supportedProtocols = null ) private val VALID_SELF_RESPONSE = UserDTOJson.createValid(selfUserDTO) diff --git a/tango-tests/src/integrationTest/kotlin/util/ListUsersResponseJson.kt b/tango-tests/src/integrationTest/kotlin/util/ListUsersResponseJson.kt index 33453af2308..cdb471a2954 100644 --- a/tango-tests/src/integrationTest/kotlin/util/ListUsersResponseJson.kt +++ b/tango-tests/src/integrationTest/kotlin/util/ListUsersResponseJson.kt @@ -42,7 +42,8 @@ object ListUsersResponseJson { email = null, expiresAt = null, nonQualifiedId = USER_1.value, - service = null + service = null, + supportedProtocols = null ), UserProfileDTO( id = USER_2, @@ -56,7 +57,8 @@ object ListUsersResponseJson { email = null, expiresAt = null, nonQualifiedId = USER_2.value, - service = null + service = null, + supportedProtocols = null ), ) diff --git a/tango-tests/src/integrationTest/kotlin/util/UserDTOJson.kt b/tango-tests/src/integrationTest/kotlin/util/UserDTOJson.kt index 5967b5efea8..9afcae43c2a 100644 --- a/tango-tests/src/integrationTest/kotlin/util/UserDTOJson.kt +++ b/tango-tests/src/integrationTest/kotlin/util/UserDTOJson.kt @@ -90,7 +90,8 @@ object UserDTOJson { locale = "", managedByDTO = null, phone = null, - ssoID = null + ssoID = null, + supportedProtocols = null ), jsonProvider ) }