This repository was archived by the owner on Aug 18, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 234
/
Copy pathtrain_imagenette.py
77 lines (69 loc) · 3.56 KB
/
train_imagenette.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastai2.basics import *
from fastai2.vision.all import *
from fastai2.callback.all import *
from fastai2.distributed import *
from fastprogress import fastprogress
from torchvision.models import *
from fastai2.vision.models.xresnet import *
from fastai2.callback.mixup import *
from fastscript import *
torch.backends.cudnn.benchmark = True
fastprogress.MAX_COLS = 80
def get_dbunch(size, woof, bs, sh=0., workers=None):
if size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
else : path = URLs.IMAGEWOOF if woof else URLs.IMAGENETTE
source = untar_data(path)
if workers is None: workers = min(8, num_cpus())
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
splitter=GrandparentSplitter(valid_name='val'),
get_items=get_image_files, get_y=parent_label)
item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
batch_tfms=RandomErasing(p=0.9, max_count=3, sh=sh) if sh else None
return dblock.databunch(source, path=source, bs=bs, num_workers=workers,
item_tfms=item_tfms, batch_tfms=batch_tfms)
@call_parse
def main(
gpu: Param("GPU to run on", int)=None,
woof: Param("Use imagewoof (otherwise imagenette)", int)=0,
lr: Param("Learning rate", float)=1e-2,
size: Param("Size (px: 128,192,256)", int)=128,
sqrmom:Param("sqr_mom", float)=0.99,
mom: Param("Momentum", float)=0.9,
eps: Param("epsilon", float)=1e-6,
epochs:Param("Number of epochs", int)=5,
bs: Param("Batch size", int)=64,
mixup: Param("Mixup", float)=0.,
opt: Param("Optimizer (adam,rms,sgd,ranger)", str)='ranger',
arch: Param("Architecture", str)='xresnet50',
sh: Param("Random erase max proportion", float)=0.,
sa: Param("Self-attention", int)=0,
sym: Param("Symmetry for self-attention", int)=0,
beta: Param("SAdam softplus beta", float)=0.,
act_fn:Param("Activation function", str)='MishJit',
fp16: Param("Use mixed precision training", int)=0,
pool: Param("Pooling method", str)='AvgPool',
dump: Param("Print model; don't train", int)=0,
runs: Param("Number of times to repeat training", int)=1,
meta: Param("Metadata (ignored)", str)='',
):
"Distributed training of Imagenette."
#gpu = setup_distrib(gpu)
if gpu is not None: torch.cuda.set_device(gpu)
if opt=='adam' : opt_func = partial(Adam, mom=mom, sqr_mom=sqrmom, eps=eps)
elif opt=='rms' : opt_func = partial(RMSprop, sqr_mom=sqrmom)
elif opt=='sgd' : opt_func = partial(SGD, mom=mom)
elif opt=='ranger': opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)
dbunch = get_dbunch(size, woof, bs, sh=sh)
if not gpu: print(f'lr: {lr}; size: {size}; sqrmom: {sqrmom}; mom: {mom}; eps: {eps}')
m,act_fn,pool = [globals()[o] for o in (arch,act_fn,pool)]
for run in range(runs):
print(f'Run: {run}')
learn = Learner(dbunch, m(c_out=10, act_cls=act_fn, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \
metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())
if dump: print(learn.model); exit()
if fp16: learn = learn.to_fp16()
cbs = MixUp(mixup) if mixup else []
#n_gpu = torch.cuda.device_count()
#if gpu is None and n_gpu: learn.to_parallel()
if num_distrib()>1: learn.to_distributed(gpu) # Requires `-m fastai.launch`
learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)