Skip to content

Commit

Permalink
use chia_rs proof of inclusion for datalayer (#19327)
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky authored Feb 27, 2025
1 parent 8611f3c commit 1567fa3
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 110 deletions.
17 changes: 8 additions & 9 deletions chia/_tests/core/data_layer/test_data_layer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest
from chia_rs.datalayer import ProofOfInclusion, ProofOfInclusionLayer

from chia._tests.util.misc import Marks, datacases, measure_runtime
from chia.data_layer.data_layer_util import (
ClearPendingRootsRequest,
ClearPendingRootsResponse,
ProofOfInclusion,
ProofOfInclusionLayer,
Root,
Side,
Status,
Expand Down Expand Up @@ -72,16 +71,16 @@ def invalid_proof_of_inclusion_fixture(request: SubRequest, side: Side) -> Proof
a_hash = bytes32(b"f" * 32)

if request.param == "bad root hash":
layers[-1] = dataclasses.replace(layers[-1], combined_hash=a_hash)
return dataclasses.replace(valid_proof_of_inclusion, layers=layers)
layers[-1] = layers[-1].replace(combined_hash=a_hash)
return valid_proof_of_inclusion.replace(layers=layers)
elif request.param == "bad other hash":
layers[1] = dataclasses.replace(layers[1], other_hash=a_hash)
return dataclasses.replace(valid_proof_of_inclusion, layers=layers)
layers[1] = layers[1].replace(other_hash=a_hash)
return valid_proof_of_inclusion.replace(layers=layers)
elif request.param == "bad other side":
layers[1] = dataclasses.replace(layers[1], other_hash_side=layers[1].other_hash_side.other())
return dataclasses.replace(valid_proof_of_inclusion, layers=layers)
layers[1] = layers[1].replace(other_hash_side=Side(layers[1].other_hash_side).other())
return valid_proof_of_inclusion.replace(layers=layers)
elif request.param == "bad node hash":
return dataclasses.replace(valid_proof_of_inclusion, node_hash=a_hash)
return valid_proof_of_inclusion.replace(node_hash=a_hash)

raise Exception(f"Unhandled parametrization: {request.param!r}") # pragma: no cover

Expand Down
12 changes: 5 additions & 7 deletions chia/_tests/core/data_layer/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
InternalNode,
Node,
OperationType,
ProofOfInclusion,
ProofOfInclusionLayer,
Root,
SerializedNode,
ServerInfo,
Expand Down Expand Up @@ -750,24 +748,24 @@ async def test_proof_of_inclusion_by_hash(data_store: DataStore, store_id: bytes
await _debug_dump(db=data_store.db_wrapper)

expected_layers = [
ProofOfInclusionLayer(
chia_rs.datalayer.ProofOfInclusionLayer(
other_hash_side=Side.RIGHT,
other_hash=bytes32.fromhex("fb66fe539b3eb2020dfbfadfd601fa318521292b41f04c2057c16fca6b947ca1"),
combined_hash=bytes32.fromhex("36cb1fc56017944213055da8cb0178fb0938c32df3ec4472f5edf0dff85ba4a3"),
),
ProofOfInclusionLayer(
chia_rs.datalayer.ProofOfInclusionLayer(
other_hash_side=Side.RIGHT,
other_hash=bytes32.fromhex("6d3af8d93db948e8b6aa4386958e137c6be8bab726db86789594b3588b35adcd"),
combined_hash=bytes32.fromhex("5f67a0ab1976e090b834bf70e5ce2a0f0a9cd474e19a905348c44ae12274d30b"),
),
ProofOfInclusionLayer(
chia_rs.datalayer.ProofOfInclusionLayer(
other_hash_side=Side.LEFT,
other_hash=bytes32.fromhex("c852ecd8fb61549a0a42f9eb9dde65e6c94a01934dbd9c1d35ab94e2a0ae58e2"),
combined_hash=bytes32.fromhex("7a5193a4e31a0a72f6623dfeb2876022ab74a48abb5966088a1c6f5451cc5d81"),
),
]

assert proof == ProofOfInclusion(node_hash=node.hash, layers=expected_layers)
assert proof == chia_rs.datalayer.ProofOfInclusion(node_hash=node.hash, layers=expected_layers)


@pytest.mark.anyio
Expand All @@ -780,7 +778,7 @@ async def test_proof_of_inclusion_by_hash_no_ancestors(data_store: DataStore, st

proof = await data_store.get_proof_of_inclusion_by_hash(node_hash=node.hash, store_id=store_id)

assert proof == ProofOfInclusion(node_hash=node.hash, layers=[])
assert proof == chia_rs.datalayer.ProofOfInclusion(node_hash=node.hash, layers=[])


@pytest.mark.anyio
Expand Down
3 changes: 1 addition & 2 deletions chia/data_layer/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, final

import aiohttp
from chia_rs.datalayer import ProofOfInclusion, ProofOfInclusionLayer
from chia_rs.sized_ints import uint32, uint64

from chia.data_layer.data_layer_errors import KeyNotFoundError
Expand All @@ -31,8 +32,6 @@
PluginRemote,
PluginStatus,
Proof,
ProofOfInclusion,
ProofOfInclusionLayer,
Root,
ServerInfo,
Side,
Expand Down
89 changes: 5 additions & 84 deletions chia/data_layer/data_layer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from enum import Enum, IntEnum
from hashlib import sha256
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union

import aiosqlite
import chia_rs.datalayer
from chia_rs.datalayer import ProofOfInclusion, ProofOfInclusionLayer
from chia_rs.sized_ints import uint8, uint64
from typing_extensions import final

Expand All @@ -25,9 +25,6 @@
from chia.data_layer.data_store import DataStore
from chia.wallet.wallet_node import WalletNode

ProofOfInclusionHint = Union["ProofOfInclusion", chia_rs.datalayer.ProofOfInclusion]
ProofOfInclusionLayerHint = Union["ProofOfInclusionLayer", chia_rs.datalayer.ProofOfInclusionLayer]


def internal_hash(left_hash: bytes32, right_hash: bytes32) -> bytes32:
# see test for the definition this is optimized from
Expand Down Expand Up @@ -252,90 +249,14 @@ def from_row(cls, row: aiosqlite.Row) -> TerminalNode:
)


@final
@dataclass(frozen=True)
class ProofOfInclusionLayer:
other_hash_side: Side
other_hash: bytes32
combined_hash: bytes32
def calculate_sibling_sides_integer(proof: ProofOfInclusion) -> int:
return sum((1 << index if layer.other_hash_side == Side.LEFT else 0) for index, layer in enumerate(proof.layers))

@classmethod
def from_internal_node(
cls,
internal_node: InternalNode,
traversal_child_hash: bytes32,
) -> ProofOfInclusionLayer:
return ProofOfInclusionLayer(
other_hash_side=internal_node.other_child_side(hash=traversal_child_hash),
other_hash=internal_node.other_child_hash(hash=traversal_child_hash),
combined_hash=internal_node.hash,
)

@classmethod
def from_hashes(cls, primary_hash: bytes32, other_hash_side: Side, other_hash: bytes32) -> ProofOfInclusionLayer:
combined_hash = calculate_internal_hash(
hash=primary_hash,
other_hash_side=other_hash_side,
other_hash=other_hash,
)

return cls(other_hash_side=other_hash_side, other_hash=other_hash, combined_hash=combined_hash)


def calculate_sibling_sides_integer(proof: ProofOfInclusionHint) -> int:
# casting to workaround this
# class C: ...
# class D: ...
#
# m: list[C | D]
# reveal_type(enumerate(m))
# # main.py:5: note: Revealed type is "builtins.enumerate[Union[__main__.C, __main__.D]]"
#
# n: list[C] | list[D]
# reveal_type(enumerate(n))
# main.py:9: note: Revealed type is "builtins.enumerate[builtins.object]"

return sum(
(1 << index if cast(ProofOfInclusionLayerHint, layer).other_hash_side == Side.LEFT else 0)
for index, layer in enumerate(proof.layers)
)


def collect_sibling_hashes(proof: ProofOfInclusionHint) -> list[bytes32]:
def collect_sibling_hashes(proof: ProofOfInclusion) -> list[bytes32]:
return [layer.other_hash for layer in proof.layers]


@dataclass(frozen=True)
class ProofOfInclusion:
node_hash: bytes32
# children before parents
layers: list[ProofOfInclusionLayer]

def root_hash(self) -> bytes32:
if len(self.layers) == 0:
return self.node_hash

return self.layers[-1].combined_hash

def valid(self) -> bool:
existing_hash = self.node_hash

for layer in self.layers:
calculated_hash = calculate_internal_hash(
hash=existing_hash, other_hash_side=layer.other_hash_side, other_hash=layer.other_hash
)

if calculated_hash != layer.combined_hash:
return False

existing_hash = calculated_hash

if existing_hash != self.root_hash():
return False

return True


@final
@dataclass(frozen=True)
class InternalNode:
Expand Down
3 changes: 2 additions & 1 deletion chia/data_layer/data_layer_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast

from chia_rs import G1Element, G2Element
from chia_rs.datalayer import ProofOfInclusion, ProofOfInclusionLayer
from chia_rs.sized_ints import uint8, uint32, uint64, uint128
from clvm.EvalError import EvalError
from typing_extensions import Unpack, final

from chia.consensus.block_record import BlockRecord
from chia.data_layer.data_layer_errors import LauncherCoinNotFoundError, OfferIntegrityError
from chia.data_layer.data_layer_util import OfferStore, ProofOfInclusion, ProofOfInclusionLayer, StoreProofs, leaf_hash
from chia.data_layer.data_layer_util import OfferStore, StoreProofs, leaf_hash
from chia.data_layer.singleton_record import SingletonRecord
from chia.protocols.wallet_protocol import CoinState
from chia.server.ws_connection import WSChiaConnection
Expand Down
7 changes: 3 additions & 4 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import aiosqlite
import chia_rs.datalayer
from chia_rs.datalayer import KeyAlreadyPresentError, KeyId, TreeIndex, ValueId
from chia_rs.datalayer import KeyAlreadyPresentError, KeyId, ProofOfInclusion, TreeIndex, ValueId

from chia.data_layer.data_layer_errors import KeyNotFoundError, MerkleBlobNotFoundError, TreeGenerationIncrementingError
from chia.data_layer.data_layer_util import (
Expand All @@ -27,7 +27,6 @@
Node,
NodeType,
OperationType,
ProofOfInclusionHint,
Root,
SerializedNode,
ServerInfo,
Expand Down Expand Up @@ -1412,7 +1411,7 @@ async def get_proof_of_inclusion_by_hash(
node_hash: bytes32,
store_id: bytes32,
root_hash: Optional[bytes32] = None,
) -> ProofOfInclusionHint:
) -> ProofOfInclusion:
if root_hash is None:
root = await self.get_tree_root(store_id=store_id)
root_hash = root.node_hash
Expand All @@ -1424,7 +1423,7 @@ async def get_proof_of_inclusion_by_key(
self,
key: bytes,
store_id: bytes32,
) -> ProofOfInclusionHint:
) -> ProofOfInclusion:
root = await self.get_tree_root(store_id=store_id)
merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash)
kvid = await self.get_kvid(key, store_id)
Expand Down
4 changes: 3 additions & 1 deletion chia/data_layer/util/merkle_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
BlockIndexOutOfBoundsError,
KeyAlreadyPresentError,
KeyId,
ProofOfInclusion,
ProofOfInclusionLayer,
TreeIndex,
UnknownKeyError,
ValueId,
)
from chia_rs.sized_ints import int64, uint8, uint32

from chia.data_layer.data_layer_util import ProofOfInclusion, ProofOfInclusionLayer, Side, internal_hash
from chia.data_layer.data_layer_util import Side, internal_hash
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.hash import std_hash
from chia.util.streamable import Streamable, streamable
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ pytest = ">=8.3.3,<9.0.0"

[[package]]
name = "chia_rs"
version = "0.19.1"
version = "0.20.0"
description = "Code useful for implementing chia consensus."
optional = false
python-versions = "*"
Expand All @@ -782,7 +782,7 @@ typing-extensions = "*"
type = "git"
url = "https://github.com/chia-network/chia_rs"
reference = "long_lived/initial_datalayer"
resolved_reference = "da91206dc3b8f8909b4d6025930e13e35935255e"
resolved_reference = "abbb907f565f488f90bbbd050faec984468a5ff7"
subdirectory = "wheel/"

[[package]]
Expand Down

0 comments on commit 1567fa3

Please sign in to comment.