Skip to content

Commit

Permalink
Parallelize value functions and J reward (#159)
Browse files Browse the repository at this point in the history
* Optimize compute J, gae and montecarlo adv

* Update test episode and value function

* Fix dataset parse

* Fix gae sign

* Fix a2c test
  • Loading branch information
paolo-magliano authored Jan 29, 2025
1 parent 63579d7 commit 31596b8
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 41 deletions.
4 changes: 2 additions & 2 deletions mushroom_rl/algorithms/actor_critic/deep_actor_critic/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(self, mdp_info, policy, actor_optimizer, critic_params,
)

def fit(self, dataset):
state, action, reward, next_state, absorbing, _ = dataset.parse(to='torch')
state, action, reward, next_state, absorbing, last = dataset.parse(to='torch')

v, adv = compute_advantage_montecarlo(self._V, state, next_state,
reward, absorbing,
reward, absorbing, last,
self.mdp_info.gamma)
self._V.fit(state, v, **self._critic_fit_params)

Expand Down
32 changes: 30 additions & 2 deletions mushroom_rl/core/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def shape(array):
@staticmethod
def full(shape, value):
raise NotImplementedError


@staticmethod
def nonzero(array):
raise NotImplementedError

@staticmethod
def repeat(array, repeats):
raise NotImplementedError

class NumpyBackend(ArrayBackend):
@staticmethod
Expand All @@ -188,7 +195,12 @@ def to_numpy(array):

@staticmethod
def to_torch(array):
return None if array is None else torch.from_numpy(array).to(TorchUtils.get_device())
if array is None:
return None
else:
if array.dtype == np.float64:
array = array.astype(np.float32)
return torch.from_numpy(array).to(TorchUtils.get_device())

@staticmethod
def convert_to_backend(cls, array):
Expand Down Expand Up @@ -303,6 +315,14 @@ def shape(array):
@staticmethod
def full(shape, value):
return np.full(shape, value)

@staticmethod
def nonzero(array):
return np.flatnonzero(array)

@staticmethod
def repeat(array, repeats):
return np.repeat(array, repeats)


class TorchBackend(ArrayBackend):
Expand Down Expand Up @@ -443,6 +463,14 @@ def shape(array):
@staticmethod
def full(shape, value):
return torch.full(shape, value)

@staticmethod
def nonzero(array):
return torch.nonzero(array)

@staticmethod
def repeat(array, repeats):
return torch.repeat_interleave(array, repeats)

class ListBackend(ArrayBackend):

Expand Down
30 changes: 14 additions & 16 deletions mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ._impl import *

from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes

class DatasetInfo(Serializable):
def __init__(self, backend, device, horizon, gamma, state_shape, state_dtype, action_shape, action_dtype,
Expand Down Expand Up @@ -473,22 +474,19 @@ def compute_J(self, gamma=1.):
The cumulative discounted reward of each episode in the dataset.
"""
js = list()

j = 0.
episode_steps = 0
for i in range(len(self)):
j += gamma ** episode_steps * self.reward[i]
episode_steps += 1
if self.last[i] or i == len(self) - 1:
js.append(j)
j = 0.
episode_steps = 0

if len(js) == 0:
js = [0.]

return self._array_backend.from_list(js)
r_ep = split_episodes(self.last, self.reward)

if len(r_ep.shape) == 1:
r_ep = r_ep.unsqueeze(0)
if hasattr(r_ep, 'device'):
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device)
else:
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype)

for k in range(r_ep.shape[1]):
js += gamma ** k * r_ep[..., k]

return js

def compute_metrics(self, gamma=1.):
"""
Expand Down
40 changes: 23 additions & 17 deletions mushroom_rl/rl_utils/value_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes


def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma):
def compute_advantage_montecarlo(V, s, ss, r, absorbing, last, gamma):
"""
Function to estimate the advantage and new value function target
over a dataset. The value function is estimated using rollouts
Expand All @@ -24,18 +24,21 @@ def compute_advantage_montecarlo(V, s, ss, r, absorbing, gamma):
"""
with torch.no_grad():
r = r.squeeze()
q = torch.zeros(len(r))
v = V(s).squeeze()

q_next = V(ss[-1]).squeeze().item()
for rev_k in range(len(r)):
k = len(r) - rev_k - 1
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int())
q[k] = q_next
r_ep, absorbing_ep, ss_ep = split_episodes(last, r, absorbing, ss)
q_ep = torch.zeros_like(r_ep, dtype=torch.float32)
q_next_ep = V(ss_ep[..., -1, :]).squeeze()

