Skip to content

Commit

Permalink
Revert back to old observation_size API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721080636
Change-Id: I09f74527dab8281fea0aa349936130055e13756b
  • Loading branch information
btaba authored and copybara-github committed Jan 29, 2025
1 parent 4424fa1 commit 0f3adda
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_can_create_all_environments(self, env_name: str) -> None:
state = jax.jit(env.reset)(jax.random.PRNGKey(42))
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
self.assertIsNotNone(state)
self.assertEqual(state.obs.shape, env.observation_size)
self.assertEqual(state.obs.shape[0], env.observation_size)
self.assertFalse(jp.isnan(state.data.qpos).any())


Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/_src/locomotion/locomotion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_can_create_all_environments(self, env_name: str) -> None:
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
self.assertIsNotNone(state)
obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs)
obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape
self.assertEqual(obs_shape, env.observation_size)
self.assertFalse(jp.isnan(state.data.qpos).any())

Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/_src/manipulation/manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_can_create_all_environments(self, env_name: str) -> None:
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
self.assertIsNotNone(state)
obs_shape = jax.tree_util.tree_map(lambda x: x.shape, state.obs)
obs_shape = obs_shape[0] if isinstance(obs_shape, tuple) else obs_shape
self.assertEqual(obs_shape, env.observation_size)
self.assertFalse(jp.isnan(state.data.qpos).any())

Expand Down
7 changes: 5 additions & 2 deletions mujoco_playground/_src/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,11 @@ def n_substeps(self) -> int:

@property
def observation_size(self) -> ObservationSize:
out = jax.eval_shape(self.reset, jax.random.PRNGKey(0))
return jax.tree_util.tree_map(lambda x: x.shape, out.obs)
abstract_state = jax.eval_shape(self.reset, jax.random.PRNGKey(0))
obs = abstract_state.obs
if isinstance(obs, Mapping):
return jax.tree_util.tree_map(lambda x: x.shape, obs)
return obs.shape[-1]

def render(
self,
Expand Down

0 comments on commit 0f3adda

Please sign in to comment.