Skip to content

Commit

Permalink
add full gymnasium support for Atari
Browse files Browse the repository at this point in the history
  • Loading branch information
AhmedMagdyHendawy committed Feb 5, 2025
1 parent 435fbf7 commit 286ac99
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 65 deletions.
40 changes: 20 additions & 20 deletions examples/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

from mushroom_rl.algorithms.value import AveragedDQN, CategoricalDQN, DQN,\
DoubleDQN, MaxminDQN, DuelingDQN, NoisyDQN, QuantileDQN, Rainbow
from mushroom_rl.approximators.parametric import NumpyTorchApproximator
from mushroom_rl.approximators.parametric import NumpyTorchApproximator, TorchApproximator
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.rl_utils.parameters import LinearParameter, Parameter
from mushroom_rl.rl_utils.replay_memory import PrioritizedReplayMemory
from mushroom_rl.utils.torch import TorchUtils

"""
This script runs Atari experiments with DQN, and some of its variants, as
Expand Down Expand Up @@ -126,7 +127,7 @@ def experiment():
# help='Height of the game screen.')

arg_mem = parser.add_argument_group('Replay Memory')
arg_mem.add_argument("--initial-replay-size", type=int, default=50000,
arg_mem.add_argument("--initial-replay-size", type=int, default=20_000,
help='Initial size of the replay memory.')
arg_mem.add_argument("--max-replay-size", type=int, default=100000, #changed to 100k instead of 500k because of memory restrictions
help='Max size of the replay memory.')
Expand All @@ -141,13 +142,13 @@ def experiment():
'rmspropcentered'],
default='adam',
help='Name of the optimizer to use.')
arg_net.add_argument("--learning-rate", type=float, default=.0001,
arg_net.add_argument("--learning-rate", type=float, default=6.25e-5,
help='Learning rate value of the optimizer.')
arg_net.add_argument("--decay", type=float, default=.95,
help='Discount factor for the history coming from the'
'gradient momentum in rmspropcentered and'
'rmsprop')
arg_net.add_argument("--epsilon", type=float, default=1e-8,
arg_net.add_argument("--epsilon", type=float, default=1.5e-4,
help='Epsilon term used in rmspropcentered and'
'rmsprop')

Expand All @@ -163,36 +164,36 @@ def experiment():
"AveragedDQN or MaxminDQN.")
arg_alg.add_argument("--batch-size", type=int, default=32,
help='Batch size for each fit of the network.')
arg_alg.add_argument("--history-length", type=int, default=4,
help='Number of frames composing a state.')
arg_alg.add_argument("--target-update-frequency", type=int, default=10000,
# arg_alg.add_argument("--history-length", type=int, default=4,
# help='Number of frames composing a state.')
arg_alg.add_argument("--target-update-frequency", type=int, default=8_000,
help='Number of collected samples before each update'
'of the target network.')
arg_alg.add_argument("--evaluation-frequency", type=int, default=250000,
arg_alg.add_argument("--evaluation-frequency", type=int, default=250_000,
help='Number of collected samples before each'
'evaluation. An epoch ends after this number of'
'steps')
arg_alg.add_argument("--train-frequency", type=int, default=4,
help='Number of collected samples before each fit of'
'the neural network.')
arg_alg.add_argument("--max-steps", type=int, default=50000000,
arg_alg.add_argument("--max-steps", type=int, default=50_000_000,
help='Total number of collected samples.')
arg_alg.add_argument("--final-exploration-frame", type=int, default=1000000,
arg_alg.add_argument("--final-exploration-frame", type=int, default=250_000,
help='Number of collected samples until the exploration'
'rate stops decreasing.')
arg_alg.add_argument("--initial-exploration-rate", type=float, default=1.,
help='Initial value of the exploration rate.')
arg_alg.add_argument("--final-exploration-rate", type=float, default=.1,
arg_alg.add_argument("--final-exploration-rate", type=float, default=.01,
help='Final value of the exploration rate. When it'
'reaches this values, it stays constant.')
arg_alg.add_argument("--test-exploration-rate", type=float, default=.05,
help='Exploration rate used during evaluation.')
arg_alg.add_argument("--test-samples", type=int, default=125000,
arg_alg.add_argument("--test-samples", type=int, default=125_000,
help='Number of collected samples for each'
'evaluation.')
arg_alg.add_argument("--max-no-op-actions", type=int, default=30,
help='Maximum number of no-op actions performed at the'
'beginning of the episodes.')
# arg_alg.add_argument("--max-no-op-actions", type=int, default=30,
# help='Maximum number of no-op actions performed at the'
# 'beginning of the episodes.')
arg_alg.add_argument("--alpha-coeff", type=float, default=.6,
help='Prioritization exponent for prioritized experience replay.')
arg_alg.add_argument("--n-atoms", type=int, default=51,
Expand Down Expand Up @@ -318,12 +319,14 @@ def experiment():
output_shape=(mdp.info.action_space.n,),
n_actions=mdp.info.action_space.n,
n_features=Network.n_features,
optimizer=optimizer
optimizer=optimizer,
)
if args.algorithm not in ['cdqn', 'qdqn', 'rainbow']:
approximator_params['loss'] = F.smooth_l1_loss

approximator = NumpyTorchApproximator
approximator = TorchApproximator if args.use_cuda else NumpyTorchApproximator

TorchUtils.set_default_device('cuda:0' if torch.cuda.is_available() and args.use_cuda else 'cpu')

if args.prioritized:
replay_memory = PrioritizedReplayMemory(
Expand Down Expand Up @@ -410,7 +413,6 @@ def experiment():

# Evaluate initial policy
pi.set_epsilon(epsilon_test)
mdp.set_episode_end(False)
dataset = core.evaluate(n_steps=test_samples, render=args.render,
quiet=args.quiet, record=args.record)
scores.append(get_stats(dataset, logger))
Expand All @@ -421,7 +423,6 @@ def experiment():
logger.info('- Learning:')
# learning step
pi.set_epsilon(epsilon)
mdp.set_episode_end(True)
core.learn(n_steps=evaluation_frequency,
n_steps_per_fit=train_frequency, quiet=args.quiet)

Expand All @@ -431,7 +432,6 @@ def experiment():
logger.info('- Evaluation:')
# evaluation step
pi.set_epsilon(epsilon_test)
mdp.set_episode_end(False)
dataset = core.evaluate(n_steps=test_samples, render=args.render,
quiet=args.quiet)
scores.append(get_stats(dataset, logger))
Expand Down
49 changes: 18 additions & 31 deletions mushroom_rl/environments/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ class Atari(Environment):
2015.
"""
def __init__(self, name,
def __init__(self,
name,
width = 84,
height = 84,
full_action_space = False,
repeat_action_probability = 0.25,
frameskip = 4,
framestack = 4,
headless = False
):
"""
Expand All @@ -35,28 +42,19 @@ def __init__(self, name,
assert 'v5' in name, 'This wrapper supports only v5 ALE environments'
self.env = gym.make(
name,
full_action_space=False,
full_action_space=full_action_space,
frameskip=1,
repeat_action_probability=0.25,
repeat_action_probability=repeat_action_probability,
render_mode='rgb_array'
)

# MDP parameters
self.name = name
self.state_height, self.state_width = (84, 84)
self.n_stacked_frames = 4
self.n_skipped_frames = 4
self.state_height, self.state_width = (height, width)
self.n_stacked_frames = frameskip
self.n_skipped_frames = framestack
self._headless = headless
# self._episode_ends_at_life = ends_at_life
# self._max_lives = self.env.unwrapped.ale.lives()
# self._lives = self._max_lives
# self._force_fire = None
# self._real_reset = True
# self._max_no_op_actions = max_no_op_actions
# self._history_length = history_length
# self._current_no_op = None

# assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP'

self.original_state_height, self.original_state_width, _ = self.env.observation_space._shape
self.screen_buffer = [
np.empty((self.original_state_height, self.original_state_width), dtype=np.uint8),
Expand All @@ -67,18 +65,18 @@ def __init__(self, name,
action_space = Discrete(self.env.action_space.n)
observation_space = Box(
low=0., high=255., shape=(self.n_stacked_frames, self.state_height, self.state_width))
horizon = 1e4 # instead of np.inf
horizon = 27_000 # instead of np.inf
gamma = .99
dt = 1/60
mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt)

# Viewer
self._viewer = ImageViewer((self.state_height, self.state_width), dt, headless=self._headless)
self._viewer = ImageViewer((self.original_state_width, self.original_state_height), dt, headless=self._headless)

super().__init__(mdp_info)

def reset(self, state=None):
_, info = self.env.reset()
def reset(self, state=None, seed=None):
_, info = self.env.reset(seed=seed)

self.n_steps = 0

Expand Down Expand Up @@ -140,14 +138,3 @@ def render(self, record=False):
def stop(self):
self.env.close()
self._viewer.close()

# def set_episode_end(self, ends_at_life):
# """
# Setter.

# Args:
# ends_at_life (bool): whether the episode ends when a life is
# lost or not.

# """
# self._episode_ends_at_life = ends_at_life
2 changes: 1 addition & 1 deletion mushroom_rl/environments/gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, name, horizon=None, gamma=0.99, headless=False, wrappers=None
Constructor.
Args:
name (str): gym id of the environment;
name (str): gymnasium id of the environment;
horizon (int): the horizon. If None, use the one from Gymnasium;
gamma (float, 0.99): the discount factor;
headless (bool, False): If True, the rendering is forced to be headless.
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def fit(self, dataset):


def test_core():
mdp = Atari(name='BreakoutDeterministic-v4')
mdp = Atari(name='ALE/Breakout-v5', repeat_action_probability=0.0)

agent = DummyAgent(mdp.info)

Expand All @@ -45,7 +45,7 @@ def test_core():
info_lives = np.array(dataset.info['lives'])

print(info_lives)
lives_gt = np.array([5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
lives_gt = np.array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.])
assert len(info_lives) == 20
assert np.all(info_lives == lives_gt)
assert len(dataset) == 20
Expand Down
14 changes: 3 additions & 11 deletions tests/environments/test_all_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,15 @@

def test_atari():
np.random.seed(1)
mdp = Atari(name='PongDeterministic-v4')
mdp.reset()
mdp = Atari(name='ALE/Pong-v5')
mdp.reset(seed=1)
for i in range(10):
ns, r, ab, _ = mdp.step([np.random.randint(mdp.info.action_space.n)])
ns_test = np.load('tests/environments/test_atari_1.npy')

assert np.allclose(ns, ns_test)

mdp = Atari(name='PongNoFrameskip-v4')
mdp.reset()
for i in range(10):
ns, r, ab, _ = mdp.step([np.random.randint(mdp.info.action_space.n)])
ns_test = np.load('tests/environments/test_atari_2.npy')
ns_test = np.load('tests/environments/test_atari_1.npy')

assert np.allclose(ns, ns_test)


def test_car_on_hill():
np.random.seed(1)
mdp = CarOnHill()
Expand Down
Binary file modified tests/environments/test_atari_1.npy
Binary file not shown.
Binary file removed tests/environments/test_atari_2.npy
Binary file not shown.

0 comments on commit 286ac99

Please sign in to comment.