forked from lucidrains/denoising-diffusion-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_test_step_t.py
84 lines (70 loc) · 1.94 KB
/
mnist_test_step_t.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from denoising_diffusion_pytorch_step_t import Unet, GaussianDiffusion, Trainer, GaussianDiffusionIter
import torchvision
import os
import errno
import shutil
def create_folder(path):
try:
os.mkdir(path)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
def del_folder(path):
try:
shutil.rmtree(path)
except OSError as exc:
pass
create = 0
if create:
trainset = torchvision.datasets.MNIST(
root='./data', train=True, download=True)
root = './root_mnist/'
del_folder(root)
create_folder(root)
for i in range(10):
lable_root = root + str(i) + '/'
create_folder(lable_root)
for idx in range(len(trainset)):
img, label = trainset[idx]
img.save(root + str(label) + '/' + str(idx) + '.png')
timesteps=1000
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
image_size = 32,
timesteps = timesteps,
with_time_emb = False
).cuda()
# diffusion = GaussianDiffusion(
# model,
# image_size = 32,
# timesteps = timesteps, # number of steps
# loss_type = 'l1' # L1 or L2
# ).cuda()
diffusion = GaussianDiffusionIter(
model,
image_size = 32,
timesteps = timesteps, # number of steps
loss_type = 'l1' # L1 or L2
).cuda()
# diffusion = GaussianDiffusionX0(
# model,
# image_size = 32,
# timesteps = timesteps, # number of steps
# loss_type = 'l1' # L1 or L2
# ).cuda()
trainer = Trainer(
diffusion,
'./root_mnist/',
image_size = 32,
train_batch_size = 36,
train_lr = 2e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
fp16 = True, # turn on mixed precision training with apex
results_folder = './results_mnist_step_t',
do_load = True
)
trainer.test()