Skip to content

Commit

Permalink
fix(src): 🐛 Fix bug in visualizing empty agent sets
Browse files Browse the repository at this point in the history
  • Loading branch information
SongshGeo committed Feb 6, 2025
1 parent 8ad1df7 commit 96abff1
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 30 deletions.
22 changes: 5 additions & 17 deletions abses/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __getitem__(self, breeds: Optional[Breeds]) -> ActorsList[Actor]:
agents = []
for breed in breeds:
breed_type = self._get_breed_type(breed)
agents.extend(self._model.agents_by_type[breed_type])
agents.extend(self._model.agents_by_type.get(breed_type, []))
return ActorsList(model=self.model, objs=agents)

# 单个 breed 的情况
Expand Down Expand Up @@ -144,11 +144,7 @@ def random(self) -> ListRandom:
@property
def is_full(self) -> bool:
"""Whether the container is full."""
return (
False
if self._max_length is None
else len(self) >= self._max_length
)
return False if self._max_length is None else len(self) >= self._max_length

@property
def is_empty(self) -> bool:
Expand Down Expand Up @@ -258,15 +254,9 @@ def new(
agent = self._new_one(agent_cls=breed_cls, **kwargs)
objs.append(agent)
# return the created actor(s).
actors_list: ActorsList[Actor] = ActorsList(
model=self.model, objs=objs
)
actors_list: ActorsList[Actor] = ActorsList(model=self.model, objs=objs)
logger.debug(f"{self} created {num} {breed_cls.__name__}.")
return (
cast(Actor, actors_list.item())
if singleton is True
else actors_list
)
return cast(Actor, actors_list.item()) if singleton is True else actors_list

def remove(self, agent: Actor) -> None:
"""Remove the given agent from the container."""
Expand Down Expand Up @@ -354,9 +344,7 @@ def check_attr(agent, attr, value=True):
agents_set = self._agents
for attr, value in selection.items():
filter_func = partial(check_attr, attr=attr, value=value)
agents_set = agents_set.select(
filter_func=filter_func, **kwargs
)
agents_set = agents_set.select(filter_func=filter_func, **kwargs)
else:
raise TypeError(f"{selection} is not valid selection criteria.")
return ActorsList(model=self.model, objs=agents_set)
Expand Down
121 changes: 108 additions & 13 deletions abses/viz/solara.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,24 @@
# GitHub : https://github.com/SongshGeo
# Website: https://cv.songshgeo.com/

import contextlib
import warnings
from typing import Any, Callable

import matplotlib.pyplot as plt
import numpy as np
import solara
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.figure import Figure
from mesa.visualization.mpl_space_drawing import draw_orthogonal_grid
from mesa.visualization.mpl_space_drawing import _scatter
from mesa.visualization.utils import update_counter
from xarray import DataArray

from abses.main import MainModel
from abses.patch import PatchModule

try:
import solara
except ImportError as e:
raise ImportError(
"`solara` is not installed, please install it via `pip install solara`"
) from e


def draw_property_layers(
space: PatchModule,
Expand All @@ -49,11 +45,7 @@ def draw_property_layers(
for layer_name, portrayal in propertylayer_portrayal.items():
layer: DataArray = space.get_xarray(layer_name)

data = (
layer.data.astype(float)
if layer.data.dtype == bool
else layer.data
)
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

# Get portrayal properties, or use defaults
alpha = portrayal.get("alpha", 1)
Expand Down Expand Up @@ -101,6 +93,109 @@ def draw_property_layers(
)


def collect_agent_data(
space: PatchModule,
agent_portrayal: Callable,
color="tab:blue",
size=25,
marker="o",
zorder: int = 1,
):
"""Collect the plotting data for all agents in the space.
Args:
space: The space containing the Agents.
agent_portrayal: A callable that is called with the agent and returns a dict
color: default color
size: default size
marker: default marker
zorder: default zorder
agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order),
marker (marker style), alpha, linewidths, and edgecolors
"""
arguments: dict[str, list[Any]] = {
"s": [],
"c": [],
"marker": [],
"zorder": [],
"alpha": [],
"edgecolors": [],
"linewidths": [],
}

for agent in space.agents:
portray = agent_portrayal(agent)
arguments["s"].append(portray.pop("size", size))
arguments["c"].append(portray.pop("color", color))
arguments["marker"].append(portray.pop("marker", marker))
arguments["zorder"].append(portray.pop("zorder", zorder))

for entry in ["alpha", "edgecolors", "linewidths"]:
with contextlib.suppress(KeyError):
arguments[entry].append(portray.pop(entry))

if len(portray) > 0:
ignored_fields = list(portray.keys())
msg = ", ".join(ignored_fields)
warnings.warn(
f"the following fields are not used in agent portrayal and thus ignored: {msg}.",
stacklevel=2,
)
# ensure loc is always a shape of (n, 2) array, even if n=0
result = {k: np.asarray(v) for k, v in arguments.items()}
result["loc"] = space.agents.array("indices")
return result


def draw_orthogonal_grid(
space: PatchModule,
agent_portrayal: Callable,
ax: Axes | None = None,
draw_grid: bool = True,
**kwargs,
):
"""Visualize a orthogonal grid.
Args:
space: the space to visualize
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
"""
if ax is None:
fig, ax = plt.subplots()

# gather agent data
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# plot the agents
_scatter(ax, arguments, **kwargs)

# further styling
ax.set_xlim(-0.5, space.width - 0.5)
ax.set_ylim(-0.5, space.height - 0.5)

if draw_grid:
# Draw grid lines
for x in np.arange(-0.5, space.width - 0.5, 1):
ax.axvline(x, color="gray", linestyle=":")
for y in np.arange(-0.5, space.height - 0.5, 1):
ax.axhline(y, color="gray", linestyle=":")

return ax


@solara.component
def SpaceMatplotlib(
model: MainModel,
Expand Down

0 comments on commit 96abff1

Please sign in to comment.