Skip to content

Commit

Permalink
fix: fetch group id on protocol change event (#2051)
Browse files Browse the repository at this point in the history
  • Loading branch information
typfel committed Sep 26, 2023
1 parent eb97b77 commit 4c9987e
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,24 @@ interface ConversationRepository {
): Either<CoreFailure, OneOnOneMembers>

/**
* Update a conversation's protocol.
* Update a conversation's protocol remotely.
*
* This also fetches the newly assigned `groupID` from the backend, if this operation fails the whole
* operation is cancelled and protocol change is not persisted.
*
* @return **true** if the protocol was changed or **false** if the protocol was unchanged.
*/
suspend fun updateProtocolRemotely(conversationId: ConversationId, protocol: Conversation.Protocol): Either<CoreFailure, Boolean>

/**
* Update a conversation's protocol locally.
*
* This also fetches the newly assigned `groupID` from the backend, if this operation fails the whole
* operation is cancelled and protocol change is not persisted.
*
* @return **true** if the protocol was changed or **false** if the protocol was unchanged.
*/
suspend fun updateProtocol(conversationId: ConversationId, protocol: Conversation.Protocol): Either<CoreFailure, Boolean>
suspend fun updateProtocolLocally(conversationId: ConversationId, protocol: Conversation.Protocol): Either<CoreFailure, Boolean>
}

@Suppress("LongParameterList", "TooManyFunctions")
Expand Down Expand Up @@ -880,7 +893,7 @@ internal class ConversationDataSource internal constructor(
}
}

override suspend fun updateProtocol(
override suspend fun updateProtocolRemotely(
conversationId: ConversationId,
protocol: Conversation.Protocol
): Either<CoreFailure, Boolean> =
Expand All @@ -894,18 +907,31 @@ internal class ConversationDataSource internal constructor(
}

is UpdateConversationProtocolResponse.ProtocolUpdated -> {
when (protocol) {
Conversation.Protocol.PROTEUS -> Either.Right(Unit)
Conversation.Protocol.MIXED -> fetchConversation(conversationId)
Conversation.Protocol.MLS -> {
wrapStorageRequest {
conversationDAO.updateConversationProtocol(
conversationId = conversationId.toDao(),
protocol = protocol.toDao()
)
}.map {}
}
}.map { true }
updateProtocolLocally(conversationId, protocol)
}
}
}

