-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Add MuJoCo Locomotion Environments (#156)
* Add Ant environment * Add HalfCheetah environment * Add HalfCheetah example * Add Ant xml model * Add Hopper environment * Add Walker2D environment * Add __init__.py to data * Finish locomotion envs * Add new mujoco envs to __init__ * Revert formatting changes * Move mujoco locomotion examples to one file * Add tests for locomotion environments * Fix minor details in example * Update test_locomotion.py * Comment out render for testing
- Loading branch information
Showing
18 changed files
with
1,256 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from argparse import ArgumentParser | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
from mushroom_rl.algorithms.actor_critic import PPO | ||
from mushroom_rl.core import Core, Logger | ||
from mushroom_rl.environments import Ant, HalfCheetah, Hopper, Walker2D | ||
from mushroom_rl.policy import GaussianTorchPolicy | ||
|
||
from tqdm import trange | ||
|
||
|
||
class Network(nn.Module): | ||
def __init__(self, input_shape, output_shape, n_features, **kwargs): | ||
super(Network, self).__init__() | ||
|
||
n_input = input_shape[-1] | ||
n_output = output_shape[0] | ||
|
||
self._h1 = nn.Linear(n_input, n_features) | ||
self._h2 = nn.Linear(n_features, n_features) | ||
self._h3 = nn.Linear(n_features, n_output) | ||
|
||
nn.init.xavier_uniform_( | ||
self._h1.weight, gain=nn.init.calculate_gain("relu") / 10 | ||
) | ||
nn.init.xavier_uniform_( | ||
self._h2.weight, gain=nn.init.calculate_gain("relu") / 10 | ||
) | ||
nn.init.xavier_uniform_( | ||
self._h3.weight, gain=nn.init.calculate_gain("linear") / 10 | ||
) | ||
|
||
def forward(self, state, **kwargs): | ||
features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) | ||
features2 = F.relu(self._h2(features1)) | ||
a = self._h3(features2) | ||
|
||
return a | ||
|
||
|
||
def experiment(env, n_epochs, n_steps, n_episodes_test): | ||
np.random.seed() | ||
|
||
logger = Logger(PPO.__name__, results_dir=None) | ||
logger.strong_line() | ||
logger.info("Experiment Algorithm: " + PPO.__name__) | ||
|
||
mdp = env() | ||
|
||
actor_lr = 3e-4 | ||
critic_lr = 3e-4 | ||
n_features = 32 | ||
batch_size = 64 | ||
n_epochs_policy = 10 | ||
eps = 0.2 | ||
lam = 0.95 | ||
std_0 = 1.0 | ||
n_steps_per_fit = 2000 | ||
|
||
critic_params = dict( | ||
network=Network, | ||
optimizer={"class": optim.Adam, "params": {"lr": critic_lr}}, | ||
loss=F.mse_loss, | ||
n_features=n_features, | ||
batch_size=batch_size, | ||
input_shape=mdp.info.observation_space.shape, | ||
output_shape=(1,), | ||
) | ||
|
||
alg_params = dict( | ||
actor_optimizer={"class": optim.Adam, "params": {"lr": actor_lr}}, | ||
n_epochs_policy=n_epochs_policy, | ||
batch_size=batch_size, | ||
eps_ppo=eps, | ||
lam=lam, | ||
critic_params=critic_params, | ||
) | ||
|
||
policy_params = dict(std_0=std_0, n_features=n_features) | ||
|
||
policy = GaussianTorchPolicy( | ||
Network, | ||
mdp.info.observation_space.shape, | ||
mdp.info.action_space.shape, | ||
**policy_params, | ||
) | ||
|
||
agent = PPO(mdp.info, policy, **alg_params) | ||
|
||
core = Core(agent, mdp) | ||
|
||
dataset = core.evaluate(n_episodes=n_episodes_test, render=False) | ||
|
||
J = np.mean(dataset.discounted_return) | ||
R = np.mean(dataset.undiscounted_return) | ||
E = agent.policy.entropy() | ||
|
||
logger.epoch_info(0, J=J, R=R, entropy=E) | ||
|
||
for it in trange(n_epochs, leave=False): | ||
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit) | ||
dataset = core.evaluate(n_episodes=n_episodes_test, render=False) | ||
|
||
J = np.mean(dataset.discounted_return) | ||
R = np.mean(dataset.undiscounted_return) | ||
E = agent.policy.entropy() | ||
|
||
logger.epoch_info(it + 1, J=J, R=R, entropy=E) | ||
|
||
logger.info("Press a button to visualize") | ||
input() | ||
core.evaluate(n_episodes=5, render=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
envs = [Ant, HalfCheetah, Hopper, Walker2D] | ||
for env in envs: | ||
experiment(env=env, n_epochs=50, n_steps=30000, n_episodes_test=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,16 @@ | ||
from .ball_in_a_cup import BallInACup | ||
from .air_hockey import AirHockeyHit, AirHockeyDefend, AirHockeyPrepare, AirHockeyRepel | ||
from .ant import Ant | ||
from .half_cheetah import HalfCheetah | ||
from .hopper import Hopper | ||
from .walker_2d import Walker2D | ||
|
||
BallInACup.register() | ||
AirHockeyHit.register() | ||
AirHockeyDefend.register() | ||
AirHockeyPrepare.register() | ||
AirHockeyRepel.register() | ||
Ant.register() | ||
HalfCheetah.register() | ||
Hopper.register() | ||
Walker2D.register() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from pathlib import Path | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
from mushroom_rl.environments.mujoco import MuJoCo, ObservationType | ||
from mushroom_rl.rl_utils.spaces import Box | ||
import mujoco | ||
|
||
|
||
class Ant(MuJoCo): | ||
""" | ||
The Ant MuJoCo environment as presented in: | ||
"High-Dimensional Continuous Control Using Generalized Advantage Estimation". John Schulman et. al.. 2015. | ||
and implemented in Gymnasium | ||
""" | ||
|
||
def __init__( | ||
self, | ||
gamma: float = 0.99, | ||
horizon: int = 1000, | ||
forward_reward_weight: float = 1.0, | ||
ctrl_cost_weight: float = 0.5, | ||
contact_cost_weight: float = 5e-4, | ||
healthy_reward: float = 1.0, | ||
terminate_when_unhealthy: bool = True, | ||
healthy_z_range: Tuple[float, float] = (0.2, 1.0), | ||
contact_force_range: Tuple[float, float] = (-1.0, 1.0), | ||
reset_noise_scale: float = 0.1, | ||
n_substeps: int = 5, | ||
exclude_current_positions_from_observation: bool = True, | ||
use_contact_forces: bool = False, | ||
**viewer_params, | ||
): | ||
""" | ||
Constructor. | ||
""" | ||
xml_path = ( | ||
Path(__file__).resolve().parent / "data" / "ant" / "model.xml" | ||
).as_posix() | ||
|
||
# This order is correct as specified in gymnasium | ||
actuation_spec = [ | ||
"hip_4", | ||
"ankle_4", | ||
"hip_1", | ||
"ankle_1", | ||
"hip_2", | ||
"ankle_2", | ||
"hip_3", | ||
"ankle_3", | ||
] | ||
|
||
observation_spec = [ | ||
("root_pose", "root", ObservationType.JOINT_POS), | ||
("hip_1_pos", "hip_1", ObservationType.JOINT_POS), | ||
("ankle_1_pos", "ankle_1", ObservationType.JOINT_POS), | ||
("hip_2_pos", "hip_2", ObservationType.JOINT_POS), | ||
("ankle_2_pos", "ankle_2", ObservationType.JOINT_POS), | ||
("hip_3_pos", "hip_3", ObservationType.JOINT_POS), | ||
("ankle_3_pos", "ankle_3", ObservationType.JOINT_POS), | ||
("hip_4_pos", "hip_4", ObservationType.JOINT_POS), | ||
("ankle_4_pos", "ankle_4", ObservationType.JOINT_POS), | ||
("root_vel", "root", ObservationType.JOINT_VEL), | ||
("hip_1_vel", "hip_1", ObservationType.JOINT_VEL), | ||
("ankle_1_vel", "ankle_1", ObservationType.JOINT_VEL), | ||
("hip_2_vel", "hip_2", ObservationType.JOINT_VEL), | ||
("ankle_2_vel", "ankle_2", ObservationType.JOINT_VEL), | ||
("hip_3_vel", "hip_3", ObservationType.JOINT_VEL), | ||
("ankle_3_vel", "ankle_3", ObservationType.JOINT_VEL), | ||
("hip_4_vel", "hip_4", ObservationType.JOINT_VEL), | ||
("ankle_4_vel", "ankle_4", ObservationType.JOINT_VEL), | ||
] | ||
|
||
additional_data_spec = [ | ||
("torso_pos", "torso", ObservationType.BODY_POS), | ||
("torso_vel", "torso", ObservationType.BODY_VEL_WORLD), | ||
] | ||
|
||
collision_groups = [ | ||
("torso", ["torso_geom"]), | ||
("floor", ["floor"]), | ||
] | ||
|
||
self._forward_reward_weight = forward_reward_weight | ||
self._ctrl_cost_weight = ctrl_cost_weight | ||
self._contact_cost_weight = contact_cost_weight | ||
self._healthy_reward = healthy_reward | ||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||
self._healthy_z_range = healthy_z_range | ||
self._contact_force_range = contact_force_range | ||
self._reset_noise_scale = reset_noise_scale | ||
self._exclude_current_positions_from_observation = ( | ||
exclude_current_positions_from_observation | ||
) | ||
self._use_contact_forces = use_contact_forces | ||
|
||
super().__init__( | ||
xml_file=xml_path, | ||
gamma=gamma, | ||
horizon=horizon, | ||
observation_spec=observation_spec, | ||
actuation_spec=actuation_spec, | ||
collision_groups=collision_groups, | ||
additional_data_spec=additional_data_spec, | ||
n_substeps=n_substeps, | ||
**viewer_params, | ||
) | ||
|
||
def _modify_mdp_info(self, mdp_info): | ||
if self._exclude_current_positions_from_observation: | ||
self.obs_helper.remove_obs("root_pose", 0) | ||
self.obs_helper.remove_obs("root_pose", 1) | ||
if self._use_contact_forces: | ||
self.obs_helper.add_obs("collision_force", 6) | ||
mdp_info = super()._modify_mdp_info(mdp_info) | ||
mdp_info.observation_space = Box(*self.obs_helper.get_obs_limits()) | ||
return mdp_info | ||
|
||
def _create_observation(self, obs): | ||
obs = super()._create_observation(obs) | ||
if self._use_contact_forces: | ||
collision_force = self._get_collision_force("torso", "floor") | ||
obs = np.concatenate([obs, collision_force]) | ||
return obs | ||
|
||
def _is_finite(self): | ||
states = self.get_states() | ||
return np.isfinite(states).all() | ||
|
||
def _is_within_z_range(self): | ||
z_pos = self._read_data("torso_pos")[2] | ||
min_z, max_z = self._healthy_z_range | ||
return min_z <= z_pos <= max_z | ||
|
||
def _is_healthy(self): | ||
is_healthy = self._is_finite() and self._is_within_z_range() | ||
return is_healthy | ||
|
||
def is_absorbing(self, obs): | ||
absorbing = self._terminate_when_unhealthy and not self._is_healthy() | ||
return absorbing | ||
|
||
def _get_healthy_reward(self, obs): | ||
return ( | ||
self._terminate_when_unhealthy and self._is_healthy() | ||
) * self._healthy_reward | ||
|
||
def _get_forward_reward(self): | ||
forward_reward = self._read_data("torso_vel")[3] | ||
return self._forward_reward_weight * forward_reward | ||
|
||
def _get_ctrl_cost(self, action): | ||
ctrl_cost = np.sum(np.square(action)) | ||
return self._ctrl_cost_weight * ctrl_cost | ||
|
||
def _get_contact_cost(self, obs): | ||
collision_force = self.obs_helper.get_from_obs(obs, "collision_force") | ||
contact_cost = np.sum( | ||
np.square(np.clip(collision_force, *self._contact_force_range)) | ||
) | ||
return self._contact_cost_weight * contact_cost | ||
|
||
def reward(self, obs, action, next_obs, absorbing): | ||
healthy_reward = self._get_healthy_reward(next_obs) | ||
forward_reward = self._get_forward_reward() | ||
cost = self._get_ctrl_cost(action) | ||
if self._use_contact_forces: | ||
contact_cost = self._get_contact_cost(next_obs) | ||
cost += contact_cost | ||
reward = healthy_reward + forward_reward - cost | ||
return reward | ||
|
||
def _generate_noise(self): | ||
self._data.qpos[:] = self._data.qpos + np.random.uniform( | ||
-self._reset_noise_scale, self._reset_noise_scale, size=self._model.nq | ||
) | ||
|
||
self._data.qvel[:] = ( | ||
self._data.qvel | ||
+ self._reset_noise_scale * np.random.standard_normal(self._model.nv) | ||
) | ||
|
||
def setup(self, obs): | ||
super().setup(obs) | ||
|
||
self._generate_noise() | ||
|
||
mujoco.mj_forward(self._model, self._data) # type: ignore | ||
|
||
def _create_info_dictionary(self, obs, action): | ||
info = { | ||
"healthy_reward": self._get_healthy_reward(obs), | ||
"forward_reward": self._get_forward_reward(), | ||
} | ||
info["ctrl_cost"] = self._get_ctrl_cost(action) | ||
if self._use_contact_forces: | ||
info["contact_cost"] = self._get_contact_cost(obs) | ||
return info | ||
|
||
def get_states(self): | ||
"""Return the position and velocity joint states of the model""" | ||
return np.concatenate([self._data.qpos.flat, self._data.qvel.flat]) |
Empty file.
Oops, something went wrong.