Skip to content

Commit

Permalink
Merge branch 'attitude_interface' into notebooks
Browse files Browse the repository at this point in the history
merge attitude interface into notebooks
  • Loading branch information
Lui committed Jan 8, 2025
2 parents 2cb8c0a + 289dfc3 commit 3359bc4
Show file tree
Hide file tree
Showing 43 changed files with 2,417 additions and 2,770 deletions.
35 changes: 15 additions & 20 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ml_collections import config_dict

import crazyflow # noqa: F401, ensure gymnasium envs are registered
from crazyflow.sim.core import Sim
from crazyflow.sim import Sim


def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float) -> None:
Expand All @@ -19,8 +19,8 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
tmax, idx_tmax = np.max(times), np.argmax(times)

# Check for significant variance
if tmax / tmin > 5:
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
if tmax / tmin > 10:
print("Warning: step time varies by more than 10x. Is JIT compiling during the benchmark?")
print(f"Times: max {tmax:.2e} @ {idx_tmax}, min {tmin:.2e} @ {idx_tmin}")

# Performance metrics
Expand All @@ -43,28 +43,23 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
device = jax.devices(device)[0]

envs = gymnasium.make_vec(
"DroneReachPos-v0",
time_horizon_in_seconds=2,
return_datatype="numpy",
num_envs=sim_config.n_worlds,
**sim_config,
"DroneReachPos-v0", time_horizon_in_seconds=3, num_envs=sim_config.n_worlds, **sim_config
)

# Action for going up (in attitude control)
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
action[..., 0] = 0.3
# Step through env once to ensure JIT compilation
envs.reset_all(seed=42)
envs.step(action)
envs.reset(seed=42)
envs.step(action)

jax.block_until_ready(envs.unwrapped.sim.states.pos) # Ensure JIT compiled dynamics
jax.block_until_ready(envs.unwrapped.sim.data) # Ensure JIT compiled dynamics

# Step through the environment
for _ in range(n_steps):
tstart = time.perf_counter()
envs.step(action)
jax.block_until_ready(envs.unwrapped.sim.states.pos)
jax.block_until_ready(envs.unwrapped.sim.data)
times.append(time.perf_counter() - tstart)

envs.close()
Expand All @@ -83,14 +78,14 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):

