diff --git a/mesa/visualization/components/altair_components.py b/mesa/visualization/components/altair_components.py index b610e46f0d0..f09167dff76 100644 --- a/mesa/visualization/components/altair_components.py +++ b/mesa/visualization/components/altair_components.py @@ -1,13 +1,10 @@ """Altair based solara components for visualization mesa spaces.""" -import contextlib import warnings +import altair as alt import solara -with contextlib.suppress(ImportError): - import altair as alt - from mesa.experimental.cell_space import DiscreteSpace, Grid from mesa.space import ContinuousSpace, _Grid from mesa.visualization.utils import update_counter @@ -30,7 +27,7 @@ def make_altair_space( Args: agent_portrayal: Function to portray agents. propertylayer_portrayal: not yet implemented - post_process :not yet implemented + post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks) space_drawing_kwargs : not yet implemented ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", @@ -46,13 +43,15 @@ def agent_portrayal(a): return {"id": a.unique_id} def MakeSpaceAltair(model): - return SpaceAltair(model, agent_portrayal) + return SpaceAltair(model, agent_portrayal, post_process=post_process) return MakeSpaceAltair @solara.component -def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): +def SpaceAltair( + model, agent_portrayal, dependencies: list[any] | None = None, post_process=None +): """Create an Altair-based space visualization component. Returns: @@ -65,6 +64,9 @@ def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): space = model.space chart = _draw_grid(space, agent_portrayal) + # Apply post-processing if provided + if post_process is not None: + chart = post_process(chart) solara.FigureAltair(chart) @@ -159,7 +161,7 @@ def _draw_grid(space, agent_portrayal): # no y-axis label "y": alt.Y("y", axis=None, type=x_y_type), "tooltip": [ - alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value])) + alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value])) for key, value in all_agent_data[0].items() if key not in invalid_tooltips ], diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index f5fde84b1a3..23d603aa820 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -97,7 +97,11 @@ def SolaraViz( reduce update frequency,resulting in faster execution. """ if components == "default": - components = [components_altair.make_altair_space()] + components = [ + components_altair.make_altair_space( + agent_portrayal=None, propertylayer_portrayal=None, post_process=None + ) + ] if model_params is None: model_params = {} diff --git a/pyproject.toml b/pyproject.toml index be0dc4a3139..b9739c917da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ network = [ viz = [ "matplotlib", "solara", + "altair", ] # Dev and CI stuff dev = [ diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 3b8d82fb7bc..6e25502e0b7 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -9,6 +9,8 @@ import mesa import mesa.visualization.components.altair_components import mesa.visualization.components.matplotlib_components +from mesa.space import MultiGrid +from mesa.visualization.components.altair_components import make_altair_space from mesa.visualization.components.matplotlib_components import make_mpl_space_component from mesa.visualization.solara_viz import ( Slider, @@ -101,17 +103,22 @@ def test_call_space_drawer(mocker): # noqa: D103 mesa.visualization.components.altair_components, "SpaceAltair" ) + class MockAgent(mesa.Agent): + def __init__(self, model): + super().__init__(model) + class MockModel(mesa.Model): def __init__(self, seed=None): super().__init__(seed=seed) + self.grid = MultiGrid(width=10, height=10, torus=True) + a = MockAgent(self) + self.grid.place_agent(a, (5, 5)) model = MockModel() - mocker.patch.object(mesa.Model, "__init__", return_value=None) - agent_portrayal = { - "marker": "circle", - "color": "gray", - } + def agent_portrayal(agent): + return {"marker": "o", "color": "gray"} + propertylayer_portrayal = None # initialize with space drawer unspecified (use default) # component must be rendered for code to run @@ -131,7 +138,33 @@ def __init__(self, seed=None): solara.render(SolaraViz(model)) # should call default method with class instance and agent portrayal assert mock_space_matplotlib.call_count == 0 - assert mock_space_altair.call_count == 0 + assert mock_space_altair.call_count == 1 # altair is the default method + + # checking if SpaceAltair is working as intended with post_process + + mock_post_process = mocker.MagicMock() + solara.render( + SolaraViz( + model, + components=[ + make_altair_space( + agent_portrayal, + propertylayer_portrayal, + mock_post_process, + ) + ], + ) + ) + + args, kwargs = mock_space_altair.call_args + assert args == (model, agent_portrayal) + assert kwargs == {"post_process": mock_post_process} + mock_post_process.assert_called_once() + assert mock_space_matplotlib.call_count == 0 + + mock_space_altair.reset_mock() + mock_space_matplotlib.reset_mock() + mock_post_process.reset_mock() # specify a custom space method class AltSpace: