From 2fc8eff9ab2999c3393bda4a38ddfb88ae2982b5 Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Sun, 26 Jan 2025 20:28:48 +0530 Subject: [PATCH 1/6] Fix: Property layer vizualization for HexGrid --- mesa/visualization/mpl_space_drawing.py | 131 ++++++++++++++++-------- 1 file changed, 89 insertions(+), 42 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index cc8d59cdd68..91235730a5e 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -18,7 +18,7 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable -from matplotlib.collections import LineCollection, PatchCollection +from matplotlib.collections import LineCollection, PatchCollection, PolyCollection from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba from matplotlib.patches import Polygon @@ -159,6 +159,22 @@ def draw_space( return ax +# Helper function for getting the vertices of a hexagon given the center and size +def _get_hex_vertices( + center_x: float, center_y: float, size: float +) -> list[tuple[float, float]]: + """Get vertices for a hexagon centered at (center_x, center_y).""" + vertices = [ + (center_x, center_y + size), # top + (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right + (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right + (center_x, center_y - size), # bottom + (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left + (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left + ] + return vertices + + def draw_property_layers( space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes ): @@ -175,6 +191,25 @@ def draw_property_layers( so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}` """ + + def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: + """Create hexagon vertices for the mesh.""" + + hexagons = [] + size = 1 + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + + for row in range(height): + for col in range(width): + # Calculate center position with offset for even rows + x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) + y = row * y_spacing + vertices = _get_hex_vertices(x, y, size) + hexagons.append(vertices) + + return hexagons + try: # old style spaces property_layers = space.properties @@ -182,6 +217,9 @@ def draw_property_layers( # new style spaces property_layers = space._mesa_property_layers + # Check space type once + is_hex_grid = isinstance(space, HexGrid) + for layer_name, portrayal in propertylayer_portrayal.items(): layer = property_layers.get(layer_name, None) if not isinstance( @@ -205,46 +243,69 @@ def draw_property_layers( vmax = portrayal.get("vmax", np.max(data)) colorbar = portrayal.get("colorbar", True) - # Draw the layer + # Prepare colormap if "color" in portrayal: - data = data.T rgba_color = to_rgba(portrayal["color"]) - normalized_data = (data - vmin) / (vmax - vmin) - rgba_data = np.full((*data.shape, 4), rgba_color) - rgba_data[..., 3] *= normalized_data * alpha - rgba_data = np.clip(rgba_data, 0, 1) cmap = LinearSegmentedColormap.from_list( layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)] ) - im = ax.imshow( - rgba_data, - origin="lower", - ) - if colorbar: - norm = Normalize(vmin=vmin, vmax=vmax) - sm = ScalarMappable(norm=norm, cmap=cmap) - sm.set_array([]) - ax.figure.colorbar(sm, ax=ax, orientation="vertical") - elif "colormap" in portrayal: cmap = portrayal.get("colormap", "viridis") if isinstance(cmap, list): cmap = LinearSegmentedColormap.from_list(layer_name, cmap) - im = ax.imshow( - data.T, - cmap=cmap, - alpha=alpha, - vmin=vmin, - vmax=vmax, - origin="lower", - ) - if colorbar: - plt.colorbar(im, ax=ax, label=layer_name) + elif isinstance(cmap, str): + cmap = plt.get_cmap(cmap) else: raise ValueError( f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) + if is_hex_grid: + width, height = data.shape + + # Generate hexagon mesh + hexagons = get_hexmesh(width, height) + + # Normalize colors + norm = Normalize(vmin=vmin, vmax=vmax) + colors = data.ravel() # flatten data to 1D array + + if "color" in portrayal: + normalized_colors = np.clip(norm(colors), 0, 1) + rgba_colors = np.full((len(colors), 4), rgba_color) + rgba_colors[:, 3] = normalized_colors * alpha + else: + rgba_colors = cmap(norm(colors)) + + # Draw hexagons + collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) + im = ax.add_collection(collection) + else: + # Rectangular grid rendering + if "color" in portrayal: + data = data.T + normalized_data = (data - vmin) / (vmax - vmin) + rgba_data = np.full((*data.shape, 4), rgba_color) + rgba_data[..., 3] *= normalized_data * alpha + rgba_data = np.clip(rgba_data, 0, 1) + im = ax.imshow(rgba_data, origin="lower") + else: + im = ax.imshow( + data.T, + cmap=cmap, + alpha=alpha, + vmin=vmin, + vmax=vmax, + origin="lower", + ) + + # Add colorbar if requested + if colorbar: + norm = Normalize(vmin=vmin, vmax=vmax) + sm = ScalarMappable(norm=norm, cmap=cmap) + sm.set_array([]) + plt.colorbar(sm, ax=ax, label=layer_name) + def draw_orthogonal_grid( space: OrthogonalGrid, @@ -353,27 +414,13 @@ def setup_hexmesh(width, height): x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size - def get_hex_vertices( - center_x: float, center_y: float - ) -> list[tuple[float, float]]: - """Get vertices for a hexagon centered at (center_x, center_y).""" - vertices = [ - (center_x, center_y + size), # top - (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right - (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right - (center_x, center_y - size), # bottom - (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left - (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left - ] - return vertices - # Generate edges for each hexagon for row, col in itertools.product(range(height), range(width)): # Calculate center position for each hexagon with offset for even rows x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) y = row * y_spacing - vertices = get_hex_vertices(x, y) + vertices = _get_hex_vertices(x, y, size) # Edge logic, connecting each vertex to the next for v1, v2 in pairwise([*vertices, vertices[0]]): From 491ec07d77d02b9a7cd091a865b5875e292eac09 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Jan 2025 15:11:04 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/visualization/mpl_space_drawing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 91235730a5e..38e64af420e 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -194,7 +194,6 @@ def draw_property_layers( def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: """Create hexagon vertices for the mesh.""" - hexagons = [] size = 1 x_spacing = np.sqrt(3) * size @@ -209,7 +208,7 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: hexagons.append(vertices) return hexagons - + try: # old style spaces property_layers = space.properties @@ -262,13 +261,13 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: if is_hex_grid: width, height = data.shape - + # Generate hexagon mesh hexagons = get_hexmesh(width, height) # Normalize colors norm = Normalize(vmin=vmin, vmax=vmax) - colors = data.ravel() # flatten data to 1D array + colors = data.ravel() # flatten data to 1D array if "color" in portrayal: normalized_colors = np.clip(norm(colors), 0, 1) From d0bfcbb0930b485955272b5b47d64f72f9a37e8b Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Sun, 26 Jan 2025 20:44:12 +0530 Subject: [PATCH 3/6] added pre-commit suggestions --- mesa/visualization/mpl_space_drawing.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 91235730a5e..00ec3906a18 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -194,7 +194,6 @@ def draw_property_layers( def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: """Create hexagon vertices for the mesh.""" - hexagons = [] size = 1 x_spacing = np.sqrt(3) * size @@ -209,7 +208,7 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: hexagons.append(vertices) return hexagons - + try: # old style spaces property_layers = space.properties @@ -262,13 +261,13 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: if is_hex_grid: width, height = data.shape - + # Generate hexagon mesh hexagons = get_hexmesh(width, height) # Normalize colors norm = Normalize(vmin=vmin, vmax=vmax) - colors = data.ravel() # flatten data to 1D array + colors = data.ravel() # flatten data to 1D array if "color" in portrayal: normalized_colors = np.clip(norm(colors), 0, 1) @@ -279,7 +278,7 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: # Draw hexagons collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) - im = ax.add_collection(collection) + ax.add_collection(collection) else: # Rectangular grid rendering if "color" in portrayal: @@ -288,9 +287,9 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: rgba_data = np.full((*data.shape, 4), rgba_color) rgba_data[..., 3] *= normalized_data * alpha rgba_data = np.clip(rgba_data, 0, 1) - im = ax.imshow(rgba_data, origin="lower") + ax.imshow(rgba_data, origin="lower") else: - im = ax.imshow( + ax.imshow( data.T, cmap=cmap, alpha=alpha, From 05ecacdd64ad64e98c9002f37be495d692d04c72 Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Mon, 27 Jan 2025 17:57:46 +0530 Subject: [PATCH 4/6] added reviewed changes --- mesa/visualization/mpl_space_drawing.py | 96 ++++++++++++------------- 1 file changed, 44 insertions(+), 52 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 00ec3906a18..2f89f421017 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -11,7 +11,8 @@ import warnings from collections.abc import Callable from itertools import pairwise -from typing import Any +from typing import Any, Iterator +from functools import lru_cache import networkx as nx import numpy as np @@ -175,6 +176,21 @@ def _get_hex_vertices( return vertices +@lru_cache(maxsize=1024, typed=True) +def _get_hexmesh(width: int, height: int, size: float) -> Iterator[list[tuple[float, float]]]: + """ + Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon + """ + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + + for row, col in itertools.product(range(height), range(width)): + # Calculate center position with offset for even rows + x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) + y = row * y_spacing + yield _get_hex_vertices(x, y, size) + + def draw_property_layers( space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes ): @@ -192,23 +208,6 @@ def draw_property_layers( """ - def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: - """Create hexagon vertices for the mesh.""" - hexagons = [] - size = 1 - x_spacing = np.sqrt(3) * size - y_spacing = 1.5 * size - - for row in range(height): - for col in range(width): - # Calculate center position with offset for even rows - x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) - y = row * y_spacing - vertices = _get_hex_vertices(x, y, size) - hexagons.append(vertices) - - return hexagons - try: # old style spaces property_layers = space.properties @@ -216,9 +215,6 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: # new style spaces property_layers = space._mesa_property_layers - # Check space type once - is_hex_grid = isinstance(space, HexGrid) - for layer_name, portrayal in propertylayer_portrayal.items(): layer = property_layers.get(layer_name, None) if not isinstance( @@ -259,11 +255,29 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) - if is_hex_grid: + if isinstance(space, OrthogonalGrid): + if "color" in portrayal: + data = data.T + normalized_data = (data - vmin) / (vmax - vmin) + rgba_data = np.full((*data.shape, 4), rgba_color) + rgba_data[..., 3] *= normalized_data * alpha + rgba_data = np.clip(rgba_data, 0, 1) + ax.imshow(rgba_data, origin="lower") + else: + ax.imshow( + data.T, + cmap=cmap, + alpha=alpha, + vmin=vmin, + vmax=vmax, + origin="lower", + ) + + elif isinstance(space, HexGrid): width, height = data.shape # Generate hexagon mesh - hexagons = get_hexmesh(width, height) + hexagons = _get_hexmesh(width, height, size=1) # Normalize colors norm = Normalize(vmin=vmin, vmax=vmax) @@ -279,24 +293,11 @@ def get_hexmesh(width: int, height: int) -> list[list[tuple[float, float]]]: # Draw hexagons collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) ax.add_collection(collection) + else: - # Rectangular grid rendering - if "color" in portrayal: - data = data.T - normalized_data = (data - vmin) / (vmax - vmin) - rgba_data = np.full((*data.shape, 4), rgba_color) - rgba_data[..., 3] *= normalized_data * alpha - rgba_data = np.clip(rgba_data, 0, 1) - ax.imshow(rgba_data, origin="lower") - else: - ax.imshow( - data.T, - cmap=cmap, - alpha=alpha, - vmin=vmin, - vmax=vmax, - origin="lower", - ) + raise NotImplementedError( + f"PropertyLayer visualization not implemented for {type(space)}." + ) # Add colorbar if requested if colorbar: @@ -410,24 +411,15 @@ def setup_hexmesh(width, height): """Helper function for creating the hexmesh with unique edges.""" edges = set() size = 1.0 - x_spacing = np.sqrt(3) * size - y_spacing = 1.5 * size # Generate edges for each hexagon - for row, col in itertools.product(range(height), range(width)): - # Calculate center position for each hexagon with offset for even rows - x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) - y = row * y_spacing - - vertices = _get_hex_vertices(x, y, size) - + for vertices in _get_hexmesh(width, height, size): # Edge logic, connecting each vertex to the next for v1, v2 in pairwise([*vertices, vertices[0]]): - # Sort vertices to ensure consistent edge representation and avoid duplicates. + # Sort vertices to ensure consistent edge representation edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) edges.add(edge) - - # Return LineCollection for hexmesh + return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1) if draw_grid: From 7d2bd37d2920a7ce541b869938cc14c838beb9a6 Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Mon, 27 Jan 2025 17:59:48 +0530 Subject: [PATCH 5/6] added pre-commit changes --- mesa/visualization/mpl_space_drawing.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 2f89f421017..2e164ba7594 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -9,10 +9,10 @@ import contextlib import itertools import warnings -from collections.abc import Callable -from itertools import pairwise -from typing import Any, Iterator +from collections.abc import Callable, Iterator from functools import lru_cache +from itertools import pairwise +from typing import Any import networkx as nx import numpy as np @@ -177,13 +177,13 @@ def _get_hex_vertices( @lru_cache(maxsize=1024, typed=True) -def _get_hexmesh(width: int, height: int, size: float) -> Iterator[list[tuple[float, float]]]: - """ - Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon - """ +def _get_hexmesh( + width: int, height: int, size: float +) -> Iterator[list[tuple[float, float]]]: + """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size - + for row, col in itertools.product(range(height), range(width)): # Calculate center position with offset for even rows x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) @@ -207,7 +207,6 @@ def draw_property_layers( so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}` """ - try: # old style spaces property_layers = space.properties @@ -419,7 +418,7 @@ def setup_hexmesh(width, height): # Sort vertices to ensure consistent edge representation edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) edges.add(edge) - + return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1) if draw_grid: From 57243cffe6b7cb6971385eebb0f87ea97601c649 Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Mon, 27 Jan 2025 23:46:11 +0530 Subject: [PATCH 6/6] added reviewed changes --- mesa/visualization/mpl_space_drawing.py | 41 ++++++++++++------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 2e164ba7594..3ddd9098ecf 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -160,27 +160,27 @@ def draw_space( return ax -# Helper function for getting the vertices of a hexagon given the center and size -def _get_hex_vertices( - center_x: float, center_y: float, size: float -) -> list[tuple[float, float]]: - """Get vertices for a hexagon centered at (center_x, center_y).""" - vertices = [ - (center_x, center_y + size), # top - (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right - (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right - (center_x, center_y - size), # bottom - (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left - (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left - ] - return vertices - - @lru_cache(maxsize=1024, typed=True) def _get_hexmesh( - width: int, height: int, size: float + width: int, height: int, size: float = 1.0 ) -> Iterator[list[tuple[float, float]]]: """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" + + # Helper function for getting the vertices of a hexagon given the center and size + def _get_hex_vertices( + center_x: float, center_y: float, size: float = 1.0 + ) -> list[tuple[float, float]]: + """Get vertices for a hexagon centered at (center_x, center_y).""" + vertices = [ + (center_x, center_y + size), # top + (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right + (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right + (center_x, center_y - size), # bottom + (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left + (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left + ] + return vertices + x_spacing = np.sqrt(3) * size y_spacing = 1.5 * size @@ -276,7 +276,7 @@ def draw_property_layers( width, height = data.shape # Generate hexagon mesh - hexagons = _get_hexmesh(width, height, size=1) + hexagons = _get_hexmesh(width, height) # Normalize colors norm = Normalize(vmin=vmin, vmax=vmax) @@ -409,13 +409,12 @@ def draw_hex_grid( def setup_hexmesh(width, height): """Helper function for creating the hexmesh with unique edges.""" edges = set() - size = 1.0 # Generate edges for each hexagon - for vertices in _get_hexmesh(width, height, size): + for vertices in _get_hexmesh(width, height): # Edge logic, connecting each vertex to the next for v1, v2 in pairwise([*vertices, vertices[0]]): - # Sort vertices to ensure consistent edge representation + # Sort vertices to ensure consistent edge representation and avoid duplicates. edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) edges.add(edge)