Skip to content

Commit

Permalink
Fix use of dones buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 24, 2025
1 parent eb27b6f commit 9349ff2
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ build
!/.devcontainer/devcontainer.linux.json
!/.devcontainer/devcontainer.wsl2.json
!/.vscode/launch.json
**/*.pt
**/*.pt
tutorials/ppo/wandb
1 change: 0 additions & 1 deletion crazyflow/gymnasium_envs/crazyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.utils import batch_space
from jax import Array
from numpy.typing import NDArray

from crazyflow.control.control import MAX_THRUST, MIN_THRUST, Control
from crazyflow.sim import Sim
Expand Down
2 changes: 1 addition & 1 deletion tutorials/ppo/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
{
"n_envs": 32,
"device": "cuda",
"total_timesteps": 4_000_000,
"total_timesteps": 2_000_000,
"learning_rate": 3e-4,
"n_steps": 2048, # Number of steps per environment per policy rollout
"gamma": 0.99, # Discount factor
Expand Down
8 changes: 6 additions & 2 deletions tutorials/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
).to(config.device)
logprobs_buffer = torch.zeros((config.n_steps, config.n_envs)).to(config.device)
rewards_buffer = torch.zeros((config.n_steps, config.n_envs)).to(config.device)
# TODO: Remove dones buffer
dones_buffer = torch.zeros((config.n_steps, config.n_envs)).to(config.device)
terminated_buffer = torch.zeros((config.n_steps, config.n_envs)).to(config.device)
values_buffer = torch.zeros((config.n_steps, config.n_envs)).to(config.device)

# Stats tracking setup
Expand Down Expand Up @@ -229,6 +231,7 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
mask = active & ~autoreset
obs_buffer[steps[mask], mask] = obs[mask]
dones_buffer[steps[mask], mask] = done[mask].float()
terminated_buffer[steps[mask], mask] = terminated[mask].float()
values_buffer[steps[mask], mask] = value[mask].squeeze()
actions_buffer[steps[mask], mask] = action[mask]
logprobs_buffer[steps[mask], mask] = logprob[mask].squeeze()
Expand All @@ -246,10 +249,11 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
lastgaelam = 0
for t in reversed(range(config.n_steps)):
if t == config.n_steps - 1:
nextnonterminal = 1.0 - dones_buffer[t] # TODO: Replace with terminated buffer
# TODO: Check that terminated is correct instead of dones
nextnonterminal = 1.0 - terminated_buffer[t]
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones_buffer[t + 1]
nextnonterminal = 1.0 - terminated_buffer[t + 1]
nextvalues = values_buffer[t + 1]
delta = (
rewards_buffer[t]
Expand Down

0 comments on commit 9349ff2

Please sign in to comment.