Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collecting data from multiple differentiable classes of agents #1701

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@

import pandas as pd

from mesa.agent import Agent


class DataCollector:
"""Class for collecting data generated by a Mesa model.
Expand Down Expand Up @@ -94,18 +96,24 @@ class attributes of a model
"""
self.model_reporters = {}
self.agent_reporters = {}
self.agent_specific_reporters = {}

self.model_vars = {}
self._agent_records = {}
self._agent_specific_records = {}
self.tables = {}

if model_reporters is not None:
for name, reporter in model_reporters.items():
self._new_model_reporter(name, reporter)

if agent_reporters is not None:
for name, reporter in agent_reporters.items():
self._new_agent_reporter(name, reporter)
for k, v in agent_reporters.items():
if isinstance(k, type) and k is not Agent and issubclass(k, Agent):
for name, reporter in v.items():
self._new_agent_specific_reporter(k, name, reporter)
else:
self._new_agent_reporter(k, v)

if tables is not None:
for name, columns in tables.items():
Expand Down Expand Up @@ -136,6 +144,24 @@ def _new_agent_reporter(self, name, reporter):
reporter.attribute_name = attribute_name
self.agent_reporters[name] = reporter

def _new_agent_specific_reporter(self, agent_class, name, reporter):
"""Add a new reporter to collect for specific class of an Agent.
If agent class is Agent then this is function is equivalent to _new_agent_reporter.

Args:
name: Name of the agent-level variable to collect.
agent_class: Class of an agent.
reporter: Attribute string, or function object that returns the
variable when given a model instance.
"""
if type(reporter) is str:
attribute_name = reporter
reporter = partial(self._getattr, reporter)
reporter.attribute_name = attribute_name
if agent_class not in self.agent_specific_reporters:
self.agent_specific_reporters[agent_class] = {}
self.agent_specific_reporters[agent_class][name] = reporter

def _new_table(self, table_name, table_columns):
"""Add a new table that objects can write to.

Expand Down Expand Up @@ -163,6 +189,29 @@ def get_reports(agent):
agent_records = map(get_reports, model.schedule.agents)
return agent_records

def _record_specific_agents(self, model):
"""Record data for agents of each type"""
agent_records = {}
for agent_type in self.agent_specific_reporters:
rep_funcs = self.agent_specific_reporters[agent_type].values()
if all(hasattr(rep, "attribute_name") for rep in rep_funcs):
get_reports = attrgetter(*[func.attribute_name for func in rep_funcs])
else:

def get_reports(agent):
return tuple(rep(agent) for rep in rep_funcs)

agent_records[agent_type] = {}
for agent in model.schedule.agents_type[agent_type]:
report = get_reports(agent)
if not isinstance(report, tuple):
report = [report]
agent_records[agent_type][agent.unique_id] = dict(
zip(self.agent_specific_reporters[agent_type].keys(), report)
)

return agent_records

def collect(self, model):
"""Collect all the data for the given model object."""
if self.model_reporters:
Expand All @@ -181,6 +230,10 @@ def collect(self, model):
else:
self.model_vars[var].append(reporter())

if self.agent_specific_reporters:
agent_records = self._record_specific_agents(model)
self._agent_specific_records[model.schedule.steps] = dict(agent_records)

if self.agent_reporters:
agent_records = self._record_agents(model)
self._agent_records[model.schedule.steps] = list(agent_records)
Expand Down Expand Up @@ -246,6 +299,28 @@ def get_agent_vars_dataframe(self):
)
return df

def get_agent_specific_vars_dataframe(self):
"""Create a pandas DataFrame from the agent variables.

The DataFrame has one column for each variable present in any type of agent, with three additional
columns for tick, agent class and agent_id.
"""
# Check if self.agent_reporters dictionary is empty, if so raise warning
if not self.agent_specific_reporters:
raise UserWarning(
"No agent reporters have been defined in the DataCollector, returning empty DataFrame."
)

return pd.DataFrame.from_dict(
{
(i, j, k): self._agent_specific_records[i][j][k]
for i in self._agent_specific_records
for j in self._agent_specific_records[i]
for k in self._agent_specific_records[i][j]
},
orient="index",
)

def get_table_dataframe(self, table_name):
"""Create a pandas DataFrame from a particular table.

