diff --git a/workflows/src/main/kotlin/com/r3/corda/lib/accounts/workflows/services/KeyManagementBackedAccountService.kt b/workflows/src/main/kotlin/com/r3/corda/lib/accounts/workflows/services/KeyManagementBackedAccountService.kt index e0a1720e..0875bc2c 100644 --- a/workflows/src/main/kotlin/com/r3/corda/lib/accounts/workflows/services/KeyManagementBackedAccountService.kt +++ b/workflows/src/main/kotlin/com/r3/corda/lib/accounts/workflows/services/KeyManagementBackedAccountService.kt @@ -17,6 +17,8 @@ import net.corda.core.internal.concurrent.doneFuture import net.corda.core.node.AppServiceHub import net.corda.core.node.services.CordaService import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.MAX_PAGE_SIZE +import net.corda.core.node.services.vault.PageSpecification import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.contextLogger import java.security.PublicKey @@ -31,7 +33,8 @@ class KeyManagementBackedAccountService(val services: AppServiceHub) : AccountSe } override fun accountsForHost(host: Party): List> { - return services.vaultService.queryBy(accountBaseCriteria.and(accountHostCriteria(host))).states + return services.vaultService + .queryBy(accountBaseCriteria.and(accountHostCriteria(host)), PageSpecification(1, MAX_PAGE_SIZE)).states } override fun ourAccounts(): List> { @@ -39,7 +42,7 @@ class KeyManagementBackedAccountService(val services: AppServiceHub) : AccountSe } override fun allAccounts(): List> { - return services.vaultService.queryBy(accountBaseCriteria).states + return services.vaultService.queryBy(accountBaseCriteria, PageSpecification(1, MAX_PAGE_SIZE)).states } override fun accountInfo(id: UUID): StateAndRef? { diff --git a/workflows/src/test/kotlin/com/r3/corda/lib/accounts/workflows/test/GetAccountsFlowTests.kt b/workflows/src/test/kotlin/com/r3/corda/lib/accounts/workflows/test/GetAccountsFlowTests.kt index 9c6f4ebf..fad475f0 100644 --- a/workflows/src/test/kotlin/com/r3/corda/lib/accounts/workflows/test/GetAccountsFlowTests.kt +++ b/workflows/src/test/kotlin/com/r3/corda/lib/accounts/workflows/test/GetAccountsFlowTests.kt @@ -5,6 +5,8 @@ import com.r3.corda.lib.accounts.workflows.flows.CreateAccount import com.r3.corda.lib.accounts.workflows.flows.RequestAccountInfo import com.r3.corda.lib.accounts.workflows.flows.ShareAccountInfo import com.r3.corda.lib.accounts.workflows.internal.accountService +import net.corda.core.contracts.StateAndRef +import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE import net.corda.core.utilities.getOrThrow import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.node.MockNetwork @@ -17,6 +19,7 @@ import org.junit.After import org.junit.Assert import org.junit.Before import org.junit.Test +import org.junit.jupiter.api.assertDoesNotThrow import java.util.* import kotlin.test.assertEquals @@ -84,6 +87,66 @@ class GetAccountsFlowTests { } + + @Test + fun `should lookup all hosted accounts when the hosted accounts number exceeds DEFAULT_PAGE_SIZE`() { + val hostedAccounts = mutableListOf>() + (1..(DEFAULT_PAGE_SIZE + 1)).forEach { + hostedAccounts.add(a.startFlow(CreateAccount("Test_AccountA_$it")).runAndGet(network)) + } + val account4 = b.startFlow(CreateAccount("Stefano_Account3")).runAndGet(network) + + b.startFlow(ShareAccountInfo(account4, listOf(a.identity()))).runAndGet(network) + + val accountService = a.services.accountService + + a.transaction { + assertDoesNotThrow { val a = accountService.ourAccounts() } + Assert.assertThat(accountService.accountInfo(account4.uuid), `is`(account4)) + Assert.assertThat(accountService.ourAccounts(), containsInAnyOrder(*hostedAccounts.toTypedArray())) + Assert.assertThat(accountService.ourAccounts(), not(hasItem(account4))) + } + + } + + @Test + fun `should lookup all accounts when the number of accounts exceeds DEFAULT_PAGE_SIZE`() { + val account1 = a.startFlow(CreateAccount("A_Account1")).runAndGet(network) + val account2 = a.startFlow(CreateAccount("A_Account2")).runAndGet(network) + val account3 = a.startFlow(CreateAccount("A_Account3")).runAndGet(network) + val accountsHostedByB = mutableListOf>() + (1..DEFAULT_PAGE_SIZE).forEach { + val account = b.startFlow(CreateAccount("B_Account$it")).runAndGet(network) + accountsHostedByB.add(account) + b.startFlow(ShareAccountInfo(account, listOf(a.identity()))).runAndGet(network) + } + + val accountService = a.services.accountService + + a.transaction { + assertDoesNotThrow { accountService.allAccounts() } + Assert.assertThat(accountService.allAccounts(), + containsInAnyOrder(*(listOf(account1, account2, account3) + accountsHostedByB).toTypedArray())) + } + } + + @Test + fun `should lookup all accounts hosted by another node when the number of accounts exceeds DEFAULT_PAGE_SIZE`() { + val accountsHostedByB = mutableListOf>() + (1..(DEFAULT_PAGE_SIZE + 1)).forEach { + val account = b.startFlow(CreateAccount("B_Account$it")).runAndGet(network) + accountsHostedByB.add(account) + b.startFlow(ShareAccountInfo(account, listOf(a.identity()))).runAndGet(network) + } + + val accountService = a.services.accountService + + a.transaction { + assertDoesNotThrow { accountService.accountsForHost(b.identity()) } + Assert.assertThat(accountService.accountsForHost(b.identity()), containsInAnyOrder(*(accountsHostedByB).toTypedArray())) + } + } + @Test fun `should be able to lookup account by UUID from service`() { val future = a.startFlow(CreateAccount("Stefano_Account"))