Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move training-only dependencies to [train] extra #60

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ TODO: write this section

## Installation

`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e <path to repository>`.
`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on.

To run the training and inference scripts, clone this repository and run `pip install -e <path to repository>[train]`
(to install with the `train` extra that includes additional libraries required for training).

## Training

Expand Down
11 changes: 8 additions & 3 deletions k_diffusion/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import operator

import numpy as np
from skimage import transform
try:
import skimage.transform as skt
except ImportError:
skt = None
import torch
from torch import nn

Expand Down Expand Up @@ -31,6 +34,8 @@ def rotate2d(theta):

class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8, disable_all=False):
if not skt:
raise ImportError('Please install scikit-image to use KarrasAugmentationPipeline')
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
Expand Down Expand Up @@ -78,9 +83,9 @@ def __call__(self, image):
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
tf = skt.AffineTransform(mat.numpy())
if not self.disable_all:
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image = skt.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
else:
image = image_orig
cond = torch.zeros_like(cond)
Expand Down
5 changes: 4 additions & 1 deletion k_diffusion/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from pathlib import Path

from cleanfid.inception_torchscript import InceptionV3W
import clip
import torch
from torch import nn
Expand All @@ -16,6 +15,10 @@
class InceptionV3FeatureExtractor(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
try:
from cleanfid.inception_torchscript import InceptionV3W
except ImportError as ie:
raise ImportError('Please install clean-fid to use InceptionV3FeatureExtractor') from ie
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
Expand Down
17 changes: 0 additions & 17 deletions requirements.txt

This file was deleted.

9 changes: 6 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@ license = MIT

[options]
packages = find:
install_requires =
install_requires =
accelerate
clean-fid
clip-anytorch
dctorch
einops
jsonmerge
kornia
Pillow
safetensors
scikit-image
scipy
torch >= 2.1
torchdiffeq
torchsde
torchvision
tqdm

[options.extras_require]
train =
clean-fid
scikit-image
wandb