-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSVM.py
33 lines (27 loc) · 1013 Bytes
/
SVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torch.nn.functional as F
class SVM(nn.Module):
"""
Linear Support Vector Machine
-----------------------------
This SVM is a subclass of the PyTorch nn module that
implements the Linear function. The size of each
input sample is input_dim and output sample is output_dim.
"""
def __init__(self, input_dim, output_dim):
super().__init__() # Call the init function of nn.Module
self.fully_connected = nn.Linear(input_dim, output_dim) # Implement the Linear function
def forward(self, x):
fwd = self.fully_connected(x.transpose(1, 2)) # Forward pass
return fwd
class HingeLoss(nn.Module):
# HingeLoss for SVM
# Is not 'nn.HingeEmbeddingLoss'
def __init__(self):
super(HingeLoss, self).__init__()
def forward(self, y_pred, y_true):
y_true = y_true.unsqueeze(-1)
loss = F.relu(1 - y_true * y_pred)
mean_loss = torch.mean(loss)
return mean_loss