Skip to content

Commit

Permalink
Improve graphing utils
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Jan 17, 2025
1 parent 8c0df93 commit 661920d
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions sae_bench/sae_bench_utils/graphing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"jumprelu": "X",
"topk": "^",
"batch_topk": "s",
"p_anneal": "*",
"p_anneal": "P",
"matryoshka_batch_topk": "*",
"gated": "d",
}
Expand Down Expand Up @@ -459,6 +459,13 @@ def update_trainer_markers_and_colors(
all_trainers = {v["sae_class"] for v in results.values()}
new_trainers = all_trainers - existing_trainers

for trainer in all_trainers:
if trainer in trainer_markers:
if trainer_markers[trainer] in available_markers:
available_markers.remove(trainer_markers[trainer])
if trainer_colors[trainer] in available_colors:
available_colors.remove(trainer_colors[trainer])

for trainer in new_trainers:
trainer_markers[trainer] = available_markers.pop(0)
trainer_colors[trainer] = available_colors.pop(0)
Expand Down Expand Up @@ -645,6 +652,7 @@ def plot_2var_graph(
trainer_markers: Optional[dict[str, str]] = None,
trainer_colors: Optional[dict[str, str]] = None,
return_fig: bool = False,
connect_points: bool = False, # New parameter to control line connections
):
if not trainer_markers:
trainer_markers = TRAINER_MARKERS
Expand All @@ -656,12 +664,8 @@ def plot_2var_graph(
results, trainer_markers, trainer_colors
)

# Extract data from results
l0_values = [data[x_axis_key] for data in results.values()]
custom_metric_values = [data[custom_metric] for data in results.values()]

# Create the scatter plot
fig, ax = plt.subplots(figsize=(10, 6))
# Create the scatter plot with extra width for legend
fig, ax = plt.subplots(figsize=(12, 6))

handles, labels = [], []

Expand All @@ -675,6 +679,22 @@ def plot_2var_graph(
l0_values = [data[x_axis_key] for data in trainer_data.values()]
custom_metric_values = [data[custom_metric] for data in trainer_data.values()]

# Sort points by l0 values for proper line connection
if connect_points and len(l0_values) > 1:
points = sorted(zip(l0_values, custom_metric_values))
l0_values = [p[0] for p in points]
custom_metric_values = [p[1] for p in points]

# Add connecting line
ax.plot(
l0_values,
custom_metric_values,
color=trainer_colors[trainer],
linestyle="-",
alpha=0.5,
zorder=1, # Ensure lines are plotted behind points
)

# Plot data points
scatter = ax.scatter(
l0_values,
Expand All @@ -684,6 +704,7 @@ def plot_2var_graph(
label=trainer,
color=trainer_colors[trainer],
edgecolor="black",
zorder=2, # Ensure points are plotted on top of lines
)

# Create custom legend handle with both marker and color
Expand All @@ -703,12 +724,16 @@ def plot_2var_graph(
ax.set_ylabel(y_label)
ax.set_title(title)

# x log
ax.set_xscale("log")

if baseline_value is not None:
ax.axhline(baseline_value, color="red", linestyle="--", label=baseline_label)
labels.append(baseline_label)
handles.append(Line2D([0], [0], color="red", linestyle="--", label=baseline_label))

ax.legend(handles, labels, loc=legend_location)
# Place legend outside the plot on the right
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))

# Set axis limits
if xlims:
Expand Down

0 comments on commit 661920d

Please sign in to comment.