From 3d50c1b623c70a0542d5f5edcb778342902baf39 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 22 Jan 2025 05:00:49 -0800 Subject: [PATCH 01/12] WiP --- reflex/app.py | 16 +- reflex/compiler/utils.py | 5 +- reflex/middleware/hydrate_middleware.py | 4 +- reflex/state.py | 433 ++++++++++++++---------- reflex/vars/base.py | 303 +++++++++++++---- tests/units/test_app.py | 18 +- tests/units/test_state.py | 119 +++++-- tests/units/test_var.py | 29 +- 8 files changed, 654 insertions(+), 273 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 7e868e73056..6b9a64ca7a1 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -833,11 +833,17 @@ def _validate_var_dependencies( if not var._cache: continue deps = var._deps(objclass=state) - for dep in deps: - if dep not in state.vars and dep not in state.backend_vars: - raise exceptions.VarDependencyError( - f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}" - ) + for state_name, dep_set in deps.items(): + state_cls = ( + state.get_root_state().get_class_substate(state_name) + if state_name != state.get_full_name() + else state + ) + for dep in dep_set: + if dep not in state_cls.vars and dep not in state_cls.backend_vars: + raise exceptions.VarDependencyError( + f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}" + ) for substate in state.class_subclasses: self._validate_var_dependencies(substate) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index c0ba28f4b36..f5e79a79671 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union from urllib.parse import urlparse @@ -29,7 +30,7 @@ ) from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.istate.storage import Cookie, LocalStorage, SessionStorage -from reflex.state import BaseState +from reflex.state import BaseState, _resolve_delta from reflex.style import Style from reflex.utils import console, format, imports, path_ops from reflex.utils.imports import ImportVar, ParsedImportDict @@ -169,7 +170,7 @@ def compile_state(state: Type[BaseState]) -> dict: initial_state = state(_reflex_internal_init=True).dict( initial=True, include_computed=False ) - return initial_state + return asyncio.run(_resolve_delta(initial_state)) def _compile_client_storage_field( diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 2198b82c2c8..2dea54e1712 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -8,7 +8,7 @@ from reflex import constants from reflex.event import Event, get_hydrate_event from reflex.middleware.middleware import Middleware -from reflex.state import BaseState, StateUpdate +from reflex.state import BaseState, StateUpdate, _resolve_delta if TYPE_CHECKING: from reflex.app import App @@ -42,7 +42,7 @@ async def preprocess( setattr(state, constants.CompileVars.IS_HYDRATED, False) # Get the initial state. - delta = state.dict() + delta = await _resolve_delta(state.dict()) # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/state.py b/reflex/state.py index 66098d23214..6ef3ff3e8db 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -328,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): ) +async def _resolve_delta(delta: Delta) -> Delta: + """Await all coroutines in the delta. + + Args: + delta: The delta to process. + + Returns: + The same delta dict with all coroutines resolved to their return value. + """ + tasks = {} + for state_name, state_delta in delta.items(): + for var_name, value in state_delta.items(): + if asyncio.iscoroutine(value): + tasks[state_name, var_name] = asyncio.create_task(value) + for (state_name, var_name), task in tasks.items(): + delta[state_name][var_name] = await task + return delta + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -355,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A set of subclassses of this class. class_subclasses: ClassVar[Set[Type[BaseState]]] = set() - # Mapping of var name to set of computed variables that depend on it - _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} - - # Mapping of var name to set of substates that depend on it - _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} + # Mapping of var name to set of (state_full_name, var_name) that depend on it. + _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {} # Set of vars which always need to be recomputed _always_dirty_computed_vars: ClassVar[Set[str]] = set() @@ -367,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Set of substates which always need to be recomputed _always_dirty_substates: ClassVar[Set[str]] = set() + # Set of states which might need to be recomputed if vars in this state change. + _potentially_dirty_states: ClassVar[Set[str]] = set() + # The parent state. parent_state: Optional[BaseState] = None @@ -518,6 +537,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): # Reset dirty substate tracking for this class. cls._always_dirty_substates = set() + cls._potentially_dirty_states = set() # Get the parent vars. parent_state = cls.get_parent_state() @@ -621,8 +641,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): setattr(cls, name, handler) # Initialize per-class var dependency tracking. - cls._computed_var_dependencies = defaultdict(set) - cls._substate_var_dependencies = defaultdict(set) + cls._var_dependencies = {} cls._init_var_dependency_dicts() @staticmethod @@ -767,26 +786,25 @@ def _init_var_dependency_dicts(cls): Additional updates tracking dicts for vars and substates that always need to be recomputed. """ - inherited_vars = set(cls.inherited_vars).union( - set(cls.inherited_backend_vars), - ) for cvar_name, cvar in cls.computed_vars.items(): - # Add the dependencies. - for var in cvar._deps(objclass=cls): - cls._computed_var_dependencies[var].add(cvar_name) - if var in inherited_vars: - # track that this substate depends on its parent for this var - state_name = cls.get_name() - parent_state = cls.get_parent_state() - while parent_state is not None and var in { - **parent_state.vars, - **parent_state.backend_vars, + if not cvar._cache: + # Do not perform dep calculation when cache=False (these are always dirty). + continue + for state_name, dvar_set in cvar._deps(objclass=cls).items(): + state_cls = cls.get_root_state().get_class_substate(state_name) + for dvar in dvar_set: + defining_state_cls = state_cls + while dvar in { + *defining_state_cls.inherited_vars, + *defining_state_cls.inherited_backend_vars, }: - parent_state._substate_var_dependencies[var].add(state_name) - state_name, parent_state = ( - parent_state.get_name(), - parent_state.get_parent_state(), - ) + defining_state_cls = defining_state_cls.get_parent_state() + defining_state_cls._var_dependencies.setdefault(dvar, set()).add( + (cls.get_full_name(), cvar_name) + ) + defining_state_cls._potentially_dirty_states.add( + cls.get_full_name() + ) # ComputedVar with cache=False always need to be recomputed cls._always_dirty_computed_vars = { @@ -901,6 +919,17 @@ def get_parent_state(cls) -> Type[BaseState] | None: raise ValueError(f"Only one parent state is allowed {parent_states}.") return parent_states[0] if len(parent_states) == 1 else None # type: ignore + @classmethod + @functools.lru_cache() + def get_root_state(cls) -> Type[BaseState]: + """Get the root state. + + Returns: + The root state. + """ + parent_state = cls.get_parent_state() + return cls if parent_state is None else parent_state.get_root_state() + @classmethod def get_substates(cls) -> set[Type[BaseState]]: """Get the substates of the state. @@ -1353,7 +1382,7 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) # Add the var to the dirty list. - if name in self.vars or name in self._computed_var_dependencies: + if name in self.base_vars: self.dirty_vars.add(name) self._mark_dirty() @@ -1423,6 +1452,23 @@ def get_substate(self, path: Sequence[str]) -> BaseState: raise ValueError(f"Invalid path: {path}") return self.substates[path[0]].get_substate(path[1:]) + @classmethod + def _get_potentially_dirty_states(cls) -> set[type[BaseState]]: + """Get substates which may have dirty vars due to dependencies. + + Returns: + The set of potentially dirty substate classes. + """ + return { + cls.get_class_substate(substate_name) + for substate_name in cls._always_dirty_substates + }.union( + { + cls.get_root_state().get_class_substate(substate_name) + for substate_name in cls._potentially_dirty_states + } + ) + @classmethod def _get_common_ancestor(cls, other: Type[BaseState]) -> str: """Find the name of the nearest common ancestor shared by this and the other state. @@ -1493,55 +1539,37 @@ def _get_root_state(self) -> BaseState: parent_state = parent_state.parent_state return parent_state - async def _populate_parent_states(self, target_state_cls: Type[BaseState]): - """Populate substates in the tree between the target_state_cls and common ancestor of this state. + async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: + """Get a state instance from redis. Args: - target_state_cls: The class of the state to populate parent states for. + state_cls: The class of the state. Returns: - The parent state instance of target_state_cls. + The instance of state_cls associated with this state's client_token. Raises: RuntimeError: If redis is not used in this backend process. + StateMismatchError: If the state instance is not of the expected type. """ + # Then get the target state and all its substates. state_manager = get_state_manager() if not isinstance(state_manager, StateManagerRedis): raise RuntimeError( - f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. " + f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) + state_in_redis = await state_manager._link_arbitrary_state( + self, + state_cls, + ) - # Find the missing parent states up to the common ancestor. - ( - common_ancestor_name, - missing_parent_states, - ) = self._determine_missing_parent_states(target_state_cls) - - # Fetch all missing parent states and link them up to the common ancestor. - parent_states_tuple = self._get_parent_states() - root_state = parent_states_tuple[-1][1] - parent_states_by_name = dict(parent_states_tuple) - parent_state = parent_states_by_name[common_ancestor_name] - for parent_state_name in missing_parent_states: - try: - parent_state = root_state.get_substate(parent_state_name.split(".")) - # The requested state is already cached, do NOT fetch it again. - continue - except ValueError: - # The requested state is missing, fetch from redis. - pass - parent_state = await state_manager.get_state( - token=_substate_key( - self.router.session.client_token, parent_state_name - ), - top_level=False, - get_substates=False, - parent_state=parent_state, + if not isinstance(state_in_redis, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." ) - # Return the direct parent of target_state_cls for subsequent linking. - return parent_state + return state_in_redis def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from the cache. @@ -1563,44 +1591,6 @@ def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: ) return substate - async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: - """Get a state instance from redis. - - Args: - state_cls: The class of the state. - - Returns: - The instance of state_cls associated with this state's client_token. - - Raises: - RuntimeError: If redis is not used in this backend process. - StateMismatchError: If the state instance is not of the expected type. - """ - # Fetch all missing parent states from redis. - parent_state_of_state_cls = await self._populate_parent_states(state_cls) - - # Then get the target state and all its substates. - state_manager = get_state_manager() - if not isinstance(state_manager, StateManagerRedis): - raise RuntimeError( - f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " - "(All states should already be available -- this is likely a bug).", - ) - - state_in_redis = await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), - top_level=False, - get_substates=True, - parent_state=parent_state_of_state_cls, - ) - - if not isinstance(state_in_redis, state_cls): - raise StateMismatchError( - f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." - ) - - return state_in_redis - async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE: """Get an instance of the state associated with this token. @@ -1737,7 +1727,7 @@ def _is_valid_type(events: Any) -> bool: f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) - def _as_state_update( + async def _as_state_update( self, handler: EventHandler, events: EventSpec | list[EventSpec] | None, @@ -1765,7 +1755,7 @@ def _as_state_update( try: # Get the delta after processing the event. - delta = state.get_delta() + delta = await _resolve_delta(state.get_delta()) state._clean() return StateUpdate( @@ -1865,24 +1855,28 @@ async def _process_event( # Handle async generators. if inspect.isasyncgen(events): async for event in events: - yield state._as_state_update(handler, event, final=False) - yield state._as_state_update(handler, events=None, final=True) + yield await state._as_state_update(handler, event, final=False) + yield await state._as_state_update(handler, events=None, final=True) # Handle regular generators. elif inspect.isgenerator(events): try: while True: - yield state._as_state_update(handler, next(events), final=False) + yield await state._as_state_update( + handler, next(events), final=False + ) except StopIteration as si: # the "return" value of the generator is not available # in the loop, we must catch StopIteration to access it if si.value is not None: - yield state._as_state_update(handler, si.value, final=False) - yield state._as_state_update(handler, events=None, final=True) + yield await state._as_state_update( + handler, si.value, final=False + ) + yield await state._as_state_update(handler, events=None, final=True) # Handle regular event chains. else: - yield state._as_state_update(handler, events, final=True) + yield await state._as_state_update(handler, events, final=True) # If an error occurs, throw a window alert. except Exception as ex: @@ -1892,7 +1886,7 @@ async def _process_event( prerequisites.get_and_validate_app().app.backend_exception_handler(ex) ) - yield state._as_state_update( + yield await state._as_state_update( handler, event_specs, final=True, @@ -1900,15 +1894,28 @@ async def _process_event( def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" + # Append expired computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._expired_computed_vars()) + # Append always dirty computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._always_dirty_computed_vars) + dirty_vars = self.dirty_vars while dirty_vars: calc_vars, dirty_vars = dirty_vars, set() - for cvar in self._dirty_computed_vars(from_vars=calc_vars): - self.dirty_vars.add(cvar) + for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars): + if state_name == self.get_full_name(): + defining_state = self + else: + defining_state = self._get_root_state().get_substate( + tuple(state_name.split(".")) + ) + defining_state.dirty_vars.add(cvar) dirty_vars.add(cvar) - actual_var = self.computed_vars.get(cvar) + actual_var = defining_state.computed_vars.get(cvar) if actual_var is not None: - actual_var.mark_dirty(instance=self) + actual_var.mark_dirty(instance=defining_state) + if defining_state is not self: + defining_state._mark_dirty() def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. @@ -1924,7 +1931,7 @@ def _expired_computed_vars(self) -> set[str]: def _dirty_computed_vars( self, from_vars: set[str] | None = None, include_backend: bool = True - ) -> set[str]: + ) -> set[tuple[str, str]]: """Determine ComputedVars that need to be recalculated based on the given vars. Args: @@ -1935,32 +1942,59 @@ def _dirty_computed_vars( Set of computed vars to include in the delta. """ return { - cvar + (state_name, cvar) for dirty_var in from_vars or self.dirty_vars - for cvar in self._computed_var_dependencies[dirty_var] + for state_name, cvar in self._var_dependencies.get(dirty_var, set()) if include_backend or not self.computed_vars[cvar]._backend } - @classmethod - def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: - """Determine substates which could be affected by dirty vars in this state. + async def _recursively_populate_dependent_substates( + self, + seen_classes: set[type[BaseState]] | None = None, + ) -> set[type[BaseState]]: + """Fetch all substates that have computed var dependencies on this state. + + Args: + seen_classes: set of classes that have already been seen to prevent infinite recursion. Returns: - Set of State classes that may need to be fetched to recalc computed vars. + The set of classes that were processed (mostly for testability). """ - # _always_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) - for substate_name in cls._always_dirty_substates - } - for dependent_substates in cls._substate_var_dependencies.values(): - fetch_substates.update( - { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) - for substate_name in dependent_substates - } + if seen_classes is None: + print( + f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:" + ) + seen_classes = set() + if type(self) in seen_classes: + return seen_classes + seen_classes.add(type(self)) + populated_substate_instances = {} + for substate_cls in { + self.get_class_substate((self.get_name(), *substate_name.split("."))) + for substate_name in self._always_dirty_substates + }: + # _always_dirty_substates need to be fetched to recalc computed vars. + if substate_cls not in populated_substate_instances: + print(f"fetching always dirty {substate_cls}") + populated_substate_instances[substate_cls] = await self.get_state( + substate_cls + ) + for dep_set in self._var_dependencies.values(): + for substate_name, _ in dep_set: + if substate_name == self.get_full_name(): + # Do NOT fetch our own state instance. + continue + substate_cls = self.get_root_state().get_class_substate(substate_name) + if substate_cls not in populated_substate_instances: + print(f"fetching dependent {substate_cls}") + populated_substate_instances[substate_cls] = await self.get_state( + substate_cls + ) + for substate in populated_substate_instances.values(): + await substate._recursively_populate_dependent_substates( + seen_classes=seen_classes, ) - return fetch_substates + return seen_classes def get_delta(self) -> Delta: """Get the delta for the state. @@ -1970,21 +2004,15 @@ def get_delta(self) -> Delta: """ delta = {} - # Apply dirty variables down into substates - self.dirty_vars.update(self._always_dirty_computed_vars) - self._mark_dirty() - + self._mark_dirty_computed_vars() frontend_computed_vars: set[str] = { name for name, cv in self.computed_vars.items() if not cv._backend } # Return the dirty vars for this instance, any cached/dependent computed vars, # and always dirty computed vars (cache=False) - delta_vars = ( - self.dirty_vars.intersection(self.base_vars) - .union(self.dirty_vars.intersection(frontend_computed_vars)) - .union(self._dirty_computed_vars(include_backend=False)) - .union(self._always_dirty_computed_vars) + delta_vars = self.dirty_vars.intersection(self.base_vars).union( + self.dirty_vars.intersection(frontend_computed_vars) ) subdelta: Dict[str, Any] = { @@ -2014,23 +2042,9 @@ def _mark_dirty(self): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state._mark_dirty() - # Append expired computed vars to dirty_vars to trigger recalculation - self.dirty_vars.update(self._expired_computed_vars()) - # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function self._mark_dirty_computed_vars() - self._mark_dirty_substates() - - def _mark_dirty_substates(self): - """Propagate dirty var / computed var status into substates.""" - substates = self.substates - for var in self.dirty_vars: - for substate_name in self._substate_var_dependencies[var]: - self.dirty_substates.add(substate_name) - substate = substates[substate_name] - substate.dirty_vars.add(var) - substate._mark_dirty() def _update_was_touched(self): """Update the _was_touched flag based on dirty_vars.""" @@ -2102,11 +2116,7 @@ def dict( The object as a dictionary. """ if include_computed: - # Apply dirty variables down into substates to allow never-cached ComputedVar to - # trigger recalculation of dependent vars - self.dirty_vars.update(self._always_dirty_computed_vars) - self._mark_dirty() - + self._mark_dirty_computed_vars() base_vars = { prop_name: self.get_value(prop_name) for prop_name in self.base_vars } @@ -3339,6 +3349,79 @@ async def _get_parent_state( ) return parent_state + async def _populate_parent_states( + self, calling_state: BaseState, target_state_cls: Type[BaseState] + ): + """Populate substates in the tree between the target_state_cls and common ancestor of calling_state. + + Args: + calling_state: The substate instance requesting subtree population. + target_state_cls: The class of the state to populate parent states for. + + Returns: + The parent state instance of target_state_cls. + """ + # Find the missing parent states up to the common ancestor. + ( + common_ancestor_name, + missing_parent_states, + ) = calling_state._determine_missing_parent_states(target_state_cls) + + # Fetch all missing parent states and link them up to the common ancestor. + parent_states_tuple = calling_state._get_parent_states() + root_state = parent_states_tuple[-1][1] + parent_states_by_name = dict(parent_states_tuple) + parent_state = parent_states_by_name[common_ancestor_name] + for parent_state_name in missing_parent_states: + try: + parent_state = root_state.get_substate(parent_state_name.split(".")) + # The requested state is already cached, do NOT fetch it again. + continue + except ValueError: + # The requested state is missing, fetch from redis. + pass + parent_state = await self.get_state( + token=_substate_key( + calling_state.router.session.client_token, parent_state_name + ), + top_level=False, + get_substates=False, + parent_state=parent_state, + ) + + # Return the direct parent of target_state_cls for subsequent linking. + return parent_state + + async def _link_arbitrary_state( + self, calling_state: BaseState, state_cls: Type[T_STATE] + ) -> T_STATE: + """Get a state instance from redis. + + Args: + calling_state: The state instance requesting the newly linked instance of state_cls. + state_cls: The class of the state to link into the tree. + + Returns: + The instance of state_cls associated with calling_state's client_token. + + Raises: + StateMismatchError: If the state instance is not of the expected type. + """ + # Fetch all missing parent states from redis. + parent_state_of_state_cls = await self._populate_parent_states( + calling_state, state_cls + ) + + # Then get the target state and all its substates. + state_in_redis = await self.get_state( + token=_substate_key(calling_state.router.session.client_token, state_cls), + top_level=False, + get_substates=True, + parent_state=parent_state_of_state_cls, + ) + + return state_in_redis + async def _populate_substates( self, token: str, @@ -3357,30 +3440,40 @@ async def _populate_substates( """ client_token, _ = _split_substate_key(token) + # Only _potentially_dirty_substates need to be fetched to recalc computed vars. + fetch_substates = state._get_potentially_dirty_states() if all_substates: # All substates are requested. - fetch_substates = state.get_substates() - else: - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._potentially_dirty_substates() + fetch_substates.update(state.get_substates()) tasks = {} + link_tasks = set() # Retrieve the necessary substates from redis. for substate_cls in fetch_substates: if substate_cls.get_name() in state.substates: continue substate_name = substate_cls.get_name() - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, + if substate_cls in state.get_substates(): + tasks[substate_name] = asyncio.create_task( + self.get_state( + token=_substate_key(client_token, substate_cls), + top_level=False, + get_substates=all_substates, + parent_state=state, + ) ) - ) + else: + try: + state._get_root_state().get_substate(substate_name.split(".")) + except ValueError: + # The requested state is missing, so fetch and link it (and its parents). + link_tasks.add( + asyncio.create_task(self._link_arbitrary_state(state, substate_cls)) + ) for substate_name, substate_task in tasks.items(): state.substates[substate_name] = await substate_task + await asyncio.gather(*link_tasks) @override async def get_state( @@ -4153,7 +4246,7 @@ def reload_state_module( if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._computed_var_dependencies = defaultdict(set) - state._substate_var_dependencies = defaultdict(set) + state._potentially_dirty_substates.discard(subclass.get_name()) + state._var_dependencies = {} state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 12254518763..6bc5b25c41d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1826,7 +1826,7 @@ class ComputedVar(Var[RETURN_TYPE]): _initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset()) # Explicit var dependencies to track - _static_deps: set[str] = dataclasses.field(default_factory=set) + _static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict) # Whether var dependencies should be auto-determined _auto_deps: bool = dataclasses.field(default=True) @@ -1901,21 +1901,40 @@ def __init__( object.__setattr__(self, "_update_interval", interval) - if deps is None: - deps = [] - else: + _static_deps = {} + if isinstance(deps, dict): + # Assume a dict is coming from _replace, so no special processing. + _static_deps = deps + elif deps is not None: for dep in deps: if isinstance(dep, Var): - continue - if isinstance(dep, str) and dep != "": - continue - raise TypeError( - "ComputedVar dependencies must be Var instances or var names (non-empty strings)." - ) + state_name = ( + all_var_data.state + if (all_var_data := dep._get_all_var_data()) + and all_var_data.state + else None + ) + var_name = ( + dep._js_expr[len(formatted_state_prefix) :] + if state_name + and ( + formatted_state_prefix := format_state_name(state_name) + + "." + ) + and dep._js_expr.startswith(formatted_state_prefix) + else dep._js_expr + ) + _static_deps.setdefault(state_name, set()).add(var_name) + elif isinstance(dep, str) and dep != "": + _static_deps.setdefault(None, set()).add(dep) + else: + raise TypeError( + "ComputedVar dependencies must be Var instances or var names (non-empty strings)." + ) object.__setattr__( self, "_static_deps", - {dep._js_expr if isinstance(dep, Var) else dep for dep in deps}, + _static_deps, ) object.__setattr__(self, "_auto_deps", auto_deps) @@ -2081,6 +2100,11 @@ def __get__(self, instance: BaseState | None, owner): setattr(instance, self._last_updated_attr, datetime.datetime.now()) value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + + return value + + def _check_deprecated_return_type(self, instance, value) -> None: if not _isinstance(value, self._var_type): console.deprecate( "mismatched-computed-var-return", @@ -2090,41 +2114,49 @@ def __get__(self, instance: BaseState | None, owner): "0.7.0", ) - return value - def _deps( self, - objclass: Type, + objclass: BaseState, obj: FunctionType | CodeType | None = None, - self_name: Optional[str] = None, - ) -> set[str]: + self_names: Optional[dict[str, str]] = None, + ) -> dict[str, set[str]]: """Determine var dependencies of this ComputedVar. - Save references to attributes accessed on "self". Recursively called - when the function makes a method call on "self" or define comprehensions - or nested functions that may reference "self". + Save references to attributes accessed on "self" or other fetched states. + + Recursively called when the function makes a method call on "self" or + define comprehensions or nested functions that may reference "self". Args: objclass: the class obj this ComputedVar is attached to. obj: the object to disassemble (defaults to the fget function). - self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions. + self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions. Returns: - A set of variable names accessed by the given obj. + A dictionary mapping state names to the set of variable names + accessed by the given obj. Raises: VarValueError: if the function references the get_state, parent_state, or substates attributes (cannot track deps in a related state, only implicitly via parent state). """ + from reflex.state import BaseState + + d = {} + if self._static_deps: + d.update(self._static_deps) + # None is a placeholder for the current state class. + if None in d: + d[objclass.get_full_name()] = d.pop(None) + if not self._auto_deps: - return self._static_deps - d = self._static_deps.copy() + return d if obj is None: fget = self._fget if fget is not None: obj = cast(FunctionType, fget) else: - return set() + return d with contextlib.suppress(AttributeError): # unbox functools.partial obj = cast(FunctionType, obj.func) # type: ignore @@ -2132,76 +2164,150 @@ def _deps( # unbox EventHandler obj = cast(FunctionType, obj.fn) # type: ignore - if self_name is None and isinstance(obj, FunctionType): + if self_names is None and isinstance(obj, FunctionType): try: # the first argument to the function is the name of "self" arg - self_name = obj.__code__.co_varnames[0] + self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()} except (AttributeError, IndexError): - self_name = None - if self_name is None: + self_names = None + if self_names is None: # cannot reference attributes on self if method takes no args - return set() + return d - invalid_names = ["get_state", "parent_state", "substates", "get_substate"] - self_is_top_of_stack = False + invalid_names = ["parent_state", "substates", "get_substate"] + self_on_top_of_stack = None + getting_state = False + getting_var = False for instruction in dis.get_instructions(obj): + if getting_state: + if instruction.opname == "LOAD_FAST": + raise VarValueError( + f"Dependency detection cannot identify get_state class from local var {instruction.argval}." + ) + if instruction.opname == "LOAD_GLOBAL": + # Special case: referencing state class from global scope. + getting_state = obj.__globals__.get(instruction.argval) + elif instruction.opname == "LOAD_DEREF": + # Special case: referencing state class from closure. + closure = dict(zip(obj.__code__.co_freevars, obj.__closure__)) + try: + getting_state = closure[instruction.argval].cell_contents + except ValueError as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." + ) from ve + elif instruction.opname == "STORE_FAST": + # Storing the result of get_state in a local variable. + if not isinstance(getting_state, type) or not issubclass( + getting_state, BaseState + ): + raise VarValueError( + f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." + ) + self_names[instruction.argval] = getting_state.get_full_name() + getting_state = False + continue # nothing else happens until we have identified the local var + if getting_var: + if instruction.opname == "CALL": + # get the original source code and eval it + start_line = getting_var[0].positions.lineno + start_column = getting_var[0].positions.col_offset + end_line = getting_var[-1].positions.end_lineno + end_column = getting_var[-1].positions.end_col_offset + source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line] + if len(source) > 1: + snipped_source = "".join( + [ + source[0][start_column:], + source[1:-2] if len(source) > 2 else "", + source[-1][:end_column] + ] + ) + else: + snipped_source = source[0][start_column:end_column] + the_var = eval(f"({snipped_source})", obj.__globals__) + print(the_var) + # code = source[start_line - 1] + # bytecode = bytearray((dis.opmap["RESUME"], 0)) + # for ins in getting_var: + # bytecode.append(ins.opcode) + # bytecode.append(ins.arg or 0 & 0xFF) + # bytecode.extend((dis.opmap["RETURN_VALUE"], 0)) + # bc = dis.Bytecode(obj) + # code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=()) + # breakpoint() + getting_var = False + elif isinstance(getting_var, list): + getting_var.append(instruction) + else: + getting_var = [instruction] + continue if ( instruction.opname in ("LOAD_FAST", "LOAD_DEREF") - and instruction.argval == self_name + and instruction.argval in self_names ): # bytecode loaded the class instance to the top of stack, next load instruction # is referencing an attribute on self - self_is_top_of_stack = True + self_on_top_of_stack = self_names[instruction.argval] continue - if self_is_top_of_stack and instruction.opname in ( + if self_on_top_of_stack and instruction.opname in ( "LOAD_ATTR", "LOAD_METHOD", ): - try: - ref_obj = getattr(objclass, instruction.argval) - except Exception: - ref_obj = None if instruction.argval in invalid_names: raise VarValueError( f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." ) + if instruction.argval == "get_state": + # Special case: arbitrary state access requested. + getting_state = True + continue + if instruction.argval == "get_var_value": + # Special case: arbitrary var access requested. + getting_var = True + continue + print(f"{self_on_top_of_stack=}") + target_state = objclass.get_root_state().get_class_substate( + self_on_top_of_stack + ) + try: + ref_obj = getattr(target_state, instruction.argval) + except Exception: + ref_obj = None if callable(ref_obj): # recurse into callable attributes - d.update( - self._deps( - objclass=objclass, - obj=ref_obj, - ) - ) + for state_name, dep_name in self._deps( + objclass=target_state, + obj=ref_obj, + ).items(): + d.setdefault(state_name, set()).update(dep_name) # recurse into property fget functions elif isinstance(ref_obj, property) and not isinstance( ref_obj, ComputedVar ): - d.update( - self._deps( - objclass=objclass, - obj=ref_obj.fget, # type: ignore - ) - ) + for state_name, dep_name in self._deps( + objclass=target_state, + obj=ref_obj.fget, # type: ignore + ).items(): + d.setdefault(state_name, set()).update(dep_name) elif ( - instruction.argval in objclass.backend_vars - or instruction.argval in objclass.vars + instruction.argval in target_state.backend_vars + or instruction.argval in target_state.vars ): # var access - d.add(instruction.argval) + d.setdefault(self_on_top_of_stack, set()).add(instruction.argval) elif instruction.opname == "LOAD_CONST" and isinstance( instruction.argval, CodeType ): # recurse into nested functions / comprehensions, which can reference # instance attributes from the outer scope - d.update( - self._deps( - objclass=objclass, - obj=instruction.argval, - self_name=self_name, - ) - ) - self_is_top_of_stack = False + for state_name, dep_name in self._deps( + objclass=objclass, + obj=instruction.argval, + self_names=self_names, + ).items(): + d.setdefault(state_name, set()).update(dep_name) + self_on_top_of_stack = None return d def mark_dirty(self, instance) -> None: @@ -2249,6 +2355,60 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]): pass +@dataclasses.dataclass( + eq=False, + frozen=True, + init=False, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class AsyncComputedVar(ComputedVar[RETURN_TYPE]): + """A computed var that wraps a coroutinefunction.""" + + _fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field( + default_factory=lambda: lambda _: None + ) # type: ignore + + def __get__(self, instance: BaseState | None, owner): + """Get the ComputedVar value. + + If the value is already cached on the instance, return the cached value. + + Args: + instance: the instance of the class accessing this computed var. + owner: the class that this descriptor is attached to. + + Returns: + The value of the var for the given instance. + """ + if instance is None: + return super(AsyncComputedVar, self).__get__(instance, owner) + + if not self._cache: + + async def _awaitable_result(): + value = await self.fget(instance) + self._check_deprecated_return_type(instance, value) + + return _awaitable_result() + else: + # handle caching + async def _awaitable_result(): + if not hasattr(instance, self._cache_attr) or self.needs_update( + instance + ): + # Set cache attr on state instance. + setattr(instance, self._cache_attr, await self.fget(instance)) + # Ensure the computed var gets serialized to redis. + instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + return value + + return _awaitable_result() + + if TYPE_CHECKING: BASE_STATE = TypeVar("BASE_STATE", bound=BaseState) @@ -2315,10 +2475,27 @@ def computed_var( raise VarDependencyError("Cannot track dependencies without caching.") if fget is not None: - return ComputedVar(fget, cache=cache) + if inspect.iscoroutinefunction(fget): + computed_var_cls = AsyncComputedVar + else: + computed_var_cls = ComputedVar + return computed_var_cls( + fget, + initial_value=initial_value, + cache=cache, + deps=deps, + auto_deps=auto_deps, + interval=interval, + backend=backend, + **kwargs, + ) def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar: - return ComputedVar( + if inspect.iscoroutinefunction(fget): + computed_var_cls = AsyncComputedVar + else: + computed_var_cls = ComputedVar + return computed_var_cls( fget, initial_value=initial_value, cache=cache, diff --git a/tests/units/test_app.py b/tests/units/test_app.py index f805f83eca7..bd9872949a6 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool): assert app.pages.keys() == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { - constants.ROUTER + EmptyState.get_full_name(): {constants.ROUTER}, } - assert constants.ROUTER in app.state()._computed_var_dependencies + assert constants.ROUTER in app.state()._var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): @@ -997,9 +997,9 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert arg_name in app.state.vars assert arg_name in app.state.computed_vars assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == { - constants.ROUTER + DynamicState.get_full_name(): {constants.ROUTER}, } - assert constants.ROUTER in app.state()._computed_var_dependencies + assert constants.ROUTER in app.state()._var_dependencies substate_token = _substate_key(token, DynamicState) sid = "mock_sid" @@ -1557,6 +1557,16 @@ def foo(self) -> str: def bar(self) -> str: return "bar" + class Child1(ValidDepState): + @computed_var(deps=["base", ValidDepState.bar]) + def other(self) -> str: + return "other" + + class Child2(ValidDepState): + @computed_var(deps=["base", Child1.other]) + def other(self) -> str: + return "other" + app.state = ValidDepState app._compile() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 19f3e42392f..c5e2b12879b 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1170,13 +1170,11 @@ def rendered_var(self) -> str: ms = MainState() # Initially there are no dirty computed vars. - assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} - assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} - assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"} + assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")} + assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")} + assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")} assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == { - "flag", - "t1", - "t2", + MainState.get_full_name(): {"flag", "t1", "t2"} } @@ -1371,7 +1369,7 @@ def cached_x_side_effect(self) -> int: assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() - assert "cached_x_side_effect" in s._computed_var_dependencies["x"] + assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() @@ -1461,15 +1459,15 @@ def comp_z(self) -> List[bool]: return [z in self._z for z in range(5)] cs = ComputedState() - assert cs._computed_var_dependencies["v"] == { - "comp_v", - "comp_v_backend", - "comp_v_via_property", + assert cs._var_dependencies["v"] == { + (ComputedState.get_full_name(), "comp_v"), + (ComputedState.get_full_name(), "comp_v_backend"), + (ComputedState.get_full_name(), "comp_v_via_property"), } - assert cs._computed_var_dependencies["w"] == {"comp_w"} - assert cs._computed_var_dependencies["x"] == {"comp_x"} - assert cs._computed_var_dependencies["y"] == {"comp_y"} - assert cs._computed_var_dependencies["_z"] == {"comp_z"} + assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")} + assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")} + assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")} + assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")} def test_backend_method(): @@ -3182,6 +3180,7 @@ class GreatGrandchild3(Grandchild3): RxState = State +@pytest.mark.skip(reason="This test is maybe not relevant anymore.") def test_potentially_dirty_substates(): """Test that potentially_dirty_substates returns the correct substates. @@ -3203,7 +3202,8 @@ def bar(self) -> str: assert C1._potentially_dirty_substates() == set() -def test_router_var_dep() -> None: +@pytest.mark.asyncio +async def test_router_var_dep() -> None: """Test that router var dependencies are correctly tracked.""" class RouterVarParentState(State): @@ -3221,13 +3221,9 @@ def foo(self) -> str: foo = RouterVarDepState.computed_vars["foo"] State._init_var_dependency_dicts() - assert foo._deps(objclass=RouterVarDepState) == {"router"} - assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} - assert RouterVarParentState._substate_var_dependencies == { - "router": {RouterVarDepState.get_name()} - } - assert RouterVarDepState._computed_var_dependencies == { - "router": {"foo"}, + assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}} + assert State._var_dependencies == { + "router": {(RouterVarDepState.get_full_name(), "foo")} } rx_state = State() @@ -3240,11 +3236,15 @@ def foo(self) -> str: state.parent_state = parent_state parent_state.substates = {RouterVarDepState.get_name(): state} + populated_substate_classes = await rx_state._recursively_populate_dependent_substates() + assert populated_substate_classes == {State, RouterVarDepState} + assert state.dirty_vars == set() # Reassign router var state.router = state.router - assert state.dirty_vars == {"foo", "router"} + assert rx_state.dirty_vars == {"router"} + assert state.dirty_vars == {"foo"} assert parent_state.dirty_substates == {RouterVarDepState.get_name()} @@ -3803,3 +3803,74 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str): # Generic Var with no state with pytest.raises(UnretrievableVarValueError): await state.get_var_value(rx.Var("undefined")) + + +@pytest.mark.asyncio +async def test_async_computed_var_get_state(mock_app: rx.App, token: str): + """A test where an async computed var depends on a var in another state. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class Parent(BaseState): + """A root state like rx.State.""" + + parent_var: int = 0 + + class Child2(Parent): + """An unconnected child state.""" + + pass + + class Child3(Parent): + """A child state with a computed var causing it to be pre-fetched. + + If child3_var gets set to a value, and `get_state` erroneously + re-fetches it from redis, the value will be lost. + """ + + child3_var: int = 0 + + @rx.var(cache=True) + def v(self): + return self.child3_var + + class Child(Parent): + """A state simulating UpdateVarsInternalState.""" + + @rx.var(cache=True) + async def v(self): + p = await self.get_state(Parent) + child3 = await self.get_state(Child3) + return child3.child3_var + p.parent_var + + mock_app.state_manager.state = mock_app.state = Parent + + # Get the top level state via unconnected sibling. + root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + # Set value in parent_var to assert it does not get refetched later. + root.parent_var = 1 + + if isinstance(mock_app.state_manager, StateManagerRedis): + # When redis is used, only states with uncached computed vars are pre-fetched. + assert Child2.get_name() not in root.substates + assert Child3.get_name() not in root.substates + + # Get the unconnected sibling state, which will be used to `get_state` other instances. + child = root.get_substate(Child.get_full_name().split(".")) + + # Get an uncached child state. + child2 = await child.get_state(Child2) + assert child2.parent_var == 1 + + # Set value on already-cached Child3 state (prefetched because it has a Computed Var). + child3 = await child.get_state(Child3) + child3.child3_var = 1 + + assert await child.v == 2 + assert await child.v == 2 + root.parent_var = 2 + assert await child.v == 3 + diff --git a/tests/units/test_var.py b/tests/units/test_var.py index a8e9cd88cec..e34fb88710e 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -15,6 +15,7 @@ from reflex.utils.imports import ImportVar from reflex.vars import VarData from reflex.vars.base import ( + AsyncComputedVar, ComputedVar, LiteralVar, Var, @@ -1808,9 +1809,9 @@ def cv_fget(state: BaseState) -> int: @pytest.mark.parametrize( "deps,expected", [ - (["a"], {"a"}), - (["b"], {"b"}), - ([ComputedVar(fget=cv_fget)], {"cv_fget"}), + (["a"], {None: {"a"}}), + (["b"], {None: {"b"}}), + ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}), ], ) def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]): @@ -1856,3 +1857,25 @@ class TestState(BaseState): single_var = Var.create(Email()) assert single_var._var_type == Email + + +@pytest.mark.asyncio +async def test_async_computed_var(): + side_effect_counter = 0 + + class AsyncComputedVarState(BaseState): + v: int = 1 + + @computed_var(cache=True) + async def async_computed_var(self) -> int: + nonlocal side_effect_counter + side_effect_counter += 1 + return self.v + 1 + + my_state = AsyncComputedVarState() + assert await my_state.async_computed_var == 2 + assert await my_state.async_computed_var == 2 + my_state.v = 2 + assert await my_state.async_computed_var == 3 + assert await my_state.async_computed_var == 3 + assert side_effect_counter == 2 From 3da64264ba56194f71fb22a7e6ba6c3c09e86b76 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 24 Jan 2025 11:46:37 -0800 Subject: [PATCH 02/12] Save the var from get_var_name --- reflex/state.py | 5 +++-- reflex/vars/base.py | 33 ++++++++++----------------------- tests/units/test_state.py | 26 +++++++++++++++++++------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 6ef3ff3e8db..6705f8a5043 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -15,7 +15,6 @@ import typing import uuid from abc import ABC, abstractmethod -from collections import defaultdict from hashlib import md5 from pathlib import Path from types import FunctionType, MethodType @@ -3468,7 +3467,9 @@ async def _populate_substates( except ValueError: # The requested state is missing, so fetch and link it (and its parents). link_tasks.add( - asyncio.create_task(self._link_arbitrary_state(state, substate_cls)) + asyncio.create_task( + self._link_arbitrary_state(state, substate_cls) + ) ) for substate_name, substate_task in tasks.items(): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 6bc5b25c41d..7d2e2af8574 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1914,16 +1914,10 @@ def __init__( and all_var_data.state else None ) - var_name = ( - dep._js_expr[len(formatted_state_prefix) :] - if state_name - and ( - formatted_state_prefix := format_state_name(state_name) - + "." - ) - and dep._js_expr.startswith(formatted_state_prefix) - else dep._js_expr - ) + if all_var_data is not None: + var_name = all_var_data.field_name + else: + var_name = dep._js_expr _static_deps.setdefault(state_name, set()).add(var_name) elif isinstance(dep, str) and dep != "": _static_deps.setdefault(None, set()).add(dep) @@ -2214,28 +2208,22 @@ def _deps( start_column = getting_var[0].positions.col_offset end_line = getting_var[-1].positions.end_lineno end_column = getting_var[-1].positions.end_col_offset - source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line] + source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[ + start_line - 1 : end_line + ] if len(source) > 1: snipped_source = "".join( [ source[0][start_column:], source[1:-2] if len(source) > 2 else "", - source[-1][:end_column] + source[-1][:end_column], ] ) else: snipped_source = source[0][start_column:end_column] the_var = eval(f"({snipped_source})", obj.__globals__) - print(the_var) - # code = source[start_line - 1] - # bytecode = bytearray((dis.opmap["RESUME"], 0)) - # for ins in getting_var: - # bytecode.append(ins.opcode) - # bytecode.append(ins.arg or 0 & 0xFF) - # bytecode.extend((dis.opmap["RETURN_VALUE"], 0)) - # bc = dis.Bytecode(obj) - # code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=()) - # breakpoint() + the_var_data = the_var._get_all_var_data() + d.setdefault(the_var_data.state, set()).add(the_var_data.field_name) getting_var = False elif isinstance(getting_var, list): getting_var.append(instruction) @@ -2266,7 +2254,6 @@ def _deps( # Special case: arbitrary var access requested. getting_var = True continue - print(f"{self_on_top_of_stack=}") target_state = objclass.get_root_state().get_class_substate( self_on_top_of_stack ) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c5e2b12879b..10713b2d25b 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1170,9 +1170,15 @@ def rendered_var(self) -> str: ms = MainState() # Initially there are no dirty computed vars. - assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")} - assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")} - assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")} + assert ms._dirty_computed_vars(from_vars={"flag"}) == { + (MainState.get_full_name(), "rendered_var") + } + assert ms._dirty_computed_vars(from_vars={"t2"}) == { + (MainState.get_full_name(), "rendered_var") + } + assert ms._dirty_computed_vars(from_vars={"t1"}) == { + (MainState.get_full_name(), "rendered_var") + } assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == { MainState.get_full_name(): {"flag", "t1", "t2"} } @@ -1369,7 +1375,10 @@ def cached_x_side_effect(self) -> int: assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() - assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"] + assert ( + HandlerState.get_full_name(), + "cached_x_side_effect", + ) in s._var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() @@ -3221,7 +3230,9 @@ def foo(self) -> str: foo = RouterVarDepState.computed_vars["foo"] State._init_var_dependency_dicts() - assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}} + assert foo._deps(objclass=RouterVarDepState) == { + RouterVarDepState.get_full_name(): {"router"} + } assert State._var_dependencies == { "router": {(RouterVarDepState.get_full_name(), "foo")} } @@ -3236,7 +3247,9 @@ def foo(self) -> str: state.parent_state = parent_state parent_state.substates = {RouterVarDepState.get_name(): state} - populated_substate_classes = await rx_state._recursively_populate_dependent_substates() + populated_substate_classes = ( + await rx_state._recursively_populate_dependent_substates() + ) assert populated_substate_classes == {State, RouterVarDepState} assert state.dirty_vars == set() @@ -3873,4 +3886,3 @@ async def v(self): assert await child.v == 2 root.parent_var = 2 assert await child.v == 3 - From ca3c0fd723cc898c5bd902fdea70039e2b32c1b9 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Jan 2025 10:00:23 -0800 Subject: [PATCH 03/12] flatten StateManagerRedis.get_state algorithm simplify fetching of states and avoid repeatedly fetching the same state --- reflex/state.py | 392 +++++++++----------------------------- tests/units/test_state.py | 36 ++-- tests/units/test_var.py | 1 - 3 files changed, 111 insertions(+), 318 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 12c5534fb14..1f162bf5db3 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1465,65 +1465,6 @@ def _get_potentially_dirty_states(cls) -> set[type[BaseState]]: } ) - @classmethod - def _get_common_ancestor(cls, other: Type[BaseState]) -> str: - """Find the name of the nearest common ancestor shared by this and the other state. - - Args: - other: The other state. - - Returns: - Full name of the nearest common ancestor. - """ - common_ancestor_parts = [] - for part1, part2 in zip( - cls.get_full_name().split("."), - other.get_full_name().split("."), - ): - if part1 != part2: - break - common_ancestor_parts.append(part1) - return ".".join(common_ancestor_parts) - - @classmethod - def _determine_missing_parent_states( - cls, target_state_cls: Type[BaseState] - ) -> tuple[str, list[str]]: - """Determine the missing parent states between the target_state_cls and common ancestor of this state. - - Args: - target_state_cls: The class of the state to find missing parent states for. - - Returns: - The name of the common ancestor and the list of missing parent states. - """ - common_ancestor_name = cls._get_common_ancestor(target_state_cls) - common_ancestor_parts = common_ancestor_name.split(".") - target_state_parts = tuple(target_state_cls.get_full_name().split(".")) - relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :] - - # Determine which parent states to fetch from the common ancestor down to the target_state_cls. - fetch_parent_states = [common_ancestor_name] - for relative_parent_state_name in relative_target_state_parts: - fetch_parent_states.append( - ".".join((fetch_parent_states[-1], relative_parent_state_name)) - ) - - return common_ancestor_name, fetch_parent_states[1:-1] - - def _get_parent_states(self) -> list[tuple[str, BaseState]]: - """Get all parent state instances up to the root of the state tree. - - Returns: - A list of tuples containing the name and the instance of each parent state. - """ - parent_states_with_name = [] - parent_state = self - while parent_state.parent_state is not None: - parent_state = parent_state.parent_state - parent_states_with_name.append((parent_state.get_full_name(), parent_state)) - return parent_states_with_name - def _get_root_state(self) -> BaseState: """Get the root state of the state tree. @@ -1555,9 +1496,10 @@ async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) - state_in_redis = await state_manager._link_arbitrary_state( - self, - state_cls, + state_in_redis = await state_manager.get_state( + token=_substate_key(self.router.session.client_token, state_cls), + top_level=False, + for_state_instance=self, ) if not isinstance(state_in_redis, state_cls): @@ -1944,54 +1886,6 @@ def _dirty_computed_vars( if include_backend or not self.computed_vars[cvar]._backend } - async def _recursively_populate_dependent_substates( - self, - seen_classes: set[type[BaseState]] | None = None, - ) -> set[type[BaseState]]: - """Fetch all substates that have computed var dependencies on this state. - - Args: - seen_classes: set of classes that have already been seen to prevent infinite recursion. - - Returns: - The set of classes that were processed (mostly for testability). - """ - if seen_classes is None: - print( - f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:" - ) - seen_classes = set() - if type(self) in seen_classes: - return seen_classes - seen_classes.add(type(self)) - populated_substate_instances = {} - for substate_cls in { - self.get_class_substate((self.get_name(), *substate_name.split("."))) - for substate_name in self._always_dirty_substates - }: - # _always_dirty_substates need to be fetched to recalc computed vars. - if substate_cls not in populated_substate_instances: - print(f"fetching always dirty {substate_cls}") - populated_substate_instances[substate_cls] = await self.get_state( - substate_cls - ) - for dep_set in self._var_dependencies.values(): - for substate_name, _ in dep_set: - if substate_name == self.get_full_name(): - # Do NOT fetch our own state instance. - continue - substate_cls = self.get_root_state().get_class_substate(substate_name) - if substate_cls not in populated_substate_instances: - print(f"fetching dependent {substate_cls}") - populated_substate_instances[substate_cls] = await self.get_state( - substate_cls - ) - for substate in populated_substate_instances.values(): - await substate._recursively_populate_dependent_substates( - seen_classes=seen_classes, - ) - return seen_classes - def get_delta(self) -> Delta: """Get the delta for the state. @@ -3316,179 +3210,74 @@ class StateManagerRedis(StateManager): b"evicted", } - async def _get_parent_state( - self, token: str, state: BaseState | None = None - ) -> BaseState | None: - """Get the parent state for the state requested in the token. - - Args: - token: The token to get the state for (_substate_key). - state: The state instance to get parent state for. - - Returns: - The parent state for the state requested by the token or None if there is no such parent. - """ - parent_state = None - client_token, state_path = _split_substate_key(token) - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - cached_substates = None - if state is not None: - cached_substates = [state] - # Retrieve the parent state to populate event handlers onto this substate. - parent_state = await self.get_state( - token=_substate_key(client_token, parent_state_name), - top_level=False, - get_substates=False, - cached_substates=cached_substates, + def _get_required_state_classes( + self, + target_state_cls: Type[BaseState], + subclasses: bool = False, + required_state_classes: set[Type[BaseState]] | None = None, + ) -> set[Type[BaseState]]: + if required_state_classes is None: + required_state_classes = set() + # Get the substates if requested. + if subclasses: + for substate in target_state_cls.get_substates(): + self._get_required_state_classes( + substate, + subclasses=True, + required_state_classes=required_state_classes, + ) + if target_state_cls in required_state_classes: + return required_state_classes + required_state_classes.add(target_state_cls) + + # Get dependent substates. + for pd_substates in target_state_cls._get_potentially_dirty_states(): + self._get_required_state_classes( + pd_substates, + subclasses=False, + required_state_classes=required_state_classes, ) - return parent_state - - async def _populate_parent_states( - self, calling_state: BaseState, target_state_cls: Type[BaseState] - ): - """Populate substates in the tree between the target_state_cls and common ancestor of calling_state. - Args: - calling_state: The substate instance requesting subtree population. - target_state_cls: The class of the state to populate parent states for. - - Returns: - The parent state instance of target_state_cls. - """ - # Find the missing parent states up to the common ancestor. - ( - common_ancestor_name, - missing_parent_states, - ) = calling_state._determine_missing_parent_states(target_state_cls) - - # Fetch all missing parent states and link them up to the common ancestor. - parent_states_tuple = calling_state._get_parent_states() - root_state = parent_states_tuple[-1][1] - parent_states_by_name = dict(parent_states_tuple) - parent_state = parent_states_by_name[common_ancestor_name] - for parent_state_name in missing_parent_states: - try: - parent_state = root_state.get_substate(parent_state_name.split(".")) - # The requested state is already cached, do NOT fetch it again. - continue - except ValueError: - # The requested state is missing, fetch from redis. - pass - parent_state = await self.get_state( - token=_substate_key( - calling_state.router.session.client_token, parent_state_name - ), - top_level=False, - get_substates=False, - parent_state=parent_state, + # Get the parent state if it exists. + if parent_state := target_state_cls.get_parent_state(): + self._get_required_state_classes( + parent_state, + subclasses=False, + required_state_classes=required_state_classes, ) + return required_state_classes - # Return the direct parent of target_state_cls for subsequent linking. - return parent_state - - async def _link_arbitrary_state( - self, calling_state: BaseState, state_cls: Type[T_STATE] - ) -> T_STATE: - """Get a state instance from redis. - - Args: - calling_state: The state instance requesting the newly linked instance of state_cls. - state_cls: The class of the state to link into the tree. - - Returns: - The instance of state_cls associated with calling_state's client_token. - - Raises: - StateMismatchError: If the state instance is not of the expected type. - """ - # Fetch all missing parent states from redis. - parent_state_of_state_cls = await self._populate_parent_states( - calling_state, state_cls - ) - - # Then get the target state and all its substates. - state_in_redis = await self.get_state( - token=_substate_key(calling_state.router.session.client_token, state_cls), - top_level=False, - get_substates=True, - parent_state=parent_state_of_state_cls, - ) - - return state_in_redis - - async def _populate_substates( + def _get_populated_states( self, - token: str, - state: BaseState, - all_substates: bool = False, - ): - """Fetch and link substates for the given state instance. - - There is no return value; the side-effect is that `state` will have `substates` populated, - and each substate will have its `parent_state` set to `state`. - - Args: - token: The token to get the state for. - state: The state instance to populate substates for. - all_substates: Whether to fetch all substates or just required substates. - """ - client_token, _ = _split_substate_key(token) - - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._get_potentially_dirty_states() - if all_substates: - # All substates are requested. - fetch_substates.update(state.get_substates()) - - tasks = {} - link_tasks = set() - # Retrieve the necessary substates from redis. - for substate_cls in fetch_substates: - if substate_cls.get_name() in state.substates: - continue - substate_name = substate_cls.get_name() - if substate_cls in state.get_substates(): - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, - ) - ) - else: - try: - state._get_root_state().get_substate(substate_name.split(".")) - except ValueError: - # The requested state is missing, so fetch and link it (and its parents). - link_tasks.add( - asyncio.create_task( - self._link_arbitrary_state(state, substate_cls) - ) - ) - - for substate_name, substate_task in tasks.items(): - state.substates[substate_name] = await substate_task - await asyncio.gather(*link_tasks) + target_state: BaseState, + populated_states: dict[str, BaseState] | None = None, + ) -> dict[str, BaseState]: + if populated_states is None: + populated_states = {} + if target_state.get_full_name() in populated_states: + return populated_states + populated_states[target_state.get_full_name()] = target_state + for substate in target_state.substates.values(): + self._get_populated_states(substate, populated_states=populated_states) + if target_state.parent_state is not None: + self._get_populated_states( + target_state.parent_state, populated_states=populated_states + ) + return populated_states @override async def get_state( self, token: str, top_level: bool = True, - get_substates: bool = True, - parent_state: BaseState | None = None, - cached_substates: list[BaseState] | None = None, + for_state_instance: BaseState | None = None, ) -> BaseState: """Get the state for a token. Args: token: The token to get the state for. top_level: If true, return an instance of the top-level state (self.state). - get_substates: If true, also retrieve substates. - parent_state: If provided, use this parent_state instead of getting it from redis. - cached_substates: If provided, attach these substates to the state. + for_state_instance: If provided, attach the requested states to this existing state tree. Returns: The state for the token. @@ -3497,7 +3286,7 @@ async def get_state( RuntimeError: when the state_cls is not specified in the token """ # Split the actual token from the fully qualified substate name. - _, state_path = _split_substate_key(token) + token, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. state_cls = self.state.get_class_substate(state_path) @@ -3506,37 +3295,44 @@ async def get_state( f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" ) - # The deserialized or newly created (sub)state instance. - state = None - - # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) - - if redis_state is not None: - # Deserialize the substate. - with contextlib.suppress(StateSchemaMismatchError): - state = BaseState._deserialize(data=redis_state) - if state is None: - # Key didn't exist or schema mismatch so create a new instance for this token. - state = state_cls( - init_substates=False, - _reflex_internal_init=True, - ) - # Populate parent state if missing and requested. - if parent_state is None: - parent_state = await self._get_parent_state(token, state) - # Set up Bidirectional linkage between this state and its parent. - if parent_state is not None: - parent_state.substates[state.get_name()] = state - state.parent_state = parent_state - # Avoid fetching substates multiple times. - if cached_substates: - for substate in cached_substates: - state.substates[substate.get_name()] = substate - if substate.parent_state is None: - substate.parent_state = state - # Populate substates if requested. - await self._populate_substates(token, state, all_substates=get_substates) + # Determine which states we already have. + flat_state_tree: dict[str, BaseState] = ( + self._get_populated_states(for_state_instance) if for_state_instance else {} + ) + + # Determine which states from the tree need to be fetched. + required_state_classes = self._get_required_state_classes( + state_cls, subclasses=True + ) - {type(s) for s in flat_state_tree.values()} + + for state_cls in sorted( + required_state_classes, key=lambda x: x.get_full_name() + ): + state = None + redis_state = await self.redis.get(_substate_key(token, state_cls)) + + if redis_state is not None: + # Deserialize the substate. + with contextlib.suppress(StateSchemaMismatchError): + state = BaseState._deserialize(data=redis_state) + if state is None: + # Key didn't exist or schema mismatch so create a new instance for this token. + state = state_cls( + init_substates=False, + _reflex_internal_init=True, + ) + flat_state_tree[state.get_full_name()] = state + if state.get_parent_state() is not None: + parent_state_name, _dot, state_name = state.get_full_name().rpartition( + "." + ) + parent_state = flat_state_tree.get(parent_state_name) + if parent_state is None: + raise Exception( + f"Parent state should get fetched first... got {state.get_full_name()} instead" + ) + parent_state.substates[state_name] = state + state.parent_state = parent_state # To retain compatibility with previous implementation, by default, we return # the top-level state by chasing `parent_state` pointers up the tree. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 6086799e115..0d9d438eaa1 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3212,8 +3212,13 @@ def bar(self) -> str: @pytest.mark.asyncio -async def test_router_var_dep() -> None: - """Test that router var dependencies are correctly tracked.""" +async def test_router_var_dep(state_manager: StateManager, token: str) -> None: + """Test that router var dependencies are correctly tracked. + + Args: + state_manager: A state manager. + token: A token. + """ class RouterVarParentState(State): """A parent state for testing router var dependency.""" @@ -3233,24 +3238,17 @@ def foo(self) -> str: assert foo._deps(objclass=RouterVarDepState) == { RouterVarDepState.get_full_name(): {"router"} } - assert State._var_dependencies == { - "router": {(RouterVarDepState.get_full_name(), "foo")} - } - - rx_state = State() - parent_state = RouterVarParentState() - state = RouterVarDepState() - - # link states - rx_state.substates = {RouterVarParentState.get_name(): parent_state} - parent_state.parent_state = rx_state - state.parent_state = parent_state - parent_state.substates = {RouterVarDepState.get_name(): state} + assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[ + "router" + ] - populated_substate_classes = ( - await rx_state._recursively_populate_dependent_substates() - ) - assert populated_substate_classes == {State, RouterVarDepState} + # Get state from state manager. + state_manager.state = State + rx_state = await state_manager.get_state(_substate_key(token, State)) + assert RouterVarParentState.get_name() in rx_state.substates + parent_state = rx_state.substates[RouterVarParentState.get_name()] + assert RouterVarDepState.get_name() in parent_state.substates + state = parent_state.substates[RouterVarDepState.get_name()] assert state.dirty_vars == set() diff --git a/tests/units/test_var.py b/tests/units/test_var.py index ab396b15e45..30fbd4e9b6d 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -18,7 +18,6 @@ from reflex.utils.imports import ImportVar from reflex.vars import VarData from reflex.vars.base import ( - AsyncComputedVar, ComputedVar, LiteralVar, Var, From 5143d74dd0129ad59e00f2f31af7611c02ebfb5d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Jan 2025 10:23:24 -0800 Subject: [PATCH 04/12] Get all the states in a single redis round-trip --- reflex/state.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 1f162bf5db3..046b7b42689 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3301,15 +3301,21 @@ async def get_state( ) # Determine which states from the tree need to be fetched. - required_state_classes = self._get_required_state_classes( - state_cls, subclasses=True - ) - {type(s) for s in flat_state_tree.values()} + required_state_classes = sorted( + self._get_required_state_classes(state_cls, subclasses=True) + - {type(s) for s in flat_state_tree.values()}, + key=lambda x: x.get_full_name(), + ) + + redis_pipeline = self.redis.pipeline() + for state_cls in required_state_classes: + redis_pipeline.get(_substate_key(token, state_cls)) - for state_cls in sorted( - required_state_classes, key=lambda x: x.get_full_name() + for state_cls, redis_state in zip( + required_state_classes, + await redis_pipeline.execute(), ): state = None - redis_state = await self.redis.get(_substate_key(token, state_cls)) if redis_state is not None: # Deserialize the substate. From 8f4b12d29f76c0464f7c397d25baf56612ef9b4f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Jan 2025 10:30:31 -0800 Subject: [PATCH 05/12] update docstrings in StateManagerRedis --- reflex/state.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 046b7b42689..afa9dedb4b8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3216,6 +3216,19 @@ def _get_required_state_classes( subclasses: bool = False, required_state_classes: set[Type[BaseState]] | None = None, ) -> set[Type[BaseState]]: + """Recursively determine which states are required to fetch the target state. + + This will always include potentially dirty substates that depend on vars + in the target_state_cls. + + Args: + target_state_cls: The target state class being fetched. + subclasses: Whether to include subclasses of the target state. + required_state_classes: Recursive argument tracking state classes that have already been seen. + + Returns: + The set of state classes required to fetch the target state. + """ if required_state_classes is None: required_state_classes = set() # Get the substates if requested. @@ -3252,6 +3265,15 @@ def _get_populated_states( target_state: BaseState, populated_states: dict[str, BaseState] | None = None, ) -> dict[str, BaseState]: + """Recursively determine which states from target_state are already fetched. + + Args: + target_state: The state to check for populated states. + populated_states: Recursive argument tracking states seen in previous calls. + + Returns: + A dictionary of state full name to state instance. + """ if populated_states is None: populated_states = {} if target_state.get_full_name() in populated_states: @@ -3283,7 +3305,8 @@ async def get_state( The state for the token. Raises: - RuntimeError: when the state_cls is not specified in the token + RuntimeError: when the state_cls is not specified in the token, or when the parent state for a + requested state was not fetched. """ # Split the actual token from the fully qualified substate name. token, state_path = _split_substate_key(token) @@ -3334,8 +3357,10 @@ async def get_state( ) parent_state = flat_state_tree.get(parent_state_name) if parent_state is None: - raise Exception( - f"Parent state should get fetched first... got {state.get_full_name()} instead" + raise RuntimeError( + f"Parent state for {state.get_full_name()} was not found " + "in the state tree, but should have already been fetched. " + "This is a bug", ) parent_state.substates[state_name] = state state.parent_state = parent_state From aa5b3037d7ed47a0906153f29547e2ded8a6fc80 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Jan 2025 15:30:27 -0800 Subject: [PATCH 06/12] Move computed var dep tracking to separate module --- reflex/state.py | 16 ++- reflex/vars/base.py | 243 +++++++++++---------------------- reflex/vars/dep_tracking.py | 265 ++++++++++++++++++++++++++++++++++++ 3 files changed, 357 insertions(+), 167 deletions(-) create mode 100644 reflex/vars/dep_tracking.py diff --git a/reflex/state.py b/reflex/state.py index afa9dedb4b8..e2bdb79f5f9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -797,7 +797,9 @@ def _init_var_dependency_dicts(cls): *defining_state_cls.inherited_vars, *defining_state_cls.inherited_backend_vars, }: - defining_state_cls = defining_state_cls.get_parent_state() + parent_state = defining_state_cls.get_parent_state() + if parent_state is not None: + defining_state_cls = parent_state defining_state_cls._var_dependencies.setdefault(dvar, set()).add( (cls.get_full_name(), cvar_name) ) @@ -2721,7 +2723,7 @@ async def get_state(self, state_cls: Type[BaseState]) -> BaseState: await self.__wrapped__.get_state(state_cls), parent_state_proxy=self ) - def _as_state_update(self, *args, **kwargs) -> StateUpdate: + async def _as_state_update(self, *args, **kwargs) -> StateUpdate: """Temporarily allow mutability to access parent_state. Args: @@ -2734,7 +2736,7 @@ def _as_state_update(self, *args, **kwargs) -> StateUpdate: original_mutable = self._self_mutable self._self_mutable = True try: - return self.__wrapped__._as_state_update(*args, **kwargs) + return await self.__wrapped__._as_state_update(*args, **kwargs) finally: self._self_mutable = original_mutable @@ -3366,10 +3368,10 @@ async def get_state( state.parent_state = parent_state # To retain compatibility with previous implementation, by default, we return - # the top-level state by chasing `parent_state` pointers up the tree. + # the top-level state which should always be fetched or already cached. if top_level: - return state._get_root_state() - return state + return flat_state_tree[self.state.get_full_name()] + return flat_state_tree[state_cls.get_full_name()] @override async def set_state( @@ -4070,7 +4072,7 @@ def reload_state_module( if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._potentially_dirty_substates.discard(subclass.get_name()) + state._potentially_dirty_states.discard(subclass.get_name()) state._var_dependencies = {} state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/reflex/vars/base.py b/reflex/vars/base.py index ada70aa6dd2..3f5e1ac86d7 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -5,7 +5,6 @@ import contextlib import dataclasses import datetime -import dis import functools import inspect import json @@ -20,6 +19,7 @@ Any, Callable, ClassVar, + Coroutine, Dict, FrozenSet, Generic, @@ -50,7 +50,6 @@ VarAttributeError, VarDependencyError, VarTypeError, - VarValueError, ) from reflex.utils.format import format_state_name from reflex.utils.imports import ( @@ -2073,9 +2072,8 @@ def _check_deprecated_return_type(self, instance, value) -> None: def _deps( self, - objclass: BaseState, + objclass: Type[BaseState], obj: FunctionType | CodeType | None = None, - self_names: Optional[dict[str, str]] = None, ) -> dict[str, set[str]]: """Determine var dependencies of this ComputedVar. @@ -2087,17 +2085,12 @@ def _deps( Args: objclass: the class obj this ComputedVar is attached to. obj: the object to disassemble (defaults to the fget function). - self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions. Returns: A dictionary mapping state names to the set of variable names accessed by the given obj. - - Raises: - VarValueError: if the function references the get_state, parent_state, or substates attributes - (cannot track deps in a related state, only implicitly via parent state). """ - from reflex.state import BaseState + from .dep_tracking import DependencyTracker d = {} if self._static_deps: @@ -2108,158 +2101,26 @@ def _deps( if not self._auto_deps: return d + if obj is None: fget = self._fget if fget is not None: obj = cast(FunctionType, fget) else: return d - with contextlib.suppress(AttributeError): - # unbox functools.partial - obj = cast(FunctionType, obj.func) # type: ignore - with contextlib.suppress(AttributeError): - # unbox EventHandler - obj = cast(FunctionType, obj.fn) # type: ignore - if self_names is None and isinstance(obj, FunctionType): - try: - # the first argument to the function is the name of "self" arg - self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()} - except (AttributeError, IndexError): - self_names = None - if self_names is None: - # cannot reference attributes on self if method takes no args + try: + return DependencyTracker( + func=obj, state_cls=objclass, dependencies=d + ).dependencies + except Exception as e: + console.warn( + "Failed to automatically determine dependencies for computed var " + f"{objclass.__name__}.{self._js_expr}: {e}. " + "Provide static_deps and set auto_deps=False to suppress this warning." + ) return d - invalid_names = ["parent_state", "substates", "get_substate"] - self_on_top_of_stack = None - getting_state = False - getting_var = False - for instruction in dis.get_instructions(obj): - if getting_state: - if instruction.opname == "LOAD_FAST": - raise VarValueError( - f"Dependency detection cannot identify get_state class from local var {instruction.argval}." - ) - if instruction.opname == "LOAD_GLOBAL": - # Special case: referencing state class from global scope. - getting_state = obj.__globals__.get(instruction.argval) - elif instruction.opname == "LOAD_DEREF": - # Special case: referencing state class from closure. - closure = dict(zip(obj.__code__.co_freevars, obj.__closure__)) - try: - getting_state = closure[instruction.argval].cell_contents - except ValueError as ve: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." - ) from ve - elif instruction.opname == "STORE_FAST": - # Storing the result of get_state in a local variable. - if not isinstance(getting_state, type) or not issubclass( - getting_state, BaseState - ): - raise VarValueError( - f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." - ) - self_names[instruction.argval] = getting_state.get_full_name() - getting_state = False - continue # nothing else happens until we have identified the local var - if getting_var: - if instruction.opname == "CALL": - # get the original source code and eval it - start_line = getting_var[0].positions.lineno - start_column = getting_var[0].positions.col_offset - end_line = getting_var[-1].positions.end_lineno - end_column = getting_var[-1].positions.end_col_offset - source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[ - start_line - 1 : end_line - ] - if len(source) > 1: - snipped_source = "".join( - [ - source[0][start_column:], - source[1:-2] if len(source) > 2 else "", - source[-1][:end_column], - ] - ) - else: - snipped_source = source[0][start_column:end_column] - the_var = eval(f"({snipped_source})", obj.__globals__) - the_var_data = the_var._get_all_var_data() - d.setdefault(the_var_data.state, set()).add(the_var_data.field_name) - getting_var = False - elif isinstance(getting_var, list): - getting_var.append(instruction) - else: - getting_var = [instruction] - continue - if ( - instruction.opname in ("LOAD_FAST", "LOAD_DEREF") - and instruction.argval in self_names - ): - # bytecode loaded the class instance to the top of stack, next load instruction - # is referencing an attribute on self - self_on_top_of_stack = self_names[instruction.argval] - continue - if self_on_top_of_stack and instruction.opname in ( - "LOAD_ATTR", - "LOAD_METHOD", - ): - if instruction.argval in invalid_names: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." - ) - if instruction.argval == "get_state": - # Special case: arbitrary state access requested. - getting_state = True - continue - if instruction.argval == "get_var_value": - # Special case: arbitrary var access requested. - getting_var = True - continue - target_state = objclass.get_root_state().get_class_substate( - self_on_top_of_stack - ) - try: - ref_obj = getattr(target_state, instruction.argval) - except Exception: - ref_obj = None - if callable(ref_obj): - # recurse into callable attributes - for state_name, dep_name in self._deps( - objclass=target_state, - obj=ref_obj, - ).items(): - d.setdefault(state_name, set()).update(dep_name) - # recurse into property fget functions - elif isinstance(ref_obj, property) and not isinstance( - ref_obj, ComputedVar - ): - for state_name, dep_name in self._deps( - objclass=target_state, - obj=ref_obj.fget, # type: ignore - ).items(): - d.setdefault(state_name, set()).update(dep_name) - elif ( - instruction.argval in target_state.backend_vars - or instruction.argval in target_state.vars - ): - # var access - d.setdefault(self_on_top_of_stack, set()).add(instruction.argval) - elif instruction.opname == "LOAD_CONST" and isinstance( - instruction.argval, CodeType - ): - # recurse into nested functions / comprehensions, which can reference - # instance attributes from the outer scope - for state_name, dep_name in self._deps( - objclass=objclass, - obj=instruction.argval, - self_names=self_names, - ).items(): - d.setdefault(state_name, set()).update(dep_name) - self_on_top_of_stack = None - return d - def mark_dirty(self, instance) -> None: """Mark this ComputedVar as dirty. @@ -2314,11 +2175,63 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]): class AsyncComputedVar(ComputedVar[RETURN_TYPE]): """A computed var that wraps a coroutinefunction.""" - _fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field( - default_factory=lambda: lambda _: None - ) # type: ignore + _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = ( + dataclasses.field() + ) - def __get__(self, instance: BaseState | None, owner): + @overload + def __get__( + self: AsyncComputedVar[int] | AsyncComputedVar[float], + instance: None, + owner: Type, + ) -> NumberVar: ... + + @overload + def __get__( + self: AsyncComputedVar[str], + instance: None, + owner: Type, + ) -> StringVar: ... + + @overload + def __get__( + self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]], + instance: None, + owner: Type, + ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... + + @overload + def __get__( + self: AsyncComputedVar[list[LIST_INSIDE]], + instance: None, + owner: Type, + ) -> ArrayVar[list[LIST_INSIDE]]: ... + + @overload + def __get__( + self: AsyncComputedVar[set[LIST_INSIDE]], + instance: None, + owner: Type, + ) -> ArrayVar[set[LIST_INSIDE]]: ... + + @overload + def __get__( + self: AsyncComputedVar[tuple[LIST_INSIDE, ...]], + instance: None, + owner: Type, + ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... + + @overload + def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ... + + @overload + def __get__( + self, instance: BaseState, owner: Type + ) -> Coroutine[None, None, RETURN_TYPE]: ... + + def __get__( + self, instance: BaseState | None, owner + ) -> Var | Coroutine[None, None, RETURN_TYPE]: """Get the ComputedVar value. If the value is already cached on the instance, return the cached value. @@ -2335,14 +2248,15 @@ def __get__(self, instance: BaseState | None, owner): if not self._cache: - async def _awaitable_result(): + async def _awaitable_result(instance=instance) -> RETURN_TYPE: value = await self.fget(instance) self._check_deprecated_return_type(instance, value) + return value return _awaitable_result() else: # handle caching - async def _awaitable_result(): + async def _awaitable_result(instance=instance) -> RETURN_TYPE: if not hasattr(instance, self._cache_attr) or self.needs_update( instance ): @@ -2356,7 +2270,16 @@ async def _awaitable_result(): self._check_deprecated_return_type(instance, value) return value - return _awaitable_result() + return _awaitable_result() + + @property + def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]: + """Get the getter function. + + Returns: + The getter function. + """ + return self._fget if TYPE_CHECKING: diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py new file mode 100644 index 00000000000..3bd8a6bf042 --- /dev/null +++ b/reflex/vars/dep_tracking.py @@ -0,0 +1,265 @@ +"""Collection of base classes.""" + +from __future__ import annotations + +import contextlib +import dataclasses +import dis +import enum +import inspect +from types import CodeType, FunctionType +from typing import TYPE_CHECKING, ClassVar, Type, cast + +from reflex.utils.exceptions import VarValueError + +if TYPE_CHECKING: + from reflex.state import BaseState + + from .base import Var + + +class ScanStatus(enum.Enum): + """State of the dis instruction scanning loop.""" + + SCANNING = enum.auto() + GETTING_ATTR = enum.auto() + GETTING_STATE = enum.auto() + GETTING_VAR = enum.auto() + + +@dataclasses.dataclass +class DependencyTracker: + """State machine for identifying state attributes that are accessed by a function.""" + + func: FunctionType | CodeType = dataclasses.field() + state_cls: Type[BaseState] = dataclasses.field() + + dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict) + + scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING) + top_of_stack: str | None = dataclasses.field(default=None) + + tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict) + + _getting_state_class: Type[BaseState] | None = dataclasses.field(default=None) + _getting_var_instructions: list[dis.Instruction] = dataclasses.field( + default_factory=list + ) + + INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"] + + def __post_init__(self): + """After initializing, populate the dependencies dict.""" + with contextlib.suppress(AttributeError): + # unbox functools.partial + self.func = cast(FunctionType, self.func.func) # type: ignore + with contextlib.suppress(AttributeError): + # unbox EventHandler + self.func = cast(FunctionType, self.func.fn) # type: ignore + + if isinstance(self.func, FunctionType): + with contextlib.suppress(AttributeError, IndexError): + # the first argument to the function is the name of "self" arg + self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls + + self._populate_dependencies() + + def _merge_deps(self, tracker: DependencyTracker) -> None: + """Merge dependencies from another DependencyTracker. + + Args: + tracker: The DependencyTracker to merge dependencies from. + """ + for state_name, dep_name in tracker.dependencies.items(): + self.dependencies.setdefault(state_name, set()).update(dep_name) + + def load_attr_or_method(self, instruction: dis.Instruction) -> None: + """Handle loading an attribute or method from the object on top of the stack. + + This method directly tracks attributes and recursively merges + dependencies from analyzing the dependencies of any methods called. + + Args: + instruction: The dis instruction to process. + """ + from .base import ComputedVar + + if instruction.argval in self.INVALID_NAMES: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." + ) + if instruction.argval == "get_state": + # Special case: arbitrary state access requested. + self.scan_status = ScanStatus.GETTING_STATE + return + if instruction.argval == "get_var_value": + # Special case: arbitrary var access requested. + self.scan_status = ScanStatus.GETTING_VAR + return + + # Reset status back to SCANNING after attribute is accessed. + self.scan_status = ScanStatus.SCANNING + if not self.top_of_stack: + return + target_state = self.tracked_locals[self.top_of_stack] + ref_obj = getattr(target_state, instruction.argval) + + if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar): + # recurse into property fget functions + ref_obj = ref_obj.fget # type: ignore + if callable(ref_obj): + # recurse into callable attributes + self._merge_deps( + type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state) + ) + elif ( + instruction.argval in target_state.backend_vars + or instruction.argval in target_state.vars + ): + # var access + self.dependencies.setdefault(target_state.get_full_name(), set()).add( + instruction.argval + ) + + def handle_getting_state(self, instruction: dis.Instruction) -> None: + """Handle bytecode analysis when `get_state` was called in the function. + + If the wrapped function is getting an arbitrary state and saving it to a + local variable, this method associates the local variable name with the + state class in self.tracked_locals. + + When an attribute/method is accessed on a tracked local, it will be + associated with this state. + + Args: + instruction: The dis instruction to process. + """ + from reflex.state import BaseState + + if instruction.opname == "LOAD_FAST": + raise VarValueError( + f"Dependency detection cannot identify get_state class from local var {instruction.argval}." + ) + if instruction.opname == "LOAD_GLOBAL": + # Special case: referencing state class from global scope. + self._getting_state_class = self.func.__globals__.get(instruction.argval) + elif instruction.opname == "LOAD_DEREF": + # Special case: referencing state class from closure. + closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) + try: + self._getting_state_class = closure[instruction.argval].cell_contents + except ValueError as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." + ) from ve + elif instruction.opname == "STORE_FAST": + # Storing the result of get_state in a local variable. + if not isinstance(self._getting_state_class, type) or not issubclass( + self._getting_state_class, BaseState + ): + raise VarValueError( + f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." + ) + self.tracked_locals[instruction.argval] = self._getting_state_class + self.scan_status = ScanStatus.SCANNING + self._getting_state_class = None + + def _eval_var(self) -> Var: + """Evaluate instructions from the wrapped function to get the Var object. + + Returns: + The Var object. + """ + # Get the original source code and eval it to get the Var. + start_line = self._getting_var_instructions[0].positions.lineno + start_column = self._getting_var_instructions[0].positions.col_offset + end_line = self._getting_var_instructions[-1].positions.end_lineno + end_column = self._getting_var_instructions[-1].positions.end_col_offset + source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[ + start_line - 1 : end_line + ] + # Create a python source string snippet. + if len(source) > 1: + snipped_source = "".join( + [ + source[0][start_column:], + source[1:-2] if len(source) > 2 else "", + source[-1][:end_column], + ] + ) + else: + snipped_source = source[0][start_column:end_column] + try: + closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) + except Exception: + # Fallback if the closure is not available. + closure = {} + # Evaluate the string in the context of the function's globals and closure. + return eval(f"({snipped_source})", self.func.__globals__, closure) + + def handle_getting_var(self, instruction: dis.Instruction) -> None: + """Handle bytecode analysis when `get_var_value` was called in the function. + + This only really works if the expression passed to `get_var_value` is + evaluable in the function's global scope or closure, so getting the var + value from a var saved in a local variable or in the class instance is + not possible. + + Args: + instruction: The dis instruction to process. + """ + if instruction.opname == "CALL" and self._getting_var_instructions: + if self._getting_var_instructions: + the_var = self._eval_var() + the_var_data = the_var._get_all_var_data() + self.dependencies.setdefault(the_var_data.state, set()).add( + the_var_data.field_name + ) + self._getting_var_instructions.clear() + self.scan_status = ScanStatus.SCANNING + else: + self._getting_var_instructions.append(instruction) + + def _populate_dependencies(self) -> None: + """Update self.dependencies based on the disassembly of self.func. + + Save references to attributes accessed on "self" or other fetched states. + + Recursively called when the function makes a method call on "self" or + define comprehensions or nested functions that may reference "self". + + Raises: + VarValueError: if the function references the get_state, parent_state, or substates attributes + (cannot track deps in a related state, only implicitly via parent state). + """ + for instruction in dis.get_instructions(self.func): + if self.scan_status == ScanStatus.GETTING_STATE: + self.handle_getting_state(instruction) + elif self.scan_status == ScanStatus.GETTING_VAR: + self.handle_getting_var(instruction) + elif ( + instruction.opname in ("LOAD_FAST", "LOAD_DEREF") + and instruction.argval in self.tracked_locals + ): + # bytecode loaded the class instance to the top of stack, next load instruction + # is referencing an attribute on self + self.top_of_stack = instruction.argval + self.scan_status = ScanStatus.GETTING_ATTR + elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in ( + "LOAD_ATTR", + "LOAD_METHOD", + ): + self.load_attr_or_method(instruction) + self.top_of_stack = None + elif instruction.opname == "LOAD_CONST" and isinstance( + instruction.argval, CodeType + ): + # recurse into nested functions / comprehensions, which can reference + # instance attributes from the outer scope + self._merge_deps( + type(self)( + func=instruction.argval, + state_cls=self.state_cls, + tracked_locals=self.tracked_locals, + ) + ) From e74e913b4c215dc6322ff447156b1b2926cb5fbb Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Jan 2025 23:54:11 -0800 Subject: [PATCH 07/12] Fix pre-commit issues --- reflex/vars/dep_tracking.py | 83 +++++++++++++++++++++++++++---------- tests/units/test_state.py | 11 +++-- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 3bd8a6bf042..387de5a6645 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -81,6 +81,9 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None: Args: instruction: The dis instruction to process. + + Raises: + VarValueError: if the attribute is an disallowed name. """ from .base import ComputedVar @@ -133,6 +136,9 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None: Args: instruction: The dis instruction to process. + + Raises: + VarValueError: if the state class cannot be determined from the instruction. """ from reflex.state import BaseState @@ -140,14 +146,25 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None: raise VarValueError( f"Dependency detection cannot identify get_state class from local var {instruction.argval}." ) + if isinstance(self.func, CodeType): + raise VarValueError( + "Dependency detection cannot identify get_state class from a code object." + ) if instruction.opname == "LOAD_GLOBAL": # Special case: referencing state class from global scope. - self._getting_state_class = self.func.__globals__.get(instruction.argval) + try: + self._getting_state_class = inspect.getclosurevars(self.func).globals[ + instruction.argval + ] + except (ValueError, KeyError) as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." + ) from ve elif instruction.opname == "LOAD_DEREF": # Special case: referencing state class from closure. - closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) try: - self._getting_state_class = closure[instruction.argval].cell_contents + closure = inspect.getclosurevars(self.func).nonlocals + self._getting_state_class = closure[instruction.argval] except ValueError as ve: raise VarValueError( f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." @@ -169,33 +186,54 @@ def _eval_var(self) -> Var: Returns: The Var object. + + Raises: + VarValueError: if the source code for the var cannot be determined. """ # Get the original source code and eval it to get the Var. - start_line = self._getting_var_instructions[0].positions.lineno - start_column = self._getting_var_instructions[0].positions.col_offset - end_line = self._getting_var_instructions[-1].positions.end_lineno - end_column = self._getting_var_instructions[-1].positions.end_col_offset - source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[ - start_line - 1 : end_line - ] + module = inspect.getmodule(self.func) + positions = self._getting_var_instructions[0].positions + if module is None or positions is None: + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + start_line = positions.lineno + start_column = positions.col_offset + end_line = positions.end_lineno + end_column = positions.end_col_offset + if ( + start_line is None + or start_column is None + or end_line is None + or end_column is None + ): + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line] # Create a python source string snippet. if len(source) > 1: snipped_source = "".join( [ - source[0][start_column:], - source[1:-2] if len(source) > 2 else "", - source[-1][:end_column], + *source[0][start_column:], + *(source[1:-2] if len(source) > 2 else []), + *source[-1][:end_column], ] ) else: snipped_source = source[0][start_column:end_column] + # Fallback if the closure is not available. + globals = {} + closure = {} try: - closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) + if not isinstance(self.func, CodeType): + closurevars = inspect.getclosurevars(self.func) + closure = closurevars.nonlocals + globals = dict(closurevars.globals) except Exception: - # Fallback if the closure is not available. - closure = {} + pass # Evaluate the string in the context of the function's globals and closure. - return eval(f"({snipped_source})", self.func.__globals__, closure) + return eval(f"({snipped_source})", globals, closure) def handle_getting_var(self, instruction: dis.Instruction) -> None: """Handle bytecode analysis when `get_var_value` was called in the function. @@ -207,11 +245,18 @@ def handle_getting_var(self, instruction: dis.Instruction) -> None: Args: instruction: The dis instruction to process. + + Raises: + VarValueError: if the source code for the var cannot be determined. """ if instruction.opname == "CALL" and self._getting_var_instructions: if self._getting_var_instructions: the_var = self._eval_var() the_var_data = the_var._get_all_var_data() + if the_var_data is None: + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) self.dependencies.setdefault(the_var_data.state, set()).add( the_var_data.field_name ) @@ -227,10 +272,6 @@ def _populate_dependencies(self) -> None: Recursively called when the function makes a method call on "self" or define comprehensions or nested functions that may reference "self". - - Raises: - VarValueError: if the function references the get_state, parent_state, or substates attributes - (cannot track deps in a related state, only implicitly via parent state). """ for instruction in dis.get_instructions(self.func): if self.scan_status == ScanStatus.GETTING_STATE: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 0d9d438eaa1..2a07f1b2ed3 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3189,8 +3189,7 @@ class GreatGrandchild3(Grandchild3): RxState = State -@pytest.mark.skip(reason="This test is maybe not relevant anymore.") -def test_potentially_dirty_substates(): +def test_potentially_dirty_states(): """Test that potentially_dirty_substates returns the correct substates. Even if the name "State" is shadowed, it should still work correctly. @@ -3206,9 +3205,9 @@ class C1(State): def bar(self) -> str: return "" - assert RxState._potentially_dirty_substates() == set() - assert State._potentially_dirty_substates() == set() - assert C1._potentially_dirty_substates() == set() + assert RxState._get_potentially_dirty_states() == set() + assert State._get_potentially_dirty_states() == set() + assert C1._get_potentially_dirty_states() == set() @pytest.mark.asyncio @@ -3857,7 +3856,7 @@ async def v(self) -> int: child3 = await self.get_state(Child3) return child3.child3_var + p.parent_var - mock_app.state_manager.state = mock_app.state = Parent + mock_app.state_manager.state = mock_app._state = Parent # Get the top level state via unconnected sibling. root = await mock_app.state_manager.get_state(_substate_key(token, Child)) From 25456a53a9380620e560b87f2b087d17405d39ba Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 29 Jan 2025 01:15:47 -0800 Subject: [PATCH 08/12] ComputedVar.add_dependency: explicitly dependency declaration Allow var dependencies to be added at runtime, for example, when defining a ComponentState that depends on vars that cannot be known statically. Fix more pyright issues. --- reflex/vars/base.py | 31 +++++++++++++ reflex/vars/dep_tracking.py | 87 +++++++++++++++++++++++++------------ tests/units/test_state.py | 56 ++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 27 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 3f5e1ac86d7..44d1764a330 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2130,6 +2130,37 @@ def mark_dirty(self, instance) -> None: with contextlib.suppress(AttributeError): delattr(instance, self._cache_attr) + def add_dependency(self, objclass: Type[BaseState], dep: Var): + """Explicitly add a dependency to the ComputedVar. + + After adding the dependency, when the `dep` changes, this computed var + will be marked dirty. + + Args: + objclass: The class obj this ComputedVar is attached to. + dep: The dependency to add. + + Raises: + VarDependencyError: If the dependency is not a Var instance with a + state and field name + """ + if all_var_data := dep._get_all_var_data(): + state_name = all_var_data.state + if state_name: + var_name = all_var_data.field_name + if var_name: + self._static_deps.setdefault(state_name, set()).add(var_name) + objclass.get_root_state().get_class_substate( + state_name + )._var_dependencies.setdefault(var_name, set()).add( + (objclass.get_full_name(), self._js_expr) + ) + return + raise VarDependencyError( + "ComputedVar dependencies must be Var instances with a state and " + f"field name, got {dep!r}." + ) + def _determine_var_type(self) -> Type: """Get the type of the var. diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 387de5a6645..8b002812092 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -8,7 +8,7 @@ import enum import inspect from types import CodeType, FunctionType -from typing import TYPE_CHECKING, ClassVar, Type, cast +from typing import TYPE_CHECKING, Any, ClassVar, Type, cast from reflex.utils.exceptions import VarValueError @@ -18,6 +18,24 @@ from .base import Var +CellEmpty = object() + + +def get_cell_value(cell) -> Any: + """Get the value of a cell object. + + Args: + cell: The cell object to get the value from. (func.__closure__ objects) + + Returns: + The value from the cell or CellEmpty if a ValueError is raised. + """ + try: + return cell.cell_contents + except ValueError: + return CellEmpty + + class ScanStatus(enum.Enum): """State of the dis instruction scanning loop.""" @@ -124,6 +142,33 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None: instruction.argval ) + def _get_globals(self) -> dict[str, Any]: + """Get the globals of the function. + + Returns: + The var names and values in the globals of the function. + """ + if isinstance(self.func, CodeType): + return {} + return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues] + + def _get_closure(self) -> dict[str, Any]: + """Get the closure of the function, with unbound values omitted. + + Returns: + The var names and values in the closure of the function. + """ + if isinstance(self.func, CodeType): + return {} + return { + var_name: get_cell_value(cell) + for var_name, cell in zip( + self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues] + self.func.__closure__, # pyright: ignore[reportGeneralTypeIssues] + ) + if get_cell_value(cell) is not CellEmpty + } + def handle_getting_state(self, instruction: dis.Instruction) -> None: """Handle bytecode analysis when `get_state` was called in the function. @@ -153,9 +198,7 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None: if instruction.opname == "LOAD_GLOBAL": # Special case: referencing state class from global scope. try: - self._getting_state_class = inspect.getclosurevars(self.func).globals[ - instruction.argval - ] + self._getting_state_class = self._get_globals()[instruction.argval] except (ValueError, KeyError) as ve: raise VarValueError( f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." @@ -163,11 +206,10 @@ def handle_getting_state(self, instruction: dis.Instruction) -> None: elif instruction.opname == "LOAD_DEREF": # Special case: referencing state class from closure. try: - closure = inspect.getclosurevars(self.func).nonlocals - self._getting_state_class = closure[instruction.argval] - except ValueError as ve: + self._getting_state_class = self._get_closure()[instruction.argval] + except (ValueError, KeyError) as ve: raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?" ) from ve elif instruction.opname == "STORE_FAST": # Storing the result of get_state in a local variable. @@ -192,15 +234,16 @@ def _eval_var(self) -> Var: """ # Get the original source code and eval it to get the Var. module = inspect.getmodule(self.func) - positions = self._getting_var_instructions[0].positions - if module is None or positions is None: + positions0 = self._getting_var_instructions[0].positions + positions1 = self._getting_var_instructions[-1].positions + if module is None or positions0 is None or positions1 is None: raise VarValueError( f"Cannot determine the source code for the var in {self.func!r}." ) - start_line = positions.lineno - start_column = positions.col_offset - end_line = positions.end_lineno - end_column = positions.end_col_offset + start_line = positions0.lineno + start_column = positions0.col_offset + end_line = positions1.end_lineno + end_column = positions1.end_col_offset if ( start_line is None or start_column is None @@ -217,23 +260,13 @@ def _eval_var(self) -> Var: [ *source[0][start_column:], *(source[1:-2] if len(source) > 2 else []), - *source[-1][:end_column], + *source[-1][: end_column - 1], ] ) else: - snipped_source = source[0][start_column:end_column] - # Fallback if the closure is not available. - globals = {} - closure = {} - try: - if not isinstance(self.func, CodeType): - closurevars = inspect.getclosurevars(self.func) - closure = closurevars.nonlocals - globals = dict(closurevars.globals) - except Exception: - pass + snipped_source = source[0][start_column : end_column - 1] # Evaluate the string in the context of the function's globals and closure. - return eval(f"({snipped_source})", globals, closure) + return eval(f"({snipped_source})", self._get_globals(), self._get_closure()) def handle_getting_var(self, instruction: dis.Instruction) -> None: """Handle bytecode analysis when `get_var_value` was called in the function. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2a07f1b2ed3..00b1ac9a0b5 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -14,6 +14,7 @@ Any, AsyncGenerator, Callable, + ClassVar, Dict, List, Optional, @@ -3883,3 +3884,58 @@ async def v(self) -> int: assert await child.v == 2 root.parent_var = 2 assert await child.v == 3 + + +class Table(rx.ComponentState): + """A table state.""" + + data: ClassVar[Var] + + @rx.var(cache=True, auto_deps=False) + async def rows(self) -> List[Dict[str, Any]]: + """Computed var over the given rows. + + Returns: + The data rows. + """ + return await self.get_var_value(self.data) + + @classmethod + def get_component(cls, data: Var) -> rx.Component: + """Get the component for the table. + + Args: + data: The data var. + + Returns: + The component. + """ + cls.data = data + cls.computed_vars["rows"].add_dependency(cls, data) + return rx.foreach(data, lambda d: rx.text(d.to_string())) + + +@pytest.mark.asyncio +async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str): + """A test where an async computed var depends on a var in another state. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class OtherState(rx.State): + """A state with a var.""" + + data: List[Dict[str, Any]] = [{"foo": "bar"}] + + mock_app.state_manager.state = mock_app._state = rx.State + comp = Table.create(data=OtherState.data) + state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + other_state = await state.get_state(OtherState) + assert comp.State is not None + comp_state = await state.get_state(comp.State) + assert comp_state.dirty_vars == set() + + other_state.data.append({"foo": "baz"}) + assert "rows" in comp_state.dirty_vars From 0aad41681cd56c1e74177ce20b2fa11b172ed4be Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 29 Jan 2025 01:31:53 -0800 Subject: [PATCH 09/12] Fix/ignore more pyright issues from recent merge --- reflex/vars/base.py | 6 +++++- reflex/vars/dep_tracking.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index cd2baafda4a..f8a26e79510 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2260,6 +2260,10 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]): pass +async def _default_async_computed_var(_self: BaseState) -> Any: + return None + + @dataclasses.dataclass( eq=False, frozen=True, @@ -2270,7 +2274,7 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): """A computed var that wraps a coroutinefunction.""" _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = ( - dataclasses.field() + dataclasses.field(default=_default_async_computed_var) ) @overload diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 376dbda9fc2..17a9371a1f8 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -70,10 +70,10 @@ def __post_init__(self): """After initializing, populate the dependencies dict.""" with contextlib.suppress(AttributeError): # unbox functools.partial - self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportGeneralTypeIssues] + self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportAttributeAccessIssue] with contextlib.suppress(AttributeError): # unbox EventHandler - self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportGeneralTypeIssues] + self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportAttributeAccessIssue] if isinstance(self.func, FunctionType): with contextlib.suppress(AttributeError, IndexError): @@ -150,7 +150,7 @@ def _get_globals(self) -> dict[str, Any]: """ if isinstance(self.func, CodeType): return {} - return self.func.__globals__ # pyright: ignore[reportGeneralTypeIssues] + return self.func.__globals__ # pyright: ignore[reportAttributeAccessIssue] def _get_closure(self) -> dict[str, Any]: """Get the closure of the function, with unbound values omitted. @@ -163,7 +163,7 @@ def _get_closure(self) -> dict[str, Any]: return { var_name: get_cell_value(cell) for var_name, cell in zip( - self.func.__code__.co_freevars, # pyright: ignore[reportGeneralTypeIssues] + self.func.__code__.co_freevars, # pyright: ignore[reportAttributeAccessIssue] self.func.__closure__ or (), strict=False, ) From c30e725715c5b6050ab0154e7ced0cd04caf6167 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 29 Jan 2025 11:00:02 -0800 Subject: [PATCH 10/12] handle cleaning out _potentially_dirty_states on reload --- reflex/state.py | 9 ++++++++- reflex/utils/exec.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 11f5b09b7c3..7688dcf7859 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -4075,12 +4075,19 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ + # Clean out all potentially dirty states of reloaded modules. + for pd_state in tuple(state._potentially_dirty_states): + with contextlib.suppress(ValueError): + if ( + state.get_root_state().get_class_substate(pd_state).__module__ == module + and module is not None + ): + state._potentially_dirty_states.remove(pd_state) for subclass in tuple(state.class_subclasses): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._potentially_dirty_states.discard(subclass.get_name()) state._var_dependencies = {} state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 479ff816aec..67df7ea919e 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -488,7 +488,7 @@ def output_system_info(): dependencies.append(fnm_info) if system == "Linux": - import distro + import distro # pyright: ignore[reportMissingImports] os_version = distro.name(pretty=True) else: From 8f6dfdef9cd224e6eea7c2089a375d1f3bb40751 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 29 Jan 2025 11:03:16 -0800 Subject: [PATCH 11/12] ignore accessed attributes missing on state class these might be added dynamically later in which case we recompute the dependency tracking dicts... if not, they'll blow up anyway at runtime. --- reflex/vars/dep_tracking.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 17a9371a1f8..0b236779985 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -123,7 +123,11 @@ def load_attr_or_method(self, instruction: dis.Instruction) -> None: if not self.top_of_stack: return target_state = self.tracked_locals[self.top_of_stack] - ref_obj = getattr(target_state, instruction.argval) + try: + ref_obj = getattr(target_state, instruction.argval) + except AttributeError: + # Not found on this state class, maybe it is a dynamic attribute that will be picked up later. + ref_obj = None if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar): # recurse into property fget functions From 534d38c8306e0a479de54fd812d7d37847e9a423 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 29 Jan 2025 19:00:03 -0800 Subject: [PATCH 12/12] fix playwright tests, which insist on running an asyncio loop --- reflex/compiler/utils.py | 19 +++++++++++++++++++ .../tests_playwright/test_table.py | 12 +++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 25c44bd742e..9b388400bd4 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -3,12 +3,14 @@ from __future__ import annotations import asyncio +import concurrent.futures import traceback from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union from urllib.parse import urlparse +from reflex.utils.exec import is_in_app_harness from reflex.utils.prerequisites import get_web_dir from reflex.vars.base import Var @@ -178,6 +180,23 @@ def compile_state(state: Type[BaseState]) -> dict: initial_state = state(_reflex_internal_init=True).dict( initial=True, include_computed=False ) + try: + _ = asyncio.get_running_loop() + except RuntimeError: + pass + else: + if is_in_app_harness(): + # Playwright tests already have an event loop running, so we can't use asyncio.run. + with concurrent.futures.ThreadPoolExecutor() as pool: + resolved_initial_state = pool.submit( + asyncio.run, _resolve_delta(initial_state) + ).result() + console.warn( + f"Had to get initial state in a thread 🤮 {resolved_initial_state}", + ) + return resolved_initial_state + + # Normally the compile runs before any event loop starts, we asyncio.run is available for calling. return asyncio.run(_resolve_delta(initial_state)) diff --git a/tests/integration/tests_playwright/test_table.py b/tests/integration/tests_playwright/test_table.py index bd399a840f5..a88c4a621ad 100644 --- a/tests/integration/tests_playwright/test_table.py +++ b/tests/integration/tests_playwright/test_table.py @@ -3,7 +3,7 @@ from typing import Generator import pytest -from playwright.sync_api import Page +from playwright.sync_api import Page, expect from reflex.testing import AppHarness @@ -87,12 +87,14 @@ def test_table(page: Page, table_app: AppHarness): table = page.get_by_role("table") # Check column headers - headers = table.get_by_role("columnheader").all_inner_texts() - assert headers == expected_col_headers + headers = table.get_by_role("columnheader") + for header, exp_value in zip(headers.all(), expected_col_headers, strict=True): + expect(header).to_have_text(exp_value) # Check rows headers - rows = table.get_by_role("rowheader").all_inner_texts() - assert rows == expected_row_headers + rows = table.get_by_role("rowheader") + for row, expected_row in zip(rows.all(), expected_row_headers, strict=True): + expect(row).to_have_text(expected_row) # Check cells rows = table.get_by_role("cell").all_inner_texts()