-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathpatch_an_edge.py
55 lines (47 loc) · 1.56 KB
/
patch_an_edge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#%%
import torch as t
from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.types import AblationType, PruneScores
from auto_circuit.utils.ablation_activations import src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model
from auto_circuit.utils.misc import repo_path_to_abs_path
device = t.device("cuda" if t.cuda.is_available() else "cpu")
model = load_tl_model("gpt2", device)
path = repo_path_to_abs_path("datasets/ioi/ioi_vanilla_template_prompts.json")
train_loader, test_loader = load_datasets_from_json(
model=model,
path=path,
device=device,
prepend_bos=True,
batch_size=16,
train_test_size=(128, 128),
return_seq_length=True,
tail_divergence=True,
)
model = patchable_model(
model,
factorized=True,
slice_output="last_seq",
seq_len=test_loader.seq_len,
separate_qkv=True,
kv_caches=(train_loader.kv_cache, test_loader.kv_cache),
device=device,
)
ablations = src_ablations(model, test_loader, AblationType.TOKENWISE_MEAN_CORRUPT)
patch_edges = [
"Resid Start->MLP 2",
"MLP 2->A2.4.Q",
"A2.4->Resid End",
]
with patch_mode(model, ablations, patch_edges):
for batch in test_loader:
patched_out = model(batch.clean)
attrution_patching_scores: PruneScores = mask_gradient_prune_scores(
model=model,
dataloader=test_loader,
official_edges=None,
grad_function="logit",
answer_function="avg_diff",
)