Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Filippo Luca Ferretti <[email protected]>
  • Loading branch information
diegoferigo and flferretti authored Mar 14, 2024
1 parent 756b69f commit 61dd60d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 15 deletions.
3 changes: 1 addition & 2 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
.squeeze()
.astype(int)
)
else:
return jnp.array(-1).astype(int)
return jnp.array(-1).astype(int)


def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
Expand Down
12 changes: 8 additions & 4 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
)

return KynDynParameters(
link_names=tuple([l.name for l in ordered_links]),
link_names=tuple(l.name for l in ordered_links),
parent_array=parent_array,
support_body_array_bool=support_body_array_bool,
link_parameters=link_parameters,
Expand Down Expand Up @@ -462,8 +462,10 @@ def build_from_inertial_parameters(

return LinkParameters(
mass=jnp.array(m).squeeze().astype(float),
I=jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze()).astype(float),
com=jnp.atleast_1d(c.squeeze()).astype(float),
inertia_elements=jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze()).astype(
float
),
center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float),
)

@staticmethod
Expand All @@ -473,7 +475,9 @@ def build_from_flat_parameters(parameters: jtp.VectorLike) -> LinkParameters:
c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float)
I = jnp.atleast_1d(parameters[4:].squeeze()).astype(float)

return LinkParameters(mass=m, I=I[jnp.triu_indices(3)], com=c)
return LinkParameters(
mass=m, inertia_elements=I[jnp.triu_indices(3)], center_of_mass=c
)

@staticmethod
def parameters(params: LinkParameters) -> jtp.Vector:
Expand Down
3 changes: 1 addition & 2 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
.squeeze()
.astype(int)
)
else:
return jnp.array(-1).astype(int)
return jnp.array(-1).astype(int)


def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
Expand Down
3 changes: 1 addition & 2 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.physics.algos.soft_contacts
import jaxsim.typing as jtp
from jaxsim import VelRepr, integrators
from jaxsim import VelRepr
from jaxsim.integrators.common import Time
from jaxsim.math.quaternion import Quaternion
from jaxsim.physics.algos.soft_contacts import SoftContactsState
Expand Down
8 changes: 3 additions & 5 deletions src/jaxsim/physics/algos/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,9 @@ def process_point_kinematics(
return W_p_Ci, CW_vl_WCi

# Process all the collidable points in parallel
W_p_Ci, CW_vl_WC = jax.vmap(
lambda Li_p_C, parent_body: process_point_kinematics(
Li_p_C=Li_p_C, parent_body=parent_body
)
)(model.gc.point, jnp.array(model.gc.body))
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(Li_p_C, parent_body)(
model.gc.point, jnp.array(model.gc.body)
)

return W_p_Ci.transpose(), CW_vl_WC.transpose()

Expand Down

0 comments on commit 61dd60d

Please sign in to comment.