Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented post_process in Altair based components #2641

Merged
merged 10 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions mesa/visualization/components/altair_components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Altair based solara components for visualization mesa spaces."""

import contextlib
import warnings

import altair as alt
import solara

with contextlib.suppress(ImportError):
import altair as alt

from mesa.experimental.cell_space import DiscreteSpace, Grid
from mesa.space import ContinuousSpace, _Grid
from mesa.visualization.utils import update_counter
Expand All @@ -30,7 +27,7 @@ def make_altair_space(
Args:
agent_portrayal: Function to portray agents.
propertylayer_portrayal: not yet implemented
post_process :not yet implemented
post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks)
space_drawing_kwargs : not yet implemented

``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
Expand All @@ -46,13 +43,15 @@ def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceAltair(model):
return SpaceAltair(model, agent_portrayal)
return SpaceAltair(model, agent_portrayal, post_process=post_process)

return MakeSpaceAltair


@solara.component
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
def SpaceAltair(
model, agent_portrayal, dependencies: list[any] | None = None, post_process=None
):
"""Create an Altair-based space visualization component.

Returns:
Expand All @@ -65,6 +64,9 @@ def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
space = model.space

chart = _draw_grid(space, agent_portrayal)
# Apply post-processing if provided
if post_process is not None:
chart = post_process(chart)
solara.FigureAltair(chart)


Expand Down Expand Up @@ -159,7 +161,7 @@ def _draw_grid(space, agent_portrayal):
# no y-axis label
"y": alt.Y("y", axis=None, type=x_y_type),
"tooltip": [
alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value]))
alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value]))
for key, value in all_agent_data[0].items()
if key not in invalid_tooltips
],
Expand Down
6 changes: 5 additions & 1 deletion mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def SolaraViz(
reduce update frequency,resulting in faster execution.
"""
if components == "default":
components = [components_altair.make_altair_space()]
components = [
components_altair.make_altair_space(
agent_portrayal=None, propertylayer_portrayal=None, post_process=None
)
]
if model_params is None:
model_params = {}

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ network = [
viz = [
"matplotlib",
"solara",
"altair",
]
# Dev and CI stuff
dev = [
Expand Down
45 changes: 39 additions & 6 deletions tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import mesa
import mesa.visualization.components.altair_components
import mesa.visualization.components.matplotlib_components
from mesa.space import MultiGrid
from mesa.visualization.components.altair_components import make_altair_space
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
from mesa.visualization.solara_viz import (
Slider,
Expand Down Expand Up @@ -101,17 +103,22 @@ def test_call_space_drawer(mocker): # noqa: D103
mesa.visualization.components.altair_components, "SpaceAltair"
)

class MockAgent(mesa.Agent):
def __init__(self, model):
super().__init__(model)

class MockModel(mesa.Model):
def __init__(self, seed=None):
super().__init__(seed=seed)
self.grid = MultiGrid(width=10, height=10, torus=True)
a = MockAgent(self)
self.grid.place_agent(a, (5, 5))

model = MockModel()
mocker.patch.object(mesa.Model, "__init__", return_value=None)

agent_portrayal = {
"marker": "circle",
"color": "gray",
}
def agent_portrayal(agent):
return {"marker": "o", "color": "gray"}

propertylayer_portrayal = None
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
Expand All @@ -131,7 +138,33 @@ def __init__(self, seed=None):
solara.render(SolaraViz(model))
# should call default method with class instance and agent portrayal
assert mock_space_matplotlib.call_count == 0
assert mock_space_altair.call_count == 0
assert mock_space_altair.call_count == 1 # altair is the default method

# checking if SpaceAltair is working as intended with post_process

mock_post_process = mocker.MagicMock()
solara.render(
SolaraViz(
model,
components=[
make_altair_space(
agent_portrayal,
propertylayer_portrayal,
mock_post_process,
)
],
)
)

args, kwargs = mock_space_altair.call_args
assert args == (model, agent_portrayal)
assert kwargs == {"post_process": mock_post_process}
mock_post_process.assert_called_once()
assert mock_space_matplotlib.call_count == 0

mock_space_altair.reset_mock()
mock_space_matplotlib.reset_mock()
mock_post_process.reset_mock()

# specify a custom space method
class AltSpace:
Expand Down
Loading