Skip to content

Commit 2b7e4d6

Browse files
improve rx.Field ObjectVar typing for sqlalchemy and dataclasses (#4728)
* improve rx.Field ObjectVar typing for sqlalchemy and dataclasses * enable parametrized objectvar tests for sqlamodel and dataclass * improve typing for ObjectVars in ArrayVars * ruffing * drop duplicate objectvar import * remove redundant overload * allow optional hints in rx.Field annotations to resolve to the correct var type
1 parent 15da4e1 commit 2b7e4d6

File tree

3 files changed

+111
-15
lines changed

3 files changed

+111
-15
lines changed

reflex/vars/base.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
overload,
4141
)
4242

43+
from sqlalchemy.orm import DeclarativeBase
4344
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
4445

4546
from reflex import constants
@@ -573,15 +574,15 @@ def _replace(
573574

574575
@overload
575576
@classmethod
576-
def create( # type: ignore[override]
577+
def create( # pyright: ignore[reportOverlappingOverload]
577578
cls,
578579
value: bool,
579580
_var_data: VarData | None = None,
580581
) -> BooleanVar: ...
581582

582583
@overload
583584
@classmethod
584-
def create( # type: ignore[override]
585+
def create(
585586
cls,
586587
value: int,
587588
_var_data: VarData | None = None,
@@ -605,7 +606,7 @@ def create( # pyright: ignore [reportOverlappingOverload]
605606

606607
@overload
607608
@classmethod
608-
def create(
609+
def create( # pyright: ignore[reportOverlappingOverload]
609610
cls,
610611
value: None,
611612
_var_data: VarData | None = None,
@@ -3182,10 +3183,16 @@ def dispatch(
31823183

31833184
V = TypeVar("V")
31843185

3185-
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
3186+
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
3187+
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
3188+
3189+
if TYPE_CHECKING:
3190+
from _typeshed import DataclassInstance
3191+
3192+
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
31863193

31873194
FIELD_TYPE = TypeVar("FIELD_TYPE")
3188-
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
3195+
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
31893196

31903197

31913198
class Field(Generic[FIELD_TYPE]):
@@ -3230,6 +3237,18 @@ def __get__(
32303237
self: Field[BASE_TYPE], instance: None, owner: Any
32313238
) -> ObjectVar[BASE_TYPE]: ...
32323239

3240+
@overload
3241+
def __get__(
3242+
self: Field[SQLA_TYPE], instance: None, owner: Any
3243+
) -> ObjectVar[SQLA_TYPE]: ...
3244+
3245+
if TYPE_CHECKING:
3246+
3247+
@overload
3248+
def __get__(
3249+
self: Field[DATACLASS_TYPE], instance: None, owner: Any
3250+
) -> ObjectVar[DATACLASS_TYPE]: ...
3251+
32333252
@overload
32343253
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
32353254

reflex/vars/sequence.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@
5353
)
5454

5555
if TYPE_CHECKING:
56+
from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE
57+
from .function import FunctionVar
5658
from .object import ObjectVar
5759

60+
5861
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
5962

6063

@@ -961,6 +964,24 @@ def __getitem__(
961964
i: int | NumberVar,
962965
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
963966

967+
@overload
968+
def __getitem__(
969+
self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
970+
i: int | NumberVar,
971+
) -> ObjectVar[BASE_TYPE]: ...
972+
973+
@overload
974+
def __getitem__(
975+
self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
976+
i: int | NumberVar,
977+
) -> ObjectVar[SQLA_TYPE]: ...
978+
979+
@overload
980+
def __getitem__(
981+
self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
982+
i: int | NumberVar,
983+
) -> ObjectVar[DATACLASS_TYPE]: ...
984+
964985
@overload
965986
def __getitem__(self, i: int | NumberVar) -> Var: ...
966987

@@ -1648,10 +1669,6 @@ def repeat_array_operation(
16481669
)
16491670

16501671

1651-
if TYPE_CHECKING:
1652-
from .function import FunctionVar
1653-
1654-
16551672
@var_operation
16561673
def map_array_operation(
16571674
array: ArrayVar[ARRAY_VAR_TYPE],

tests/units/vars/test_object.py

+66-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import dataclasses
2+
13
import pytest
4+
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
25
from typing_extensions import assert_type
36

47
import reflex as rx
58
from reflex.utils.types import GenericType
69
from reflex.vars.base import Var
710
from reflex.vars.object import LiteralObjectVar, ObjectVar
11+
from reflex.vars.sequence import ArrayVar
812

913

1014
class Bare:
@@ -32,14 +36,44 @@ class Base(rx.Base):
3236
quantity: int = 0
3337

3438

39+
class SqlaBase(DeclarativeBase, MappedAsDataclass):
40+
"""Sqlalchemy declarative mapping base class."""
41+
42+
pass
43+
44+
45+
class SqlaModel(SqlaBase):
46+
"""A sqlalchemy model with a single attribute."""
47+
48+
__tablename__: str = "sqla_model"
49+
50+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
51+
quantity: Mapped[int] = mapped_column(default=0)
52+
53+
54+
@dataclasses.dataclass
55+
class Dataclass:
56+
"""A dataclass with a single attribute."""
57+
58+
quantity: int = 0
59+
60+
3561
class ObjectState(rx.State):
36-
"""A reflex state with bare and base objects."""
62+
"""A reflex state with bare, base and sqlalchemy base vars."""
3763

3864
bare: rx.Field[Bare] = rx.field(Bare())
65+
bare_optional: rx.Field[Bare | None] = rx.field(None)
3966
base: rx.Field[Base] = rx.field(Base())
67+
base_optional: rx.Field[Base | None] = rx.field(None)
68+
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
69+
sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
70+
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
71+
dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
72+
73+
base_list: rx.Field[list[Base]] = rx.field([Base()])
4074

4175

42-
@pytest.mark.parametrize("type_", [Base, Bare])
76+
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
4377
def test_var_create(type_: GenericType) -> None:
4478
my_object = type_()
4579
var = Var.create(my_object)
@@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
4983
assert quantity._var_type is int
5084

5185

52-
@pytest.mark.parametrize("type_", [Base, Bare])
86+
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
5387
def test_literal_create(type_: GenericType) -> None:
5488
my_object = type_()
5589
var = LiteralObjectVar.create(my_object)
@@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
5993
assert quantity._var_type is int
6094

6195

62-
@pytest.mark.parametrize("type_", [Base, Bare])
96+
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
6397
def test_guess(type_: GenericType) -> None:
6498
my_object = type_()
6599
var = Var.create(my_object)
@@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
70104
assert quantity._var_type is int
71105

72106

73-
@pytest.mark.parametrize("type_", [Base, Bare])
107+
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
74108
def test_state(type_: GenericType) -> None:
75109
attr_name = type_.__name__.lower()
76110
var = getattr(ObjectState, attr_name)
@@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
80114
assert quantity._var_type is int
81115

82116

83-
@pytest.mark.parametrize("type_", [Base, Bare])
117+
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
84118
def test_state_to_operation(type_: GenericType) -> None:
85119
attr_name = type_.__name__.lower()
86120
original_var = getattr(ObjectState, attr_name)
@@ -100,3 +134,29 @@ def test_typing() -> None:
100134
# Base
101135
var = ObjectState.base
102136
_ = assert_type(var, ObjectVar[Base])
137+
optional_var = ObjectState.base_optional
138+
_ = assert_type(optional_var, ObjectVar[Base | None])
139+
list_var = ObjectState.base_list
140+
_ = assert_type(list_var, ArrayVar[list[Base]])
141+
list_var_0 = list_var[0]
142+
_ = assert_type(list_var_0, ObjectVar[Base])
143+
144+
# Sqla
145+
var = ObjectState.sqlamodel
146+
_ = assert_type(var, ObjectVar[SqlaModel])
147+
optional_var = ObjectState.sqlamodel_optional
148+
_ = assert_type(optional_var, ObjectVar[SqlaModel | None])
149+
list_var = ObjectState.base_list
150+
_ = assert_type(list_var, ArrayVar[list[Base]])
151+
list_var_0 = list_var[0]
152+
_ = assert_type(list_var_0, ObjectVar[Base])
153+
154+
# Dataclass
155+
var = ObjectState.dataclass
156+
_ = assert_type(var, ObjectVar[Dataclass])
157+
optional_var = ObjectState.dataclass_optional
158+
_ = assert_type(optional_var, ObjectVar[Dataclass | None])
159+
list_var = ObjectState.base_list
160+
_ = assert_type(list_var, ArrayVar[list[Base]])
161+
list_var_0 = list_var[0]
162+
_ = assert_type(list_var_0, ObjectVar[Base])

0 commit comments

Comments
 (0)