-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtrain_classifier.py
41 lines (31 loc) · 2.08 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Import necessary modules and functions
from tools.classification import load_classification_data, train_classification_model
from model import Classifier # Import custom model from model.py file
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Check if CUDA is available and set PyTorch to use GPU or CPU accordingly
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Use the load_data function from tools.training to load our dataset
# This function presumably returns a set of data loaders and the number of classes in the dataset
dataloaders, num_classes = load_classification_data()
# Initialize our classifier model with the number of output classes equal to num_classes
model = Classifier(num_classes) # this will load the small model
# model = Classifier(num_classes, backbone = 'dinov2_b') # to load the base model
# model = Classifier(num_classes, backbone = 'dinov2_l') # to load the large model
# model = Classifier(num_classes, backbone = 'dinov2_g') # to load the largest model
# Move the model to the device (GPU or CPU)
model.to(device)
# Set our loss function to Cross Entropy Loss, a common choice for classification problems
criterion = nn.CrossEntropyLoss()
# Initialize Stochastic Gradient Descent (SGD) as our optimizer
# Set the initial learning rate to 0.001 and momentum to 0.9
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Initialize a learning rate scheduler that reduces learning rate when a metric has stopped improving
# In this case, we're monitoring the minimum validation loss with a patience of 7 epochs
# i.e., the learning rate will be reduced if the validation loss does not improve for 7 consecutive epochs
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=7, verbose=True)
# Finally, use the train_model function from tools.training to train our model
# The model, dataloaders, loss function, optimizer, learning rate scheduler, and device are passed as arguments
model = train_classification_model(model, dataloaders, criterion, optimizer, scheduler, device)