-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathcall_txt2img.py
372 lines (291 loc) · 13.9 KB
/
call_txt2img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import json
import requests
import io
import base64
import uuid
import sys, os
from PIL import Image, PngImagePlugin
from model_lists import *
import time
import random
def call_txt2img(passingprompt,size,upscale,debugmode,filename="",model = "currently selected model",samplingsteps = "40",cfg= "7",hiressteps ="0",denoisestrength="0.6",samplingmethod="DPM++ SDE Karras", upscaler="R-ESRGAN 4x+",hiresscale="2",apiurl="http://127.0.0.1:7860", qualitygate=False,quality="7.6",runs="5",negativeprompt="",qualityhiresfix = False, qualitymode = "highest", qualitykeep="keep used", basesize="512"):
#set the prompt!
prompt = passingprompt
checkprompt = passingprompt.lower()
#set the URL for the API
url = apiurl
#rest of prompt things
sampler_index = samplingmethod
steps = samplingsteps
if(debugmode==1):
steps="10"
cfg_scale = cfg
originalsize = size
#size
sizes = setsize(size, basesize,originalsize)
width = sizes[0]
height = sizes[1]
#upscaler
enable_hr = upscale
if(debugmode==1 or qualityhiresfix == True):
enable_hr="False"
#defaults
hr_scale = hiresscale
denoising_strength = denoisestrength
hr_second_pass_steps = hiressteps
#hr_upscaler = "LDSR" # We have the time, why not use LDSR
if(upscaler != "automatic"):
hr_upscaler = upscaler
else:
upscalerlist = get_upscalers()
# on automatic, make some choices about what upscaler to use
# photos, prefer 4x ultrasharp
# anime, cartoon or drawing, go for R-ESRGAN 4x+ Anime6B
# else, R-ESRGAN 4x+"
if("hoto" in checkprompt and "4x-UltraSharp" in upscalerlist):
hr_upscaler = "4x-UltraSharp"
elif("anime" in checkprompt or "cartoon" in checkprompt or "draw" in checkprompt or "vector" in checkprompt or "cel shad" in checkprompt or "visual novel" in checkprompt):
hr_upscaler = "R-ESRGAN 4x+ Anime6B"
else:
hr_upscaler = "R-ESRGAN 4x+"
if(hiressteps==0):
hiressteps = samplingsteps
hr_second_pass_steps = int(hiressteps/2)
hr_scale = 2
if(hr_upscaler== "4x-UltraSharp"):
denoising_strength = "0.35"
if(hr_upscaler== "R-ESRGAN 4x+ Anime6B+"):
denoising_strength = "0.6" # 0.6 is fine for the anime upscaler
if(hr_upscaler== "R-ESRGAN 4x+"):
denoising_strength = "0.5" # default 0.6 is a lot and changes a lot of details
#params to stay the same
script_dir = os.path.dirname(os.path.abspath(__file__)) # Script directory
outputTXT2IMGfolder = os.path.join(script_dir, "./automated_outputs/txt2img/")
outputTXT2IMGfolder.replace("./", "/")
if(filename==""):
filename = str(uuid.uuid4())
outputTXT2IMGpng = '.png'
#outputTXT2IMGFull = '{}{}{}'.format(outputTXT2IMGfolder,filename,outputTXT2IMGpng)
outputTXT2IMGtxtfolder = os.path.join(script_dir, "./automated_outputs/prompts/")
outputTXT2IMGtxtfolder.replace("./", "/")
outputTXT2IMGtxt = '.txt'
outputTXT2IMGtxtFull = '{}{}{}'.format(outputTXT2IMGtxtfolder,filename,outputTXT2IMGtxt)
# params for quality gate
isGoodNumber = float(quality)
foundgood = False
MaxRuns = int(runs)
Runs = 0
scorelist = []
scoredeclist = []
imagelist = []
pnginfolist = []
seedlist = []
widthlist = []
heightlist = []
usedwidht = width
usedheight = height
usedseed = -1
imagethatiskept = ""
# flow things
continuewithnextpart = True
# starting seed of -1
seed = -1
#call TXT2IMG
payload = {
"prompt": prompt,
"sampler_index": sampler_index,
"steps": steps,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"enable_hr": enable_hr,
"denoising_strength": denoising_strength,
"hr_scale": hr_scale,
"hr_upscaler": hr_upscaler,
"hr_second_pass_steps": hr_second_pass_steps,
"seed": seed,
"hr_prompt": prompt
}
if(model != "currently selected model"):
payload.update({"sd_model": model})
if(negativeprompt != ""):
payload.update({"negative_prompt": negativeprompt})
payload.update({"hr_negative_prompt": negativeprompt})
while Runs < MaxRuns:
# make the filename unique for each run _0, _1, etc.
addrun = "_" + str(Runs)
filenamefull = filename + addrun
outputTXT2IMGFull = '{}{}{}'.format(outputTXT2IMGfolder,filenamefull,outputTXT2IMGpng)
r = []
# randomize the seed ( A number between 0 and 4,294,967,295 )
seed = random.randrange(1, 4294967295)
payload["seed"] = seed
# do we need to randomize the size?
if(originalsize=='all' or originalsize == 'wild'):
sizes = setsize(size, basesize, originalsize)
width = sizes[0]
height = sizes[1]
payload["width"] = width
payload["height"] = height
# If we don't get an image back, we want to retry a few times. Max 3 times
for i in range(4):
response = requests.post(url=f'{url}/sdapi/v1/txt2img', json=payload)
r = response.json()
if('images' in r):
break # this means if we have the images object, then we "break" out of the for loop.
else:
if(i == 3):
print("If this keeps happening: Is WebUI started with --api enabled?")
print("")
raise ValueError("API has not been responding after several retries. Stopped processing.")
print("")
print("We haven't received an image from the API. Maybe something went wrong. Will retry after waiting a bit.")
time.sleep(10 * (i+1) ) # incremental waiting time
for i in r['images']:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
png_payload = {
"image": "data:image/png;base64," + i
}
response2 = requests.post(url=f'{url}/sdapi/v1/png-info', json=png_payload)
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(outputTXT2IMGFull, pnginfo=pnginfo)
if(qualitygate==True):
# check if the file exists in the parent directory
imagescorer_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'stable-diffusion-webui-aesthetic-image-scorer', 'scripts'))
#print(imagescorer_path)
if imagescorer_path not in sys.path:
sys.path.append(imagescorer_path)
try:
import image_scorer
print("Found aesthetic-image-scorer! Using this to measure the results...")
score = image_scorer.get_score(image)
scoredeclist.append(score)
score = round(score,1)
seedlist.append(seed)
widthlist.append(width)
heightlist.append(height)
scorelist.append(score)
imagelist.append(outputTXT2IMGFull)
pnginfolist.append(pnginfo)
print("This image has scored: "+ str(score) + " out of " + str(isGoodNumber))
if(score >= isGoodNumber or debugmode == 1):
foundgood = True
print("Yay its good! Keeping this result.")
else:
runstodo = MaxRuns - Runs - 1
print("Not a good result. Retrying for another " + str(runstodo) + " times or until the image is good enough.")
except ImportError:
foundgood = True # just continue :)
# handle the case where the module doesn't exist
print("Could not find the stable-diffusion-webui-aesthetic-image-scorer extension.")
print("Install this extension via the WebUI to use Quality Gate")
pass
else:
foundgood = True # If there is no quality gate, then everything is good. So we escape this loop
Runs += 1
if(foundgood == True):
break #Break the loop if we found something good. Or if we set it to good :)
if(len(imagelist) > 0):
if(foundgood == True):
if(qualitykeep == "keep used"):
print("Removing any other images generated this run (if any).")
else:
if(qualitymode == "highest"):
print("")
print("Stopped trying, keeping the best image we had so far.")
print("")
else:
print("")
print("Eh, its all pretty bad. Not going forward with any image.")
print("")
# Get the index of the first occurrence of the maximum value in the list
#if(qualitymode == "highest" or (qualitymode != "highest" and foundgood == True)):
indexofimagetokeep = scoredeclist.index(max(scoredeclist))
outputTXT2IMGFull = imagelist[indexofimagetokeep] #store the image to keep in here, so we can pass it along
pnginfo = pnginfolist[indexofimagetokeep]
usedseed = seedlist[indexofimagetokeep]
usedwidht = widthlist[indexofimagetokeep]
usedheight = heightlist[indexofimagetokeep]
imagethatiskept = imagelist[indexofimagetokeep]
imagelist.pop(indexofimagetokeep)
#remove all other images
if(qualitykeep == "keep used"):
for imagelocation in imagelist:
os.remove(imagelocation)
if(foundgood == False and qualitymode != "highest"):
continuewithnextpart = False
if(imagethatiskept != "" and qualitykeep == "keep used"):
os.remove(imagethatiskept)
# We have done everything, but if we want to run Hires fix from the quality gate, we are going to have to do it again. But this time a little easier.
# We do have the check wether we want to run hiresfix first
if(qualityhiresfix == True and upscale == False and continuewithnextpart == True):
print("Quality Gate hires fix was enabled, but no hires fix settings were given.")
if(qualityhiresfix == True and upscale == True and continuewithnextpart == True):
print("Going to run the chosen image with hiresfix")
payload["seed"] = usedseed
payload["width"] = usedwidht
payload["height"] = usedheight
payload["enable_hr"] = "True"
# make the filename unique for hiresfix
addrun = "_hiresfix"
filenamefull = filename + addrun
outputTXT2IMGFull = '{}{}{}'.format(outputTXT2IMGfolder,filenamefull,outputTXT2IMGpng)
# If we don't get an image back, we want to retry a few times. Max 3 times
for i in range(4):
response = requests.post(url=f'{url}/sdapi/v1/txt2img', json=payload)
r = response.json()
if('images' in r):
break # this means if we have the images object, then we "break" out of the for loop.
else:
if(i == 3):
print("If this keeps happening: Is WebUI started with --api enabled?")
print("")
raise ValueError("API has not been responding after several retries. Stopped processing.")
print("")
print("We haven't received an image from the API. Maybe something went wrong. Will retry after waiting a bit.")
time.sleep(10 * (i+1) ) # incremental waiting time
for i in r['images']:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",",1)[0])))
png_payload = {
"image": "data:image/png;base64," + i
}
response2 = requests.post(url=f'{url}/sdapi/v1/png-info', json=png_payload)
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(outputTXT2IMGFull, pnginfo=pnginfo)
with open(outputTXT2IMGtxtFull,'w',encoding="utf8") as txt:
json_object = json.dumps(payload, indent = 4)
txt.write(json_object)
return [outputTXT2IMGFull,pnginfo,continuewithnextpart]
def setsize(ratio,basesize, originalsize):
# prompt + size
if(originalsize == "all"):
sizelist = ["portrait", "wide", "square"]
ratio = random.choice(sizelist)
# from base ratio
if(ratio=='wide' and basesize != "1024"):
width = str(int(basesize) + 256)
height = basesize
elif(ratio=='wide' and basesize == "1024"):
width = "1152"
height = "896"
elif(ratio=='portrait' and basesize != "1024"):
width = basesize
height = str(int(basesize) + 256)
elif(ratio=='portrait' and basesize == "1024"):
width = "896"
height = "1152"
elif(ratio=='ultrawide'):
width = "1280"
height = "360"
elif(ratio=='ultraheight'):
width = "360"
height = "1280"
elif(ratio=='wild'):
width = str(round((random.randint(0,4) * 128) + (int(basesize) /2) ) ) # random value of 0 to 512 in steps of 128 + half of base size
height = str(round( (random.randint(0,4) * 128) + (int(basesize) /2) ) ) # random value of 0 to 512 in steps of 128 + half of base size
else:
width = basesize
height = basesize
return [width, height]