Skip to content

Commit

Permalink
Fix docs and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Feb 4, 2025
1 parent 024b37b commit 50fb8f2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 21 deletions.
22 changes: 3 additions & 19 deletions crazyflow/gymnasium_envs/crazyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(
Args:
num_envs: The number of environments to run in parallel.
time_horizon_in_seconds: The time horizon after which episodes are truncated.
**kwargs: Takes arguments that are passed to the Crazyfly simulation.
physics: The crazyflow physics simulation model.
freq: The frequency at which the environment is run.
device: The device of the environment and the simulation.
"""
self.num_envs = num_envs
self.device = jax.devices(device)[0]
Expand Down Expand Up @@ -530,21 +532,3 @@ def actions(self, actions: Array) -> Array:
# Ensure actions are within the valid range of the simulation action space
rescaled_actions = np.clip(rescaled_actions, self.action_sim_low, self.action_sim_high)
return rescaled_actions


def render_trajectory(viewer: MujocoRenderer | None, pos: Array) -> None:
"""Render trajectory."""
if viewer is None:
return

pos = np.array(pos[0]).transpose(1, 0, 2)
n_trace, n_drones = len(pos) - 1, len(pos[0])

for i in range(n_trace):
for j in range(n_drones):
viewer.viewer.add_marker(
type=mujoco.mjtGeom.mjGEOM_SPHERE,
size=np.array([0.02, 0.02, 0.02]),
pos=pos[i][j],
rgba=np.array([1, 0, 0, 0.8]),
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[build-system]
requires = ["setuptools>=61.0.0", "wheel", "numpy"]
build-backend = "setuptools.build_meta"
requires-python = "3.11" # tested in python 3.11
requires-python = "3.11" # tested in python 3.11

[project]
name = "crazyflow"
Expand Down Expand Up @@ -81,6 +81,7 @@ unfixable = []
"benchmark/*" = ["D100", "D103"]
"tests/*" = ["D100", "D103", "D104"]
"examples/*" = ["D100", "D103"]
"tutorials/*" = ["D", "ANN"]
# TODO: Remove once everything is stable and document
"crazyflow/*" = ["D100", "D101", "D102", "D104", "D107"]

Expand Down
3 changes: 2 additions & 1 deletion tutorials/ppo/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import random
from pathlib import Path

import gymnasium
import gymnasium.wrappers.vector.jax_to_torch
Expand Down Expand Up @@ -43,7 +44,7 @@
test_env = gymnasium.wrappers.vector.jax_to_torch.JaxToTorch(norm_test_env, device=device)

# Load checkpoint
checkpoint = torch.load("ppo_checkpoint.pt")
checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt")

# Create agent and load state
agent = Agent(test_env).to(device)
Expand Down

0 comments on commit 50fb8f2

Please sign in to comment.