-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX Jacobian of from_axis_angle producing NaNs #339
Conversation
Thanks a lot for you interest in Jaxsim and you contribution @ConnorTingley!! We will review soon you PR. Looking a bit at past history of this function, @diegoferigo updated this function in #277 to address the same AD issue, then we updated it more recently in #319, maybe introducing some sort of regression. What do you think @flferretti ? |
HI @ConnorTingley!, Thanks so much for opening this PR and contributing to JAXSim! We really appreciate you taking the time to dive into this issue and propose a fix. Would you mind sharing a minimal working example that reproduces the error you're seeing? When we call |
Hi @ConnorTingley, thanks for reporting the issue! I tried to make an example for which the issue is verified: import jax
from jaxsim.math import Rotation
vector = jax.numpy.zeros(3)
jac = jax.jacobian(Rotation.from_axis_angle)(vector)
print(jac)
>>> Array([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], dtype=float64) You're right, a matrix of NaNs is produced. We will work on that to see if the issue is propagated to other parts of the code and will get back to you with a potential low-level solution. Thanks again for working on this! We'll get back to you ASAP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this @ConnorTingley! Would you mind removing the jnp.where
at the return? At this point it is not needed, as the function would still return a jnp.eye(3)
when the vector is a jnp.zeros(3)
. For the rest, LGTM
vector = vector.squeeze()
- def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
+ theta = safe_norm(vector)
- v = axis
- theta = safe_norm(v)
+ s = jnp.sin(theta)
+ c = jnp.cos(theta)
- s = jnp.sin(theta)
- c = jnp.cos(theta)
+ c1 = 2 * jnp.sin(theta / 2.0) ** 2
- c1 = 2 * jnp.sin(theta / 2.0) ** 2
+ safe_theta = jnp.where(theta == 0, 1.0, theta)
+ u = vector / safe_theta
+ u = jnp.vstack(u.squeeze())
- safe_theta = jnp.where(theta == 0, 1.0, theta)
- u = v / safe_theta
- u = jnp.vstack(u.squeeze())
+ R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T
- R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T
-
- return R.transpose()
-
- return jnp.where(
- jnp.allclose(vector, 0.0),
- # Return an identity rotation matrix when the input vector is zero.
- jnp.eye(3),
- theta_is_not_zero(axis=vector),
- )
+ return R.T
@ConnorTingley Just a tip: for an easier development, I suggest you to install the |
Thanks, will do! I'm new to this :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again @ConnorTingley, LGTM
Taking the JAX Jacobian of the
from_axis_angle
function was producing NaNs whenvector = [0,0,0]
I found this: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
which lays out how to fix this issue by making the internal function safe when the input is 0.
I just added an extra jnp.where() to check before dividing.
📚 Documentation preview 📚: https://jaxsim--339.org.readthedocs.build//339/