Skip to content

Commit

Permalink
examples now support gym >= 0.26 (#215)
Browse files Browse the repository at this point in the history
Handle interface changes from gym >= 0.26 based on #205
  • Loading branch information
51616 authored Oct 31, 2022
1 parent 6284ebb commit e384d09
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 13 deletions.
13 changes: 11 additions & 2 deletions examples/acme_examples/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
from acme.jax.types import PRNGKey
from acme.utils import loggers
from acme.utils.loggers import aggregators, base, filters, terminal
from packaging import version
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv

import envpool
from envpool.python.protocol import EnvPool

logging.getLogger().setLevel(logging.INFO)
is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class TimeStep(dm_env.TimeStep):
Expand Down Expand Up @@ -276,7 +278,10 @@ def __init__(

def reset(self) -> TimeStep:
self._reset_next_step = False
observation = self._environment.reset()
if is_legacy_gym:
observation = self._environment.reset()
else:
observation, _ = self._environment.reset()
ts = TimeStep(
step_type=np.full(self._num_envs, dm_env.StepType.FIRST, dtype="int32"),
reward=np.zeros(self._num_envs, dtype="float32"),
Expand All @@ -289,7 +294,11 @@ def step(self, action: types.NestedArray) -> TimeStep:
if self._reset_next_step:
return self.reset()
if self._use_env_pool:
observation, reward, done, _ = self._environment.step(action)
if is_legacy_gym:
observation, reward, done, _ = self._environment.step(action)
else:
observation, reward, term, trunc, _ = self._environment.step(action)
done = term + trunc
else:
self._environment.step_async(action)
observation, reward, done, _ = self._environment.step_wait()
Expand Down
19 changes: 16 additions & 3 deletions examples/cleanrl_examples/ppo_atari_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
import torch
import torch.nn as nn
import torch.optim as optim
from packaging import version
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

import envpool

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


def parse_args():
# fmt: off
Expand Down Expand Up @@ -221,7 +224,10 @@ def __init__(self, env, deque_size=100):
print("env has lives")

def reset(self, **kwargs):
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
if is_legacy_gym:
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
else:
observations, _ = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
Expand All @@ -230,8 +236,15 @@ def reset(self, **kwargs):
return observations

def step(self, action):
observations, rewards, dones, infos = super(RecordEpisodeStatistics,
self).step(action)
if is_legacy_gym:
observations, rewards, dones, infos = super(
RecordEpisodeStatistics, self
).step(action)
else:
observations, rewards, term, trunc, infos = super(
RecordEpisodeStatistics, self
).step(action)
dones = term + trunc
self.episode_returns += infos["reward"]
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
Expand Down
10 changes: 9 additions & 1 deletion examples/ppo_atari/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import argparse
from typing import Any, Dict, Tuple, Type

import gym
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from gae import compute_gae
from packaging import version
from torch import nn
from torch.utils.tensorboard import SummaryWriter

import envpool

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class CnnActorCritic(nn.Module):

Expand Down Expand Up @@ -228,7 +232,11 @@ def run(self) -> None:
while t.n < self.config.step_per_epoch:
# collect
for _ in range(self.config.step_per_collect // self.config.waitnum):
obs, rew, done, info = self.train_envs.recv()
if is_legacy_gym:
obs, rew, done, info = self.train_envs.recv()
else:
obs, rew, term, trunc, info = self.train_envs.recv()
done = term + trunc
env_id = info["env_id"]
obs = torch.tensor(obs, device="cuda")
self.obs_batch.append(obs)
Expand Down
40 changes: 35 additions & 5 deletions examples/sb3_examples/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import gym
import numpy as np
import torch as th
from packaging import version
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
Expand All @@ -38,6 +39,7 @@
seed = 0
use_env_pool = True # whether to use EnvPool or Gym for training
render = False # whether to render final policy using Gym
is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


class VecAdapter(VecEnvWrapper):
Expand All @@ -56,14 +58,21 @@ def step_async(self, actions: np.ndarray) -> None:
self.actions = actions

def reset(self) -> VecEnvObs:
return self.venv.reset()
if is_legacy_gym:
return self.venv.reset()
else:
return self.venv.reset()[0]

def seed(self, seed: Optional[int] = None) -> None:
# You can only seed EnvPool env by calling envpool.make()
pass

def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
if is_legacy_gym:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
else:
obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions)
dones = terms + truncs
infos = []
# Convert dict to list of dict
# and add terminal observation
Expand All @@ -77,8 +86,10 @@ def step_wait(self) -> VecEnvStepReturn:
)
if dones[i]:
infos[i]["terminal_observation"] = obs[i]
obs[i] = self.venv.reset(np.array([i]))

if is_legacy_gym:
obs[i] = self.venv.reset(np.array([i]))
else:
obs[i] = self.venv.reset(np.array([i]))[0]
return obs, rewards, dones, infos


Expand Down Expand Up @@ -115,7 +126,26 @@ def step_wait(self) -> VecEnvStepReturn:
pass

# Agent trained on envpool version should also perform well on regular Gym env
test_env = gym.make(env_id)
if not is_legacy_gym:

def legacy_wrap(env):
env.reset_fn = env.reset
env.step_fn = env.step

def legacy_reset():
return env.reset_fn()[0]

def legacy_step(action):
obs, rew, term, trunc, info = env.step_fn(action)
return obs, rew, term + trunc, info

env.reset = legacy_reset
env.step = legacy_step
return env

test_env = legacy_wrap(gym.make(env_id))
else:
test_env = gym.make(env_id)

# Test with EnvPool
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20)
Expand Down
14 changes: 12 additions & 2 deletions examples/xla_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
See https://envpool.readthedocs.io/en/latest/content/xla_interface.html
"""

import gym
import jax.numpy as jnp
from jax import jit, lax
from packaging import version

import envpool

is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")


def policy(states: jnp.ndarray) -> jnp.ndarray:
return jnp.zeros(states.shape[0], dtype=jnp.int32)
Expand All @@ -35,14 +39,20 @@ def gym_sync_step() -> None:
def actor_step(iter, loop_var):
handle0, states = loop_var
action = policy(states)
handle1, (new_states, rew, done, info) = step(handle0, action)
if is_legacy_gym:
handle1, (new_states, rew, done, info) = step(handle0, action)
else:
handle1, (new_states, rew, term, trunc, info) = step(handle0, action)
return (handle1, new_states)

@jit
def run_actor_loop(num_steps, init_var):
return lax.fori_loop(0, num_steps, actor_step, init_var)

states = env.reset()
if is_legacy_gym:
states = env.reset()
else:
states, _ = env.reset()
run_actor_loop(100, (handle, states))


Expand Down

0 comments on commit e384d09

Please sign in to comment.