Skip to content

Commit 0ca3146

Browse files
committed
improve: support new open clip models
1 parent 5599547 commit 0ca3146

File tree

4 files changed

+69
-26
lines changed

4 files changed

+69
-26
lines changed

perceptor/losses/open_clip.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,38 @@
55

66

77
class OpenCLIP(LossInterface):
8-
def __init__(self, architecture="ViT-B-32", weights="laion2b_e16"):
8+
def __init__(
9+
self,
10+
architecture="ViT-H-14",
11+
weights="laion2b_s32b_b79k",
12+
):
913
"""
1014
Args:
1115
architecture (str): name of the clip model
1216
weights (str): name of the weights
1317
1418
Available weight/model combinations are (in order of relevance):
15-
- ("ViT-B-32", "laion2b_e16") (65.62%)
16-
- ("ViT-B-16-plus-240", "laion400m_e32") (69.21%)
17-
- ("ViT-B-16", "laion400m_e32") (67.07%)
18-
- ("ViT-B-32", "laion400m_e32") (62.96%)
19-
- ("ViT-L-14", "laion400m_e32") (72.77%)
19+
- ("ViT-H-14", "laion2b_s32b_b79k") (78.0%)
20+
- ("ViT-g-14", "laion2b_s12b_b42k") (76.6%)
21+
- ("ViT-L-14", "laion2b_s32b_b82k") (75.3%)
22+
- ("ViT-B-32", "laion2b_s34b_b79k") (66.6%)
23+
- ("ViT-B-16-plus-240", "laion400m_e32") (69.2%)
24+
- ("ViT-B-32", "laion2b_e16") (65.7%)
25+
- ("ViT-B-16", "laion400m_e32") (67.0%)
26+
- ("ViT-B-32", "laion400m_e32") (62.9%)
27+
- ("ViT-L-14", "laion400m_e32") (72.8%)
2028
- ("RN101", "yfcc15m") (34.8%)
2129
- ("RN50", "yfcc15m") (32.7%)
2230
- ("RN50", "cc12m") (36.45%)
23-
- ("RN50-quickgelu", "openai")
31+
- ("RN50-quickgelu", "openai") (59.6%)
2432
- ("RN101-quickgelu", "openai")
2533
- ("RN50x4", "openai")
2634
- ("RN50x16", "openai")
2735
- ("RN50x64", "openai")
28-
- ("ViT-B-32-quickgelu", "openai")
29-
- ("ViT-B-16", "openai")
30-
- ("ViT-L-14", "openai")
31-
- ("ViT-L-14-336", "openai")
36+
- ("ViT-B-32-quickgelu", "openai") (63.3%)
37+
- ("ViT-B-16", "openai") (68.3%)
38+
- ("ViT-L-14", "openai") (75.6%)
39+
- ("ViT-L-14-336", "openai") (76.6%)
3240
"""
3341
super().__init__()
3442
self.architecture = architecture

