Skip to content

Commit

Permalink
fixed reshape bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AndReGeist committed Apr 15, 2024
1 parent 001d165 commit 2212584
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions visu/gso_vs_svd.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,70 @@
import os
import numpy as np
import pandas as pd
from scipy.spatial.transform import Rotation
import jax
import jax.numpy as jnp
from einops import rearrange
import matplotlib.pyplot as plt
import matplotlib.colors as matcolors
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns


from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from mpl_toolkits.mplot3d import Axes3D
# from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
# from mpl_toolkits.mplot3d import Axes3D

# import lovely_jax as lj
# lj.monkey_patch()

N = int(1e3) # Number of randomly sampled rotations and predicted matrices
plot_frames = False # Plot frames before/after SVD/GSO transform
plot_frames_with_grads = True
plot_frames_with_grads = False
plot_grads = False # Plot ratio between gradients entries
plot_ratios = False # Plot 2D scatter plot of gradients
plot_ratios = True # Plot 2D scatter plot of gradients
plot_condnums = False # Plot condition numbers of Hessian matrices and max eigen values


# Helper functions to rearrange 6D vectors to 3x2 matrices
# where columns denote the representation vectors
def mat2vec(mat, dimb=3):
# Same as table.T.reshape(-1, 1)
return rearrange(mat, "a b -> (b a)", a=3, b=dimb)


def vec2mat(vec, dimb=3):
return rearrange(vec, "(b a) -> a b", a=3, b=dimb)


def bmat2bvec(mat, dimb=3):
return rearrange(mat, "i a b -> i (b a)", a=3, b=dimb)


def bvec2bmat(vec, dimb=3):
return rearrange(vec, "i (b a) -> i a b", a=3, b=dimb)


rot = Rotation.random(N) # generate N random rotations
rotmats = jnp.array(rot.as_matrix())
rotmats_vec = bmat2bvec(rotmats)

# predmats = rotmats + 1e-1 * jax.random.normal(key=jax.random.PRNGKey(42), shape=(N, 3, 3))
predmats = jax.random.uniform(key=jax.random.PRNGKey(42), shape=(N, 3, 3), minval=-2.0, maxval=2.0)
# predmats = 2*jax.random.normal(key=jax.random.PRNGKey(1), shape=(N, 3, 3))
predmats_vec = bmat2bvec(predmats)

assert (
(
(
(
jnp.allclose(predmats[0, :, 0], predmats_vec[0, :3])
and jnp.allclose(predmats[0, :, 1], predmats_vec[0, 3:6])
)
and jnp.allclose(predmats[0, 2, 1], predmats_vec[0, 5])
)
and jnp.allclose(predmats[0, :, :], bvec2bmat(predmats_vec)[0])
)
and jnp.allclose(predmats[0, :, :], vec2mat(predmats_vec[0]))
) and jnp.allclose(mat2vec(predmats[0, :, :]), predmats_vec[0]), "Conversion functions are not working"


