Skip to content

Commit

Permalink
Improved serialization functions
Browse files Browse the repository at this point in the history
- now the array backend can get the preferred serialization method
- fixed issues in serialization of preprocessors
  • Loading branch information
boris-il-forte committed Jan 22, 2025
1 parent edd0bd5 commit 63579d7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
16 changes: 16 additions & 0 deletions mushroom_rl/core/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class ArrayBackend(object):
def get_backend_name():
raise NotImplementedError

@staticmethod
def get_backend_serialization():
raise NotImplementedError

@staticmethod
def get_array_backend(backend_name):
assert type(backend_name) == str, f"Backend has to be string, not {type(backend_name).__name__}."
Expand Down Expand Up @@ -174,6 +178,10 @@ class NumpyBackend(ArrayBackend):
def get_backend_name():
return 'numpy'

@staticmethod
def get_backend_serialization():
return 'numpy'

@staticmethod
def to_numpy(array):
return array
Expand Down Expand Up @@ -303,6 +311,10 @@ class TorchBackend(ArrayBackend):
def get_backend_name():
return 'torch'

@staticmethod
def get_backend_serialization():
return 'torch'

@staticmethod
def to_numpy(array):
return None if array is None else array.detach().cpu().numpy()
Expand Down Expand Up @@ -438,6 +450,10 @@ class ListBackend(ArrayBackend):
def get_backend_name():
return 'list'

@staticmethod
def get_backend_serialization():
return 'numpy'

@staticmethod
def to_numpy(array):
return np.array(array)
Expand Down
8 changes: 3 additions & 5 deletions mushroom_rl/rl_utils/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import numpy as np

from mushroom_rl.core import Serializable, ArrayBackend
from mushroom_rl.rl_utils.running_stats import RunningStandardization

Expand Down Expand Up @@ -122,9 +120,9 @@ def __init__(self, mdp_info, backend, clip_obs=10., alpha=1e-32):
self._add_save_attr(
_array_backend='pickle',
_run_norm_obs='primitive',
_obs_mask='numpy',
_obs_mean='numpy',
_obs_delta='numpy'
_obs_mask=self._array_backend.get_backend_serialization(),
_obs_mean=self._array_backend.get_backend_serialization(),
_obs_delta=self._array_backend.get_backend_serialization()
)

def __call__(self, obs):
Expand Down

0 comments on commit 63579d7

Please sign in to comment.