Skip to content

Commit

Permalink
GroupBy: Add count and agg methods (#2290)
Browse files Browse the repository at this point in the history
Added two new methods to the `GroupBy` class to enhance aggregation and group operations:

- `count`: Returns the count of agents in each group.
- `agg`: Performs aggregation on a specific attribute across groups, applying a function like `sum`, `min`, `max`, or `mean`.

These methods improve flexibility in applying both group-level and attribute-specific operations.
  • Loading branch information
EwoutH authored Sep 20, 2024
1 parent 9179edb commit 84881e7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
25 changes: 24 additions & 1 deletion mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence
from random import Random

# mypy
Expand Down Expand Up @@ -611,6 +611,29 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:

return self

def count(self) -> dict[Any, int]:
"""Return the count of agents in each group.
Returns:
dict: A dictionary mapping group names to the number of agents in each group.
"""
return {k: len(v) for k, v in self.groups.items()}

def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]:
"""Aggregate the values of a specific attribute across each group using the provided function.
Args:
attr_name (str): The name of the attribute to aggregate.
func (Callable): The function to apply (e.g., sum, min, max, mean).
Returns:
dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function.
"""
return {
group_name: func([getattr(agent, attr_name) for agent in group])
for group_name, group in self.groups.items()
}

def __iter__(self): # noqa: D105
return iter(self.groups.items())

Expand Down
32 changes: 32 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ class TestAgent(Agent):
def __init__(self, model):
super().__init__(model)
self.even = self.unique_id % 2 == 0
self.value = self.unique_id * 10

def get_unique_identifier(self):
return self.unique_id
Expand Down Expand Up @@ -560,6 +561,37 @@ def get_unique_identifier(self):
another_ref_to_groups = groups.do(lambda x: x.do("step"))
assert groups == another_ref_to_groups

# New tests for count() method
groups = agentset.groupby("even")
count_result = groups.count()
assert count_result == {True: 5, False: 5}

# New tests for agg() method
groups = agentset.groupby("even")
sum_result = groups.agg("value", sum)
assert sum_result[True] == sum(agent.value for agent in agents if agent.even)
assert sum_result[False] == sum(agent.value for agent in agents if not agent.even)

max_result = groups.agg("value", max)
assert max_result[True] == max(agent.value for agent in agents if agent.even)
assert max_result[False] == max(agent.value for agent in agents if not agent.even)

min_result = groups.agg("value", min)
assert min_result[True] == min(agent.value for agent in agents if agent.even)
assert min_result[False] == min(agent.value for agent in agents if not agent.even)

# Test with a custom aggregation function
def custom_agg(values):
return sum(values) / len(values) if values else 0

custom_result = groups.agg("value", custom_agg)
assert custom_result[True] == custom_agg(
[agent.value for agent in agents if agent.even]
)
assert custom_result[False] == custom_agg(
[agent.value for agent in agents if not agent.even]
)


def test_oldstyle_agent_instantiation():
"""Old behavior of Agent creation with unique_id and model as positional arguments.
Expand Down

0 comments on commit 84881e7

Please sign in to comment.