Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detection recipe enhancements #5715

Merged
merged 3 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = model.parameters()
parameters = [p for p in model.parameters() if p.requires_grad]
datumbox marked this conversation as resolved.
Show resolved Hide resolved
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
Expand Down
23 changes: 22 additions & 1 deletion references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class DetectionPresetTrain:
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
datumbox marked this conversation as resolved.
Show resolved Hide resolved
if data_augmentation == "hflip":
self.transforms = T.Compose(
[
Expand All @@ -12,6 +12,27 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
datumbox marked this conversation as resolved.
Show resolved Hide resolved
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
Expand Down
25 changes: 23 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument(
"--lr",
default=0.02,
Expand All @@ -84,6 +85,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
datumbox marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
)
Expand Down Expand Up @@ -176,6 +183,8 @@ def main(args):

print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if args.data_augmentation in ["multiscale", "lsj"]:
kwargs["_skip_resize"] = True
datumbox marked this conversation as resolved.
Show resolved Hide resolved
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
Expand All @@ -191,8 +200,20 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.norm_weight_decay is None:
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Straight copy-paste from classification.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since is copy-pasted I think it can stay as is and if we want to change we can do in a different PR, but still wonder if there is a reson why for sgd we use starts_with and for adamw we use ==?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! I've indeed copy-pasted it from classification but replaced the previous sgd optimizer line. The problem is that the existing recipe didn't contain the nesterov momentum update. I've just updated the file to support it; it's not something I used so far but it's a simple update.


scaler = torch.cuda.amp.GradScaler() if args.amp else None

Expand Down
1 change: 0 additions & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_get_weight(name, weight):
)
def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
print(weights_enum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

profit!

assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")

Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
**kwargs,
):

if not hasattr(backbone, "out_channels"):
Expand Down Expand Up @@ -268,7 +269,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
datumbox marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(backbone, rpn, roi_heads, transform)

Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def __init__(
nms_thresh: float = 0.6,
detections_per_img: int = 100,
topk_candidates: int = 1000,
**kwargs,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -410,7 +411,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

self.center_sampling_radius = center_sampling_radius
self.score_thresh = score_thresh
Expand Down
2 changes: 2 additions & 0 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
keypoint_head=None,
keypoint_predictor=None,
num_keypoints=None,
**kwargs,
):

if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
Expand Down Expand Up @@ -259,6 +260,7 @@ def __init__(
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
**kwargs,
)

self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
Expand Down
2 changes: 2 additions & 0 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
**kwargs,
):

if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
Expand Down Expand Up @@ -254,6 +255,7 @@ def __init__(
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
**kwargs,
)

self.roi_heads.mask_roi_pool = mask_roi_pool
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __init__(
fg_iou_thresh=0.5,
bg_iou_thresh=0.4,
topk_candidates=1000,
**kwargs,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -383,7 +384,7 @@ def __init__(
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)

self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
iou_thresh: float = 0.5,
topk_candidates: int = 400,
positive_fraction: float = 0.25,
**kwargs: Any,
):
super().__init__()
_log_api_usage_once(self)
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
)

self.score_thresh = score_thresh
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional, Any

import torch
import torchvision
Expand Down Expand Up @@ -91,6 +91,7 @@ def __init__(
image_std: List[float],
size_divisible: int = 32,
fixed_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
):
super().__init__()
if not isinstance(min_size, (list, tuple)):
Expand All @@ -101,6 +102,7 @@ def __init__(
self.image_std = image_std
self.size_divisible = size_divisible
self.fixed_size = fixed_size
self._skip_resize = kwargs.pop("_skip_resize", False)

def forward(
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
Expand Down Expand Up @@ -170,6 +172,8 @@ def resize(
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
h, w = image.shape[-2:]
if self.training:
if self._skip_resize:
return image, target
datumbox marked this conversation as resolved.
Show resolved Hide resolved
size = float(self.torch_choice(self.min_size))
else:
# FIXME assume for now that testing uses the largest scale
Expand Down
8 changes: 7 additions & 1 deletion torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def split_normalization_params(
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
norm_classes = [
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LocalResponseNorm,
]
datumbox marked this conversation as resolved.
Show resolved Hide resolved

for t in norm_classes:
if not issubclass(t, nn.Module):
Expand Down