Skip to content

Commit

Permalink
[wip] PPO implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Jan 24, 2025
1 parent 66083e5 commit b94c402
Show file tree
Hide file tree
Showing 9 changed files with 12,582 additions and 8 deletions.
11 changes: 5 additions & 6 deletions crazyflow/gymnasium_envs/crazyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gymnasium.vector import VectorEnv, VectorWrapper
from gymnasium.vector.utils import batch_space
from jax import Array
from scipy.interpolate import splev, splprep
from numpy.typing import NDArray

Check failure on line 14 in crazyflow/gymnasium_envs/crazyflow.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

crazyflow/gymnasium_envs/crazyflow.py:14:26: F401 `numpy.typing.NDArray` imported but unused

from crazyflow.control.control import MAX_THRUST, MIN_THRUST, Control
from crazyflow.sim import Sim
Expand Down Expand Up @@ -434,7 +434,7 @@ def __init__(
def reward(self) -> Array:
return self._reward(
self.prev_done, self.terminated, self.sim.data.states, self.trajectory[self.steps]
)
).reshape(-1)

@staticmethod
@jax.jit
Expand Down Expand Up @@ -495,16 +495,15 @@ def __init__(self, env: VectorEnv):
self.single_action_space.high = np.ones_like(self.action_sim_high)
self.action_space = batch_space(self.single_action_space, self.num_envs)

def step(self, actions: Array) -> tuple[Array, Array, Array, Array, dict]:
def step(self, actions: Array) -> tuple[dict, Array, Array, Array, dict]:
actions = np.clip(actions, -1.0, 1.0)
return self.env.step(self.actions(actions))
obs, reward, terminated, truncated, info = self.env.step(self.actions(actions))
return obs, reward, terminated, truncated, info

def actions(self, actions: Array) -> Array:
"""Rescale and clip actions from [-1, 1] to [action_sim_low, action_sim_high]."""
# Rescale actions using the computed scale and mean
rescaled_actions = actions * self.action_scale + self.action_mean

# Ensure actions are within the valid range of the simulation action space
rescaled_actions = np.clip(rescaled_actions, self.action_sim_low, self.action_sim_high)

return rescaled_actions
2 changes: 1 addition & 1 deletion tutorials/LQR_ILQR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down
12,083 changes: 12,083 additions & 0 deletions tutorials/PPO.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tutorials/compare_sim_and_symbolic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down
49 changes: 49 additions & 0 deletions tutorials/ppo/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions import Normal


def layer_init(layer: nn.Linear, std: float = np.sqrt(2), bias_const: float = 0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(torch.tensor(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(torch.tensor(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(
nn.Linear(64, torch.tensor(envs.single_action_space.shape).prod()), std=0.01
),
)
self.actor_logstd = nn.Parameter(
torch.zeros(1, torch.tensor(envs.single_action_space.shape).prod())
)

def value(self, x: Tensor) -> Tensor:
return self.critic(x)

def action_and_value(
self, x: Tensor, action: Tensor | None = None, deterministic: bool = False
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
action_mean = self.actor_mean(x)
action_logstd = self.actor_logstd.expand_as(action_mean)
action_std = torch.exp(action_logstd)
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample() if not deterministic else action_mean
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
Binary file added tutorials/ppo/ppo_checkpoint.pt
Binary file not shown.
81 changes: 81 additions & 0 deletions tutorials/ppo/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging
import random

import gymnasium
import gymnasium.wrappers.vector.jax_to_torch
import numpy as np
import torch
from agent import Agent
from wrappers import FlattenJaxObservation

import crazyflow # noqa: F401, register the gymnasium envs
from crazyflow.gymnasium_envs.crazyflow import CrazyflowRL

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
n_envs = 2
seed = 0

# Seeding
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Create and wrap test environment
env_device = "cpu"
test_env = gymnasium.make_vec(
"DroneFigureEightTrajectory-v0",
freq=50,
num_envs=n_envs,
render_samples=True,
device=env_device,
)

test_env = CrazyflowRL(test_env)
test_env = FlattenJaxObservation(test_env)
norm_test_env = gymnasium.wrappers.vector.NormalizeObservation(test_env)
norm_test_env.update_running_mean = False
test_env = gymnasium.wrappers.vector.jax_to_torch.JaxToTorch(norm_test_env, device=device)

# Load checkpoint
checkpoint = torch.load("ppo_checkpoint.pt")

# Create agent and load state
agent = Agent(test_env).to(device)
agent.load_state_dict(checkpoint["model_state_dict"])

# Set normalization parameters
norm_test_env.obs_rms.mean = checkpoint["obs_mean"]
norm_test_env.obs_rms.var = checkpoint["obs_var"]

# Test for 10 episodes
n_episodes = 10
episode_rewards = []
episode_lengths = []

for episode in range(n_episodes):
obs, _ = test_env.reset(seed=seed + episode)
done = torch.zeros(n_envs, dtype=bool, device=device)
episode_reward = 0
steps = 0

while not done.all():
with torch.no_grad():
action, _, _, _ = agent.action_and_value(obs, deterministic=True)
obs, reward, terminated, truncated, info = test_env.step(action)
test_env.render()
done = terminated | truncated
# episode_reward += reward.item()
steps += 1

episode_rewards.append(episode_reward)
episode_lengths.append(steps)
print(f"Episode {episode + 1}: Reward = {episode_reward:.2f}, Length = {steps}")

print(f"\nAverage episode reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"Average episode length: {np.mean(episode_lengths):.1f} ± {np.std(episode_lengths):.1f}")
Loading

0 comments on commit b94c402

Please sign in to comment.