-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ResNet, AlexNet, and VGG model definitions and model zoo
- Loading branch information
Showing
5 changed files
with
389 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, ResNet | ||
from .alexnet import alexnet, AlexNet | ||
from .vgg import vgg11, vgg13, vgg16, vgg19 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch.nn as nn | ||
from . import model_zoo | ||
|
||
|
||
class AlexNet(nn.Container): | ||
def __init__(self, num_classes=1000): | ||
super(AlexNet, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
nn.Conv2d(64, 192, kernel_size=5, padding=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
nn.Conv2d(192, 384, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(384, 256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(256, 256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
) | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(), | ||
nn.Linear(256 * 6 * 6, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(4096, num_classes), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), 256 * 6 * 6) | ||
x = self.classifier(x) | ||
return x | ||
|
||
|
||
def alexnet(pretrained=False): | ||
r"""AlexNet model architecture from the "One weird trick" paper. | ||
https://arxiv.org/abs/1404.5997 | ||
""" | ||
model = AlexNet() | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load('alexnet')) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
|
||
import hashlib | ||
import os | ||
import re | ||
import shutil | ||
import sys | ||
import tempfile | ||
if sys.version_info[0] == 2: | ||
from urlparse import urlparse | ||
from urllib2 import urlopen | ||
else: | ||
from urllib.request import urlopen | ||
from urllib.parse import urlparse | ||
try: | ||
from tqdm import tqdm | ||
except ImportError: | ||
tqdm = None # defined below | ||
|
||
|
||
DEFAULT_MODEL_DIR = os.path.expanduser('~/.torch/models') | ||
|
||
models = { | ||
'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', | ||
'alexnet': 'https://s3.amazonaws.com/pytorch/models/alexnet-owt-4df8aa71.pth', | ||
} | ||
|
||
# matches bfd8deac from resnet18-bfd8deac.pth | ||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | ||
|
||
|
||
def load(model_name): | ||
r"""Returns the state_dict for the given model name""" | ||
return load_url(models[model_name]) | ||
|
||
|
||
def load_url(url, model_dir=None): | ||
if model_dir is None: | ||
model_dir = os.getenv('TORCH_MODEL_ZOO', DEFAULT_MODEL_DIR) | ||
if not os.path.exists(model_dir): | ||
os.makedirs(model_dir) | ||
parts = urlparse(url) | ||
filename = os.path.basename(parts.path) | ||
cached_file = os.path.join(model_dir, filename) | ||
if not os.path.exists(cached_file): | ||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | ||
hash_prefix = HASH_REGEX.search(filename).group(1) | ||
download_url_to_file(url, cached_file, hash_prefix) | ||
return torch.load(cached_file) | ||
|
||
|
||
def download_url_to_file(url, filename, hash_prefix): | ||
u = urlopen(url) | ||
meta = u.info() | ||
if hasattr(meta, 'getheaders'): | ||
file_size = int(meta.getheaders("Content-Length")[0]) | ||
else: | ||
file_size = int(meta.get_all("Content-Length")[0]) | ||
|
||
with tempfile.NamedTemporaryFile(delete=False) as f, tqdm(total=file_size) as pbar: | ||
while True: | ||
buffer = u.read(8192) | ||
if len(buffer) == 0: | ||
break | ||
f.write(buffer) | ||
pbar.update(len(buffer)) | ||
|
||
f.seek(0) | ||
sha256 = hashlib.sha256(f.read()).hexdigest() | ||
f.close() | ||
if sha256[:len(hash_prefix)] == hash_prefix: | ||
shutil.move(f.name, filename) | ||
else: | ||
raise RuntimeError('invalid hash value (expected "{}", got "{}")' | ||
.format(hash_prefix, sha256)) | ||
|
||
|
||
if tqdm is None: | ||
# fake tqdm if it's not installed | ||
class tqdm(object): | ||
def __init__(self, total): | ||
self.total = total | ||
self.n = 0 | ||
|
||
def update(self, n): | ||
self.n += n | ||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | ||
sys.stderr.flush() | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
sys.stderr.write('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import torch.nn as nn | ||
import math | ||
from . import model_zoo | ||
|
||
|
||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', | ||
'resnet152'] | ||
|
||
|
||
def conv3x3(in_planes, out_planes, stride=1): | ||
"3x3 convolution with padding" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
|
||
|
||
class BasicBlock(nn.Container): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = conv3x3(inplanes, planes, stride) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = conv3x3(planes, planes) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class Bottleneck(nn.Container): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(Bottleneck, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * 4) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet(nn.Container): | ||
def __init__(self, block, layers, num_classes=1000): | ||
self.inplanes = 64 | ||
super(ResNet, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | ||
bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(block, 64, layers[0]) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | ||
self.avgpool = nn.AvgPool2d(7) | ||
self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
|
||
x = self.avgpool(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
|
||
return x | ||
|
||
|
||
def resnet18(pretrained=False): | ||
model = ResNet(BasicBlock, [2, 2, 2, 2]) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load('resnet18')) | ||
return model | ||
|
||
|
||
def resnet34(): | ||
return ResNet(BasicBlock, [3, 4, 6, 3]) | ||
|
||
|
||
def resnet50(): | ||
return ResNet(Bottleneck, [3, 4, 6, 3]) | ||
|
||
|
||
def resnet101(): | ||
return ResNet(Bottleneck, [3, 4, 23, 3]) | ||
|
||
|
||
def resnet152(): | ||
return ResNet(Bottleneck, [3, 8, 36, 3]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch.nn as nn | ||
from . import model_zoo | ||
|
||
|
||
class VGG(nn.Container): | ||
def __init__(self, features): | ||
super(VGG, self).__init__() | ||
self.features = features | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(), | ||
nn.Linear(512 * 7 * 7, 4096), | ||
nn.ReLU(True), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(True), | ||
nn.Linear(4096, 1000), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.classifier(x) | ||
return x | ||
|
||
|
||
def make_layers(cfg, batch_norm=False): | ||
layers = [] | ||
in_channels = 3 | ||
for v in cfg: | ||
if v == 'M': | ||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] | ||
else: | ||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) | ||
if batch_norm: | ||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] | ||
else: | ||
layers += [conv2d, nn.ReLU(inplace=True)] | ||
in_channels = v | ||
return nn.Sequential(*layers) | ||
|
||
|
||
cfg = { | ||
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | ||
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], | ||
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], | ||
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], | ||
} | ||
|
||
|
||
def vgg11(): | ||
return VGG(make_layers(cfg['A'])) | ||
|
||
|
||
def vgg11_bn(): | ||
return VGG(make_layers(cfg['A'], batch_norm=True)) | ||
|
||
|
||
def vgg13(): | ||
return VGG(make_layers(cfg['B'])) | ||
|
||
|
||
def vgg13_bn(): | ||
return VGG(make_layers(cfg['B'], batch_norm=True)) | ||
|
||
|
||
def vgg16(): | ||
return VGG(make_layers(cfg['D'])) | ||
|
||
|
||
def vgg16_bn(): | ||
return VGG(make_layers(cfg['D'], batch_norm=True)) | ||
|
||
|
||
def vgg19(): | ||
return VGG(make_layers(cfg['E'])) | ||
|
||
|
||
def vgg19_bn(): | ||
return VGG(make_layers(cfg['E'], batch_norm=True)) |