Skip to content

Commit

Permalink
Add random starts. Add lr decay
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 26, 2025
1 parent f72cc2b commit 0e3efc4
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions tutorials/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,19 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
autoreset = torch.zeros(config.n_envs, dtype=bool, device=config.device)

obs, _ = train_envs.reset(seed=config.seed)
for _ in range(1000):
train_envs.step(torch.tensor(train_envs.action_space.sample()))

for iteration in range(1, config.n_iterations + 1):
# Decay learning rate from initial value to 0.1x over all iterations
if config.lr_decay:
progress = (iteration - 1) / config.n_iterations # 0 to 1
current_lr = config.learning_rate * (1 - 0.8 * progress) # Decay to 0.2x
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
if wandb_log:
wandb.log({"train/learning_rate": current_lr}, step=global_step)

start_time = time.time()
t_env = 0
steps = torch.zeros(config.n_envs, dtype=torch.int32, device=config.device)
Expand Down Expand Up @@ -375,6 +386,9 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
if config.save_model:
save_model(agent, optimizer, train_envs, Path(__file__).parent / "ppo_checkpoint.pt")

if wandb_log:
wandb.finish()

# plot_results(train_rewards_hist, train_rewards_steps, eval_rewards_hist, eval_rewards_steps)


Expand All @@ -384,25 +398,26 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
"n_envs": 1024,
"device": "cuda",
"total_timesteps": 1_000_000,
"learning_rate": 1.5e-3,
"learning_rate": 1e-3,
"n_steps": 16, # Number of steps per environment per policy rollout
"gamma": 0.90, # Discount factor
"gae_lambda": 0.95, # Lambda for general advantage estimation
"n_minibatches": 16, # Number of mini-batches
"n_epochs": 15,
"norm_adv": True,
"clip_coef": 0.25,
"clip_coef": 0.275,
"clip_vloss": True,
"ent_coef": 0.01,
"ent_coef": 0.018,
"vf_coef": 0.5,
"max_grad_norm": 5.0,
"max_grad_norm": 1.5,
"target_kl": None,
"seed": 0,
"n_eval_envs": 64,
"n_eval_steps": 1_000,
"save_model": False,
"eval_interval": 999_000,
"lr_decay": False,
}
)

train_ppo(config, wandb_log=False)
train_ppo(config, wandb_log=True)

0 comments on commit 0e3efc4

Please sign in to comment.