Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update isort and black style of the repository #28

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# Follow upstream development in https://github.com/google/jax/pull/13304
def _jnp_options() -> None:

import os

from jax.config import config
Expand All @@ -21,14 +20,12 @@ def _jnp_options() -> None:


def _np_options() -> None:

import numpy as np

np.set_printoptions(precision=5, suppress=True, linewidth=150, threshold=10_000)


def _is_editable() -> bool:

import importlib.util
import pathlib
import site
Expand Down
1 change: 0 additions & 1 deletion src/jaxsim/high_level/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


class VelRepr(enum.IntEnum):

Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()
14 changes: 0 additions & 14 deletions src/jaxsim/high_level/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,38 @@

@jax_dataclasses.pytree_dataclass
class Joint(JaxsimDataclass):

joint_description: descriptions.JointDescription = jax_dataclasses.static_field()
parent_model: "jaxsim.high_level.model.Model" = jax_dataclasses.field(
default=None, repr=False, compare=False
)

def valid(self) -> bool:

return self.parent_model is not None

def index(self) -> int:

return self.joint_description.index

def dofs(self) -> int:

return 1

def name(self) -> str:

return self.joint_description.name

def position(self, dof: int = 0) -> float:

return self.parent_model.joint_positions(joint_names=[self.name()])[dof]

def velocity(self, dof: int = 0) -> float:

return self.parent_model.joint_velocities(joint_names=[self.name()])[dof]

def acceleration(self, dof: int = 0) -> float:

return self.parent_model.joint_accelerations(joint_names=[self.name()])[dof]

def force(self, dof: int = 0) -> float:

return self.parent_model.joint_generalized_forces(joint_names=[self.name()])[
dof
]

def position_limit(self, dof: int = 0) -> Tuple[float, float]:

if dof != 0:
msg = "Only joints with 1 DoF are currently supported"
raise ValueError(msg)
Expand All @@ -63,17 +53,13 @@ def position_limit(self, dof: int = 0) -> Tuple[float, float]:
# =================

def joint_position(self) -> jtp.Vector:

return self.parent_model.joint_positions(joint_names=[self.name()])

def joint_velocity(self) -> jtp.Vector:

return self.parent_model.joint_velocities(joint_names=[self.name()])

def joint_acceleration(self) -> jtp.Vector:

return self.parent_model.joint_accelerations(joint_names=[self.name()])

def joint_force(self) -> jtp.Vector:

return self.parent_model.joint_generalized_forces(joint_names=[self.name()])
29 changes: 0 additions & 29 deletions src/jaxsim/high_level/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,35 @@

@jax_dataclasses.pytree_dataclass
class Link(JaxsimDataclass):

link_description: descriptions.LinkDescription = jax_dataclasses.static_field()
parent_model: "jaxsim.high_level.model.Model" = jax_dataclasses.field(
default=None, repr=False, compare=False
)

def valid(self) -> bool:

return self.parent_model is not None

# ==========
# Properties
# ==========

def name(self) -> str:

return self.link_description.name

def index(self) -> int:

return self.link_description.index

# ========
# Dynamics
# ========

def mass(self) -> jtp.Float:

return self.link_description.mass

def spatial_inertia(self) -> jtp.Matrix:

return self.link_description.inertia

def com_position(self, in_link_frame: bool = True) -> jtp.VectorJax:

from jaxsim.math.inertia import Inertia

_, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia())
Expand All @@ -67,38 +60,31 @@ def com_position(self, in_link_frame: bool = True) -> jtp.VectorJax:
# ==========

def position(self) -> jtp.Vector:

return self.transform()[0:3, 3]

def orientation(self, dcm: bool = False) -> jtp.Vector:

R = self.transform()[0:3, 0:3]

to_wxyz = np.array([3, 0, 1, 2])
return R if dcm else sixd.so3.SO3.from_matrix(R).as_quaternion_xyzw()[to_wxyz]

def transform(self) -> jtp.Matrix:

