Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX Jacobian of from_axis_angle producing NaNs #339

Merged
merged 4 commits into from
Jan 16, 2025

Conversation

ConnorTingley
Copy link
Contributor

@ConnorTingley ConnorTingley commented Jan 15, 2025

Taking the JAX Jacobian of the from_axis_angle function was producing NaNs when vector = [0,0,0]

I found this: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
which lays out how to fix this issue by making the internal function safe when the input is 0.

I just added an extra jnp.where() to check before dividing.


📚 Documentation preview 📚: https://jaxsim--339.org.readthedocs.build//339/

@xela-95
Copy link
Member

xela-95 commented Jan 15, 2025

Thanks a lot for you interest in Jaxsim and you contribution @ConnorTingley!! We will review soon you PR.

Looking a bit at past history of this function, @diegoferigo updated this function in #277 to address the same AD issue, then we updated it more recently in #319, maybe introducing some sort of regression. What do you think @flferretti ?

@CarlottaSartore
Copy link
Contributor

HI @ConnorTingley!, Thanks so much for opening this PR and contributing to JAXSim! We really appreciate you taking the time to dive into this issue and propose a fix.

Would you mind sharing a minimal working example that reproduces the error you're seeing?

When we call from_axis_angle with vector = [0,0,0], it seems to correctly return the identity matrix, as far as we understood your issue is more related to the usage of such a function when computing gradients. Having an example of your error would help us better understand the context and ensure your fix works smoothly for all use cases.

@CarlottaSartore CarlottaSartore self-requested a review January 15, 2025 09:28
@flferretti
Copy link
Collaborator

Hi @ConnorTingley, thanks for reporting the issue! I tried to make an example for which the issue is verified:

import jax
from jaxsim.math import Rotation

vector = jax.numpy.zeros(3)

jac = jax.jacobian(Rotation.from_axis_angle)(vector)

print(jac)

>>> Array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float64)

You're right, a matrix of NaNs is produced.

We will work on that to see if the issue is propagated to other parts of the code and will get back to you with a potential low-level solution.

Thanks again for working on this! We'll get back to you ASAP

Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this @ConnorTingley! Would you mind removing the jnp.where at the return? At this point it is not needed, as the function would still return a jnp.eye(3) when the vector is a jnp.zeros(3). For the rest, LGTM

         vector = vector.squeeze()
 
-        def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
+        theta = safe_norm(vector)
 
-            v = axis
-            theta = safe_norm(v)
+        s = jnp.sin(theta)
+        c = jnp.cos(theta)
 
-            s = jnp.sin(theta)
-            c = jnp.cos(theta)
+        c1 = 2 * jnp.sin(theta / 2.0) ** 2
 
-            c1 = 2 * jnp.sin(theta / 2.0) ** 2
+        safe_theta = jnp.where(theta == 0, 1.0, theta)
+        u = vector / safe_theta
+        u = jnp.vstack(u.squeeze())
 
-            safe_theta = jnp.where(theta == 0, 1.0, theta)
-            u = v / safe_theta
-            u = jnp.vstack(u.squeeze())
+        R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T
 
-            R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T
-
-            return R.transpose()
-
-        return jnp.where(
-            jnp.allclose(vector, 0.0),
-            # Return an identity rotation matrix when the input vector is zero.
-            jnp.eye(3),
-            theta_is_not_zero(axis=vector),
-        )
+        return R.T

@flferretti
Copy link
Collaborator

@ConnorTingley Just a tip: for an easier development, I suggest you to install the pre-commit hooks with pre-commit install. This will run the same checks that are in CI for each commit. You can find more info in the Contributing Guide

@ConnorTingley
Copy link
Contributor Author

@ConnorTingley Just a tip: for an easier development, I suggest you to install the pre-commit hooks with pre-commit install. This will run the same checks that are in CI for each commit. You can find more info in the Contributing Guide

Thanks, will do! I'm new to this :)

Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @ConnorTingley, LGTM

@flferretti flferretti enabled auto-merge (squash) January 16, 2025 10:28
@flferretti flferretti merged commit 30892cd into ami-iit:main Jan 16, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants