diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index a180d290be2..6209f558f9e 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -25,6 +25,8 @@ import asyncio import inspect +import threading +import time from collections.abc import Callable from typing import TYPE_CHECKING, Literal @@ -57,6 +59,7 @@ def SolaraViz( simulator: Simulator | None = None, model_params=None, name: str | None = None, + use_threads: bool = False, ): """Solara visualization component. @@ -76,6 +79,8 @@ def SolaraViz( This controls the speed of the model's automatic stepping. Defaults to 100 ms. render_interval (int, optional): Controls how often plots are updated during a simulation, allowing users to skip intermediate steps and update graphs less frequently. + use_threads: Flag for indicating whether to utilize multi-threading for model execution. + When checked, the model will utilize multiple threads,adjust based on system capabilities. simulator: A simulator that controls the model (optional) model_params (dict, optional): Parameters for (re-)instantiating a model. Can include user-adjustable parameters and fixed parameters. Defaults to None. @@ -114,6 +119,7 @@ def SolaraViz( reactive_model_parameters = solara.use_reactive({}) reactive_play_interval = solara.use_reactive(play_interval) reactive_render_interval = solara.use_reactive(render_interval) + reactive_use_threads = solara.use_reactive(use_threads) with solara.AppBar(): solara.AppBarTitle(name if name else model.value.__class__.__name__) solara.lab.ThemeToggle() @@ -136,12 +142,25 @@ def SolaraViz( max=100, step=2, ) + if reactive_use_threads.value: + solara.Text("Increase play interval to avoid skipping plots") + + def set_reactive_use_threads(value): + reactive_use_threads.set(value) + + solara.Checkbox( + label="Use Threads", + value=reactive_use_threads, + on_value=set_reactive_use_threads, + ) + if not isinstance(simulator, Simulator): ModelController( model, model_parameters=reactive_model_parameters, play_interval=reactive_play_interval, render_interval=reactive_render_interval, + use_threads=reactive_use_threads, ) else: SimulatorController( @@ -150,6 +169,7 @@ def SolaraViz( model_parameters=reactive_model_parameters, play_interval=reactive_play_interval, render_interval=reactive_render_interval, + use_threads=reactive_use_threads, ) with solara.Card("Model Parameters"): ModelCreator( @@ -211,6 +231,7 @@ def ModelController( model_parameters: dict | solara.Reactive[dict] = None, play_interval: int | solara.Reactive[int] = 100, render_interval: int | solara.Reactive[int] = 1, + use_threads: bool | solara.Reactive[bool] = False, ): """Create controls for model execution (step, play, pause, reset). @@ -219,37 +240,70 @@ def ModelController( model_parameters: Reactive parameters for (re-)instantiating a model. play_interval: Interval for playing the model steps in milliseconds. render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency. + use_threads: Flag for indicating whether to utilize multi-threading for model execution. """ playing = solara.use_reactive(False) running = solara.use_reactive(True) + if model_parameters is None: model_parameters = {} model_parameters = solara.use_reactive(model_parameters) - - async def step(): - while playing.value and running.value: - await asyncio.sleep(play_interval.value / 1000) - do_step() + visualization_pause_event = solara.use_memo(lambda: threading.Event(), []) + + def step(): + try: + while running.value and playing.value: + time.sleep(play_interval.value / 1000) + do_step() + if use_threads.value: + visualization_pause_event.set() + except Exception as e: + print(f"Error in step: {e}") + return + + def visualization_task(): + if use_threads.value: + try: + while playing.value and running.value: + visualization_pause_event.wait() + visualization_pause_event.clear() + force_update() + except Exception as e: + print(f"Error in visualization_task: {e}") solara.lab.use_task( - step, dependencies=[playing.value, running.value], prefer_threaded=False + step, dependencies=[playing.value, running.value], prefer_threaded=True + ) + + solara.use_thread( + visualization_task, + dependencies=[playing.value, running.value], ) @function_logger(__name__) def do_step(): """Advance the model by the number of steps specified by the render_interval slider.""" - for _ in range(render_interval.value): - model.value.step() + if playing.value: + for _ in range(render_interval.value): + model.value.step() + running.value = model.value.running + if not playing.value: + break + if not use_threads.value: + force_update() - running.value = model.value.running - - force_update() + else: + for _ in range(render_interval.value): + model.value.step() + running.value = model.value.running + force_update() @function_logger(__name__) def do_reset(): """Reset the model to its initial state.""" playing.value = False running.value = True + visualization_pause_event.clear() _mesa_logger.log( 10, f"creating new {model.value.__class__} instance with {model_parameters.value}", @@ -285,6 +339,7 @@ def SimulatorController( model_parameters: dict | solara.Reactive[dict] = None, play_interval: int | solara.Reactive[int] = 100, render_interval: int | solara.Reactive[int] = 1, + use_threads: bool | solara.Reactive[bool] = False, ): """Create controls for model execution (step, play, pause, reset). @@ -294,6 +349,7 @@ def SimulatorController( model_parameters: Reactive parameters for (re-)instantiating a model. play_interval: Interval for playing the model steps in milliseconds. render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency. + use_threads: Flag for indicating whether to utilize multi-threading for model execution. Notes: The `step button` increments the step by the value specified in the `render_interval` slider. @@ -304,27 +360,66 @@ def SimulatorController( if model_parameters is None: model_parameters = {} model_parameters = solara.use_reactive(model_parameters) - - async def step(): - while playing.value and running.value: - await asyncio.sleep(play_interval.value / 1000) - do_step() + visualization_pause_event = solara.use_memo(lambda: threading.Event(), []) + pause_step_event = solara.use_memo(lambda: threading.Event(), []) + + def step(): + try: + while running.value and playing.value: + time.sleep(play_interval.value / 1000) + if use_threads.value: + pause_step_event.wait() + pause_step_event.clear() + do_step() + if use_threads.value: + visualization_pause_event.set() + except Exception as e: + print(f"Error in step: {e}") + + def visualization_task(): + if use_threads.value: + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + pause_step_event.set() + while playing.value and running.value: + visualization_pause_event.wait() + visualization_pause_event.clear() + force_update() + pause_step_event.set() + except Exception as e: + print(f"Error in visualization_task: {e}") + return solara.lab.use_task( step, dependencies=[playing.value, running.value], prefer_threaded=False ) + solara.lab.use_task(visualization_task, dependencies=[playing.value]) def do_step(): """Advance the model by the number of steps specified by the render_interval slider.""" - simulator.run_for(render_interval.value) - running.value = model.value.running - force_update() + if playing.value: + for _ in range(render_interval.value): + simulator.run_for(1) + running.value = model.value.running + if not playing.value: + break + if not use_threads.value: + force_update() + + else: + for _ in range(render_interval.value): + simulator.run_for(1) + running.value = model.value.running + force_update() def do_reset(): """Reset the model to its initial state.""" playing.value = False running.value = True simulator.reset() + visualization_pause_event.clear() + pause_step_event.clear() model.value = model.value = model.value.__class__( simulator=simulator, **model_parameters.value )