return self.parent_model.forward_kinematics()[self.index()]

def velocity(self, vel_repr: VelRepr = None) -> jtp.Vector:

v_WL = (
self.jacobian(output_vel_repr=vel_repr)
@ self.parent_model.generalized_velocity()
)
return v_WL

def linear_velocity(self, vel_repr: VelRepr = None) -> jtp.Vector:

return self.velocity(vel_repr=vel_repr)[0:3]

def angular_velocity(self, vel_repr: VelRepr = None) -> jtp.Vector:

return self.velocity(vel_repr=vel_repr)[3:6]

def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:

if output_vel_repr is None:
output_vel_repr = self.parent_model.velocity_representation

Expand All @@ -110,11 +96,9 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:
)

if self.parent_model.velocity_representation is VelRepr.Body:

L_J_WL_target = L_J_WL_B

elif self.parent_model.velocity_representation is VelRepr.Inertial:

dofs = self.parent_model.dofs()
W_H_B = self.parent_model.base_transform()

Expand All @@ -124,7 +108,6 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:
L_J_WL_target = L_J_WL_B @ B_T_W

elif self.parent_model.velocity_representation is VelRepr.Mixed:

dofs = self.parent_model.dofs()
W_H_B = self.parent_model.base_transform()
BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
Expand All @@ -142,13 +125,11 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:
return L_J_WL_target

elif output_vel_repr is VelRepr.Inertial:

W_H_L = self.transform()
W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint()
return W_X_L @ L_J_WL_target

elif output_vel_repr is VelRepr.Mixed:

W_H_L = self.transform()
LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
LW_X_L = sixd.se3.SE3.from_matrix(LW_H_L).adjoint()
Expand All @@ -158,14 +139,12 @@ def jacobian(self, output_vel_repr: VelRepr = None) -> jtp.Matrix:
raise ValueError(output_vel_repr)

def external_force(self) -> jtp.Vector:

W_f_ext = self.parent_model.data.model_input.f_ext[self.index()]

if self.parent_model.velocity_representation is VelRepr.Inertial:
return W_f_ext

elif self.parent_model.velocity_representation is VelRepr.Body:

W_H_B = self.parent_model.base_transform()
W_X_B = sixd.se3.SE3.from_matrix(W_H_B).adjoint()

Expand All @@ -180,7 +159,6 @@ def external_force(self) -> jtp.Vector:
def add_external_force(
self, force: jtp.Array = None, torque: jtp.Array = None
) -> None:

force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)

Expand All @@ -190,15 +168,13 @@ def add_external_force(
W_f_ext = f_ext

elif self.parent_model.velocity_representation is VelRepr.Body:

L_f_ext = f_ext
W_H_L = self.transform()
L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint()

W_f_ext = L_X_W.transpose() @ L_f_ext

elif self.parent_model.velocity_representation is VelRepr.Mixed:

LW_f_ext = f_ext

W_p_L = self.transform()[0:3, 3]
Expand All @@ -221,18 +197,15 @@ def add_external_force(
def add_com_external_force(
self, force: jtp.Array = None, torque: jtp.Array = None
) -> None:

force = force if force is not None else jnp.zeros(3)
torque = torque if torque is not None else jnp.zeros(3)

f_ext = jnp.hstack([force, torque])

if self.parent_model.velocity_representation is VelRepr.Inertial:

W_f_ext = f_ext

elif self.parent_model.velocity_representation is VelRepr.Body:

GL_f_ext = f_ext

W_H_L = self.transform()
Expand All @@ -244,7 +217,6 @@ def add_com_external_force(
W_f_ext = GL_X_W.transpose() @ GL_f_ext

elif self.parent_model.velocity_representation is VelRepr.Mixed:

GW_f_ext = f_ext

W_p_CoM = self.com_position(in_link_frame=False)
Expand All @@ -265,5 +237,4 @@ def add_com_external_force(
)

def in_contact(self) -> jtp.Bool:

return not jnp.allclose(self.external_force(), 0)
Loading