Skip to content

Commit 81c7e9e

Browse files
committed
use safe_issubclass and handle union in sequences
1 parent 66fa613 commit 81c7e9e

File tree

3 files changed

+81
-35
lines changed

3 files changed

+81
-35
lines changed

reflex/utils/types.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,7 @@ def get_type_hints(obj: Any) -> Dict[str, Any]:
156156
return get_type_hints_og(obj)
157157

158158

159-
def unionize(*args: GenericType) -> Type:
160-
"""Unionize the types.
161-
162-
Args:
163-
args: The types to unionize.
164-
165-
Returns:
166-
The unionized types.
167-
"""
159+
def _unionize(args: list[GenericType]) -> Type:
168160
if not args:
169161
return Any # pyright: ignore [reportReturnType]
170162
if len(args) == 1:
@@ -176,6 +168,18 @@ def unionize(*args: GenericType) -> Type:
176168
return Union[unionize(*first_half), unionize(*second_half)] # pyright: ignore [reportReturnType]
177169

178170

171+
def unionize(*args: GenericType) -> Type:
172+
"""Unionize the types.
173+
174+
Args:
175+
args: The types to unionize.
176+
177+
Returns:
178+
The unionized types.
179+
"""
180+
return _unionize([arg for arg in args if arg is not NoReturn])
181+
182+
179183
def is_none(cls: GenericType) -> bool:
180184
"""Check if a class is None.
181185

reflex/vars/base.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def can_use_in_object_var(cls: GenericType) -> bool:
359359
return all(can_use_in_object_var(t) for t in types.get_args(cls))
360360
return (
361361
inspect.isclass(cls)
362-
and not issubclass(cls, Var)
362+
and not safe_issubclass(cls, Var)
363363
and serializers.can_serialize(cls, dict)
364364
)
365365

@@ -796,7 +796,7 @@ def to(
796796

797797
if inspect.isclass(output):
798798
for var_subclass in _var_subclasses[::-1]:
799-
if issubclass(output, var_subclass.var_subclass):
799+
if safe_issubclass(output, var_subclass.var_subclass):
800800
current_var_type = self._var_type
801801
if current_var_type is Any:
802802
new_var_type = var_type
@@ -808,7 +808,7 @@ def to(
808808
return to_operation_return # pyright: ignore [reportReturnType]
809809

810810
# If we can't determine the first argument, we just replace the _var_type.
811-
if not issubclass(output, Var) or var_type is None:
811+
if not safe_issubclass(output, Var) or var_type is None:
812812
return dataclasses.replace(
813813
self,
814814
_var_type=output,
@@ -850,7 +850,6 @@ def guess_type(self) -> Var:
850850
Raises:
851851
TypeError: If the type is not supported for guessing.
852852
"""
853-
from .number import NumberVar
854853
from .object import ObjectVar
855854

856855
var_type = self._var_type
@@ -868,11 +867,20 @@ def guess_type(self) -> Var:
868867

869868
if fixed_type in types.UnionTypes:
870869
inner_types = get_args(var_type)
870+
non_optional_inner_types = [
871+
types.value_inside_optional(inner_type) for inner_type in inner_types
872+
]
873+
fixed_inner_types = [
874+
get_origin(inner_type) or inner_type
875+
for inner_type in non_optional_inner_types
876+
]
871877

872-
if all(
873-
inspect.isclass(t) and issubclass(t, (int, float)) for t in inner_types
874-
):
875-
return self.to(NumberVar, self._var_type)
878+
for var_subclass in _var_subclasses[::-1]:
879+
if all(
880+
safe_issubclass(t, var_subclass.python_types)
881+
for t in fixed_inner_types
882+
):
883+
return self.to(var_subclass.var_subclass, self._var_type)
876884

877885
if can_use_in_object_var(var_type):
878886
return self.to(ObjectVar, self._var_type)
@@ -890,7 +898,7 @@ def guess_type(self) -> Var:
890898
return self.to(None)
891899

892900
for var_subclass in _var_subclasses[::-1]:
893-
if issubclass(fixed_type, var_subclass.python_types):
901+
if safe_issubclass(fixed_type, var_subclass.python_types):
894902
return self.to(var_subclass.var_subclass, self._var_type)
895903

