Skip to content

Commit e741042

Browse files
authored
Merge pull request #125 from Deci-AI/feature/SG-18_RegSeg
Added RegSeg to codebase
2 parents f38fe06 + 69155a1 commit e741042

File tree

5 files changed

+514
-3
lines changed

5 files changed

+514
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# RegSeg segmentation training example with Cityscapes dataset.
2+
# Reproduction of paper: Rethink Dilated Convolution for Real-time Semantic Segmentation.
3+
#
4+
# Usage RegSeg48:
5+
# python -m torch.distributed.launch --nproc_per_node=4 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=regseg48_cityscapes
6+
#
7+
#
8+
# Validation mIoU - Cityscapes, training time:
9+
# RegSeg48: input-size: [1024, 2048] mIoU: 78.15 using 4 GeForce RTX 2080 Ti with DDP, ~2 minutes / epoch
10+
#
11+
# Official git repo:
12+
# https://github.com/RolandGao/RegSeg
13+
# Paper:
14+
# https://arxiv.org/pdf/2111.09957.pdf
15+
#
16+
#
17+
# Logs, tensorboards and network checkpoints:
18+
# s3://deci-pretrained-models/regseg48_cityscapes/
19+
#
20+
#
21+
# Learning rate and batch size parameters, using 4 GeForce RTX 2080 Ti with DDP:
22+
# RegSeg48: input-size: [1024, 2048] initial_lr: 0.02 batch-size: 4 * 4gpus = 16
23+
24+
defaults:
25+
- training_hyperparams: default_train_params
26+
- dataset_params: cityscapes_dataset_params
27+
28+
hydra:
29+
searchpath:
30+
- pkg://super_gradients.recipes
31+
32+
project_name: RegSeg
33+
architecture: regseg48
34+
experiment_name: ${architecture}_cityscapes
35+
multi_gpu: AUTO
36+
37+
arch_params:
38+
num_classes: 19
39+
sync_bn: True
40+
strict_load: no_key_matching
41+
42+
dataset_params:
43+
_convert_: all
44+
batch_size: 4
45+
val_batch_size: 4
46+
crop_size: 1024
47+
img_size: 1024
48+
random_scales:
49+
- 0.4
50+
- 1.6
51+
image_mask_transforms_aug:
52+
Compose:
53+
transforms:
54+
- ColorJitterSeg:
55+
brightness: 0.1
56+
contrast: 0.1
57+
saturation: 0.1
58+
59+
- RandomFlipSeg
60+
61+
- RandomRescaleSeg:
62+
scales: ${dataset_params.random_scales}
63+
64+
- PadShortToCropSizeSeg:
65+
crop_size: ${dataset_params.crop_size}
66+
fill_image:
67+
- ${dataset_params.cityscapes_ignored_label}
68+
- 0
69+
- 0
70+
fill_mask: ${dataset_params.cityscapes_ignored_label}
71+
72+
- CropImageAndMaskSeg:
73+
crop_size: ${dataset_params.crop_size}
74+
mode: random
75+
76+
image_mask_transforms:
77+
Compose:
78+
transforms: [ ]
79+
80+
dataset_interface:
81+
cityscapes:
82+
dataset_params: ${dataset_params}
83+
84+
data_loader_num_workers: 8
85+
86+
model_checkpoints_location: local
87+
load_checkpoint: False
88+
89+
training_hyperparams:
90+
max_epochs: 800
91+
lr_mode: poly
92+
initial_lr: 0.02 # for effective batch_size=16
93+
lr_warmup_epochs: 0
94+
optimizer: SGD
95+
optimizer_params:
96+
momentum: 0.9
97+
weight_decay: 5e-4
98+
99+
ema: True
100+
101+
loss: cross_entropy
102+
criterion_params:
103+
ignore_index: ${dataset_params.cityscapes_ignored_label}
104+
105+
train_metrics_list:
106+
- PixelAccuracy:
107+
ignore_label: ${dataset_params.cityscapes_ignored_label}
108+
- IoU:
109+
num_classes: 20
110+
ignore_index: ${dataset_params.cityscapes_ignored_label}
111+
112+
valid_metrics_list:
113+
- PixelAccuracy:
114+
ignore_label: ${dataset_params.cityscapes_ignored_label}
115+
- IoU:
116+
num_classes: 20
117+
ignore_index: ${dataset_params.cityscapes_ignored_label}
118+
119+
metric_to_watch: IoU
120+
greater_metric_to_watch_is_better: True
121+
122+
_convert_: all

src/super_gradients/training/models/all_architectures.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from super_gradients.training.models.detection_models.yolov3 import YoloV3, TinyYoloV3
1414
from super_gradients.training.models.detection_models.yolov5 import YoLoV5N, YoLoV5S, YoLoV5M, YoLoV5L, YoLoV5X, Custom_YoLoV5
1515
from super_gradients.training.models.segmentation_models.ddrnet import DDRNet23, DDRNet23Slim, AnyBackBoneDDRNet23
16+
from super_gradients.training.models.segmentation_models.regseg import RegSeg48
1617
from super_gradients.training.models.segmentation_models.shelfnet import ShelfNet18_LW, ShelfNet34_LW, ShelfNet50, \
1718
ShelfNet503343, ShelfNet101
1819
from super_gradients.training.models.segmentation_models.stdc import STDC1Classification, STDC2Classification,\
@@ -105,4 +106,5 @@
105106
"stdc2_seg": STDC2Seg,
106107
"stdc2_seg50": STDC2Seg,
107108
"stdc2_seg75": STDC2Seg,
109+
"regseg48": RegSeg48,
108110
}

0 commit comments

Comments
 (0)