-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_dinov2_models.py
123 lines (84 loc) · 3.85 KB
/
eval_dinov2_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# grouped & sorted imports
import argparse
import os
from PIL import ImageFile
from torch.utils.data import DataLoader
from dataset import ImageFolder
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torchvision import transforms, datasets
# simplified functions
def save_img(img, file_dir):
transforms.ToPILImage()(img.cpu()).save(file_dir)
def get_dataloader(dataset, batch_size, shuffle=False):
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def dinov2_vit_base_patch14():
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
model = model.to(DEVICE)
return model
def dino_v2_vit_base_patch14_reg4():
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc')
model = model.to(DEVICE)
return model
def dinov2_vit_small_patch14():
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
model = model.to(DEVICE)
return model
def dino_v2_vit_small_patch14_reg4():
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc')
model = model.to(DEVICE)
return model
def dinov2_vit_large_patch14():
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
model = model.to(DEVICE)
return model
def dino_v2_vit_large_patch14_reg4():
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc')
model = model.to(DEVICE)
return model
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--dataset_path", default=r'/path/to/dataset', type=str)
args.add_argument("--batch_size", default=64, type=int)
args.add_argument("--save_dir", default="./DINOv2_Results", type=str)
args = args.parse_args()
transform = transforms.Compose([
transforms.Resize(size=518, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(size=(518, 518)),
transforms.ToTensor(),
transforms.Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), std=torch.tensor([0.2290, 0.2240, 0.2250]))
])
# dataset and model creation
dataset = ImageFolder(args.dataset_path, transform)
dataloader = get_dataloader(dataset, args.batch_size)
os.makedirs(args.save_dir, exist_ok=True)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ALL_CLASSES = sorted(os.listdir(args.dataset_path))
# only on dinov2 models
models = [dinov2_vit_base_patch14, dino_v2_vit_base_patch14_reg4, dinov2_vit_small_patch14, dino_v2_vit_small_patch14_reg4,
dinov2_vit_large_patch14,
dino_v2_vit_large_patch14_reg4]
names = ['dinov2_vit_base_patch14', 'dino_v2_vit_base_patch14_reg4', 'dinov2_vit_small_patch14', 'dino_v2_vit_small_patch14_reg4',
'dinov2_vit_large_patch14', 'dino_v2_vit_large_patch14_reg4']
accuracy_per_model = []
with torch.no_grad():
for i, model in enumerate(models):
model = model()
correct = 0
total = 0
for batch_number, (images, labels) in enumerate(dataloader):
images = images.to(DEVICE)
labels = labels.to(DEVICE)
logits = model(images)
_, predicted = torch.max(logits.data, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
accuracy_per_model.append(accuracy)
print(f"Model {names[i]} TF: {transforms}")
with open(f'{args.save_dir}/dinov2_model_results.txt', 'a') as f:
f.write(f"Model {names[i]} Accuracy: {accuracy}\n")
average_accuracy = sum(accuracy_per_model) / len(models)
print(f"Average Accuracy: {average_accuracy}")
with open(f'{args.save_dir}/dinov2_model_results.txt', 'a') as f:
f.write(f"Average Accuracy: {average_accuracy}\n")