From f0c7f0dfc384fa8e6246c537206d52bfcceadca1 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Fri, 20 Sep 2024 10:17:02 +0200 Subject: [PATCH] Allow all Agent subclasses --- mesa/datacollection.py | 9 ++++----- tests/test_datacollector.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 2156fb462cf..fb9333da898 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -256,15 +256,14 @@ def get_reports(agent): else: from mesa import Agent - # Check if agent_type is an Agent subclass if issubclass(agent_type, Agent): - raise NotImplementedError( - f"Agent type {agent_type} is not in model.agent_types. We might implement using superclasses in the future. For now, use one of {agent_types}." - ) + agents = [ + agent for agent in model.agents if isinstance(agent, agent_type) + ] else: # Raise error if agent_type is not in model.agent_types raise ValueError( - f"Agent type {agent_type} is not recognized as an Agent type in the model. Use one of {agent_types}." + f"Agent type {agent_type} is not recognized as an Agent type in the model or Agent subclass. Use an Agent (sub)class, like {agent_types}." ) agenttype_records = map(get_reports, agents) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index e8aba12372c..b2760fc1b44 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -366,14 +366,22 @@ def test_agenttype_reporter_multiple_types(self): self.assertNotIn("type_b_val", agent_a_data.columns) self.assertNotIn("type_a_val", agent_b_data.columns) - def test_agenttype_reporter_not_in_model(self): - """Test NotImplementedError is raised when agent type is not in model.agents_by_type.""" + def test_agenttype_superclass_reporter(self): + """Test adding a reporter for a superclass of an agent type.""" model = MockModelWithAgentTypes() - # MockAgent is a legit Agent subclass, but it is not in model.agents_by_type model.datacollector._new_agenttype_reporter(MockAgent, "val", lambda a: a.val) - with self.assertRaises(NotImplementedError): + model.datacollector._new_agenttype_reporter(Agent, "val", lambda a: a.val) + for _ in range(3): model.step() + super_data = model.datacollector.get_agenttype_vars_dataframe(MockAgent) + agent_data = model.datacollector.get_agenttype_vars_dataframe(Agent) + self.assertIn("val", super_data.columns) + self.assertIn("val", agent_data.columns) + self.assertEqual(len(super_data), 30) # 10 agents * 3 steps + self.assertEqual(len(agent_data), 30) + self.assertTrue(super_data.equals(agent_data)) + if __name__ == "__main__": unittest.main()