From eaf0a5bf4a70a2684aedbf624d46dbb85b9ffcf3 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 21 Aug 2024 15:46:21 +0200 Subject: [PATCH] seperate apply and do --- mesa/agent.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index d46b51bb835..06e38946b93 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -401,7 +401,13 @@ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: class GroupBy: - """Helper class for AgentSet.groupby""" + """Helper class for AgentSet.groupby + + + Attributes: + groups (dict): A dictionary with the group_name as key and group as values + + """ def __init__(self, groups: dict[Any, list | AgentSet]): self.groups: dict[Any, list | AgentSet] = groups @@ -436,5 +442,32 @@ def apply(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]: else: return {k: method(v, *args, **kwargs) for k, v in self.groups.items()} + def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: + """Apply the specified callable to each group + + Args: + method (Callable, str): The callable to apply to each group, + + * if ``method`` is a callable, it will be called it will be called with the group as first argument + * if ``method`` is a str, it should refer to a method on the group + + Additional arguments and keyword arguments will be passed on to the callable. + + Returns: + GroupBy + + """ + if isinstance(method, str): + for v in self.groups.values(): + getattr(v, method)(*args, **kwargs) + else: + for v in self.groups.values(): + method(v, *args, **kwargs) + + return self + def __iter__(self): return iter(self.groups.items()) + + def __len__(self): + return len(self.groups)