Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[example] alexnet on imagenet #144

Merged
merged 2 commits into from
Sep 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions example/imagenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Training Neural Networks on Imagenet

## Prepare Dataset

TODO

## Neural Networks

- [alexnet.py](alexnet.py) : alexnet with 5 convolution layers followed by 3
fully connnected layers

## Results

Machine: Dual Xeon E5-2680 2.8GHz, Dual GTX 980, Ubuntu 14.0, GCC 4.8, MKL, CUDA
7, CUDNN v3

| | val accuracy | 1 x GTX 980 | 2 x GTX 980 |
| --- | ---: | ---: | ---: | ---: |
| `alexnet.py` | ? | ? | 400 img/sec |
61 changes: 61 additions & 0 deletions example/imagenet/alexnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# pylint: skip-file
from data import ilsvrc12_iterator
import mxnet as mx
import logging

## define alexnet
input_data = mx.symbol.Variable(name="data")
# stage 1
conv1 = mx.symbol.Convolution(
data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)
relu1 = mx.symbol.Activation(data=conv1, act_type="relu")
pool1 = mx.symbol.Pooling(
data=relu1, pool_type="max", kernel=(3, 3), stride=(2,2))
lrn1 = mx.symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 2
conv2 = mx.symbol.Convolution(
data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256)
relu2 = mx.symbol.Activation(data=conv2, act_type="relu")
pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2))
lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 3
conv3 = mx.symbol.Convolution(
data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu3 = mx.symbol.Activation(data=conv3, act_type="relu")
conv4 = mx.symbol.Convolution(
data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu4 = mx.symbol.Activation(data=conv4, act_type="relu")
conv5 = mx.symbol.Convolution(
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu5 = mx.symbol.Activation(data=conv5, act_type="relu")
pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2))
# stage 4
flatten = mx.symbol.Flatten(data=pool3)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096)
relu6 = mx.symbol.Activation(data=fc1, act_type="relu")
dropout1 = mx.symbol.Dropout(data=relu6, p=0.5)
# stage 5
fc2 = mx.symbol.FullyConnected(data=dropout1, num_hidden=4096)
relu7 = mx.symbol.Activation(data=fc2, act_type="relu")
dropout2 = mx.symbol.Dropout(data=relu7, p=0.5)
# stage 6
fc3 = mx.symbol.FullyConnected(data=dropout2, num_hidden=1000)
softmax = mx.symbol.Softmax(data=fc3)


## data
train, val = ilsvrc12_iterator(batch_size=256, input_shape=(3,224,224))

## train
num_gpus = 2
gpus = [mx.gpu(i) for i in range(num_gpus)]
model = mx.model.FeedForward(
ctx = gpus,
symbol = softmax,
num_round = 20,
learning_rate = 0.01,
momentum = 0.9,
wd = 0.00001)
logging.basicConfig(level = logging.DEBUG)
model.fit(X = train, eval_data = val,
epoch_end_callback = mx.callback.Speedometer(100))
28 changes: 28 additions & 0 deletions example/imagenet/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# pylint: skip-file
""" data iterator for imagnet"""
import sys
sys.path.insert(0, "../../python/")
import mxnet as mx

def ilsvrc12_iterator(batch_size, input_shape):
"""return train and val iterators for imagenet"""
train_dataiter = mx.io.ImageRecordIter(
path_imgrec = "data/ilsvrc12/train.rec",
mean_img = "data/ilsvrc12/mean.bin",
rand_crop = True,
rand_mirror = True,
prefetch_buffer = 4,
preprocess_threads = 4,
data_shape = input_shape,
batch_size = batch_size)
val_dataiter = mx.io.ImageRecordIter(
path_imgrec = "data/ilsvrc12/val.rec",
mean_img = "data/ilsvrc12/mean.bin",
rand_crop = False,
rand_mirror = False,
prefetch_buffer = 4,
preprocess_threads = 4,
data_shape = input_shape,
batch_size = batch_size)

return (train_dataiter, val_dataiter)
2 changes: 1 addition & 1 deletion example/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Cortes, and Christopher J.C. Burges.

Using 100 minibatch size and 20 data passes (not fine tuned.)

Machine: Dual Xeon E5-2680 2.8GHz, Dual GTX 980, Ubuntu 14.0, GCC 4.8. Intel MKL, and CUDA 7.0
Machine: Dual Xeon E5-2680 2.8GHz, Dual GTX 980, Ubuntu 14.0, GCC 4.8.

| | val accuracy | 2 x E5-2680 | 1 x GTX 980 | 2 x GTX 980 |
| --- | ---: | ---: | ---: | ---: |
Expand Down
6 changes: 0 additions & 6 deletions example/mnist/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging

## define lenet

# input
data = mx.symbol.Variable('data')
# first conv
Expand All @@ -27,19 +26,14 @@
lenet = mx.symbol.Softmax(data=fc2)

## data

train, val = mnist_iterator(batch_size=100, input_shape=(1,28,28))

## train

logging.basicConfig(level=logging.DEBUG)

# dev = [mx.gpu(i) for i in range(2)]
dev = mx.gpu()

model = mx.model.FeedForward(
ctx = dev, symbol = lenet, num_round = 20,
learning_rate = 0.01, momentum = 0.9, wd = 0.00001)

model.fit(X=train, eval_data=val,
epoch_end_callback=mx.callback.Speedometer(100))