Skip to content

Commit

Permalink
full tf concordance
Browse files Browse the repository at this point in the history
  • Loading branch information
sevakon committed Apr 10, 2020
1 parent e7a77e8 commit 4a02b85
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 93 deletions.
18 changes: 10 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,24 @@ https://arxiv.org/abs/1911.09070
As of the time I started working on this project, there was no PyTorch implementation on GitHub that would match the original paper in the number of the model's parameters.

### Model Zoo
| Model Name | Weights | Params | Params paper | mAP | mAP paper |
| Model Name | Weights | #params | #params paper | mAP | mAP paper |
| :----------: | :--------: | :-----------: | :--------: | :-----: | :-----: |
| D0 | coming soon | 3.875M | 3.9M | soon | 33.8 |
| D1 | coming soon | 6.618M | 6.6M | soon | 39.6 |
| D2 | coming soon | 8.086M | 8.1M | soon | 43.0 |
| D3 | coming soon | 12.01M | 12.0M | soon | 45.8 |
| D4 | coming soon | 20.694M | 20.7M | soon | 49.4 |
| D5 | coming soon | 33.615M | 33.7M | soon | 50.7 |
| D0 | coming soon | 3.878M | 3.9M | soon | 33.5 |
| D1 | coming soon | 6.622M | 6.6M | soon | 39.1 |
| D2 | coming soon | 8.091M | 8.1M | soon | 42.5 |
| D3 | coming soon | 12.022M | 12.0M | soon | 45.9 |
| D4 | coming soon | 20.708M | 20.7M | soon | 49.0 |
| D5 | coming soon | 33.633M | 33.7M | soon | 50.5 |


### RoadMap
- [X] Model Architecture that would match the original paper
- [ ] COCO train and val script
- [X] COCO train and val script
- [X] port weights from TensorFlow
- [ ] Reproduce results from the paper
- [ ] Pre-trained weights release

### References
- EfficientDet: Scalable and Efficient Object Detection [arXiv:1911.09070](https://arxiv.org/abs/1911.09070)
- EfficientDet implementation in TensorFlow by [Google AutoML](https://github.com/google/automl/tree/master/efficientdet)
- PyTorch EfficientNet implementation by [lukemelas](https://github.com/lukemelas/EfficientNet-PyTorch)
16 changes: 16 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from pathlib import Path

MODEL_NAME = 'efficientdet-d0'

BASE_PATH = Path('./')
DATA_PATH = BASE_PATH / 'data'
WEIGHTS_PATH = BASE_PATH / 'weights'
MODEL_WEIGHTS = WEIGHTS_PATH / '{}.pth'.format(MODEL_NAME)

ASPECT_RATIOS = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
NUM_SCALES = 3
ANCHOR_SCALE = 4.0

NUM_ANCHORS = len(ASPECT_RATIOS) * NUM_SCALES
NUM_CLASSES = 90
1 change: 1 addition & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/coco
1 change: 1 addition & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from model.det import EfficientDet
55 changes: 43 additions & 12 deletions model/det.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from itertools import chain

import config as cfg
from model.backbone import EfficientNet
from model.bifpn import BiFPN
from model.utils import efficientdet_params, check_model_name
from model.head import Classifier, Regresser
from model.head import HeadNet
from model.module import ChannelAdjuster
from model.utils import efficientdet_params, check_model_name, download_model_weights


class EfficientDet(nn.Module):
def __init__(self, name):
super(EfficientDet, self).__init__()
check_model_name(name)

self.params = efficientdet_params(name)
self.backbone = EfficientNet(self.params['backbone'])

Expand All @@ -22,27 +24,56 @@ def __init__(self, name):
self.bifpn = nn.Sequential(*[BiFPN(self.params['W_bifpn'])
for _ in range(self.params['D_bifpn'])])

self.regresser = Regresser(self.params['W_bifpn'], self.params['D_class'])
self.classifier = Classifier(self.params['W_bifpn'], self.params['D_class'])
self.regresser = HeadNet(n_features=self.params['W_bifpn'],
out_channels=cfg.NUM_ANCHORS * 4,
n_repeats=self.params['D_class'])

self.classifier = HeadNet(n_features=self.params['W_bifpn'],
out_channels=cfg.NUM_ANCHORS * cfg.NUM_CLASSES,
n_repeats=self.params['D_class'])

def forward(self, x):
features = self.backbone(x)

features = self.adjuster(features)
features = self.bifpn(features)

box_outputs, cls_outputs = [], []
for f_map in features:
box_outputs.append(self.regresser(f_map))
cls_outputs.append(self.classifier(f_map))
box_outputs = self.regresser(features)
cls_outputs = self.classifier(features)

return box_outputs, cls_outputs

def initialize_weights(self):
""" Initialize Model Weights before training from scratch """
for module in chain(self.adjuster.modules(),
self.bifpn.modules(),
self.regresser.modules(),
self.classifier.modules()):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)

nn.init.zeros_(self.regresser.head.conv_pw.bias)
nn.init.constant_(self.classifier.head.conv_pw.bias, -np.log((1 - 0.01) / 0.01))

def load_backbone(self, path):
self.backbone.load_state_dict(torch.load(path), strict=True)
self.backbone.model.load_state_dict(torch.load(path), strict=True)

def load_weights(self, path):
self.load_state_dict(torch.load(path))

@staticmethod
def load_from_name(name):
pass
def from_pretrained(name=cfg.MODEL_NAME):
check_model_name(name)

if not cfg.MODEL_WEIGHTS.exists():
download_model_weights(name, cfg.MODEL_WEIGHTS)

model_to_return = EfficientDet(name)
model_to_return.load_weights(cfg.MODEL_WEIGHTS)
return model_to_return


if __name__ == '__main__':
Expand Down
54 changes: 25 additions & 29 deletions model/head.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from model.module import DepthWiseSeparableConvModule as DWSConv


class Regresser(nn.Module):
def __init__(self, n_features, n_repeats, n_anchors=9):
super(Regresser, self).__init__()
layers = [DWSConv(n_features, n_features) for _ in range(n_repeats)]
class HeadNet(nn.Module):
""" Box Regression and Classification Nets """
def __init__(self, n_features, out_channels, n_repeats):
super(HeadNet, self).__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()

self.layers = nn.Sequential(*layers)
self.head = nn.Sequential(
nn.Conv2d(n_features, n_features, 3, padding=1, groups=n_features),
nn.Conv2d(n_features, n_anchors * 4, 1)
)

def forward(self, inputs):
inputs = self.layers(inputs)
inputs = self.head(inputs)
out = inputs
return out
for _ in range(n_repeats):
self.convs.append(DWSConv(n_features, n_features,
bath_norm=False, relu=False))
bn_levels = nn.ModuleList()
for _ in range(5):
bn = nn.BatchNorm2d(n_features, eps=1e-3, momentum=0.01)
bn_levels.append(bn)
self.bns.append(bn_levels)

self.head = DWSConv(n_features, out_channels, bath_norm=False, relu=False, bias=True)

class Classifier(nn.Module):
def __init__(self, n_features, n_repeats, n_anchors=9, n_classes=90):
super(Classifier, self).__init__()
layers = [DWSConv(n_features, n_features) for _ in range(n_repeats)]
def forward(self, inputs):
outs = []

self.layers = nn.Sequential(*layers)
self.head = nn.Sequential(
nn.Conv2d(n_features, n_features, 3, padding=1, groups=n_features),
nn.Conv2d(n_features, n_anchors * n_classes, 1)
)
for f_idx, f_map in enumerate(inputs):
for conv, bn in zip(self.convs, self.bns):
f_map = conv(f_map)
f_map = bn[f_idx](f_map)
f_map = F.relu(f_map)
outs.append(self.head(f_map))

def forward(self, inputs):
inputs = self.layers(inputs)
inputs = self.head(inputs)
out = inputs
return out
return outs
35 changes: 20 additions & 15 deletions model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,37 @@ class ConvModule(nn.Module):
""" Regular Convolution with BatchNorm """
def __init__(self, in_channels, out_channels, kernel_size=1, padding=0):
super(ConvModule, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
padding=padding, bias=False),
nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.003)
)
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x


