Skip to content

Commit

Permalink
Datacollection: Update _record_agenttype for Agent subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
EwoutH committed Sep 20, 2024
1 parent 90d9ca5 commit ee9d09f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
22 changes: 18 additions & 4 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,24 @@ def get_reports(agent):
reports = tuple(rep(agent) for rep in rep_funcs)
return _prefix + reports

agenttype_records = map(
get_reports,
model.agents_by_type[agent_type],
)
agent_types = model.agent_types
if agent_type in agent_types:
agents = model.agents_by_type[agent_type]
else:
from mesa import Agent

# Check if agent_type is an Agent subclass
if issubclass(agent_type, Agent):
raise NotImplementedError(
f"Agent type {agent_type} is not in model.agent_types. We might implement using superclasses in the future. For now, use one of {agent_types}."
)
else:
# Raise error if agent_type is not in model.agent_types
raise ValueError(
f"Agent type {agent_type} is not recognized as an Agent type in the model. Use one of {agent_types}."
)

agenttype_records = map(get_reports, agents)
return agenttype_records

def collect(self, model):
Expand Down
37 changes: 20 additions & 17 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,30 +104,25 @@ class MockModelWithAgentTypes(Model):

def __init__(self): # noqa: D107
super().__init__()
self.schedule = BaseScheduler(self)
self.model_val = 100

for i in range(10):
if i % 2 == 0:
self.schedule.add(MockAgentA(self, val=i))
MockAgentA(self, val=i)
else:
self.schedule.add(MockAgentB(self, val=i))
MockAgentB(self, val=i)

self.datacollector = DataCollector(
model_reporters={
"total_agents": lambda m: m.schedule.get_agent_count(),
},
agent_reporters={
"value": lambda a: a.val,
},
model_reporters={"total_agents": lambda m: len(m.agents)},
agent_reporters={"value": lambda a: a.val},
agenttype_reporters={
MockAgentA: {"type_a_val": lambda a: a.type_a_val},
MockAgentB: {"type_b_val": lambda a: a.type_b_val},
},
)

def step(self): # noqa: D102
self.schedule.step()
self.agents.do("step")
self.datacollector.collect(self)


Expand Down Expand Up @@ -324,14 +319,14 @@ class NonExistentAgent(Agent):

def test_agenttype_reporter_string_attribute(self):
"""Test agent-type-specific reporter with string attribute."""
model = MockModel()
model = MockModelWithAgentTypes()
model.datacollector._new_agenttype_reporter(MockAgentA, "string_attr", "val")
model.step()

agent_a_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentA)
self.assertIn("string_attr", agent_a_data.columns)
for (step, agent_id), value in agent_a_data["string_attr"].items():
expected_value = agent_id + 1 # Initial value + 1 step
for (_step, agent_id), value in agent_a_data["string_attr"].items():
expected_value = agent_id
self.assertEqual(value, expected_value)

def test_agenttype_reporter_function_with_params(self):
Expand All @@ -340,21 +335,21 @@ def test_agenttype_reporter_function_with_params(self):
def test_func(agent, multiplier):
return agent.val * multiplier

model = MockModel()
model = MockModelWithAgentTypes()
model.datacollector._new_agenttype_reporter(
MockAgentB, "func_param", [test_func, [2]]
)
model.step()

agent_b_data = model.datacollector.get_agenttype_vars_dataframe(MockAgentB)
self.assertIn("func_param", agent_b_data.columns)
for (step, agent_id), value in agent_b_data["func_param"].items():
expected_value = (agent_id + 1) * 2 # (Initial value + 1 step) * 2
for (_step, agent_id), value in agent_b_data["func_param"].items():
expected_value = agent_id * 2
self.assertEqual(value, expected_value)

def test_agenttype_reporter_multiple_types(self):
"""Test adding reporters for multiple agent types."""
model = MockModel()
model = MockModelWithAgentTypes()
model.datacollector._new_agenttype_reporter(
MockAgentA, "type_a_val", lambda a: a.type_a_val
)
Expand All @@ -371,6 +366,14 @@ def test_agenttype_reporter_multiple_types(self):
self.assertNotIn("type_b_val", agent_a_data.columns)
self.assertNotIn("type_a_val", agent_b_data.columns)

def test_agenttype_reporter_not_in_model(self):
"""Test NotImplementedError is raised when agent type is not in model.agents_by_type."""
model = MockModelWithAgentTypes()
# MockAgent is a legit Agent subclass, but it is not in model.agents_by_type
model.datacollector._new_agenttype_reporter(MockAgent, "val", lambda a: a.val)
with self.assertRaises(NotImplementedError):
model.step()


if __name__ == "__main__":
unittest.main()

0 comments on commit ee9d09f

Please sign in to comment.