Skip to content

Commit

Permalink
Fixed hex-space draw function to avoid overlaps (#2609)
Browse files Browse the repository at this point in the history
Replaces RegularPolygon based PatchCollection with individual drawn lines while ensuring that  no duplicate lines exist.
  • Loading branch information
Sahil-Chhoker authored Jan 26, 2025
1 parent 13518b2 commit 79f2969
Showing 1 changed file with 70 additions and 62 deletions.
132 changes: 70 additions & 62 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@

import contextlib
import itertools
import math
import warnings
from collections.abc import Callable
from itertools import pairwise
from typing import Any

import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.collections import PatchCollection
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.patches import Polygon, RegularPolygon
from matplotlib.patches import Polygon

import mesa
from mesa.experimental.cell_space import (
Expand Down Expand Up @@ -308,13 +308,6 @@ def draw_hex_grid(
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter
Returns:
Returns the Axes object with the plot drawn onto it.
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
"""
if ax is None:
fig, ax = plt.subplots()
Expand All @@ -323,62 +316,77 @@ def draw_hex_grid(
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# for hexgrids we have to go from logical coordinates to visual coordinates
# this is a bit messy.

# give all even rows an offset in the x direction
# give all rows an offset in the y direction

# numbers here are based on a distance of 1 between centers of hexes
offset = math.sqrt(0.75)
# Parameters for hexagon grid
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

loc = arguments["loc"].astype(float)

logical = np.mod(loc[:, 1], 2) == 0
loc[:, 0][logical] += 0.5
loc[:, 1] *= offset
arguments["loc"] = loc

# plot the agents
_scatter(ax, arguments, **kwargs)

# further styling and adding of grid
ax.set_xlim(-1, space.width + 0.5)
ax.set_ylim(-offset, space.height * offset)

def setup_hexmesh(
width,
height,
):
"""Helper function for creating the hexmaesh."""
# fixme: this should be done once, rather than in each update
# fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset)

patches = []
for x, y in itertools.product(range(width), range(height)):
if y % 2 == 0:
x += 0.5 # noqa: PLW2901
y *= offset # noqa: PLW2901
hex = RegularPolygon(
(x, y),
numVertices=6,
radius=math.sqrt(1 / 3),
orientation=np.radians(120),
)
patches.append(hex)
mesh = PatchCollection(
patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1
)
return mesh
# Calculate hexagon centers for agents if agents are present and plot them.
if loc.size > 0:
loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] - 1) % 2) * (x_spacing / 2)
loc[:, 1] = loc[:, 1] * y_spacing
arguments["loc"] = loc

# plot the agents
_scatter(ax, arguments, **kwargs)

# Calculate proper bounds that account for the full hexagon width and height
x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2)
y_max = space.height * y_spacing

# Add padding that accounts for the hexagon points
x_padding = (
size * np.sqrt(3) / 2
) # Distance from center to rightmost point of hexagon
y_padding = size # Distance from center to topmost point of hexagon

# Plot limits to perfectly contain the hexagonal grid
# Determined through physical testing.
ax.set_xlim(-2 * x_padding, x_max + x_padding)
ax.set_ylim(-2 * y_padding, y_max + y_padding)

def setup_hexmesh(width, height):
"""Helper function for creating the hexmesh with unique edges."""
edges = set()
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

def get_hex_vertices(
center_x: float, center_y: float
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

# Generate edges for each hexagon
for row, col in itertools.product(range(height), range(width)):
# Calculate center position for each hexagon with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing

vertices = get_hex_vertices(x, y)

# Edge logic, connecting each vertex to the next
for v1, v2 in pairwise([*vertices, vertices[0]]):
# Sort vertices to ensure consistent edge representation and avoid duplicates.
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
edges.add(edge)

# Return LineCollection for hexmesh
return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1)

if draw_grid:
# add grid
ax.add_collection(
setup_hexmesh(
space.width,
space.height,
)
)
ax.add_collection(setup_hexmesh(space.width, space.height))

return ax


Expand Down

0 comments on commit 79f2969

Please sign in to comment.