Skip to content

Commit

Permalink
Fix dataset.compute_J for numpy 2.x (#162)
Browse files Browse the repository at this point in the history
* bug fix for numpy backend split_episodes

* fixed bug in compute_j with numpy 2.x
  • Loading branch information
cube1324 authored Feb 10, 2025
1 parent df23975 commit 3295221
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mushroom_rl/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def compute_J(self, gamma=1.):

if len(r_ep.shape) == 1:
r_ep = r_ep.unsqueeze(0)
if hasattr(r_ep, 'device'):
if self._dataset_info.backend == 'torch':
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype, device=r_ep.device)
else:
js = self._array_backend.zeros(r_ep.shape[0], dtype=r_ep.dtype)
Expand Down
2 changes: 1 addition & 1 deletion mushroom_rl/utils/episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def split_episodes(last, *arrays):
episodes_arrays = []

for array in arrays:
array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if hasattr(array, 'device') else None)
array_ep = backend.zeros(n_episodes, max_episode_steps, *array.shape[1:], dtype=array.dtype, device=array.device if backend.get_backend_name() == "torch" else None)

array_ep[row_idx, colum_idx] = array
episodes_arrays.append(array_ep)
Expand Down

0 comments on commit 3295221

Please sign in to comment.