-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtrain_osediff_face.py
329 lines (281 loc) · 15.1 KB
/
train_osediff_face.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import os
import gc
import lpips
import clip
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import diffusers
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler
from osediff import OSEDiff_reg, OSEDiff_gen
from dataloaders.facesr_dataset import CodeformerTxtDataset
from pathlib import Path
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate import DistributedDataParallelKwargs
def parse_float_list(arg):
try:
return [float(x) for x in arg.split(',')]
except ValueError:
raise argparse.ArgumentTypeError("List elements should be floats")
def parse_int_list(arg):
try:
return [int(x) for x in arg.split(',')]
except ValueError:
raise argparse.ArgumentTypeError("List elements should be integers")
def parse_str_list(arg):
return arg.split(',')
def parse_args(input_args=None):
"""
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
This function sets up an argument parser to handle various training options.
Returns:
argparse.Namespace: The parsed command-line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--revision", type=str, default=None,)
parser.add_argument("--variant", type=str, default=None,)
parser.add_argument("--tokenizer_name", type=str, default=None)
# training details
parser.add_argument("--output_dir", default='experience/osediff_face')
parser.add_argument("--seed", type=int, default=123, help="A seed for reproducible training.")
parser.add_argument("--resolution", type=int, default=512,)
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
parser.add_argument("--num_training_epochs", type=int, default=10000)
parser.add_argument("--max_train_steps", type=int, default=100000,)
parser.add_argument("--checkpointing_steps", type=int, default=500,)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.",)
parser.add_argument("--gradient_checkpointing", action="store_true",)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--lr_scheduler", type=str, default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--lr_num_cycles", type=int, default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument("--dataloader_num_workers", type=int, default=0,)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--allow_tf32", action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--report_to", type=str, default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"],)
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
parser.add_argument("--set_grads_to_none", action="store_true",)
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument("--tracker_project_name", type=str, default="train_osediff_face", help="The name of the wandb project to log to.")
parser.add_argument('--dataset_txt_paths_list', type=parse_str_list, default=['YOUR TXT FILE PATH'], help='A comma-separated list of integers')
parser.add_argument('--dataset_prob_paths_list', type=parse_int_list, default=[1], help='A comma-separated list of integers')
parser.add_argument("--deg_file_path", default="params_codeformer.yml", type=str)
parser.add_argument("--pretrained_model_name_or_path", default=None, type=str)
parser.add_argument("--lambda_l2", default=1.0, type=float)
parser.add_argument("--lambda_lpips", default=2.0, type=float)
parser.add_argument("--lambda_vsd", default=1.0, type=float)
parser.add_argument("--lambda_vsd_lora", default=1.0, type=float)
parser.add_argument("--neg_prompt", default="", type=str)
parser.add_argument("--cfg_vsd", default=7.5, type=float)
# lora setting
parser.add_argument("--lora_rank", default=4, type=int)
# ram path
parser.add_argument('--ram_path', type=str, default=None, help='Path to RAM model')
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[ddp_kwargs],
)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
model_gen = OSEDiff_gen(args)
model_gen.set_train()
model_reg = OSEDiff_reg(args=args, accelerator=accelerator)
model_reg.set_train()
net_lpips = lpips.LPIPS(net='vgg').cuda()
net_lpips.requires_grad_(False)
# set vae adapter
model_gen.vae.set_adapter(['default_encoder'])
# set gen adapter
model_gen.unet.set_adapter(['default_encoder', 'default_decoder', 'default_others'])
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
model_gen.unet.enable_xformers_memory_efficient_attention()
model_reg.unet_fix.enable_xformers_memory_efficient_attention()
model_reg.unet_update.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available, please install it by running `pip install xformers`")
if args.gradient_checkpointing:
model_gen.unet.enable_gradient_checkpointing()
model_reg.unet_fix.enable_gradient_checkpointing()
model_reg.unet_update.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# make the optimizer
layers_to_opt = []
for n, _p in model_gen.unet.named_parameters():
if "lora" in n:
layers_to_opt.append(_p)
layers_to_opt += list(model_gen.unet.conv_in.parameters())
for n, _p in model_gen.vae.named_parameters():
if "lora" in n:
layers_to_opt.append(_p)
optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, power=args.lr_power,)
layers_to_opt_reg = []
for n, _p in model_reg.unet_update.named_parameters():
if "lora" in n:
layers_to_opt_reg.append(_p)
optimizer_reg = torch.optim.AdamW(layers_to_opt_reg, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,)
lr_scheduler_reg = get_scheduler(args.lr_scheduler, optimizer=optimizer_reg,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, power=args.lr_power)
dataset_train = CodeformerTxtDataset(split="train", args=args)
dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
# init vlm model
from ram.models.ram_lora import ram
from ram import inference_ram as inference
ram_transforms = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model_vlm = ram(pretrained=args.ram_path,
pretrained_condition=None,
image_size=384,
vit='swin_l')
model_vlm.eval()
model_vlm.to("cuda", dtype=torch.float16)
# Prepare everything with our `accelerator`.
model_gen, model_reg, optimizer, optimizer_reg, dl_train, lr_scheduler, lr_scheduler_reg = accelerator.prepare(
model_gen, model_reg, optimizer, optimizer_reg, dl_train, lr_scheduler, lr_scheduler_reg
)
net_lpips = accelerator.prepare(net_lpips)
# renorm with image net statistics
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
args.dataset_txt_paths_list = str(args.dataset_txt_paths_list)
args.dataset_prob_paths_list = str(args.dataset_prob_paths_list)
tracker_config = dict(vars(args))
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
disable=not accelerator.is_local_main_process,)
# start the training loop
global_step = 0
for epoch in range(0, args.num_training_epochs):
for step, batch in enumerate(dl_train):
m_acc = [model_gen, model_reg]
with accelerator.accumulate(*m_acc):
x_src = batch["conditioning_pixel_values"]
x_tgt = batch["output_pixel_values"]
B, C, H, W = x_src.shape
# get text prompts from GT
x_tgt_ram = ram_transforms(x_tgt*0.5+0.5)
caption = inference(x_tgt_ram.to(dtype=torch.float16), model_vlm)
batch["prompt"] = [f'{each_caption}' for each_caption in caption]
# forward pass
x_tgt_pred, latents_pred, prompt_embeds, neg_prompt_embeds = model_gen(x_src, batch=batch, args=args)
# Reconstruction loss
loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
loss = loss_l2 + loss_lpips
# KL loss
if torch.cuda.device_count() > 1:
loss_kl = model_reg.module.distribution_matching_loss(latents=latents_pred, prompt_embeds=prompt_embeds, neg_prompt_embeds=neg_prompt_embeds, args=args) * args.lambda_vsd
else:
loss_kl = model_reg.distribution_matching_loss(latents=latents_pred, prompt_embeds=prompt_embeds, neg_prompt_embeds=neg_prompt_embeds, args=args) * args.lambda_vsd
loss = loss + loss_kl
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
"""
diff loss: let lora model closed to generator
"""
if torch.cuda.device_count() > 1:
loss_d = model_reg.module.diff_loss(latents=latents_pred, prompt_embeds=prompt_embeds, args=args)*args.lambda_vsd_lora
else:
loss_d = model_reg.diff_loss(latents=latents_pred, prompt_embeds=prompt_embeds, args=args)*args.lambda_vsd_lora
accelerator.backward(loss_d)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model_reg.parameters(), args.max_grad_norm)
optimizer_reg.step()
lr_scheduler_reg.step()
optimizer_reg.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
logs = {}
# log all the losses
logs["loss_d"] = loss_d.detach().item()
logs["loss_kl"] = loss_kl.detach().item()
logs["loss_l2"] = loss_l2.detach().item()
logs["loss_lpips"] = loss_lpips.detach().item()
progress_bar.set_postfix(**logs)
# checkpoint the model
if global_step % args.checkpointing_steps == 1:
outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
accelerator.unwrap_model(model_gen).save_model(outf)
accelerator.log(logs, step=global_step)
if __name__ == "__main__":
args = parse_args()
main(args)