Skip to content

Commit

Permalink
Improve hyperparams
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 26, 2025
1 parent 38ab0f8 commit d0852f1
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions tutorials/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action.py
"""


import logging
import random
import time
Expand Down Expand Up @@ -99,6 +98,8 @@ def sync_envs(train_envs, eval_envs):
eval_norm_env = unwrap_norm_env(eval_envs)
if (train_norm_env is None) != (eval_norm_env is None):
raise ValueError("Both envs must either have normalization or not have normalization")
if train_norm_env is None and eval_norm_env is None: # No normalization, no sync necessary
return
eval_norm_env.obs_rms.mean = train_norm_env.obs_rms.mean
eval_norm_env.obs_rms.var = train_norm_env.obs_rms.var

Expand Down Expand Up @@ -161,11 +162,10 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
wandb.login(key=wandb_api_key)
wandb.init(project="crazyflow-ppo", config=None)
config.update(wandb.config)
if config.get("n_train_samples"):
config.n_steps = config.n_train_samples // config.n_envs
wandb.config.update(dict(config))

if config.get("n_train_samples"):
config.n_steps = config.n_train_samples // config.n_envs

set_seeds(config.seed)
train_envs, eval_envs = make_envs(config.n_envs, config.n_eval_envs, config.device)

Expand Down Expand Up @@ -375,25 +375,27 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
if config.save_model:
save_model(agent, optimizer, train_envs, Path(__file__).parent / "ppo_checkpoint.pt")

plot_results(train_rewards_hist, train_rewards_steps, eval_rewards_hist, eval_rewards_steps)


if __name__ == "__main__":
config = ConfigDict(
{
"n_envs": 32,
"device": "cuda",
"total_timesteps": 4_000_000,
"learning_rate": 3e-4,
"n_steps": 2048, # Number of steps per environment per policy rollout
"gamma": 0.99, # Discount factor
"gae_lambda": 0.95, # Lambda for general advantage estimation
"n_minibatches": 32, # Number of mini-batches
"n_epochs": 10,
"total_timesteps": 2_000_000,
"learning_rate": 5e-3,
"n_steps": 1024, # Number of steps per environment per policy rollout
"gamma": 0.90, # Discount factor
"gae_lambda": 0.90, # Lambda for general advantage estimation
"n_minibatches": 8, # Number of mini-batches
"n_epochs": 15,
"norm_adv": True,
"clip_coef": 0.2,
"clip_coef": 0.25,
"clip_vloss": True,
"ent_coef": 0.0,
"vf_coef": 0.5,
"max_grad_norm": 0.5,
"max_grad_norm": 5.0,
"target_kl": None,
"seed": 0,
"n_eval_envs": 64,
Expand Down

0 comments on commit d0852f1

Please sign in to comment.