diff --git a/examples/acrobot_a2c.py b/examples/acrobot_a2c.py index 05bcdd2b..3f73c9d4 100644 --- a/examples/acrobot_a2c.py +++ b/examples/acrobot_a2c.py @@ -4,7 +4,7 @@ from mushroom_rl.algorithms.actor_critic import A2C from mushroom_rl.core import Core, Logger -from mushroom_rl.environments import Gym +from mushroom_rl.environments import Gymnasium from mushroom_rl.policy import BoltzmannTorchPolicy from mushroom_rl.approximators.parametric.torch_approximator import * from mushroom_rl.rl_utils.parameters import Parameter @@ -47,7 +47,7 @@ def experiment(n_epochs, n_steps, n_steps_per_fit, n_step_test): # MDP horizon = 1000 gamma = 0.99 - mdp = Gym('Acrobot-v1', horizon, gamma) + mdp = Gymnasium('Acrobot-v1', horizon, gamma, headless=False) # Policy policy_params = dict( diff --git a/examples/acrobot_dqn.py b/examples/acrobot_dqn.py index 81e5ea37..1c272630 100644 --- a/examples/acrobot_dqn.py +++ b/examples/acrobot_dqn.py @@ -54,7 +54,7 @@ def experiment(n_epochs, n_steps, n_steps_test): # MDP horizon = 1000 gamma = 0.99 - mdp = Gym('Acrobot-v1', horizon, gamma) + mdp = Gymnasium('Acrobot-v1', horizon, gamma, headless=False) # Policy epsilon = LinearParameter(value=1., threshold_value=.01, n=5000) diff --git a/examples/atari_dqn.py b/examples/atari_dqn.py index 5fc4cbda..fef112b9 100644 --- a/examples/atari_dqn.py +++ b/examples/atari_dqn.py @@ -276,9 +276,9 @@ def experiment(): max_steps = args.max_steps # MDP - mdp = Atari(args.name, args.screen_width, args.screen_height, + mdp = GymnasiumAtari(args.name, args.screen_width, args.screen_height, ends_at_life=True, history_length=args.history_length, - max_no_op_actions=args.max_no_op_actions) + max_no_op_actions=args.max_no_op_actions, headless=False) if args.load_path: logger = Logger(DQN.__name__, results_dir=None) @@ -408,7 +408,7 @@ def experiment(): pi.set_epsilon(epsilon_test) mdp.set_episode_end(False) dataset = core.evaluate(n_steps=test_samples, render=args.render, - quiet=args.quiet) + quiet=args.quiet, record=True) scores.append(get_stats(dataset, logger)) np.save(folder_name + '/scores.npy', scores) diff --git a/examples/mountain_car_sarsa.py b/examples/mountain_car_sarsa.py index c15c31b5..809e4052 100644 --- a/examples/mountain_car_sarsa.py +++ b/examples/mountain_car_sarsa.py @@ -3,7 +3,7 @@ from mushroom_rl.algorithms.value import TrueOnlineSARSALambda from mushroom_rl.core import Core, Logger -from mushroom_rl.environments import Gym +from mushroom_rl.environments import Gymnasium from mushroom_rl.features import Features from mushroom_rl.features.tiles import Tiles from mushroom_rl.policy import EpsGreedy @@ -21,7 +21,7 @@ def experiment(alpha): np.random.seed() # MDP - mdp = Gym(name='MountainCar-v0', horizon=np.inf, gamma=1.) + mdp = Gymnasium(name='MountainCar-v0', horizon=int(1e4), gamma=1., headless=False) # Policy epsilon = Parameter(value=0.) diff --git a/examples/pendulum_a2c.py b/examples/pendulum_a2c.py index 63c20cd2..5a057cc0 100644 --- a/examples/pendulum_a2c.py +++ b/examples/pendulum_a2c.py @@ -7,7 +7,7 @@ from tqdm import trange from mushroom_rl.core import Core, Logger -from mushroom_rl.environments import Gym +from mushroom_rl.environments import Gymnasium from mushroom_rl.algorithms.actor_critic import A2C from mushroom_rl.policy import GaussianTorchPolicy @@ -45,7 +45,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit, logger.strong_line() logger.info('Experiment Algorithm: ' + A2C.__name__) - mdp = Gym(env_id, horizon, gamma) + mdp = Gymnasium(env_id, horizon, gamma, headless=False) critic_params = dict(network=Network, optimizer={'class': optim.RMSprop, diff --git a/examples/pendulum_ddpg.py b/examples/pendulum_ddpg.py index c1adfa88..29ed9642 100644 --- a/examples/pendulum_ddpg.py +++ b/examples/pendulum_ddpg.py @@ -7,7 +7,7 @@ from mushroom_rl.algorithms.actor_critic import DDPG, TD3 from mushroom_rl.core import Core, Logger -from mushroom_rl.environments.gym_env import Gym +from mushroom_rl.environments import Gymnasium from mushroom_rl.policy import OrnsteinUhlenbeckPolicy from tqdm import trange @@ -76,7 +76,7 @@ def experiment(alg, n_epochs, n_steps, n_steps_test): # MDP horizon = 200 gamma = 0.99 - mdp = Gym('Pendulum-v1', horizon, gamma) + mdp = Gymnasium('Pendulum-v1', horizon, gamma, headless=False) # Policy policy_class = OrnsteinUhlenbeckPolicy diff --git a/examples/pendulum_sac.py b/examples/pendulum_sac.py index 031838bd..beb09028 100644 --- a/examples/pendulum_sac.py +++ b/examples/pendulum_sac.py @@ -7,7 +7,7 @@ from mushroom_rl.algorithms.actor_critic import SAC from mushroom_rl.core import Core, Logger -from mushroom_rl.environments.gym_env import Gym +from mushroom_rl.environments import Gymnasium from mushroom_rl.utils import TorchUtils from tqdm import trange @@ -76,7 +76,7 @@ def experiment(alg, n_epochs, n_steps, n_steps_test, save, load): # MDP horizon = 200 gamma = 0.99 - mdp = Gym('Pendulum-v1', horizon, gamma) + mdp = Gymnasium('Pendulum-v1', horizon, gamma, headless=False) # Settings initial_replay_size = 64 diff --git a/examples/pendulum_trust_region.py b/examples/pendulum_trust_region.py index 51c6aac0..2e70bf8c 100644 --- a/examples/pendulum_trust_region.py +++ b/examples/pendulum_trust_region.py @@ -7,7 +7,7 @@ from tqdm import trange from mushroom_rl.core import Core, Logger -from mushroom_rl.environments import Gym +from mushroom_rl.environments import Gymnasium from mushroom_rl.algorithms.actor_critic import PPO, TRPO from mushroom_rl.policy import GaussianTorchPolicy @@ -43,7 +43,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit, logger.strong_line() logger.info('Experiment Algorithm: ' + alg.__name__) - mdp = Gym(env_id, horizon, gamma) + mdp = Gymnasium(env_id, horizon, gamma, headless=False) critic_params = dict(network=Network, optimizer={'class': optim.Adam, diff --git a/mushroom_rl/environments/__init__.py b/mushroom_rl/environments/__init__.py index 11ce95b6..1de23d0e 100644 --- a/mushroom_rl/environments/__init__.py +++ b/mushroom_rl/environments/__init__.py @@ -12,6 +12,20 @@ except ImportError: pass +try: + GymnasiumAtari = None + from .gymnasium_atari import GymnasiumAtari + GymnasiumAtari.register() +except ImportError: + pass + +try: + Gymnasium = None + from .gymnasium_env import Gymnasium + Gymnasium.register() +except ImportError: + pass + try: DMControl = None from .dm_control_env import DMControl diff --git a/mushroom_rl/environments/gymnasium_atari.py b/mushroom_rl/environments/gymnasium_atari.py new file mode 100644 index 00000000..8f52b9ca --- /dev/null +++ b/mushroom_rl/environments/gymnasium_atari.py @@ -0,0 +1,168 @@ +from copy import deepcopy +from collections import deque + +import gymnasium as gym + +from mushroom_rl.core import Environment, MDPInfo +from mushroom_rl.rl_utils.spaces import * +from mushroom_rl.utils.frames import LazyFrames, preprocess_frame +from mushroom_rl.utils.viewer import ImageViewer + +class MaxAndSkip(gym.Wrapper): + def __init__(self, env, skip, max_pooling=True): + gym.Wrapper.__init__(self, env) + self._obs_buffer = np.zeros((2,) + env.observation_space.shape, + dtype=np.uint8) + self._skip = skip + self._max_pooling = max_pooling + + def reset(self): + return self.env.reset() + + def step(self, action): + total_reward = 0. + for i in range(self._skip): + obs, reward, absorbing, _, info = self.env.step(action) + if i == self._skip - 2: + self._obs_buffer[0] = obs + if i == self._skip - 1: + self._obs_buffer[1] = obs + total_reward += reward + if absorbing: + break + if self._max_pooling: + frame = self._obs_buffer.max(axis=0) + else: + frame = self._obs_buffer.mean(axis=0) + + return frame, total_reward, absorbing, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class GymnasiumAtari(Environment): + """ + The Atari environment as presented in: + "Human-level control through deep reinforcement learning". Mnih et. al.. + 2015. + + """ + def __init__(self, name, width=84, height=84, ends_at_life=False, + max_pooling=True, history_length=4, max_no_op_actions=30, headless = False): + """ + Constructor. + + Args: + name (str): id name of the Atari game in Gym; + width (int, 84): width of the screen; + height (int, 84): height of the screen; + ends_at_life (bool, False): whether the episode ends when a life is + lost or not; + max_pooling (bool, True): whether to do max-pooling or + average-pooling of the last two frames when using NoFrameskip; + history_length (int, 4): number of frames to form a state; + max_no_op_actions (int, 30): maximum number of no-op action to + execute at the beginning of an episode. + headless (bool, False): If True, the rendering is forced to be headless. + + """ + # MPD creation + if 'NoFrameskip' in name: + self.env = MaxAndSkip(gym.make(name, render_mode='rgb_array'), history_length, max_pooling) + else: + self.env = gym.make(name, render_mode='rgb_array') + + # MDP parameters + self._headless = headless + self._img_size = (width, height) + self._episode_ends_at_life = ends_at_life + self._max_lives = self.env.unwrapped.ale.lives() + self._lives = self._max_lives + self._force_fire = None + self._real_reset = True + self._max_no_op_actions = max_no_op_actions + self._history_length = history_length + self._current_no_op = None + + assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP' + + # MDP properties + action_space = Discrete(self.env.action_space.n) + observation_space = Box( + low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0])) + horizon = 1e4 # instead of np.inf + gamma = .99 + dt = 1/60 + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) + + # Viewer + self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), dt, headless=self._headless) + + super().__init__(mdp_info) + + def reset(self, state=None): + if self._real_reset: + state, info = self.env.reset() + self._state = preprocess_frame(state, self._img_size) + self._state = deque([deepcopy( + self._state) for _ in range(self._history_length)], + maxlen=self._history_length + ) + self._lives = self._max_lives + + self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE' + + self._current_no_op = np.random.randint(self._max_no_op_actions + 1) + + return LazyFrames(list(self._state), self._history_length), info + + def step(self, action): + action = action[0] + + # Force FIRE action to start episodes in games with lives + if self._force_fire: + obs, _, _, _, _ = self.env.env.step(1) + self._force_fire = False + while self._current_no_op > 0: + obs, _, _, _, _ = self.env.env.step(0) + self._current_no_op -= 1 + + obs, reward, absorbing, _, info = self.env.step(action) + self._real_reset = absorbing + + if info['lives'] != self._lives: + if self._episode_ends_at_life: + absorbing = True + self._lives = info['lives'] + self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE' + + self._state.append(preprocess_frame(obs, self._img_size)) + + return LazyFrames(list(self._state), self._history_length), reward, absorbing, info + + def render(self, record=False): + img = self.env.render() + + self._viewer.display(img) + + if record: + return img + else: + return None + + def stop(self): + self.env.close() + self._viewer.close() + self._real_reset = True + + def set_episode_end(self, ends_at_life): + """ + Setter. + + Args: + ends_at_life (bool): whether the episode ends when a life is + lost or not. + + """ + self._episode_ends_at_life = ends_at_life diff --git a/mushroom_rl/environments/gymnasium_env.py b/mushroom_rl/environments/gymnasium_env.py new file mode 100644 index 00000000..80519a12 --- /dev/null +++ b/mushroom_rl/environments/gymnasium_env.py @@ -0,0 +1,157 @@ +import warnings + +import gymnasium as gym +from gymnasium import spaces as gym_spaces + +try: + import pybullet_envs + pybullet_found = True +except ImportError: + pybullet_found = False + +from mushroom_rl.core import Environment, MDPInfo +from mushroom_rl.rl_utils.spaces import * +from mushroom_rl.utils.viewer import ImageViewer + +gym.logger.set_level(40) + + +class Gymnasium(Environment): + """ + Interface for Gymnasium environments. It makes it possible to use every + Gymnasium environment just providing the id, except for the Atari games that + are managed in a separate class. + + """ + def __init__(self, name, horizon=None, gamma=0.99, headless = False, wrappers=None, wrappers_args=None, + **env_args): + """ + Constructor. + + Args: + name (str): gym id of the environment; + horizon (int): the horizon. If None, use the one from Gym; + gamma (float, 0.99): the discount factor; + headless (bool, False): If True, the rendering is forced to be headless. + wrappers (list, None): list of wrappers to apply over the environment. It + is possible to pass arguments to the wrappers by providing + a tuple with two elements: the gym wrapper class and a + dictionary containing the parameters needed by the wrapper + constructor; + wrappers_args (list, None): list of list of arguments for each wrapper; + ** env_args: other gym environment parameters. + + """ + + # MDP creation + self._not_pybullet = True + self._first = True + self._headless = headless + self._viewer = None + if pybullet_found and '- ' + name in pybullet_envs.getList(): + import pybullet + pybullet.connect(pybullet.DIRECT) + self._not_pybullet = False + + self.env = gym.make(name, render_mode = 'rgb_array', **env_args) # always rgb_array render mode + + if wrappers is not None: + if wrappers_args is None: + wrappers_args = [dict()] * len(wrappers) + for wrapper, args in zip(wrappers, wrappers_args): + if isinstance(wrapper, tuple): + self.env = wrapper[0](self.env, *args, **wrapper[1]) + else: + self.env = wrapper(self.env, *args, **env_args) + + horizon = self._set_horizon(self.env, horizon) + + # MDP properties + assert not isinstance(self.env.observation_space, + gym_spaces.MultiDiscrete) + assert not isinstance(self.env.action_space, gym_spaces.MultiDiscrete) + + dt = self.env.unwrapped.dt if hasattr(self.env.unwrapped, "dt") else 0.1 + action_space = self._convert_gym_space(self.env.action_space) + observation_space = self._convert_gym_space(self.env.observation_space) + mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) + + if isinstance(action_space, Discrete): + self._convert_action = lambda a: a[0] + else: + self._convert_action = lambda a: a + + super().__init__(mdp_info) + + def reset(self, state=None): + if state is None: + state, info = self.env.reset() + return np.atleast_1d(state), info + else: + _, info = self.env.reset() + self.env.state = state + + return np.atleast_1d(state), info + + def step(self, action): + action = self._convert_action(action) + obs, reward, absorbing, _, info = self.env.step(action) #truncated flag is ignored + + return np.atleast_1d(obs), reward, absorbing, info + + def render(self, record=False): + if self._first or self._not_pybullet: + img = self.env.render() + + if self._first: + self._viewer = ImageViewer((img.shape[1], img.shape[0]), self.info.dt, headless=self._headless) + + self._viewer.display(img) + + self._first = False + + if record: + return img + else: + return None + + return None + + def stop(self): + try: + if self._not_pybullet: + self.env.close() + + if self._viewer is not None: + self._viewer.close() + except: + pass + + @staticmethod + def _set_horizon(env, horizon): + + while not hasattr(env, '_max_episode_steps') and env.env != env.unwrapped: + env = env.env + + if horizon is None: + if not hasattr(env, '_max_episode_steps'): + raise RuntimeError('This gymnasium environment has no specified time limit!') + horizon = env._max_episode_steps + if horizon == np.inf: + warnings.warn("Horizon can not be infinity.") + horizon = int(1e4) + + if hasattr(env, '_max_episode_steps'): + env._max_episode_steps = horizon + + return horizon + + @staticmethod + def _convert_gym_space(space): + if isinstance(space, gym_spaces.Discrete): + return Discrete(space.n) + elif isinstance(space, gym_spaces.Box): + return Box(low=space.low, high=space.high, shape=space.shape) + else: + raise ValueError + diff --git a/mushroom_rl/utils/viewer.py b/mushroom_rl/utils/viewer.py index 0ffbcc31..26805501 100644 --- a/mushroom_rl/utils/viewer.py +++ b/mushroom_rl/utils/viewer.py @@ -13,19 +13,21 @@ class ImageViewer: Interface to pygame for visualizing plain images. """ - def __init__(self, size, dt): + def __init__(self, size, dt, headless = False): """ Constructor. Args: size ([list, tuple]): size of the displayed image; dt (float): duration of a control step. + headless (bool, False): skip the display. """ self._size = size self._dt = dt self._initialized = False self._screen = None + self._headless = headless def display(self, img): """ @@ -35,6 +37,9 @@ def display(self, img): img: image to display. """ + if self._headless: + return + if not self._initialized: pygame.init() self._initialized = True