forked from lucidrains/DALLE-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgenDALLE.py
executable file
·113 lines (85 loc) · 2.68 KB
/
genDALLE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from dalle_pytorch import DiscreteVAE, DALLE
from torchvision.io import read_image
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
import time
import sys
# vae
load_epoch = 390
vaename = "vae-cdim256"
# general
imgSize = 256
batchSize = 12
n_epochs = 100
log_interval = 10
lr = 2e-5
#dalle
dalle_epoch = 220
#loadfn = ""
#start_epoch = 0
name = "vae-cdim256"
loadfn = "./models/dalle_"+name+"-"+str(dalle_epoch)+".pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tf = transforms.Compose([
#transforms.Resize(imgSize),
#transforms.RandomHorizontalFlip(),
#transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234))
])
vae = DiscreteVAE(
image_size = 256,
num_layers = 3,
num_tokens = 2048,
codebook_dim = 256,
hidden_dim = 128,
temperature = 0.9
)
# load pretrained vae
vae_dict = torch.load("./models/"+vaename+"-"+str(load_epoch)+".pth")
vae.load_state_dict(vae_dict)
vae.to(device)
dalle = DALLE(
dim = 256, #512,
vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
num_text_tokens = 10000, # vocab size for text
text_seq_len = 256, # text sequence length
depth = 6, # should be 64
heads = 8, # attention heads
dim_head = 64, # attention head dimension
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
)
# load pretrained dalle if continuing training
dalle_dict = torch.load(loadfn)
dalle.load_state_dict(dalle_dict)
dalle.to(device)
# get image and text data
lf = open("od-captionsonly.txt", "r") # file contains captions only, one caption per line
# build vocabulary
from Vocabulary import Vocabulary
vocab = Vocabulary("captions")
captions = []
for lin in lf:
captions.append(lin)
for caption in captions:
vocab.add_sentence(caption)
def tokenizer(text): # create a tokenizer function
return text.split(' ')
inp_text = sys.argv[1]
print(inp_text)
tokens = tokenizer(inp_text)
codes = []
for t in tokens:
codes.append(vocab.to_index(t))
print(codes)
c_tokens = [0]*256 # fill to match text_seq_len
c_tokens[:len(codes)] = codes
text = torch.LongTensor(codes).unsqueeze(0).to(device) # a minibatch of text (numerical tokens)
mask = torch.ones_like(text).bool().to(device)
oimgs = dalle.generate_images(text, mask = mask)
ts = int(time.time())
print(inp_text, ts)
save_image(oimgs,
'results/gendalle'+name+'_epoch_' + str(dalle_epoch) + '-' +str(ts)+'.png', normalize=True)