Skip to content

Commit

Permalink
[Broken, wip] Fix gym envs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Dec 10, 2024
1 parent ebe5112 commit bb2feb7
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 38 deletions.
30 changes: 17 additions & 13 deletions crazyflow/gymnasium_envs/crazyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,25 +175,29 @@ def reset(self, mask: Array) -> None:
minval=jnp.array([-1.0, -1.0, 1.0]), # x,y,z
maxval=jnp.array([1.0, 1.0, 2.0]), # x,y,z
)
self.sim.states = self.sim.states.replace(
pos=jnp.where(mask3d, init_pos, self.sim.states.pos)
self.sim.data = self.sim.data.replace(
states=self.sim.data.states.replace(
pos=jnp.where(mask3d, init_pos, self.sim.data.states.pos)
)
)
# Sample initial vel
self.jax_key, subkey = jax.random.split(self.jax_key)
init_vel = jax.random.uniform(
key=subkey, shape=(self.sim.n_worlds, self.sim.n_drones, 3), minval=-1.0, maxval=1.0
)
self.sim.states = self.sim.states.replace(
vel=jnp.where(mask3d, init_vel, self.sim.states.vel)
self.sim.data = self.sim.data.replace(
states=self.sim.data.states.replace(
vel=jnp.where(mask3d, init_vel, self.sim.data.states.vel)
)
)

@property
def reward(self) -> Array:
return self._reward(self.prev_done, self.terminated, self.sim.states)
return self._reward(self.prev_done, self.terminated, self.sim.data.states)

@property
def terminated(self) -> Array:
return self._terminated(self.prev_done, self.sim.states, self.sim.contacts())
return self._terminated(self.prev_done, self.sim.data.states, self.sim.contacts())

@property
def truncated(self) -> Array:
Expand Down Expand Up @@ -224,7 +228,7 @@ def render(self):
def _obs(self) -> dict[str, Array]:
convert = self.return_datatype == "numpy"
fields = self.obs_keys
states = [maybe_to_numpy(getattr(self.sim.states, field), convert) for field in fields]
states = [maybe_to_numpy(getattr(self.sim.data.states, field), convert) for field in fields]
return {k: v for k, v in zip(fields, states)}


Expand All @@ -243,7 +247,7 @@ def __init__(self, **kwargs: dict):

@property
def reward(self) -> Array:
return self._reward(self.prev_done, self.terminated, self.sim.states, self.goal)
return self._reward(self.prev_done, self.terminated, self.sim.data.states, self.goal)

@staticmethod
@jax.jit
Expand All @@ -269,7 +273,7 @@ def reset(self, mask: Array) -> None:

def _obs(self) -> dict[str, Array]:
obs = super()._obs()
obs["difference_to_goal"] = [self.goal - self.sim.states.pos]
obs["difference_to_goal"] = [self.goal - self.sim.data.states.pos]
return obs


Expand All @@ -287,7 +291,7 @@ def __init__(self, **kwargs: dict):

@property
def reward(self) -> Array:
return self._reward(self.prev_done, self.terminated, self.sim.states, self.target_vel)
return self._reward(self.prev_done, self.terminated, self.sim.data.states, self.target_vel)

@staticmethod
@jax.jit
Expand All @@ -313,7 +317,7 @@ def reset(self, mask: Array) -> None:

def _obs(self) -> dict[str, Array]:
obs = super()._obs()
obs["difference_to_target_vel"] = [self.target_vel - self.sim.states.vel]
obs["difference_to_target_vel"] = [self.target_vel - self.sim.data.states.vel]
return obs


Expand All @@ -335,7 +339,7 @@ def __init__(self, **kwargs: dict):

@property
def reward(self) -> Array:
return self._reward(self.prev_done, self.terminated, self.sim.states, self.goal)
return self._reward(self.prev_done, self.terminated, self.sim.data.states, self.goal)

@staticmethod
@jax.jit
Expand All @@ -352,5 +356,5 @@ def reset(self, mask: Array) -> None:

def _get_obs(self) -> dict[str, Array]:
obs = super()._get_obs()
obs["difference_to_goal"] = [self.goal - self.sim.states.pos]
obs["difference_to_goal"] = [self.goal - self.sim.data.states.pos]
return obs
10 changes: 10 additions & 0 deletions crazyflow/sim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@ def time(self) -> Array:
def freq(self) -> int:
return self.data.sim.freq

@property
def control_freq(self) -> int:
if self.control == Control.state:
return self.data.controls.state_freq
if self.control == Control.attitude:
return self.data.controls.attitude_freq
if self.control == Control.thrust:
raise NotImplementedError("Thrust control is not yet supported by the sim config")
raise NotImplementedError(f"Control mode {self.control} not implemented")

@property
def controllable(self) -> Array:
"""Boolean array of shape (n_worlds,) that indicates which worlds are controllable.
Expand Down
2 changes: 1 addition & 1 deletion examples/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def step(cmd: NDArray) -> jax.Array:
sim.state_control(cmd)
sim.step()
sim.step() # We need two steps for the initial step to take effect on the z position
return (sim.states.pos[0, 0, 2] - 1.0) ** 2 # Quadratic cost to reach 1m height
return (sim.data.states.pos[0, 0, 2] - 1.0) ** 2 # Quadratic cost to reach 1m height

step_grad = jax.jit(jax.grad(step))

Expand Down
7 changes: 3 additions & 4 deletions examples/gymnasium_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import numpy as np
from ml_collections import config_dict

from crazyflow.control.controller import Control, Controller
from crazyflow.control.controller import Control
from crazyflow.sim.physics import Physics

# set config for simulation
sim_config = config_dict.ConfigDict()
sim_config.device = "cpu"
sim_config.physics = Physics.sys_id
sim_config.control = Control.default
sim_config.controller = Controller.default
sim_config.control_freq = 50
sim_config.n_drones = 1
sim_config.n_worlds = 20
Expand All @@ -25,9 +24,9 @@
**sim_config,
)

# action for going up (in attitude control). NOTE actions are rescaled in the environment
# Action for going up (in attitude control)
action = np.zeros((sim_config.n_worlds * sim_config.n_drones, 4), dtype=np.float32)
action[..., 0] = -0.2
action[..., 0] = 0.3

obs, info = envs.reset_all(seed=SEED)

Expand Down
24 changes: 4 additions & 20 deletions tests/unit/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flax.serialization import to_state_dict

from crazyflow.control.controller import Control
from crazyflow.exception import ConfigError
from crazyflow.sim.core import Sim
from crazyflow.sim.physics import Physics

Expand Down Expand Up @@ -32,21 +33,13 @@ def skip_unavailable_device(device: str):
@pytest.mark.parametrize("device", ["gpu", "cpu"])
@pytest.mark.parametrize("control", Control)
@pytest.mark.parametrize("n_worlds", [1, 2])
def test_sim_init(
physics: Physics, device: str, control: Control, controller: Controller, n_worlds: int
):
def test_sim_init(physics: Physics, device: str, control: Control, n_worlds: int):
n_drones = 1
skip_unavailable_device(device)

def create_sim() -> Sim:
return Sim(
n_worlds=n_worlds, n_drones=n_drones, physics=physics, device=device, control=control
)
return Sim(n_worlds=n_worlds, physics=physics, device=device, control=control)

if n_drones * n_worlds > 1 and controller == Controller.pycffirmware:
with pytest.raises(ConfigError):
create_sim()
return
if physics != Physics.analytical and control == Control.thrust:
with pytest.raises(ConfigError): # TODO: Remove when supported with sys_id
create_sim()
Expand Down Expand Up @@ -173,18 +166,9 @@ def test_reset_masked(device: str, physics: Physics):
@pytest.mark.parametrize("device", ["gpu", "cpu"])
def test_sim_step(n_worlds: int, n_drones: int, physics: Physics, control: Control, device: str):
skip_unavailable_device(device)
if n_drones * n_worlds > 1 and controller == Controller.pycffirmware:
return # PyCFFirmware does not support multiple drones
if physics != Physics.analytical and control == Control.thrust:
return # TODO: Remove when supported with sys_id
sim = Sim(
n_worlds=n_worlds,
n_drones=n_drones,
physics=physics,
device=device,
control=control,
controller=controller,
)
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, device=device, control=control)
try:
for _ in range(2):
sim.step()
Expand Down

0 comments on commit bb2feb7

Please sign in to comment.