diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 3ddd9098ecf..cc8d59cdd68 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -9,8 +9,7 @@ import contextlib import itertools import warnings -from collections.abc import Callable, Iterator -from functools import lru_cache +from collections.abc import Callable from itertools import pairwise from typing import Any @@ -19,7 +18,7 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable -from matplotlib.collections import LineCollection, PatchCollection, PolyCollection +from matplotlib.collections import LineCollection, PatchCollection from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba from matplotlib.patches import Polygon @@ -160,37 +159,6 @@ def draw_space( return ax -@lru_cache(maxsize=1024, typed=True) -def _get_hexmesh( - width: int, height: int, size: float = 1.0 -) -> Iterator[list[tuple[float, float]]]: - """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" - - # Helper function for getting the vertices of a hexagon given the center and size - def _get_hex_vertices( - center_x: float, center_y: float, size: float = 1.0 - ) -> list[tuple[float, float]]: - """Get vertices for a hexagon centered at (center_x, center_y).""" - vertices = [ - (center_x, center_y + size), # top - (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right - (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right - (center_x, center_y - size), # bottom - (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left - (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left - ] - return vertices - - x_spacing = np.sqrt(3) * size - y_spacing = 1.5 * size - - for row, col in itertools.product(range(height), range(width)): - # Calculate center position with offset for even rows - x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) - y = row * y_spacing - yield _get_hex_vertices(x, y, size) - - def draw_property_layers( space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes ): @@ -237,74 +205,46 @@ def draw_property_layers( vmax = portrayal.get("vmax", np.max(data)) colorbar = portrayal.get("colorbar", True) - # Prepare colormap + # Draw the layer if "color" in portrayal: + data = data.T rgba_color = to_rgba(portrayal["color"]) + normalized_data = (data - vmin) / (vmax - vmin) + rgba_data = np.full((*data.shape, 4), rgba_color) + rgba_data[..., 3] *= normalized_data * alpha + rgba_data = np.clip(rgba_data, 0, 1) cmap = LinearSegmentedColormap.from_list( layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)] ) + im = ax.imshow( + rgba_data, + origin="lower", + ) + if colorbar: + norm = Normalize(vmin=vmin, vmax=vmax) + sm = ScalarMappable(norm=norm, cmap=cmap) + sm.set_array([]) + ax.figure.colorbar(sm, ax=ax, orientation="vertical") + elif "colormap" in portrayal: cmap = portrayal.get("colormap", "viridis") if isinstance(cmap, list): cmap = LinearSegmentedColormap.from_list(layer_name, cmap) - elif isinstance(cmap, str): - cmap = plt.get_cmap(cmap) + im = ax.imshow( + data.T, + cmap=cmap, + alpha=alpha, + vmin=vmin, + vmax=vmax, + origin="lower", + ) + if colorbar: + plt.colorbar(im, ax=ax, label=layer_name) else: raise ValueError( f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) - if isinstance(space, OrthogonalGrid): - if "color" in portrayal: - data = data.T - normalized_data = (data - vmin) / (vmax - vmin) - rgba_data = np.full((*data.shape, 4), rgba_color) - rgba_data[..., 3] *= normalized_data * alpha - rgba_data = np.clip(rgba_data, 0, 1) - ax.imshow(rgba_data, origin="lower") - else: - ax.imshow( - data.T, - cmap=cmap, - alpha=alpha, - vmin=vmin, - vmax=vmax, - origin="lower", - ) - - elif isinstance(space, HexGrid): - width, height = data.shape - - # Generate hexagon mesh - hexagons = _get_hexmesh(width, height) - - # Normalize colors - norm = Normalize(vmin=vmin, vmax=vmax) - colors = data.ravel() # flatten data to 1D array - - if "color" in portrayal: - normalized_colors = np.clip(norm(colors), 0, 1) - rgba_colors = np.full((len(colors), 4), rgba_color) - rgba_colors[:, 3] = normalized_colors * alpha - else: - rgba_colors = cmap(norm(colors)) - - # Draw hexagons - collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) - ax.add_collection(collection) - - else: - raise NotImplementedError( - f"PropertyLayer visualization not implemented for {type(space)}." - ) - - # Add colorbar if requested - if colorbar: - norm = Normalize(vmin=vmin, vmax=vmax) - sm = ScalarMappable(norm=norm, cmap=cmap) - sm.set_array([]) - plt.colorbar(sm, ax=ax, label=layer_name) - def draw_orthogonal_grid( space: OrthogonalGrid, @@ -409,15 +349,39 @@ def draw_hex_grid( def setup_hexmesh(width, height): """Helper function for creating the hexmesh with unique edges.""" edges = set() + size = 1.0 + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + + def get_hex_vertices( + center_x: float, center_y: float + ) -> list[tuple[float, float]]: + """Get vertices for a hexagon centered at (center_x, center_y).""" + vertices = [ + (center_x, center_y + size), # top + (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right + (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right + (center_x, center_y - size), # bottom + (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left + (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left + ] + return vertices # Generate edges for each hexagon - for vertices in _get_hexmesh(width, height): + for row, col in itertools.product(range(height), range(width)): + # Calculate center position for each hexagon with offset for even rows + x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) + y = row * y_spacing + + vertices = get_hex_vertices(x, y) + # Edge logic, connecting each vertex to the next for v1, v2 in pairwise([*vertices, vertices[0]]): # Sort vertices to ensure consistent edge representation and avoid duplicates. edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) edges.add(edge) + # Return LineCollection for hexmesh return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1) if draw_grid: