-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_sum.py
137 lines (114 loc) · 4.85 KB
/
train_sum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Train a diffusion model on images.
"""
import argparse
import json, torch, os
from tqdm import tqdm
from imp_diff.improved_diffusion.summarization_datasets_exp import load_data_summarization
from transformers import AutoTokenizer
from transformers import set_seed
from functools import partial
import torch.distributed as dist
import wandb
import numpy as np
from imp_diff.improved_diffusion import dist_util
from imp_diff.improved_diffusion.image_datasets import load_data
from imp_diff.improved_diffusion.resample import create_named_schedule_sampler
from imp_diff.improved_diffusion.script_util import create_model_and_diffusion, args_to_dict, \
model_and_diffusion_defaults, add_dict_to_argparser
from imp_diff.improved_diffusion.test_util import get_weights, compute_logp
from imp_diff.improved_diffusion.train_util import TrainLoop
from imp_diff.improved_diffusion import logger
def main():
args = create_argparser().parse_args()
set_seed(args.seed)
dist_util.setup_dist() # DEBUG **
logger.configure()
print("Micro Batch: {}".format(args.microbatch))
print("Dataset: {}".format(args.roc_train))
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'the parameter count is {pytorch_total_params}')
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
logger.log(f'saving the hyperparameters to {args.checkpoint_path}/training_args.json')
with open(f'{args.checkpoint_path}/training_args.json', 'w') as f:
json.dump(args.__dict__, f, indent=2)
wandb.init(
project=os.getenv("WANDB_PROJECT", "diffusion_lm"),
name=args.checkpoint_path,
)
wandb.config.update(args.__dict__, allow_val_change=True)
logger.log("creating data loader...")
print('load data', '*'*50)
data = load_data_summarization(args.batch_size, 8, args.roc_train, split=['train'], sent_encoder_type='sbert', summary_type='oracle',
summary_level='sen')
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=0,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
checkpoint_path=args.checkpoint_path,
gradient_clipping=args.gradient_clipping,
eval_data=None,
eval_interval=args.eval_interval,
epochs=args.epochs,
tokenizer=None
).run_loop()
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=50,
save_interval=50000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
seed=101,
gradient_clipping=-1.0,
eval_interval=2000,
checkpoint_path='diff_models',
epochs=5
)
text_defaults = dict(modality='text',
dataset_name='wikitext',
dataset_config_name='wikitext-2-raw-v1',
config='diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml',
model_name_or_path='predictability/diff_models/compress_e=5_b=60_m=gpt2_wikitext-103-raw-v1_None',
experiment='gpt2_pre_compress',model_arch='conv-unet',
roc_train='diffusion_lm/ROCstory',#'diffusion_lm/ROCstory/ROCstory17.csv',
wiki_train='diffusion_lm/simple_wiki/data.v1.split/simple.training.txt',
e2e_train='e2e_data',
yelp_train='diffusion_lm/yelpnlg-resources/yelpnlg-corpus',
commonGen_train = 'diffusion_lm/common-gen/commongen_data',
emb_scale_factor=1.0, noise_level=0.0, cache_mode='no', use_bert_tokenizer='no',
padding_mode='block',
preprocessing_num_workers=1)
defaults.update(model_and_diffusion_defaults())
defaults.update(text_defaults)
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()