896904
if can_use_in_object_var(fixed_type):
@@ -918,17 +926,17 @@ def _get_default_value(self) -> Any:
918926
if type_ is Literal:
919927
args = get_args(self._var_type)
920928
return args[0] if args else None
921-
if issubclass(type_, str):
929+
if safe_issubclass(type_, str):
922930
return ""
923-
if issubclass(type_, types.get_args(int | float)):
931+
if safe_issubclass(type_, types.get_args(int | float)):
924932
return 0
925-
if issubclass(type_, bool):
933+
if safe_issubclass(type_, bool):
926934
return False
927-
if issubclass(type_, list):
935+
if safe_issubclass(type_, list):
928936
return []
929-
if issubclass(type_, Mapping):
937+
if safe_issubclass(type_, Mapping):
930938
return {}
931-
if issubclass(type_, tuple):
939+
if safe_issubclass(type_, tuple):
932940
return ()
933941
if types.is_dataframe(type_):
934942
try:
@@ -939,7 +947,7 @@ def _get_default_value(self) -> Any:
939947
raise ImportError(
940948
"Please install pandas to use dataframes in your app."
941949
) from e
942-
return set() if issubclass(type_, set) else None
950+
return set() if safe_issubclass(type_, set) else None
943951

944952
def _get_setter_name(self, include_state: bool = True) -> str:
945953
"""Get the name of the var's generated setter function.
@@ -1412,7 +1420,7 @@ def __init_subclass__(cls, **kwargs):
14121420
possible_bases = [
14131421
base
14141422
for base in bases_normalized
1415-
if issubclass(base, Var) and base != LiteralVar
1423+
if safe_issubclass(base, Var) and base != LiteralVar
14161424
]
14171425

14181426
if not possible_bases:
@@ -2706,7 +2714,7 @@ def create(
27062714

27072715
def var_operation_return(
27082716
js_expression: str,
2709-
var_type: Type[RETURN] | None = None,
2717+
var_type: Type[RETURN] | GenericType | None = None,
27102718
var_data: VarData | None = None,
27112719
) -> CustomVarOperationReturn[RETURN]:
27122720
"""Shortcut for creating a CustomVarOperationReturn.

reflex/vars/sequence.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from __future__ import annotations
44

5+
import collections.abc
56
import dataclasses
67
import inspect
78
import json
89
import re
9-
import typing
1010
from typing import (
1111
TYPE_CHECKING,
1212
Any,
13+
Iterable,
1314
List,
1415
Literal,
1516
Mapping,
@@ -18,6 +19,7 @@
1819
Type,
1920
TypeVar,
2021
Union,
22+
get_args,
2123
overload,
2224
)
2325

@@ -26,6 +28,7 @@
2628
from reflex import constants
2729
from reflex.constants.base import REFLEX_VAR_OPENING_TAG
2830
from reflex.constants.colors import Color
31+
from reflex.utils import types
2932
from reflex.utils.exceptions import VarTypeError
3033
from reflex.utils.types import GenericType, get_origin
3134

@@ -46,7 +49,6 @@
4649
)
4750
from .number import (
4851
BooleanVar,
49-
LiteralNumberVar,
5052
NumberVar,
5153
raise_unsupported_operand_types,
5254
ternary_operation,
@@ -1622,6 +1624,41 @@ def is_tuple_type(t: GenericType) -> bool:
16221624
return get_origin(t) is tuple
16231625

16241626

1627+
def _determine_value_of_array_index(var_type: GenericType, index: int | None = None):
1628+
"""Determine the value of an array index.
1629+
1630+
Args:
1631+
var_type: The type of the array.
1632+
index: The index of the array.
1633+
1634+
Returns:
1635+
The value of the array index.
1636+
"""
1637+
origin_var_type = get_origin(var_type) or var_type
1638+
if origin_var_type in types.UnionTypes:
1639+
return unionize(
1640+
*[
1641+
_determine_value_of_array_index(t, index)
1642+
for t in get_args(var_type)
1643+
if t is not type(None)
1644+
]
1645+
)
1646+
if origin_var_type in [
1647+
Sequence,
1648+
Iterable,
1649+
list,
1650+
tuple,
1651+
collections.abc.Sequence,
1652+
collections.abc.Iterable,
1653+
]:
1654+
args = get_args(var_type)
1655+
return args[0] if args else Any
1656+
if origin_var_type is tuple:
1657+
args = get_args(var_type)
1658+
return args[index % len(args)] if args and index is not None else Any
1659+
return Any
1660+
1661+
16251662
@var_operation
16261663
def array_item_operation(array: ArrayVar, index: NumberVar | int):
16271664
"""Get an item from an array.
@@ -1633,12 +1670,9 @@ def array_item_operation(array: ArrayVar, index: NumberVar | int):
16331670
Returns:
16341671
The item from the array.
16351672
"""
1636-
args = typing.get_args(array._var_type)
1637-
if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type):
1638-
index_value = int(index._var_value)
1639-
element_type = args[index_value % len(args)]
1640-
else:
1641-
element_type = unionize(*args)
1673+
element_type = _determine_value_of_array_index(
1674+
array._var_type, index if isinstance(index, int) else None
1675+
)
16421676

16431677
return var_operation_return(
16441678
js_expression=f"{array!s}.at({index!s})",

0 commit comments

Comments
 (0)