Skip to content

Commit

Permalink
reset module context after invoking callback + use context manager (A…
Browse files Browse the repository at this point in the history
  • Loading branch information
dinhlongviolin1 authored Oct 20, 2023
1 parent a20dcdd commit cd1ed66
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 124 deletions.
5 changes: 2 additions & 3 deletions src/taipy/gui/_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ def render(self, gui: Gui):
raise RuntimeError(f"Can't render page {self._route}: no renderer found")
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
gui._set_locals_context(self._renderer._get_module_name())
self._rendered_jsx = self._renderer.render(gui)
gui._reset_locals_context()
with gui._set_locals_context(self._renderer._get_module_name()):
self._rendered_jsx = self._renderer.render(gui)
if w:
s = "\033[1;31m\n"
s += (
Expand Down
69 changes: 32 additions & 37 deletions src/taipy/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,23 +510,22 @@ def _manage_message(self, msg_type: _WsType, message: dict) -> None:
res = self._bindings()._get_or_create_scope(message.get("payload", ""))
client_id = res[0] if res[1] else None
self.__set_client_id_in_context(client_id or message.get(Gui.__ARG_CLIENT_ID))
self._set_locals_context(message.get("module_context") or None)
if msg_type == _WsType.UPDATE.value:
payload = message.get("payload", {})
self.__front_end_update(
str(message.get("name")),
payload.get("value"),
message.get("propagate", True),
payload.get("relvar"),
payload.get("on_change"),
)
elif msg_type == _WsType.ACTION.value:
self.__on_action(message.get("name"), message.get("payload"))
elif msg_type == _WsType.DATA_UPDATE.value:
self.__request_data_update(str(message.get("name")), message.get("payload"))
elif msg_type == _WsType.REQUEST_UPDATE.value:
self.__request_var_update(message.get("payload"))
self._reset_locals_context()
with self._set_locals_context(message.get("module_context") or None):
if msg_type == _WsType.UPDATE.value:
payload = message.get("payload", {})
self.__front_end_update(
str(message.get("name")),
payload.get("value"),
message.get("propagate", True),
payload.get("relvar"),
payload.get("on_change"),
)
elif msg_type == _WsType.ACTION.value:
self.__on_action(message.get("name"), message.get("payload"))
elif msg_type == _WsType.DATA_UPDATE.value:
self.__request_data_update(str(message.get("name")), message.get("payload"))
elif msg_type == _WsType.REQUEST_UPDATE.value:
self.__request_var_update(message.get("payload"))
self.__send_ack(message.get("ack_id"))
except Exception as e: # pragma: no cover
_warn(f"Decoding Message has failed: {message}", e)
Expand Down Expand Up @@ -1136,15 +1135,17 @@ def _call_function_with_state(self, user_function: t.Callable, args: t.List[t.An
args = args[:argcount]
return user_function(*args)

def _set_module_context(self, module_context: t.Optional[str]) -> t.ContextManager[None]:
return self._set_locals_context(module_context) if module_context is not None else contextlib.nullcontext()

def _call_user_callback(
self, state_id: t.Optional[str], user_callback: t.Callable, args: t.List[t.Any], module_context: t.Optional[str]
) -> t.Any:
try:
with self.get_flask_app().app_context():
self.__set_client_id_in_context(state_id)
if module_context is not None:
self._set_locals_context(module_context)
return self._call_function_with_state(user_callback, args)
with self._set_module_context(module_context):
return self._call_function_with_state(user_callback, args)
except Exception as e: # pragma: no cover
if not self._call_on_exception(user_callback.__name__, e):
_warn(f"invoke_callback(): Exception raised in '{user_callback.__name__}()'", e)
Expand All @@ -1153,20 +1154,17 @@ def _call_user_callback(
def _call_broadcast_callback(
self, user_callback: t.Callable, args: t.List[t.Any], module_context: t.Optional[str]
) -> t.Any:
try:
with self.get_flask_app().app_context():
# Use global scopes for broadcast callbacks
self.__set_client_id_in_context(_DataScopes._GLOBAL_ID)
if module_context is not None:
self._set_locals_context(module_context)
@contextlib.contextmanager
def _broadcast_callback() -> t.Iterator[None]:
try:
setattr(g, Gui.__BRDCST_CALLBACK_G_ID, True)
callback_result = self._call_function_with_state(user_callback, args)
yield
finally:
setattr(g, Gui.__BRDCST_CALLBACK_G_ID, False)
return callback_result
except Exception as e:
if not self._call_on_exception(user_callback.__name__, e):
_warn(f"invoke_callback(): Exception raised in '{user_callback.__name__}()':\n{e}")
return None

with _broadcast_callback():
# Use global scopes for broadcast callbacks
return self._call_user_callback(_DataScopes._GLOBAL_ID, user_callback, args, module_context)

def _is_in_brdcst_callback(self):
try:
Expand Down Expand Up @@ -1331,11 +1329,8 @@ def _get_locals_context(self) -> str:
current_context = self.__locals_context.get_context()
return current_context if current_context is not None else self.__default_module_name

def _set_locals_context(self, context: t.Optional[str]) -> None:
self.__locals_context.set_locals_context(context)

def _reset_locals_context(self) -> None:
self.__locals_context.reset_locals_context()
def _set_locals_context(self, context: t.Optional[str]) -> t.ContextManager[None]:
return self.__locals_context.set_locals_context(context)

def _get_page_context(self, page_name: str) -> str | None:
if page_name not in self._config.routes:
Expand Down
58 changes: 18 additions & 40 deletions src/taipy/gui/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

import inspect
import typing as t
from contextlib import nullcontext
from operator import attrgetter
from types import FrameType

from flask import has_app_context
from flask.ctx import AppContext

from .utils import _get_module_name_from_frame, _is_in_notebook
from .utils._attributes import _attrsetter
Expand Down Expand Up @@ -83,7 +85,7 @@ def change_variable(state):
"get_gui",
"refresh",
"_set_context",
"_reset_context",
"_notebook_context",
"_get_placeholder",
"_set_placeholder",
"_get_gui_attr",
Expand Down Expand Up @@ -127,19 +129,9 @@ def __getattribute__(self, name: str) -> t.Any:
raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.")
if name not in super().__getattribute__(State.__attrs[1]):
raise AttributeError(f"Variable '{name}' is not defined.")
if not has_app_context() and _is_in_notebook():
with gui.get_flask_app().app_context():
# Code duplication is ugly but necessary due to frame resolution
set_context = self._set_context(gui)
encoded_name = gui._bind_var(name)
attr = getattr(gui._bindings(), encoded_name)
self._reset_context(gui, set_context)
return attr
set_context = self._set_context(gui)
encoded_name = gui._bind_var(name)
attr = getattr(gui._bindings(), encoded_name)
self._reset_context(gui, set_context)
return attr
with self._notebook_context(gui), self._set_context(gui):
encoded_name = gui._bind_var(name)
return getattr(gui._bindings(), encoded_name)

def __setattr__(self, name: str, value: t.Any) -> None:
gui: "Gui" = super().__getattribute__(State.__gui_attr)
Expand All @@ -149,18 +141,9 @@ def __setattr__(self, name: str, value: t.Any) -> None:
raise AttributeError(f"Variable '{name}' is not available to be accessed in shared callback.")
if name not in super().__getattribute__(State.__attrs[1]):
raise AttributeError(f"Variable '{name}' is not accessible.")
if not has_app_context() and _is_in_notebook():
with gui.get_flask_app().app_context():
# Code duplication is ugly but necessary due to frame resolution
set_context = self._set_context(gui)
encoded_name = gui._bind_var(name)
setattr(gui._bindings(), encoded_name, value)
self._reset_context(gui, set_context)
return
set_context = self._set_context(gui)
encoded_name = gui._bind_var(name)
setattr(gui._bindings(), encoded_name, value)
self._reset_context(gui, set_context)
with self._notebook_context(gui), self._set_context(gui):
encoded_name = gui._bind_var(name)
setattr(gui._bindings(), encoded_name, value)

def __getitem__(self, key: str):
context = key if key in super().__getattribute__(State.__attrs[2]) else None
Expand All @@ -173,25 +156,21 @@ def __getitem__(self, key: str):
self._set_placeholder(State.__placeholder_attrs[1], context)
return self

def _set_context(self, gui: "Gui") -> bool:
def _set_context(self, gui: "Gui") -> t.ContextManager[None]:
if (pl_ctx := self._get_placeholder(State.__placeholder_attrs[1])) is not None:
self._set_placeholder(State.__placeholder_attrs[1], None)
if pl_ctx != gui._get_locals_context():
gui._set_locals_context(pl_ctx)
return True
return gui._set_locals_context(pl_ctx)
if len(inspect.stack()) > 1:
ctx = _get_module_name_from_frame(t.cast(FrameType, t.cast(FrameType, inspect.stack()[2].frame)))
current_context = gui._get_locals_context()
# ignore context if the current one starts with the new one (to resolve for class modules)
if ctx != current_context and not current_context.startswith(str(ctx)):
gui._set_locals_context(ctx)
return True
return False
return gui._set_locals_context(ctx)
return nullcontext()

def _reset_context(self, gui: "Gui", set_context: bool) -> None:
if not set_context:
return
gui._reset_locals_context()
def _notebook_context(self, gui: "Gui"):
return gui.get_flask_app().app_context() if not has_app_context() and _is_in_notebook() else nullcontext()

def _get_placeholder(self, name: str):
if name in State.__placeholder_attrs:
Expand Down Expand Up @@ -255,10 +234,9 @@ def broadcast(self, name: str, value: t.Any):
value (Any): The new variable value.
"""
gui: "Gui" = super().__getattribute__(State.__gui_attr)
set_context = self._set_context(gui)
encoded_name = gui._bind_var(name)
gui._broadcast_all_clients(encoded_name, value)
self._reset_context(gui, set_context)
with self._set_context(gui):
encoded_name = gui._bind_var(name)
gui._broadcast_all_clients(encoded_name, value)

def __enter__(self):
super().__getattribute__(State.__attrs[0]).__enter__()
Expand Down
25 changes: 14 additions & 11 deletions src/taipy/gui/utils/_locals_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import contextlib
import typing as t

from flask import g
Expand Down Expand Up @@ -45,11 +46,20 @@ def add(self, context: t.Optional[str], locals_dict: t.Optional[t.Dict[str, t.An
if context is not None and locals_dict is not None and context not in self._locals_map:
self._locals_map[context] = locals_dict

def set_locals_context(self, context: t.Optional[str]) -> None:
if context in self._locals_map:
@contextlib.contextmanager
def set_locals_context(self, context: t.Optional[str]) -> t.Iterator[None]:
try:
if context in self._locals_map:
if hasattr(g, _LocalsContext.__ctx_g_name):
self._lc_stack.append(getattr(g, _LocalsContext.__ctx_g_name))
setattr(g, _LocalsContext.__ctx_g_name, context)
yield
finally:
if hasattr(g, _LocalsContext.__ctx_g_name):
self._lc_stack.append(getattr(g, _LocalsContext.__ctx_g_name))
setattr(g, _LocalsContext.__ctx_g_name, context)
if len(self._lc_stack) > 0:
setattr(g, _LocalsContext.__ctx_g_name, self._lc_stack.pop())
else:
delattr(g, _LocalsContext.__ctx_g_name)

def get_locals(self) -> t.Dict[str, t.Any]:
return self.get_default() if (context := self.get_context()) is None else self._locals_map[context]
Expand All @@ -64,10 +74,3 @@ def _get_locals_bind_from_context(self, context: t.Optional[str]):
if context is None:
context = self.__default_module
return self._locals_map[context]

def reset_locals_context(self) -> None:
if hasattr(g, _LocalsContext.__ctx_g_name):
if len(self._lc_stack) > 0:
setattr(g, _LocalsContext.__ctx_g_name, self._lc_stack.pop())
else:
delattr(g, _LocalsContext.__ctx_g_name)
55 changes: 26 additions & 29 deletions src/taipy/gui/utils/_variable_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,42 @@ def pre_process_module_import_all(self) -> None:
continue
if module not in self._locals_context._locals_map.keys():
continue
self._locals_context.set_locals_context(module)
additional_var_list.extend(
(v, v, module) for v in self._locals_context.get_locals().keys() if not v.startswith("_")
)
self._locals_context.reset_locals_context()
with self._locals_context.set_locals_context(module):
additional_var_list.extend(
(v, v, module) for v in self._locals_context.get_locals().keys() if not v.startswith("_")
)
imported_dir.extend(additional_var_list)

def process_imported_var(self) -> None:
self.pre_process_module_import_all()
default_imported_dir = self._imported_var_dir[self._default_module]
self._locals_context.set_locals_context(self._default_module)
for name, asname, module in default_imported_dir:
if name == "*" and asname == "*":
continue
imported_module_name = _get_module_name_from_imported_var(
name, self._locals_context.get_locals().get(asname, None), module
)
temp_var_name = self.add_var(asname, self._default_module)
self.add_var(name, imported_module_name, temp_var_name)
self._locals_context.reset_locals_context()

for k, v in self._imported_var_dir.items():
self._locals_context.set_locals_context(k)
for name, asname, module in v:
with self._locals_context.set_locals_context(self._default_module):
for name, asname, module in default_imported_dir:
if name == "*" and asname == "*":
continue
imported_module_name = _get_module_name_from_imported_var(
name, self._locals_context.get_locals().get(asname, None), module
)
var_name = self.get_var(name, imported_module_name)
var_asname = self.get_var(asname, k)
if var_name is None and var_asname is None:
temp_var_name = self.add_var(asname, k)
self.add_var(name, imported_module_name, temp_var_name)
elif var_name is not None:
self.add_var(asname, k, var_name)
else:
self.add_var(name, imported_module_name, var_asname)
self._locals_context.reset_locals_context()
temp_var_name = self.add_var(asname, self._default_module)
self.add_var(name, imported_module_name, temp_var_name)

for k, v in self._imported_var_dir.items():
with self._locals_context.set_locals_context(k):
for name, asname, module in v:
if name == "*" and asname == "*":
continue
imported_module_name = _get_module_name_from_imported_var(
name, self._locals_context.get_locals().get(asname, None), module
)
var_name = self.get_var(name, imported_module_name)
var_asname = self.get_var(asname, k)
if var_name is None and var_asname is None:
temp_var_name = self.add_var(asname, k)
self.add_var(name, imported_module_name, temp_var_name)
elif var_name is not None:
self.add_var(asname, k, var_name)
else:
self.add_var(name, imported_module_name, var_asname)

def add_var(self, name: str, module: t.Optional[str], var_name: t.Optional[str] = None) -> str:
if module is None:
Expand Down
7 changes: 3 additions & 4 deletions tests/taipy/gui/gui_specific/test_locals_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def test_locals_context(gui: Gui):
lc.add("test", temp_locals)
assert lc.get_context() is None
assert lc.get_locals() == current_locals
lc.set_locals_context("test")
assert lc.get_context() == "test"
assert lc.get_locals() == temp_locals
lc.reset_locals_context()
with lc.set_locals_context("test"):
assert lc.get_context() == "test"
assert lc.get_locals() == temp_locals
assert lc.get_context() is None
assert lc.get_locals() == current_locals
assert lc.is_default() is True
Expand Down

0 comments on commit cd1ed66

Please sign in to comment.