diff --git a/reflex/components/component.py b/reflex/components/component.py index 6a276408267..e25440865a9 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -188,7 +188,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: Returns: Whether the object satisfies the type hint. """ - return types._isinstance(obj, type_hint, nested=1) + return types._isinstance(obj, type_hint, nested=1, treat_var_as_type=True) def _components_from( diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 8bafcd42daa..917796efe74 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -6,7 +6,7 @@ from typing import Any, List, Mapping, Sequence from reflex.event import EventChain -from reflex.utils import format, types +from reflex.utils import format from reflex.vars.base import LiteralVar, Var @@ -103,9 +103,9 @@ def add_props(self, **kwargs: Any | None) -> Tag: { format.to_camel_case(name, treat_hyphens_as_underscores=False): ( prop - if types._isinstance(prop, (EventChain, Mapping)) + if isinstance(prop, (EventChain, Mapping)) else LiteralVar.create(prop) - ) # rx.color is always a string + ) for name, prop in kwargs.items() if self.is_valid_prop(prop) } diff --git a/reflex/state.py b/reflex/state.py index 70b351af7ef..2be05741d85 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -692,7 +692,7 @@ def _evaluate(cls, f: Callable[[Self], Any], of_type: type | None = None) -> Var def computed_var_func(state: Self): result = f(state) - if not _isinstance(result, of_type): + if not _isinstance(result, of_type, nested=1, treat_var_as_type=False): console.warn( f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " "You can specify expected type with `of_type` argument." @@ -1353,7 +1353,7 @@ def __setattr__(self, name: str, value: Any): field_type = _unwrap_field_type(field.outer_type_) if field.allow_none and not is_optional(field_type): field_type = field_type | None - if not _isinstance(value, field_type): + if not _isinstance(value, field_type, nested=1, treat_var_as_type=False): console.error( f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}'," f" but got '{value}' of type '{type(value)}'." diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 162e3a68fbf..3aa72e32115 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -509,13 +509,16 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool: return required_keys.issubset(required_keys) -def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: +def _isinstance( + obj: Any, cls: GenericType, *, nested: int = 0, treat_var_as_type: bool = True +) -> bool: """Check if an object is an instance of a class. Args: obj: The object to check. cls: The class to check against. nested: How many levels deep to check. + treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type. Returns: Whether the object is an instance of the class. @@ -528,15 +531,20 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: if cls is Var: return isinstance(obj, Var) if isinstance(obj, LiteralVar): - return _isinstance(obj._var_value, cls, nested=nested) + return treat_var_as_type and _isinstance( + obj._var_value, cls, nested=nested, treat_var_as_type=True + ) if isinstance(obj, Var): - return _issubclass(obj._var_type, cls) + return treat_var_as_type and _issubclass(obj._var_type, cls) if cls is None or cls is type(None): return obj is None if cls and is_union(cls): - return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls)) + return any( + _isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type) + for arg in get_args(cls) + ) if is_literal(cls): return obj in get_args(cls) @@ -566,37 +574,69 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: if nested > 0 and args: if origin is list: return isinstance(obj, list) and all( - _isinstance(item, args[0], nested=nested - 1) for item in obj + _isinstance( + item, + args[0], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) + for item in obj ) if origin is tuple: if args[-1] is Ellipsis: return isinstance(obj, tuple) and all( - _isinstance(item, args[0], nested=nested - 1) for item in obj + _isinstance( + item, + args[0], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) + for item in obj ) return ( isinstance(obj, tuple) and len(obj) == len(args) and all( - _isinstance(item, arg, nested=nested - 1) + _isinstance( + item, + arg, + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) for item, arg in zip(obj, args, strict=True) ) ) if origin in (dict, Mapping, Breakpoints): return isinstance(obj, Mapping) and all( - _isinstance(key, args[0], nested=nested - 1) - and _isinstance(value, args[1], nested=nested - 1) + _isinstance( + key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type + ) + and _isinstance( + value, + args[1], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) for key, value in obj.items() ) if origin is set: return isinstance(obj, set) and all( - _isinstance(item, args[0], nested=nested - 1) for item in obj + _isinstance( + item, + args[0], + nested=nested - 1, + treat_var_as_type=treat_var_as_type, + ) + for item in obj ) if args: from reflex.vars import Field if origin is Field: - return _isinstance(obj, args[0], nested=nested) + return _isinstance( + obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type + ) return isinstance(obj, get_base_class(cls)) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 45b94147b81..974828ab69b 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2289,7 +2289,7 @@ def __get__(self, instance: BaseState | None, owner: Type): return value def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None: - if not _isinstance(value, self._var_type): + if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False): console.error( f"Computed var '{type(instance).__name__}.{self._js_expr}' must return" f" type '{self._var_type}', got '{type(value)}'."