Skip to content

Commit

Permalink
improve var base typing (#4718)
Browse files Browse the repository at this point in the history
* improve var base typing

* fix pyi

* dang it darglint

* drain _process in tests

* fixes #4576

* dang it darglint
  • Loading branch information
adhami3310 authored Jan 31, 2025
1 parent 12a42b6 commit 8663dbc
Show file tree
Hide file tree
Showing 21 changed files with 280 additions and 265 deletions.
3 changes: 2 additions & 1 deletion reflex/components/base/error_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from reflex.state import FrontendEventExceptionState
from reflex.vars.base import Var
from reflex.vars.function import ArgsFunctionOperation
from reflex.vars.object import ObjectVar


def on_error_spec(
error: Var[Dict[str, str]], info: Var[Dict[str, str]]
error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
) -> Tuple[Var[str], Var[str]]:
"""The spec for the on_error event handler.
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/base/error_boundary.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ from reflex.components.component import Component
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.vars.base import Var
from reflex.vars.object import ObjectVar

def on_error_spec(
error: Var[Dict[str, str]], info: Var[Dict[str, str]]
error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
) -> Tuple[Var[str], Var[str]]: ...

class ErrorBoundary(Component):
Expand Down
1 change: 1 addition & 0 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"""A Var that represents a Component."""
Expand Down
12 changes: 10 additions & 2 deletions reflex/components/core/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode
from reflex.state import ComponentState
from reflex.utils.exceptions import UntypedVarError
from reflex.vars.base import LiteralVar, Var


Expand Down Expand Up @@ -51,6 +52,7 @@ def create(
Raises:
ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState.
UntypedVarError: If the iterable is of type Any without a type annotation.
"""
iterable = LiteralVar.create(iterable)
if iterable._var_type == Any:
Expand All @@ -72,8 +74,14 @@ def create(
iterable=iterable,
render_fn=render_fn,
)
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
try:
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
except UntypedVarError as e:
raise UntypedVarError(
f"Could not foreach over var `{iterable!s}` without a type annotation. "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
) from e
return component

def _render(self) -> IterTag:
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/datadisplay/dataeditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def create(cls, *children, **props) -> Component:
raise ValueError(
"DataEditor data must be an ArrayVar if rows is not provided."
)
props["rows"] = data.length() if isinstance(data, Var) else len(data)

props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data)

if not isinstance(columns, Var) and len(columns):
if types.is_dataframe(type(data)) or (
Expand Down
12 changes: 8 additions & 4 deletions reflex/components/datadisplay/shiki_code_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,18 +621,22 @@ def add_imports(self) -> dict[str, list[str]]:
Returns:
Imports for the component.
Raises:
ValueError: If the transformers are not of type LiteralVar.
"""
imports = defaultdict(list)
if not isinstance(self.transformers, LiteralVar):
raise ValueError(
f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead."
)
for transformer in self.transformers._var_value:
if isinstance(transformer, ShikiBaseTransformers):
imports[transformer.library].extend(
[ImportVar(tag=str(fn)) for fn in transformer.fns]
)
(
if transformer.library not in self.lib_dependencies:
self.lib_dependencies.append(transformer.library)
if transformer.library not in self.lib_dependencies
else None
)
return imports

@classmethod
Expand Down
25 changes: 17 additions & 8 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import dataclasses
import inspect
import sys
import types
import urllib.parse
from base64 import b64encode
Expand Down Expand Up @@ -541,7 +540,7 @@ class JavasciptKeyboardEvent:
shiftKey: bool = False # noqa: N815


def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]:
def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]:
"""Get the value from an input event.
Args:
Expand All @@ -562,7 +561,9 @@ class KeyInputInfo(TypedDict):
shift_key: bool


def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]:
def key_event(
e: ObjectVar[JavasciptKeyboardEvent],
) -> Tuple[Var[str], Var[KeyInputInfo]]:
"""Get the key from a keyboard event.
Args:
Expand All @@ -572,15 +573,15 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
The key from the keyboard event.
"""
return (
e.key,
e.key.to(str),
Var.create(
{
"alt_key": e.altKey,
"ctrl_key": e.ctrlKey,
"meta_key": e.metaKey,
"shift_key": e.shiftKey,
},
),
).to(KeyInputInfo),
)


Expand Down Expand Up @@ -1354,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType):
Returns:
The unwrapped annotation.
"""
if get_origin(annotation) is Var and (args := get_args(annotation)):
if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)):
return args[0]
return annotation

Expand Down Expand Up @@ -1620,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
"""A literal event var."""
Expand Down Expand Up @@ -1681,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
# Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
Expand Down Expand Up @@ -1713,6 +1714,9 @@ def create(
Returns:
The created LiteralEventChainVar instance.
Raises:
ValueError: If the invocation is not a FunctionVar.
"""
arg_spec = (
value.args_spec[0]
Expand Down Expand Up @@ -1740,6 +1744,11 @@ def create(
else:
invocation = value.invocation

if invocation is not None and not isinstance(invocation, FunctionVar):
raise ValueError(
f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}."
)

return cls(
_js_expr="",
_var_type=EventChain,
Expand Down
3 changes: 1 addition & 2 deletions reflex/experimental/client_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import dataclasses
import re
import sys
from typing import Any, Callable, Union

from reflex import constants
Expand Down Expand Up @@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str:
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ClientStateVar(Var):
"""A Var that exists on the client via useState."""
Expand Down
6 changes: 4 additions & 2 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,9 +1637,11 @@ async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE:
if not isinstance(var, Var):
return var

unset = object()

# Fast case: this is a literal var and the value is known.
if hasattr(var, "_var_value"):
return var._var_value
if (var_value := getattr(var, "_var_value", unset)) is not unset:
return var_value # pyright: ignore [reportReturnType]

var_data = var._get_all_var_data()
if var_data is None or not var_data.state:
Expand Down
4 changes: 4 additions & 0 deletions reflex/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError):
"""Custom AttributeError for var related errors."""


class UntypedVarError(ReflexError, TypeError):
"""Custom TypeError for untyped var errors."""


class UntypedComputedVarError(ReflexError, TypeError):
"""Custom TypeError for untyped computed var errors."""

Expand Down
Loading

0 comments on commit 8663dbc

Please sign in to comment.