diff --git a/mushroom_rl/core/array_backend.py b/mushroom_rl/core/array_backend.py index 706357c4..bd0a8c48 100644 --- a/mushroom_rl/core/array_backend.py +++ b/mushroom_rl/core/array_backend.py @@ -11,6 +11,10 @@ class ArrayBackend(object): def get_backend_name(): raise NotImplementedError + @staticmethod + def get_backend_serialization(): + raise NotImplementedError + @staticmethod def get_array_backend(backend_name): assert type(backend_name) == str, f"Backend has to be string, not {type(backend_name).__name__}." @@ -174,6 +178,10 @@ class NumpyBackend(ArrayBackend): def get_backend_name(): return 'numpy' + @staticmethod + def get_backend_serialization(): + return 'numpy' + @staticmethod def to_numpy(array): return array @@ -303,6 +311,10 @@ class TorchBackend(ArrayBackend): def get_backend_name(): return 'torch' + @staticmethod + def get_backend_serialization(): + return 'torch' + @staticmethod def to_numpy(array): return None if array is None else array.detach().cpu().numpy() @@ -438,6 +450,10 @@ class ListBackend(ArrayBackend): def get_backend_name(): return 'list' + @staticmethod + def get_backend_serialization(): + return 'numpy' + @staticmethod def to_numpy(array): return np.array(array) diff --git a/mushroom_rl/rl_utils/preprocessors.py b/mushroom_rl/rl_utils/preprocessors.py index 921b2e41..30721080 100644 --- a/mushroom_rl/rl_utils/preprocessors.py +++ b/mushroom_rl/rl_utils/preprocessors.py @@ -1,5 +1,3 @@ -import numpy as np - from mushroom_rl.core import Serializable, ArrayBackend from mushroom_rl.rl_utils.running_stats import RunningStandardization @@ -122,9 +120,9 @@ def __init__(self, mdp_info, backend, clip_obs=10., alpha=1e-32): self._add_save_attr( _array_backend='pickle', _run_norm_obs='primitive', - _obs_mask='numpy', - _obs_mean='numpy', - _obs_delta='numpy' + _obs_mask=self._array_backend.get_backend_serialization(), + _obs_mean=self._array_backend.get_backend_serialization(), + _obs_delta=self._array_backend.get_backend_serialization() ) def __call__(self, obs):