Skip to content

Commit

Permalink
Allow all Agent subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
EwoutH committed Sep 20, 2024
1 parent ee9d09f commit f0c7f0d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
9 changes: 4 additions & 5 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,14 @@ def get_reports(agent):
else:
from mesa import Agent

# Check if agent_type is an Agent subclass
if issubclass(agent_type, Agent):
raise NotImplementedError(
f"Agent type {agent_type} is not in model.agent_types. We might implement using superclasses in the future. For now, use one of {agent_types}."
)
agents = [
agent for agent in model.agents if isinstance(agent, agent_type)
]
else:
# Raise error if agent_type is not in model.agent_types
raise ValueError(

Check warning on line 265 in mesa/datacollection.py

View check run for this annotation

Codecov / codecov/patch

mesa/datacollection.py#L265

Added line #L265 was not covered by tests
f"Agent type {agent_type} is not recognized as an Agent type in the model. Use one of {agent_types}."
f"Agent type {agent_type} is not recognized as an Agent type in the model or Agent subclass. Use an Agent (sub)class, like {agent_types}."
)

agenttype_records = map(get_reports, agents)
Expand Down
16 changes: 12 additions & 4 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,22 @@ def test_agenttype_reporter_multiple_types(self):
self.assertNotIn("type_b_val", agent_a_data.columns)
self.assertNotIn("type_a_val", agent_b_data.columns)

def test_agenttype_reporter_not_in_model(self):
"""Test NotImplementedError is raised when agent type is not in model.agents_by_type."""
def test_agenttype_superclass_reporter(self):
"""Test adding a reporter for a superclass of an agent type."""
model = MockModelWithAgentTypes()
# MockAgent is a legit Agent subclass, but it is not in model.agents_by_type
model.datacollector._new_agenttype_reporter(MockAgent, "val", lambda a: a.val)
with self.assertRaises(NotImplementedError):
model.datacollector._new_agenttype_reporter(Agent, "val", lambda a: a.val)
for _ in range(3):
model.step()

super_data = model.datacollector.get_agenttype_vars_dataframe(MockAgent)
agent_data = model.datacollector.get_agenttype_vars_dataframe(Agent)
self.assertIn("val", super_data.columns)
self.assertIn("val", agent_data.columns)
self.assertEqual(len(super_data), 30) # 10 agents * 3 steps
self.assertEqual(len(agent_data), 30)
self.assertTrue(super_data.equals(agent_data))


if __name__ == "__main__":
unittest.main()

0 comments on commit f0c7f0d

Please sign in to comment.