Skip to content

Commit 21ba01c

Browse files
authored
don't use _outer_type if we don't have to (#4528)
* don't use _outer_type if we don't have to * apparently we should use .annotation, and .allow_none is useless * have a shorter path for get_field_type if it's nice * check against optional in annotation str * add check for default value being null * post merge * we still console erroring * bring back nested * get_type_hints is slow af * simplify value inside optional * optimize get_event_triggers a tad bit * optimize subclass checks * optimize things even more why not * what if we don't validate components
1 parent d97d1d9 commit 21ba01c

File tree

7 files changed

+162
-80
lines changed

7 files changed

+162
-80
lines changed

reflex/components/base/bare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def create(cls, contents: Any) -> Component:
7676
validate_str(contents)
7777
contents = str(contents) if contents is not None else ""
7878

79-
return cls(contents=contents)
79+
return cls._create(children=[], contents=contents)
8080

8181
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
8282
"""Include the hooks for the component.

reflex/components/component.py

+69-47
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919
Sequence,
2020
Set,
2121
Type,
22+
TypeVar,
2223
Union,
2324
get_args,
2425
get_origin,
2526
)
2627

27-
from typing_extensions import Self
28-
2928
import reflex.state
3029
from reflex.base import Base
3130
from reflex.compiler.templates import STATEFUL_COMPONENT
@@ -210,6 +209,27 @@ def _components_from(
210209
return ()
211210

212211

212+
DEFAULT_TRIGGERS: dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
213+
EventTriggers.ON_FOCUS: no_args_event_spec,
214+
EventTriggers.ON_BLUR: no_args_event_spec,
215+
EventTriggers.ON_CLICK: no_args_event_spec,
216+
EventTriggers.ON_CONTEXT_MENU: no_args_event_spec,
217+
EventTriggers.ON_DOUBLE_CLICK: no_args_event_spec,
218+
EventTriggers.ON_MOUSE_DOWN: no_args_event_spec,
219+
EventTriggers.ON_MOUSE_ENTER: no_args_event_spec,
220+
EventTriggers.ON_MOUSE_LEAVE: no_args_event_spec,
221+
EventTriggers.ON_MOUSE_MOVE: no_args_event_spec,
222+
EventTriggers.ON_MOUSE_OUT: no_args_event_spec,
223+
EventTriggers.ON_MOUSE_OVER: no_args_event_spec,
224+
EventTriggers.ON_MOUSE_UP: no_args_event_spec,
225+
EventTriggers.ON_SCROLL: no_args_event_spec,
226+
EventTriggers.ON_MOUNT: no_args_event_spec,
227+
EventTriggers.ON_UNMOUNT: no_args_event_spec,
228+
}
229+
230+
T = TypeVar("T", bound="Component")
231+
232+
213233
class Component(BaseComponent, ABC):
214234
"""A component with style, event trigger and other props."""
215235

@@ -364,12 +384,16 @@ def __init_subclass__(cls, **kwargs):
364384
if field.name not in props:
365385
continue
366386

387+
field_type = types.value_inside_optional(
388+
types.get_field_type(cls, field.name)
389+
)
390+
367391
# Set default values for any props.
368-
if types._issubclass(field.type_, Var):
392+
if types._issubclass(field_type, Var):
369393
field.required = False
370394
if field.default is not None:
371395
field.default = LiteralVar.create(field.default)
372-
elif types._issubclass(field.type_, EventHandler):
396+
elif types._issubclass(field_type, EventHandler):
373397
field.required = False
374398

375399
# Ensure renamed props from parent classes are applied to the subclass.
@@ -380,7 +404,7 @@ def __init_subclass__(cls, **kwargs):
380404
inherited_rename_props.update(parent._rename_props)
381405
cls._rename_props = inherited_rename_props
382406

383-
def __init__(self, *args, **kwargs):
407+
def _post_init(self, *args, **kwargs):
384408
"""Initialize the component.
385409
386410
Args:
@@ -393,16 +417,6 @@ def __init__(self, *args, **kwargs):
393417
"""
394418
# Set the id and children initially.
395419
children = kwargs.get("children", [])
396-
initial_kwargs = {
397-
"id": kwargs.get("id"),
398-
"children": children,
399-
**{
400-
prop: LiteralVar.create(kwargs[prop])
401-
for prop in self.get_initial_props()
402-
if prop in kwargs
403-
},
404-
}
405-
super().__init__(**initial_kwargs)
406420

407421
self._validate_component_children(children)
408422

@@ -433,7 +447,9 @@ def __init__(self, *args, **kwargs):
433447
field_type = EventChain
434448
elif key in props:
435449
# Set the field type.
436-
field_type = fields[key].type_
450+
field_type = types.value_inside_optional(
451+
types.get_field_type(type(self), key)
452+
)
437453

438454
else:
439455
continue
@@ -455,7 +471,10 @@ def determine_key(value: Any):
455471
try:
456472
kwargs[key] = determine_key(value)
457473

458-
expected_type = fields[key].outer_type_.__args__[0]
474+
expected_type = types.get_args(
475+
types.get_field_type(type(self), key)
476+
)[0]
477+
459478
# validate literal fields.
460479
types.validate_literal(
461480
key, value, expected_type, type(self).__name__
@@ -470,7 +489,7 @@ def determine_key(value: Any):
470489
except TypeError:
471490
# If it is not a valid var, check the base types.
472491
passed_type = type(value)
473-
expected_type = fields[key].outer_type_
492+
expected_type = types.get_field_type(type(self), key)
474493
if types.is_union(passed_type):
475494
# We need to check all possible types in the union.
476495
passed_types = (
@@ -552,7 +571,8 @@ def determine_key(value: Any):
552571
kwargs["class_name"] = " ".join(class_name)
553572

554573
# Construct the component.
555-
super().__init__(*args, **kwargs)
574+
for key, value in kwargs.items():
575+
setattr(self, key, value)
556576

557577
def get_event_triggers(
558578
self,
@@ -562,34 +582,17 @@ def get_event_triggers(
562582
Returns:
563583
The event triggers.
564584
"""
565-
default_triggers: dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
566-
EventTriggers.ON_FOCUS: no_args_event_spec,
567-
EventTriggers.ON_BLUR: no_args_event_spec,
568-
EventTriggers.ON_CLICK: no_args_event_spec,
569-
EventTriggers.ON_CONTEXT_MENU: no_args_event_spec,
570-
EventTriggers.ON_DOUBLE_CLICK: no_args_event_spec,
571-
EventTriggers.ON_MOUSE_DOWN: no_args_event_spec,
572-
EventTriggers.ON_MOUSE_ENTER: no_args_event_spec,
573-
EventTriggers.ON_MOUSE_LEAVE: no_args_event_spec,
574-
EventTriggers.ON_MOUSE_MOVE: no_args_event_spec,
575-
EventTriggers.ON_MOUSE_OUT: no_args_event_spec,
576-
EventTriggers.ON_MOUSE_OVER: no_args_event_spec,
577-
EventTriggers.ON_MOUSE_UP: no_args_event_spec,
578-
EventTriggers.ON_SCROLL: no_args_event_spec,
579-
EventTriggers.ON_MOUNT: no_args_event_spec,
580-
EventTriggers.ON_UNMOUNT: no_args_event_spec,
581-
}
582-
585+
triggers = DEFAULT_TRIGGERS.copy()
583586
# Look for component specific triggers,
584587
# e.g. variable declared as EventHandler types.
585588
for field in self.get_fields().values():
586-
if types._issubclass(field.outer_type_, EventHandler):
589+
if field.type_ is EventHandler:
587590
args_spec = None
588591
annotation = field.annotation
589592
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
590593
args_spec = metadata[0]
591-
default_triggers[field.name] = args_spec or (no_args_event_spec)
592-
return default_triggers
594+
triggers[field.name] = args_spec or (no_args_event_spec)
595+
return triggers
593596

