Skip to content

Commit

Permalink
Rerun black
Browse files Browse the repository at this point in the history
  • Loading branch information
AndReGeist committed Mar 7, 2024
1 parent 7f1cf0e commit 8ccd724
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
17 changes: 12 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle
from hitchhiking_rotations.cfgs import (get_cfg_pcd_to_pose, get_cfg_cube_image_to_pose, get_cfg_pose_to_cube_image,
get_cfg_pose_to_fourier)
from hitchhiking_rotations.cfgs import (
get_cfg_pcd_to_pose,
get_cfg_cube_image_to_pose,
get_cfg_pose_to_cube_image,
get_cfg_pose_to_fourier,
)

import numpy as np
import argparse
Expand All @@ -24,9 +28,12 @@
default="pose_to_cube_image",
help="Experiment Configuration",
)
parser.add_argument("--seed", type=int, default=0,
help="Random seed used during training, " +
"for pose_to_fourier the seed is used to select the target function.")
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed used during training, " + "for pose_to_fourier the seed is used to select the target function.",
)
args = parser.parse_args()

s = args.seed
Expand Down
42 changes: 25 additions & 17 deletions visu/figure_14.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
from hitchhiking_rotations.utils import RotRep

nb_max = 6
nb_values = range(1, nb_max+1)
files = [str(s) for nb in nb_values for s in Path(os.path.join(HITCHHIKING_ROOT_DIR, "results", f"pose_to_fourier_{nb}")).rglob("*result.npy")]

#files = [str(s) for s in Path(os.path.join(HITCHHIKING_ROOT_DIR, "results", "pose_to_fourier_1")).rglob("*result.npy")]
nb_values = range(1, nb_max + 1)
files = [
str(s)
for nb in nb_values
for s in Path(os.path.join(HITCHHIKING_ROOT_DIR, "results", f"pose_to_fourier_{nb}")).rglob("*result.npy")
]

# files = [str(s) for s in Path(os.path.join(HITCHHIKING_ROOT_DIR, "results", "pose_to_fourier_1")).rglob("*result.npy")]
results = [np.load(file, allow_pickle=True) for file in files]

# trainer_name
Expand Down Expand Up @@ -69,23 +73,27 @@
sns.set_style("whitegrid")
plt.rcParams.update({"font.size": 11})

plt.figure(figsize=(5.5,1.0))
g = sns.catplot(data=df, x="basis", y="score",
hue="method",
kind="box",
palette='Greens',
flierprops={"markeredgecolor": "grey"},
height=7.,
aspect=2.0)

sns.move_legend(g, "upper left", bbox_to_anchor=(.11, 0.98), ncol=3, title='Network input') # len(names)

for i in range(nb_max-1):
plt.figure(figsize=(5.5, 1.0))
g = sns.catplot(
data=df,
x="basis",
y="score",
hue="method",
kind="box",
palette="Greens",
flierprops={"markeredgecolor": "grey"},
height=7.0,
aspect=2.0,
)

sns.move_legend(g, "upper left", bbox_to_anchor=(0.11, 0.98), ncol=3, title="Network input") # len(names)

for i in range(nb_max - 1):
plt.axvline(0.5 + i, color="lightgrey", dashes=(2, 2))

plt.xlabel(f"Error - {selected_metric}")
plt.ylabel("")
plt.yscale('log')
plt.yscale("log")
plt.tight_layout()

out_p = os.path.join(HITCHHIKING_ROOT_DIR, "results", "figure_14.pdf")
Expand Down

0 comments on commit 8ccd724

Please sign in to comment.