class DepthWiseSeparableConvModule(nn.Module):
""" DepthWise Separable Convolution with BatchNorm and ReLU activation """
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, out_channels, bath_norm=True, relu=True, bias=False):
super(DepthWiseSeparableConvModule, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1,
padding=1, groups=in_channels, bias=False),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1,
padding=0, bias=False),
nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.003)
)
self.conv_dw = nn.Conv2d(in_channels, in_channels, kernel_size=3,
padding=1, groups=in_channels, bias=False)
self.conv_pw = nn.Conv2d(in_channels, out_channels, kernel_size=1,
padding=0, bias=bias)

self.bn = None if bath_norm is False else \
nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01)
self.relu = relu

def forward(self, x):
x = self.conv(x)
x = F.relu(x)
x = self.conv_dw(x)
x = self.conv_pw(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x)
return x


Expand Down
15 changes: 15 additions & 0 deletions model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import numpy as np
from torch.hub import download_url_to_file


def efficientdet_params(model_name):
Expand Down Expand Up @@ -82,6 +83,8 @@ def efficientdet_params(model_name):

def check_model_name(model_name):
possibles = ['efficientdet-d' + str(i) for i in range(7)]
if model_name == 'efficientdet-d6':
raise ValueError('Sorry! EfficientDet D-6 is not yet supported :( ')
if model_name not in possibles:
raise ValueError('Name {} not in {}'.format(model_name, possibles))

Expand All @@ -90,3 +93,15 @@ def count_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params


def download_model_weights(model_name, filename):
model_to_url = {
'efficientdet-d0': '',
'efficientdet-d1': '',
'efficientdet-d2': '',
'efficientdet-d3': '',
'efficientdet-d4': '',
'efficientdet-d5': ''
}
download_url_to_file(model_to_url[model_name], filename)
29 changes: 0 additions & 29 deletions test.py

This file was deleted.

31 changes: 31 additions & 0 deletions test_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from model import EfficientDet
import config as cfg
from model.utils import efficientdet_params, count_parameters


""" Quick test on parameters number """


model = EfficientDet.from_pretrained().to('cpu')

model.train()
params = count_parameters(model)

print('Model: {}, params: {:.6f}M, params in paper: {}'.format(cfg.MODEL_NAME, params / 1e6,
efficientdet_params(cfg.MODEL_NAME)['params']))
print(' Backbone: {:.6f}M'.format(count_parameters(model.backbone) / 1e6))
print(' Adjuster: {:.6f}M'.format(count_parameters(model.adjuster) / 1e6))
print(' BiFPN: {:.6f}M'.format(count_parameters(model.bifpn) / 1e6))
print(' Head: {:.6f}M'.format((count_parameters(model.classifier) +
count_parameters(model.regresser)) / 1e6))

# model.initialize_weights()

image_size = efficientdet_params(cfg.MODEL_NAME)['R_input']
x = torch.rand(1, 3, image_size, image_size)
box, cls = model(x)

for b, c in zip(box, cls):
print(b.shape)
print(c.shape)
Empty file added utils/__init__.py
Empty file.
Empty file added utils/loss.py
Empty file.
1 change: 1 addition & 0 deletions weights/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pth

0 comments on commit 4a02b85

Please sign in to comment.