10
10
@utils .cache
11
11
class OpenCLIP (torch .nn .Module ):
12
12
def __init__ (
13
- self , architecture = "ViT-B-32" , weights = "laion2b_e16" , precision = None , jit = False
13
+ self ,
14
+ architecture = "ViT-L-14" ,
15
+ weights = "laion2b_s32b_b82k" ,
16
+ precision = None ,
17
+ jit = False ,
14
18
):
15
19
"""
16
20
Args:
17
21
architecture (str): name of the clip model
18
22
weights (str): name of the weights
19
23
20
24
Available weight/model combinations are (in order of relevance):
21
- - ("ViT-B-32", "laion2b_e16") (65.7%)
25
+ - ("ViT-H-14", "laion2b_s32b_b79k") (78.0%)
26
+ - ("ViT-g-14", "laion2b_s12b_b42k") (76.6%)
27
+ - ("ViT-L-14", "laion2b_s32b_b82k") (75.3%)
28
+ - ("ViT-B-32", "laion2b_s34b_b79k") (66.6%)
22
29
- ("ViT-B-16-plus-240", "laion400m_e32") (69.2%)
30
+ - ("ViT-B-32", "laion2b_e16") (65.7%)
23
31
- ("ViT-B-16", "laion400m_e32") (67.0%)
24
32
- ("ViT-B-32", "laion400m_e32") (62.9%)
25
33
- ("ViT-L-14", "laion400m_e32") (72.8%)
@@ -43,9 +51,10 @@ def __init__(
43
51
if (architecture , weights ) not in open_clip .list_pretrained ():
44
52
raise ValueError (f"Invalid architecture/weights: { architecture } /{ weights } " )
45
53
54
+ pretrained_cfg = open_clip .pretrained .get_pretrained_cfg (architecture , weights )
46
55
weights_path = open_clip .pretrained .download_pretrained (
47
- open_clip . pretrained . get_pretrained_url ( architecture , weights ) ,
48
- root = "models" ,
56
+ pretrained_cfg ,
57
+ cache_dir = "models" ,
49
58
)
50
59
51
60
# softmax on cpu does not support half precision
@@ -58,6 +67,7 @@ def __init__(
58
67
else :
59
68
precision = "fp32"
60
69
70
+ # hack: needed to specify path to weights
61
71
if weights == "openai" :
62
72
self .model = open_clip .load_openai_model (
63
73
weights_path , start_device , jit = jit
@@ -73,12 +83,32 @@ def __init__(
73
83
jit = jit ,
74
84
).eval ()
75
85
86
+ # hack: since we specified the weights path instead of the model name the config isn't loaded right
87
+ setattr (
88
+ self .model .visual ,
89
+ "image_mean" ,
90
+ pretrained_cfg .get (
91
+ "mean" ,
92
+ getattr (self .model .visual , "image_mean" , None ),
93
+ )
94
+ or (0.48145466 , 0.4578275 , 0.40821073 ),
95
+ )
96
+ setattr (
97
+ self .model .visual ,
98
+ "image_std" ,
99
+ pretrained_cfg .get (
100
+ "std" ,
101
+ getattr (self .model .visual , "image_std" , None ),
102
+ )
103
+ or (0.26862954 , 0.26130258 , 0.27577711 ),
104
+ )
105
+
76
106
if jit is False :
77
107
self .model = self .model .requires_grad_ (False )
78
108
79
109
self .normalize = transforms .Normalize (
80
- ( 0.48145466 , 0.4578275 , 0.40821073 ) ,
81
- ( 0.26862954 , 0.26130258 , 0.27577711 ) ,
110
+ self . model . visual . image_mean ,
111
+ self . model . visual . image_std ,
82
112
)
83
113
84
114
def to (self , device ):
@@ -90,6 +120,13 @@ def to(self, device):
90
120
def device (self ):
91
121
return next (iter (self .parameters ())).device
92
122
123
+ @property
124
+ def image_size (self ):
125
+ if isinstance (self .model .visual .image_size , tuple ):
126
+ return self .model .visual .image_size
127
+ else :
128
+ return (self .model .visual .image_size , self .model .visual .image_size )
129
+
93
130
@torch .cuda .amp .autocast ()
94
131
def encode_texts (self , text_prompts , normalize = True ):
95
132
encodings = self .model .encode_text (
@@ -106,10 +143,7 @@ def encode_images(self, images, normalize=True):
106
143
self .normalize (
107
144
resize (
108
145
images .to (self .device ),
109
- out_shape = (
110
- self .model .visual .image_size ,
111
- self .model .visual .image_size ,
112
- ),
146
+ out_shape = self .image_size ,
113
147
)
114
148
)
115
149
)
@@ -126,7 +160,7 @@ def forward(self, _):
126
160
def test_open_clip ():
127
161
import torch
128
162
129
- model = OpenCLIP ("ViT-B-32" , "laion2b_e16" )
163
+ model = OpenCLIP ()
130
164
131
165
image = torch .randn ((1 , 3 , 256 , 256 )).requires_grad_ ()
132
166
with torch .enable_grad ():
0 commit comments