From 24a5309156f7e359f9b82a980da2109cbe37d905 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Sat, 22 Feb 2025 09:23:53 -0800 Subject: [PATCH] don't treat vars as their types for setting state fields --- reflex/components/component.py | 2 +- reflex/components/tags/tag.py | 6 ++-- reflex/state.py | 4 +-- reflex/utils/types.py | 62 ++++++++++++++++++++++++++++------ reflex/vars/base.py | 2 +- 5 files changed, 58 insertions(+), 18 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index af9da1b4e30..ce050f3d3fc 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -192,7 +192,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: Returns: Whether the object satisfies the type hint. """ - return types._isinstance(obj, type_hint, nested=1) + return types._isinstance(obj, type_hint, nested=1, treat_var_as_type=True) def _components_from( diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 7f7a8c74d60..12ff878ed0e 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence from reflex.event import EventChain -from reflex.utils import format, types +from reflex.utils import format from reflex.vars.base import LiteralVar, Var @@ -103,9 +103,9 @@ def add_props(self, **kwargs: Optional[Any]) -> Tag: { format.to_camel_case(name, treat_hyphens_as_underscores=False): ( prop - if types._isinstance(prop, (EventChain, Mapping)) + if isinstance(prop, (EventChain, Mapping)) else LiteralVar.create(prop) - ) # rx.color is always a string + ) for name, prop in kwargs.items() if self.is_valid_prop(prop) } diff --git a/reflex/state.py b/reflex/state.py index 0f0ba97f9ca..a54ff28223e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -695,7 +695,7 @@ def _evaluate( def computed_var_func(state: Self): result = f(state) - if not _isinstance(result, of_type): + if not _isinstance(result, of_type, nested=1, treat_var_as_type=False): console.warn( f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " "You can specify expected type with `of_type` argument." @@ -1356,7 +1356,7 @@ def __setattr__(self, name: str, value: Any): field_type = _unwrap_field_type(field.outer_type_) if field.allow_none and not is_optional(field_type): field_type = Union[field_type, None] - if not _isinstance(value, field_type): + if not _isinstance(value, field_type, nested=1, treat_var_as_type=False): console.error( f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}'," f" but got '{value}' of type '{type(value)}'." diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 516b709864d..8f4ac0a6b31 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -510,13 +510,16 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool: return required_keys.issubset(required_keys) -def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: +def _isinstance( + obj: Any, cls: GenericType, *, nested: int = 0, treat_var_as_type: bool = True +) -> bool: """Check if an object is an instance of a class. Args: obj: The object to check. cls: The class to check against. nested: How many levels deep to check. + treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type. Returns: Whether the object is an instance of the class. @@ -529,15 +532,20 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: if cls is Var: return isinstance(obj, Var) if isinstance(obj, LiteralVar): - return _isinstance(obj._var_value, cls, nested=nested) + return treat_var_as_type and _isinstance( + obj._var_value, cls, nested=nested, treat_var_as_type=True + ) if isinstance(obj, Var): - return _issubclass(obj._var_type, cls) + return treat_var_as_type and _issubclass(obj._var_type, cls) if cls is None or cls is type(None): return obj is None if cls and is_union(cls): - return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls)) + return any( + _isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type) + for arg in get_args(cls) + ) if is_literal(cls): return obj in get_args(cls) @@ -567,37 +575,69 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: if nested > 0 and args: if origin is list: return isinstance(obj, list) and all( - _isinstance(item, args[0], nested=nested - 1) for item in obj + _isinstance( + item, + args[0], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) + for item in obj ) if origin is tuple: if args[-1] is Ellipsis: return isinstance(obj, tuple) and all( - _isinstance(item, args[0], nested=nested - 1) for item in obj + _isinstance( + item, + args[0], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) + for item in obj ) return ( isinstance(obj, tuple) and len(obj) == len(args) and all( - _isinstance(item, arg, nested=nested - 1) + _isinstance( + item, + arg, + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) for item, arg in zip(obj, args, strict=True) ) ) if origin in (dict, Mapping, Breakpoints): return isinstance(obj, Mapping) and all( - _isinstance(key, args[0], nested=nested - 1) - and _isinstance(value, args[1], nested=nested - 1) + _isinstance( + key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type + ) + and _isinstance( + value, + args[1], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) for key, value in obj.items() ) if origin is set: return isinstance(obj, set) and all( - _isinstance(item, args[0], nested=nested - 1) for item in obj + _isinstance( + item, + args[0], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) + for item in obj ) if args: from reflex.vars import Field if origin is Field: - return _isinstance(obj, args[0], nested=nested) + return _isinstance( + obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type + ) return isinstance(obj, get_base_class(cls)) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 6654c7e2228..032cd1ecabd 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2290,7 +2290,7 @@ def __get__(self, instance: BaseState | None, owner: Type): return value def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None: - if not _isinstance(value, self._var_type): + if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False): console.error( f"Computed var '{type(instance).__name__}.{self._js_expr}' must return" f" type '{self._var_type}', got '{type(value)}'."