Skip to content

Commit

Permalink
remove resets
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Jan 30, 2025
1 parent 7f926e7 commit a8c23cc
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 261 deletions.
264 changes: 81 additions & 183 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.utils.tracing import not_tracing

from . import common
from .common import VelRepr
Expand Down Expand Up @@ -55,10 +54,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
base_position: jtp.Vector

# Cached computations.
_base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
_joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)

@staticmethod
def build(
Expand Down Expand Up @@ -173,10 +172,10 @@ def build(
base_angular_velocity=v_WB[3:6],
joint_velocities=joint_velocities,
velocity_representation=velocity_representation,
_base_transform=W_H_B,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
base_transform=W_H_B,
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
)

if not model_data.valid(model=model):
Expand Down Expand Up @@ -209,34 +208,6 @@ def zero(
# Extract quantities
# ==================

@property
def base_transform(self) -> jtp.Matrix:
"""
Get the base transform.
"""
return self._base_transform

@property
def joint_transforms(self) -> jtp.Matrix:
"""
Get the joint transforms.
"""
return self._joint_transforms

@property
def link_transforms(self) -> jtp.Matrix:
"""
Get the link transforms.
"""
return self._link_transforms

@property
def link_velocities(self) -> jtp.Matrix:
"""
Get the link velocities.
"""
return self._link_velocities

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["dcm"])
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
Expand Down Expand Up @@ -282,7 +253,7 @@ def base_velocity(self) -> jtp.Vector:
]
)

W_H_B = self._base_transform
W_H_B = self.base_transform

return (
JaxSimModelData.inertial_to_other_representation(
Expand All @@ -306,7 +277,7 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
A tuple containing the base transform and the joint positions.
"""

return self._base_transform, self.joint_positions
return self.base_transform, self.joint_positions

@js.common.named_scope
@jax.jit
Expand All @@ -330,106 +301,6 @@ def generalized_velocity(self) -> jtp.Vector:
# Store quantities
# ================

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def reset_joint_positions(
self,
positions: jtp.VectorLike,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Reset the joint positions.
Args:
positions: The joint positions.
model: The model to consider.
joint_names: The names of the joints for which to set the positions.
Returns:
The updated `JaxSimModelData` object.
"""

positions = jnp.atleast_1d(jnp.array(positions).squeeze()).astype(float)

if model is None:
return self.replace(validate=True, joint_positions=positions)

if not_tracing(positions) and not self.valid(model=model):
msg = "The data object is not compatible with the provided model"
raise ValueError(msg)

joint_idxs = (
js.joint.names_to_idxs(joint_names=joint_names, model=model)
if joint_names is not None
else jnp.arange(model.number_of_joints())
)

return self.replace(
validate=True,
joint_positions=self.joint_positions.at[joint_idxs].set(positions),
)

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def reset_joint_velocities(
self,
velocities: jtp.VectorLike,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Reset the joint velocities.
Args:
velocities: The joint velocities.
model: The model to consider.
joint_names: The names of the joints for which to set the velocities.
Returns:
The updated `JaxSimModelData` object.
"""

velocities = jnp.atleast_1d(jnp.array(velocities).squeeze()).astype(float)

if model is None:
return self.replace(validate=True, joint_velocities=velocities)

if not_tracing(velocities) and not self.valid(model=model):
msg = "The data object is not compatible with the provided model"
raise ValueError(msg)

joint_idxs = (
js.joint.names_to_idxs(joint_names=joint_names, model=model)
if joint_names is not None
else jnp.arange(model.number_of_joints())
)

return self.replace(
validate=True,
joint_velocities=self.joint_velocities.at[joint_idxs].set(velocities),
)

@js.common.named_scope
@jax.jit
def reset_base_position(self, base_position: jtp.VectorLike) -> Self:
"""
Reset the base position.
Args:
base_position: The base position.
Returns:
The updated `JaxSimModelData` object.
"""

base_position = jnp.array(base_position)

return self.replace(
validate=True,
base_position=jnp.atleast_1d(base_position.squeeze()).astype(float),
)

@js.common.named_scope
@jax.jit
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
Expand Down Expand Up @@ -464,13 +335,11 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
"""

base_pose = jnp.array(base_pose)

W_p_B = base_pose[0:3, 3]

W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])

return self.reset_base_position(base_position=W_p_B).reset_base_quaternion(
base_quaternion=W_Q_B
return self.replace(
base_position=W_p_B,
base_quaternion=W_Q_B,
)

@js.common.named_scope
Expand Down Expand Up @@ -568,7 +437,7 @@ def reset_base_velocity(
W_v_WB = self.other_representation_to_inertial(
array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
other_representation=velocity_representation,
transform=self._base_transform,
transform=self.base_transform,
is_force=False,
)

Expand All @@ -578,6 +447,73 @@ def reset_base_velocity(
base_angular_velocity=W_v_WB[3:6].squeeze().astype(float),
)

def replace(
self,
model: js.model.JaxSimModel,
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
validate: bool = False,
) -> Self:
"""
Replace the attributes of the `JaxSimModelData` object.
"""
if joint_positions is None:
joint_positions = self.joint_positions
if joint_velocities is None:
joint_velocities = self.joint_velocities
if base_quaternion is None:
base_quaternion = self.base_quaternion
if base_linear_velocity is None:
base_linear_velocity = self.base_linear_velocity
if base_angular_velocity is None:
base_angular_velocity = self.base_angular_velocity
if base_position is None:
base_position = self.base_position

joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
base_quaternion = jnp.atleast_1d(base_quaternion.squeeze()).astype(float)
base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
base_linear_velocity = base_linear_velocity.astype(float)
base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())
base_angular_velocity = base_angular_velocity.astype(float)
base_position = jnp.atleast_1d(base_position.squeeze())
base_position = base_position.astype(float)

base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
translation=base_position, quaternion=base_quaternion
)
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=joint_positions, base_transform=base_transform
)
link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
)

