Skip to content

Commit

Permalink
feat: Indicate user with valid E2EI certificate (WPB-3228) (#2335)
Browse files Browse the repository at this point in the history
* feat: Indicate user with valid E2EI certificate

* Code style fix

* feat: Indicate user with valid E2EI certificate: review comments

* Review updates

* Review fixes

---------

Co-authored-by: Mojtaba Chenani <[email protected]>
  • Loading branch information
borichellow and mchenani authored Jan 10, 2024
1 parent 0696b8e commit 52f2033
Show file tree
Hide file tree
Showing 12 changed files with 625 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.event.Event.Conversation.MLSWelcome
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.id.QualifiedClientID
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper
Expand Down Expand Up @@ -118,6 +120,11 @@ interface MLSConversationRepository {
): Either<CoreFailure, Unit>

suspend fun getClientIdentity(clientId: ClientId): Either<CoreFailure, WireIdentity>
suspend fun getUserIdentity(userId: UserId): Either<CoreFailure, List<WireIdentity>>
suspend fun getMembersIdentities(
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>>
}

private enum class CommitStrategy {
Expand Down Expand Up @@ -551,6 +558,41 @@ internal class MLSConversationDataSource(
}
}

override suspend fun getUserIdentity(userId: UserId) =
wrapStorageRequest { conversationDAO.getMLSGroupIdByUserId(userId.toDao()) }.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value]!!
}
}
}

override suspend fun getMembersIdentities(
conversationId: ConversationId,
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>> =
wrapStorageRequest {
conversationDAO.getMLSGroupIdByConversationId(conversationId.toDao())!!
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
val userIdsAndIdentity = mutableMapOf<UserId, List<WireIdentity>>()

mlsClient.getUserIdentities(mlsGroupId, userIds.map { it.toCrypto() })
.forEach { (userIdValue, identities) ->
userIds.firstOrNull { it.value == userIdValue }?.also {
userIdsAndIdentity[it] = identities
}
}

userIdsAndIdentity
}
}
}

