From 19491b1e0625159af0b9c7909d2872d1e98f2a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Saleniuk?= <30429749+saleniuk@users.noreply.github.com> Date: Thu, 28 Dec 2023 12:51:46 +0100 Subject: [PATCH] feat: discover legal hold when sending message [WPB-5999] (#2333) * feat: discover legal hold when receiving message [WPB-5837] * removed unused code * handle only once for each user id * reduce db queries * use date of received message to create system message for conversation * replace DebounceBuffer with simpler TriggerBuffer * handle live messages right away and use message timestamp for new system messages * change name of the buffer * fix detekt * feat: discover legal hold when sending message [WPB-5999] * return dedicated failure type when legal hold enabled * return messageId with legal hold enabled failure * fix detekt issues --- .../LegalHoldEnabledForConversationFailure.kt | 24 +++ .../kalium/logic/feature/UserSessionScope.kt | 4 + .../kalium/logic/feature/debug/DebugScope.kt | 10 +- .../logic/feature/message/MessageScope.kt | 6 + .../message/MessageSendFailureHandler.kt | 20 ++- .../logic/feature/message/MessageSender.kt | 60 +++++-- .../handler/legalhold/LegalHoldHandler.kt | 35 +++- .../prekey/MessageSendFailureHandlerTest.kt | 92 ++++++++-- .../feature/message/MessageSenderTest.kt | 145 ++++++++++++++++ .../handler/legalhold/LegalHoldHandlerTest.kt | 159 +++++++++++++++++- .../repository/ClientRepositoryArrangement.kt | 15 ++ 11 files changed, 526 insertions(+), 44 deletions(-) create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/failure/LegalHoldEnabledForConversationFailure.kt diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/failure/LegalHoldEnabledForConversationFailure.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/failure/LegalHoldEnabledForConversationFailure.kt new file mode 100644 index 00000000000..0898684db2e --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/failure/LegalHoldEnabledForConversationFailure.kt @@ -0,0 +1,24 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.kalium.logic.failure + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.id.MessageId + +data class LegalHoldEnabledForConversationFailure(val messageId: MessageId) : CoreFailure.FeatureFailure() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index e479264db7a..7fd70b3356a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1600,6 +1600,7 @@ class UserSessionScope internal constructor( conversationRepository, mlsConversationRepository, clientRepository, + clientRemoteRepository, clientIdProvider, proteusClientProvider, mlsClientProvider, @@ -1613,6 +1614,7 @@ class UserSessionScope internal constructor( selfConversationIdProvider, staleEpochVerifier, eventProcessor, + legalHoldHandler, this ) val messages: MessageScope @@ -1625,6 +1627,7 @@ class UserSessionScope internal constructor( conversationRepository, mlsConversationRepository, clientRepository, + clientRemoteRepository, proteusClientProvider, mlsClientProvider, preKeyRepository, @@ -1641,6 +1644,7 @@ class UserSessionScope internal constructor( observeSelfDeletingMessages, messageMetadataRepository, staleEpochVerifier, + legalHoldHandler, this ) val users: UserScope diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt index 3968d65709b..e5731bc6570 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/debug/DebugScope.kt @@ -23,7 +23,9 @@ import com.wire.kalium.logic.data.asset.AssetRepository import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.client.ProteusClientProvider +import com.wire.kalium.logic.data.client.remote.ClientRemoteRepository import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapperImpl import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.data.message.MessageRepository @@ -35,7 +37,6 @@ import com.wire.kalium.logic.data.prekey.PreKeyRepository import com.wire.kalium.logic.data.sync.SlowSyncRepository import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository -import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapperImpl import com.wire.kalium.logic.feature.message.MLSMessageCreator import com.wire.kalium.logic.feature.message.MLSMessageCreatorImpl import com.wire.kalium.logic.feature.message.MessageEnvelopeCreator @@ -53,6 +54,7 @@ import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageFor import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionHandlerImpl import com.wire.kalium.logic.sync.SyncManager import com.wire.kalium.logic.sync.incremental.EventProcessor +import com.wire.kalium.logic.sync.receiver.handler.legalhold.LegalHoldHandler import com.wire.kalium.logic.util.MessageContentEncoder import com.wire.kalium.util.KaliumDispatcher import com.wire.kalium.util.KaliumDispatcherImpl @@ -67,6 +69,7 @@ class DebugScope internal constructor( private val conversationRepository: ConversationRepository, private val mlsConversationRepository: MLSConversationRepository, private val clientRepository: ClientRepository, + private val clientRemoteRepository: ClientRemoteRepository, private val currentClientIdProvider: CurrentClientIdProvider, private val proteusClientProvider: ProteusClientProvider, private val mlsClientProvider: MLSClientProvider, @@ -80,6 +83,7 @@ class DebugScope internal constructor( private val selfConversationIdProvider: SelfConversationIdProvider, private val staleEpochVerifier: StaleEpochVerifier, private val eventProcessor: EventProcessor, + private val legalHoldHandler: LegalHoldHandler, private val scope: CoroutineScope, internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl ) { @@ -113,9 +117,10 @@ class DebugScope internal constructor( get() = MessageSendFailureHandlerImpl( userRepository, clientRepository, + clientRemoteRepository, messageRepository, messageSendingScheduler, - conversationRepository + conversationRepository, ) private val sessionEstablisher: SessionEstablisher @@ -153,6 +158,7 @@ class DebugScope internal constructor( mlsConversationRepository, syncManager, messageSendFailureHandler, + legalHoldHandler, sessionEstablisher, messageEnvelopeCreator, mlsMessageCreator, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt index 2a35b757dde..bd0367d3370 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageScope.kt @@ -23,6 +23,7 @@ import com.wire.kalium.logic.data.asset.AssetRepository import com.wire.kalium.logic.data.client.ClientRepository import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.client.ProteusClientProvider +import com.wire.kalium.logic.data.client.remote.ClientRemoteRepository import com.wire.kalium.logic.data.connection.ConnectionRepository import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.conversation.LegalHoldStatusMapper @@ -65,6 +66,7 @@ import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveSelfDeletionTim import com.wire.kalium.logic.feature.sessionreset.ResetSessionUseCase import com.wire.kalium.logic.feature.sessionreset.ResetSessionUseCaseImpl import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.sync.receiver.handler.legalhold.LegalHoldHandler import com.wire.kalium.logic.util.MessageContentEncoder import com.wire.kalium.util.KaliumDispatcher import com.wire.kalium.util.KaliumDispatcherImpl @@ -80,6 +82,7 @@ class MessageScope internal constructor( private val conversationRepository: ConversationRepository, private val mlsConversationRepository: MLSConversationRepository, private val clientRepository: ClientRepository, + private val clientRemoteRepository: ClientRemoteRepository, private val proteusClientProvider: ProteusClientProvider, private val mlsClientProvider: MLSClientProvider, private val preKeyRepository: PreKeyRepository, @@ -96,6 +99,7 @@ class MessageScope internal constructor( private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase, private val messageMetadataRepository: MessageMetadataRepository, private val staleEpochVerifier: StaleEpochVerifier, + private val legalHoldHandler: LegalHoldHandler, private val scope: CoroutineScope, internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl, private val legalHoldStatusMapper: LegalHoldStatusMapper = LegalHoldStatusMapperImpl @@ -105,6 +109,7 @@ class MessageScope internal constructor( get() = MessageSendFailureHandlerImpl( userRepository, clientRepository, + clientRemoteRepository, messageRepository, messageSendingScheduler, conversationRepository @@ -158,6 +163,7 @@ class MessageScope internal constructor( mlsConversationRepository, syncManager, messageSendFailureHandler, + legalHoldHandler, sessionEstablisher, messageEnvelopeCreator, mlsMessageCreator, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSendFailureHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSendFailureHandler.kt index d1c3ce1c6d9..14718e7dd82 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSendFailureHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSendFailureHandler.kt @@ -22,13 +22,16 @@ package com.wire.kalium.logic.feature.message import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.client.ClientMapper import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.client.remote.ClientRemoteRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.message.MessageRepository import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.failure.ProteusSendMessageFailure import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap @@ -66,12 +69,15 @@ interface MessageSendFailureHandler { ) } +@Suppress("LongParameterList") class MessageSendFailureHandlerImpl internal constructor( private val userRepository: UserRepository, private val clientRepository: ClientRepository, + private val clientRemoteRepository: ClientRemoteRepository, private val messageRepository: MessageRepository, private val messageSendingScheduler: MessageSendingScheduler, private val conversationRepository: ConversationRepository, + private val clientMapper: ClientMapper = MapperProvider.clientMapper(), ) : MessageSendFailureHandler { override suspend fun handleClientsHaveChangedFailure( @@ -108,10 +114,16 @@ class MessageSendFailureHandlerImpl internal constructor( else userRepository.fetchUsersByIds(userId) } - private suspend fun addMissingClients(missingClients: Map>): Either { - return if (missingClients.isEmpty()) Either.Right(Unit) - else clientRepository.storeMapOfUserToClientId(missingClients) - } + private suspend fun addMissingClients(missingClients: Map>): Either = + if (missingClients.isEmpty()) Either.Right(Unit) + else clientRemoteRepository.fetchOtherUserClients(missingClients.keys.toList()) + .flatMap { + it.map { (userId, clientList) -> clientMapper.toInsertClientParam(clientList, userId) } + .flatten().let { insertClientParamList -> + if (insertClientParamList.isEmpty()) Either.Right(Unit) + else clientRepository.storeUserClientListAndRemoveRedundantClients(insertClientParamList) + } + } override suspend fun handleFailureAndUpdateMessageStatus( failure: CoreFailure, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt index a78061bf16f..429d8d7445e 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/message/MessageSender.kt @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.conversation.MLSConversationRepository import com.wire.kalium.logic.data.conversation.Recipient import com.wire.kalium.logic.data.id.ConversationId import com.wire.kalium.logic.data.id.GroupID +import com.wire.kalium.logic.data.id.MessageId import com.wire.kalium.logic.data.message.BroadcastMessage import com.wire.kalium.logic.data.message.BroadcastMessageOption import com.wire.kalium.logic.data.message.BroadcastMessageTarget @@ -44,6 +45,7 @@ import com.wire.kalium.logic.data.message.getType import com.wire.kalium.logic.data.prekey.UsersWithoutSessions import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.failure.LegalHoldEnabledForConversationFailure import com.wire.kalium.logic.failure.ProteusSendMessageFailure import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap @@ -53,6 +55,7 @@ import com.wire.kalium.logic.functional.onFailure import com.wire.kalium.logic.functional.onSuccess import com.wire.kalium.logic.kaliumLogger import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.sync.receiver.handler.legalhold.LegalHoldHandler import com.wire.kalium.network.exceptions.KaliumException import com.wire.kalium.network.exceptions.isMlsStaleMessage import com.wire.kalium.util.DateTimeUtil @@ -135,6 +138,7 @@ internal class MessageSenderImpl internal constructor( private val mlsConversationRepository: MLSConversationRepository, private val syncManager: SyncManager, private val messageSendFailureHandler: MessageSendFailureHandler, + private val legalHoldHandler: LegalHoldHandler, private val sessionEstablisher: SessionEstablisher, private val messageEnvelopeCreator: MessageEnvelopeCreator, private val mlsMessageCreator: MLSMessageCreator, @@ -359,6 +363,8 @@ internal class MessageSenderImpl internal constructor( failure = it, action = "Send", messageLogString = message.toLogString(), + messageId = message.id, + messageTimestampIso = message.date, conversationId = message.conversationId, remainingAttempts = remainingAttempts ) { remainingAttempts -> @@ -385,7 +391,7 @@ internal class MessageSenderImpl internal constructor( messageRepository .broadcastEnvelope(envelope, option) .fold({ - handleProteusError(it, "Broadcast", message.toLogString(), null, remainingAttempts = 1) { + handleProteusError(it, "Broadcast", message.toLogString(), message.id, message.date, null, remainingAttempts = 1) { attemptToBroadcastWithProteus( message, target, @@ -401,6 +407,8 @@ internal class MessageSenderImpl internal constructor( failure: CoreFailure, action: String, // Send or Broadcast messageLogString: String, + messageId: MessageId, + messageTimestampIso: String, conversationId: ConversationId?, remainingAttempts: Int, retry: suspend (remainingAttempts: Int) -> Either @@ -410,21 +418,33 @@ internal class MessageSenderImpl internal constructor( logger.w( "Proteus $action Failure: { \"message\" : \"${messageLogString}\", \"errorInfo\" : \"${failure}\" }" ) - messageSendFailureHandler - .handleClientsHaveChangedFailure(failure, conversationId) - .flatMap { - if (remainingAttempts > 0) { - logger.w( - "Retrying (remaining attempts: $remainingAttempts) after Proteus $action " + - "Failure: { \"message\" : \"${messageLogString}\"}" - ) - retry(remainingAttempts - 1) - } else { - logger.e( - "No remaining attempts to retry after Proteus $action " + - "Failure: { \"message\" : \"${messageLogString}\"}" - ) - Either.Left(failure) + handleLegalHoldChanges(conversationId, messageTimestampIso) { + messageSendFailureHandler + .handleClientsHaveChangedFailure(failure, conversationId) + } + .flatMap { legalHoldEnabled -> + when { + legalHoldEnabled -> { + logger.w( + "Legal hold enabled, no retry after Proteus $action " + + "Failure: { \"message\" : \"${messageLogString}\", \"errorInfo\" : \"${failure}\" }" + ) + Either.Left(LegalHoldEnabledForConversationFailure(messageId)) + } + remainingAttempts > 0 -> { + logger.w( + "Retrying (remaining attempts: $remainingAttempts) after Proteus $action " + + "Failure: { \"message\" : \"${messageLogString}\", \"errorInfo\" : \"${failure}\" }" + ) + retry(remainingAttempts - 1) + } + else -> { + logger.e( + "No remaining attempts to retry after Proteus $action " + + "Failure: { \"message\" : \"${messageLogString}\", \"errorInfo\" : \"${failure}\" }" + ) + Either.Left(failure) + } } } .onFailure { @@ -443,6 +463,14 @@ internal class MessageSenderImpl internal constructor( } } + private suspend fun handleLegalHoldChanges( + conversationId: ConversationId?, + messageTimestampIso: String, + handleClientsHaveChangedFailure: suspend () -> Either + ) = + if (conversationId == null) handleClientsHaveChangedFailure().map { false } + else legalHoldHandler.handleMessageSendFailure(conversationId, messageTimestampIso, handleClientsHaveChangedFailure) + private fun getBroadcastParams( selfUserId: UserId, selfClientId: ClientId, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandler.kt index d2fe61709c8..a62a3c28610 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandler.kt @@ -40,6 +40,7 @@ import com.wire.kalium.logic.sync.ObserveSyncStateUseCase import com.wire.kalium.logic.sync.receiver.conversation.message.MessageUnpackResult import com.wire.kalium.logic.util.TriggerBuffer import com.wire.kalium.util.DateTimeUtil +import com.wire.kalium.util.DateTimeUtil.minusMilliseconds import com.wire.kalium.util.KaliumDispatcher import com.wire.kalium.util.KaliumDispatcherImpl import kotlinx.coroutines.CoroutineScope @@ -52,6 +53,11 @@ internal interface LegalHoldHandler { suspend fun handleEnable(legalHoldEnabled: Event.User.LegalHoldEnabled): Either suspend fun handleDisable(legalHoldDisabled: Event.User.LegalHoldDisabled): Either suspend fun handleNewMessage(message: MessageUnpackResult.ApplicationMessage, live: Boolean): Either + suspend fun handleMessageSendFailure( + conversationId: ConversationId, + messageTimestampIso: String, + handleFailure: suspend () -> Either + ): Either } @Suppress("LongParameterList") @@ -115,18 +121,43 @@ internal class LegalHoldHandlerImpl internal constructor( } override suspend fun handleNewMessage(message: MessageUnpackResult.ApplicationMessage, live: Boolean): Either { + val systemMessageTimestampIso = minusMilliseconds(message.timestampIso, 1) val isStatusChangedForConversation = when (val legalHoldStatus = message.content.legalHoldStatus) { Conversation.LegalHoldStatus.ENABLED, Conversation.LegalHoldStatus.DISABLED -> - handleForConversation(message.conversationId, legalHoldStatus, message.timestampIso) + handleForConversation(message.conversationId, legalHoldStatus, systemMessageTimestampIso) else -> false } if (isStatusChangedForConversation) { - if (live) handleUpdatedConversations(listOf(message.conversationId), message.timestampIso) // handle it right away + if (live) handleUpdatedConversations(listOf(message.conversationId), systemMessageTimestampIso) // handle it right away else bufferedUpdatedConversationIds.add(message.conversationId) // buffer and handle after sync } return Either.Right(Unit) } + override suspend fun handleMessageSendFailure( + conversationId: ConversationId, + messageTimestampIso: String, + handleFailure: suspend () -> Either, + ): Either = + membersHavingLegalHoldClient(conversationId).flatMap { membersHavingLegalHoldClientBefore -> + handleFailure().flatMap { + val systemMessageTimestampIso = minusMilliseconds(messageTimestampIso, 1) + membersHavingLegalHoldClient(conversationId).map { membersHavingLegalHoldClientAfter -> + val newStatus = + if (membersHavingLegalHoldClientAfter.isEmpty()) Conversation.LegalHoldStatus.DISABLED + else Conversation.LegalHoldStatus.ENABLED + val isStatusChangedForConversation = handleForConversation(conversationId, newStatus, systemMessageTimestampIso) + (membersHavingLegalHoldClientBefore - membersHavingLegalHoldClientAfter).forEach { + legalHoldSystemMessagesHandler.handleDisabledForUser(it, systemMessageTimestampIso) + } + (membersHavingLegalHoldClientAfter - membersHavingLegalHoldClientBefore).forEach { + legalHoldSystemMessagesHandler.handleEnabledForUser(it, systemMessageTimestampIso) + } + isStatusChangedForConversation && newStatus == Conversation.LegalHoldStatus.ENABLED + } + } + } + private suspend fun processEvent(selfUserId: UserId, userId: UserId) { if (selfUserId == userId) { userConfigRepository.deleteLegalHoldRequest() diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/prekey/MessageSendFailureHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/prekey/MessageSendFailureHandlerTest.kt index cc6e99c14fd..06a5d3890c6 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/prekey/MessageSendFailureHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/prekey/MessageSendFailureHandlerTest.kt @@ -21,11 +21,14 @@ package com.wire.kalium.logic.data.prekey import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.client.ClientMapper +import com.wire.kalium.logic.data.client.remote.ClientRemoteRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.ConversationRepository import com.wire.kalium.logic.data.message.MessageRepository import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.di.MapperProvider import com.wire.kalium.logic.failure.ProteusSendMessageFailure import com.wire.kalium.logic.feature.message.MessageSendFailureHandler import com.wire.kalium.logic.feature.message.MessageSendFailureHandlerImpl @@ -37,6 +40,7 @@ import com.wire.kalium.logic.test_util.TestNetworkException import com.wire.kalium.logic.util.arrangement.repository.ClientRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.ClientRepositoryArrangementImpl import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.network.api.base.authenticated.client.SimpleClientResponse import com.wire.kalium.persistence.dao.message.MessageEntity import io.mockative.Mock import io.mockative.any @@ -51,6 +55,7 @@ import kotlinx.coroutines.test.runTest import okio.IOException import kotlin.test.Test import kotlin.test.assertEquals +import com.wire.kalium.network.api.base.model.UserId as UserIdDTO class MessageSendFailureHandlerTest { @@ -59,30 +64,39 @@ class MessageSendFailureHandlerTest { val (arrangement, messageSendFailureHandler) = Arrangement() .arrange { withFetchUsersByIdSuccess() - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(emptyMap())) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) } val failureData = ProteusSendMessageFailure(mapOf(arrangement.userOne, arrangement.userTwo), mapOf(), mapOf(), null) messageSendFailureHandler.handleClientsHaveChangedFailure(failureData, null) + + verify(arrangement.userRepository) + .suspendFunction(arrangement.userRepository::fetchUsersByIds) + .with(eq(failureData.missingClientsOfUsers.keys)) + .wasInvoked(once) } @Test - fun givenMissingContactsAndClients_whenHandlingClientsHaveChangedFailureThenClientsShouldBeAddedToContacts() = runTest { + fun givenMissingClients_whenHandlingClientsHaveChangedFailure_thenSimpleClientsDataShouldBeFetchedAndAddedToContacts() = runTest { val (arrangement, messageSendFailureHandler) = Arrangement() .arrange { withFetchUsersByIdSuccess() - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(mapOf(userOneDTO, userTwoDTO))) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) } - val expected = - mapOf(arrangement.userOne.first to arrangement.userOne.second, arrangement.userTwo.first to arrangement.userTwo.second) val failureData = ProteusSendMessageFailure(mapOf(arrangement.userOne, arrangement.userTwo), mapOf(), mapOf(), null) messageSendFailureHandler.handleClientsHaveChangedFailure(failureData, null) + verify(arrangement.clientRemoteRepository) + .suspendFunction(arrangement.clientRemoteRepository::fetchOtherUserClients) + .with(eq(listOf(arrangement.userOne.first, arrangement.userTwo.first))) + .wasInvoked(once) verify(arrangement.clientRepository) - .suspendFunction(arrangement.clientRepository::storeMapOfUserToClientId) - .with(eq(expected)) + .suspendFunction(arrangement.clientRepository::storeUserClientListAndRemoveRedundantClients) + .with(eq(arrangement.userOneInsertClientParams + arrangement.userTwoInsertClientParams)) .wasInvoked(once) } @@ -101,13 +115,29 @@ class MessageSendFailureHandlerTest { assertEquals(Either.Left(failure), result) } + @Test + fun givenRepositoryFailsToFetchClients_whenHandlingClientsHaveChangedFailure_thenFailureShouldBePropagated() = runTest { + val failure = NetworkFailure.ServerMiscommunication(TestNetworkException.generic) + val (arrangement, messageSendFailureHandler) = Arrangement() + .arrange { + withFetchUsersByIdSuccess() + withFetchOtherUserClients(Either.Left(failure)) + } + val failureData = ProteusSendMessageFailure(mapOf(arrangement.userOne), mapOf(), mapOf(), null) + + val result = messageSendFailureHandler.handleClientsHaveChangedFailure(failureData, null) + result.shouldFail() + assertEquals(Either.Left(failure), result) + } + @Test fun givenRepositoryFailsToAddClientsToContacts_whenHandlingClientsHaveChangedFailure_thenFailureShouldBePropagated() = runTest { val failure = StorageFailure.Generic(IOException()) val (arrangement, messageSendFailureHandler) = Arrangement() .arrange { withFetchUsersByIdSuccess() - withStoreMapOfUserToClientId(Either.Left(failure)) + withFetchOtherUserClients(Either.Right(mapOf(userOneDTO))) + withStoreUserClientListAndRemoveRedundantClients(Either.Left(failure)) } val failureData = ProteusSendMessageFailure(mapOf(arrangement.userOne), mapOf(), mapOf(), null) @@ -183,7 +213,8 @@ class MessageSendFailureHandlerTest { .arrange { withRemoveClientsAndReturnUsersWithNoClients(Either.Right(listOf(userOne.first))) withFetchUsersByIdSuccess() - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(emptyMap())) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) } val failure = ProteusSendMessageFailure( missingClientsOfUsers = mapOf(), @@ -211,7 +242,8 @@ class MessageSendFailureHandlerTest { .arrange { withRemoveClientsAndReturnUsersWithNoClients(Either.Right(emptyList())) withFetchUsersByIdSuccess() - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(emptyMap())) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) } val failure = ProteusSendMessageFailure( @@ -240,7 +272,8 @@ class MessageSendFailureHandlerTest { .arrange { withRemoveClientsAndReturnUsersWithNoClients(Either.Right(listOf(userTwo.first))) withFetchUsersByIdSuccess() - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(mapOf(userOneDTO))) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) } val failure = ProteusSendMessageFailure( @@ -257,9 +290,14 @@ class MessageSendFailureHandlerTest { .with(eq(mapOf(arrangement.userTwo.first to arrangement.userTwo.second))) .wasInvoked(once) + verify(arrangement.clientRemoteRepository) + .suspendFunction(arrangement.clientRemoteRepository::fetchOtherUserClients) + .with(eq(listOf(arrangement.userOne.first))) + .wasInvoked(once) + verify(arrangement.clientRepository) - .suspendFunction(arrangement.clientRepository::storeMapOfUserToClientId) - .with(eq(mapOf(arrangement.userOne.first to arrangement.userOne.second))) + .suspendFunction(arrangement.clientRepository::storeUserClientListAndRemoveRedundantClients) + .with(eq(arrangement.userOneInsertClientParams)) .wasInvoked(once) verify(arrangement.userRepository) @@ -272,7 +310,8 @@ class MessageSendFailureHandlerTest { fun givenMissingClientsError_whenAConversationIdIsProvided_thenUpdateConversationInfo() = runTest { val (arrangement, messageSendFailureHandler) = Arrangement() .arrange { - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(emptyMap())) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) withFetchUsersByIdSuccess() withFetchConversation(Either.Right(Unit)) } @@ -295,7 +334,8 @@ class MessageSendFailureHandlerTest { fun givenMissingClientsError_whenNoConversationIdIsProvided_thenUpdateConversationInfo() = runTest { val (arrangement, messageSendFailureHandler) = Arrangement() .arrange { - withStoreMapOfUserToClientId(Either.Right(Unit)) + withFetchOtherUserClients(Either.Right(emptyMap())) + withStoreUserClientListAndRemoveRedundantClients(Either.Right(Unit)) withFetchUsersByIdSuccess() } val failureData = ProteusSendMessageFailure(mapOf(arrangement.userOne, arrangement.userTwo), mapOf(), mapOf(), null) @@ -346,18 +386,31 @@ class MessageSendFailureHandlerTest { @Mock val conversationRepository = mock(classOf()) + @Mock + val clientRemoteRepository = mock(classOf()) + + val clientMapper: ClientMapper = MapperProvider.clientMapper() + private val messageSendFailureHandler: MessageSendFailureHandler = MessageSendFailureHandlerImpl( userRepository, clientRepository, + clientRemoteRepository, messageRepository, messageSendingScheduler, - conversationRepository + conversationRepository, + clientMapper ) val userOne: Pair> = UserId("userId1", "anta.wire") to listOf(ClientId("clientId"), ClientId("secondClientId")) val userTwo: Pair> = UserId("userId2", "bella.wire") to listOf(ClientId("clientId2"), ClientId("secondClientId2")) + val userOneDTO: Pair> = + UserIdDTO("userId1", "anta.wire") to listOf(SimpleClientResponse("clientId"), SimpleClientResponse("secondClientId")) + val userTwoDTO: Pair> = + UserIdDTO("userId2", "bella.wire") to listOf(SimpleClientResponse("clientId2"), SimpleClientResponse("secondClientId2")) + val userOneInsertClientParams = clientMapper.toInsertClientParam(userOneDTO.second, userOneDTO.first) + val userTwoInsertClientParams = clientMapper.toInsertClientParam(userTwoDTO.second, userTwoDTO.first) val messageId = TestMessage.TEST_MESSAGE_ID val conversationId = TestConversation.ID @@ -390,6 +443,13 @@ class MessageSendFailureHandlerTest { .whenInvokedWith(any()) .thenReturn(result) } + + fun withFetchOtherUserClients(result: Either>>) = apply { + given(clientRemoteRepository) + .suspendFunction(clientRemoteRepository::fetchOtherUserClients) + .whenInvokedWith(any()) + .thenReturn(result) + } } private companion object { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt index 3d96cd596c3..48409193530 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/message/MessageSenderTest.kt @@ -41,6 +41,8 @@ import com.wire.kalium.logic.data.message.SessionEstablisher import com.wire.kalium.logic.data.prekey.UsersWithoutSessions import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.failure.LegalHoldEnabledForConversationFailure +import com.wire.kalium.logic.failure.ProteusSendMessageFailure import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.FEDERATION_MESSAGE_FAILURE import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.MESSAGE_SENT_TIME import com.wire.kalium.logic.feature.message.MessageSenderTest.Arrangement.Companion.TEST_MEMBER_2 @@ -51,10 +53,12 @@ import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestMessage import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.SyncManager +import com.wire.kalium.logic.sync.receiver.handler.legalhold.LegalHoldHandler import com.wire.kalium.logic.util.arrangement.mls.StaleEpochVerifierArrangement import com.wire.kalium.logic.util.arrangement.mls.StaleEpochVerifierArrangementImpl import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed +import com.wire.kalium.logic.util.thenReturnSequentially import com.wire.kalium.network.api.base.authenticated.message.MLSMessageApi import com.wire.kalium.network.api.base.model.ErrorResponse import com.wire.kalium.network.exceptions.KaliumException @@ -76,6 +80,7 @@ import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertIs import kotlin.time.Duration class MessageSenderTest { @@ -884,6 +889,111 @@ class MessageSenderTest { } } + @Test + fun givenProteusSendMessageFailure_WhenSendingMessage_ThenHandleFailureProperly() { + // given + val failure = ProteusSendMessageFailure(emptyMap(), emptyMap(), emptyMap(), emptyMap()) + val message = TestMessage.TEXT_MESSAGE + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withSendEnvelope(Either.Left(failure), Either.Right(MessageSent(MESSAGE_SENT_TIME))) // to avoid loop - fail then succeed + withPromoteMessageToSentUpdatingServerTime() + withHandleLegalHoldMessageSendFailure(Either.Right(false)) + withHandleClientsHaveChangedFailure() + } + arrangement.testScope.runTest { + // when + messageSender.sendMessage(message) + // then + verify(arrangement.messageSendFailureHandler) + .suspendFunction(arrangement.messageSendFailureHandler::handleClientsHaveChangedFailure) + .with(eq(failure), eq(message.conversationId)) + .wasInvoked() + verify(arrangement.legalHoldHandler) + .suspendFunction(arrangement.legalHoldHandler::handleMessageSendFailure) + .with(eq(message.conversationId), eq(message.date), anything()) + .wasInvoked() + } + } + + @Test + fun givenProteusSendMessageFailure_WhenBroadcastingMessage_ThenHandleFailureProperly() { + // given + val failure = ProteusSendMessageFailure(emptyMap(), emptyMap(), emptyMap(), emptyMap()) + val message = TestMessage.BROADCAST_MESSAGE + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withAllRecipients(listOf(Arrangement.TEST_RECIPIENT_1) to listOf()) + withCreateOutgoingBroadcastEnvelope() + withBroadcastEnvelope(Either.Left(failure), Either.Right(TestMessage.TEST_DATE_STRING)) // to avoid loop - fail then succeed + withHandleLegalHoldMessageSendFailure(Either.Right(false)) + withHandleClientsHaveChangedFailure() + } + arrangement.testScope.runTest { + // when + messageSender.broadcastMessage(message, BroadcastMessageTarget.AllUsers(100)) + // then + verify(arrangement.messageSendFailureHandler) + .suspendFunction(arrangement.messageSendFailureHandler::handleClientsHaveChangedFailure) + .with(eq(failure), eq(null)) + .wasInvoked() + verify(arrangement.legalHoldHandler) + .suspendFunction(arrangement.legalHoldHandler::handleMessageSendFailure) + .with(anything(), anything(), anything()) + .wasNotInvoked() + } + } + + @Test + fun givenProteusSendMessageFailureAndLegalHoldEnabledForConversation_WhenSendingMessage_ThenDoNotRetrySendingAfterHandlingFailure() { + // given + val failure = ProteusSendMessageFailure(emptyMap(), emptyMap(), emptyMap(), emptyMap()) + val message = TestMessage.TEXT_MESSAGE + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withSendEnvelope(Either.Left(failure), Either.Right(MessageSent(MESSAGE_SENT_TIME))) // to avoid loop - fail then succeed + withHandleLegalHoldMessageSendFailure(Either.Right(true)) + withHandleClientsHaveChangedFailure() + } + arrangement.testScope.runTest { + // when + val result = messageSender.sendMessage(message) + // then + result.shouldFail() { + assertIs(it) + assertEquals(message.id, it.messageId) + } + verify(arrangement.messageRepository) + .suspendFunction(arrangement.messageRepository::sendEnvelope) + .with(eq(message.conversationId), anything(), anything()) + .wasInvoked(exactly = once) + } + } + + @Test + fun givenProteusSendMessageFailureAndLegalHoldNotEnabledForConversation_WhenSendingMessage_ThenRetrySendingAfterHandlingFailure() { + // given + val failure = ProteusSendMessageFailure(emptyMap(), emptyMap(), emptyMap(), emptyMap()) + val message = TestMessage.TEXT_MESSAGE + val (arrangement, messageSender) = arrange { + withSendProteusMessage() + withSendEnvelope(Either.Left(failure), Either.Right(MessageSent(MESSAGE_SENT_TIME))) // to avoid loop - fail then succeed + withPromoteMessageToSentUpdatingServerTime() + withHandleLegalHoldMessageSendFailure(Either.Right(false)) + withHandleClientsHaveChangedFailure() + } + arrangement.testScope.runTest { + // when + val result = messageSender.sendMessage(message) + // then + result.shouldSucceed() + verify(arrangement.messageRepository) + .suspendFunction(arrangement.messageRepository::sendEnvelope) + .with(eq(message.conversationId), anything(), anything()) + .wasInvoked(exactly = twice) + } + } + private class Arrangement(private val block: Arrangement.() -> Unit): StaleEpochVerifierArrangement by StaleEpochVerifierArrangementImpl() { @@ -917,6 +1027,9 @@ class MessageSenderTest { @Mock val selfDeleteMessageSenderHandler = mock(EphemeralMessageDeletionHandler::class) + @Mock + val legalHoldHandler = mock(LegalHoldHandler::class) + val testScope = TestScope() private val messageSendingInterceptor = object : MessageSendingInterceptor { @@ -933,6 +1046,7 @@ class MessageSenderTest { mlsConversationRepository = mlsConversationRepository, syncManager = syncManager, messageSendFailureHandler = messageSendFailureHandler, + legalHoldHandler = legalHoldHandler, sessionEstablisher = sessionEstablisher, messageEnvelopeCreator = messageEnvelopeCreator, mlsMessageCreator = mlsMessageCreator, @@ -1015,6 +1129,13 @@ class MessageSenderTest { .thenReturn(result) } + fun withBroadcastEnvelope(vararg result: Either) = apply { + given(messageRepository) + .suspendFunction(messageRepository::broadcastEnvelope) + .whenInvokedWith(anything(), anything()) + .thenReturnSequentially(*result) + } + fun withCreateOutgoingMlsMessage(failing: Boolean = false) = apply { given(mlsMessageCreator) .suspendFunction(mlsMessageCreator::createOutgoingMLSMessage) @@ -1029,6 +1150,13 @@ class MessageSenderTest { .thenReturn(result) } + fun withSendEnvelope(vararg result: Either) = apply { + given(messageRepository) + .suspendFunction(messageRepository::sendEnvelope) + .whenInvokedWith(anything(), anything(), anything()) + .thenReturnSequentially(*result) + } + fun withSendOutgoingMlsMessage( result: Either = Either.Right(MessageSent(MESSAGE_SENT_TIME)), times: Int = Int.MAX_VALUE @@ -1126,6 +1254,23 @@ class MessageSenderTest { .thenReturn(Either.Right(Unit)) } + fun withHandleLegalHoldMessageSendFailure(result: Either = Either.Right(false)) = apply { + given(legalHoldHandler) + .suspendFunction(legalHoldHandler::handleMessageSendFailure) + .whenInvokedWith(anything(), anything(), anything()) + .then { _, _, handleFailure -> + handleFailure() // simulate the handler calling the handleFailure function + result + } + } + + fun withHandleClientsHaveChangedFailure(result: Either = Either.Right(Unit)) = apply { + given(messageSendFailureHandler) + .suspendFunction(messageSendFailureHandler::handleClientsHaveChangedFailure) + .whenInvokedWith(anything(), anything()) + .thenReturn(result) + } + companion object { fun arrange(configuration: Arrangement.() -> Unit) = Arrangement(configuration).arrange() diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandlerTest.kt index 8acbe380534..f4201323656 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/handler/legalhold/LegalHoldHandlerTest.kt @@ -17,6 +17,7 @@ */ package com.wire.kalium.logic.sync.receiver.handler.legalhold +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.Conversation @@ -38,6 +39,10 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.sync.ObserveSyncStateUseCase import com.wire.kalium.logic.sync.receiver.conversation.message.MessageUnpackResult import com.wire.kalium.logic.test_util.TestKaliumDispatcher +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import com.wire.kalium.logic.util.thenReturnSequentially +import com.wire.kalium.util.DateTimeUtil.minusMilliseconds import com.wire.kalium.util.DateTimeUtil.toIsoDateTimeString import com.wire.kalium.util.KaliumDispatcher import io.mockative.Mock @@ -60,6 +65,7 @@ import kotlinx.coroutines.test.setMain import kotlinx.datetime.Instant import kotlin.test.BeforeTest import kotlin.test.Test +import kotlin.test.assertEquals class LegalHoldHandlerTest { @@ -419,7 +425,7 @@ class LegalHoldHandlerTest { .wasInvoked() } @Test - fun givenConversation_whenHandlingNewMessageWithChangedLegalHold_thenUseTimestampOfThatMessageToCreateSystemMessage() = runTest { + fun givenConversation_whenHandlingNewMessageWithChangedLegalHold_thenUseTimestampOfMessageMinus1msToCreateSystemMessage() = runTest { // given val (arrangement, handler) = Arrangement() .withGetConversationsByUserIdSuccess(listOf(conversation(legalHoldStatus = Conversation.LegalHoldStatus.DISABLED))) @@ -430,12 +436,12 @@ class LegalHoldHandlerTest { // then verify(arrangement.legalHoldSystemMessagesHandler) .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleEnabledForConversation) - .with(eq(TestConversation.CONVERSATION.id), eq(message.timestampIso)) + .with(eq(TestConversation.CONVERSATION.id), eq(minusMilliseconds(message.timestampIso, 1))) .wasInvoked() } @OptIn(ExperimentalCoroutinesApi::class) @Test - fun givenNewMessageWithChangedLegalHoldStateAndSyncing_whenHandling_thenBufferAndHandleItWhenSyncStateIsLive() = runTest { + fun givenNewMessageWithChangedLegalHoldStateAndSyncing_whenHandlingNewMessage_thenBufferAndHandleItWhenSyncStateIsLive() = runTest { // given val syncStatesFlow = MutableStateFlow(SyncState.GatheringPendingEvents) val (arrangement, handler) = Arrangement() @@ -464,7 +470,7 @@ class LegalHoldHandlerTest { @OptIn(ExperimentalCoroutinesApi::class) @Test - fun givenNewMessageWithChangedLegalHoldStateAndSynced_whenHandling_thenHandleItRightAway() = runTest { + fun givenNewMessageWithChangedLegalHoldStateAndSynced_whenHandlingNewMessage_thenHandleItRightAway() = runTest { // given val (arrangement, handler) = Arrangement() .withGetConversationsByUserIdSuccess(listOf(conversation(legalHoldStatus = Conversation.LegalHoldStatus.DISABLED))) @@ -484,6 +490,145 @@ class LegalHoldHandlerTest { .wasInvoked() } + @Test + fun givenHandleMessageSendFailureFails_whenHandlingMessageSendFailure_thenPropagateThisFailure() = runTest { + // given + val conversationId = TestConversation.CONVERSATION.id + val failure = CoreFailure.Unknown(null) + val timestampIso = "2022-03-30T15:36:00.000Z" + val handleFailure: () -> Either = { Either.Left(failure) } + val (arrangement, handler) = Arrangement() + .arrange() + // when + val result = handler.handleMessageSendFailure(conversationId, timestampIso, handleFailure) + // then + result.shouldFail() { + assertEquals(failure, it) + } + } + + @Test + fun givenLegalHoldEnabledForConversation_whenHandlingMessageSendFailure_thenHandleItProperlyAndReturnTrue() = runTest { + // given + val conversationId = TestConversation.CONVERSATION.id + val timestampIso = "2022-03-30T15:36:00.000Z" + val handleFailure: () -> Either = { Either.Right(Unit) } + val membersHavingLegalHoldClientBefore = emptyList() + val membersHavingLegalHoldClientAfter = listOf(TestUser.OTHER_USER_ID) + val (arrangement, handler) = Arrangement() + .withMembersHavingLegalHoldClientSuccess(membersHavingLegalHoldClientBefore, membersHavingLegalHoldClientAfter) + .withUpdateLegalHoldStatusSuccess(true) + .arrange() + // when + val result = handler.handleMessageSendFailure(conversationId, timestampIso, handleFailure) + // then + result.shouldSucceed() { + assertEquals(true, it) + } + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleEnabledForConversation) + .with(eq(conversationId), any()) + .wasInvoked() + } + + @Test + fun givenLegalHoldDisabledForConversation_whenHandlingMessageSendFailure_thenHandleItProperlyAndReturnFalse() = runTest { + // given + val conversationId = TestConversation.CONVERSATION.id + val timestampIso = "2022-03-30T15:36:00.000Z" + val handleFailure: () -> Either = { Either.Right(Unit) } + val membersHavingLegalHoldClientBefore = listOf(TestUser.OTHER_USER_ID) + val membersHavingLegalHoldClientAfter = emptyList() + val (arrangement, handler) = Arrangement() + .withMembersHavingLegalHoldClientSuccess(membersHavingLegalHoldClientBefore, membersHavingLegalHoldClientAfter) + .withUpdateLegalHoldStatusSuccess(true) + .arrange() + // when + val result = handler.handleMessageSendFailure(conversationId, timestampIso, handleFailure) + // then + result.shouldSucceed() { + assertEquals(false, it) + } + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleDisabledForConversation) + .with(eq(conversationId), any()) + .wasInvoked() + } + @Test + fun givenLegalHoldChangedForConversation_whenHandlingMessageSendFailure_thenUseTimestampOfMessageMinus1msForSystemMessage() = runTest { + // given + val conversationId = TestConversation.CONVERSATION.id + val timestampIso = "2022-03-30T15:36:00.000Z" + val handleFailure: () -> Either = { Either.Right(Unit) } + val membersHavingLegalHoldClientBefore = emptyList() + val membersHavingLegalHoldClientAfter = listOf(TestUser.OTHER_USER_ID) + val (arrangement, handler) = Arrangement() + .withMembersHavingLegalHoldClientSuccess(membersHavingLegalHoldClientBefore, membersHavingLegalHoldClientAfter) + .withUpdateLegalHoldStatusSuccess(true) + .arrange() + // when + val result = handler.handleMessageSendFailure(conversationId, timestampIso, handleFailure) + // then + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleEnabledForConversation) + .with(eq(conversationId), eq(minusMilliseconds(timestampIso, 1))) + .wasInvoked() + } + + @Test + fun givenLegalHoldNotChangedForConversation_whenHandlingMessageSendFailure_thenHandleItProperlyAndReturnFalse() = runTest { + // given + val conversationId = TestConversation.CONVERSATION.id + val timestampIso = "2022-03-30T15:36:00.000Z" + val handleFailure: () -> Either = { Either.Right(Unit) } + val membersHavingLegalHoldClientBefore = listOf(TestUser.OTHER_USER_ID) + val membersHavingLegalHoldClientAfter = listOf(TestUser.OTHER_USER_ID) + val (arrangement, handler) = Arrangement() + .withMembersHavingLegalHoldClientSuccess(membersHavingLegalHoldClientBefore, membersHavingLegalHoldClientAfter) + .withUpdateLegalHoldStatusSuccess(false) + .arrange() + // when + val result = handler.handleMessageSendFailure(conversationId, timestampIso, handleFailure) + // then + result.shouldSucceed() { + assertEquals(false, it) + } + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleDisabledForConversation) + .with(eq(conversationId), any()) + .wasNotInvoked() + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleEnabledForConversation) + .with(eq(conversationId), any()) + .wasNotInvoked() + } + + @Test + fun givenLegalHoldChangedForMembers_whenHandlingMessageSendFailure_thenHandleItProperly() = runTest { + // given + val conversationId = TestConversation.CONVERSATION.id + val timestampIso = "2022-03-30T15:36:00.000Z" + val handleFailure: () -> Either = { Either.Right(Unit) } + val membersHavingLegalHoldClientBefore = listOf(TestUser.OTHER_USER_ID) + val membersHavingLegalHoldClientAfter = listOf(TestUser.OTHER_USER_ID_2) + val (arrangement, handler) = Arrangement() + .withMembersHavingLegalHoldClientSuccess(membersHavingLegalHoldClientBefore, membersHavingLegalHoldClientAfter) + .withUpdateLegalHoldStatusSuccess() + .arrange() + // when + val result = handler.handleMessageSendFailure(conversationId, timestampIso, handleFailure) + // then + result.shouldSucceed() + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleDisabledForUser) + .with(eq(TestUser.OTHER_USER_ID)) + .wasInvoked() + verify(arrangement.legalHoldSystemMessagesHandler) + .suspendFunction(arrangement.legalHoldSystemMessagesHandler::handleEnabledForUser) + .with(eq(TestUser.OTHER_USER_ID_2)) + .wasInvoked() + } + private class Arrangement { @Mock @@ -574,6 +719,12 @@ class LegalHoldHandlerTest { .whenInvokedWith(any()) .thenReturn(Either.Right(result)) } + fun withMembersHavingLegalHoldClientSuccess(vararg result: List) = apply { + given(membersHavingLegalHoldClient) + .suspendFunction(membersHavingLegalHoldClient::invoke) + .whenInvokedWith(any()) + .thenReturnSequentially(*result.map { Either.Right(it) }.toTypedArray()) + } fun withUpdateLegalHoldStatusSuccess(isChanged: Boolean = true) = apply { given(conversationRepository) .suspendFunction(conversationRepository::updateLegalHoldStatus) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ClientRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ClientRepositoryArrangement.kt index f911f194af8..be89cdc6ea6 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ClientRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/ClientRepositoryArrangement.kt @@ -23,6 +23,7 @@ import com.wire.kalium.logic.data.client.OtherUserClient import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.functional.Either +import com.wire.kalium.persistence.dao.client.InsertClientParam import io.mockative.Mock import io.mockative.any import io.mockative.given @@ -47,6 +48,10 @@ internal interface ClientRepositoryArrangement { result: Either, mapUserToClientId: Matcher>> = any() ) + fun withStoreUserClientListAndRemoveRedundantClients( + result: Either, + clients: Matcher> = any() + ) } internal open class ClientRepositoryArrangementImpl : ClientRepositoryArrangement { @@ -97,4 +102,14 @@ internal open class ClientRepositoryArrangementImpl : ClientRepositoryArrangemen .whenInvokedWith(mapUserToClientId) .thenReturn(result) } + + override fun withStoreUserClientListAndRemoveRedundantClients( + result: Either, + clients: Matcher> + ) { + given(clientRepository) + .suspendFunction(clientRepository::storeUserClientListAndRemoveRedundantClients) + .whenInvokedWith(any()) + .thenReturn(result) + } }