Skip to content

Commit 6334cfa

Browse files
authored
allow for event handlers to ignore args (#4282)
* allow for event handlers to ignore args * use a constant * dang it darglint * forgor * keep the tests but move them to valid place
1 parent d9ab3a0 commit 6334cfa

File tree

8 files changed

+127
-116
lines changed

8 files changed

+127
-116
lines changed

reflex/components/component.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Iterator,
1818
List,
1919
Optional,
20+
Sequence,
2021
Set,
2122
Type,
2223
Union,
@@ -38,6 +39,7 @@
3839
PageNames,
3940
)
4041
from reflex.constants.compiler import SpecialAttributes
42+
from reflex.constants.state import FRONTEND_EVENT_STATE
4143
from reflex.event import (
4244
EventCallback,
4345
EventChain,
@@ -533,7 +535,7 @@ def __init__(self, *args, **kwargs):
533535

534536
def _create_event_chain(
535537
self,
536-
args_spec: Any,
538+
args_spec: types.ArgsSpec | Sequence[types.ArgsSpec],
537539
value: Union[
538540
Var,
539541
EventHandler,
@@ -599,7 +601,7 @@ def _create_event_chain(
599601

600602
# If the input is a callable, create an event chain.
601603
elif isinstance(value, Callable):
602-
result = call_event_fn(value, args_spec)
604+
result = call_event_fn(value, args_spec, key=key)
603605
if isinstance(result, Var):
604606
# Recursively call this function if the lambda returned an EventChain Var.
605607
return self._create_event_chain(args_spec, result, key=key)
@@ -629,14 +631,16 @@ def _create_event_chain(
629631
event_actions={},
630632
)
631633

632-
def get_event_triggers(self) -> Dict[str, Any]:
634+
def get_event_triggers(
635+
self,
636+
) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]:
633637
"""Get the event triggers for the component.
634638
635639
Returns:
636640
The event triggers.
637641
638642
"""
639-
default_triggers = {
643+
default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
640644
EventTriggers.ON_FOCUS: no_args_event_spec,
641645
EventTriggers.ON_BLUR: no_args_event_spec,
642646
EventTriggers.ON_CLICK: no_args_event_spec,
@@ -1142,7 +1146,10 @@ def _event_trigger_values_use_state(self) -> bool:
11421146
if isinstance(event, EventCallback):
11431147
continue
11441148
if isinstance(event, EventSpec):
1145-
if event.handler.state_full_name:
1149+
if (
1150+
event.handler.state_full_name
1151+
and event.handler.state_full_name != FRONTEND_EVENT_STATE
1152+
):
11461153
return True
11471154
else:
11481155
if event._var_state:

reflex/constants/state.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ class StateManagerMode(str, Enum):
99
DISK = "disk"
1010
MEMORY = "memory"
1111
REDIS = "redis"
12+
13+
14+
# Used for things like console_log, etc.
15+
FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state"

reflex/event.py

+73-64
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from typing_extensions import ParamSpec, Protocol, get_args, get_origin
2929

3030
from reflex import constants
31+
from reflex.constants.state import FRONTEND_EVENT_STATE
3132
from reflex.utils import console, format
3233
from reflex.utils.exceptions import (
3334
EventFnArgMismatch,
34-
EventHandlerArgMismatch,
3535
EventHandlerArgTypeMismatch,
3636
)
3737
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
@@ -662,7 +662,7 @@ def fn():
662662
fn.__qualname__ = name
663663
fn.__signature__ = sig
664664
return EventSpec(
665-
handler=EventHandler(fn=fn),
665+
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
666666
args=tuple(
667667
(
668668
Var(_js_expr=k),
@@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str:
10921092

10931093

10941094
def call_event_handler(
1095-
event_handler: EventHandler | EventSpec,
1096-
arg_spec: ArgsSpec | Sequence[ArgsSpec],
1095+
event_callback: EventHandler | EventSpec,
1096+
event_spec: ArgsSpec | Sequence[ArgsSpec],
10971097
key: Optional[str] = None,
10981098
) -> EventSpec:
10991099
"""Call an event handler to get the event spec.
@@ -1103,53 +1103,57 @@ def call_event_handler(
11031103
Otherwise, the event handler will be called with no args.
11041104
11051105
Args:
1106-
event_handler: The event handler.
1107-
arg_spec: The lambda that define the argument(s) to pass to the event handler.
1106+
event_callback: The event handler.
1107+
event_spec: The lambda that define the argument(s) to pass to the event handler.
11081108
key: The key to pass to the event handler.
11091109
1110-
Raises:
1111-
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.
1112-
11131110
Returns:
11141111
The event spec from calling the event handler.
11151112
11161113
# noqa: DAR401 failure
11171114
11181115
"""
1119-
parsed_args = parse_args_spec(arg_spec) # type: ignore
1120-
1121-
if isinstance(event_handler, EventSpec):
1122-
# Handle partial application of EventSpec args
1123-
return event_handler.add_args(*parsed_args)
1124-
1125-
provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)
1126-
1127-
provided_callback_n_args = (
1128-
len(provided_callback_fullspec.args) - 1
1129-
) # subtract 1 for bound self arg
1130-
1131-
if provided_callback_n_args != len(parsed_args):
1132-
raise EventHandlerArgMismatch(
1133-
"The number of arguments accepted by "
1134-
f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
1135-
"does not match the arguments passed by the event trigger: "
1136-
f"{[str(v) for v in parsed_args]}\n"
1137-
"See https://reflex.dev/docs/events/event-arguments/"
1116+
event_spec_args = parse_args_spec(event_spec) # type: ignore
1117+
1118+
if isinstance(event_callback, EventSpec):
1119+
check_fn_match_arg_spec(
1120+
event_callback.handler.fn,
1121+
event_spec,
1122+
key,
1123+
bool(event_callback.handler.state_full_name) + len(event_callback.args),
1124+
event_callback.handler.fn.__qualname__,
11381125
)
1126+
# Handle partial application of EventSpec args
1127+
return event_callback.add_args(*event_spec_args)
1128+
1129+
check_fn_match_arg_spec(
1130+
event_callback.fn,
1131+
event_spec,
1132+
key,
1133+
bool(event_callback.state_full_name),
1134+
event_callback.fn.__qualname__,
1135+
)
11391136

1140-
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
1137+
all_acceptable_specs = (
1138+
[event_spec] if not isinstance(event_spec, Sequence) else event_spec
1139+
)
11411140

11421141
event_spec_return_types = list(
11431142
filter(
11441143
lambda event_spec_return_type: event_spec_return_type is not None
11451144
and get_origin(event_spec_return_type) is tuple,
1146-
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
1145+
(
1146+
get_type_hints(arg_spec).get("return", None)
1147+
for arg_spec in all_acceptable_specs
1148+
),
11471149
)
11481150
)
11491151

11501152
if event_spec_return_types:
11511153
failures = []
11521154

1155+
event_callback_spec = inspect.getfullargspec(event_callback.fn)
1156+
11531157
for event_spec_index, event_spec_return_type in enumerate(
11541158
event_spec_return_types
11551159
):
@@ -1160,14 +1164,14 @@ def call_event_handler(
11601164
]
11611165

11621166
try:
1163-
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
1167+
type_hints_of_provided_callback = get_type_hints(event_callback.fn)
11641168
except NameError:
11651169
type_hints_of_provided_callback = {}
11661170

11671171
failed_type_check = False
11681172

11691173
# check that args of event handler are matching the spec if type hints are provided
1170-
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
1174+
for i, arg in enumerate(event_callback_spec.args[1:]):
11711175
if arg not in type_hints_of_provided_callback:
11721176
continue
11731177

@@ -1181,15 +1185,15 @@ def call_event_handler(
11811185
# f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
11821186
# ) from e
11831187
console.warn(
1184-
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
1188+
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
11851189
)
11861190
compare_result = False
11871191

11881192
if compare_result:
11891193
continue
11901194
else:
11911195
failure = EventHandlerArgTypeMismatch(
1192-
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
1196+
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
11931197
)
11941198
failures.append(failure)
11951199
failed_type_check = True
@@ -1210,14 +1214,14 @@ def call_event_handler(
12101214

12111215
given_string = ", ".join(
12121216
repr(type_hints_of_provided_callback.get(arg, Any))
1213-
for arg in provided_callback_fullspec.args[1:]
1217+
for arg in event_callback_spec.args[1:]
12141218
).replace("[", "\\[")
12151219

12161220
console.warn(
1217-
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
1221+
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
12181222
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
12191223
)
1220-
return event_handler(*parsed_args)
1224+
return event_callback(*event_spec_args)
12211225

12221226
if failures:
12231227
console.deprecate(
@@ -1227,7 +1231,7 @@ def call_event_handler(
12271231
"0.7.0",
12281232
)
12291233

1230-
return event_handler(*parsed_args) # type: ignore
1234+
return event_callback(*event_spec_args) # type: ignore
12311235

12321236

12331237
def unwrap_var_annotation(annotation: GenericType):
@@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
12941298

12951299

12961300
def check_fn_match_arg_spec(
1297-
fn: Callable,
1298-
arg_spec: ArgsSpec,
1299-
key: Optional[str] = None,
1300-
) -> List[Var]:
1301+
user_func: Callable,
1302+
arg_spec: ArgsSpec | Sequence[ArgsSpec],
1303+
key: str | None = None,
1304+
number_of_bound_args: int = 0,
1305+
func_name: str | None = None,
1306+
):
13011307
"""Ensures that the function signature matches the passed argument specification
13021308
or raises an EventFnArgMismatch if they do not.
13031309
13041310
Args:
1305-
fn: The function to be validated.
1311+
user_func: The function to be validated.
13061312
arg_spec: The argument specification for the event trigger.
1307-
key: The key to pass to the event handler.
1308-
1309-
Returns:
1310-
The parsed arguments from the argument specification.
1313+
key: The key of the event trigger.
1314+
number_of_bound_args: The number of bound arguments to the function.
1315+
func_name: The name of the function to be validated.
13111316
13121317
Raises:
13131318
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
13141319
"""
1315-
fn_args = inspect.getfullargspec(fn).args
1316-
fn_defaults_args = inspect.getfullargspec(fn).defaults
1317-
n_fn_args = len(fn_args)
1318-
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
1319-
if isinstance(fn, types.MethodType):
1320-
n_fn_args -= 1 # subtract 1 for bound self arg
1321-
parsed_args = parse_args_spec(arg_spec)
1322-
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
1320+
user_args = inspect.getfullargspec(user_func).args
1321+
user_default_args = inspect.getfullargspec(user_func).defaults
1322+
number_of_user_args = len(user_args) - number_of_bound_args
1323+
number_of_user_default_args = len(user_default_args) if user_default_args else 0
1324+
1325+
parsed_event_args = parse_args_spec(arg_spec)
1326+
1327+
number_of_event_args = len(parsed_event_args)
1328+
1329+
if number_of_user_args - number_of_user_default_args > number_of_event_args:
13231330
raise EventFnArgMismatch(
1324-
"The number of mandatory arguments accepted by "
1325-
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
1326-
"does not match the arguments passed by the event trigger: "
1327-
f"{[str(v) for v in parsed_args]}\n"
1331+
f"Event {key} only provides {number_of_event_args} arguments, but "
1332+
f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
1333+
"arguments to be passed to the event handler.\n"
13281334
"See https://reflex.dev/docs/events/event-arguments/"
13291335
)
1330-
return parsed_args
13311336

13321337

13331338
def call_event_fn(
13341339
fn: Callable,
1335-
arg_spec: ArgsSpec,
1340+
arg_spec: ArgsSpec | Sequence[ArgsSpec],
13361341
key: Optional[str] = None,
13371342
) -> list[EventSpec] | Var:
13381343
"""Call a function to a list of event specs.
@@ -1356,10 +1361,14 @@ def call_event_fn(
13561361
from reflex.utils.exceptions import EventHandlerValueError
13571362

13581363
# Check that fn signature matches arg_spec
1359-
parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
1364+
check_fn_match_arg_spec(fn, arg_spec, key=key)
1365+
1366+
parsed_args = parse_args_spec(arg_spec)
1367+
1368+
number_of_fn_args = len(inspect.getfullargspec(fn).args)
13601369

13611370
# Call the function with the parsed args.
1362-
out = fn(*parsed_args)
1371+
out = fn(*[*parsed_args][:number_of_fn_args])
13631372

13641373
# If the function returns a Var, assume it's an EventChain and render it directly.
13651374
if isinstance(out, Var):
@@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
14781487
"""
14791488
signature = inspect.signature(fn)
14801489
new_param = inspect.Parameter(
1481-
"state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
1490+
FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
14821491
)
14831492
return signature.replace(parameters=(new_param, *signature.parameters.values()))
14841493

reflex/utils/exceptions.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError):
8989
"""Raised when the return types of match cases are different."""
9090

9191

92-
class EventHandlerArgMismatch(ReflexError, TypeError):
93-
"""Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""
94-
95-
9692
class EventHandlerArgTypeMismatch(ReflexError, TypeError):
9793
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
9894

9995

10096
class EventFnArgMismatch(ReflexError, TypeError):
101-
"""Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""
97+
"""Raised when the number of args required by an event handler is more than provided by the event trigger."""
10298

10399

104100
class DynamicRouteArgShadowsStateVar(ReflexError, NameError):

reflex/utils/format.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING, Any, List, Optional, Union
1010

1111
from reflex import constants
12+
from reflex.constants.state import FRONTEND_EVENT_STATE
1213
from reflex.utils import exceptions
1314
from reflex.utils.console import deprecate
1415

@@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
439440

440441
from reflex.state import State
441442

442-
if state_full_name == "state" and name not in State.__dict__:
443+
if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__:
443444
return ("", to_snake_case(handler.fn.__qualname__))
444445

445446
return (state_full_name, name)

reflex/utils/pyi_generator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from multiprocessing import Pool, cpu_count
1717
from pathlib import Path
1818
from types import ModuleType, SimpleNamespace
19-
from typing import Any, Callable, Iterable, Type, get_args, get_origin
19+
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
2020

2121
from reflex.components.component import Component
2222
from reflex.utils import types as rx_types
@@ -560,7 +560,7 @@ def figure_out_return_type(annotation: Any):
560560
inspect.signature(event_specs).return_annotation
561561
)
562562
if not isinstance(
563-
event_specs := event_triggers[trigger], tuple
563+
event_specs := event_triggers[trigger], Sequence
564564
)
565565
else ast.Subscript(
566566
ast.Name("Union"),

0 commit comments

Comments
 (0)