diff --git a/benchmarks/test_benchmark_compile_pages.py b/benchmarks/test_benchmark_compile_pages.py index 149fc613007..6cf39f60ca8 100644 --- a/benchmarks/test_benchmark_compile_pages.py +++ b/benchmarks/test_benchmark_compile_pages.py @@ -46,10 +46,26 @@ def render_multiple_pages(app, num: int): class State(rx.State): """The app state.""" - position: str - college: str - age: Tuple[int, int] = (18, 50) - salary: Tuple[int, int] = (0, 25000000) + position: rx.Field[str] + college: rx.Field[str] + age: rx.Field[Tuple[int, int]] = rx.field((18, 50)) + salary: rx.Field[Tuple[int, int]] = rx.field((0, 25000000)) + + @rx.event + def set_position(self, value: str): + self.position = value + + @rx.event + def set_college(self, value: str): + self.college = value + + @rx.event + def set_age(self, value: list[int]): + self.age = (value[0], value[1]) + + @rx.event + def set_salary(self, value: list[int]): + self.salary = (value[0], value[1]) comp1 = rx.center( rx.theme_panel(), @@ -74,13 +90,13 @@ class State(rx.State): rx.select( ["C", "PF", "SF", "PG", "SG"], placeholder="Select a position. (All)", - on_change=State.set_position, # pyright: ignore [reportAttributeAccessIssue] + on_change=State.set_position, size="3", ), rx.select( college, placeholder="Select a college. (All)", - on_change=State.set_college, # pyright: ignore [reportAttributeAccessIssue] + on_change=State.set_college, size="3", ), ), @@ -95,7 +111,7 @@ class State(rx.State): default_value=[18, 50], min=18, max=50, - on_value_commit=State.set_age, # pyright: ignore [reportAttributeAccessIssue] + on_value_commit=State.set_age, ), align_items="left", width="100%", @@ -110,7 +126,7 @@ class State(rx.State): default_value=[0, 25000000], min=0, max=25000000, - on_value_commit=State.set_salary, # pyright: ignore [reportAttributeAccessIssue] + on_value_commit=State.set_salary, ), align_items="left", width="100%", diff --git a/reflex/state.py b/reflex/state.py index 92aaa471041..dceba7e3bc6 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1742,6 +1742,9 @@ async def _process_event( Yields: StateUpdate object + + Raises: + ValueError: If a string value is received for an int or float type and cannot be converted. """ from reflex.utils import telemetry @@ -1779,12 +1782,25 @@ async def _process_event( hinted_args, (Base, BaseModelV1, BaseModelV2) ): payload[arg] = hinted_args(**value) - if isinstance(value, list) and (hinted_args is set or hinted_args is Set): + elif isinstance(value, list) and (hinted_args is set or hinted_args is Set): payload[arg] = set(value) - if isinstance(value, list) and ( + elif isinstance(value, list) and ( hinted_args is tuple or hinted_args is Tuple ): payload[arg] = tuple(value) + elif isinstance(value, str) and ( + hinted_args is int or hinted_args is float + ): + try: + payload[arg] = hinted_args(value) + except ValueError: + raise ValueError( + f"Received a string value ({value}) for {arg} but expected a {hinted_args}" + ) from None + else: + console.warn( + f"Received a string value ({value}) for {arg} but expected a {hinted_args}. A simple conversion was successful." + ) # Wrap the function in a try/except block. try: diff --git a/reflex/vars/base.py b/reflex/vars/base.py index ec65c37111d..f9ee2792b93 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -933,7 +933,7 @@ def _get_setter(self) -> Callable[[BaseState, Any], None]: """ actual_name = self._var_field_name - def setter(state: BaseState, value: Any): + def setter(state: Any, value: Any): """Get the setter for the var. Args: @@ -951,6 +951,8 @@ def setter(state: BaseState, value: Any): else: setattr(state, actual_name, value) + setter.__annotations__["value"] = self._var_type + setter.__qualname__ = self._get_setter_name() return setter diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index f312f81221f..91a1b5ae15e 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -20,7 +20,11 @@ def BackgroundTask(): class State(rx.State): counter: int = 0 _task_id: int = 0 - iterations: int = 10 + iterations: rx.Field[int] = rx.field(10) + + @rx.event + def set_iterations(self, value: str): + self.iterations = int(value) @rx.event(background=True) async def handle_event(self): @@ -125,8 +129,8 @@ def index() -> rx.Component: rx.input( id="iterations", placeholder="Iterations", - value=State.iterations.to_string(), # pyright: ignore [reportAttributeAccessIssue] - on_change=State.set_iterations, # pyright: ignore [reportAttributeAccessIssue] + value=State.iterations.to_string(), + on_change=State.set_iterations, ), rx.button( "Delayed Increment",