sim.reset()
sim.attitude_control(cmd)
sim.step()
jax.block_until_ready(sim.states.pos) # Ensure JIT compiled dynamics
sim.step(sim.freq // sim.control_freq)
jax.block_until_ready(sim.data) # Ensure JIT compiled dynamics

for _ in range(n_steps):
tstart = time.perf_counter()
sim.attitude_control(cmd)
sim.step()
jax.block_until_ready(sim.states.pos)
sim.step(sim.freq // sim.control_freq)
jax.block_until_ready(sim.data)
times.append(time.perf_counter() - tstart)

analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
Expand All @@ -102,16 +97,16 @@ def main():
sim_config = config_dict.ConfigDict()
sim_config.n_worlds = 1
sim_config.n_drones = 1
sim_config.physics = "sys_id"
sim_config.physics = "analytical"
sim_config.control = "attitude"
sim_config.controller = "emulatefirmware"
sim_config.attitude_freq = 500
sim_config.device = device

print("Simulator performance")
profile_step(sim_config, 100, device)
profile_step(sim_config, 1000, device)

print("\nGymnasium environment performance")
profile_gym_env_step(sim_config, 100, device)
profile_gym_env_step(sim_config, 1000, device)


if __name__ == "__main__":
Expand Down
24 changes: 9 additions & 15 deletions benchmark/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyinstrument.renderers.html import HTMLRenderer

import crazyflow # noqa: F401, ensure gymnasium envs are registered
from crazyflow.sim.core import Sim
from crazyflow.sim import Sim

if TYPE_CHECKING:
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal
Expand All @@ -26,9 +26,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
sim.reset()
control_fn(cmd)
sim.step()
sim.step()
sim.reset()
jax.block_until_ready(sim.states.pos)
jax.block_until_ready(sim.data)

profiler = Profiler()
profiler.start()
Expand All @@ -37,7 +35,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
control_fn(cmd)
# sim.reset()
sim.step()
jax.block_until_ready(sim.states.pos)
jax.block_until_ready(sim.data)
profiler.stop()
renderer = HTMLRenderer()
renderer.open_in_browser(profiler.last_session)
Expand All @@ -47,29 +45,26 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
device = jax.devices(device)[0]

envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
"DroneReachPos-v0",
time_horizon_in_seconds=2,
return_datatype="numpy",
num_envs=sim_config.n_worlds,
**sim_config,
"DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
)

# Action for going up (in attitude control)
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
action[..., 0] = -0.3
action[..., 0] = 0.3

# Step through env once to ensure JIT compilation.
envs.reset_all(seed=42)
envs.reset(seed=42)
envs.step(action)
envs.step(action) # Ensure all paths have been taken at least once
envs.reset_all(seed=42)
envs.reset(seed=42)
jax.block_until_ready(envs.unwrapped.sim.data)

profiler = Profiler()
profiler.start()

for _ in range(n_steps):
envs.step(action)
jax.block_until_ready(envs.unwrapped.sim.states.pos)
jax.block_until_ready(envs.unwrapped.sim.data)

profiler.stop()
renderer = HTMLRenderer()
Expand All @@ -84,7 +79,6 @@ def main():
sim_config.n_drones = 1
sim_config.physics = "analytical"
sim_config.control = "attitude"
sim_config.controller = "emulatefirmware"
sim_config.device = device

profile_step(sim_config, 1000, device)
Expand Down
4 changes: 4 additions & 0 deletions crazyflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
import crazyflow.gymnasium_envs # noqa: F401, ensure gymnasium envs are registered
from crazyflow.control import Control
from crazyflow.sim import Physics, Sim

__all__ = ["Sim", "Physics", "Control"]
2 changes: 1 addition & 1 deletion crazyflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
GRAVITY: float = 9.81

# Drone constants
ARM_LEN: float = 0.46
ARM_LEN: float = 0.0325 * jnp.sqrt(2)
MIX_MATRIX: Array = jnp.array([[-0.5, -0.5, -1], [-0.5, 0.5, 1], [0.5, 0.5, -1], [0.5, -0.5, 1]])
SIGN_MIX_MATRIX: Array = jnp.sign(MIX_MATRIX)
MASS: float = 0.027
Expand Down
3 changes: 3 additions & 0 deletions crazyflow/control/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from crazyflow.control.control import Control

__all__ = ["Control"]
41 changes: 27 additions & 14 deletions crazyflow/control/controller.py → crazyflow/control/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,28 @@ class Control(str, Enum):
"""Control type of the simulated onboard controller."""

state = "state"
"""State control takes [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].
Note:
Recommended frequency is >=20 Hz.
Warning:
Currently, we only use positions, velocities, and yaw. The rest of the state is ignored.
This is subject to change in the future.
"""
attitude = "attitude"
thrust = "thrust"
default = attitude
"""Attitude control takes [collective thrust, roll, pitch, yaw].
Note:
Recommended frequency is >=100 Hz.
"""
thrust = "thrust"
"""Thrust control takes [thrust1, thrust2, thrust3, thrust4] for each drone motor.
class Controller(str, Enum):
"""Controller type of the simulated onboard controller."""

pycffirmware = "pycffirmware"
emulatefirmware = "emulatefirmware"
default = emulatefirmware
Note:
Recommended frequency is >=500 Hz.
"""
default = attitude


KF: float = 3.16e-10
Expand Down Expand Up @@ -90,27 +101,28 @@ def state2attitude(

@partial(jnp.vectorize, signature="(4),(4),(3),(3)->(4),(3)", excluded=[4])
def attitude2rpm(
cmd: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
controls: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
) -> tuple[Array, Array]:
"""Convert the desired attitude and quaternion into motor RPMs."""
"""Convert the desired collective thrust and attitude into motor RPMs."""
rot = R.from_quat(quat)
target_rot = R.from_euler("xyz", cmd[..., 1:])
target_rot = R.from_euler("xyz", controls[1:])
drot = (target_rot.inv() * rot).as_matrix()

# Extract the anti-symmetric part of the relative rotation matrix.
rot_e = jnp.array([drot[2, 1] - drot[1, 2], drot[0, 2] - drot[2, 0], drot[1, 0] - drot[0, 1]])
rpy_rates_e = -(rot.as_euler("xyz") - last_rpy) / dt # Assuming zero rpy_rates target
# TODO: Assumes zero rpy_rates targets for now, use the actual target instead.
rpy_rates_e = -(rot.as_euler("xyz") - last_rpy) / dt
rpy_err_i = rpy_err_i - rot_e * dt
rpy_err_i = jnp.clip(rpy_err_i, -1500.0, 1500.0)
rpy_err_i = rpy_err_i.at[:2].set(jnp.clip(rpy_err_i[:2], -1.0, 1.0))
# PID target torques.
target_torques = -P_T * rot_e + D_T * rpy_rates_e + I_T * rpy_err_i
target_torques = jnp.clip(target_torques, -3200, 3200)
thrust_per_motor = cmd[0] / 4
thrust_per_motor = jnp.atleast_1d(controls[0]) / 4
pwm = jnp.clip(thrust2pwm(thrust_per_motor) + MIX_MATRIX @ target_torques, MIN_PWM, MAX_PWM)
return pwm2rpm(pwm), rpy_err_i


@partial(jnp.vectorize, signature="(4)->(4)")
def thrust2pwm(thrust: Array) -> Array:
"""Convert the desired thrust into motor PWM.
Expand All @@ -124,6 +136,7 @@ def thrust2pwm(thrust: Array) -> Array:
return jnp.clip((jnp.sqrt(thrust / KF) - PWM2RPM_CONST) / PWM2RPM_SCALE, MIN_PWM, MAX_PWM)


@partial(jnp.vectorize, signature="(4)->(4)")
def pwm2rpm(pwm: Array) -> Array:
"""Convert the motors' PWMs into RPMs."""
return PWM2RPM_CONST + PWM2RPM_SCALE * pwm
17 changes: 14 additions & 3 deletions crazyflow/gymnasium_envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from gymnasium.envs.registration import register

from crazyflow.gymnasium_envs.crazyflow import (
CrazyflowEnvFigureEightTrajectory,
CrazyflowEnvLanding,
CrazyflowEnvReachGoal,
CrazyflowEnvTargetVelocity,
CrazyflowRL
CrazyflowRL,
)

__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity", "CrazyflowEnvLanding", "CrazyflowRL"]
__all__ = [
"CrazyflowEnvReachGoal",
"CrazyflowEnvTargetVelocity",
"CrazyflowEnvLanding",
"CrazyflowRL",
"CrazyflowEnvFigureEightTrajectory",
]

register(
id="DroneReachPos-v0",
Expand All @@ -19,8 +26,12 @@
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvTargetVelocity",
)


register(
id="DroneLanding-v0",
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvLanding",
)

register(
id="DroneFigureEightTrajectory-v0",
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvFigureEightTrajectory",
)
Loading

0 comments on commit 3359bc4

Please sign in to comment.