Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ResNet, AlexNet, and VGG model definitions and model zoo #23

Merged
merged 1 commit into from
Jan 9, 2017

Conversation

colesbury
Copy link
Member

You can create a ResNet-18 model pre-trained on ImageNet with:

import torchvision.models
resnet = torchvision.models.resnet18(pretrained=True)

@colesbury colesbury force-pushed the zoo branch 2 times, most recently from 236d645 to c96e794 Compare January 6, 2017 18:15
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fmassa
Copy link
Member

fmassa commented Jan 6, 2017

About the padding in alexnet, if I remember properly the original network does not contain padding.
I saw a padded version of alexnet in Fast R-CNN paper, but padding was added even to the pooling layers.
Is it something we want to add here?

@fmassa
Copy link
Member

fmassa commented Jan 6, 2017

I'm wondering, it might be better at some point to move the model_zoo.py to a different place.
The only part specific to vision is the model names and paths, and whenever we want to support models from other domains, we will need such a file.

@soumith
Copy link
Member

soumith commented Jan 6, 2017

good idea. torch.utils?

@colesbury
Copy link
Member Author

colesbury commented Jan 6, 2017

torch.zoo?

@colesbury
Copy link
Member Author

I also think the model zoo is a work-in-progress. There's still some unanswered questions and things I'd like to have:

  1. We don't have a good way of specifying how inputs should be transformed. For example, the ImageNet trained models expect images that have zero mean and unit std.
  2. It would be nice for users to be able to distribute models that aren't included directly as part of PyTorch. For example, you might want to do zoo.load("colesbury/resnet.pth"). This could be tied to GitHub, for example
  3. For (2) to work well, we'd probably need a good way to distribute model definitions (Python code) as well as the weights

@apaszke
Copy link
Contributor

apaszke commented Jan 6, 2017

I love the idea of using github names! I think github supports git lfs, so we only need to standardize the structure of the repo and it should be good.

We could just clone the repo in some local directory where it would be cached, and then load it from there.


def load_url(url, model_dir=None):
if model_dir is None:
model_dir = os.getenv('TORCH_MODEL_ZOO', DEFAULT_MODEL_DIR)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

else:
file_size = int(meta.get_all("Content-Length")[0])

with tempfile.NamedTemporaryFile(delete=False) as f, tqdm(total=file_size) as pbar:

This comment was marked as off-topic.

return VGG(make_layers(cfg['A']))


def vgg11_bn():

This comment was marked as off-topic.

This comment was marked as off-topic.

@colesbury colesbury force-pushed the zoo branch 4 times, most recently from 8ef76ff to 5d69981 Compare January 9, 2017 18:39
@colesbury
Copy link
Member Author

Updated to use torch.utils.model_zoo

def resnet18(pretrained=False):
model = ResNet(BasicBlock, [2, 2, 2, 2])
if pretrained:
model.load_state_dict(model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth'))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@colesbury colesbury merged commit 7d150e8 into pytorch:master Jan 9, 2017
@colesbury colesbury deleted the zoo branch January 17, 2017 19:53
@normster
Copy link

Hi @colesbury, do you remember any more details on how these models (specifically resnet18) were trained? i.e. which year of ImageNet, final train/val accuracy, batch size.

Thank you!

@colesbury
Copy link
Member Author

The ResNet models were trained on ILSVRC2012. They were trained with the default options using https://github.com/pytorch/examples/tree/master/imagenet. (batch size 256, 90 epochs, lr=0.1 decay by 10 every 30 epochs). I think I used 4 GPUs, but it might have been 8.

Accuracy on the validation set is in the table here:
https://pytorch.org/docs/stable/torchvision/models.html

You can verify the accuracy on the validation set by running the example imagenet script with:

python main.py -a resnet18 --evaluate --pretrained [path to imagenet]

You can compute the accuracy on the training set by swapping your train and val directories and running the above command.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants