Skip to content

Commit f39b71a

Browse files
Merge pull request #968 from ZeroCool940711/dev
Improved the progress bar for the txt2vid tab and other fixes.
2 parents 3e9cdb1 + 24ddbdc commit f39b71a

6 files changed

+525
-38
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ condaenv.*.requirements.txt
6262
/log/log.csv
6363
/flagged/*
6464
/gfpgan/*
65+
/models/custom/

environment.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies:
3333
- python-slugify>=6.1.2
3434
- streamlit>=1.12.2
3535
- retry>=0.9.2
36-
- diffusers>=0.3.0
36+
- diffusers<=0.2.4
3737
- -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers
3838
- -e git+https://github.com/openai/CLIP#egg=clip
3939
- -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN

frontend/css/streamlit.main.css

+2-13
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
1-
.css-18e3th9 {
2-
padding-top: 2rem;
3-
padding-bottom: 10rem;
4-
padding-left: 5rem;
5-
padding-right: 5rem;
6-
}
7-
.css-1d391kg {
8-
padding-top: 3.5rem;
9-
padding-right: 1rem;
10-
padding-bottom: 3.5rem;
11-
padding-left: 1rem;
12-
}
131
button[data-baseweb="tab"] {
142
font-size: 25px;
153
}
164
.css-du1fp8 {
175
justify-content: center;
18-
}
6+
}
7+

scripts/stable_diffusion_pipeline.py

+233
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import inspect
2+
import warnings
3+
from tqdm.auto import tqdm
4+
from typing import List, Optional, Union
5+
6+
import torch
7+
from diffusers import ModelMixin
8+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
9+
from diffusers.pipeline_utils import DiffusionPipeline
10+
from diffusers.pipelines.stable_diffusion.safety_checker import \
11+
StableDiffusionSafetyChecker
12+
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
13+
PNDMScheduler)
14+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
15+
16+
17+
class StableDiffusionPipeline(DiffusionPipeline):
18+
def __init__(
19+
self,
20+
vae: AutoencoderKL,
21+
text_encoder: CLIPTextModel,
22+
tokenizer: CLIPTokenizer,
23+
unet: UNet2DConditionModel,
24+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
25+
safety_checker: StableDiffusionSafetyChecker,
26+
feature_extractor: CLIPFeatureExtractor,
27+
):
28+
super().__init__()
29+
scheduler = scheduler.set_format("pt")
30+
self.register_modules(
31+
vae=vae,
32+
text_encoder=text_encoder,
33+
tokenizer=tokenizer,
34+
unet=unet,
35+
scheduler=scheduler,
36+
safety_checker=safety_checker,
37+
feature_extractor=feature_extractor,
38+
)
39+
40+
@torch.no_grad()
41+
def __call__(
42+
self,
43+
prompt: Optional[Union[str, List[str]]] = None,
44+
height: Optional[int] = 512,
45+
width: Optional[int] = 512,
46+
num_inference_steps: Optional[int] = 50,
47+
guidance_scale: Optional[float] = 7.5,
48+
eta: Optional[float] = 0.0,
49+
generator: Optional[torch.Generator] = None,
50+
latents: Optional[torch.FloatTensor] = None,
51+
text_embeddings: Optional[torch.FloatTensor] = None,
52+
output_type: Optional[str] = "pil",
53+
**kwargs,
54+
):
55+
if "torch_device" in kwargs:
56+
device = kwargs.pop("torch_device")
57+
warnings.warn(
58+
"`torch_device` is deprecated as an input argument to `__call__` and"
59+
" will be removed in v0.3.0. Consider using `pipe.to(torch_device)`"
60+
" instead."
61+
)
62+
63+
# Set device as before (to be removed in 0.3.0)
64+
if device is None:
65+
device = "cuda" if torch.cuda.is_available() else "cpu"
66+
self.to(device)
67+
68+
if text_embeddings is None:
69+
if isinstance(prompt, str):
70+
batch_size = 1
71+
elif isinstance(prompt, list):
72+
batch_size = len(prompt)
73+
else:
74+
raise ValueError(
75+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
76+
)
77+
78+
if height % 8 != 0 or width % 8 != 0:
79+
raise ValueError(
80+
"`height` and `width` have to be divisible by 8 but are"
81+
f" {height} and {width}."
82+
)
83+
84+
# get prompt text embeddings
85+
text_input = self.tokenizer(
86+
prompt,
87+
padding="max_length",
88+
max_length=self.tokenizer.model_max_length,
89+
truncation=True,
90+
return_tensors="pt",
91+
)
92+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
93+
else:
94+
batch_size = text_embeddings.shape[0]
95+
96+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
97+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
98+
# corresponds to doing no classifier free guidance.
99+
do_classifier_free_guidance = guidance_scale > 1.0
100+
# get unconditional embeddings for classifier free guidance
101+
if do_classifier_free_guidance:
102+
# max_length = text_input.input_ids.shape[-1]
103+
max_length = 77 # self.tokenizer.model_max_length
104+
uncond_input = self.tokenizer(
105+
[""] * batch_size,
106+
padding="max_length",
107+
max_length=max_length,
108+
return_tensors="pt",
109+
)
110+
uncond_embeddings = self.text_encoder(
111+
uncond_input.input_ids.to(self.device)
112+
)[0]
113+
114+
# For classifier free guidance, we need to do two forward passes.
115+
# Here we concatenate the unconditional and text embeddings into a single batch
116+
# to avoid doing two forward passes
117+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
118+
119+
# get the initial random noise unless the user supplied it
120+
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
121+
if latents is None:
122+
latents = torch.randn(
123+
latents_shape,
124+
generator=generator,
125+
device=self.device,
126+
)
127+
else:
128+
if latents.shape != latents_shape:
129+
raise ValueError(
130+
f"Unexpected latents shape, got {latents.shape}, expected"
131+
f" {latents_shape}"
132+
)
133+
latents = latents.to(self.device)
134+
135+
# set timesteps
136+
accepts_offset = "offset" in set(
137+
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
138+
)
139+
extra_set_kwargs = {}
140+
if accepts_offset:
141+
extra_set_kwargs["offset"] = 1
142+
143+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
144+
145+
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
146+
if isinstance(self.scheduler, LMSDiscreteScheduler):
147+
latents = latents * self.scheduler.sigmas[0]
148+
149+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
150+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
151+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
152+
# and should be between [0, 1]
153+
accepts_eta = "eta" in set(
154+
inspect.signature(self.scheduler.step).parameters.keys()
155+
)
156+
extra_step_kwargs = {}
157+
if accepts_eta:
158+
extra_step_kwargs["eta"] = eta
159+
160+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
161+
# expand the latents if we are doing classifier free guidance
162+
latent_model_input = (
163+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
164+
)
165+
if isinstance(self.scheduler, LMSDiscreteScheduler):
166+
sigma = self.scheduler.sigmas[i]
167+
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
168+
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
169+
170+
# predict the noise residual
171+
noise_pred = self.unet(
172+
latent_model_input, t, encoder_hidden_states=text_embeddings
173+
)["sample"]
174+
175+
# perform guidance
176+
if do_classifier_free_guidance:
177+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
178+
noise_pred = noise_pred_uncond + guidance_scale * (
179+
noise_pred_text - noise_pred_uncond
180+
)
181+
182+
# compute the previous noisy sample x_t -> x_t-1
183+
if isinstance(self.scheduler, LMSDiscreteScheduler):
184+
latents = self.scheduler.step(
185+
noise_pred, i, latents, **extra_step_kwargs
186+
)["prev_sample"]
187+
else:
188+
latents = self.scheduler.step(
189+
noise_pred, t, latents, **extra_step_kwargs
190+
)["prev_sample"]
191+
192+
# scale and decode the image latents with vae
193+
latents = 1 / 0.18215 * latents
194+
image = self.vae.decode(latents).sample
195+
196+
image = (image / 2 + 0.5).clamp(0, 1)
197+
image = image.cpu().permute(0, 2, 3, 1).numpy()
198+
199+
safety_cheker_input = self.feature_extractor(
200+
self.numpy_to_pil(image), return_tensors="pt"
201+
).to(self.device)
202+
image, has_nsfw_concept = self.safety_checker(
203+
images=image, clip_input=safety_cheker_input.pixel_values
204+
)
205+
206+
if output_type == "pil":
207+
image = self.numpy_to_pil(image)
208+
209+
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
210+
211+
def embed_text(self, text):
212+
"""Helper to embed some text"""
213+
with torch.autocast("cuda"):
214+
text_input = self.tokenizer(
215+
text,
216+
padding="max_length",
217+
max_length=self.tokenizer.model_max_length,
218+
truncation=True,
219+
return_tensors="pt",
220+
)
221+
with torch.no_grad():
222+
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
223+
return embed
224+
225+
226+
class NoCheck(ModelMixin):
227+
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
228+
def __init__(self):
229+
super().__init__()
230+
self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
231+
232+
def forward(self, images=None, **kwargs):
233+
return images, [False]

0 commit comments

Comments
 (0)