return super().replace(
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_quaternion=base_quaternion,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
base_position=base_position,
validate=validate,
base_transform=base_transform,
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
)

def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `JaxSimModelData` is valid for a given `JaxSimModel`.
Expand All @@ -604,44 +540,6 @@ def valid(self, model: js.model.JaxSimModel) -> bool:

return True

@js.common.named_scope
@jax.jit
def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData:
"""
Update the cached kinematics and dynamics quantities of the model.
Args:
model: the model to consider.
Returns:
The data object with updated quantity.
"""

base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
translation=self.base_position, quaternion=self.base_quaternion
)

joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=self.joint_positions, base_transform=self._base_transform
)

link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=self.base_position,
base_quaternion=self.base_quaternion,
joint_positions=self.joint_positions,
joint_velocities=self.joint_velocities,
base_linear_velocity=self.base_linear_velocity,
base_angular_velocity=self.base_angular_velocity,
)

return self.replace(
_base_transform=base_transform,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
)


@functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
def random_model_data(
Expand Down
29 changes: 15 additions & 14 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,18 @@ def semi_implicit_euler_integration(

new_joint_position = data.joint_positions + dt * new_joint_velocities

data = data.replace(
validate=True,
base_quaternion=new_base_quaternion,
base_position=new_base_position,
joint_positions=new_joint_position,
joint_velocities=new_joint_velocities,
base_linear_velocity=base_lin_velocity_inertial,
# Here we use the base angular velocity in mixed representation since
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
base_angular_velocity=base_ang_velocity_mixed,
)

return data
data = data.replace(
model=model,
validate=True,
base_quaternion=new_base_quaternion,
base_position=new_base_position,
joint_positions=new_joint_position,
joint_velocities=new_joint_velocities,
base_linear_velocity=base_lin_velocity_inertial,
# Here we use the base angular velocity in mixed representation since
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
base_angular_velocity=base_ang_velocity_mixed,
)

return data
Loading

0 comments on commit a8c23cc

Please sign in to comment.