Skip to content

Commit

Permalink
Merge pull request #244 from ami-iit/feature/jaxopt_to_optax
Browse files Browse the repository at this point in the history
Switch from `jaxopt` to `optax` in relaxed rigid contact model
  • Loading branch information
flferretti authored Oct 3, 2024
2 parents 6e23f66 + de1bc2a commit 7a2d193
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ dependencies:
- python >= 3.12.0
- coloredlogs
- jax >= 0.4.26
- jaxopt >= 0.8.0
- jaxlib >= 0.4.26
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- optax >= 0.2.3
- pptree
- qpax
- rod >= 0.3.3
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ classifiers = [
dependencies = [
"coloredlogs",
"jax >= 0.4.26",
"jaxopt >= 0.8.0",
"jaxlib >= 0.4.26",
"jaxlie >= 1.3.0",
"jax_dataclasses >= 1.4.0",
"pptree",
"optax >= 0.2.3",
"qpax",
"rod >= 0.3.3",
"typing_extensions ; python_version < '3.12'",
Expand Down
72 changes: 60 additions & 12 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable
from typing import Any

import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxopt
import optax

import jaxsim.api as js
import jaxsim.typing as jtp
Expand Down Expand Up @@ -297,24 +298,71 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
A = G + R
b = CW_al_free_WC - a_ref

objective = lambda x: jnp.sum(jnp.square(A @ x + b))
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))

# Compute the 3D linear force in C[W] frame
opt = jaxopt.LBFGS(
fun=objective,
maxiter=self.parameters.max_iterations,
tol=self.parameters.tolerance,
maxls=30,
history_size=10,
max_stepsize=100.0,
)
def run_optimization(
init_params: jtp.Array,
fun: Callable,
opt: optax.GradientTransformation,
maxiter: jtp.Int,
tol: jtp.Float,
**kwargs,
):
value_and_grad_fn = optax.value_and_grad_from_state(fun)

def step(carry):
params, state = carry
value, grad = value_and_grad_fn(
params,
state=state,
A=A,
b=b,
)
updates, state = opt.update(
updates=grad,
state=state,
params=params,
value=value,
grad=grad,
value_fn=fun,
A=A,
b=b,
)
params = optax.apply_updates(params, updates)
return params, state

def continuing_criterion(carry):
_, state = carry
iter_num = optax.tree_utils.tree_get(state, "count")
grad = optax.tree_utils.tree_get(state, "grad")
err = optax.tree_utils.tree_l2_norm(grad)
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))

init_carry = (init_params, opt.init(init_params))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return final_params, final_state

init_params = (
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
+ D[:, jnp.newaxis] * velocity
).flatten()

CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
# Compute the 3D linear force in C[W] frame
CW_f_Ci, _ = run_optimization(
init_params=init_params,
A=A,
b=b,
maxiter=self.parameters.max_iterations,
opt=optax.lbfgs(
memory_size=10,
),
fun=objective,
tol=self.parameters.tolerance,
)

CW_f_Ci = CW_f_Ci.reshape((-1, 3))

def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
W_Xf_CW = Adjoint.from_transform(
Expand Down

0 comments on commit 7a2d193

Please sign in to comment.