Skip to content

Commit

Permalink
feat: Update legal hold status on message receive (WPB-5442) (#2286)
Browse files Browse the repository at this point in the history
* feat: add legal hold status to conversation table

* feat: add legal hold status to conversation table

* feat: add missing params of legal hold

* chore: unit test

* chore: address comments

* feat: include legal hold flag when sending/receiving messages

* chore: detekt

* feat: update legal hold status on message receive

* feat: legal hold status for ephemeral messages

* Revert "feat: legal hold status for ephemeral messages"

This reverts commit ef75642.

* chore: detekt
  • Loading branch information
ohassine authored Dec 11, 2023
1 parent ab83fea commit 597af01
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1240,10 +1240,14 @@ class UserSessionScope internal constructor(

private val newMessageHandler: NewMessageEventHandler
get() = NewMessageEventHandlerImpl(
proteusUnpacker, mlsUnpacker, applicationMessageHandler,
proteusUnpacker,
mlsUnpacker,
conversationRepository,
applicationMessageHandler,
{ conversationId, messageId ->
messages.ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId)
}, userId,
},
userId,
staleEpochVerifier
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.wire.kalium.cryptography.exceptions.ProteusException
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.ProteusFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.EventLoggingStatus
import com.wire.kalium.logic.data.event.logEventProcessing
Expand All @@ -40,9 +41,11 @@ internal interface NewMessageEventHandler {
suspend fun handleNewMLSMessage(event: Event.Conversation.NewMLSMessage)
}

@Suppress("LongParameterList")
internal class NewMessageEventHandlerImpl(
private val proteusMessageUnpacker: ProteusMessageUnpacker,
private val mlsMessageUnpacker: MLSMessageUnpacker,
private val conversationRepository: ConversationRepository,
private val applicationMessageHandler: ApplicationMessageHandler,
private val enqueueSelfDeletion: (conversationId: ConversationId, messageId: String) -> Unit,
private val selfUserId: UserId,
Expand Down Expand Up @@ -83,7 +86,10 @@ internal class NewMessageEventHandlerImpl(
}.onSuccess {
if (it is MessageUnpackResult.ApplicationMessage) {
handleSuccessfulResult(it)
// TODO(legalhold): update legal hold status in DB
conversationRepository.updateLegalHoldStatus(
conversationId = it.conversationId,
legalHoldStatus = it.content.legalHoldStatus
)
onMessageInserted(it)
}
kaliumLogger
Expand Down Expand Up @@ -130,6 +136,10 @@ internal class NewMessageEventHandlerImpl(
it.forEach {
if (it is MessageUnpackResult.ApplicationMessage) {
handleSuccessfulResult(it)
conversationRepository.updateLegalHoldStatus(
conversationId = it.conversationId,
legalHoldStatus = it.content.legalHoldStatus
)
onMessageInserted(it)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.ProteusFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.message.ProtoContent
Expand Down Expand Up @@ -142,11 +143,59 @@ class NewMessageEventHandlerTest {
.wasInvoked(exactly = once)
}

@Test
fun givenUnpackingSuccess_whenHandling_thenHandleContent() = runTest {
val (arrangement, newMessageEventHandler) = Arrangement()
.withUpdateLegalHoldStatusSuccess()
.withMLSUnpackerReturning(
Either.Right(
listOf(
MessageUnpackResult.ApplicationMessage(
conversationId = ConversationId("conversationID", "domain"),
timestampIso = Instant.DISTANT_PAST.toIsoDateTimeString(),
senderUserId = UserId("otherUserId", "domain"),
senderClientId = ClientId("clientID"),
content = ProtoContent.Readable(
messageUid = "messageUID",
messageContent = MessageContent.Text(
value = "messageContent"
),
expectsReadConfirmation = false,
legalHoldStatus = Conversation.LegalHoldStatus.DISABLED,
expiresAfterMillis = 123L
)
)
)
)
)
.arrange()

val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant())

newMessageEventHandler.handleNewMLSMessage(newMessageEvent)

verify(arrangement.mlsMessageUnpacker)
.suspendFunction(arrangement.mlsMessageUnpacker::unpackMlsMessage)
.with(eq(newMessageEvent))
.wasInvoked(exactly = once)

verify(arrangement.conversationRepository)
.suspendFunction(arrangement.conversationRepository::updateLegalHoldStatus)
.with(any(), eq(Conversation.LegalHoldStatus.DISABLED))
.wasInvoked(exactly = once)

verify(arrangement.applicationMessageHandler)
.suspendFunction(arrangement.applicationMessageHandler::handleContent)
.with(any(), any(), any(), any(), any())
.wasInvoked(exactly = once)
}

@Test
fun givenEphemeralMessageFromSelf_whenHandling_thenEnqueueForSelfDelete() = runTest {
val conversationID = ConversationId("conversationID", "domain")
val senderUserId = SELF_USER_ID
val (arrangement, newMessageEventHandler) = Arrangement()
.withUpdateLegalHoldStatusSuccess()
.withProteusUnpackerReturning(
Either.Right(
MessageUnpackResult.ApplicationMessage(
Expand All @@ -156,9 +205,9 @@ class NewMessageEventHandlerTest {
ClientId("clientID"),
ProtoContent.Readable(
messageUid = "messageUID",
messageContent = MessageContent.Text(
value = "messageContent"
),
messageContent = MessageContent.Text(
value = "messageContent"
),
expectsReadConfirmation = false,
legalHoldStatus = Conversation.LegalHoldStatus.DISABLED,
expiresAfterMillis = 123L
Expand Down Expand Up @@ -193,6 +242,7 @@ class NewMessageEventHandlerTest {
val conversationID = ConversationId("conversationID", "domain")
val senderUserId = UserId("otherUserId", "domain")
val (arrangement, newMessageEventHandler) = Arrangement()
.withUpdateLegalHoldStatusSuccess()
.withProteusUnpackerReturning(
Either.Right(
MessageUnpackResult.ApplicationMessage(
Expand Down Expand Up @@ -239,6 +289,7 @@ class NewMessageEventHandlerTest {
val conversationID = ConversationId("conversationID", "domain")
val senderUserId = SELF_USER_ID
val (arrangement, newMessageEventHandler) = Arrangement()
.withUpdateLegalHoldStatusSuccess()
.withProteusUnpackerReturning(
Either.Right(
MessageUnpackResult.ApplicationMessage(
Expand All @@ -248,9 +299,7 @@ class NewMessageEventHandlerTest {
ClientId("clientID"),
ProtoContent.Readable(
messageUid = "messageUID",
messageContent = MessageContent.Text(
value = "messageContent"
),
messageContent = MessageContent.Text(value = "messageContent"),
expectsReadConfirmation = false,
legalHoldStatus = Conversation.LegalHoldStatus.DISABLED,
expiresAfterMillis = null
Expand All @@ -269,6 +318,11 @@ class NewMessageEventHandlerTest {
.with(eq(newMessageEvent))
.wasInvoked(exactly = once)

verify(arrangement.conversationRepository)
.suspendFunction(arrangement.conversationRepository::updateLegalHoldStatus)
.with(eq(conversationID), eq(Conversation.LegalHoldStatus.DISABLED))
.wasInvoked(exactly = once)

verify(arrangement.applicationMessageHandler)
.suspendFunction(arrangement.applicationMessageHandler::handleDecryptionError)
.with(any(), any(), any(), any(), any(), any())
Expand Down Expand Up @@ -297,26 +351,27 @@ class NewMessageEventHandlerTest {

verify(arrangement.staleEpochVerifier)
.suspendFunction(arrangement.staleEpochVerifier::verifyEpoch)
.with(eq(newMessageEvent.conversationId),eq(newMessageEvent.timestampIso.toInstant()))
.with(eq(newMessageEvent.conversationId), eq(newMessageEvent.timestampIso.toInstant()))
.wasInvoked(exactly = once)
}

@Test
fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldNotPersistDecryptionErrorMessage() = runTest {
val (arrangement, newMessageEventHandler) = Arrangement()
.withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch))
.withVerifyEpoch(Either.Right(Unit))
.arrange()
fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldNotPersistDecryptionErrorMessage() =
runTest {
val (arrangement, newMessageEventHandler) = Arrangement()
.withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch))
.withVerifyEpoch(Either.Right(Unit))
.arrange()

val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant())
val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant())

newMessageEventHandler.handleNewMLSMessage(newMessageEvent)
newMessageEventHandler.handleNewMLSMessage(newMessageEvent)

verify(arrangement.applicationMessageHandler)
.suspendFunction(arrangement.applicationMessageHandler::handleDecryptionError)
.with(any())
.wasNotInvoked()
}
verify(arrangement.applicationMessageHandler)
.suspendFunction(arrangement.applicationMessageHandler::handleDecryptionError)
.with(any())
.wasNotInvoked()
}

private class Arrangement {

Expand All @@ -326,6 +381,9 @@ class NewMessageEventHandlerTest {
@Mock
val mlsMessageUnpacker = mock(classOf<MLSMessageUnpacker>())

@Mock
val conversationRepository = mock(classOf<ConversationRepository>())

@Mock
val applicationMessageHandler = configure(mock(classOf<ApplicationMessageHandler>())) {
stubsUnitByDefault = true
Expand All @@ -340,8 +398,14 @@ class NewMessageEventHandlerTest {
private val newMessageEventHandler: NewMessageEventHandler = NewMessageEventHandlerImpl(
proteusMessageUnpacker,
mlsMessageUnpacker,
conversationRepository,
applicationMessageHandler,
{ conversationId, messageId -> ephemeralMessageDeletionHandler.startSelfDeletion(conversationId, messageId) },
{ conversationId, messageId ->
ephemeralMessageDeletionHandler.startSelfDeletion(
conversationId,
messageId
)
},
SELF_USER_ID,
staleEpochVerifier
)
Expand All @@ -353,13 +417,21 @@ class NewMessageEventHandlerTest {
.thenReturn(result)
}

fun withMLSUnpackerReturning(result: Either<CoreFailure, List<MessageUnpackResult>>) = apply {
given(mlsMessageUnpacker)
.suspendFunction(mlsMessageUnpacker::unpackMlsMessage)
.whenInvokedWith(any())
.thenReturn(result)
fun withUpdateLegalHoldStatusSuccess() = apply {
given(conversationRepository)
.suspendFunction(conversationRepository::updateLegalHoldStatus)
.whenInvokedWith(any(), any())
.thenReturn(Either.Right(Unit))
}

fun withMLSUnpackerReturning(result: Either<CoreFailure, List<MessageUnpackResult>>) =
apply {
given(mlsMessageUnpacker)
.suspendFunction(mlsMessageUnpacker::unpackMlsMessage)
.whenInvokedWith(any())
.thenReturn(result)
}

fun withVerifyEpoch(result: Either<CoreFailure, Unit>) = apply {
given(staleEpochVerifier)
.suspendFunction(staleEpochVerifier::verifyEpoch)
Expand Down

0 comments on commit 597af01

Please sign in to comment.