diff --git a/environment.yml b/environment.yml index 334dba22c..bc78942a2 100644 --- a/environment.yml +++ b/environment.yml @@ -8,10 +8,10 @@ dependencies: - python >= 3.12.0 - coloredlogs - jax >= 0.4.26 - - jaxopt >= 0.8.0 - jaxlib >= 0.4.26 - jaxlie >= 1.3.0 - jax-dataclasses >= 1.4.0 + - optax >= 0.2.3 - pptree - qpax - rod >= 0.3.3 diff --git a/pyproject.toml b/pyproject.toml index 843c85bfc..b5377e55d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,11 +45,11 @@ classifiers = [ dependencies = [ "coloredlogs", "jax >= 0.4.26", - "jaxopt >= 0.8.0", "jaxlib >= 0.4.26", "jaxlie >= 1.3.0", "jax_dataclasses >= 1.4.0", "pptree", + "optax >= 0.2.3", "qpax", "rod >= 0.3.3", "typing_extensions ; python_version < '3.12'", diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 085bdc50b..f6b410eb8 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -1,12 +1,13 @@ from __future__ import annotations import dataclasses +from collections.abc import Callable from typing import Any import jax import jax.numpy as jnp import jax_dataclasses -import jaxopt +import optax import jaxsim.api as js import jaxsim.typing as jtp @@ -297,24 +298,71 @@ def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: A = G + R b = CW_al_free_WC - a_ref - objective = lambda x: jnp.sum(jnp.square(A @ x + b)) + objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b)) - # Compute the 3D linear force in C[W] frame - opt = jaxopt.LBFGS( - fun=objective, - maxiter=self.parameters.max_iterations, - tol=self.parameters.tolerance, - maxls=30, - history_size=10, - max_stepsize=100.0, - ) + def run_optimization( + init_params: jtp.Array, + fun: Callable, + opt: optax.GradientTransformation, + maxiter: jtp.Int, + tol: jtp.Float, + **kwargs, + ): + value_and_grad_fn = optax.value_and_grad_from_state(fun) + + def step(carry): + params, state = carry + value, grad = value_and_grad_fn( + params, + state=state, + A=A, + b=b, + ) + updates, state = opt.update( + updates=grad, + state=state, + params=params, + value=value, + grad=grad, + value_fn=fun, + A=A, + b=b, + ) + params = optax.apply_updates(params, updates) + return params, state + + def continuing_criterion(carry): + _, state = carry + iter_num = optax.tree_utils.tree_get(state, "count") + grad = optax.tree_utils.tree_get(state, "grad") + err = optax.tree_utils.tree_l2_norm(grad) + return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol)) + + init_carry = (init_params, opt.init(init_params)) + final_params, final_state = jax.lax.while_loop( + continuing_criterion, step, init_carry + ) + return final_params, final_state init_params = ( K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ) + D[:, jnp.newaxis] * velocity ).flatten() - CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3) + # Compute the 3D linear force in C[W] frame + CW_f_Ci, _ = run_optimization( + init_params=init_params, + A=A, + b=b, + maxiter=self.parameters.max_iterations, + opt=optax.lbfgs( + memory_size=10, + ), + fun=objective, + tol=self.parameters.tolerance, + ) + + CW_f_Ci = CW_f_Ci.reshape((-1, 3)) def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array: W_Xf_CW = Adjoint.from_transform(