Skip to content

Commit

Permalink
Introduced new preprocessors.
Browse files Browse the repository at this point in the history
- now there are two types of preprocessors; agent_preprocessors and core_preprocessors.
- the former does not change the dataset returned by the core, while the latter does.
  • Loading branch information
robfiras committed Jan 18, 2024
1 parent b4a0a02 commit f408f2d
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/plotting_and_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def experiment(n_epochs, n_iterations, ep_per_run):

# normalization callback
prepro = MinMaxPreprocessor(mdp_info=mdp.info)
agent.add_preprocessor(prepro)
agent.add_core_preprocessor(prepro)

# plotting callback
plotter = PlotDataset(mdp.info, obs_normalized=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
)

# add the standardization preprocessor
self._preprocessors.append(StandardizationPreprocessor(mdp_info))
self._core_preprocessors.append(StandardizationPreprocessor(mdp_info))

def divide_state_to_env_hidden_batch(self, states):
assert len(states.shape) > 1, "This function only divides batches of states."
Expand Down
59 changes: 51 additions & 8 deletions mushroom_rl/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'):
self._agent_backend = ArrayBackend.get_array_backend(backend)
self._env_backend = ArrayBackend.get_array_backend(self.mdp_info.backend)

self._preprocessors = list()
self._core_preprocessors = list()
self._agent_preprocessors = list()

self._logger = None

Expand All @@ -62,7 +63,8 @@ def __init__(self, mdp_info, policy, is_episodic=False, backend='numpy'):
_info='mushroom',
_agent_backend='primitive',
_env_backend='primitive',
_preprocessors='mushroom',
_core_preprocessors='mushroom',
_agent_preprocessors='mushroom',
_logger='none'
)

Expand All @@ -89,8 +91,10 @@ def draw_action(self, state, policy_state=None):
The action to be executed.
"""

if self.next_action is None:
state = self._convert_to_agent_backend(state)
state = self._agent_preprocess(state)
policy_state = self._convert_to_agent_backend(policy_state)
action, next_policy_state = self.policy.draw_action(state, policy_state)
else:
Expand All @@ -100,6 +104,34 @@ def draw_action(self, state, policy_state=None):

return self._convert_to_env_backend(action), self._convert_to_env_backend(next_policy_state)

def _agent_preprocess(self, state):
"""
Applies all the agent's preprocessors to the state.
Args:
state (Array): the state where the agent is;
Returns:
The preprocessed state.
"""
for p in self._agent_preprocessors:
state = p(state)
return state

def _update_agent_preprocessor(self, state):
"""
Updates the stats of all the agent's preprocessors given the state.
Args:
state (Array): the state where the agent is;
"""
for i, p in enumerate(self._agent_preprocessors, 1):
p.update(state)
if i < len(self._agent_preprocessors):
state = p(state)

def episode_start(self, initial_state, episode_info):
"""
Called by the Core when a new episode starts.
Expand Down Expand Up @@ -147,24 +179,35 @@ def set_logger(self, logger):
"""
self._logger = logger

def add_preprocessor(self, preprocessor):
def add_core_preprocessor(self, preprocessor):
"""
Add preprocessor to the core's preprocessor list. The preprocessors are applied in order.
Args:
preprocessor (object): state preprocessors to be applied
to state variables before feeding them to the agent.
"""
self._core_preprocessors.append(preprocessor)

def add_agent_preprocessor(self, preprocessor):
"""
Add preprocessor to the preprocessor list. The preprocessors are applied in order.
Add preprocessor to the agent's preprocessor list. The preprocessors are applied in order.
Args:
preprocessor (object): state preprocessors to be applied
to state variables before feeding them to the agent.
"""
self._preprocessors.append(preprocessor)
self._agent_preprocessors.append(preprocessor)

@property
def preprocessors(self):
def core_preprocessors(self):
"""
Access to state preprocessors stored in the agent.
Access to core's state preprocessors stored in the agent.
"""
return self._preprocessors
return self._core_preprocessors

def _convert_to_env_backend(self, array):
return self._env_backend.to_backend_array(self._agent_backend, array)
Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _preprocess(self, state):
The preprocessed state.
"""
for p in self.agent.preprocessors:
for p in self.agent.core_preprocessors:
p.update(state)
state = p(state)

Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/core/vectorized_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _preprocess(self, states):
The preprocessed states.
"""
for p in self.agent.preprocessors:
for p in self.agent.core_preprocessors:
p.update(states)
states = p(states)

Expand Down
6 changes: 3 additions & 3 deletions tests/utils/test_preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_normalizing_preprocessor(tmpdir):
agent = DQN(mdp.info, pi, NumpyTorchApproximator, approximator_params=approximator_params, **alg_params)

norm_box = MinMaxPreprocessor(mdp_info=mdp.info, clip_obs=5.0, alpha=0.001)
agent.add_preprocessor(norm_box)
agent.add_core_preprocessor(norm_box)

core = Core(agent, mdp)

Expand All @@ -91,9 +91,9 @@ def test_normalizing_preprocessor(tmpdir):

agent_new = DQN.load(tmpdir / 'agent.msh')

assert len(agent_new.preprocessors) == 1
assert len(agent_new.core_preprocessors) == 1

norm_box_agent = agent_new.preprocessors[0]
norm_box_agent = agent_new.core_preprocessors[0]

state_dict2 = norm_box_new.__dict__
state_dict3 = norm_box_agent.__dict__
Expand Down

0 comments on commit f408f2d

Please sign in to comment.