private suspend fun retryOnCommitFailure(
groupID: GroupID,
retryOnClientMismatch: Boolean = true,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold

/**
* This use case is used to get the e2ei certificates of all the users in Conversation.
* Return [Map] where keys are [UserId] and values - nullable [CertificateStatus] of corresponding user.
*/
interface GetMembersE2EICertificateStatusesUseCase {
suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, CertificateStatus?>
}

class GetMembersE2EICertificateStatusesUseCaseImpl internal constructor(
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetMembersE2EICertificateStatusesUseCase {
override suspend operator fun invoke(conversationId: ConversationId, userIds: List<UserId>): Map<UserId, CertificateStatus?> =
mlsConversationRepository.getMembersIdentities(conversationId, userIds).fold(
{ mapOf() },
{
it.mapValues { (_, identities) ->
identities.getUserCertificateStatus(pemCertificateDecoder)
}
}
)
}

/**
* @return null if list is empty;
* [CertificateStatus.REVOKED] if any certificate is revoked;
* [CertificateStatus.EXPIRED] if any certificate is expired;
* [CertificateStatus.VALID] otherwise.
*/
fun List<WireIdentity>.getUserCertificateStatus(pemCertificateDecoder: PemCertificateDecoder): CertificateStatus? {
val certificates = this.map { pemCertificateDecoder.decode(it.certificate, it.status) }
return if (certificates.isEmpty()) {
null
} else if (certificates.any { it.status == CertificateStatus.REVOKED }) {
CertificateStatus.REVOKED
} else if (certificates.any { it.status == CertificateStatus.EXPIRED }) {
CertificateStatus.EXPIRED
} else {
CertificateStatus.VALID
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.e2ei.usecase

import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.CertificateStatus
import com.wire.kalium.logic.feature.e2ei.PemCertificateDecoder
import com.wire.kalium.logic.functional.fold

/**
* This use case is used to get the e2ei certificate status of specific user
*/
interface GetUserE2eiCertificateStatusUseCase {
suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult
}

class GetUserE2eiCertificateStatusUseCaseImpl internal constructor(
private val mlsConversationRepository: MLSConversationRepository,
private val pemCertificateDecoder: PemCertificateDecoder
) : GetUserE2eiCertificateStatusUseCase {
override suspend operator fun invoke(userId: UserId): GetUserE2eiCertificateStatusResult =
mlsConversationRepository.getUserIdentity(userId).fold(
{
GetUserE2eiCertificateStatusResult.Failure.NotActivated
},
{ identities ->
identities.getUserCertificateStatus(pemCertificateDecoder)?.let {
GetUserE2eiCertificateStatusResult.Success(it)
} ?: GetUserE2eiCertificateStatusResult.Failure.NotActivated
}
)
}

sealed class GetUserE2eiCertificateStatusResult {
class Success(val status: CertificateStatus) : GetUserE2eiCertificateStatusResult()
sealed class Failure : GetUserE2eiCertificateStatusResult() {
data object NotActivated : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetMembersE2EICertificateStatusesUseCaseImpl
import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetUserE2eiCertificateStatusUseCaseImpl
import com.wire.kalium.logic.feature.message.MessageSender
import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCase
import com.wire.kalium.logic.feature.publicuser.GetAllContactsUseCaseImpl
Expand Down Expand Up @@ -113,10 +117,21 @@ class UserScope internal constructor(
private val pemCertificateDecoderImpl by lazy { PemCertificateDecoderImpl() }
val getPublicAsset: GetAvatarAssetUseCase get() = GetAvatarAssetUseCaseImpl(assetRepository, userRepository)
val enrollE2EI: EnrollE2EIUseCase get() = EnrollE2EIUseCaseImpl(e2EIRepository)
val getE2EICertificate: GetE2eiCertificateUseCase get() = GetE2eiCertificateUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getE2EICertificate: GetE2eiCertificateUseCase
get() = GetE2eiCertificateUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getUserE2eiCertificateStatus: GetUserE2eiCertificateStatusUseCase
get() = GetUserE2eiCertificateStatusUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val getMembersE2EICertificateStatuses: GetMembersE2EICertificateStatusesUseCase
get() = GetMembersE2EICertificateStatusesUseCaseImpl(
mlsConversationRepository = mlsConversationRepository,
pemCertificateDecoder = pemCertificateDecoderImpl
)
val deleteAsset: DeleteAssetUseCase get() = DeleteAssetUseCaseImpl(assetRepository)
val setUserHandle: SetUserHandleUseCase get() = SetUserHandleUseCase(accountRepository, validateUserHandleUseCase, syncManager)
val getAllKnownUsers: GetAllContactsUseCase get() = GetAllContactsUseCaseImpl(userRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,71 @@ class MLSConversationRepositoryTest {
.wasInvoked(once)
}

@Test
fun givenUserId_whenGetMLSGroupIdByUserIdSucceed_thenReturnsIdentities() = runTest {
val groupId = "some_group"
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(
mapOf(
TestUser.USER_ID.value to listOf(WIRE_IDENTITY),
"some_other_user_id" to listOf(WIRE_IDENTITY.copy(clientId = "another_client_id")),
)
)
.withGetMLSGroupIdByUserIdReturns(groupId)
.arrange()

assertEquals(Either.Right(listOf(WIRE_IDENTITY)), mlsConversationRepository.getUserIdentity(TestUser.USER_ID))

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByUserId)
.with(any())
.wasInvoked(once)
}

@Test
fun givenConversationId_whenGetMLSGroupIdByConversationIdSucceed_thenReturnsIdentities() = runTest {
val groupId = "some_group"
val member1 = TestUser.USER_ID
val member2 = TestUser.USER_ID.copy(value = "member_2_id")
val member3 = TestUser.USER_ID.copy(value = "member_3_id")
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetUserIdentitiesReturn(
mapOf(
member1.value to listOf(WIRE_IDENTITY),
member2.value to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id"))
)
)
.withGetMLSGroupIdByConversationIdReturns(groupId)
.arrange()

assertEquals(
Either.Right(
mapOf(
member1 to listOf(WIRE_IDENTITY),
member2 to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id"))
)
),
mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(member1, member2, member3))
)

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByConversationId)
.with(any())
.wasInvoked(once)
}

private class Arrangement {

@Mock
Expand Down Expand Up @@ -1512,6 +1577,27 @@ class MLSConversationRepositoryTest {
.thenReturn(verificationStatus)
}

fun withGetMLSGroupIdByUserIdReturns(result: String?) = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::getMLSGroupIdByUserId)
.whenInvokedWith(anything())
.thenReturn(result)
}

fun withGetMLSGroupIdByConversationIdReturns(result: String?) = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::getMLSGroupIdByConversationId)
.whenInvokedWith(anything())
.thenReturn(result)
}

fun withGetUserIdentitiesReturn(identitiesMap: Map<String, List<WireIdentity>>) = apply {
given(mlsClient)
.suspendFunction(mlsClient::getUserIdentities)
.whenInvokedWith(anything(), anything())
.thenReturn(identitiesMap)
}

fun arrange() = this to MLSConversationDataSource(
TestUser.SELF.id,
keyPackageRepository,
Expand Down
Loading

0 comments on commit 52f2033

Please sign in to comment.