Skip to content

Commit

Permalink
Replaced fourier data egenrator with initial JAX implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
AndReGeist committed Mar 26, 2024
1 parent f1a669c commit 29ff77e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 35 deletions.
10 changes: 5 additions & 5 deletions hitchhiking_rotations/cfgs/cfg_pose_to_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_cfg_pose_to_fourier(device, nb, nf):
return {
"verbose": False,
"batch_size": 64,
"epochs": 400,
"epochs": 250,
"training_data": {
"_target_": "hitchhiking_rotations.datasets.PoseToFourierDataset",
"mode": "train",
Expand All @@ -46,10 +46,10 @@ def get_cfg_pose_to_fourier(device, nb, nf):
"nb": nb,
"nf": nf,
},
"model9": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 9, "output_dim": 1},
"model6": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 6, "output_dim": 1},
"model4": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 4, "output_dim": 1},
"model3": {"_target_": "hitchhiking_rotations.models.MLP", "input_dim": 3, "output_dim": 1},
"model9": {"_target_": "hitchhiking_rotations.models.MLP2", "input_dim": 9, "output_dim": 1},
"model6": {"_target_": "hitchhiking_rotations.models.MLP2", "input_dim": 6, "output_dim": 1},
"model4": {"_target_": "hitchhiking_rotations.models.MLP2", "input_dim": 4, "output_dim": 1},
"model3": {"_target_": "hitchhiking_rotations.models.MLP2", "input_dim": 3, "output_dim": 1},
"logger": {
"_target_": "hitchhiking_rotations.utils.OrientationLogger",
"metrics": ["l2"],
Expand Down
62 changes: 32 additions & 30 deletions hitchhiking_rotations/datasets/fourier_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import roma
import pandas as pd
import seaborn as sns
import jax
import jax.numpy as jnp
import equinox as eqx

jax.config.update("jax_default_device", jax.devices("cpu")[0])

from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle, load_pickle
Expand Down Expand Up @@ -44,26 +47,25 @@ def __getitem__(self, idx):
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.features[idx]


class random_fourier_function:
def __init__(self, n_basis, seed, L=np.pi):
np.random.seed(seed)
self.n_basis = n_basis
self.L = L
self.A = np.random.normal(size=n_basis)
self.B = np.random.normal(size=n_basis)
self.matrix = np.random.normal(size=(1, 9))
def random_fourier_function(x, nb, seed):
key = jax.random.PRNGKey(seed)
key1, key2, key3 = jax.random.split(key, 3)
model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=key1)
A = jax.random.normal(key=key2, shape=(nb,))
B = jax.random.normal(key=key3, shape=(nb,))

input = model(x)
fFs = 0.0
for k in range(len(A)):
fFs += A[k] * jnp.cos((k + 1) * jnp.pi * input) + B[k] * jnp.sin((k + 1) * jnp.pi * input)
return fFs

def __call__(self, x):
fFs = 0.0
for k in range(len(self.A)):
input = np.matmul(self.matrix, x)
fFs += self.A[k] * np.cos((k + 1) * np.pi * input / self.L) + self.B[k] * np.sin(
(k + 1) * np.pi * input / self.L
)
return fFs

def input_to_fourier(self, x):
return np.matmul(self.matrix, x)
def input_to_fourier(x, seed):
key = jax.random.PRNGKey(seed)
key1, key2, key3 = jax.random.split(key, 3)
model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=key1)
return model(x)


def batch_normalize(arr):
Expand All @@ -79,17 +81,15 @@ def create_data(N_points, nb, seed):
Args:
N_points: Number of random rotations to generate
nb: Number of fourier basis that form the target function
seed: Used to randomly initialize fourier function coefficients
seed: Used to randomly initialize fourier function
Returns:
rots: Random rotations
features: Target function evaluated at rots
"""
np.random.seed(seed)
rots = Rotation.random(N_points)
inputs = rots.as_matrix().reshape(N_points, -1)
four_func = random_fourier_function(nb, seed)
features = np.apply_along_axis(four_func, 1, inputs)

features = np.array(jax.vmap(random_fourier_function, in_axes=[0, None, None])(inputs, nb, seed).reshape(-1, 1))
features = batch_normalize(features)
return rots.as_quat().astype(np.float32), features.astype(np.float32)

Expand All @@ -116,9 +116,8 @@ def plot_fourier_func(nb, seed):
"""Plot the target function."""
rots = Rotation.random(400)
inputs = rots.as_matrix().reshape(400, -1)
four_func = random_fourier_function(nb, seed)
four_in = np.apply_along_axis(four_func.input_to_fourier, 1, inputs)
features = np.apply_along_axis(four_func, 1, inputs)
four_in = np.array(jax.vmap(input_to_fourier, [0, None])(inputs, seed))
features = np.array(jax.vmap(random_fourier_function, [0, None, None])(inputs, nb, seed))
features2 = batch_normalize(features)
sorted_indices = np.argsort(four_in, axis=0)

Expand All @@ -127,15 +126,18 @@ def plot_fourier_func(nb, seed):
plt.plot(
four_in[sorted_indices].flatten(), features2[sorted_indices].flatten(), linestyle="-", color="red", marker=None
)
plt.title(f"nb: {nb}, seed: {seed},\n matrix: {four_func.matrix}")
plt.title(f"nb: {nb}, seed: {seed}")
plt.show()


if __name__ == "__main__":
# Analyze created data
for b in range(1, 6):
for s in range(0, 3):
for b in range(1, 7):
for s in range(0, 6):
# rots, features = create_data(N_points=100, nb=b, seed=s)
# data_stats(rots, features)
# plot_fourier_data(rots, features)
print("MLP PyTree used to create Fourier function inputs:")
model = eqx.nn.MLP(in_size=9, out_size=1, width_size=50, depth=1, key=jax.random.PRNGKey(42))
eqx.tree_pprint(model)
plot_fourier_func(b, s)

0 comments on commit 29ff77e

Please sign in to comment.