-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Parallelize value functions and J reward (#159)
* Optimize compute J, gae and montecarlo adv * Update test episode and value function * Fix dataset parse * Fix gae sign * Fix a2c test
- Loading branch information
1 parent
63579d7
commit 31596b8
Showing
9 changed files
with
286 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from mushroom_rl.core.array_backend import ArrayBackend | ||
|
||
def split_episodes(last, *arrays): | ||
""" | ||
Split a array from shape (n_steps) to (n_episodes, max_episode_steps). | ||
""" | ||
backend = ArrayBackend.get_array_backend_from(last) | ||
|
||
if last.sum().item() <= 1: | ||
return arrays if len(arrays) > 1 else arrays[0] | ||
|
||
row_idx, colum_idx, n_episodes, max_episode_steps = _get_episode_idx(last, backend) | ||
episodes_arrays = [] | ||
|
||
for array in arrays: | ||
array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None) | ||
|
||
array_ep[row_idx, colum_idx] = array | ||
episodes_arrays.append(array_ep) | ||
|
||
return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] | ||
|
||
def unsplit_episodes(last, *episodes_arrays): | ||
""" | ||
Unsplit a array from shape (n_episodes, max_episode_steps) to (n_steps). | ||
""" | ||
|
||
if last.sum().item() <= 1: | ||
return episodes_arrays if len(episodes_arrays) > 1 else episodes_arrays[0] | ||
|
||
row_idx, colum_idx, _, _ = _get_episode_idx(last) | ||
arrays = [] | ||
|
||
for episode_array in episodes_arrays: | ||
array = episode_array[row_idx, colum_idx] | ||
arrays.append(array) | ||
|
||
return arrays if len(arrays) > 1 else arrays[0] | ||
|
||
def _get_episode_idx(last, backend=None): | ||
if backend is None: | ||
backend = ArrayBackend.get_array_backend_from(last) | ||
|
||
n_episodes = last.sum() | ||
last_idx = backend.nonzero(last).squeeze() | ||
first_steps = backend.from_list([last_idx[0] + 1]) | ||
if hasattr(last, 'device'): | ||
first_steps = first_steps.to(last.device) | ||
episode_steps = backend.concatenate([first_steps, last_idx[1:] - last_idx[:-1]]) | ||
max_episode_steps = episode_steps.max() | ||
|
||
start_idx = backend.concatenate([backend.zeros(1, dtype=int, device=last.device if hasattr(last, 'device') else None), last_idx[:-1] + 1]) | ||
range_n_episodes = backend.arange(0, n_episodes, dtype=int) | ||
range_len = backend.arange(0, last.shape[0], dtype=int) | ||
if hasattr(last, 'device'): | ||
range_n_episodes = range_n_episodes.to(last.device) | ||
range_len = range_len.to(last.device) | ||
row_idx = backend.repeat(range_n_episodes, episode_steps) | ||
colum_idx = range_len - start_idx[row_idx] | ||
|
||
return row_idx, colum_idx, n_episodes, max_episode_steps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
from mushroom_rl.policy import DeterministicPolicy | ||
from mushroom_rl.environments.segway import Segway | ||
from mushroom_rl.core import Core, Agent | ||
from mushroom_rl.approximators import Regressor | ||
from mushroom_rl.approximators.parametric import LinearApproximator, TorchApproximator | ||
from mushroom_rl.rl_utils.value_functions import compute_gae, compute_advantage_montecarlo | ||
|
||
from mushroom_rl.utils.episodes import split_episodes, unsplit_episodes | ||
|
||
def test_compute_advantage_montecarlo(): | ||
def advantage_montecarlo(V, s, ss, r, absorbing, last, gamma): | ||
with torch.no_grad(): | ||
r = r.squeeze() | ||
q = torch.zeros(len(r)) | ||
v = V(s).squeeze() | ||
|
||
for rev_k in range(len(r)): | ||
k = len(r) - rev_k - 1 | ||
if last[k] or rev_k == 0: | ||
q_next = V(ss[k]).squeeze().item() | ||
q_next = r[k] + gamma * q_next * (1 - absorbing[k].int()) | ||
q[k] = q_next | ||
|
||
adv = q - v | ||
return q[:, None], adv[:, None] | ||
|
||
torch.manual_seed(42) | ||
_value_functions_tester(compute_advantage_montecarlo, advantage_montecarlo, 0.99) | ||
|
||
def test_compute_gae(): | ||
def gae(V, s, ss, r, absorbing, last, gamma, lam): | ||
with torch.no_grad(): | ||
v = V(s) | ||
v_next = V(ss) | ||
gen_adv = torch.empty_like(v) | ||
for rev_k in range(len(v)): | ||
k = len(v) - rev_k - 1 | ||
if last[k] or rev_k == 0: | ||
gen_adv[k] = r[k] - v[k] | ||
if not absorbing[k]: | ||
gen_adv[k] += gamma * v_next[k] | ||
else: | ||
gen_adv[k] = r[k] - v[k] + gamma * v_next[k] + gamma * lam * gen_adv[k + 1] | ||
return gen_adv + v, gen_adv | ||
|
||
torch.manual_seed(42) | ||
_value_functions_tester(compute_gae, gae, 0.99, 0.95) | ||
|
||
def _value_functions_tester(test_fun, correct_fun, *args): | ||
mdp = Segway() | ||
V = Regressor(TorchApproximator, input_shape=mdp.info.observation_space.shape, output_shape=(1,), network=Net, loss=torch.nn.MSELoss(), optimizer={'class': torch.optim.Adam, 'params': {'lr': 0.001}}) | ||
|
||
state, action, reward, next_state, absorbing, last = _get_episodes(mdp, 10) | ||
|
||
correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) | ||
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) | ||
|
||
assert torch.allclose(v, correct_v) | ||
assert torch.allclose(adv, correct_adv) | ||
|
||
V.fit(state, correct_v) | ||
|
||
correct_v, correct_adv = correct_fun(V, state, next_state, reward, absorbing, last, *args) | ||
v, adv = test_fun(V, state, next_state, reward, absorbing, last, *args) | ||
|
||
assert torch.allclose(v, correct_v) | ||
assert torch.allclose(adv, correct_adv) | ||
|
||
def _get_episodes(mdp, n_episodes=100): | ||
mu = torch.tensor([6.31154476, 3.32346271, 0.49648221]).unsqueeze(0) | ||
|
||
approximator = Regressor(LinearApproximator, | ||
input_shape=mdp.info.observation_space.shape, | ||
output_shape=mdp.info.action_space.shape, | ||
weights=mu) | ||
|
||
policy = DeterministicPolicy(approximator) | ||
|
||
agent = Agent(mdp.info, policy) | ||
core = Core(agent, mdp) | ||
dataset = core.evaluate(n_episodes=n_episodes) | ||
|
||
return dataset.parse(to='torch') | ||
|
||
class Net(torch.nn.Module): | ||
def __init__(self, input_shape, output_shape, **kwargs): | ||
super().__init__() | ||
self._q = torch.nn.Linear(input_shape[0], output_shape[0]) | ||
|
||
def forward(self, x): | ||
return self._q(x.float()) |
Oops, something went wrong.