Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PM-15804, PM-17130: Add logic to monitor when the screen on state to ensure the vault locks properly #4618

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package com.x8bit.bitwarden.data.vault.manager

import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import android.content.IntentFilter
import com.bitwarden.core.InitOrgCryptoRequest
import com.bitwarden.core.InitUserCryptoMethod
import com.bitwarden.core.InitUserCryptoRequest
Expand Down Expand Up @@ -50,6 +54,8 @@ import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import java.time.Clock
import java.util.concurrent.ConcurrentHashMap
import kotlin.time.Duration.Companion.minutes

/**
Expand All @@ -62,6 +68,7 @@ private const val MAXIMUM_INVALID_UNLOCK_ATTEMPTS = 5
*/
@Suppress("TooManyFunctions", "LongParameterList")
class VaultLockManagerImpl(
private val clock: Clock,
private val authDiskSource: AuthDiskSource,
private val authSdkSource: AuthSdkSource,
private val vaultSdkSource: VaultSdkSource,
Expand All @@ -70,13 +77,15 @@ class VaultLockManagerImpl(
private val userLogoutManager: UserLogoutManager,
private val trustedDeviceManager: TrustedDeviceManager,
dispatcherManager: DispatcherManager,
context: Context,
) : VaultLockManager {
private val unconfinedScope = CoroutineScope(dispatcherManager.unconfined)

/**
* This [Map] tracks all active timeout [Job]s that are running using the user ID as the key.
* This [Map] tracks all active timeout [Job]s that are running and their associated data using
* the user ID as the key.
*/
private val userIdTimerJobMap = mutableMapOf<String, Job>()
private val userIdTimerJobMap: MutableMap<String, TimeoutJobData> = ConcurrentHashMap()

private val activeUserId: String? get() = authDiskSource.userState?.activeUserId

Expand All @@ -96,6 +105,10 @@ class VaultLockManagerImpl(
observeUserSwitchingChanges()
observeVaultTimeoutChanges()
observeUserLogoutResults()
context.registerReceiver(
ScreenStateBroadcastReceiver(),
IntentFilter(Intent.ACTION_SCREEN_ON),
)
}

override fun isVaultUnlocked(userId: String): Boolean =
Expand Down Expand Up @@ -363,7 +376,7 @@ class VaultLockManagerImpl(

private fun handleOnForeground() {
val userId = activeUserId ?: return
userIdTimerJobMap[userId]?.cancel()
userIdTimerJobMap.remove(key = userId)?.job?.cancel()
}

private fun observeUserSwitchingChanges() {
Expand Down Expand Up @@ -459,7 +472,7 @@ class VaultLockManagerImpl(
currentActiveUserId: String,
) {
// Make sure to clear the now-active user's timeout job.
userIdTimerJobMap[currentActiveUserId]?.cancel()
userIdTimerJobMap.remove(key = currentActiveUserId)?.job?.cancel()
// Check if the user's timeout action should be performed as we switch away.
checkForVaultTimeout(
userId = previousActiveUserId,
Expand Down Expand Up @@ -529,7 +542,7 @@ class VaultLockManagerImpl(
handleTimeoutActionWithDelay(
userId = userId,
vaultTimeoutAction = vaultTimeoutAction,
delayInMs = vaultTimeout
delayMs = vaultTimeout
.vaultTimeoutInMinutes
?.minutes
?.inWholeMilliseconds
Expand All @@ -542,20 +555,26 @@ class VaultLockManagerImpl(
}

/**
* Performs the [VaultTimeoutAction] for the given [userId] after the [delayInMs] has passed.
* Performs the [VaultTimeoutAction] for the given [userId] after the [delayMs] has passed.
*
* @see handleTimeoutAction
*/
private fun handleTimeoutActionWithDelay(
userId: String,
vaultTimeoutAction: VaultTimeoutAction,
delayInMs: Long,
delayMs: Long,
) {
userIdTimerJobMap[userId]?.cancel()
userIdTimerJobMap[userId] = unconfinedScope.launch {
delay(timeMillis = delayInMs)
handleTimeoutAction(userId = userId, vaultTimeoutAction = vaultTimeoutAction)
}
userIdTimerJobMap.remove(key = userId)?.job?.cancel()
userIdTimerJobMap[userId] = TimeoutJobData(
job = unconfinedScope.launch {
delay(timeMillis = delayMs)
userIdTimerJobMap.remove(key = userId)
handleTimeoutAction(userId = userId, vaultTimeoutAction = vaultTimeoutAction)
},
vaultTimeoutAction = vaultTimeoutAction,
startTimeMs = clock.millis(),
durationMs = delayMs,
)
}

/**
Expand Down Expand Up @@ -601,6 +620,37 @@ class VaultLockManagerImpl(
return (accounts.find { it.userId == userId }?.isLoggedIn) == false
}

/**
* A custom [BroadcastReceiver] that listens for when the screen is powered on and restarts the
* vault timeout jobs to ensure they complete at the correct time.
*
* This is necessary because the [delay] function in a coroutine will not keep accurate time
* when the screen is off. We do not cancel the job when the screen is off, this allows the
* job to complete as-soon-as possible if the screen is powered off for an extended period.
*/
private inner class ScreenStateBroadcastReceiver : BroadcastReceiver() {
override fun onReceive(context: Context, intent: Intent) {
userIdTimerJobMap.map { (userId, data) ->
handleTimeoutActionWithDelay(
userId = userId,
vaultTimeoutAction = data.vaultTimeoutAction,
delayMs = data.durationMs - (clock.millis() - data.startTimeMs)
.coerceAtLeast(minimumValue = 0L),
)
}
}
}

/**
* A wrapper class containing all relevant data concerning a timeout action [Job].
*/
private data class TimeoutJobData(
val job: Job,
val vaultTimeoutAction: VaultTimeoutAction,
val startTimeMs: Long,
val durationMs: Long,
)

/**
* Helper sealed class which denotes the reason to check the vault timeout.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ object VaultManagerModule {
@Provides
@Singleton
fun provideVaultLockManager(
@ApplicationContext context: Context,
clock: Clock,
authDiskSource: AuthDiskSource,
authSdkSource: AuthSdkSource,
vaultSdkSource: VaultSdkSource,
Expand All @@ -81,6 +83,8 @@ object VaultManagerModule {
trustedDeviceManager: TrustedDeviceManager,
): VaultLockManager =
VaultLockManagerImpl(
context = context,
clock = clock,
authDiskSource = authDiskSource,
authSdkSource = authSdkSource,
vaultSdkSource = vaultSdkSource,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.x8bit.bitwarden.data.vault.manager

import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import app.cash.turbine.test
import com.bitwarden.core.InitOrgCryptoRequest
import com.bitwarden.core.InitUserCryptoMethod
Expand Down Expand Up @@ -36,6 +39,7 @@ import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.runs
import io.mockk.slot
import io.mockk.verify
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
Expand All @@ -49,11 +53,18 @@ import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import java.time.Clock
import java.time.Instant
import java.time.ZoneOffset
import java.time.ZonedDateTime

@OptIn(ExperimentalCoroutinesApi::class)
@Suppress("LargeClass")
class VaultLockManagerTest {
private val broadcastReceiver = slot<BroadcastReceiver>()
private val context: Context = mockk {
every { registerReceiver(capture(broadcastReceiver), any()) } returns null
}
private val fakeAuthDiskSource = FakeAuthDiskSource()
private val fakeAppStateManager = FakeAppStateManager()
private val authSdkSource: AuthSdkSource = mockk {
Expand Down Expand Up @@ -88,6 +99,8 @@ class VaultLockManagerTest {
private val fakeDispatcherManager = FakeDispatcherManager(unconfined = testDispatcher)

private val vaultLockManager: VaultLockManager = VaultLockManagerImpl(
context = context,
clock = FIXED_CLOCK,
authDiskSource = fakeAuthDiskSource,
authSdkSource = authSdkSource,
vaultSdkSource = vaultSdkSource,
Expand All @@ -98,6 +111,53 @@ class VaultLockManagerTest {
dispatcherManager = fakeDispatcherManager,
)

@Test
fun `broadcast receiver should be registered on initialization`() {
verify(exactly = 1) {
context.registerReceiver(any(), any())
}
}

@Test
fun `broadcast intent should reset active job`() {
setAccountTokens()
fakeAuthDiskSource.userState = MOCK_USER_STATE

// Setup state as unlocked
mutableVaultTimeoutStateFlow.value = VaultTimeout.OneMinute
mutableVaultTimeoutActionStateFlow.value = VaultTimeoutAction.LOCK
fakeAppStateManager.appForegroundState = AppForegroundState.FOREGROUNDED
verifyUnlockedVaultBlocking(userId = USER_ID)
assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))

// Background the app
fakeAppStateManager.appForegroundState = AppForegroundState.BACKGROUNDED

// Advance by 30 seconds (half of what is required to lock the app)
testDispatcher.scheduler.advanceTimeBy(delayTimeMillis = 30 * 1000L)

// Still unlocked
assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))

// Receive the screen on event
broadcastReceiver.captured.onReceive(context, Intent())

// Still unlocked
assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))

// Because the test clock is fixed, this should mean that we need to advance the clock a
// full minute to get the vault to lock.
testDispatcher.scheduler.advanceTimeBy(delayTimeMillis = 30 * 1000L)

// Still unlocked
assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))

testDispatcher.scheduler.advanceTimeBy(delayTimeMillis = 31 * 1000L)

// Finally locked
assertFalse(vaultLockManager.isVaultUnlocked(USER_ID))
}

@Test
fun `vaultStateEventFlow should emit Locked event when vault state changes to locked`() =
runTest {
Expand Down Expand Up @@ -1587,6 +1647,11 @@ class VaultLockManagerTest {
}
}

private val FIXED_CLOCK: Clock = Clock.fixed(
Instant.parse("2023-10-27T12:00:00Z"),
ZoneOffset.UTC,
)

private const val USER_ID = "mockId-1"

private val MOCK_TIMEOUTS = VaultTimeout.Type.entries.map {
Expand Down