forked from lucidrains/denoising-diffusion-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_train_step_t.py
69 lines (57 loc) · 1.57 KB
/
mnist_train_step_t.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from denoising_diffusion_pytorch_step_t import Unet, GaussianDiffusion, Trainer
import torchvision
import os
import errno
import shutil
def create_folder(path):
try:
os.mkdir(path)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
def del_folder(path):
try:
shutil.rmtree(path)
except OSError as exc:
pass
create = 0
if create:
trainset = torchvision.datasets.MNIST(
root='./data', train=True, download=True)
root = './root_mnist/'
del_folder(root)
create_folder(root)
for i in range(10):
lable_root = root + str(i) + '/'
create_folder(lable_root)
for idx in range(len(trainset)):
img, label = trainset[idx]
img.save(root + str(label) + '/' + str(idx) + '.png')
timesteps=1000
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
image_size = 32,
timesteps = timesteps,
with_time_emb = False
).cuda()
diffusion = GaussianDiffusion(
model,
image_size = 32,
timesteps = timesteps, # number of steps
loss_type = 'l1' # L1 or L2
).cuda()
trainer = Trainer(
diffusion,
'./root_mnist/',
image_size = 32,
train_batch_size = 32,
train_lr = 2e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
fp16 = True, # turn on mixed precision training with apex
results_folder = './results_mnist_step_t'
)
trainer.train()