perceptor/models/open_clip.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,24 @@
1010
@utils.cache
1111
class OpenCLIP(torch.nn.Module):
1212
def __init__(
13-
self, architecture="ViT-B-32", weights="laion2b_e16", precision=None, jit=False
13+
self,
14+
architecture="ViT-L-14",
15+
weights="laion2b_s32b_b82k",
16+
precision=None,
17+
jit=False,
1418
):
1519
"""
1620
Args:
1721
architecture (str): name of the clip model
1822
weights (str): name of the weights
1923
2024
Available weight/model combinations are (in order of relevance):
21-
- ("ViT-B-32", "laion2b_e16") (65.7%)
25+
- ("ViT-H-14", "laion2b_s32b_b79k") (78.0%)
26+
- ("ViT-g-14", "laion2b_s12b_b42k") (76.6%)
27+
- ("ViT-L-14", "laion2b_s32b_b82k") (75.3%)
28+
- ("ViT-B-32", "laion2b_s34b_b79k") (66.6%)
2229
- ("ViT-B-16-plus-240", "laion400m_e32") (69.2%)
30+
- ("ViT-B-32", "laion2b_e16") (65.7%)
2331
- ("ViT-B-16", "laion400m_e32") (67.0%)
2432
- ("ViT-B-32", "laion400m_e32") (62.9%)
2533
- ("ViT-L-14", "laion400m_e32") (72.8%)
@@ -43,9 +51,10 @@ def __init__(
4351
if (architecture, weights) not in open_clip.list_pretrained():
4452
raise ValueError(f"Invalid architecture/weights: {architecture}/{weights}")
4553

54+
pretrained_cfg = open_clip.pretrained.get_pretrained_cfg(architecture, weights)
4655
weights_path = open_clip.pretrained.download_pretrained(
47-
open_clip.pretrained.get_pretrained_url(architecture, weights),
48-
root="models",
56+
pretrained_cfg,
57+
cache_dir="models",
4958
)
5059

5160
# softmax on cpu does not support half precision
@@ -58,6 +67,7 @@ def __init__(
5867
else:
5968
precision = "fp32"
6069

70+
# hack: needed to specify path to weights
6171
if weights == "openai":
6272
self.model = open_clip.load_openai_model(
6373
weights_path, start_device, jit=jit
@@ -73,12 +83,32 @@ def __init__(
7383
jit=jit,
7484
).eval()
7585

86+
# hack: since we specified the weights path instead of the model name the config isn't loaded right
87+
setattr(
88+
self.model.visual,
89+
"image_mean",
90+
pretrained_cfg.get(
91+
"mean",
92+
getattr(self.model.visual, "image_mean", None),
93+
)
94+
or (0.48145466, 0.4578275, 0.40821073),
95+
)
96+
setattr(
97+
self.model.visual,
98+
"image_std",
99+
pretrained_cfg.get(
100+
"std",
101+
getattr(self.model.visual, "image_std", None),
102+
)
103+
or (0.26862954, 0.26130258, 0.27577711),
104+
)
105+
76106
if jit is False:
77107
self.model = self.model.requires_grad_(False)
78108

79109
self.normalize = transforms.Normalize(
80-
(0.48145466, 0.4578275, 0.40821073),
81-
(0.26862954, 0.26130258, 0.27577711),
110+
self.model.visual.image_mean,
111+
self.model.visual.image_std,
82112
)
83113

84114
def to(self, device):
@@ -90,6 +120,13 @@ def to(self, device):
90120
def device(self):
91121
return next(iter(self.parameters())).device
92122

123+
@property
124+
def image_size(self):
125+
if isinstance(self.model.visual.image_size, tuple):
126+
return self.model.visual.image_size
127+
else:
128+
return (self.model.visual.image_size, self.model.visual.image_size)
129+
93130
@torch.cuda.amp.autocast()
94131
def encode_texts(self, text_prompts, normalize=True):
95132
encodings = self.model.encode_text(
@@ -106,10 +143,7 @@ def encode_images(self, images, normalize=True):
106143
self.normalize(
107144
resize(
108145
images.to(self.device),
109-
out_shape=(
110-
self.model.visual.image_size,
111-
self.model.visual.image_size,
112-
),
146+
out_shape=self.image_size,
113147
)
114148
)
115149
)
@@ -126,7 +160,7 @@ def forward(self, _):
126160
def test_open_clip():
127161
import torch
128162

129-
model = OpenCLIP("ViT-B-32", "laion2b_e16")
163+
model = OpenCLIP()
130164

131165
image = torch.randn((1, 3, 256, 256)).requires_grad_()
132166
with torch.enable_grad():

poetry.lock

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "perceptor"
3-
version = "0.5.11"
3+
version = "0.5.12"
44
description = "Modular image generation library"
55
authors = ["Richard Löwenström <[email protected]>", "dribnet"]
66
readme = "README.md"
@@ -28,10 +28,10 @@ more-itertools = "^8.12.0"
2828
dill = "^0.3.4"
2929
ninja = "^1.10.2"
3030
lpips = "^0.1.4"
31-
open-clip-torch = "^1.3.0"
3231
pytorch-lantern = "^0.12.0"
3332
taming-transformers-rom1504 = "^0.0.6"
3433
diffusers = "^0.2.4"
34+
open-clip-torch = "^2.0.0"
3535

3636
[tool.poetry.dev-dependencies]
3737
ipykernel = "^6.8.0"

0 commit comments

Comments
 (0)