Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggegrated agent metric in DataCollection, graph in ChartModule #1145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions mesa/datacollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from operator import attrgetter
import pandas as pd
import types
import statistics


class DataCollector:
Expand Down Expand Up @@ -100,6 +101,8 @@ class attributes of model

self.model_vars = {}
self._agent_records = {}
self.agent_attr_index = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be self.agent_attr_indexes, to make it consistent with existing dict attribute names, e.g. self.model_vars, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of it as an index between reporter names and their position in the _agent_records. So a single index.

Maybe agent_records_index might be even clearer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, the usage of the term "an index" as a collection of the mapping is inconsistent with its later use:

        # Get the index of the reporter
        attr_index = self.agent_attr_index[reporter]

Where this is a single index.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a suggestion for a better name?

self.agent_name_index = {}
self.tables = {}

if model_reporters is not None:
Expand All @@ -108,6 +111,7 @@ class attributes of model

if agent_reporters is not None:
for name, reporter in agent_reporters.items():
self.agent_name_index[name] = reporter
self._new_agent_reporter(name, reporter)

if tables is not None:
Expand Down Expand Up @@ -159,6 +163,7 @@ def _record_agents(self, model):
if all([hasattr(rep, "attribute_name") for rep in rep_funcs]):
prefix = ["model.schedule.steps", "unique_id"]
attributes = [func.attribute_name for func in rep_funcs]
self.agent_attr_index = {k: v for v, k in enumerate(prefix + attributes)}
get_reports = attrgetter(*prefix + attributes)
else:

Expand Down Expand Up @@ -256,3 +261,36 @@ def get_table_dataframe(self, table_name):
if table_name not in self.tables:
raise Exception("No such table.")
return pd.DataFrame(self.tables[table_name])

def get_agent_metric(self, var_name, metric="mean"):
"""Get a single aggegrated value from an agent variable.

Args:
var_name: The name of the variable to aggegrate.
metric: Statistics metric to be used (default: mean)
all functions from built-in statistics module are supported
as well as "min", "max", "sum" and "len"

"""
# Get the reporter from the name
reporter = self.agent_name_index[var_name]

# Get the index of the reporter
attr_index = self.agent_attr_index[reporter]

# Create a dictionary with all attributes from all agents
attr_dict = self._agent_records

# Get the values from all agents in a list
values_tuples = list(attr_dict.values())[-1]

# Get the correct value using the attribute index
values = [value_tuple[attr_index] for value_tuple in values_tuples]

# Calculate the metric among all agents (mean by default)
if metric in ["min", "max", "sum", "len"]:
value = eval(f"{metric}(values)")
else:
stat_function = getattr(statistics, metric)
value = stat_function(values)
return value
18 changes: 14 additions & 4 deletions mesa/visualization/modules/ChartVisualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class ChartModule(VisualizationElement):
data_collector_name="datacollector")

TODO:
Have it be able to handle agent-level variables as well.
Aggregate agent level variables other than mean (requires ChartModule
API change)

More Pythonic customization; in particular, have both series-level and
chart-level options settable in Python, and passed to the front-end
Expand Down Expand Up @@ -78,9 +79,18 @@ def render(self, model):

for s in self.series:
name = s["Label"]
try:
val = data_collector.model_vars[name][-1] # Latest value
except (IndexError, KeyError):
if name in data_collector.model_vars.keys():
try:
val = data_collector.model_vars[name][-1] # Latest value
except (IndexError, KeyError):
val = 0
elif name in data_collector.agent_name_index.keys():
try:
# Returns mean of latest value of all agents
val = data_collector.get_agent_metric(name)
except (IndexError, KeyError):
val = 0
else:
val = 0
current_values.append(val)
return current_values