Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have a dedicated neighborhood property and a get_neighborhood method on Cell #2309

Merged
merged 10 commits into from
Sep 21, 2024
28 changes: 26 additions & 2 deletions mesa/experimental/cell_space/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,34 @@ def is_full(self) -> bool:
def __repr__(self): # noqa
return f"Cell({self.coordinate}, {self.agents})"

@property
@cache
def neighborhood(self) -> CellCollection:
"""Returns the direct neigborhood of the cell

This is equivalent to cell.get_neighborhood(radius=1)

"""
return self.get_neighborhood(radius=1)

# FIXME: Revisit caching strategy on methods
@cache # noqa: B019
def neighborhood(self, radius: int = 1, include_center: bool = False):
"""Returns a list of all neighboring cells."""
def get_neighborhood(
self, radius: int = 1, include_center: bool = False
) -> CellCollection:
"""Returns a list of all neighboring cells for the given radius

For getting the direct neighborhood (i.e., radius=1) you can also use
the `neighborhood` property.

Args:
radius (int): the radius of the neighborhood
include_center (bool): include the center of the neighborhood

Returns:
a list of all neighboring cells

"""
return CellCollection[Cell](
self._neighborhood(radius=radius, include_center=include_center),
random=self.random,
Expand Down
22 changes: 17 additions & 5 deletions tests/test_cell_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,33 +280,45 @@ def test_cell_neighborhood():
height = 10
grid = OrthogonalVonNeumannGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [2, 5, 9]):
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(0, 0)].neighborhood
else:
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

## Moore
width = 10
height = 10
grid = OrthogonalMooreGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [3, 8, 15]):
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(0, 0)].neighborhood
else:
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

with pytest.raises(ValueError):
grid._cells[(0, 0)].neighborhood(radius=0)
grid._cells[(0, 0)].get_neighborhood(radius=0)

# hexgrid
width = 10
height = 10
grid = HexGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [2, 6, 11]):
neighborhood = grid._cells[(0, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(0, 0)].neighborhood
else:
neighborhood = grid._cells[(0, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

width = 10
height = 10
grid = HexGrid((width, height), torus=False, capacity=None)
for radius, n in zip(range(1, 4), [5, 10, 17]):
neighborhood = grid._cells[(1, 0)].neighborhood(radius=radius)
if radius == 1:
neighborhood = grid._cells[(1, 0)].neighborhood
else:
neighborhood = grid._cells[(1, 0)].get_neighborhood(radius=radius)
assert len(neighborhood) == n

# networkgrid
Expand Down
Loading