Skip to content

Commit

Permalink
Tidy plots
Browse files Browse the repository at this point in the history
  • Loading branch information
thecharlieblake committed Aug 13, 2023
1 parent b30cdca commit e139288
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions unit_scaling/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import colorsys
import logging
import re
from math import isnan
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

Expand Down Expand Up @@ -163,11 +164,12 @@ def plot(
prune_same_scale: bool = True,
show_arrows: bool = True,
show_error_bars: bool = True,
show_zero_tensors: bool = False,
xmin: Optional[float] = None,
xmax: Optional[float] = None,
) -> matplotlib.axes.Axes:
"""Generate a :mod:`matplotlib` plot visualising the scales in the forward (and
optionally backward) pass of all tensors in an arbitrary FX graph.
optionally backward) pass of all tensors in an FX graph.
The input graph is intended to have been generated by applying
:func:`unit_scaling.transforms.track_scales` to an arbitrary
Expand Down Expand Up @@ -255,10 +257,33 @@ def plot(
p.xaxis.set_ticks_position("top")
p.xaxis.set_label_position("top")
p.xaxis.grid(False)
p.legend(loc="upper right").set_title("")
if title:
p.set_title(title, fontweight="bold")

label_map = {
"fwd": "forward pass",
"bwd": "backward pass",
"False": "non-weight tensor",
"True": "weight tensor",
}
new_legend_labels = {
label_map[l]: h
for h, l in zip(*p.get_legend_handles_labels())
if l in label_map
}
p.legend(
new_legend_labels.values(), new_legend_labels.keys(), loc="upper right"
).set_title("")

def _rename(s: str) -> str:
s = re.sub(r"(^|_)\d+", "", s)
s = s.replace("self_", "")
s = s.replace("transformer_h_", "")
s = s.replace("transformer_", "")
return s

p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])

plt.axvline(2**-14, color="grey", dashes=(3, 1))
plt.axvline(2**-7, color="grey", dashes=(1, 3))
plt.axvline(240, color="grey", dashes=(1, 3))
Expand Down Expand Up @@ -387,6 +412,9 @@ def draw_arrow(node_a: Node, node_b: Node, direction: str) -> None:
color = colors[1]
a_x, a_y, b_x, b_y = b_x, b_y, a_x, a_y

if annotation == "0" and not show_zero_tensors:
return

plt.annotate(
annotation,
color=color,
Expand Down Expand Up @@ -480,7 +508,7 @@ def visualiser(
inputs, attn_mask, labels = example_batch(
tokenizer, batch_size, seq_len, dataset_path, dataset_name
)
tracked_model = track_scales(model)
tracked_model = track_scales(model.to("cpu"))
_, loss = tracked_model(inputs, labels)
if backward:
loss.backward()
Expand Down

0 comments on commit e139288

Please sign in to comment.