-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaugmentations.py
34 lines (28 loc) · 1.28 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms.autoaugment import AutoAugmentPolicy
def GetAugment():
plain = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
baseline = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.RandomErasing()
])
autoaugment = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.AutoAugment(AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.RandomErasing()
])
return plain, baseline, autoaugment