diff --git a/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py b/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py index f8ec3ab..124838e 100644 --- a/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py +++ b/mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py @@ -33,7 +33,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.reset)(jax.random.PRNGKey(42)) state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) - self.assertEqual(state.obs.shape, env.observation_size) + self.assertEqual(state.obs.shape[0], env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/locomotion/locomotion_test.py b/mujoco_playground/_src/locomotion/locomotion_test.py index 971a849..6afd8f8 100644 --- a/mujoco_playground/_src/locomotion/locomotion_test.py +++ b/mujoco_playground/_src/locomotion/locomotion_test.py @@ -36,6 +36,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs) + obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape self.assertEqual(obs_shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/manipulation/manipulation_test.py b/mujoco_playground/_src/manipulation/manipulation_test.py index 971aa19..0b52d2e 100644 --- a/mujoco_playground/_src/manipulation/manipulation_test.py +++ b/mujoco_playground/_src/manipulation/manipulation_test.py @@ -36,6 +36,7 @@ def test_can_create_all_environments(self, env_name: str) -> None: state = jax.jit(env.step)(state, jp.zeros(env.action_size)) self.assertIsNotNone(state) obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs) + obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape self.assertEqual(obs_shape, env.observation_size) self.assertFalse(jp.isnan(state.data.qpos).any()) diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py index 4871554..4e558df 100644 --- a/mujoco_playground/_src/mjx_env.py +++ b/mujoco_playground/_src/mjx_env.py @@ -270,8 +270,11 @@ def n_substeps(self) -> int: @property def observation_size(self) -> ObservationSize: - out = jax.eval_shape(self.reset, jax.random.PRNGKey(0)) - return jax.tree_util.tree_map(lambda x: x.shape, out.obs) + abstract_state = jax.eval_shape(self.reset, jax.random.PRNGKey(0)) + obs = abstract_state.obs + if isinstance(obs, Mapping): + return jax.tree_util.tree_map(lambda x: x.shape, obs) + return obs.shape[-1] def render( self,