diff --git a/chia/_tests/core/data_layer/test_data_store.py b/chia/_tests/core/data_layer/test_data_store.py index 59d86f6a630d..6d15e612fac5 100644 --- a/chia/_tests/core/data_layer/test_data_store.py +++ b/chia/_tests/core/data_layer/test_data_store.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from pathlib import Path from random import Random -from typing import Any, Awaitable, Callable, Dict, List, Set, Tuple, cast +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast import aiohttp import aiosqlite @@ -22,6 +22,7 @@ from chia.data_layer.data_layer_util import ( DiffData, InternalNode, + Node, NodeType, OperationType, ProofOfInclusion, @@ -2048,3 +2049,99 @@ async def test_migration_unknown_version(data_store: DataStore) -> None: ) with pytest.raises(Exception, match="Unknown version"): await data_store.migrate_db() + + +async def _check_ancestors( + data_store: DataStore, store_id: bytes32, root_hash: bytes32 +) -> Dict[bytes32, Optional[bytes32]]: + ancestors: Dict[bytes32, Optional[bytes32]] = {} + root_node: Node = await data_store.get_node(root_hash) + queue: List[Node] = [root_node] + + while queue: + node = queue.pop(0) + if isinstance(node, InternalNode): + left_node = await data_store.get_node(node.left_hash) + right_node = await data_store.get_node(node.right_hash) + ancestors[left_node.hash] = node.hash + ancestors[right_node.hash] = node.hash + queue.append(left_node) + queue.append(right_node) + + ancestors[root_hash] = None + for node_hash, ancestor_hash in ancestors.items(): + ancestor_node = await data_store._get_one_ancestor(node_hash, store_id) + if ancestor_hash is None: + assert ancestor_node is None + else: + assert ancestor_node is not None + assert ancestor_node.hash == ancestor_hash + + return ancestors + + +@pytest.mark.anyio +async def test_build_ancestor_table(data_store: DataStore, store_id: bytes32) -> None: + num_values = 1000 + changelist: List[Dict[str, Any]] = [] + for value in range(num_values): + value_bytes = value.to_bytes(4, byteorder="big") + changelist.append({"action": "upsert", "key": value_bytes, "value": value_bytes}) + await data_store.insert_batch( + store_id=store_id, + changelist=changelist, + status=Status.PENDING, + ) + + pending_root = await data_store.get_pending_root(store_id=store_id) + assert pending_root is not None + assert pending_root.node_hash is not None + await data_store.change_root_status(pending_root, Status.COMMITTED) + await data_store.build_ancestor_table_for_latest_root(store_id=store_id) + + assert pending_root.node_hash is not None + await _check_ancestors(data_store, store_id, pending_root.node_hash) + + +@pytest.mark.anyio +async def test_sparse_ancestor_table(data_store: DataStore, store_id: bytes32) -> None: + num_values = 100 + for value in range(num_values): + value_bytes = value.to_bytes(4, byteorder="big") + await data_store.autoinsert( + key=value_bytes, + value=value_bytes, + store_id=store_id, + status=Status.COMMITTED, + ) + root = await data_store.get_tree_root(store_id=store_id) + assert root.node_hash is not None + ancestors = await _check_ancestors(data_store, store_id, root.node_hash) + + # Check the ancestor table is sparse + root_generation = root.generation + current_generation_count = 0 + previous_generation_count = 0 + for node_hash, ancestor_hash in ancestors.items(): + async with data_store.db_wrapper.reader() as reader: + if ancestor_hash is not None: + cursor = await reader.execute( + "SELECT MAX(generation) AS generation FROM ancestors WHERE hash == :hash AND ancestor == :ancestor", + {"hash": node_hash, "ancestor": ancestor_hash}, + ) + else: + cursor = await reader.execute( + "SELECT MAX(generation) AS generation FROM ancestors WHERE hash == :hash AND ancestor IS NULL", + {"hash": node_hash}, + ) + row = await cursor.fetchone() + assert row is not None + generation = row["generation"] + assert generation <= root_generation + if generation == root_generation: + current_generation_count += 1 + else: + previous_generation_count += 1 + + assert current_generation_count == 15 + assert previous_generation_count == 184 diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 7256faeeec1c..4c065e1ff5b1 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -1628,32 +1628,41 @@ async def _get_one_ancestor( return InternalNode.from_row(row=row) async def build_ancestor_table_for_latest_root(self, store_id: bytes32) -> None: - async with self.db_wrapper.writer(): + async with self.db_wrapper.writer() as writer: root = await self.get_tree_root(store_id=store_id) if root.node_hash is None: return - previous_root = await self.get_tree_root( - store_id=store_id, - generation=max(root.generation - 1, 0), - ) - if previous_root.node_hash is not None: - previous_internal_nodes: List[InternalNode] = await self.get_internal_nodes( - store_id=store_id, - root_hash=previous_root.node_hash, + await writer.execute( + """ + WITH RECURSIVE tree_from_root_hash AS ( + SELECT + node.hash, + node.left, + node.right, + NULL AS ancestor + FROM node + WHERE node.hash = :root_hash + UNION ALL + SELECT + node.hash, + node.left, + node.right, + tree_from_root_hash.hash AS ancestor + FROM node + JOIN tree_from_root_hash ON node.hash = tree_from_root_hash.left + OR node.hash = tree_from_root_hash.right ) - known_hashes: Set[bytes32] = {node.hash for node in previous_internal_nodes} - else: - known_hashes = set() - internal_nodes: List[InternalNode] = await self.get_internal_nodes( - store_id=store_id, - root_hash=root.node_hash, + INSERT OR REPLACE INTO ancestors (hash, ancestor, tree_id, generation) + SELECT + tree_from_root_hash.hash, + tree_from_root_hash.ancestor, + :tree_id, + :generation + FROM tree_from_root_hash + """, + {"root_hash": root.node_hash, "tree_id": store_id, "generation": root.generation}, ) - for node in internal_nodes: - # We already have the same values in ancestor tables, if we have the same internal node. - # Don't reinsert it so we can save DB space. - if node.hash not in known_hashes: - await self._insert_ancestor_table(node.left_hash, node.right_hash, store_id, root.generation) async def insert_root_with_ancestor_table( self, store_id: bytes32, node_hash: Optional[bytes32], status: Status = Status.PENDING