Skip to content

Commit

Permalink
(v3.5.2) SB3 Logger bug fix with wandb (#429)
Browse files Browse the repository at this point in the history
* Scripts: Train updated with traceback show (useful for wandb)

* LoggerEvalCallback: Updated eval freq to avoid empty episodes in any envioronment

* Logger: Native SB3 logger updated for no specify step in wandb log

* Version updated from 3.5.1 to 3.5.2
  • Loading branch information
AlejandroCN7 authored Aug 29, 2024
1 parent e2e37b2 commit 616cb43
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 12 deletions.
2 changes: 2 additions & 0 deletions scripts/train/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import sys
from datetime import datetime
import traceback

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -343,6 +344,7 @@ def process_algorithm_parameters(alg_params: dict):

except (Exception, KeyboardInterrupt) as err:
print("Error or interruption in process detected")
print(traceback.print_exc(), file=sys.stderr)

# Current model state save
model.save(env.get_wrapper_attr('workspace_path') + '/model')
Expand Down
8 changes: 6 additions & 2 deletions scripts/train/train_agent_A2C.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"stats_window_size": 100,
"tensorboard_log": null,
"policy_kwargs": null,
"verbose": 0,
"verbose": 1,
"seed": null,
"device": "auto",
"_init_setup_model": true
Expand All @@ -42,7 +42,11 @@
"LoggerWrapper": {
"storage_class": "sinergym.utils.logger.LoggerStorage"
},
"CSVLogger": {}
"CSVLogger": {},
"WandBLogger": {
"entity": "sail_ugr",
"project_name": "test_project"
}
},
"evaluation": {
"eval_freq": 2,
Expand Down
8 changes: 6 additions & 2 deletions scripts/train/train_agent_DDPG.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"optimize_memory_usage": false,
"tensorboard_log": null,
"policy_kwargs": null,
"verbose": 0,
"verbose": 1,
"seed": null,
"device": "auto",
"_init_setup_model": true
Expand All @@ -39,7 +39,11 @@
"LoggerWrapper": {
"storage_class": "sinergym.utils.logger.LoggerStorage"
},
"CSVLogger": {}
"CSVLogger": {},
"WandBLogger": {
"entity": "sail_ugr",
"project_name": "test_project"
}
},
"evaluation": {
"eval_freq": 2,
Expand Down
6 changes: 5 additions & 1 deletion scripts/train/train_agent_DQN.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
"LoggerWrapper": {
"storage_class": "sinergym.utils.logger.LoggerStorage"
},
"CSVLogger": {}
"CSVLogger": {},
"WandBLogger": {
"entity": "sail_ugr",
"project_name": "test_project"
}
},
"evaluation": {
"eval_freq": 2,
Expand Down
6 changes: 5 additions & 1 deletion scripts/train/train_agent_PPO.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
"LoggerWrapper": {
"storage_class": "sinergym.utils.logger.LoggerStorage"
},
"CSVLogger": {}
"CSVLogger": {},
"WandBLogger":{
"entity": "sail_ugr",
"project_name": "test_project"
}
},
"evaluation": {
"eval_freq": 2,
Expand Down
8 changes: 6 additions & 2 deletions scripts/train/train_agent_SAC.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"stats_window_size": 100,
"tensorboard_log": null,
"policy_kwargs": null,
"verbose": 0,
"verbose": 1,
"seed": null,
"device": "auto",
"_init_setup_model": true
Expand All @@ -46,7 +46,11 @@
"LoggerWrapper": {
"storage_class": "sinergym.utils.logger.LoggerStorage"
},
"CSVLogger": {}
"CSVLogger": {},
"WandBLogger": {
"entity": "sail_ugr",
"project_name": "test_project"
}
},
"evaluation": {
"eval_freq": 2,
Expand Down
6 changes: 5 additions & 1 deletion scripts/train/train_agent_TD3.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
"LoggerWrapper": {
"storage_class": "sinergym.utils.logger.LoggerStorage"
},
"CSVLogger": {}
"CSVLogger": {},
"WandBLogger": {
"entity": "sail_ugr",
"project_name": "test_project"
}
},
"evaluation": {
"eval_freq": 2,
Expand Down
2 changes: 1 addition & 1 deletion sinergym/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.train_env = train_env
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq_episodes * \
train_env.get_wrapper_attr('timestep_per_episode') - 3
train_env.get_wrapper_attr('timestep_per_episode') - 30
self.save_path = self.train_env.get_wrapper_attr(
'workspace_path') + '/evaluation'
# Make dir if not exists
Expand Down
2 changes: 1 addition & 1 deletion sinergym/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ def write(

if isinstance(value, np.ScalarType):
if not isinstance(value, str):
wandb.log({key: value}, step=step)
wandb.log({key: value}, commit=False)
2 changes: 1 addition & 1 deletion sinergym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.5.1
3.5.2

0 comments on commit 616cb43

Please sign in to comment.