Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualization Script for Intermediate Buffers in Attention Estimation of BERT (Fig. 10 BERT ver.) #10

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ plots/
submission/
submission*.zip

hello.png
hello.png
2 changes: 1 addition & 1 deletion src/main/visualize/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def main(
'evaluate': args.evaluate
})

main(**kwargs)
main(**kwargs)
88 changes: 88 additions & 0 deletions src/main/visualize/glue_for_fig10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
import cv2
import os
import argparse

from ...utils import batch_to
from ...models import perlin_bert
from ...trainer.perlin_trainer import GlueTrainer, add_perlin_model_options, parse_perlin_model_options

from .common import (
gather_fixed_batch,
)

def main(
subset = 'mnli',
checkpoint_path = None,
evaluate = False,
**kwargs
):
trainer = GlueTrainer(
subset=subset,
**kwargs
)
trainer.load(path=checkpoint_path)

batch = gather_fixed_batch(trainer.valid_loader, 10)
batch = batch_to(batch, trainer.device)

teacher = trainer.base_model
bert = trainer.model

teacher.eval()
bert.eval()

with torch.no_grad():
teacher(**batch)
batch['teacher'] = teacher
bert(**batch)

attentions = []
for module in bert.modules():
if isinstance(module, perlin_bert.BertSelfAttention):
teacher_attn = module.teacher_attention_prob
estimated_attn = module.last_perlin_estimated_probs
estimated_attn_m = module.perlin_output.estimated_attention_probs_m
dense_attn = module.last_perlin_dense_probs
partial_attn = module.last_perlin_partial_probs
partial_attention_mask=module.perlin_output.partial_attention_mask
partial_attention_mask_m=module.perlin_output.partial_attention_mask_m
attentions.append({
'teacher_attn': teacher_attn.cpu(),
'estimated_attn': estimated_attn.cpu(),
'dense_attn': dense_attn.cpu(),
'partial_attn': partial_attn.cpu(),
'estimated_attn_m':estimated_attn_m.cpu(),
'partial_attention_mask':partial_attention_mask.cpu(),
'partial_attention_mask_m':partial_attention_mask_m.cpu()
})
torch.save({
'teacher_attn': attentions[1]['teacher_attn'],
'estimated_attn':attentions[1]['estimated_attn'],
'estimated_attn_m':attentions[1]['estimated_attn_m'],
'dense_attn': attentions[1]['dense_attn'],
'partial_attn':attentions[1]['partial_attn'],
'partial_attention_mask':attentions[1]['partial_attention_mask'],
'partial_attention_mask_m':attentions[1]['partial_attention_mask_m'],
'token_length': batch['attention_mask'][7].sum().item()},
'./debug/bert_viz.pth') # layer1 'estimated_attn_m':attentions[1]['estimated_attn_m'],

if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument('--subset', type=str, default='mnli')
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--evaluate', action='store_true')
add_perlin_model_options(parser)

args = parser.parse_args()
print(args)

kwargs = parse_perlin_model_options(args)
kwargs.update({
'subset': args.subset,
'checkpoint_path': args.checkpoint,
'evaluate': args.evaluate
})

main(**kwargs)
2 changes: 1 addition & 1 deletion src/models/perlin_attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def register_default_config(config: PerlinAttentionConfig):

def get_default_config() -> PerlinAttentionConfig:
global DEFAULT_CONFIG
return DEFAULT_CONFIG
return DEFAULT_CONFIG
2 changes: 1 addition & 1 deletion src/trainer/glue_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,4 @@ def main(self):
trainer = Trainer(
subset='mnli'
)
trainer.main()
trainer.main()