Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure consistent link forces and fix static joint friction #162

Merged
merged 4 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def inertial_to_other_representation(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
is_force: bool = False,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from inertial-fixed to another representation.
Expand Down Expand Up @@ -153,7 +154,8 @@ def other_representation_to_inertial(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
is_force: bool = False,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from another representation to inertial-fixed.
Expand Down
17 changes: 14 additions & 3 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def system_velocity_dynamics(
).astype(float)

# Build link forces if not provided
W_f_L = (
O_f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
Expand All @@ -125,7 +125,7 @@ def system_velocity_dynamics(

# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)

# Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
# expressed in the world frame.
Expand Down Expand Up @@ -183,7 +183,7 @@ def system_velocity_dynamics(

# Compute the joint friction torque
τ_friction = -(
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
)

Expand All @@ -194,6 +194,17 @@ def system_velocity_dynamics(
# Compute the total joint forces
τ_total = τ + τ_friction + τ_position_limit

references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=τ_total,
link_forces=O_f_L,
data=data,
velocity_representation=data.velocity_representation,
)

with references.switch_velocity_representation(VelRepr.Inertial):
W_f_L = references.link_forces(model=model, data=data)

# Compute the total external 6D forces applied to the links
W_f_L_total = W_f_L + W_f_Li_terrain

Expand Down
81 changes: 81 additions & 0 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import pytest

Expand Down Expand Up @@ -90,3 +91,83 @@ def test_box_with_external_forces(
assert data.time() == t_ns / 1e9 + dt
assert data.base_position() == pytest.approx(data0.base_position())
assert data.base_orientation() == pytest.approx(data0.base_orientation())


def test_box_with_zero_gravity(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jnp.ndarray,
):

model = jaxsim_model_box

# Split the PRNG key.
key, subkey, subkey2 = jax.random.split(prng_key, num=3)

# Build the data of the model.
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jax.random.uniform(subkey2, shape=(3,)),
velocity_representation=velocity_representation,
standard_gravity=0.0,
soft_contacts_params=jaxsim.rbda.SoftContactsParams.build(K=0.0, D=0.0, mu=0.0),
)

# Generate a random linear force.
L_f = (
jax.random.uniform(subkey, shape=(model.number_of_links(), 6))
.at[:, 3:]
.set(jnp.zeros(3))
)

# Initialize a references object that simplifies handling external forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data0,
velocity_representation=velocity_representation,
)

# Apply a link forces to the base link.
references = references.apply_link_forces(
forces=jnp.atleast_2d(L_f),
link_names=model.link_names(),
model=model,
data=data0,
additive=False,
)

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model, data=data0, system_dynamics=js.ode.system_dynamics
)
)

# Initialize the integrator.
tf = 1.0
dt = 0.010
T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)

# Copy the initial data...
data = data0.copy()

# ... and step the simulation.
for t_ns in T:

data, integrator_state = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
link_forces=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
assert data.time() == t_ns / 1e9 + dt
assert data.base_position() == pytest.approx(
data0.base_position()
+ 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
rel=1e-4,
)