Skip to content

Commit

Permalink
Add sweep. Improve ppo readability
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 24, 2025
1 parent b94c402 commit eb27b6f
Show file tree
Hide file tree
Showing 4 changed files with 433 additions and 308 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ build
!/.devcontainer/devcontainer.json
!/.devcontainer/devcontainer.linux.json
!/.devcontainer/devcontainer.wsl2.json
!/.vscode/launch.json
!/.vscode/launch.json
**/*.pt
Binary file removed tutorials/ppo/ppo_checkpoint.pt
Binary file not shown.
71 changes: 71 additions & 0 deletions tutorials/ppo/sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import fire
import wandb
from ml_collections import ConfigDict
from train import train_ppo

sweep_config = {
"method": "bayes",
"metric": {"name": "eval/mean_rewards", "goal": "maximize"},
"parameters": {
"n_envs": {"values": [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]},
"n_train_samples": {"values": [int(2**n) for n in range(15, 20)]},
"learning_rate": {
"distribution": "log_uniform",
"min": -10, # e⁻¹⁰ ~= 5e-5
"max": -5, # e⁻⁵ ~= 6e-3
},
"n_minibatches": {"values": [8, 16, 32, 64, 128]},
"n_epochs": {"values": [5, 10, 15]},
"clip_coef": {"distribution": "uniform", "min": 0.1, "max": 0.3},
"ent_coef": {"distribution": "uniform", "min": 0.0, "max": 0.25},
"vf_coef": {"distribution": "uniform", "min": 0.4, "max": 0.6},
"gamma": {"distribution": "uniform", "min": 0.9, "max": 0.999},
"gae_lambda": {"distribution": "uniform", "min": 0.5, "max": 0.99},
"max_grad_norm": {"distribution": "uniform", "min": 0.2, "max": 5.0},
},
}

config = ConfigDict(
{
"n_envs": 32,
"device": "cuda",
"total_timesteps": 4_000_000,
"learning_rate": 3e-4,
"n_steps": 2048, # Number of steps per environment per policy rollout
"gamma": 0.99, # Discount factor
"gae_lambda": 0.95, # Lambda for general advantage estimation
"n_minibatches": 32, # Number of mini-batches
"n_epochs": 10,
"norm_adv": True,
"clip_coef": 0.2,
"clip_vloss": True,
"ent_coef": 0.0,
"vf_coef": 0.5,
"max_grad_norm": 0.5,
"target_kl": None,
"seed": 0,
"n_eval_envs": 64,
"n_eval_steps": 1_000,
"save_model": False,
"eval_interval": 40_000,
}
)


def main(n_runs: int | None = None):
with open("wandb_api_key.secret", "r") as f:
wandb_api_key = f.read().lstrip("\n").rstrip("\n")
wandb.login(key=wandb_api_key)

sweep_id = wandb.sweep(sweep_config, project="crazyflow-ppo")

wandb.agent(
sweep_id,
lambda: train_ppo(config.copy_and_resolve_references(), True),
count=n_runs,
project="crazyflow-ppo",
)


if __name__ == "__main__":
fire.Fire(main)
Loading

0 comments on commit eb27b6f

Please sign in to comment.