@jax.jit
Expand Down Expand Up @@ -89,15 +125,15 @@ def plot_matrix(
ax.text(mat[0][i], mat[1][i], mat[2][i], f"$e_{i + 1}$", color="black")


def plot_matrices(ax, r_list, labels, off_list=None):
def plot_matrices(ax, mat_list, labels, off_list=None):
ax.set(xlim=(-1.25, 1.25), ylim=(-1.25, 1.25), zlim=(-1.25, 1.25))
colors = ["r", "g", "b", "y", "m", "c"]

if off_list is None:
off_list = [jnp.zeros((3, 3)) for _ in range(len(r_list))]
off_list = [jnp.zeros((3, 3)) for _ in range(len(mat_list))]

for i in range(len(r_list)):
plot_matrix(ax, r_list[i], colors[i], label=labels[i], offset=off_list[i])
for i in range(len(mat_list)):
plot_matrix(ax, mat_list[i], colors[i], label=labels[i], offset=off_list[i])

ax.set_xlabel("X")
ax.set_ylabel("Y")
Expand Down Expand Up @@ -129,11 +165,11 @@ def norm(mat1: jnp.ndarray, mat2: jnp.ndarray) -> jnp.ndarray:


def norm_gso(predmat_vec, rotmat):
return norm(rotmat, gso(predmat_vec.reshape(3, 2)))
return norm(rotmat, gso(vec2mat(predmat_vec, dimb=2)))


def norm_svd(predmat_vec, rotmat):
return norm(rotmat, svd(predmat_vec.reshape(3, 3)))
return norm(rotmat, svd(vec2mat(predmat_vec)))


def hess_gso(rotmat: jnp.ndarray, predmat_vec: jnp.ndarray) -> jnp.ndarray:
Expand All @@ -151,8 +187,8 @@ def hess_svd(rotmat: jnp.ndarray, predmat_vec: jnp.ndarray) -> jnp.ndarray:
loss_svd = jax.vmap(norm, (0, 0))(rotmats, pred_svd)
loss_gso = jax.vmap(norm, (0, 0))(rotmats, pred_gso)

grads_gso = jax.vmap(jax.grad(norm_gso), (0, 0))(predmats[:, :, :2].reshape(N, 6), rotmats)
grads_svd = jax.vmap(jax.grad(norm_svd), (0, 0))(predmats.reshape(N, 9), rotmats)
grads_gso = jax.vmap(jax.grad(norm_gso), (0, 0))(predmats_vec[:, :6], rotmats)
grads_svd = jax.vmap(jax.grad(norm_svd), (0, 0))(predmats_vec, rotmats)
gradnorm_gso = jnp.linalg.norm(grads_gso, axis=-1)
gradnorm_svd = jnp.linalg.norm(grads_svd, axis=-1)

Expand All @@ -174,12 +210,12 @@ def hess_svd(rotmat: jnp.ndarray, predmat_vec: jnp.ndarray) -> jnp.ndarray:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d", proj_type="ortho")
g_gso = jnp.c_[
g_gso.reshape(3, 2),
vec2mat(g_gso, dimb=2),
np.zeros(
3,
),
]
r_list = [r0, r_svd, r_gso, -1 * g_svd.reshape(3, 3), -1 * g_gso]
r_list = [r0, r_svd, r_gso, -1 * vec2mat(g_svd), -1 * g_gso]
off_list = [jnp.zeros((3, 3)), jnp.zeros((3, 3)), jnp.zeros((3, 3))] + [r_svd, r_gso]

plot_matrices(ax, r_list, ["rotmat", "svd", "gso", "grad_svd", "grad_gso"], off_list)
Expand All @@ -189,8 +225,8 @@ def hess_svd(rotmat: jnp.ndarray, predmat_vec: jnp.ndarray) -> jnp.ndarray:
# COMPUTE HESSIANS
###############################################################################

hessmats_gso = jax.vmap(hess_gso, (0, 0))(rotmats, predmats[:, :, :2].reshape(N, 6))
hessmats_svd = jax.vmap(hess_svd, (0, 0))(rotmats, predmats.reshape(N, 9))
hessmats_gso = jax.vmap(hess_gso, (0, 0))(rotmats, predmats_vec[:, :6])
hessmats_svd = jax.vmap(hess_svd, (0, 0))(rotmats, predmats_vec)

eig_gso = jnp.sort(jax.vmap(jnp.linalg.eig)(hessmats_gso)[0], axis=-1)
eig_svd = jnp.sort(jax.vmap(jnp.linalg.eig)(hessmats_svd)[0], axis=-1)
Expand All @@ -213,6 +249,7 @@ def hess_svd(rotmat: jnp.ndarray, predmat_vec: jnp.ndarray) -> jnp.ndarray:
"condnums": np.r_[condnums_svd, condnums_gso].flatten(),
"eigmin": np.r_[eig_svd[:, 0], eig_gso[:, 0]].flatten(),
"eigmax": np.r_[eig_svd[:, -1], eig_gso[:, -1]].flatten(),
# Note: eig_svd[:, -1] is the max eigenvalue as the eigenvalues are sorted
}
)

Expand Down Expand Up @@ -315,11 +352,12 @@ def plot_2D_paper():
# cbar=True,
clip=((None, None), (None, None)),
antialiased=True,
# thresh=0.1,
)
for c in cnt.collections:
c.set_edgecolor("face")

axs[i].set_ylim(0.08, 20)
# axs[i].set_ylim(0.08, 20)
axs[i].set_xlabel("L2 loss")
axs[i].set_ylabel(None)
axs[i].set_title(datalabels[i])
Expand Down

0 comments on commit 2212584

Please sign in to comment.