diff --git a/ui/revenuecatui/src/main/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/viewmodel/CustomerCenterViewModel.kt b/ui/revenuecatui/src/main/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/viewmodel/CustomerCenterViewModel.kt index 64efb25c9e..5114d65515 100644 --- a/ui/revenuecatui/src/main/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/viewmodel/CustomerCenterViewModel.kt +++ b/ui/revenuecatui/src/main/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/viewmodel/CustomerCenterViewModel.kt @@ -504,24 +504,29 @@ internal class CustomerCenterViewModelImpl( override fun onNavigationButtonPressed(context: Context, onDismiss: () -> Unit) { val currentState = _state.value + // Handle special case for promotional offers first if (currentState is CustomerCenterState.Success && currentState.promotionalOfferData != null) { dismissPromotionalOffer(context, currentState.promotionalOfferData.originalPath) return } - val buttonType = state.value.navigationButtonType - if (buttonType == CustomerCenterState.NavigationButtonType.CLOSE) { - onDismiss() - return - } + + val navigationButtonType = state.value.navigationButtonType + _state.update { state -> when { + // For BACK button: Return to main screen without losing loaded data state is CustomerCenterState.Success && - state.navigationButtonType == CustomerCenterState.NavigationButtonType.BACK -> { + navigationButtonType == CustomerCenterState.NavigationButtonType.BACK -> state.resetToMainScreen() - } + // For all other cases (including CLOSE): Reset to NotLoaded else -> CustomerCenterState.NotLoaded } } + + // Call onDismiss only for the CLOSE button + if (navigationButtonType == CustomerCenterState.NavigationButtonType.CLOSE) { + onDismiss() + } } override suspend fun loadCustomerCenter() { diff --git a/ui/revenuecatui/src/test/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/data/CustomerCenterViewModelTests.kt b/ui/revenuecatui/src/test/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/data/CustomerCenterViewModelTests.kt index f04370c451..46850994c7 100644 --- a/ui/revenuecatui/src/test/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/data/CustomerCenterViewModelTests.kt +++ b/ui/revenuecatui/src/test/kotlin/com/revenuecat/purchases/ui/revenuecatui/customercenter/data/CustomerCenterViewModelTests.kt @@ -1,12 +1,9 @@ package com.revenuecat.purchases.ui.revenuecatui.customercenter.data -import android.content.Context -import androidx.compose.material3.ColorScheme import androidx.test.ext.junit.runners.AndroidJUnit4 import com.revenuecat.purchases.CustomerInfo import com.revenuecat.purchases.EntitlementInfos import com.revenuecat.purchases.ExperimentalPreviewRevenueCatPurchasesAPI -import com.revenuecat.purchases.OwnershipType import com.revenuecat.purchases.PeriodType import com.revenuecat.purchases.PurchasesAreCompletedBy import com.revenuecat.purchases.Store @@ -14,39 +11,26 @@ import com.revenuecat.purchases.SubscriptionInfo import com.revenuecat.purchases.VerificationResult import com.revenuecat.purchases.customercenter.CustomerCenterConfigData import com.revenuecat.purchases.customercenter.CustomerCenterConfigData.HelpPath -import com.revenuecat.purchases.customercenter.CustomerCenterConfigData.HelpPath.PathDetail -import com.revenuecat.purchases.customercenter.CustomerCenterConfigData.HelpPath.PathType -import com.revenuecat.purchases.customercenter.CustomerCenterConfigData.Localization import com.revenuecat.purchases.customercenter.CustomerCenterConfigData.Screen -import com.revenuecat.purchases.customercenter.CustomerCenterConfigData.Support import com.revenuecat.purchases.models.Transaction -import com.revenuecat.purchases.paywalls.PaywallData -import com.revenuecat.purchases.ui.revenuecatui.customercenter.viewmodel.CustomerCenterViewModel import com.revenuecat.purchases.ui.revenuecatui.customercenter.viewmodel.CustomerCenterViewModelImpl -import com.revenuecat.purchases.ui.revenuecatui.data.PaywallState -import com.revenuecat.purchases.ui.revenuecatui.data.PurchasesImpl import com.revenuecat.purchases.ui.revenuecatui.data.PurchasesType import com.revenuecat.purchases.ui.revenuecatui.data.testdata.TestData -import com.revenuecat.purchases.ui.revenuecatui.utils.DateFormatter -import com.revenuecat.purchases.ui.revenuecatui.utils.DefaultDateFormatter import io.mockk.Runs import io.mockk.clearAllMocks import io.mockk.coEvery import io.mockk.every import io.mockk.just import io.mockk.mockk -import junit.framework.TestCase.fail +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.cancel import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before import org.junit.Test import org.junit.runner.RunWith -import java.util.Collections import java.util.Date import java.util.Locale @@ -206,4 +190,125 @@ class CustomerCenterViewModelTests { job.join() } + + @Test + fun `onNavigationButtonPressed handles CLOSE and BACK buttons correctly`() = runTest { + // Setup basic mocks + every { customerInfo.activeSubscriptions } returns setOf("product-id") + every { customerInfo.nonSubscriptionTransactions } returns emptyList() + every { customerInfo.entitlements } returns EntitlementInfos( + emptyMap(), + VerificationResult.VERIFIED, + ) + every { customerInfo.subscriptionsByProductIdentifier } returns emptyMap() + + // Setup customer center data + coEvery { purchases.awaitCustomerCenterConfigData() } returns configData + coEvery { purchases.awaitCustomerInfo(any()) } returns customerInfo + coEvery { purchases.awaitGetProduct(any(), any()) } returns null + + // Setup screen with a path + val testPath = HelpPath( + id = "test_path_id", + title = "Test Path", + type = HelpPath.PathType.CUSTOM_URL + ) + + val managementScreen = Screen( + type = Screen.ScreenType.MANAGEMENT, + title = "Management", + subtitle = null, + paths = listOf(testPath) + ) + + every { configData.getManagementScreen() } returns managementScreen + + // Create the ViewModel + val model = CustomerCenterViewModelImpl( + purchases = purchases, + locale = Locale.US, + colorScheme = TestData.Constants.currentColorScheme, + isDarkMode = false + ) + + // Track state changes + val initialLoadingCompleted = CompletableDeferred() + val onDismissCalled = CompletableDeferred() + val stateResetToNotLoaded = CompletableDeferred() + val stateResetToMainScreen = CompletableDeferred() + val onDismissShouldNotBeCalled = CompletableDeferred() + + val job = launch { + model.state.collect { state -> + when (state) { + is CustomerCenterState.Success -> { + // Track initial load completion + if (!initialLoadingCompleted.isCompleted) { + initialLoadingCompleted.complete(true) + } + + // Track when state is reset to main screen (BACK button) + if (state.navigationButtonType == CustomerCenterState.NavigationButtonType.CLOSE && + !stateResetToMainScreen.isCompleted) { + stateResetToMainScreen.complete(true) + } + } + is CustomerCenterState.NotLoaded -> { + // Track when state is reset to NotLoaded (CLOSE button) + if (onDismissCalled.isCompleted && !stateResetToNotLoaded.isCompleted) { + stateResetToNotLoaded.complete(true) + } + } + else -> {} + } + } + } + + // Wait for initial setup to complete + initialLoadingCompleted.await() + + // Test CLOSE button + model.onNavigationButtonPressed(mockk()) { + onDismissCalled.complete(true) + } + + // Wait for state to be reset to NotLoaded + stateResetToNotLoaded.await() + + // Reload the state for BACK button test + model.loadCustomerCenter() + initialLoadingCompleted.await() + + // Set up state for BACK button test by displaying a feedback survey + model.pathButtonPressed( + mockk(), + HelpPath( + id = "feedback_id", + title = "Feedback", + type = HelpPath.PathType.CUSTOM_URL, + feedbackSurvey = HelpPath.PathDetail.FeedbackSurvey( + title = "Feedback", + options = emptyList() + ) + ), + null + ) + + // Test BACK button - verify onDismiss is not called + model.onNavigationButtonPressed(mockk()) { + onDismissShouldNotBeCalled.complete(true) + } + + // Wait for state to be reset to main screen + stateResetToMainScreen.await() + + // Verify the state transitions + assertThat(onDismissCalled.isCompleted).isTrue() + assertThat(onDismissShouldNotBeCalled.isCompleted).isFalse() + assertThat(stateResetToNotLoaded.isCompleted).isTrue() + assertThat(stateResetToMainScreen.isCompleted).isTrue() + + // Cancel the collection job + job.cancel() + } }