Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unified push unregister #3877

Merged
merged 8 commits into from
Nov 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ class LoggedInPresenter @Inject constructor(
.also { Timber.tag(pusherTag.value).w("No distributors available") }
.also {
// In this case, consider the push provider is chosen.
pushService.selectPushProvider(matrixClient, pushProvider)
pushService.selectPushProvider(matrixClient.sessionId, pushProvider)
}
.also { pusherRegistrationState.value = AsyncData.Failure(PusherRegistrationFailure.NoDistributorsAvailable()) }
pushService.registerWith(matrixClient, pushProvider, distributor)
} else {
val currentPushDistributor = currentPushProvider.getCurrentDistributor(matrixClient)
val currentPushDistributor = currentPushProvider.getCurrentDistributor(matrixClient.sessionId)
if (currentPushDistributor == null) {
Timber.tag(pusherTag.value).d("Register with the first available distributor")
val distributor = currentPushProvider.getDistributors().firstOrNull()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class LoggedInPresenterTest {
val lambda = lambdaRecorder<MatrixClient, PushProvider, Distributor, Result<Unit>> { _, _, _ ->
Result.success(Unit)
}
val selectPushProviderLambda = lambdaRecorder<MatrixClient, PushProvider, Unit> { _, _ -> }
val selectPushProviderLambda = lambdaRecorder<SessionId, PushProvider, Unit> { _, _ -> }
val sessionVerificationService = FakeSessionVerificationService(
initialSessionVerifiedStatus = SessionVerifiedStatus.Verified
)
Expand Down Expand Up @@ -408,8 +408,8 @@ class LoggedInPresenterTest {
selectPushProviderLambda.assertions()
.isCalledOnce()
.with(
// MatrixClient
any(),
// SessionId
value(A_SESSION_ID),
// PushProvider
value(pushProvider),
)
Expand Down Expand Up @@ -481,7 +481,7 @@ class LoggedInPresenterTest {
registerWithLambda: (MatrixClient, PushProvider, Distributor) -> Result<Unit> = { _, _, _ ->
Result.success(Unit)
},
selectPushProviderLambda: (MatrixClient, PushProvider) -> Unit = { _, _ -> lambdaError() },
selectPushProviderLambda: (SessionId, PushProvider) -> Unit = { _, _ -> lambdaError() },
currentPushProvider: () -> PushProvider? = { null },
setIgnoreRegistrationErrorLambda: (SessionId, Boolean) -> Unit = { _, _ -> lambdaError() },
): PushService {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class NotificationSettingsPresenter @Inject constructor(

LaunchedEffect(refreshPushProvider) {
val p = pushService.getCurrentPushProvider()
val name = p?.getCurrentDistributor(matrixClient)?.name
val name = p?.getCurrentDistributor(matrixClient.sessionId)?.name
currentDistributorName = if (name != null) {
AsyncData.Success(name)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ interface PushService {
* To be used when there is no distributor available.
*/
suspend fun selectPushProvider(
matrixClient: MatrixClient,
sessionId: SessionId,
pushProvider: PushProvider,
)

Expand Down
2 changes: 2 additions & 0 deletions libraries/push/impl/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies {
implementation(projects.libraries.matrix.api)
implementation(projects.libraries.matrixui)
implementation(projects.libraries.preferences.api)
implementation(projects.libraries.sessionStorage.api)
implementation(projects.libraries.uiStrings)
implementation(projects.libraries.troubleshoot.api)
implementation(projects.features.call.api)
Expand All @@ -64,6 +65,7 @@ dependencies {
testImplementation(libs.coroutines.test)
testImplementation(projects.libraries.matrix.test)
testImplementation(projects.libraries.preferences.test)
testImplementation(projects.libraries.sessionStorage.test)
testImplementation(projects.libraries.push.test)
testImplementation(projects.libraries.pushproviders.test)
testImplementation(projects.libraries.pushstore.test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package io.element.android.libraries.push.impl

import com.squareup.anvil.annotations.ContributesBinding
import io.element.android.libraries.di.AppScope
import io.element.android.libraries.di.SingleIn
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.api.GetCurrentPushProvider
Expand All @@ -17,17 +18,27 @@ import io.element.android.libraries.push.impl.test.TestPush
import io.element.android.libraries.pushproviders.api.Distributor
import io.element.android.libraries.pushproviders.api.PushProvider
import io.element.android.libraries.pushstore.api.UserPushStoreFactory
import io.element.android.libraries.pushstore.api.clientsecret.PushClientSecretStore
import io.element.android.libraries.sessionstorage.api.observer.SessionListener
import io.element.android.libraries.sessionstorage.api.observer.SessionObserver
import kotlinx.coroutines.flow.Flow
import timber.log.Timber
import javax.inject.Inject

@ContributesBinding(AppScope::class)
@ContributesBinding(AppScope::class, boundType = PushService::class)
@SingleIn(AppScope::class)
class DefaultPushService @Inject constructor(
private val testPush: TestPush,
private val userPushStoreFactory: UserPushStoreFactory,
private val pushProviders: Set<@JvmSuppressWildcards PushProvider>,
private val getCurrentPushProvider: GetCurrentPushProvider,
) : PushService {
private val sessionObserver: SessionObserver,
private val pushClientSecretStore: PushClientSecretStore,
) : PushService, SessionListener {
init {
observeSessions()
}

override suspend fun getCurrentPushProvider(): PushProvider? {
val currentPushProvider = getCurrentPushProvider.getCurrentPushProvider()
return pushProviders.find { it.name == currentPushProvider }
Expand All @@ -47,7 +58,7 @@ class DefaultPushService @Inject constructor(
val userPushStore = userPushStoreFactory.getOrCreate(matrixClient.sessionId)
val currentPushProviderName = userPushStore.getPushProviderName()
val currentPushProvider = pushProviders.find { it.name == currentPushProviderName }
val currentDistributorValue = currentPushProvider?.getCurrentDistributor(matrixClient)?.value
val currentDistributorValue = currentPushProvider?.getCurrentDistributor(matrixClient.sessionId)?.value
if (currentPushProviderName != pushProvider.name || currentDistributorValue != distributor.value) {
// Unregister previous one if any
currentPushProvider
Expand All @@ -65,11 +76,11 @@ class DefaultPushService @Inject constructor(
}

override suspend fun selectPushProvider(
matrixClient: MatrixClient,
sessionId: SessionId,
pushProvider: PushProvider,
) {
Timber.d("Select ${pushProvider.name}")
val userPushStore = userPushStoreFactory.getOrCreate(matrixClient.sessionId)
val userPushStore = userPushStoreFactory.getOrCreate(sessionId)
userPushStore.setPushProviderName(pushProvider.name)
}

Expand All @@ -87,4 +98,31 @@ class DefaultPushService @Inject constructor(
testPush.execute(config)
return true
}

private fun observeSessions() {
sessionObserver.addListener(this)
}

override suspend fun onSessionCreated(userId: String) {
// Nothing to do
}

/**
* The session has been deleted.
* In this case, this is not necessary to unregister the pusher from the homeserver,
* but we need to do some cleanup locally.
* The current push provider may want to take action, and we need to
* cleanup the stores.
*/
override suspend fun onSessionDeleted(userId: String) {
val sessionId = SessionId(userId)
val userPushStore = userPushStoreFactory.getOrCreate(sessionId)
val currentPushProviderName = userPushStore.getPushProviderName()
val currentPushProvider = pushProviders.find { it.name == currentPushProviderName }
// Cleanup the current push provider. They may need the client secret, so delete the secret after.
currentPushProvider?.onSessionDeleted(sessionId)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix 2: let the push provider do any necessary clean up when a session is removed.

Only after cleanup the stores related to the removed session.

// Now we can safely reset the stores.
pushClientSecretStore.resetSecret(sessionId)
userPushStore.reset()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package io.element.android.libraries.push.impl

import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.test.AN_EXCEPTION
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.FakeMatrixClient
Expand All @@ -22,8 +23,12 @@ import io.element.android.libraries.pushproviders.api.PushProvider
import io.element.android.libraries.pushproviders.test.FakePushProvider
import io.element.android.libraries.pushproviders.test.aCurrentUserPushConfig
import io.element.android.libraries.pushstore.api.UserPushStoreFactory
import io.element.android.libraries.pushstore.api.clientsecret.PushClientSecretStore
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStore
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStoreFactory
import io.element.android.libraries.pushstore.test.userpushstore.clientsecret.InMemoryPushClientSecretStore
import io.element.android.libraries.sessionstorage.api.observer.SessionObserver
import io.element.android.libraries.sessionstorage.test.observer.NoOpSessionObserver
import io.element.android.tests.testutils.lambda.lambdaRecorder
import io.element.android.tests.testutils.lambda.value
import kotlinx.coroutines.flow.first
Expand Down Expand Up @@ -210,17 +215,87 @@ class DefaultPushServiceTest {
assertThat(defaultPushService.ignoreRegistrationError(A_SESSION_ID).first()).isTrue()
}

@Test
fun `onSessionCreated is noop`() = runTest {
val defaultPushService = createDefaultPushService()
defaultPushService.onSessionCreated(A_SESSION_ID.value)
}

@Test
fun `onSessionDeleted should transmit the info to the current push provider and cleanup the stores`() = runTest {
val onSessionDeletedLambda = lambdaRecorder<SessionId, Unit> { }
val aCurrentPushProvider = FakePushProvider(
name = "aCurrentPushProvider",
onSessionDeletedLambda = onSessionDeletedLambda,
)
val userPushStore = FakeUserPushStore(
pushProviderName = aCurrentPushProvider.name,
)
val pushClientSecretStore = InMemoryPushClientSecretStore()
val defaultPushService = createDefaultPushService(
pushProviders = setOf(aCurrentPushProvider),
getCurrentPushProvider = FakeGetCurrentPushProvider(currentPushProvider = aCurrentPushProvider.name),
userPushStoreFactory = FakeUserPushStoreFactory(
userPushStore = { userPushStore },
),
pushClientSecretStore = pushClientSecretStore,
)
defaultPushService.onSessionDeleted(A_SESSION_ID.value)
assertThat(userPushStore.getPushProviderName()).isNull()
assertThat(pushClientSecretStore.getSecret(A_SESSION_ID)).isNull()
onSessionDeletedLambda.assertions().isCalledOnce().with(value(A_SESSION_ID))
}

@Test
fun `onSessionDeleted when there is no push provider should just cleanup the stores`() = runTest {
val userPushStore = FakeUserPushStore(
pushProviderName = null,
)
val pushClientSecretStore = InMemoryPushClientSecretStore()
val defaultPushService = createDefaultPushService(
pushProviders = emptySet(),
getCurrentPushProvider = FakeGetCurrentPushProvider(currentPushProvider = null),
userPushStoreFactory = FakeUserPushStoreFactory(
userPushStore = { userPushStore },
),
pushClientSecretStore = pushClientSecretStore,
)
defaultPushService.onSessionDeleted(A_SESSION_ID.value)
assertThat(userPushStore.getPushProviderName()).isNull()
assertThat(pushClientSecretStore.getSecret(A_SESSION_ID)).isNull()
}

@Test
fun `selectPushProvider should store the data in the store`() = runTest {
val userPushStore = FakeUserPushStore()
val defaultPushService = createDefaultPushService(
userPushStoreFactory = FakeUserPushStoreFactory(
userPushStore = { userPushStore },
),
)
val aPushProvider = FakePushProvider(
name = "aCurrentPushProvider",
)
assertThat(userPushStore.getPushProviderName()).isNull()
defaultPushService.selectPushProvider(A_SESSION_ID, aPushProvider)
assertThat(userPushStore.getPushProviderName()).isEqualTo(aPushProvider.name)
}

private fun createDefaultPushService(
testPush: TestPush = FakeTestPush(),
userPushStoreFactory: UserPushStoreFactory = FakeUserPushStoreFactory(),
pushProviders: Set<@JvmSuppressWildcards PushProvider> = emptySet(),
getCurrentPushProvider: GetCurrentPushProvider = FakeGetCurrentPushProvider(currentPushProvider = null),
sessionObserver: SessionObserver = NoOpSessionObserver(),
pushClientSecretStore: PushClientSecretStore = InMemoryPushClientSecretStore(),
): DefaultPushService {
return DefaultPushService(
testPush = testPush,
userPushStoreFactory = userPushStoreFactory,
pushProviders = pushProviders,
getCurrentPushProvider = getCurrentPushProvider,
sessionObserver = sessionObserver,
pushClientSecretStore = pushClientSecretStore,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FakePushService(
Result.success(Unit)
},
private val currentPushProvider: () -> PushProvider? = { availablePushProviders.firstOrNull() },
private val selectPushProviderLambda: suspend (MatrixClient, PushProvider) -> Unit = { _, _ -> lambdaError() },
private val selectPushProviderLambda: suspend (SessionId, PushProvider) -> Unit = { _, _ -> lambdaError() },
private val setIgnoreRegistrationErrorLambda: (SessionId, Boolean) -> Unit = { _, _ -> lambdaError() },
) : PushService {
override suspend fun getCurrentPushProvider(): PushProvider? {
Expand All @@ -50,8 +50,8 @@ class FakePushService(
}
}

override suspend fun selectPushProvider(matrixClient: MatrixClient, pushProvider: PushProvider) {
selectPushProviderLambda(matrixClient, pushProvider)
override suspend fun selectPushProvider(sessionId: SessionId, pushProvider: PushProvider) {
selectPushProviderLambda(sessionId, pushProvider)
}

private val ignoreRegistrationError = MutableStateFlow(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package io.element.android.libraries.pushproviders.api

import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId

/**
* This is the main API for this module.
Expand Down Expand Up @@ -36,13 +37,18 @@ interface PushProvider {
/**
* Return the current distributor, or null if none.
*/
suspend fun getCurrentDistributor(matrixClient: MatrixClient): Distributor?
suspend fun getCurrentDistributor(sessionId: SessionId): Distributor?

/**
* Unregister the pusher.
*/
suspend fun unregister(matrixClient: MatrixClient): Result<Unit>

/**
* To invoke when the session is deleted.
*/
suspend fun onSessionDeleted(sessionId: SessionId)

suspend fun getCurrentUserPushConfig(): CurrentUserPushConfig?

fun canRotateToken(): Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.squareup.anvil.annotations.ContributesMultibinding
import io.element.android.libraries.core.log.logger.LoggerTag
import io.element.android.libraries.di.AppScope
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.pushproviders.api.CurrentUserPushConfig
import io.element.android.libraries.pushproviders.api.Distributor
import io.element.android.libraries.pushproviders.api.PushProvider
Expand Down Expand Up @@ -51,7 +52,7 @@ class FirebasePushProvider @Inject constructor(
)
}

override suspend fun getCurrentDistributor(matrixClient: MatrixClient) = firebaseDistributor
override suspend fun getCurrentDistributor(sessionId: SessionId) = firebaseDistributor

override suspend fun unregister(matrixClient: MatrixClient): Result<Unit> {
val pushKey = firebaseStore.getFcmToken()
Expand All @@ -63,6 +64,11 @@ class FirebasePushProvider @Inject constructor(
}
}

/**
* Nothing to clean up here.
*/
override suspend fun onSessionDeleted(sessionId: SessionId) = Unit

override suspend fun getCurrentUserPushConfig(): CurrentUserPushConfig? {
return firebaseStore.getFcmToken()?.let { fcmToken ->
CurrentUserPushConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package io.element.android.libraries.pushproviders.firebase
import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.matrix.api.MatrixClient
import io.element.android.libraries.matrix.test.AN_EXCEPTION
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.FakeMatrixClient
import io.element.android.libraries.push.test.FakePusherSubscriber
import io.element.android.libraries.pushproviders.api.CurrentUserPushConfig
Expand Down Expand Up @@ -47,9 +48,9 @@ class FirebasePushProviderTest {
}

@Test
fun `getCurrentDistributor always return the unique distributor`() = runTest {
fun `getCurrentDistributor always returns the unique distributor`() = runTest {
val firebasePushProvider = createFirebasePushProvider()
val result = firebasePushProvider.getCurrentDistributor(FakeMatrixClient())
val result = firebasePushProvider.getCurrentDistributor(A_SESSION_ID)
assertThat(result).isEqualTo(Distributor("Firebase", "Firebase"))
}

Expand Down Expand Up @@ -176,6 +177,18 @@ class FirebasePushProviderTest {
lambda.assertions().isCalledOnce()
}

@Test
fun `canRotateToken should return true`() = runTest {
val firebasePushProvider = createFirebasePushProvider()
assertThat(firebasePushProvider.canRotateToken()).isTrue()
}

@Test
fun `onSessionDeleted should be noop`() = runTest {
val firebasePushProvider = createFirebasePushProvider()
firebasePushProvider.onSessionDeleted(A_SESSION_ID)
}

private fun createFirebasePushProvider(
firebaseStore: FirebaseStore = InMemoryFirebaseStore(),
pusherSubscriber: PusherSubscriber = FakePusherSubscriber(),
Expand Down
Loading
Loading