diff --git a/src/taipy/gui/_page.py b/src/taipy/gui/_page.py index f9b5a7d3..45478a37 100644 --- a/src/taipy/gui/_page.py +++ b/src/taipy/gui/_page.py @@ -34,9 +34,8 @@ def render(self, gui: Gui): raise RuntimeError(f"Can't render page {self._route}: no renderer found") with warnings.catch_warnings(record=True) as w: warnings.resetwarnings() - gui._set_locals_context(self._renderer._get_module_name()) - self._rendered_jsx = self._renderer.render(gui) - gui._reset_locals_context() + with gui._set_locals_context(self._renderer._get_module_name()): + self._rendered_jsx = self._renderer.render(gui) if w: s = "\033[1;31m\n" s += ( diff --git a/src/taipy/gui/gui.py b/src/taipy/gui/gui.py index 7b708d0b..488bdf73 100644 --- a/src/taipy/gui/gui.py +++ b/src/taipy/gui/gui.py @@ -510,23 +510,22 @@ def _manage_message(self, msg_type: _WsType, message: dict) -> None: res = self._bindings()._get_or_create_scope(message.get("payload", "")) client_id = res[0] if res[1] else None self.__set_client_id_in_context(client_id or message.get(Gui.__ARG_CLIENT_ID)) - self._set_locals_context(message.get("module_context") or None) - if msg_type == _WsType.UPDATE.value: - payload = message.get("payload", {}) - self.__front_end_update( - str(message.get("name")), - payload.get("value"), - message.get("propagate", True), - payload.get("relvar"), - payload.get("on_change"), - ) - elif msg_type == _WsType.ACTION.value: - self.__on_action(message.get("name"), message.get("payload")) - elif msg_type == _WsType.DATA_UPDATE.value: - self.__request_data_update(str(message.get("name")), message.get("payload")) - elif msg_type == _WsType.REQUEST_UPDATE.value: - self.__request_var_update(message.get("payload")) - self._reset_locals_context() + with self._set_locals_context(message.get("module_context") or None): + if msg_type == _WsType.UPDATE.value: + payload = message.get("payload", {}) + self.__front_end_update( + str(message.get("name")), + payload.get("value"), + message.get("propagate", True), + payload.get("relvar"), + payload.get("on_change"), + ) + elif msg_type == _WsType.ACTION.value: + self.__on_action(message.get("name"), message.get("payload")) + elif msg_type == _WsType.DATA_UPDATE.value: + self.__request_data_update(str(message.get("name")), message.get("payload")) + elif msg_type == _WsType.REQUEST_UPDATE.value: + self.__request_var_update(message.get("payload")) self.__send_ack(message.get("ack_id")) except Exception as e: # pragma: no cover _warn(f"Decoding Message has failed: {message}", e) @@ -1136,15 +1135,17 @@ def _call_function_with_state(self, user_function: t.Callable, args: t.List[t.An args = args[:argcount] return user_function(*args) + def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]: + return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext() + def _call_user_callback( self, state_id: t.Optional[str], user_callback: t.Callable, args: t.List[t.Any], module_context: t.Optional[str] ) -> t.Any: try: with self.get_flask_app().app_context(): self.__set_client_id_in_context(state_id) - if module_context is not None: - self._set_locals_context(module_context) - return self._call_function_with_state(user_callback, args) + with self._set_module_context(module_context): + return self._call_function_with_state(user_callback, args) except Exception as e: # pragma: no cover if not self._call_on_exception(user_callback.__name__, e): _warn(f"invoke_callback(): Exception raised in '{user_callback.__name__}()'", e) @@ -1153,20 +1154,17 @@ def _call_user_callback( def _call_broadcast_callback( self, user_callback: t.Callable, args: t.List[t.Any], module_context: t.Optional[str] ) -> t.Any: - try: - with self.get_flask_app().app_context(): - # Use global scopes for broadcast callbacks - self.__set_client_id_in_context(_DataScopes._GLOBAL_ID) - if module_context is not None: - self._set_locals_context(module_context) + @contextlib.contextmanager + def _broadcast_callback() -> t.Iterator[None]: + try: setattr(g, Gui.__BRDCST_CALLBACK_G_ID, True) - callback_result = self._call_function_with_state(user_callback, args) + yield + finally: setattr(g, Gui.__BRDCST_CALLBACK_G_ID, False) - return callback_result - except Exception as e: - if not self._call_on_exception(user_callback.__name__, e): - _warn(f"invoke_callback(): Exception raised in '{user_callback.__name__}()':\n{e}") - return None + + with _broadcast_callback(): + # Use global scopes for broadcast callbacks + return self._call_user_callback(_DataScopes._GLOBAL_ID, user_callback, args, module_context) def _is_in_brdcst_callback(self): try: @@ -1331,11 +1329,8 @@ def _get_locals_context(self) -> str: current_context = self.__locals_context.get_context() return current_context if current_context is not None else self.__default_module_name - def _set_locals_context(self, context: t.Optional[str]) -> None: - self.__locals_context.set_locals_context(context) - - def _reset_locals_context(self) -> None: - self.__locals_context.reset_locals_context() + def _set_locals_context(self, context: t.Optional[str]) -> t.ContextManager[None]: + return self.__locals_context.set_locals_context(context) def _get_page_context(self, page_name: str) -> str | None: if page_name not in self._config.routes: diff --git a/src/taipy/gui/state.py b/src/taipy/gui/state.py index b83d3579..d6bb043d 100644 --- a/src/taipy/gui/state.py +++ b/src/taipy/gui/state.py @@ -11,10 +11,12 @@ import inspect import typing as t +from contextlib import nullcontext from operator import attrgetter from types import FrameType from flask import has_app_context +from flask.ctx import AppContext from .utils import _get_module_name_from_frame, _is_in_notebook from .utils._attributes import _attrsetter @@ -83,7 +85,7 @@ def change_variable(state): "get_gui", "refresh", "_set_context", - "_reset_context", + "_notebook_context", "_get_placeholder", "_set_placeholder", "_get_gui_attr", @@ -127,19 +129,9 @@ def __getattribute__(self, name: str) -> t.Any: raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.") if name not in super().__getattribute__(State.__attrs[1]): raise AttributeError(f"Variable '{name}' is not defined.") - if not has_app_context() and _is_in_notebook(): - with gui.get_flask_app().app_context(): - # Code duplication is ugly but necessary due to frame resolution - set_context = self._set_context(gui) - encoded_name = gui._bind_var(name) - attr = getattr(gui._bindings(), encoded_name) - self._reset_context(gui, set_context) - return attr - set_context = self._set_context(gui) - encoded_name = gui._bind_var(name) - attr = getattr(gui._bindings(), encoded_name) - self._reset_context(gui, set_context) - return attr + with self._notebook_context(gui), self._set_context(gui): + encoded_name = gui._bind_var(name) + return getattr(gui._bindings(), encoded_name) def __setattr__(self, name: str, value: t.Any) -> None: gui: "Gui" = super().__getattribute__(State.__gui_attr) @@ -149,18 +141,9 @@ def __setattr__(self, name: str, value: t.Any) -> None: raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.") if name not in super().__getattribute__(State.__attrs[1]): raise AttributeError(f"Variable '{name}' is not accessible.") - if not has_app_context() and _is_in_notebook(): - with gui.get_flask_app().app_context(): - # Code duplication is ugly but necessary due to frame resolution - set_context = self._set_context(gui) - encoded_name = gui._bind_var(name) - setattr(gui._bindings(), encoded_name, value) - self._reset_context(gui, set_context) - return - set_context = self._set_context(gui) - encoded_name = gui._bind_var(name) - setattr(gui._bindings(), encoded_name, value) - self._reset_context(gui, set_context) + with self._notebook_context(gui), self._set_context(gui): + encoded_name = gui._bind_var(name) + setattr(gui._bindings(), encoded_name, value) def __getitem__(self, key: str): context = key if key in super().__getattribute__(State.__attrs[2]) else None @@ -173,25 +156,21 @@ def __getitem__(self, key: str): self._set_placeholder(State.__placeholder_attrs[1], context) return self - def _set_context(self, gui: "Gui") -> bool: + def _set_context(self, gui: "Gui") -> t.ContextManager[None]: if (pl_ctx := self._get_placeholder(State.__placeholder_attrs[1])) is not None: self._set_placeholder(State.__placeholder_attrs[1], None) if pl_ctx != gui._get_locals_context(): - gui._set_locals_context(pl_ctx) - return True + return gui._set_locals_context(pl_ctx) if len(inspect.stack()) > 1: ctx = _get_module_name_from_frame(t.cast(FrameType, t.cast(FrameType, inspect.stack()[2].frame))) current_context = gui._get_locals_context() # ignore context if the current one starts with the new one (to resolve for class modules) if ctx != current_context and not current_context.startswith(str(ctx)): - gui._set_locals_context(ctx) - return True - return False + return gui._set_locals_context(ctx) + return nullcontext() - def _reset_context(self, gui: "Gui", set_context: bool) -> None: - if not set_context: - return - gui._reset_locals_context() + def _notebook_context(self, gui: "Gui"): + return gui.get_flask_app().app_context() if not has_app_context() and _is_in_notebook() else nullcontext() def _get_placeholder(self, name: str): if name in State.__placeholder_attrs: @@ -255,10 +234,9 @@ def broadcast(self, name: str, value: t.Any): value (Any): The new variable value. """ gui: "Gui" = super().__getattribute__(State.__gui_attr) - set_context = self._set_context(gui) - encoded_name = gui._bind_var(name) - gui._broadcast_all_clients(encoded_name, value) - self._reset_context(gui, set_context) + with self._set_context(gui): + encoded_name = gui._bind_var(name) + gui._broadcast_all_clients(encoded_name, value) def __enter__(self): super().__getattribute__(State.__attrs[0]).__enter__() diff --git a/src/taipy/gui/utils/_locals_context.py b/src/taipy/gui/utils/_locals_context.py index 4f3d6887..2314e744 100644 --- a/src/taipy/gui/utils/_locals_context.py +++ b/src/taipy/gui/utils/_locals_context.py @@ -11,6 +11,7 @@ from __future__ import annotations +import contextlib import typing as t from flask import g @@ -45,11 +46,20 @@ def add(self, context: t.Optional[str], locals_dict: t.Optional[t.Dict[str, t.An if context is not None and locals_dict is not None and context not in self._locals_map: self._locals_map[context] = locals_dict - def set_locals_context(self, context: t.Optional[str]) -> None: - if context in self._locals_map: + @contextlib.contextmanager + def set_locals_context(self, context: t.Optional[str]) -> t.Iterator[None]: + try: + if context in self._locals_map: + if hasattr(g, _LocalsContext.__ctx_g_name): + self._lc_stack.append(getattr(g, _LocalsContext.__ctx_g_name)) + setattr(g, _LocalsContext.__ctx_g_name, context) + yield + finally: if hasattr(g, _LocalsContext.__ctx_g_name): - self._lc_stack.append(getattr(g, _LocalsContext.__ctx_g_name)) - setattr(g, _LocalsContext.__ctx_g_name, context) + if len(self._lc_stack) > 0: + setattr(g, _LocalsContext.__ctx_g_name, self._lc_stack.pop()) + else: + delattr(g, _LocalsContext.__ctx_g_name) def get_locals(self) -> t.Dict[str, t.Any]: return self.get_default() if (context := self.get_context()) is None else self._locals_map[context] @@ -64,10 +74,3 @@ def _get_locals_bind_from_context(self, context: t.Optional[str]): if context is None: context = self.__default_module return self._locals_map[context] - - def reset_locals_context(self) -> None: - if hasattr(g, _LocalsContext.__ctx_g_name): - if len(self._lc_stack) > 0: - setattr(g, _LocalsContext.__ctx_g_name, self._lc_stack.pop()) - else: - delattr(g, _LocalsContext.__ctx_g_name) diff --git a/src/taipy/gui/utils/_variable_directory.py b/src/taipy/gui/utils/_variable_directory.py index 98664499..c80ba0c9 100644 --- a/src/taipy/gui/utils/_variable_directory.py +++ b/src/taipy/gui/utils/_variable_directory.py @@ -46,45 +46,42 @@ def pre_process_module_import_all(self) -> None: continue if module not in self._locals_context._locals_map.keys(): continue - self._locals_context.set_locals_context(module) - additional_var_list.extend( - (v, v, module) for v in self._locals_context.get_locals().keys() if not v.startswith("_") - ) - self._locals_context.reset_locals_context() + with self._locals_context.set_locals_context(module): + additional_var_list.extend( + (v, v, module) for v in self._locals_context.get_locals().keys() if not v.startswith("_") + ) imported_dir.extend(additional_var_list) def process_imported_var(self) -> None: self.pre_process_module_import_all() default_imported_dir = self._imported_var_dir[self._default_module] - self._locals_context.set_locals_context(self._default_module) - for name, asname, module in default_imported_dir: - if name == "*" and asname == "*": - continue - imported_module_name = _get_module_name_from_imported_var( - name, self._locals_context.get_locals().get(asname, None), module - ) - temp_var_name = self.add_var(asname, self._default_module) - self.add_var(name, imported_module_name, temp_var_name) - self._locals_context.reset_locals_context() - - for k, v in self._imported_var_dir.items(): - self._locals_context.set_locals_context(k) - for name, asname, module in v: + with self._locals_context.set_locals_context(self._default_module): + for name, asname, module in default_imported_dir: if name == "*" and asname == "*": continue imported_module_name = _get_module_name_from_imported_var( name, self._locals_context.get_locals().get(asname, None), module ) - var_name = self.get_var(name, imported_module_name) - var_asname = self.get_var(asname, k) - if var_name is None and var_asname is None: - temp_var_name = self.add_var(asname, k) - self.add_var(name, imported_module_name, temp_var_name) - elif var_name is not None: - self.add_var(asname, k, var_name) - else: - self.add_var(name, imported_module_name, var_asname) - self._locals_context.reset_locals_context() + temp_var_name = self.add_var(asname, self._default_module) + self.add_var(name, imported_module_name, temp_var_name) + + for k, v in self._imported_var_dir.items(): + with self._locals_context.set_locals_context(k): + for name, asname, module in v: + if name == "*" and asname == "*": + continue + imported_module_name = _get_module_name_from_imported_var( + name, self._locals_context.get_locals().get(asname, None), module + ) + var_name = self.get_var(name, imported_module_name) + var_asname = self.get_var(asname, k) + if var_name is None and var_asname is None: + temp_var_name = self.add_var(asname, k) + self.add_var(name, imported_module_name, temp_var_name) + elif var_name is not None: + self.add_var(asname, k, var_name) + else: + self.add_var(name, imported_module_name, var_asname) def add_var(self, name: str, module: t.Optional[str], var_name: t.Optional[str] = None) -> str: if module is None: diff --git a/tests/taipy/gui/gui_specific/test_locals_context.py b/tests/taipy/gui/gui_specific/test_locals_context.py index 2f6a92de..85980ef0 100644 --- a/tests/taipy/gui/gui_specific/test_locals_context.py +++ b/tests/taipy/gui/gui_specific/test_locals_context.py @@ -28,10 +28,9 @@ def test_locals_context(gui: Gui): lc.add("test", temp_locals) assert lc.get_context() is None assert lc.get_locals() == current_locals - lc.set_locals_context("test") - assert lc.get_context() == "test" - assert lc.get_locals() == temp_locals - lc.reset_locals_context() + with lc.set_locals_context("test"): + assert lc.get_context() == "test" + assert lc.get_locals() == temp_locals assert lc.get_context() is None assert lc.get_locals() == current_locals assert lc.is_default() is True