diff --git a/mesa/datacollection.py b/mesa/datacollection.py index fcaedc7c8c8..f1c238a559d 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -41,6 +41,8 @@ import pandas as pd +from mesa.agent import Agent + class DataCollector: """Class for collecting data generated by a Mesa model. @@ -94,9 +96,11 @@ class attributes of a model """ self.model_reporters = {} self.agent_reporters = {} + self.agent_specific_reporters = {} self.model_vars = {} self._agent_records = {} + self._agent_specific_records = {} self.tables = {} if model_reporters is not None: @@ -104,8 +108,12 @@ class attributes of a model self._new_model_reporter(name, reporter) if agent_reporters is not None: - for name, reporter in agent_reporters.items(): - self._new_agent_reporter(name, reporter) + for k, v in agent_reporters.items(): + if isinstance(k, type) and k is not Agent and issubclass(k, Agent): + for name, reporter in v.items(): + self._new_agent_specific_reporter(k, name, reporter) + else: + self._new_agent_reporter(k, v) if tables is not None: for name, columns in tables.items(): @@ -136,6 +144,24 @@ def _new_agent_reporter(self, name, reporter): reporter.attribute_name = attribute_name self.agent_reporters[name] = reporter + def _new_agent_specific_reporter(self, agent_class, name, reporter): + """Add a new reporter to collect for specific class of an Agent. + If agent class is Agent then this is function is equivalent to _new_agent_reporter. + + Args: + name: Name of the agent-level variable to collect. + agent_class: Class of an agent. + reporter: Attribute string, or function object that returns the + variable when given a model instance. + """ + if type(reporter) is str: + attribute_name = reporter + reporter = partial(self._getattr, reporter) + reporter.attribute_name = attribute_name + if agent_class not in self.agent_specific_reporters: + self.agent_specific_reporters[agent_class] = {} + self.agent_specific_reporters[agent_class][name] = reporter + def _new_table(self, table_name, table_columns): """Add a new table that objects can write to. @@ -163,6 +189,29 @@ def get_reports(agent): agent_records = map(get_reports, model.schedule.agents) return agent_records + def _record_specific_agents(self, model): + """Record data for agents of each type""" + agent_records = {} + for agent_type in self.agent_specific_reporters: + rep_funcs = self.agent_specific_reporters[agent_type].values() + if all(hasattr(rep, "attribute_name") for rep in rep_funcs): + get_reports = attrgetter(*[func.attribute_name for func in rep_funcs]) + else: + + def get_reports(agent): + return tuple(rep(agent) for rep in rep_funcs) + + agent_records[agent_type] = {} + for agent in model.schedule.agents_type[agent_type]: + report = get_reports(agent) + if not isinstance(report, tuple): + report = [report] + agent_records[agent_type][agent.unique_id] = dict( + zip(self.agent_specific_reporters[agent_type].keys(), report) + ) + + return agent_records + def collect(self, model): """Collect all the data for the given model object.""" if self.model_reporters: @@ -181,6 +230,10 @@ def collect(self, model): else: self.model_vars[var].append(reporter()) + if self.agent_specific_reporters: + agent_records = self._record_specific_agents(model) + self._agent_specific_records[model.schedule.steps] = dict(agent_records) + if self.agent_reporters: agent_records = self._record_agents(model) self._agent_records[model.schedule.steps] = list(agent_records) @@ -246,6 +299,28 @@ def get_agent_vars_dataframe(self): ) return df + def get_agent_specific_vars_dataframe(self): + """Create a pandas DataFrame from the agent variables. + + The DataFrame has one column for each variable present in any type of agent, with three additional + columns for tick, agent class and agent_id. + """ + # Check if self.agent_reporters dictionary is empty, if so raise warning + if not self.agent_specific_reporters: + raise UserWarning( + "No agent reporters have been defined in the DataCollector, returning empty DataFrame." + ) + + return pd.DataFrame.from_dict( + { + (i, j, k): self._agent_specific_records[i][j][k] + for i in self._agent_specific_records + for j in self._agent_specific_records[i] + for k in self._agent_specific_records[i][j] + }, + orient="index", + ) + def get_table_dataframe(self, table_name): """Create a pandas DataFrame from a particular table. diff --git a/mesa/time.py b/mesa/time.py index b079e1439f0..2b3c01d00b7 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -53,6 +53,7 @@ def __init__(self, model: Model) -> None: self.steps = 0 self.time: TimeT = 0 self._agents: dict[int, Agent] = {} + self._agents_type: dict[type[Agent], dict[int, Agent]] = {} def add(self, agent: Agent) -> None: """Add an Agent object to the schedule. @@ -67,6 +68,9 @@ def add(self, agent: Agent) -> None: ) self._agents[agent.unique_id] = agent + if type(agent) not in self._agents_type: + self._agents_type[type(agent)] = {} + self._agents_type[type(agent)][agent.unique_id] = agent def remove(self, agent: Agent) -> None: """Remove all instances of a given agent from the schedule. @@ -75,6 +79,7 @@ def remove(self, agent: Agent) -> None: agent: An agent object. """ del self._agents[agent.unique_id] + del self._agents_type[type(agent)][agent.unique_id] def step(self) -> None: """Execute the step of all the agents, one at a time.""" @@ -91,6 +96,13 @@ def get_agent_count(self) -> int: def agents(self) -> list[Agent]: return list(self._agents.values()) + @property + def agents_type(self) -> dict[type[Agent], list[Agent]]: + return { + agent_class: list(agents.values()) + for agent_class, agents in self._agents_type.items() + } + def agent_buffer(self, shuffled: bool = False) -> Iterator[Agent]: """Simple generator that yields the agents while letting the user remove and/or add agents during stepping. diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 7ba72c73df5..0772bcf9ddd 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -16,6 +16,7 @@ def __init__(self, unique_id, model, val=0): super().__init__(unique_id, model) self.val = val self.val2 = val + self.val3 = val def step(self): """ @@ -23,6 +24,27 @@ def step(self): """ self.val += 1 self.val2 += 1 + self.val3 += 1 + + def write_final_values(self): + """ + Write the final value to the appropriate table. + """ + row = {"agent_id": self.unique_id, "final_value": self.val} + self.model.datacollector.add_table_row("Final_Values", row) + + +class DifferentMockAgent(Agent): + """ + Minimalistic agent for testing purposes. + """ + + def __init__(self, unique_id, model, val=0): + super().__init__(unique_id, model) + self.val = val + self.val2 = val + self.val4 = val + self.val5 = val def write_final_values(self): """ @@ -43,9 +65,10 @@ def __init__(self): self.schedule = BaseScheduler(self) self.model_val = 100 - for i in range(10): - a = MockAgent(i, self, val=i) - self.schedule.add(a) + n = 5 + for i in range(n): + self.schedule.add(MockAgent(i, self, val=i)) + self.schedule.add(DifferentMockAgent(n + i, self, val=i)) self.initialize_data_collector( { "total_agents": lambda m: m.schedule.get_agent_count(), @@ -54,7 +77,12 @@ def __init__(self): "model_calc_comp": [self.test_model_calc_comp, [3, 4]], "model_calc_fail": [self.test_model_calc_comp, [12, 0]], }, - {"value": lambda a: a.val, "value2": "val2"}, + { + "value": lambda a: a.val, + "value2": "val2", + MockAgent: {"value3": "val3"}, + DifferentMockAgent: {"value4": "val4", "value5": lambda a: a.val5}, + }, {"Final_Values": ["agent_id", "final_value"]}, ) @@ -164,9 +192,11 @@ def test_exports(self): data_collector = self.model.datacollector model_vars = data_collector.get_model_vars_dataframe() agent_vars = data_collector.get_agent_vars_dataframe() + specific_agent_vars = data_collector.get_agent_specific_vars_dataframe() table_df = data_collector.get_table_dataframe("Final_Values") assert model_vars.shape == (8, 5) assert agent_vars.shape == (77, 2) + assert specific_agent_vars.shape == (77, 3) assert table_df.shape == (9, 2) with self.assertRaises(Exception):