Skip to content

Commit

Permalink
Replace most print()s with logging calls
Browse files Browse the repository at this point in the history
This is a rebase of #42 after it was reverted,
and includes the fix from #64.
  • Loading branch information
akx committed Jul 27, 2023
1 parent 45c443b commit b7384ca
Show file tree
Hide file tree
Showing 14 changed files with 135 additions and 100 deletions.
26 changes: 13 additions & 13 deletions sgm/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import logging
from typing import Optional

import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule

logger = logging.getLogger(__name__)

try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
print("please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)
raise NotImplementedError(
"Datasets not yet available. "
"To enable, we need to add stable-datasets as a submodule; "
"please use ``git submodule update --init --recursive`` "
"and do ``pip install -e stable-datasets/`` from the root of this repo"
) from e


class StableDataModuleFromConfig(LightningDataModule):
Expand All @@ -39,8 +41,8 @@ def __init__(
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
print(
"Warning: No Validation datapipeline defined, using that one from training"
logger.warning(
"No Validation datapipeline defined, using that one from training"
)
self.val_config = train

Expand All @@ -52,12 +54,10 @@ def __init__(

self.dummy = dummy
if self.dummy:
print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")

def setup(self, stage: str) -> None:
print("Preparing datasets")
logger.debug("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
Expand Down
2 changes: 1 addition & 1 deletion sgm/inference/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):

def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
logger.info(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
Expand Down
31 changes: 16 additions & 15 deletions sgm/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import logging

import numpy as np

logger = logging.getLogger(__name__)


class LambdaWarmUpCosineScheduler:
"""
Expand All @@ -24,9 +28,8 @@ def __init__(
self.verbosity_interval = verbosity_interval

def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
Expand Down Expand Up @@ -83,12 +86,11 @@ def find_in_interval(self, n):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
Expand All @@ -114,12 +116,11 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)

if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
Expand Down
19 changes: 11 additions & 8 deletions sgm/models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import re
from abc import abstractmethod
from contextlib import contextmanager
Expand All @@ -14,6 +15,8 @@
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config

logger = logging.getLogger(__name__)


class AbstractAutoencoder(pl.LightningModule):
"""
Expand All @@ -38,7 +41,7 @@ def __init__(

if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
Expand All @@ -60,16 +63,16 @@ def init_from_ckpt(
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
logger.debug(f"Deleting key {k} from state_dict.")
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
logger.debug(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
logger.info(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
logger.info(f"Unexpected Keys: {unexpected}")

@abstractmethod
def get_input(self, batch) -> Any:
Expand All @@ -86,14 +89,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
logger.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logger.info(f"{context}: Restored training weights")

@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
Expand All @@ -104,7 +107,7 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")

def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config")
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
Expand Down
17 changes: 10 additions & 7 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union

Expand All @@ -18,6 +19,8 @@
log_txt_as_img,
)

logger = logging.getLogger(__name__)


class DiffusionEngine(pl.LightningModule):
def __init__(
Expand Down Expand Up @@ -73,7 +76,7 @@ def __init__(
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
Expand All @@ -94,13 +97,13 @@ def init_from_ckpt(
raise NotImplementedError

missing, unexpected = self.load_state_dict(sd, strict=False)
print(
logger.info(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
logger.info(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
logger.info(f"Unexpected Keys: {unexpected}")

def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
Expand Down Expand Up @@ -179,14 +182,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
logger.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logger.info(f"{context}: Restored training weights")

def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
Expand All @@ -202,7 +205,7 @@ def configure_optimizers(self):
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
logger.debug("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
Expand Down
38 changes: 21 additions & 17 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
from inspect import isfunction
from typing import Any, Optional
Expand All @@ -8,6 +9,10 @@
from packaging import version
from torch import nn


logger = logging.getLogger(__name__)


if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
Expand Down Expand Up @@ -36,9 +41,9 @@
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
logger.warning(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)

try:
Expand All @@ -48,7 +53,7 @@
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
logger.debug("no module 'xformers'. Processing without...")

from .diffusionmodules.util import checkpoint

Expand Down Expand Up @@ -289,7 +294,7 @@ def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
print(
logger.info(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
Expand Down Expand Up @@ -393,22 +398,21 @@ def __init__(
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
logger.warning(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print(
logger.warning(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
raise NotImplementedError(
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
logger.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
Expand Down Expand Up @@ -437,7 +441,7 @@ def __init__(
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
logger.info(f"{self.__class__.__name__} is using checkpointing")

def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
Expand Down Expand Up @@ -554,7 +558,7 @@ def __init__(
sdp_backend=None,
):
super().__init__()
print(
logger.debug(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
Expand All @@ -563,8 +567,8 @@ def __init__(
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
logger.warning(
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
Expand Down
6 changes: 5 additions & 1 deletion sgm/modules/autoencoding/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Union

import torch
Expand All @@ -10,6 +11,9 @@
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss


logger = logging.getLogger(__name__)


def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
Expand Down Expand Up @@ -104,7 +108,7 @@ def __init__(
super().__init__()
self.dims = dims
if self.dims > 2:
print(
logger.info(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
Expand Down
Loading

0 comments on commit b7384ca

Please sign in to comment.