Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for duplicate code in contract internal implementations #3579

Merged
merged 7 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3579.internal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Move duplicate code into base class in (1) ContractFunction and AsyncContractFunction, (2) ContractEvents and AsyncContractEvents, and (3) ContractFunctions and AsyncContractFunctions.
darwintree marked this conversation as resolved.
Show resolved Hide resolved
206 changes: 3 additions & 203 deletions web3/contract/async_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@

from eth_typing import (
ABI,
ABIFunction,
ChecksumAddress,
)
from eth_utils import (
combomethod,
)
from eth_utils.abi import (
abi_to_signature,
filter_abi_by_type,
get_abi_input_names,
)
from eth_utils.toolz import (
Expand All @@ -34,7 +32,6 @@

from web3._utils.abi import (
fallback_func_abi_exists,
get_name_from_abi_element_identifier,
receive_func_abi_exists,
)
from web3._utils.abi_element_identifiers import (
Expand All @@ -50,7 +47,6 @@
from web3._utils.contracts import (
async_parse_block_identifier,
copy_contract_event,
copy_contract_function,
)
from web3._utils.datatypes import (
PropertyCheckingFactory,
Expand Down Expand Up @@ -89,12 +85,6 @@
get_function_by_identifier,
)
from web3.exceptions import (
ABIEventNotFound,
ABIFunctionNotFound,
MismatchedABI,
NoABIEventsFound,
NoABIFound,
NoABIFunctionsFound,
Web3AttributeError,
Web3TypeError,
Web3ValidationError,
Expand All @@ -106,12 +96,6 @@
StateOverride,
TxParams,
)
from web3.utils.abi import (
_filter_by_argument_count,
_get_any_abi_signature_with_name,
_mismatched_abi_error_diagnosis,
get_abi_element,
)

if TYPE_CHECKING:
from ens import AsyncENS # noqa: F401
Expand All @@ -122,7 +106,7 @@ class AsyncContractEvent(BaseContractEvent):
# mypy types
w3: "AsyncWeb3"

def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractEvent":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to leave this as is since the returned object is a new copy of the event.

Suggested change
def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractEvent":
def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractEvent":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are nearly the same and Self can keep consistency. Self is used to indicate type hints, but did not imply if it is self or a copy of it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine with me. Should this __call__ and the one in the ContractEvent class move to base_contract? If not, I noticed the ContractEvent.__call__ has the type "ContractEvent" rather than self, probably best to make them consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Moved into BaseContractEvent

def __call__(self, *args: Any, **kwargs: Any) -> Self:
return copy_contract_event(self, *args, **kwargs)

@combomethod
Expand Down Expand Up @@ -255,162 +239,18 @@ def build_filter(self) -> AsyncEventFilterBuilder:
builder.address = self.address
return builder

@classmethod
def factory(cls, class_name: str, **kwargs: Any) -> Self:
return PropertyCheckingFactory(class_name, (cls,), kwargs)()


class AsyncContractEvents(BaseContractEvents):
class AsyncContractEvents(BaseContractEvents[AsyncContractEvent]):
def __init__(
self, abi: ABI, w3: "AsyncWeb3", address: Optional[ChecksumAddress] = None
) -> None:
super().__init__(abi, w3, AsyncContractEvent, address)

def __iter__(self) -> Iterable["AsyncContractEvent"]:
if not hasattr(self, "_events") or not self._events:
return

for event in self._events:
yield self[abi_to_signature(event)]

def __getattr__(self, event_name: str) -> "AsyncContractEvent":
if super().__getattribute__("abi") is None:
raise NoABIFound(
"There is no ABI found for this contract.",
)
elif "_events" not in self.__dict__ or len(self._events) == 0:
raise NoABIEventsFound(
"The abi for this contract contains no event definitions. ",
"Are you sure you provided the correct contract abi?",
)
elif get_name_from_abi_element_identifier(event_name) not in [
get_name_from_abi_element_identifier(event["name"])
for event in self._events
]:
raise ABIEventNotFound(
f"The event '{event_name}' was not found in this contract's abi. ",
"Are you sure you provided the correct contract abi?",
)

if "(" not in event_name:
event_name = _get_any_abi_signature_with_name(event_name, self._events)
else:
event_name = f"_{event_name}"

return super().__getattribute__(event_name)

def __getitem__(self, event_name: str) -> "AsyncContractEvent":
return getattr(self, event_name)


class AsyncContractFunction(BaseContractFunction):
# mypy types
w3: "AsyncWeb3"

def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractFunction":
# When a function is called, check arguments to obtain the correct function
# in the contract. self will be used if all args and kwargs are
# encodable to self.abi, otherwise the correct function is obtained from
# the contract.
if (
self.abi_element_identifier in [FallbackFn, ReceiveFn]
or self.abi_element_identifier == "constructor"
):
return copy_contract_function(self, *args, **kwargs)

all_functions = cast(
List[ABIFunction],
filter_abi_by_type(
"function",
self.contract_abi,
),
)
# Filter functions by name to obtain function signatures
function_name = get_name_from_abi_element_identifier(
self.abi_element_identifier
)
function_abis = [
function for function in all_functions if function["name"] == function_name
]
num_args = len(args) + len(kwargs)
function_abis_with_arg_count = cast(
List[ABIFunction],
_filter_by_argument_count(
num_args,
function_abis,
),
)

if not len(function_abis_with_arg_count):
# Build an ABI without arguments to determine if one exists
function_abis_with_arg_count = [
ABIFunction({"type": "function", "name": function_name})
]

# Check that arguments in call match a function ABI
num_attempts = 0
function_abi_matches = []
contract_function = None
for abi in function_abis_with_arg_count:
try:
num_attempts += 1

# Search for a function ABI that matches the arguments used
function_abi_matches.append(
cast(
ABIFunction,
get_abi_element(
function_abis,
abi_to_signature(abi),
*args,
abi_codec=self.w3.codec,
**kwargs,
),
)
)
except MismatchedABI:
# ignore exceptions
continue

if len(function_abi_matches) == 1:
function_abi = function_abi_matches[0]
if abi_to_signature(self.abi) == abi_to_signature(function_abi):
contract_function = self
else:
# Found a match that is not self
contract_function = AsyncContractFunction.factory(
abi_to_signature(function_abi),
w3=self.w3,
contract_abi=self.contract_abi,
address=self.address,
abi_element_identifier=abi_to_signature(function_abi),
abi=function_abi,
)
else:
for abi in function_abi_matches:
if abi_to_signature(self.abi) == abi_to_signature(abi):
contract_function = self
break
else:
# Raise exception if multiple found
raise MismatchedABI(
_mismatched_abi_error_diagnosis(
function_name,
self.contract_abi,
len(function_abi_matches),
num_args,
*args,
abi_codec=self.w3.codec,
**kwargs,
)
)

return copy_contract_function(contract_function, *args, **kwargs)

@classmethod
def factory(cls, class_name: str, **kwargs: Any) -> Self:
return PropertyCheckingFactory(class_name, (cls,), kwargs)()

async def call(
self,
transaction: Optional[TxParams] = None,
Expand Down Expand Up @@ -551,7 +391,7 @@ def get_receive_function(
return cast(AsyncContractFunction, NonExistentReceiveFunction())


class AsyncContractFunctions(BaseContractFunctions):
class AsyncContractFunctions(BaseContractFunctions[AsyncContractFunction]):
def __init__(
self,
abi: ABI,
Expand All @@ -561,46 +401,6 @@ def __init__(
) -> None:
super().__init__(abi, w3, AsyncContractFunction, address, decode_tuples)

def __iter__(self) -> Iterable["AsyncContractFunction"]:
if not hasattr(self, "_functions") or not self._functions:
return

for func in self._functions:
yield self[abi_to_signature(func)]

def __getattr__(self, function_name: str) -> "AsyncContractFunction":
if super().__getattribute__("abi") is None:
raise NoABIFound(
"There is no ABI found for this contract.",
)
elif "_functions" not in self.__dict__ or len(self._functions) == 0:
raise NoABIFunctionsFound(
"The abi for this contract contains no function definitions. ",
"Are you sure you provided the correct contract abi?",
)
elif get_name_from_abi_element_identifier(function_name) not in [
get_name_from_abi_element_identifier(function["name"])
for function in self._functions
]:
raise ABIFunctionNotFound(
f"The function '{function_name}' was not found in this ",
"contract's abi.",
)

if "(" not in function_name:
function_name = _get_any_abi_signature_with_name(
function_name, self._functions
)
else:
function_name = f"_{function_name}"

return super().__getattribute__(
function_name,
)

def __getitem__(self, function_name: str) -> "AsyncContractFunction":
return getattr(self, function_name)


class AsyncContract(BaseContract):
functions: AsyncContractFunctions = None
Expand Down
Loading