From 33c87797ea627ab2825c34de3c657f1e3eaa0396 Mon Sep 17 00:00:00 2001
From: boris-il-forte <boris.ilpossente@hotmail.it>
Date: Sat, 12 Oct 2024 18:29:04 +0200
Subject: [PATCH] Added BODY_VEL_WORLD as observation type in Mujoco

- now we provide also the opportunity to get the body velocity info in world frame
---
 .../mujoco_envs/air_hockey/base.py            |  4 ++--
 .../environments/mujoco_envs/ball_in_a_cup.py |  2 +-
 .../utils/mujoco/observation_helper.py        | 19 +++++++++++--------
 3 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/mushroom_rl/environments/mujoco_envs/air_hockey/base.py b/mushroom_rl/environments/mujoco_envs/air_hockey/base.py
index 06de5564..020f81c5 100644
--- a/mushroom_rl/environments/mujoco_envs/air_hockey/base.py
+++ b/mushroom_rl/environments/mujoco_envs/air_hockey/base.py
@@ -57,7 +57,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
                                  ("robot_1/joint_3_vel", "planar_robot_1/joint_3", ObservationType.JOINT_VEL)]
 
             additional_data += [("robot_1/ee_pos", "planar_robot_1/body_ee", ObservationType.BODY_POS),
-                                ("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL)]
+                                ("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL_WORLD)]
 
             collision_spec += [("robot_1/ee", ["planar_robot_1/ee"])]
 
@@ -76,7 +76,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
                                      ("robot_2/joint_3_vel", "planar_robot_2/joint_3", ObservationType.JOINT_VEL)]
 
                 additional_data += [("robot_2/ee_pos", "planar_robot_2/body_ee", ObservationType.BODY_POS),
-                                    ("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL)]
+                                    ("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL_WORLD)]
 
                 collision_spec += [("robot_2/ee", ["planar_robot_2/ee"])]
         else:
diff --git a/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py b/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py
index c8cf6b46..a0ada7d7 100644
--- a/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py
+++ b/mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py
@@ -36,7 +36,7 @@ def __init__(self):
                             ("palm_yaw_pos", "wam/palm_yaw_joint", ObservationType.JOINT_POS),
                             ("palm_yaw_vel", "wam/palm_yaw_joint", ObservationType.JOINT_VEL),
                             ("ball_pos", "ball", ObservationType.BODY_POS),
-                            ("ball_vel", "ball", ObservationType.BODY_VEL)]
+                            ("ball_vel", "ball", ObservationType.BODY_VEL_WORLD)]
 
         additional_data_spec = [("ball_pos", "ball", ObservationType.BODY_POS),
                                 ("goal_pos", "cup_goal_final", ObservationType.SITE_POS)]
diff --git a/mushroom_rl/utils/mujoco/observation_helper.py b/mushroom_rl/utils/mujoco/observation_helper.py
index cf1c3f88..77429e85 100644
--- a/mushroom_rl/utils/mujoco/observation_helper.py
+++ b/mushroom_rl/utils/mujoco/observation_helper.py
@@ -11,20 +11,22 @@ class ObservationType(Enum):
     The Observation have the following returns:
         BODY_POS: (3,) x, y, z position of the body
         BODY_ROT: (4,) quaternion of the body
-        BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z
+        BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in local frame
+        BODY_VEL_WORLD: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in world frame
         JOINT_POS: (1,) rotation of the joint OR (7,) position, quaternion of a free joint
         JOINT_VEL: (1,) velocity of the joint OR (6,) FIRST linear then angular velocity !different to BODY_VEL!
         SITE_POS: (3,) x, y, z position of the body
         SITE_ROT: (9,) rotation matrix of the site
     """
-    __order__ = "BODY_POS BODY_ROT BODY_VEL JOINT_POS JOINT_VEL SITE_POS SITE_ROT"
+    __order__ = "BODY_POS BODY_ROT BODY_VEL BODY_VEL_WORLD JOINT_POS JOINT_VEL SITE_POS SITE_ROT"
     BODY_POS = 0
     BODY_ROT = 1
     BODY_VEL = 2
-    JOINT_POS = 3
-    JOINT_VEL = 4
-    SITE_POS = 5
-    SITE_ROT = 6
+    BODY_VEL_WORLD = 3
+    JOINT_POS = 4
+    JOINT_VEL = 5
+    SITE_POS = 6
+    SITE_ROT = 7
 
 
 class ObservationHelper:
@@ -190,9 +192,10 @@ def get_state(self, model, data, name, o_type):
             obs = data.body(name).xpos
         elif o_type == ObservationType.BODY_ROT:
             obs = data.body(name).xquat
-        elif o_type == ObservationType.BODY_VEL:
+        elif o_type == ObservationType.BODY_VEL or o_type == ObservationType.BODY_VEL_WORLD:
+            local = o_type == ObservationType.BODY_VEL
             obs = np.empty(6)
-            mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, True)
+            mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, local)
         elif o_type == ObservationType.JOINT_POS:
             obs = data.joint(name).qpos
         elif o_type == ObservationType.JOINT_VEL: