Skip to content

Commit

Permalink
refactor(src): ♻️ refactored final data reporters
Browse files Browse the repository at this point in the history
  • Loading branch information
SongshGeo committed Jun 4, 2024
1 parent b25d6fa commit f0f9603
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 271 deletions.
30 changes: 15 additions & 15 deletions abses/_bases/datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

import numpy as np
import pandas as pd

from abses._bases.errors import ABSESpyError
from loguru import logger

if TYPE_CHECKING:
from abses.actor import Actor
Expand Down Expand Up @@ -88,7 +87,6 @@ def __init__(self, reports: Dict[ReportType, Dict[str, Reporter]]):

self._agent_records: Dict[str, List[pd.DataFrame]] = {}
self.model_vars: Dict[str, List[Any]] = {}
self.final_vars: Dict[str, List[Any]] = {}

self.add_reporters("model", reports.get("model", {}))
self.add_reporters("agents", reports.get("agents", {}))
Expand All @@ -109,7 +107,7 @@ def add_reporters(
return
if item == "final":
for name, reporter in reporters.items():
self._new_final_reporter(name, reporter)
self.final_reporters[name] = clean_to_reporter(reporter)
return
if item == "agents":
for breed, tmp_reporters in reporters.items():
Expand All @@ -133,10 +131,6 @@ def _new_model_reporter(self, name: str, reporter: Reporter) -> None:
self.model_reporters[name] = clean_to_reporter(reporter)
self.model_vars[name] = []

def _new_final_reporter(self, name: str, reporter: Reporter) -> None:
self.final_reporters[name] = clean_to_reporter(reporter)
self.final_vars[name] = []

def _record_a_breed_of_agents(
self, time: TimeDriver, breed: str, agents: ActorsList[Actor]
) -> None:
Expand Down Expand Up @@ -178,8 +172,9 @@ def get_model_vars_dataframe(self):
"""
# Check if self.model_reporters dictionary is empty, if so raise warning
if not self.model_reporters:
raise UserWarning(
"No model reporters have been defined in the DataCollector, returning empty DataFrame."
logger.warning(
"No model reporters have been defined"
"returning empty DataFrame."
)

return pd.DataFrame(self.model_vars)
Expand All @@ -194,23 +189,28 @@ def get_agent_vars_dataframe(
for breed in self.agent_reporters
}
if not self.agent_reporters:
raise ABSESpyError(
logger.warning(
"No agent reporters have been defined in the DataCollector."
)
if results := self._agent_records.get(breed):
return pd.concat([pd.DataFrame(res) for res in results])
return pd.DataFrame()

def get_final_vars_report(self, model: MainModel) -> pd.DataFrame:
"""Report at the end of this model."""
if not self.final_reporters:
logger.warning(
"No final reporters have been defined"
"returning empty DataFrame."
)
return {var: func(model) for var, func in self.final_reporters.items()}

def collect(self, model: MainModel):
"""Collect all the data for the given model object."""

if self.model_reporters:
for var, func in self.model_reporters.items():
self.model_vars[var].append(func(model))

if self.final_reporters:
for var, func in self.model_reporters.items():
self.final_vars[var].append(func(model))

if self.agent_reporters:
self._record_agents(model)
18 changes: 13 additions & 5 deletions abses/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,22 @@ def _get_logging_mode(self, repeat_id: Optional[int] = None) -> str | bool:
logging = False
return logging

def _update_log_config(self, config, repeat_id: Optional[int] = None):
"""Update the log configuration."""
if isinstance(config, dict):
config = DictConfig(config)
log_name = self._get_logging_mode(repeat_id=repeat_id)
OmegaConf.set_struct(config, False)
logging_cfg = OmegaConf.create({"log": {"name": log_name}})
config = OmegaConf.merge(config, logging_cfg)
return config

def run(
self, cfg: DictConfig, repeat_id: int, outpath: Optional[Path] = None
) -> Tuple[Dict[str, Any], pd.DataFrame]:
"""运行模型一次"""
self._update_log_config(cfg, repeat_id)
# 获取日志
log_name = self._get_logging_mode(repeat_id=repeat_id)
OmegaConf.set_struct(cfg, False)
logging_cfg = OmegaConf.create({"log": {"name": log_name}})
cfg = OmegaConf.merge(cfg, logging_cfg)
model = self.model(
parameters=cfg,
run_id=repeat_id,
Expand All @@ -263,7 +270,8 @@ def run(
df = model.datacollector.get_model_vars_dataframe()
else:
df = pd.DataFrame()
return model.final_report(), df
final_report = model.datacollector.get_final_vars_report(model)
return final_report, df

def _update_result(
self,
Expand Down
30 changes: 8 additions & 22 deletions abses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import functools
import json
import os
import types
from datetime import datetime
from pathlib import Path
from typing import (
Expand All @@ -41,7 +40,7 @@

from mesa import Model
from mesa.time import BaseScheduler
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig

from abses import __version__
from abses._bases.logging import (
Expand Down Expand Up @@ -378,29 +377,16 @@ def _step(self) -> None:
def _end(self) -> None:
self._do_each("end", order=("nature", "human", "model"))
self._do_each("set_state", code=3)
# result = self.final_report()
# msg = (
# "The model is ended.\n"
# f"Total ticks: {self.time.tick}\n"
# f"Final result: {json.dumps(result, indent=4)}\n"
# )
# log_session(title="Ending Run", msg=msg)
result = self.datacollector.get_final_vars_report(self)
msg = (
"The model is ended.\n"
f"Total ticks: {self.time.tick}\n"
f"Final result: {json.dumps(result, indent=4)}\n"
)
log_session(title="Ending Run", msg=msg)
logger.bind(no_format=True).info(f"{datetime.now()}\n\n\n")
logger.remove()

# def final_report(self) -> Dict[str, Any]:
# """Report at the end of this model."""
# result = {}
# for k, reporter in self._reports["final"].items():
# if isinstance(reporter, str):
# value = getattr(self, reporter)
# elif isinstance(reporter, types.FunctionType):
# value = reporter(self)
# else:
# raise TypeError(f"Invalid final reporter {type(reporter)}.")
# result[k] = value
# return result

def summary(self, verbose: bool = False) -> pd.DataFrame:
"""Report the state of the model."""
print(f"Using ABSESpy version: {self.version}")
Expand Down
Loading

0 comments on commit f0f9603

Please sign in to comment.