override suspend fun updateProtocolLocally(
conversationId: ConversationId,
protocol: Conversation.Protocol
): Either<CoreFailure, Boolean> =
wrapApiRequest {
conversationApi.fetchConversationDetails(conversationId.toApi())
}.flatMap { conversationResponse ->
wrapStorageRequest {
conversationDAO.updateConversationProtocol(
conversationId = conversationId.toDao(),
protocol = protocol.toDao()
)
}.flatMap { updated ->
if (updated) {
val selfUserTeamId = selfTeamIdProvider().getOrNull()
persistConversations(listOf(conversationResponse), selfUserTeamId?.value, invalidateMembers = true)
} else {
Either.Right(Unit)
}.map {
updated
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ class UserSessionScope internal constructor(

private val protocolUpdateEventHandler: ProtocolUpdateEventHandler
get() = ProtocolUpdateEventHandlerImpl(
conversationDAO = userStorage.database.conversationDAO,
conversationRepository = conversationRepository,
systemMessageInserter = systemMessageBuilder
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ internal class MLSMigratorImpl(

private suspend fun migrate(conversationId: ConversationId): Either<CoreFailure, Unit> {
kaliumLogger.i("migrating ${conversationId.toLogString()} to mixed")
return conversationRepository.updateProtocol(conversationId, Protocol.MIXED)
return conversationRepository.updateProtocolRemotely(conversationId, Protocol.MIXED)
.flatMap { updated ->
if (updated) {
systemMessageInserter.insertProtocolChangedSystemMessage(
Expand All @@ -106,7 +106,7 @@ internal class MLSMigratorImpl(

private suspend fun finalise(conversationId: ConversationId): Either<CoreFailure, Unit> {
kaliumLogger.i("finalising ${conversationId.toLogString()} to mls")
return conversationRepository.updateProtocol(conversationId, Protocol.MLS)
return conversationRepository.updateProtocolRemotely(conversationId, Protocol.MLS)
.fold({ failure ->
kaliumLogger.w("failed to finalise ${conversationId.toLogString()} to mls: $failure")
Either.Right(Unit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ internal class ConversationEventReceiverImpl(
is Event.Conversation.TypingIndicator -> typingIndicatorHandler.handle(event)
is Event.Conversation.ConversationProtocol -> {
protocolUpdateEventHandler.handle(event)
Either.Right(Unit)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,31 @@
package com.wire.kalium.logic.sync.receiver.conversation

import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.data.conversation.toDao
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.EventLoggingStatus
import com.wire.kalium.logic.data.event.logEventProcessing
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.message.SystemMessageInserter
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapStorageRequest
import com.wire.kalium.persistence.dao.conversation.ConversationDAO

interface ProtocolUpdateEventHandler {
suspend fun handle(event: Event.Conversation.ConversationProtocol)
suspend fun handle(event: Event.Conversation.ConversationProtocol): Either<CoreFailure, Unit>
}

internal class ProtocolUpdateEventHandlerImpl(
private val conversationDAO: ConversationDAO,
private val conversationRepository: ConversationRepository,
private val systemMessageInserter: SystemMessageInserter
) : ProtocolUpdateEventHandler {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) }

override suspend fun handle(event: Event.Conversation.ConversationProtocol) {
updateProtocol(event)
override suspend fun handle(event: Event.Conversation.ConversationProtocol): Either<CoreFailure, Unit> =
conversationRepository.updateProtocolLocally(event.conversationId, event.protocol)
.onSuccess { updated ->
if (updated) {
systemMessageInserter.insertProtocolChangedSystemMessage(
Expand All @@ -65,14 +65,5 @@ internal class ProtocolUpdateEventHandlerImpl(
event,
Pair("errorInfo", "$coreFailure")
)
}
}

private suspend fun updateProtocol(event: Event.Conversation.ConversationProtocol) = wrapStorageRequest {
conversationDAO.updateConversationProtocol(
event.conversationId.toDao(),
event.protocol.toDao()
)
}

}.map { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler
import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangement
import com.wire.kalium.logic.util.arrangement.dao.MemberDAOArrangementImpl
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.client.ClientApi
import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol
Expand All @@ -69,6 +70,7 @@ import com.wire.kalium.network.api.base.authenticated.conversation.model.Convers
import com.wire.kalium.network.api.base.authenticated.notification.EventContentDTO
import com.wire.kalium.network.api.base.model.ConversationAccessDTO
import com.wire.kalium.network.api.base.model.ConversationAccessRoleDTO
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.persistence.dao.ConversationIDEntity
import com.wire.kalium.persistence.dao.QualifiedIDEntity
Expand Down Expand Up @@ -1087,78 +1089,135 @@ class ConversationRepositoryTest {
}

@Test
fun givenConversation_whenUpdatingProtocolToMls_thenShouldUpdateLocally() = runTest {
fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest {
// given
val protocol = Conversation.Protocol.MLS

val (arrange, conversationRepository) = Arrangement()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS)
.withDaoUpdateProtocolSuccess()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED)
.arrange()

// when
val result = conversationRepository.updateProtocol(CONVERSATION_ID, protocol)
val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol)

// then
with(result) {
shouldSucceed()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.wasNotInvoked()
}
}

@Test
fun givenChange_whenUpdatingProtocol_thenShouldFetchConversationDetails() = runTest {
// given
val protocol = Conversation.Protocol.MIXED
val conversationResponse = NetworkResponse.Success(
TestConversation.CONVERSATION_RESPONSE,
emptyMap(),
HttpStatusCode.OK.value
)

val (arrangement, conversationRepository) = Arrangement()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS)
.withFetchConversationsDetails(conversationResponse)
.withDaoUpdateProtocolSuccess()
.arrange()

// when
val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol)

// then
with(result) {
shouldSucceed()
verify(arrangement.conversationApi)
.suspendFunction(arrangement.conversationApi::fetchConversationDetails)
.with(eq(CONVERSATION_ID.toApi()))
.wasInvoked(exactly = once)
}
}

@Test
fun givenNoChange_whenUpdatingProtocolToMls_thenShouldNotUpdateLocally() = runTest {
fun givenChange_whenUpdatingProtocol_thenShouldUpdateLocally() = runTest {
// given
val protocol = Conversation.Protocol.MLS
val conversationResponse = NetworkResponse.Success(
TestConversation.CONVERSATION_RESPONSE,
emptyMap(),
HttpStatusCode.OK.value
)

val (arrange, conversationRepository) = Arrangement()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_UNCHANGED)
.withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS)
.withFetchConversationsDetails(conversationResponse)
.withDaoUpdateProtocolSuccess()
.arrange()

// when
val result = conversationRepository.updateProtocol(CONVERSATION_ID, protocol)
val result = conversationRepository.updateProtocolRemotely(CONVERSATION_ID, protocol)

// then
with(result) {
shouldSucceed()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.wasNotInvoked()
.wasInvoked(exactly = once)
}
}

@Test
fun givenConversation_whenUpdatingProtocolToMixed_thenShouldFetchConversation() = runTest {
fun givenSuccessFetchingConversationDetails_whenUpdatingProtocolLocally_thenShouldUpdateLocally() = runTest {
// given
val protocol = Conversation.Protocol.MIXED
val protocol = Conversation.Protocol.MLS
val conversationResponse = NetworkResponse.Success(
TestConversation.CONVERSATION_RESPONSE,
emptyMap(),
HttpStatusCode.OK.value
)

val (arrangement, conversationRepository) = Arrangement()
.withUpdateProtocolResponse(UPDATE_PROTOCOL_SUCCESS)
val (arrange, conversationRepository) = Arrangement()
.withFetchConversationsDetails(conversationResponse)
.withDaoUpdateProtocolSuccess()
.arrange()

// when
val result = conversationRepository.updateProtocol(CONVERSATION_ID, protocol)
val result = conversationRepository.updateProtocolLocally(CONVERSATION_ID, protocol)

// then
with(result) {
shouldSucceed()
verify(arrangement.conversationApi)
.suspendFunction(arrangement.conversationApi::fetchConversationDetails)
.with(eq(CONVERSATION_ID.toApi()))
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.wasInvoked(exactly = once)
}
}

@Test
fun givenFailureFetchingConversationDetails_whenUpdatingProtocolLocally_thenShouldNotUpdateLocally() = runTest {
// given
val protocol = Conversation.Protocol.MLS
val (arrange, conversationRepository) = Arrangement()
.withFetchConversationsDetails(NetworkResponse.Error(KaliumException.NoNetwork()))
.withDaoUpdateProtocolSuccess()
.arrange()

// when
val result = conversationRepository.updateProtocolLocally(CONVERSATION_ID, protocol)

// then
with(result) {
shouldFail()
verify(arrange.conversationDAO)
.suspendFunction(arrange.conversationDAO::updateConversationProtocol)
.with(eq(CONVERSATION_ID.toDao()), eq(protocol.toDao()))
.wasNotInvoked()
}
}

private class Arrangement :
MemberDAOArrangement by MemberDAOArrangementImpl() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class MLSMigratorTest {
migrator.migrateProteusConversations()

verify(arrangement.conversationRepository)
.suspendFunction(arrangement.conversationRepository::updateProtocol)
.suspendFunction(arrangement.conversationRepository::updateProtocolRemotely)
.with(eq(conversation.id), eq(Conversation.Protocol.MIXED))
.wasInvoked(once)

Expand Down Expand Up @@ -126,7 +126,7 @@ class MLSMigratorTest {
.wasInvoked(once)

verify(arrangement.conversationRepository)
.suspendFunction(arrangement.conversationRepository::updateProtocol)
.suspendFunction(arrangement.conversationRepository::updateProtocolRemotely)
.with(eq(conversation.id), eq(Conversation.Protocol.MLS))
.wasInvoked(once)
}
Expand Down Expand Up @@ -208,7 +208,7 @@ class MLSMigratorTest {
}
fun withUpdateProtocolReturns(result: Either<CoreFailure, Boolean> = Either.Right(true)) = apply {
given(conversationRepository)
.suspendFunction(conversationRepository::updateProtocol)
.suspendFunction(conversationRepository::updateProtocolRemotely)
.whenInvokedWith(any(), any())
.thenReturn(result)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,12 @@ object TestEvent {
timestampIso = "2022-03-30T15:36:00.000Z",
typingIndicatorMode = typingIndicatorMode
)

fun newConversationProtocolEvent() = Event.Conversation.ConversationProtocol(
id = "eventId",
conversationId = TestConversation.ID,
transient = false,
protocol = Conversation.Protocol.MIXED,
senderUserId = TestUser.OTHER_USER_ID
)
}
Loading

0 comments on commit 4c9987e

Please sign in to comment.