for rev_k in range(r_ep.shape[-1]):
k = r_ep.shape[-1] - rev_k - 1
q_next_ep = r_ep[..., k] + gamma * q_next_ep * (1 - absorbing_ep[..., k].int())
q_ep[..., k] = q_next_ep

q = unsplit_episodes(last, q_ep)
adv = q - v
return q[:, None], adv[:, None]

return q[:, None], adv[:, None]

def compute_advantage(V, s, ss, r, absorbing, gamma):
"""
Expand Down Expand Up @@ -97,13 +100,16 @@ def compute_gae(V, s, ss, r, absorbing, last, gamma, lam):
with torch.no_grad():
v = V(s)
v_next = V(ss)
gen_adv = torch.empty_like(v)
for rev_k in range(len(v)):
k = len(v) - rev_k - 1
if last[k] or rev_k == 0:
gen_adv[k] = r[k] - v[k]
if not absorbing[k]:
gen_adv[k] += gamma * v_next[k]

v_ep, v_next_ep, r_ep, absorbing_ep = split_episodes(last, v.squeeze(), v_next.squeeze(), r, absorbing)
gen_adv_ep = torch.zeros_like(v_ep)
for rev_k in range(v_ep.shape[-1]):
k = v_ep.shape[-1] - rev_k - 1
if rev_k == 0:
gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k]
else:
gen_adv[k] = r[k] + gamma * v_next[k] - v[k] + gamma * lam * gen_adv[k + 1]
gen_adv_ep[..., k] = r_ep[..., k] - v_ep[..., k] + (1 - absorbing_ep[..., k].int()) * gamma * v_next_ep[..., k] + gamma * lam * gen_adv_ep[..., k + 1]

gen_adv = unsplit_episodes(last, gen_adv_ep).unsqueeze(-1)

return gen_adv + v, gen_adv
61 changes: 61 additions & 0 deletions mushroom_rl/utils/episodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from mushroom_rl.core.array_backend import ArrayBackend

def split_episodes(last, *arrays):
"""
Split a array from shape (n_steps) to (n_episodes, max_episode_steps).
"""
backend = ArrayBackend.get_array_backend_from(last)

if last.sum().item() <= 1:
return arrays if len(arrays) > 1 else arrays[0]

row_idx, colum_idx, n_episodes, max_episode_steps = _get_episode_idx(last, backend)
episodes_arrays = []

for array in arrays:
array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None)

array_ep[row_idx, colum_idx] = array
episodes_arrays.append(array_ep)

return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0]

def unsplit_episodes(last, *episodes_arrays):
"""
Unsplit a array from shape (n_episodes, max_episode_steps) to (n_steps).
"""

if last.sum().item() <= 1:
return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0]

row_idx, colum_idx, _, _ = _get_episode_idx(last)
arrays = []

for episode_array in episodes_arrays:
array = episode_array[row_idx, colum_idx]
arrays.append(array)

return arrays if len(arrays) > 1 else arrays[0]

def _get_episode_idx(last, backend=None):
if backend is None:
backend = ArrayBackend.get_array_backend_from(last)

n_episodes = last.sum()
last_idx = backend.nonzero(last).squeeze()
first_steps = backend.from_list([last_idx[0] + 1])
if hasattr(last, 'device'):
first_steps = first_steps.to(last.device)
episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]])
max_episode_steps = episode_steps.max()

start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1])
range_n_episodes = backend.arange(0, n_episodes, dtype=int)
range_len = backend.arange(0, last.shape[0], dtype=int)
if hasattr(last, 'device'):
range_n_episodes = range_n_episodes.to(last.device)
range_len = range_len.to(last.device)
row_idx = backend.repeat(range_n_episodes, episode_steps)
colum_idx = range_len - start_idx[row_idx]

return row_idx, colum_idx, n_episodes, max_episode_steps
4 changes: 3 additions & 1 deletion tests/algorithms/test_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_a2c():
agent = learn_a2c()

w = agent.policy.get_weights()
w_test = np.array([0.9382279 , -1.8847059 , -0.13790752, -0.00786441])
w_test = np.array([ 0.9389272 ,-1.8838323 ,-0.13710725,-0.00668973])

assert np.allclose(w, w_test)

Expand All @@ -95,3 +95,5 @@ def test_a2c_save(tmpdir):
print(save_attr, load_attr)

tu.assert_eq(save_attr, load_attr)

test_a2c()
4 changes: 1 addition & 3 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,4 @@ def test_dataset_loading(tmpdir):

assert len(dataset.info) == len(new_dataset.info)
for key in dataset.info:
assert np.array_equal(dataset.info[key], new_dataset.info[key])


assert np.array_equal(dataset.info[key], new_dataset.info[key])
92 changes: 92 additions & 0 deletions tests/rl_utils/test_value_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from mushroom_rl.policy import DeterministicPolicy
from mushroom_rl.environments.segway import Segway
from mushroom_rl.core import Core, Agent
from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator
from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo

from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes

def test_compute_advantage_montecarlo():
def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma):
with torch.no_grad():
r = r.squeeze()
q = torch.zeros(len(r))
v = V(s).squeeze()

for rev_k in range(len(r)):
k = len(r) - rev_k - 1
if last[k] or rev_k == 0:
q_next = V(ss[k]).squeeze().item()
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int())
q[k] = q_next

adv = q - v
return q[:, None], adv[:, None]

torch.manual_seed(42)
_value_functions_tester(compute_advantage_montecarlo, advantage_montecarlo, 0.99)

def test_compute_gae():
def gae(V, s, ss, r, absorbing, last, gamma, lam):
with torch.no_grad():
v = V(s)
v_next = V(ss)
gen_adv = torch.empty_like(v)
for rev_k in range(len(v)):
k = len(v) - rev_k - 1
if last[k] or rev_k == 0:
gen_adv[k] = r[k] - v[k]
if not absorbing[k]:
gen_adv[k] += gamma * v_next[k]
else:
gen_adv[k] = r[k] - v[k] + gamma * v_next[k] + gamma * lam * gen_adv[k + 1]
return gen_adv + v, gen_adv

torch.manual_seed(42)
_value_functions_tester(compute_gae, gae, 0.99, 0.95)

def _value_functions_tester(test_fun, correct_fun, *args):
mdp = Segway()
V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}})

state, action, reward, next_state, absorbing, last = _get_episodes(mdp, 10)

correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args)
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args)

assert torch.allclose(v, correct_v)
assert torch.allclose(adv, correct_adv)

V.fit(state, correct_v)

correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args)
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args)

assert torch.allclose(v, correct_v)
assert torch.allclose(adv, correct_adv)

def _get_episodes(mdp, n_episodes=100):
mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0)

approximator = Regressor(LinearApproximator,
input_shape=mdp.info.observation_space.shape,
output_shape=mdp.info.action_space.shape,
weights=mu)

policy = DeterministicPolicy(approximator)

agent = Agent(mdp.info, policy)
core = Core(agent, mdp)
dataset = core.evaluate(n_episodes=n_episodes)

return dataset.parse(to='torch')

class Net(torch.nn.Module):
def __init__(self, input_shape, output_shape, **kwargs):
super().__init__()
self._q = torch.nn.Linear(input_shape[0], output_shape[0])

def forward(self, x):
return self._q(x.float())
Loading

0 comments on commit 31596b8

Please sign in to comment.