From 79f29693014df47330f4ef77566539ca8459b2a7 Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Sun, 26 Jan 2025 18:49:55 +0530 Subject: [PATCH] Fixed hex-space draw function to avoid overlaps (#2609) Replaces RegularPolygon based PatchCollection with individual drawn lines while ensuring that no duplicate lines exist. --- mesa/visualization/mpl_space_drawing.py | 132 +++++++++++++----------- 1 file changed, 70 insertions(+), 62 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 11100e6d104..cc8d59cdd68 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -8,9 +8,9 @@ import contextlib import itertools -import math import warnings from collections.abc import Callable +from itertools import pairwise from typing import Any import networkx as nx @@ -18,9 +18,9 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable -from matplotlib.collections import PatchCollection +from matplotlib.collections import LineCollection, PatchCollection from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba -from matplotlib.patches import Polygon, RegularPolygon +from matplotlib.patches import Polygon import mesa from mesa.experimental.cell_space import ( @@ -308,13 +308,6 @@ def draw_hex_grid( ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid kwargs: additional keyword arguments passed to ax.scatter - - Returns: - Returns the Axes object with the plot drawn onto it. - - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. - """ if ax is None: fig, ax = plt.subplots() @@ -323,62 +316,77 @@ def draw_hex_grid( s_default = (180 / max(space.width, space.height)) ** 2 arguments = collect_agent_data(space, agent_portrayal, size=s_default) - # for hexgrids we have to go from logical coordinates to visual coordinates - # this is a bit messy. - - # give all even rows an offset in the x direction - # give all rows an offset in the y direction - - # numbers here are based on a distance of 1 between centers of hexes - offset = math.sqrt(0.75) + # Parameters for hexagon grid + size = 1.0 + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size loc = arguments["loc"].astype(float) - - logical = np.mod(loc[:, 1], 2) == 0 - loc[:, 0][logical] += 0.5 - loc[:, 1] *= offset - arguments["loc"] = loc - - # plot the agents - _scatter(ax, arguments, **kwargs) - - # further styling and adding of grid - ax.set_xlim(-1, space.width + 0.5) - ax.set_ylim(-offset, space.height * offset) - - def setup_hexmesh( - width, - height, - ): - """Helper function for creating the hexmaesh.""" - # fixme: this should be done once, rather than in each update - # fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset) - - patches = [] - for x, y in itertools.product(range(width), range(height)): - if y % 2 == 0: - x += 0.5 # noqa: PLW2901 - y *= offset # noqa: PLW2901 - hex = RegularPolygon( - (x, y), - numVertices=6, - radius=math.sqrt(1 / 3), - orientation=np.radians(120), - ) - patches.append(hex) - mesh = PatchCollection( - patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1 - ) - return mesh + # Calculate hexagon centers for agents if agents are present and plot them. + if loc.size > 0: + loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] - 1) % 2) * (x_spacing / 2) + loc[:, 1] = loc[:, 1] * y_spacing + arguments["loc"] = loc + + # plot the agents + _scatter(ax, arguments, **kwargs) + + # Calculate proper bounds that account for the full hexagon width and height + x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2) + y_max = space.height * y_spacing + + # Add padding that accounts for the hexagon points + x_padding = ( + size * np.sqrt(3) / 2 + ) # Distance from center to rightmost point of hexagon + y_padding = size # Distance from center to topmost point of hexagon + + # Plot limits to perfectly contain the hexagonal grid + # Determined through physical testing. + ax.set_xlim(-2 * x_padding, x_max + x_padding) + ax.set_ylim(-2 * y_padding, y_max + y_padding) + + def setup_hexmesh(width, height): + """Helper function for creating the hexmesh with unique edges.""" + edges = set() + size = 1.0 + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + + def get_hex_vertices( + center_x: float, center_y: float + ) -> list[tuple[float, float]]: + """Get vertices for a hexagon centered at (center_x, center_y).""" + vertices = [ + (center_x, center_y + size), # top + (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right + (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right + (center_x, center_y - size), # bottom + (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left + (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left + ] + return vertices + + # Generate edges for each hexagon + for row, col in itertools.product(range(height), range(width)): + # Calculate center position for each hexagon with offset for even rows + x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) + y = row * y_spacing + + vertices = get_hex_vertices(x, y) + + # Edge logic, connecting each vertex to the next + for v1, v2 in pairwise([*vertices, vertices[0]]): + # Sort vertices to ensure consistent edge representation and avoid duplicates. + edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) + edges.add(edge) + + # Return LineCollection for hexmesh + return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1) if draw_grid: - # add grid - ax.add_collection( - setup_hexmesh( - space.width, - space.height, - ) - ) + ax.add_collection(setup_hexmesh(space.width, space.height)) + return ax