Expand Down
12 changes: 12 additions & 0 deletions mesa/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, model: Model) -> None:
self.steps = 0
self.time: TimeT = 0
self._agents: dict[int, Agent] = {}
self._agents_type: dict[type[Agent], dict[int, Agent]] = {}

def add(self, agent: Agent) -> None:
"""Add an Agent object to the schedule.
Expand All @@ -67,6 +68,9 @@ def add(self, agent: Agent) -> None:
)

self._agents[agent.unique_id] = agent
if type(agent) not in self._agents_type:
self._agents_type[type(agent)] = {}
self._agents_type[type(agent)][agent.unique_id] = agent

def remove(self, agent: Agent) -> None:
"""Remove all instances of a given agent from the schedule.
Expand All @@ -75,6 +79,7 @@ def remove(self, agent: Agent) -> None:
agent: An agent object.
"""
del self._agents[agent.unique_id]
del self._agents_type[type(agent)][agent.unique_id]

def step(self) -> None:
"""Execute the step of all the agents, one at a time."""
Expand All @@ -91,6 +96,13 @@ def get_agent_count(self) -> int:
def agents(self) -> list[Agent]:
return list(self._agents.values())

@property
def agents_type(self) -> dict[type[Agent], list[Agent]]:
return {
agent_class: list(agents.values())
for agent_class, agents in self._agents_type.items()
}

def agent_buffer(self, shuffled: bool = False) -> Iterator[Agent]:
"""Simple generator that yields the agents while letting the user
remove and/or add agents during stepping.
Expand Down
38 changes: 34 additions & 4 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,35 @@ def __init__(self, unique_id, model, val=0):
super().__init__(unique_id, model)
self.val = val
self.val2 = val
self.val3 = val

def step(self):
"""
Increment vals by 1.
"""
self.val += 1
self.val2 += 1
self.val3 += 1

def write_final_values(self):
"""
Write the final value to the appropriate table.
"""
row = {"agent_id": self.unique_id, "final_value": self.val}
self.model.datacollector.add_table_row("Final_Values", row)


class DifferentMockAgent(Agent):
"""
Minimalistic agent for testing purposes.
"""

def __init__(self, unique_id, model, val=0):
super().__init__(unique_id, model)
self.val = val
self.val2 = val
self.val4 = val
self.val5 = val

def write_final_values(self):
"""
Expand All @@ -43,9 +65,10 @@ def __init__(self):
self.schedule = BaseScheduler(self)
self.model_val = 100

for i in range(10):
a = MockAgent(i, self, val=i)
self.schedule.add(a)
n = 5
for i in range(n):
self.schedule.add(MockAgent(i, self, val=i))
self.schedule.add(DifferentMockAgent(n + i, self, val=i))
self.initialize_data_collector(
{
"total_agents": lambda m: m.schedule.get_agent_count(),
Expand All @@ -54,7 +77,12 @@ def __init__(self):
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
{"value": lambda a: a.val, "value2": "val2"},
{
"value": lambda a: a.val,
"value2": "val2",
MockAgent: {"value3": "val3"},
DifferentMockAgent: {"value4": "val4", "value5": lambda a: a.val5},
},
{"Final_Values": ["agent_id", "final_value"]},
)

Expand Down Expand Up @@ -164,9 +192,11 @@ def test_exports(self):
data_collector = self.model.datacollector
model_vars = data_collector.get_model_vars_dataframe()
agent_vars = data_collector.get_agent_vars_dataframe()
specific_agent_vars = data_collector.get_agent_specific_vars_dataframe()
table_df = data_collector.get_table_dataframe("Final_Values")
assert model_vars.shape == (8, 5)
assert agent_vars.shape == (77, 2)
assert specific_agent_vars.shape == (77, 3)
assert table_df.shape == (9, 2)

with self.assertRaises(Exception):
Expand Down