-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathplot_hyperparameter_sweep.py
executable file
·149 lines (122 loc) · 6.2 KB
/
plot_hyperparameter_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""Code to generate plots for Extended Data Fig. 1."""
import os
import matplotlib
import matplotlib.pyplot as plt
import echonet
def main(root=os.path.join("output", "video"),
fig_root=os.path.join("figure", "hyperparameter"),
FRAMES=(1, 8, 16, 32, 64, 96, None),
PERIOD=(1, 2, 4, 6, 8)
):
"""Generate plots for Extended Data Fig. 1."""
echonet.utils.latexify()
os.makedirs(fig_root, exist_ok=True)
# Parameters for plotting length sweep
MAX = FRAMES[-2]
START = 1 # Starting point for normal range
TERM0 = 104 # Ending point for normal range
BREAK = 112 # Location for break
TERM1 = 120 # Starting point for "all" section
ALL = 128 # Location of "all" point
END = 135 # Ending point for "all" section
RATIO = (BREAK - START) / (END - BREAK)
# Set up figure
fig = plt.figure(figsize=(3 + 2.5 + 1.5, 2.75))
outer = matplotlib.gridspec.GridSpec(1, 3, width_ratios=[3, 2.5, 1.50])
ax = plt.subplot(outer[2]) # Legend
ax2 = plt.subplot(outer[1]) # Period plot
gs = matplotlib.gridspec.GridSpecFromSubplotSpec(
1, 2, subplot_spec=outer[0], width_ratios=[RATIO, 1], wspace=0.020) # Length plot
# Plot legend
for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3"],
matplotlib.colors.TABLEAU_COLORS):
ax.plot([float("nan")], [float("nan")], "-", color=color, label=model)
ax.plot([float("nan")], [float("nan")], "-", color="k", label="Pretrained")
ax.plot([float("nan")], [float("nan")], "--", color="k", label="Random")
ax.set_title("")
ax.axis("off")
ax.legend(loc="center")
# Plot length sweep (panel a)
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1], sharey=ax0)
print("FRAMES")
for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"],
matplotlib.colors.TABLEAU_COLORS):
for pretrained in [True, False]:
loss = [load(root, model, frames, 1, pretrained) for frames in FRAMES]
print(model, pretrained)
print(" ".join(list(map(lambda x: "{:.1f}".format(x) if x is not None else None, loss))))
l0 = loss[-2]
l1 = loss[-1]
ax0.plot(FRAMES[:-1] + (TERM0,),
loss[:-1] + [l0 + (l1 - l0) * (TERM0 - MAX) / (ALL - MAX)],
"-" if pretrained else "--", color=color)
ax1.plot([TERM1, ALL],
[l0 + (l1 - l0) * (TERM1 - MAX) / (ALL - MAX)] + [loss[-1]],
"-" if pretrained else "--", color=color)
ax0.scatter(list(map(lambda x: x if x is not None else ALL, FRAMES)), loss, color=color, s=4)
ax1.scatter(list(map(lambda x: x if x is not None else ALL, FRAMES)), loss, color=color, s=4)
ax0.set_xticks(list(map(lambda x: x if x is not None else ALL, FRAMES)))
ax1.set_xticks(list(map(lambda x: x if x is not None else ALL, FRAMES)))
ax0.set_xticklabels(list(map(lambda x: x if x is not None else "All", FRAMES)))
ax1.set_xticklabels(list(map(lambda x: x if x is not None else "All", FRAMES)))
# https://stackoverflow.com/questions/5656798/python-matplotlib-is-there-a-way-to-make-a-discontinuous-axis/43684155
# zoom-in / limit the view to different portions of the data
ax0.set_xlim(START, BREAK) # most of the data
ax1.set_xlim(BREAK, END)
# hide the spines between ax and ax2
ax0.spines['right'].set_visible(False)
ax1.spines['left'].set_visible(False)
ax1.get_yaxis().set_visible(False)
d = 0.015 # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax0.transAxes, color='k', clip_on=False, linewidth=1)
x0, x1, y0, y1 = ax0.axis()
scale = (y1 - y0) / (x1 - x0) / 2
ax0.plot((1 - scale * d, 1 + scale * d), (-d, +d), **kwargs) # top-left diagonal
ax0.plot((1 - scale * d, 1 + scale * d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
kwargs.update(transform=ax1.transAxes) # switch to the bottom 1xes
x0, x1, y0, y1 = ax1.axis()
scale = (y1 - y0) / (x1 - x0) / 2
ax1.plot((-scale * d, scale * d), (-d, +d), **kwargs) # top-right diagonal
ax1.plot((-scale * d, scale * d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
# ax0.xaxis.label.set_transform(matplotlib.transforms.blended_transform_factory(
# matplotlib.transforms.IdentityTransform(), fig.transFigure # specify x, y transform
# )) # changed from default blend (IdentityTransform(), a[0].transAxes)
ax0.xaxis.label.set_position((0.6, 0.0))
ax0.text(-0.05, 1.10, "(a)", transform=ax0.transAxes)
ax0.set_xlabel("Clip length (frames)")
ax0.set_ylabel("Validation Loss")
# Plot period sweep (panel b)
print("PERIOD")
for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
for pretrained in [True, False]:
loss = [load(root, model, 64 // period, period, pretrained) for period in PERIOD]
print(model, pretrained)
print(" ".join(list(map(lambda x: "{:.1f}".format(x) if x is not None else None, loss))))
ax2.plot(PERIOD, loss, "-" if pretrained else "--", marker=".", color=color)
ax2.set_xticks(PERIOD)
ax2.text(-0.05, 1.10, "(b)", transform=ax2.transAxes)
ax2.set_xlabel("Sampling Period (frames)")
ax2.set_ylabel("Validation Loss")
# Save figure
plt.tight_layout()
plt.savefig(os.path.join(fig_root, "hyperparameter.pdf"))
plt.savefig(os.path.join(fig_root, "hyperparameter.eps"))
plt.savefig(os.path.join(fig_root, "hyperparameter.png"))
plt.close(fig)
def load(root, model, frames, period, pretrained):
"""Loads best validation loss for specified hyperparameter choice."""
pretrained = ("pretrained" if pretrained else "random")
f = os.path.join(
root,
"{}_{}_{}_{}".format(model, frames, period, pretrained),
"log.csv")
with open(f, "r") as f:
for line in f:
if "Best validation loss " in line:
return float(line.split()[3])
raise ValueError("File missing information.")
if __name__ == "__main__":
main()