From a9f3f1a56efdec366240eec33885a44d4f7a7ca1 Mon Sep 17 00:00:00 2001 From: boris-il-forte Date: Mon, 22 Jan 2024 16:25:37 +0100 Subject: [PATCH] Work on Actor-Critic algorithms - Minor refactoring in Agent - Created new class OnPolicyDeepAC - Fixed some issues of PPO_BPTT and policy class - PPO_BPTT is now using our latest interface - Implemented agent normalization in PPO, TRPO, PPO_BPTT --- examples/gym_recurrent_ppo.py | 16 +-- .../deep_actor_critic/__init__.py | 4 +- .../deep_actor_critic/deep_actor_critic.py | 21 ++- .../actor_critic/deep_actor_critic/ppo.py | 7 +- .../deep_actor_critic/ppo_bptt.py | 136 +++++++++--------- .../actor_critic/deep_actor_critic/trpo.py | 9 +- mushroom_rl/core/agent.py | 56 ++++---- mushroom_rl/policy/recurrent_torch_policy.py | 2 +- 8 files changed, 135 insertions(+), 116 deletions(-) diff --git a/examples/gym_recurrent_ppo.py b/examples/gym_recurrent_ppo.py index d68d6207..fe234e33 100644 --- a/examples/gym_recurrent_ppo.py +++ b/examples/gym_recurrent_ppo.py @@ -223,7 +223,7 @@ def experiment( # setup critic input_shape_critic = (mdp.info.observation_space.shape[0]+2*n_hidden_features,) critic_params = dict(network=PPOCriticBPTTNetwork, - optimizer={'class': optim.Adam, + optimizer={'class': optim.Adam, 'params': {'lr': lr_critic, 'weight_decay': 0.0}}, loss=torch.nn.MSELoss(), @@ -240,7 +240,7 @@ def experiment( ) alg_params = dict(actor_optimizer={'class': optim.Adam, - 'params': {'lr': lr_actor, + 'params': {'lr': lr_actor, 'weight_decay': 0.0}}, n_epochs_policy=n_epochs_policy, batch_size=batch_size_actor, @@ -258,9 +258,9 @@ def experiment( # Evaluation dataset = core.evaluate(n_episodes=5) - J = np.mean(dataset.discounted_return) - R = np.mean(dataset.undiscounted_return) - L = np.mean(dataset.episodes_length) + J = dataset.discounted_return.mean() + R = dataset.undiscounted_return.mean() + L = dataset.episodes_length.mean() logger.log_numpy(R=R, J=J, L=L) logger.epoch_info(0, R=R, J=J, L=L) @@ -269,9 +269,9 @@ def experiment( # Evaluation dataset = core.evaluate(n_episodes=n_episode_eval) - J = np.mean(dataset.discounted_return) - R = np.mean(dataset.undiscounted_return) - L = np.mean(dataset.episodes_length) + J = dataset.discounted_return.mean() + R = dataset.undiscounted_return.mean() + L = dataset.episodes_length.mean() logger.log_numpy(R=R, J=J, L=L) logger.epoch_info(i, R=R, J=J, L=L) diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py index 9740f682..f561b311 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/__init__.py @@ -1,4 +1,4 @@ -from .deep_actor_critic import DeepAC +from .deep_actor_critic import OnPolicyDeepAC, DeepAC from .a2c import A2C from .ddpg import DDPG from .td3 import TD3 @@ -6,5 +6,3 @@ from .trpo import TRPO from .ppo import PPO from .ppo_bptt import PPO_BPTT - -__all__ = ['DeepAC', 'A2C', 'DDPG', 'TD3', 'SAC', 'TRPO', 'PPO', 'PPO_BPTT'] \ No newline at end of file diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py index 0ea5a4eb..1ae11b32 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/deep_actor_critic.py @@ -2,10 +2,27 @@ from mushroom_rl.utils.torch import TorchUtils +class OnPolicyDeepAC(Agent): + def _preprocess_state(self, state, next_state, output_old=True): + state_old = None + + if output_old: + state_old = self._agent_preprocess(state) + + self._update_agent_preprocessor(state) + state = self._agent_preprocess(state) + next_state = self._agent_preprocess(next_state) + + if output_old: + return state, next_state, state_old + else: + return state, next_state + + class DeepAC(Agent): """ - Base class for algorithms that uses the reparametrization trick, such as - SAC, DDPG and TD3. + Base class for off policy deep actor-critic algorithms. + These algorithms use the reparametrization trick, such as SAC, DDPG and TD3. """ diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py index dd339085..b4524881 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from mushroom_rl.core import Agent +from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC from mushroom_rl.approximators import Regressor from mushroom_rl.approximators.parametric import TorchApproximator from mushroom_rl.utils.torch import TorchUtils @@ -12,7 +12,7 @@ from mushroom_rl.rl_utils.parameters import to_parameter -class PPO(Agent): +class PPO(OnPolicyDeepAC): """ Proximal Policy Optimization algorithm. "Proximal Policy Optimization Algorithms". @@ -72,6 +72,7 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, def fit(self, dataset): state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') + state, next_state, state_old = self._preprocess_state(state, next_state) v_target, adv = compute_gae(self._V, state, next_state, reward, absorbing, last, self.mdp_info.gamma, self._lambda()) @@ -80,7 +81,7 @@ def fit(self, dataset): adv = adv.detach() v_target = v_target.detach() - old_pol_dist = self.policy.distribution_t(state) + old_pol_dist = self.policy.distribution_t(state_old) old_log_p = old_pol_dist.log_prob(action)[:, None].detach() self._V.fit(state, v_target, **self._critic_fit_params) diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py index 27b094a8..bfbed4c2 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/ppo_bptt.py @@ -1,15 +1,14 @@ import torch -from mushroom_rl.core import Agent +from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC from mushroom_rl.approximators import Regressor from mushroom_rl.approximators.parametric import TorchApproximator from mushroom_rl.utils.torch import TorchUtils from mushroom_rl.utils.minibatches import minibatch_generator from mushroom_rl.rl_utils.parameters import to_parameter -from mushroom_rl.rl_utils.preprocessors import StandardizationPreprocessor -class PPO_BPTT(Agent): +class PPO_BPTT(OnPolicyDeepAC): """ Proximal Policy Optimization algorithm. "Proximal Policy Optimization Algorithms". @@ -71,81 +70,84 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params, _dim_env_state='primitive' ) - # add the standardization preprocessor - self._core_preprocessors.append(StandardizationPreprocessor(mdp_info)) - - def divide_state_to_env_hidden_batch(self, states): - assert len(states.shape) > 1, "This function only divides batches of states." - return states[:, 0:self._dim_env_state], states[:, self._dim_env_state:] - def fit(self, dataset): - obs, act, r, obs_next, absorbing, last = dataset.parse(to='torch') + state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') + state, next_state, state_old = self._preprocess_state(state, next_state) + policy_state, policy_next_state = dataset.parse_policy_state(to='torch') - obs_seq, policy_state_seq, act_seq, obs_next_seq, policy_next_state_seq, lengths = \ - self.transform_to_sequences(obs, policy_state, act, obs_next, policy_next_state, last, absorbing) + state_old_seq, state_seq, policy_state_seq, act_seq, state_next_seq, policy_next_state_seq, lengths = \ + self._transform_to_sequences(state_old, state, policy_state, action, next_state, policy_next_state, + last, absorbing) - v_target, adv = self.compute_gae(self._V, obs_seq, policy_state_seq, obs_next_seq, policy_next_state_seq, - lengths, r, absorbing, last, self.mdp_info.gamma, self._lambda()) + v_target, adv = self.compute_gae(self._V, state_seq, policy_state_seq, state_next_seq, policy_next_state_seq, + lengths, reward, absorbing, last, self.mdp_info.gamma, self._lambda()) adv = (adv - torch.mean(adv)) / (torch.std(adv) + 1e-8) - old_pol_dist = self.policy.distribution_t(obs_seq, policy_state_seq, lengths) - old_log_p = old_pol_dist.log_prob(act)[:, None].detach() + old_pol_dist = self.policy.distribution_t(state_old_seq, policy_state_seq, lengths) + old_log_p = old_pol_dist.log_prob(action)[:, None].detach() - self._V.fit(obs_seq, policy_state_seq, lengths, v_target, **self._critic_fit_params) + self._V.fit(state_seq, policy_state_seq, lengths, v_target, **self._critic_fit_params) - self._update_policy(obs_seq, policy_state_seq, act, lengths, adv, old_log_p) + self._update_policy(state_seq, policy_state_seq, action, lengths, adv, old_log_p) # Print fit information - self._log_info(dataset, obs_seq, policy_state_seq, lengths, v_target, old_pol_dist) + self._log_info(dataset, state_seq, policy_state_seq, lengths, v_target, old_pol_dist) self._iter += 1 - def transform_to_sequences(self, states, policy_states, actions, next_states, policy_next_states, last, absorbing): - - s = torch.empty(len(states), self._truncation_length, states.shape[-1]) - ps = torch.empty(len(states), policy_states.shape[-1]) - a = torch.empty(len(actions), self._truncation_length, actions.shape[-1]) - ss = torch.empty(len(states), self._truncation_length, states.shape[-1]) - pss = torch.empty(len(states), policy_states.shape[-1]) - lengths = torch.empty(len(states), dtype=torch.long) - - for i in range(len(states)): - # determine the begin of a sequence - begin_seq = max(i - self._truncation_length + 1, 0) - end_seq = i + 1 - - # maybe the sequence contains more than one trajectory, so we need to cut it so that it contains only one - lasts_absorbing = last[begin_seq - 1: i].int() + absorbing[begin_seq - 1: i].int() - begin_traj = torch.where(lasts_absorbing > 0) - sequence_is_shorter_than_requested = len(*begin_traj) > 0 - if sequence_is_shorter_than_requested: - begin_seq = begin_seq + begin_traj[0][-1] - - # get the sequences - states_seq = states[begin_seq:end_seq] - actions_seq = actions[begin_seq:end_seq] - next_states_seq = next_states[begin_seq:end_seq] - - # apply padding - length_seq = len(states_seq) - padded_states = torch.concatenate([states_seq, - torch.zeros((self._truncation_length - states_seq.shape[0], - states_seq.shape[1]))]) - padded_next_states = torch.concatenate([next_states_seq, - torch.zeros((self._truncation_length - next_states_seq.shape[0], - next_states_seq.shape[1]))]) - padded_action_seq = torch.concatenate([actions_seq, - torch.zeros((self._truncation_length - actions_seq.shape[0], - actions_seq.shape[1]))]) - - s[i] = padded_states - ps[i] = policy_states[begin_seq] - a[i] = padded_action_seq - ss[i] = padded_next_states - pss[i] = policy_next_states[begin_seq] - - lengths[i] = length_seq - - return s.detach(), ps.detach(), a.detach(), ss.detach(), pss.detach(), lengths.detach() + def _transform_to_sequences(self, states_old, states, policy_states, actions, next_states, policy_next_states, + last, absorbing): + with torch.no_grad(): + s_old = torch.empty(len(states), self._truncation_length, states.shape[-1]) + s = torch.empty(len(states), self._truncation_length, states.shape[-1]) + ps = torch.empty(len(states), policy_states.shape[-1]) + a = torch.empty(len(actions), self._truncation_length, actions.shape[-1]) + ss = torch.empty(len(states), self._truncation_length, states.shape[-1]) + pss = torch.empty(len(states), policy_states.shape[-1]) + lengths = torch.empty(len(states), dtype=torch.long) + + for i in range(len(states)): + # determine the begin of a sequence + begin_seq = max(i - self._truncation_length + 1, 0) + end_seq = i + 1 + + # the sequence may contain more than one trajectory, we need to cut it so that it contains only one + lasts_absorbing = last[begin_seq - 1: i].int() + absorbing[begin_seq - 1: i].int() + begin_traj = torch.where(lasts_absorbing > 0) + sequence_is_shorter_than_requested = len(*begin_traj) > 0 + if sequence_is_shorter_than_requested: + begin_seq = begin_seq + begin_traj[0][-1] + + # get the sequences + states_old_seq = states_old[begin_seq:end_seq] + states_seq = states[begin_seq:end_seq] + actions_seq = actions[begin_seq:end_seq] + next_states_seq = next_states[begin_seq:end_seq] + + # apply padding + length_seq = len(states_seq) + padded_states_old = torch.concatenate([states_old_seq, + torch.zeros((self._truncation_length - states_seq.shape[0], + states_seq.shape[1]))]) + padded_states = torch.concatenate([states_seq, + torch.zeros((self._truncation_length - states_seq.shape[0], + states_seq.shape[1]))]) + padded_next_states = torch.concatenate([next_states_seq, + torch.zeros((self._truncation_length - next_states_seq.shape[0], + next_states_seq.shape[1]))]) + padded_action_seq = torch.concatenate([actions_seq, + torch.zeros((self._truncation_length - actions_seq.shape[0], + actions_seq.shape[1]))]) + + s_old[i] = padded_states_old + s[i] = padded_states + ps[i] = policy_states[begin_seq] + a[i] = padded_action_seq + ss[i] = padded_next_states + pss[i] = policy_next_states[begin_seq] + + lengths[i] = length_seq + + return s_old, s, ps, a, ss, pss, lengths def _update_policy(self, obs, pi_h, act, lengths, adv, old_log_p): for epoch in range(self._n_epochs_policy()): diff --git a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py index 7c0d8698..979a91f6 100644 --- a/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py +++ b/mushroom_rl/algorithms/actor_critic/deep_actor_critic/trpo.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F -from mushroom_rl.core import Agent +from mushroom_rl.algorithms.actor_critic.deep_actor_critic import OnPolicyDeepAC from mushroom_rl.approximators import Regressor from mushroom_rl.approximators.parametric import TorchApproximator from mushroom_rl.utils.torch import TorchUtils @@ -13,7 +13,7 @@ from mushroom_rl.rl_utils.parameters import to_parameter -class TRPO(Agent): +class TRPO(OnPolicyDeepAC): """ Trust Region Policy optimization algorithm. "Trust Region Policy Optimization". @@ -83,6 +83,7 @@ def __init__(self, mdp_info, policy, critic_params, ent_coeff=0., max_kl=.001, l def fit(self, dataset): state, action, reward, next_state, absorbing, last = dataset.parse(to='torch') + state, next_state, state_old = self._preprocess_state(state, next_state) v_target, adv = compute_gae(self._V, state, next_state, reward, absorbing, last, self.mdp_info.gamma, self._lambda()) @@ -93,8 +94,8 @@ def fit(self, dataset): # Policy update self._old_policy = deepcopy(self.policy) - old_pol_dist = self._old_policy.distribution_t(state) - old_log_prob = self._old_policy.log_prob_t(state, action).detach() + old_pol_dist = self._old_policy.distribution_t(state_old) + old_log_prob = self._old_policy.log_prob_t(state_old, action).detach() TorchUtils.zero_grad(self.policy.parameters()) loss = self._compute_loss(state, action, adv, old_log_prob) diff --git a/mushroom_rl/core/agent.py b/mushroom_rl/core/agent.py index d2a314ab..f9f526d7 100644 --- a/mushroom_rl/core/agent.py +++ b/mushroom_rl/core/agent.py @@ -103,34 +103,6 @@ def draw_action(self, state, policy_state=None): return self._convert_to_env_backend(action), self._convert_to_env_backend(next_policy_state) - def _agent_preprocess(self, state): - """ - Applies all the agent's preprocessors to the state. - - Args: - state (Array): the state where the agent is; - - Returns: - The preprocessed state. - - """ - for p in self._agent_preprocessors: - state = p(state) - return state - - def _update_agent_preprocessor(self, state): - """ - Updates the stats of all the agent's preprocessors given the state. - - Args: - state (Array): the state where the agent is; - - """ - for i, p in enumerate(self._agent_preprocessors, 1): - p.update(state) - if i < len(self._agent_preprocessors): - state = p(state) - def episode_start(self, initial_state, episode_info): """ Called by the Core when a new episode starts. @@ -214,6 +186,34 @@ def _convert_to_env_backend(self, array): def _convert_to_agent_backend(self, array): return self._agent_backend.convert_to_backend(self._env_backend, array) + def _agent_preprocess(self, state): + """ + Applies all the agent's preprocessors to the state. + + Args: + state (Array): the state where the agent is; + + Returns: + The preprocessed state. + + """ + for p in self._agent_preprocessors: + state = p(state) + return state + + def _update_agent_preprocessor(self, state): + """ + Updates the stats of all the agent's preprocessors given the state. + + Args: + state (Array): the state where the agent is; + + """ + for i, p in enumerate(self._agent_preprocessors, 1): + p.update(state) + if i < len(self._agent_preprocessors): + state = p(state) + @property def info(self): return self._info diff --git a/mushroom_rl/policy/recurrent_torch_policy.py b/mushroom_rl/policy/recurrent_torch_policy.py index 5598704c..51960841 100644 --- a/mushroom_rl/policy/recurrent_torch_policy.py +++ b/mushroom_rl/policy/recurrent_torch_policy.py @@ -22,7 +22,7 @@ def draw_action(self, state, policy_state): state = TorchUtils.to_float_tensor(state) policy_state = torch.as_tensor(policy_state) a, policy_state = self.draw_action_t(state, policy_state) - return torch.squeeze(a, dim=0).detach().cpu().numpy(), policy_state + return torch.squeeze(a, dim=0), policy_state def draw_action_t(self, state, policy_state): lengths = torch.tensor([1])