Skip to content

Commit

Permalink
feat: include legal hold status when sending messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ohassine committed Dec 5, 2023
1 parent 0999a28 commit 3d93546
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@

package com.wire.kalium.logic.data.conversation

import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.message.Message
import com.wire.kalium.logic.data.user.LegalHoldStatus
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.network.api.base.model.LegalHoldStatusDTO

interface LegalHoldStatusMapper {
fun fromApiModel(legalHoldStatusDTO: LegalHoldStatusDTO): LegalHoldStatus
fun mapLegalHoldConversationStatus(
legalHoldStatus: Either<StorageFailure, Conversation.LegalHoldStatus>,
message: Message.Sendable
): Conversation.LegalHoldStatus

}

class LegalHoldStatusMapperImpl : LegalHoldStatusMapper {
Expand All @@ -33,4 +41,17 @@ class LegalHoldStatusMapperImpl : LegalHoldStatusMapper {
LegalHoldStatusDTO.DISABLED -> LegalHoldStatus.DISABLED
LegalHoldStatusDTO.NO_CONSENT -> LegalHoldStatus.NO_CONSENT
}

override fun mapLegalHoldConversationStatus(
legalHoldStatus: Either<StorageFailure, Conversation.LegalHoldStatus>,
message: Message.Sendable
): Conversation.LegalHoldStatus = when (legalHoldStatus) {
is Either.Left -> Conversation.LegalHoldStatus.UNKNOWN
is Either.Right -> {
when (message) {
is Message.Regular -> legalHoldStatus.value
else -> Conversation.LegalHoldStatus.UNKNOWN
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.client.ProteusClientProvider
import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapperImpl
import com.wire.kalium.logic.feature.message.MLSMessageCreator
import com.wire.kalium.logic.feature.message.MLSMessageCreatorImpl
import com.wire.kalium.logic.feature.message.MessageEnvelopeCreator
Expand Down Expand Up @@ -119,14 +120,18 @@ class DebugScope internal constructor(

private val messageEnvelopeCreator: MessageEnvelopeCreator
get() = MessageEnvelopeCreatorImpl(
conversationRepository = conversationRepository,
legalHoldStatusMapper = LegalHoldStatusMapperImpl(),
proteusClientProvider = proteusClientProvider,
selfUserId = userId,
protoContentMapper = protoContentMapper
)

private val mlsMessageCreator: MLSMessageCreator
get() = MLSMessageCreatorImpl(
conversationRepository = conversationRepository,
mlsClientProvider = mlsClientProvider,
legalHoldStatusMapper = LegalHoldStatusMapperImpl(),
selfUserId = userId,
protoContentMapper = protoContentMapper
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package com.wire.kalium.logic.feature.message
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapper
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.message.Message
Expand All @@ -34,6 +36,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapMLSRequest
import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi
import kotlinx.coroutines.flow.first

interface MLSMessageCreator {

Expand All @@ -45,6 +48,8 @@ interface MLSMessageCreator {
}

class MLSMessageCreatorImpl(
private val conversationRepository: ConversationRepository,
private val legalHoldStatusMapper: LegalHoldStatusMapper,
private val mlsClientProvider: MLSClientProvider,
private val selfUserId: UserId,
private val protoContentMapper: ProtoContentMapper = MapperProvider.protoContentMapper(selfUserId = selfUserId),
Expand All @@ -60,10 +65,10 @@ class MLSMessageCreatorImpl(
else -> false
}

// TODO(legalhold) - Get correct legal hold status
val legalHoldStatus = when (message) {
is Message.Regular -> Conversation.LegalHoldStatus.DISABLED
else -> Conversation.LegalHoldStatus.DISABLED
val legalHoldStatus = conversationRepository.observeLegalHoldForConversation(
message.conversationId
).first().let {
legalHoldStatusMapper.mapLegalHoldConversationStatus(it, message)
}

val content = protoContentMapper.encodeToProtobuf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.client.ProteusClientProvider
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapper
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapProteusRequest
import kotlinx.coroutines.flow.first

interface MessageEnvelopeCreator {

Expand All @@ -64,6 +67,8 @@ interface MessageEnvelopeCreator {
}

class MessageEnvelopeCreatorImpl(
private val conversationRepository: ConversationRepository,
private val legalHoldStatusMapper: LegalHoldStatusMapper,
private val proteusClientProvider: ProteusClientProvider,
private val selfUserId: UserId,
private val protoContentMapper: ProtoContentMapper = MapperProvider.protoContentMapper(selfUserId = selfUserId),
Expand All @@ -81,10 +86,10 @@ class MessageEnvelopeCreatorImpl(
else -> false
}

// TODO(legalhold) - Get correct legal hold status
val legalHoldStatus = when (message) {
is Message.Regular -> Conversation.LegalHoldStatus.DISABLED
else -> Conversation.LegalHoldStatus.DISABLED
val legalHoldStatus = conversationRepository.observeLegalHoldForConversation(
message.conversationId
).first().let {
legalHoldStatusMapper.mapLegalHoldConversationStatus(it, message)
}

val actualMessageContent = ProtoContent.Readable(
Expand All @@ -105,7 +110,6 @@ class MessageEnvelopeCreatorImpl(
val senderClientId = message.senderClientId
val expectsReadConfirmation = false

// TODO - Get legal hold status
val legalHoldStatus = Conversation.LegalHoldStatus.UNKNOWN

val actualMessageContent = ProtoContent.Readable(message.id, message.content, expectsReadConfirmation, legalHoldStatus)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.client.ProteusClientProvider
import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapperImpl
import com.wire.kalium.logic.data.message.SessionEstablisher
import com.wire.kalium.logic.data.message.SessionEstablisherImpl
import com.wire.kalium.logic.feature.asset.GetAssetMessagesForConversationUseCase
Expand Down Expand Up @@ -106,13 +107,17 @@ class MessageScope internal constructor(

private val messageEnvelopeCreator: MessageEnvelopeCreator
get() = MessageEnvelopeCreatorImpl(
conversationRepository = conversationRepository,
legalHoldStatusMapper = LegalHoldStatusMapperImpl(),
proteusClientProvider = proteusClientProvider,
selfUserId = selfUserId,
protoContentMapper = protoContentMapper
)

private val mlsMessageCreator: MLSMessageCreator
get() = MLSMessageCreatorImpl(
conversationRepository = conversationRepository,
legalHoldStatusMapper = LegalHoldStatusMapperImpl(),
mlsClientProvider = mlsClientProvider,
selfUserId = selfUserId,
protoContentMapper = protoContentMapper
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.conversation

import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapperImpl
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.user.LegalHoldStatus
import com.wire.kalium.logic.framework.TestMessage
import com.wire.kalium.logic.framework.TestMessage.signalingMessage
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.network.api.base.model.LegalHoldStatusDTO
import kotlin.test.Test
import kotlin.test.assertEquals

class LegalHoldStatusMapperTest {

@Test
fun givenDTOLegalHoldStatus_whenMappingToDomain_thenMapCorrectly() {
val legalHoldStatusMapper = Arrangement().legalHoldStatusMapper

val resultDisabled = legalHoldStatusMapper.fromApiModel(LegalHoldStatusDTO.DISABLED)
assertEquals(LegalHoldStatus.DISABLED, resultDisabled)

val resultEnabled = legalHoldStatusMapper.fromApiModel(LegalHoldStatusDTO.ENABLED)
assertEquals(LegalHoldStatus.ENABLED, resultEnabled)

val resultNonConsent = legalHoldStatusMapper.fromApiModel(LegalHoldStatusDTO.NO_CONSENT)
assertEquals(LegalHoldStatus.NO_CONSENT, resultNonConsent)

val resultPending = legalHoldStatusMapper.fromApiModel(LegalHoldStatusDTO.PENDING)
assertEquals(LegalHoldStatus.PENDING, resultPending)
}

@Test
fun givenStorageFailure_whenMappingLegalHoldStatus_thenReturnUnknown() {
val legalHoldStatusMapper = Arrangement().legalHoldStatusMapper

val result = legalHoldStatusMapper.mapLegalHoldConversationStatus(
Either.Left(StorageFailure.DataNotFound),
TestMessage.TEXT_MESSAGE
)
assertEquals(Conversation.LegalHoldStatus.UNKNOWN, result)
}

@Test
fun givenRegularMessage_whenMappingLegalHoldStatus_thenReturnLegalHoldStatusOfTheMessage() {
val legalHoldStatusMapper = Arrangement().legalHoldStatusMapper

val result1 = legalHoldStatusMapper.mapLegalHoldConversationStatus(
Either.Right(Conversation.LegalHoldStatus.ENABLED),
TestMessage.TEXT_MESSAGE
)
assertEquals(Conversation.LegalHoldStatus.ENABLED, result1)

val result2 = legalHoldStatusMapper.mapLegalHoldConversationStatus(
Either.Right(Conversation.LegalHoldStatus.DISABLED),
TestMessage.TEXT_MESSAGE
)
assertEquals(Conversation.LegalHoldStatus.DISABLED, result2)
}

@Test
fun givenNonRegularMessage_whenMappingLegalHoldStatus_thenReturnDisabledStatus() {
val legalHoldStatusMapper = Arrangement().legalHoldStatusMapper

val result = legalHoldStatusMapper.mapLegalHoldConversationStatus(
Either.Right(Conversation.LegalHoldStatus.ENABLED),
signalingMessage(
MessageContent.TextEdited(
editMessageId = "ORIGINAL_MESSAGE_ID",
newContent = "some new content",
newMentions = listOf()
)
)
)

assertEquals(Conversation.LegalHoldStatus.UNKNOWN, result)
}

private class Arrangement {
val legalHoldStatusMapper = LegalHoldStatusMapperImpl()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package com.wire.kalium.logic.feature.message

import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapper
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.message.PlainMessageBlob
import com.wire.kalium.logic.data.message.ProtoContentMapper
Expand All @@ -37,11 +40,11 @@ import io.mockative.mock
import io.mockative.once
import io.mockative.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.runTest
import kotlin.test.BeforeTest
import kotlin.test.Test

@OptIn(ExperimentalCoroutinesApi::class)
class MLSMessageCreatorTest {

@Mock
Expand All @@ -50,11 +53,24 @@ class MLSMessageCreatorTest {
@Mock
private val protoContentMapper = mock(ProtoContentMapper::class)

@Mock
private val conversationRepository = mock(ConversationRepository::class)

@Mock
private val legalHoldStatusMapper = mock(LegalHoldStatusMapper::class)


private lateinit var mlsMessageCreator: MLSMessageCreator

@BeforeTest
fun setup() {
mlsMessageCreator = MLSMessageCreatorImpl(mlsClientProvider, SELF_USER_ID, protoContentMapper)
mlsMessageCreator = MLSMessageCreatorImpl(
conversationRepository,
legalHoldStatusMapper,
mlsClientProvider,
SELF_USER_ID,
protoContentMapper
)
}

@Test
Expand All @@ -65,6 +81,16 @@ class MLSMessageCreatorTest {
.whenInvokedWith(anything())
.then { Either.Right(MLS_CLIENT) }

given(conversationRepository)
.suspendFunction(conversationRepository::observeLegalHoldForConversation)
.whenInvokedWith(anything())
.then { flowOf(Either.Right(Conversation.LegalHoldStatus.DISABLED)) }

given(legalHoldStatusMapper)
.function(legalHoldStatusMapper::mapLegalHoldConversationStatus)
.whenInvokedWith(anything(), anything())
.thenReturn(Conversation.LegalHoldStatus.DISABLED)

given(MLS_CLIENT)
.suspendFunction(MLS_CLIENT::encryptMessage)
.whenInvokedWith(anything(), anything())
Expand All @@ -82,6 +108,16 @@ class MLSMessageCreatorTest {
.function(MLS_CLIENT::encryptMessage)
.with(eq(CRYPTO_GROUP_ID), eq(plainData))
.wasInvoked(once)

verify(conversationRepository)
.suspendFunction(conversationRepository::observeLegalHoldForConversation)
.with(anything())
.wasInvoked(once)

verify(legalHoldStatusMapper)
.function(legalHoldStatusMapper::mapLegalHoldConversationStatus)
.with(anything(), anything())
.wasInvoked(once)
}

private companion object {
Expand Down
Loading

0 comments on commit 3d93546

Please sign in to comment.