594597
def __repr__(self) -> str:
595598
"""Represent the component in React.
@@ -703,9 +706,11 @@ def _get_component_prop_names(cls) -> Set[str]:
703706
"""
704707
return {
705708
name
706-
for name, field in cls.get_fields().items()
709+
for name in cls.get_fields()
707710
if name in cls.get_props()
708-
and types._issubclass(field.outer_type_, Component)
711+
and types._issubclass(
712+
types.value_inside_optional(types.get_field_type(cls, name)), Component
713+
)
709714
}
710715

711716
def _get_components_in_props(self) -> Sequence[BaseComponent]:
@@ -729,7 +734,7 @@ def _get_components_in_props(self) -> Sequence[BaseComponent]:
729734
]
730735

731736
@classmethod
732-
def create(cls, *children, **props) -> Self:
737+
def create(cls: Type[T], *children, **props) -> T:
733738
"""Create the component.
734739
735740
Args:
@@ -774,7 +779,22 @@ def validate_children(children: tuple | list):
774779
for child in children
775780
]
776781

777-
return cls(children=children, **props)
782+
return cls._create(children, **props)
783+
784+
@classmethod
785+
def _create(cls: Type[T], children: list[Component], **props: Any) -> T:
786+
"""Create the component.
787+
788+
Args:
789+
children: The children of the component.
790+
**props: The props of the component.
791+
792+
Returns:
793+
The component.
794+
"""
795+
comp = cls.construct(id=props.get("id"), children=children)
796+
comp._post_init(children=children, **props)
797+
return comp
778798

779799
def add_style(self) -> dict[str, Any] | None:
780800
"""Add style to the component.
@@ -1659,7 +1679,7 @@ class CustomComponent(Component):
16591679
# The props of the component.
16601680
props: dict[str, Any] = {}
16611681

