Skip to content

Commit

Permalink
Fix value restoration after removeState call in `RetainedStateHolde…
Browse files Browse the repository at this point in the history
…r` (#1931)

I discovered an issue while running additional tests with
`RetainedStateHolder` where the handling was incorrect.

```kotlin
@test
fun removedStateShouldNotBeRestored() {
  var increment = 0
  val screen = Screens.Screen1
  var restorableStateHolder: RetainedStateHolder? = null
  var restorableNumberOnScreen1 = -1
  restorationTester.setContent {
    val holder = rememberRetainedStateHolder()
    restorableStateHolder = holder
    holder.RetainedStateProvider(screen.name) {
      restorableNumberOnScreen1 = rememberRetained { increment++ }
    }
  }

  composeTestRule.runOnIdle {
    assertThat(restorableNumberOnScreen1).isEqualTo(0)
    restorableNumberOnScreen1 = -1
    restorableStateHolder!!.removeState(screen.name)
  }

  restorationTester.emulateRetainedInstanceStateRestore()

  composeTestRule.runOnIdle { assertThat(restorableNumberOnScreen1).isEqualTo(1) }
}
```

In this test scenario, the `RetainedStateHolder` holds a child
`RetainedStateRegistry` for Screen1. Since `removeState` is called, the
corresponding `Entry.shouldSave` is set to false, preventing a call to
`saveValue` in the subsequent `DisposableEffect`.

#### RetainedStateHolder.kt (before fix)
```kotlin
@composable
override fun RetainedStateProvider(key: String, content: @composable (() -> Unit)) {
  CompositionLocalProvider(LocalRetainedStateRegistry provides registry) {
    ReusableContent(key) {
      val entry = remember { Entry() }
      val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() }
      CompositionLocalProvider(
        LocalRetainedStateRegistry provides childRegistry,
        LocalCanRetainChecker provides CanRetainChecker.Always,
        content = content,
      )
      DisposableEffect(Unit) {
        entries[key] = entry
        onDispose {
          if (entry.shouldSave) {
            registry.saveValue(key)
          }
          entries -= key
        }
      }
    }
  }
}
```

However, when the parent’s overall `RetainedStateRegistry.saveAll()` is
invoked, the childRegistry—lacking a specifically set
`LocalCanRetainChecker`—does not consider the `Entry.shouldSave` flag
and ends up saving the value.

To fix this, I modified the code to provide a `LocalCanRetainChecker`
for the childRegistry. Additionally, I have added few test logics for
`RetainedStateHolder`.

I wish this had been handled better in #1794; sorry for the hassle, and
please take a look when you have time!
  • Loading branch information
vulpeszerda authored Feb 13, 2025
1 parent 132526b commit 865814d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package com.slack.circuit.foundation

import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.SideEffect
import androidx.compose.runtime.key
import androidx.compose.runtime.remember
import com.slack.circuit.foundation.internal.withCompositionLocalProvider
Expand Down Expand Up @@ -32,36 +33,36 @@ private class RetainedStateHolderImpl : RetainedStateHolder, RetainedStateRegist

private val registry = RetainedStateRegistry()

private val entries = mutableMapOf<String, Entry>()
private val entryCheckers = mutableMapOf<String, EntryCanRetainChecker>()

@Composable
override fun <T> RetainedStateProvider(key: String, content: @Composable (() -> T)): T {
return withCompositionLocalProvider(LocalRetainedStateRegistry provides registry) {
key(key) {
val entry = remember { Entry() }
val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() }
withCompositionLocalProvider(
LocalRetainedStateRegistry provides childRegistry,
LocalCanRetainChecker provides CanRetainChecker.Always,
content = content,
)
.also {
DisposableEffect(Unit) {
entries[key] = entry
onDispose {
if (entry.shouldSave) {
val entryCanRetainChecker = rememberEntryCanRetainChecker()
withCompositionLocalProvider(LocalCanRetainChecker provides entryCanRetainChecker) {
val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() }
withCompositionLocalProvider(
LocalRetainedStateRegistry provides childRegistry,
LocalCanRetainChecker provides CanRetainChecker.Always,
content = content,
)
.also {
DisposableEffect(Unit) {
entryCheckers[key] = entryCanRetainChecker
onDispose {
registry.saveValue(key)
entryCheckers -= key
}
entries -= key
}
}
}
}
}
}
}

override fun removeState(key: String) {
val entry = entries[key]
val entry = entryCheckers[key]
if (entry != null) {
entry.shouldSave = false
} else {
Expand Down Expand Up @@ -91,6 +92,25 @@ private class RetainedStateHolderImpl : RetainedStateHolder, RetainedStateRegist
override fun forgetUnclaimedValues() {
registry.forgetUnclaimedValues()
}
}

@Composable
private fun rememberEntryCanRetainChecker(): EntryCanRetainChecker {
val parentChecker = LocalCanRetainChecker.current ?: CanRetainChecker.Always
val checker = remember { EntryCanRetainChecker(parentChecker) }
SideEffect { checker.update(parentChecker) }
return checker
}

private class EntryCanRetainChecker(
private var parentChecker: CanRetainChecker,
var shouldSave: Boolean = true,
) : CanRetainChecker {

fun update(parentChecker: CanRetainChecker) {
this.parentChecker = parentChecker
}

private data class Entry(var shouldSave: Boolean = true)
override fun canRetain(registry: RetainedStateRegistry): Boolean =
shouldSave && parentChecker.canRetain(registry)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.slack.circuit.retained.RetainedStateHolder
import com.slack.circuit.retained.RetainedStateRegistry
import com.slack.circuit.retained.rememberRetained
import com.slack.circuit.retained.rememberRetainedStateHolder
import kotlin.random.Random
import leakcanary.DetectLeaksAfterTestSuccess.Companion.detectLeaksAfterTestSuccessWrapping
import org.junit.Rule
import org.junit.Test
Expand Down Expand Up @@ -305,7 +306,7 @@ class RetainedStateHolderTest {
}
}

composeTestRule.runOnIdle { showFirstPage = false }
showFirstPage = false

composeTestRule.runOnIdle {
val savedData = registry.saveAll()
Expand All @@ -315,6 +316,35 @@ class RetainedStateHolderTest {

@Test
fun saveNothingWhenCanRetainCheckerReturnsFalse() {
var showFirstPage by mutableStateOf(true)
val registry = RetainedStateRegistry(emptyMap())
val canRetainChecker = CanRetainChecker { false }
composeTestRule.setContent {
CompositionLocalProvider(LocalRetainedStateRegistry provides registry) {
val holder = rememberRetainedStateHolder()
CompositionLocalProvider(LocalCanRetainChecker provides canRetainChecker) {
holder.RetainedStateProvider(showFirstPage.toString()) {
rememberRetained { Random.nextInt() }
}
}
}
}

composeTestRule.runOnIdle {
val savedData = registry.saveAll()
assertThat(savedData).isEqualTo(emptyMap<String, List<Any?>>())
}

showFirstPage = false

composeTestRule.runOnIdle {
val savedData = registry.saveAll()
assertThat(savedData).isEqualTo(emptyMap<String, List<Any?>>())
}
}

@Test
fun alwaysReinitializeWhenCanRetainCheckerReturnsFalse() {
var increment = 0
var restorableNumber = -1
val registry = RetainedStateRegistry(emptyMap())
Expand Down Expand Up @@ -351,7 +381,12 @@ class RetainedStateHolderTest {
composeTestRule.waitForIdle()
screen = Screens.Screen1

composeTestRule.runOnIdle { assertThat(restorableNumber).isEqualTo(2) }
composeTestRule.runOnIdle {
assertThat(restorableNumber).isEqualTo(2)

val savedData = registry.saveAll()
assertThat(savedData).isEqualTo(emptyMap<String, List<Any?>>())
}
}

@Test
Expand Down Expand Up @@ -396,6 +431,38 @@ class RetainedStateHolderTest {
screen = Screens.Screen1

composeTestRule.runOnIdle { assertThat(restorableNumber).isEqualTo(1) }

canRetainChecker = CanRetainChecker { false }

composeTestRule.runOnIdle {
val savedData = registry.saveAll()
assertThat(savedData).isEqualTo(emptyMap<String, List<Any?>>())
}
}

@Test
fun removedStateShouldNotBeRestored() {
var increment = 0
val screen = Screens.Screen1
var restorableStateHolder: RetainedStateHolder? = null
var restorableNumberOnScreen1 = -1
restorationTester.setContent {
val holder = rememberRetainedStateHolder()
restorableStateHolder = holder
holder.RetainedStateProvider(screen.name) {
restorableNumberOnScreen1 = rememberRetained { increment++ }
}
}

composeTestRule.runOnIdle {
assertThat(restorableNumberOnScreen1).isEqualTo(0)
restorableNumberOnScreen1 = -1
restorableStateHolder!!.removeState(screen.name)
}

restorationTester.emulateRetainedInstanceStateRestore()

composeTestRule.runOnIdle { assertThat(restorableNumberOnScreen1).isEqualTo(1) }
}

class Activity : ComponentActivity() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.CompositionLocalProvider
import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.ReusableContent
import androidx.compose.runtime.SideEffect
import androidx.compose.runtime.remember

/**
Expand Down Expand Up @@ -37,34 +38,34 @@ private class RetainedStateHolderImpl : RetainedStateHolder, RetainedStateRegist

private val registry: RetainedStateRegistry = RetainedStateRegistry()

private val entries = mutableMapOf<String, Entry>()
private val entryCheckers = mutableMapOf<String, EntryCanRetainChecker>()

@Composable
override fun RetainedStateProvider(key: String, content: @Composable (() -> Unit)) {
CompositionLocalProvider(LocalRetainedStateRegistry provides registry) {
ReusableContent(key) {
val entry = remember { Entry() }
val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() }
CompositionLocalProvider(
LocalRetainedStateRegistry provides childRegistry,
LocalCanRetainChecker provides CanRetainChecker.Always,
content = content,
)
val entryCanRetainChecker = rememberEntryCanRetainChecker()
CompositionLocalProvider(LocalCanRetainChecker provides entryCanRetainChecker) {
val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() }
CompositionLocalProvider(
LocalRetainedStateRegistry provides childRegistry,
LocalCanRetainChecker provides CanRetainChecker.Always,
content = content,
)
}
DisposableEffect(Unit) {
entries[key] = entry
entryCheckers[key] = entryCanRetainChecker
onDispose {
if (entry.shouldSave) {
registry.saveValue(key)
}
entries -= key
registry.saveValue(key)
entryCheckers -= key
}
}
}
}
}

override fun removeState(key: String) {
val entry = entries[key]
val entry = entryCheckers[key]
if (entry != null) {
entry.shouldSave = false
} else {
Expand Down Expand Up @@ -94,6 +95,25 @@ private class RetainedStateHolderImpl : RetainedStateHolder, RetainedStateRegist
override fun forgetUnclaimedValues() {
registry.forgetUnclaimedValues()
}
}

@Composable
private fun rememberEntryCanRetainChecker(): EntryCanRetainChecker {
val parentChecker = LocalCanRetainChecker.current ?: CanRetainChecker.Always
val checker = remember { EntryCanRetainChecker(parentChecker) }
SideEffect { checker.update(parentChecker) }
return checker
}

private class EntryCanRetainChecker(
private var parentChecker: CanRetainChecker,
var shouldSave: Boolean = true,
) : CanRetainChecker {

fun update(parentChecker: CanRetainChecker) {
this.parentChecker = parentChecker
}

private data class Entry(var shouldSave: Boolean = true)
override fun canRetain(registry: RetainedStateRegistry): Boolean =
shouldSave && parentChecker.canRetain(registry)
}

0 comments on commit 865814d

Please sign in to comment.