Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AgentSet.groupby #2220

Merged
merged 30 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d90b0f7
further updates
quaquel Jan 21, 2024
9586490
Update benchmarks/WolfSheep/__init__.py
quaquel Jan 21, 2024
4aaa35d
Merge remote-tracking branch 'upstream/main'
quaquel Feb 21, 2024
d31478c
Merge remote-tracking branch 'upstream/main'
quaquel Apr 22, 2024
6e4c72e
Merge remote-tracking branch 'upstream/main'
quaquel Aug 14, 2024
70fbaf5
Merge remote-tracking branch 'upstream/main'
quaquel Aug 18, 2024
724c8db
Merge remote-tracking branch 'upstream/main'
quaquel Aug 21, 2024
96cd80a
add groupby
quaquel Aug 17, 2024
07467f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
b04099f
Update agent.py
quaquel Aug 18, 2024
6faaf9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
7eed1db
add tests and simplify groupby objects
quaquel Aug 19, 2024
e9be4cc
Update test_agent.py
quaquel Aug 19, 2024
6a329bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
16d75c5
ruff fixes
quaquel Aug 19, 2024
e09f5a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
a4d9559
additional docs
quaquel Aug 19, 2024
7db02b8
add test for callable
quaquel Aug 19, 2024
8aa56da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
943a907
rename groupby and additional docs
quaquel Aug 20, 2024
9fc7963
Update agent.py
quaquel Aug 20, 2024
c481281
add str option to GroupBy.apply
quaquel Aug 21, 2024
585025b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
f0dccf6
remove unnecessary list comprehension
quaquel Aug 21, 2024
eaf0a5b
seperate apply and do
quaquel Aug 21, 2024
7c3a78f
remove unintened update to __init__.py for wolfsheep benchmark
quaquel Aug 21, 2024
a25dd06
additional tests
quaquel Aug 21, 2024
5544cba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
e075bcb
Change apply to map
quaquel Aug 22, 2024
c72aa37
also update tests
quaquel Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import copy
import operator
import weakref
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
from random import Random

Expand Down Expand Up @@ -357,7 +358,116 @@ def random(self) -> Random:
"""
return self.model.random

def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
"""
Group agents by the specified attribute or return from the callable

Args:
by (Callable, str): used to determine what to group agents by

* if ``by`` is a callable, it will be called for each agent and the return is used
for grouping
* if ``by`` is a str, it should refer to an attribute on the agent and the value
of this attribute will be used for grouping
result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
Returns:
GroupBy


Notes:
There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
of an AgentSet.

"""
groups = defaultdict(list)

if isinstance(by, Callable):
for agent in self:
groups[by(agent)].append(agent)
else:
for agent in self:
groups[getattr(agent, by)].append(agent)

if result_type == "agentset":
return GroupBy(
{k: AgentSet(v, model=self.model) for k, v in groups.items()}
)
else:
return GroupBy(groups)

# consider adding for performance reasons
# for Sequence: __reversed__, index, and count
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__


class GroupBy:
"""Helper class for AgentSet.groupby


Attributes:
groups (dict): A dictionary with the group_name as key and group as values

"""

def __init__(self, groups: dict[Any, list | AgentSet]):
self.groups: dict[Any, list | AgentSet] = groups

def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
"""Apply the specified callable to each group and return the results.

Args:
method (Callable, str): The callable to apply to each group,

* if ``method`` is a callable, it will be called it will be called with the group as first argument
* if ``method`` is a str, it should refer to a method on the group

Additional arguments and keyword arguments will be passed on to the callable.

Returns:
dict with group_name as key and the return of the method as value

Notes:
this method is useful for methods or functions that do return something. It
will break method chaining. For that, use ``do`` instead.

"""
if isinstance(method, str):
return {
k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items()
}
else:
return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}

def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
"""Apply the specified callable to each group

Args:
method (Callable, str): The callable to apply to each group,

* if ``method`` is a callable, it will be called it will be called with the group as first argument
* if ``method`` is a str, it should refer to a method on the group

Additional arguments and keyword arguments will be passed on to the callable.

Returns:
the original GroupBy instance

Notes:
this method is useful for methods or functions that don't return anything and/or
if you want to chain multiple do calls

"""
if isinstance(method, str):
for v in self.groups.values():
getattr(v, method)(*args, **kwargs)
else:
for v in self.groups.values():
method(v, *args, **kwargs)

return self

def __iter__(self):
return iter(self.groups.items())

# consider adding for performance reasons
# for Sequence: __reversed__, index, and count
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
def __len__(self):
return len(self.groups)
42 changes: 42 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,45 @@ def test_agentset_shuffle():
agentset = AgentSet(test_agents, model=model)
agentset.shuffle(inplace=True)
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))


def test_agentset_groupby():
class TestAgent(Agent):
def __init__(self, unique_id, model):
super().__init__(unique_id, model)
self.even = self.unique_id % 2 == 0

def get_unique_identifier(self):
return self.unique_id

model = Model()
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
agentset = AgentSet(agents, model)

groups = agentset.groupby("even")
assert len(groups.groups[True]) == 5
assert len(groups.groups[False]) == 5

groups = agentset.groupby(lambda a: a.unique_id % 2 == 0)
assert len(groups.groups[True]) == 5
assert len(groups.groups[False]) == 5
assert len(groups) == 2

for group_name, group in groups:
assert len(group) == 5
assert group_name in {True, False}

sizes = agentset.groupby("even", result_type="list").map(len)
assert sizes == {True: 5, False: 5}

attributes = agentset.groupby("even", result_type="agentset").map("get", "even")
for group_name, group in attributes.items():
assert all(group_name == entry for entry in group)

groups = agentset.groupby("even", result_type="agentset")
another_ref_to_groups = groups.do("do", "step")
assert groups == another_ref_to_groups

groups = agentset.groupby("even", result_type="agentset")
another_ref_to_groups = groups.do(lambda x: x.do("step"))
assert groups == another_ref_to_groups