Skip to content

Commit

Permalink
[wip,broken] Move to lsy_models
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Feb 18, 2025
1 parent d38bb5c commit 82699ce
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
25 changes: 17 additions & 8 deletions crazyflow/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from jax import Array, Device
from jax.scipy.spatial.transform import Rotation as R
from lsy_models.models_numeric import f_first_principles
from mujoco.mjx import Data, Model

from crazyflow.constants import J_INV, MASS, SIGN_MIX_MATRIX, J
Expand Down Expand Up @@ -128,8 +129,8 @@ def build_step_fn(self):
# None is required by jax.lax.scan to unpack the tuple returned by single_step.
def single_step(data: SimData, _: None) -> tuple[SimData, None]:
data = ctrl_fn(data)
data = wrench_fn(data)
data = disturbance_fn(data)
data = wrench_fn(data)
data = physics_fn(data)
data = data.replace(core=data.core.replace(steps=data.core.steps + 1))
# MuJoCo needs to sync after every physics step, so that the next step control, wrench
Expand Down Expand Up @@ -570,13 +571,21 @@ def analytical_wrench(data: SimData) -> SimData:

def analytical_derivative(data: SimData) -> SimData:
"""Compute the derivative of the states."""
quat, mass, J_inv = data.states.quat, data.params.mass, data.params.J_INV
acc = collective_force2acceleration(data.states.force, mass)
ang_vel_deriv = collective_torque2ang_vel_deriv(data.states.torque, quat, J_inv)
vel, ang_vel = (data.states.vel, data.states.ang_vel) # Already given in the states
deriv = data.states_deriv
deriv = deriv.replace(dpos=vel, drot=ang_vel, dvel=acc, dang_vel=ang_vel_deriv)
return data.replace(states_deriv=deriv)
dpos, _, dvel, dang_vel, df_motor = f_first_principles(
data.states.pos,
data.states.quat,
data.states.vel,
data.states.ang_vel,
data.controls.thrust,
data.params,
data.states.motor_forces,
data.states.force,
data.states.torque,
)
states_deriv = data.states_deriv.replace(
dpos=dpos, drot=data.states.ang_vel, dvel=dvel, dang_vel=dang_vel, dmotor_forces=df_motor
)
return data.replace(states_deriv=states_deriv)


def identified_wrench(data: SimData) -> SimData:
Expand Down
19 changes: 10 additions & 9 deletions crazyflow/sim/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ class SimState:
"""Velocity of the drone's center of mass in the world frame."""
ang_vel: Array # (N, M, 3)
"""Angular velocity of the drone's center of mass in the world frame."""
motor_forces: Array # (N, M, 4) # Motor forces along body frame z axis
"""Motor forces along body frame z axis."""
force: Array # (N, M, 3) # CoM force
"""Force applied to the drone's center of mass in the world frame."""
torque: Array # (N, M, 3) # CoM torque
"""Torque applied to the drone's center of mass in the world frame."""
motor_forces: Array # (N, M, 4) # Motor forces along body frame z axis
"""Motor forces along body frame z axis."""
motor_torques: Array # (N, M, 4) # Motor torques around the body frame z axis
"""Motor torques around the body frame z axis."""

@staticmethod
def create(n_worlds: int, n_drones: int, device: Device) -> SimState:
Expand All @@ -41,16 +39,14 @@ def create(n_worlds: int, n_drones: int, device: Device) -> SimState:
force = jnp.zeros((n_worlds, n_drones, 3), device=device)
torque = jnp.zeros((n_worlds, n_drones, 3), device=device)
motor_forces = jnp.zeros((n_worlds, n_drones, 4), device=device)
motor_torques = jnp.zeros((n_worlds, n_drones, 4), device=device)
return SimState(
pos=pos,
quat=quat,
vel=vel,
ang_vel=ang_vel,
force=force,
torque=torque,
motor_forces=motor_forces,
motor_torques=motor_torques,
vel=vel,
ang_vel=ang_vel,
)


Expand All @@ -64,6 +60,8 @@ class SimStateDeriv:
"""Derivative of the velocity of the drone's center of mass."""
dang_vel: Array # (N, M, 3)
"""Derivative of the angular velocity of the drone's center of mass."""
dmotor_forces: Array # (N, M, 4)
"""Derivative of the motor forces along body frame z axis."""

@staticmethod
def create(n_worlds: int, n_drones: int, device: Device) -> SimStateDeriv:
Expand All @@ -72,7 +70,10 @@ def create(n_worlds: int, n_drones: int, device: Device) -> SimStateDeriv:
drot = jnp.zeros((n_worlds, n_drones, 3), device=device)
dvel = jnp.zeros((n_worlds, n_drones, 3), device=device)
dang_vel = jnp.zeros((n_worlds, n_drones, 3), device=device)
return SimStateDeriv(dpos=dpos, drot=drot, dvel=dvel, dang_vel=dang_vel)
dmotor_forces = jnp.zeros((n_worlds, n_drones, 4), device=device)
return SimStateDeriv(
dpos=dpos, drot=drot, dvel=dvel, dang_vel=dang_vel, dmotor_forces=dmotor_forces
)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion examples/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def main():
"""Spawn 25 drones in one world and render each with a trace behind it."""
n_worlds, n_drones = 1, 25
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.sys_id, device="cpu")
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.analytical, device="cpu")
fps = 60
cmd = np.zeros((sim.n_worlds, sim.n_drones, 4))
cmd[..., 0] = MASS * GRAVITY * 1.2
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"ml_collections",
"casadi",
"numpy",
"lsy_models @ git+https://github.com/utiasDSL/models.git",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 82699ce

Please sign in to comment.