Skip to content

Commit

Permalink
Img condition (pesser#1)
Browse files Browse the repository at this point in the history
* update reqs
* add image variations
* update readme
  • Loading branch information
justinpinkney authored Sep 4, 2022
1 parent 693e713 commit 7e3956e
Show file tree
Hide file tree
Showing 10 changed files with 455 additions and 281 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
logs/
dump/
examples/
outputs/
flagged/
*.egg-info
__pycache__
281 changes: 8 additions & 273 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/img-vars.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
134 changes: 134 additions & 0 deletions configs/stable-diffusion/sd-image-condition-finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "jpg"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215

scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder


data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: "/mnt/data_rome/laion/improved_aesthetics_6plus/ims"
batch_size: 6
num_workers: 4
multinode: True
min_size: 256
train:
shards: '{00000..01209}.tar'
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512

# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards: '{00000..00008}.tar -'
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512


lightning:
find_unused_parameters: false
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: False
log_first_step: True
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 8
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]

trainer:
benchmark: True
val_check_interval: 5000000 # really sorry
num_sanity_val_steps: 0
accumulate_grad_batches: 1
8 changes: 5 additions & 3 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def __init__(self,
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.bbox_tokenizer = None

self.restarted_from_ckpt = False
if ckpt_path is not None:
Expand Down Expand Up @@ -904,7 +904,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False):

if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids
assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)

Expand Down Expand Up @@ -1343,7 +1343,9 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
log["samples_x0_quantized"] = x_samples

if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# FIXME
uc = torch.zeros_like(c)
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
Expand Down
38 changes: 38 additions & 0 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch.nn as nn
import numpy as np
from functools import partial
import kornia

from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.util import default
import clip


class AbstractEncoder(nn.Module):
Expand Down Expand Up @@ -170,6 +172,42 @@ def forward(self, text):
def encode(self, text):
return self(text)

class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.device = device

self.antialias = antialias

self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)

def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x

def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x)).float()

def encode(self, im):
return self(im).unsqueeze(1)

class SpatialRescaler(nn.Module):
def __init__(self,
Expand Down
25 changes: 23 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ldm.util import instantiate_from_config


MULTINODE_HACKS = True
MULTINODE_HACKS = False


def get_parser(**parser_kwargs):
Expand All @@ -36,6 +36,13 @@ def str2bool(v):
raise argparse.ArgumentTypeError("Boolean value expected.")

parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"--finetune_from",
type=str,
nargs="?",
default="",
help="path to checkpoint to load model state from"
)
parser.add_argument(
"-n",
"--name",
Expand Down Expand Up @@ -644,6 +651,20 @@ def check_frequency(self, check_idx):
# model
model = instantiate_from_config(config.model)

if not opt.finetune_from == "":
print(f"Attempting to load state from {opt.finetune_from}")
old_state = torch.load(opt.finetune_from, map_location="cpu")
if "state_dict" in old_state:
print(f"Found nested key 'state_dict' in checkpoint, loading this instead")
old_state = old_state["state_dict"]
m, u = model.load_state_dict(old_state, strict=False)
if len(m) > 0:
print("missing keys:")
print(m)
if len(u) > 0:
print("unexpected keys:")
print(u)

# trainer and callbacks
trainer_kwargs = dict()

Expand All @@ -666,7 +687,7 @@ def check_frequency(self, check_idx):
}
},
}
default_logger_cfg = default_logger_cfgs["testtube"]
default_logger_cfg = default_logger_cfgs["wandb"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
Expand Down
9 changes: 6 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
albumentations==0.4.3
opencv-python
opencv-python==4.5.5.64
pudb==2019.2
imageio==2.9.0
imageio-ffmpeg==0.4.2
pytorch-lightning==1.4.2
torchmetrics==0.6
omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
torch-fidelity==0.3.0
transformers==4.19.2
transformers
kornia==0.6
webdataset==0.2.5
torchmetrics==0.6.0
fire==0.4.0
gradio==3.2
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e .
Loading

0 comments on commit 7e3956e

Please sign in to comment.