1662-
def __init__(self, **kwargs):
1682+
def _post_init(self, **kwargs):
16631683
"""Initialize the custom component.
16641684
16651685
Args:
@@ -1702,7 +1722,7 @@ def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]:
17021722
)
17031723
)
17041724

1705-
super().__init__(
1725+
super()._post_init(
17061726
event_triggers={
17071727
key: EventChain.create(
17081728
value=props[key],
@@ -1863,7 +1883,9 @@ def custom_component(
18631883
def wrapper(*children, **props) -> CustomComponent:
18641884
# Remove the children from the props.
18651885
props.pop("children", None)
1866-
return CustomComponent(component_fn=component_fn, children=children, **props)
1886+
return CustomComponent._create(
1887+
children=list(children), component_fn=component_fn, **props
1888+
)
18671889

18681890
return wrapper
18691891

reflex/config.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
from reflex.base import Base
3737
from reflex.utils import console
3838
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
39-
from reflex.utils.types import GenericType, is_union, value_inside_optional
39+
from reflex.utils.types import (
40+
GenericType,
41+
is_union,
42+
true_type_for_pydantic_field,
43+
value_inside_optional,
44+
)
4045

4146
try:
4247
from dotenv import load_dotenv # pyright: ignore [reportMissingImports]
@@ -943,7 +948,9 @@ def update_from_env(self) -> dict[str, Any]:
943948
# If the env var is set, override the config value.
944949
if env_var is not None:
945950
# Interpret the value.
946-
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
951+
value = interpret_env_var_value(
952+
env_var, true_type_for_pydantic_field(field), field.name
953+
)
947954

948955
# Set the value.
949956
updated_values[key] = value

reflex/state.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@
8989
from reflex.utils.types import (
9090
_isinstance,
9191
get_origin,
92-
is_optional,
9392
is_union,
9493
override,
94+
true_type_for_pydantic_field,
9595
value_inside_optional,
9696
)
9797
from reflex.vars import VarData
@@ -272,7 +272,11 @@ def __call__(self, *args: Any) -> EventSpec:
272272
return super().__call__(*args)
273273

274274

275-
def _unwrap_field_type(type_: Type) -> Type:
275+
if TYPE_CHECKING:
276+
from pydantic.v1.fields import ModelField
277+
278+
279+
def _unwrap_field_type(type_: types.GenericType) -> Type:
276280
"""Unwrap rx.Field type annotations.
277281
278282
Args:
@@ -303,7 +307,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
303307
return dispatch(
304308
field_name=field_name,
305309
var_data=VarData.from_state(cls, f.name),
306-
result_var_type=_unwrap_field_type(f.outer_type_),
310+
result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
307311
)
308312

309313

@@ -1350,9 +1354,7 @@ def __setattr__(self, name: str, value: Any):
13501354

13511355
if name in fields:
13521356
field = fields[name]
1353-
field_type = _unwrap_field_type(field.outer_type_)
1354-
if field.allow_none and not is_optional(field_type):
1355-
field_type = field_type | None
1357+
field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
13561358
if not _isinstance(value, field_type, nested=1, treat_var_as_type=False):
13571359
console.error(
13581360
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"

reflex/utils/pyi_generator.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"Literal",
6666
"Optional",
6767
"Union",
68+
"Annotated",
6869
}
6970

7071
# TODO: fix import ordering and unused imports with ruff later

0 commit comments

Comments
 (0)