Skip to content

Commit

Permalink
[Incompatible] refactor inferencer. ClassificationError now takes a v…
Browse files Browse the repository at this point in the history
…ector
  • Loading branch information
ppwwyyxx committed Nov 6, 2016
1 parent 148d7dd commit 740e9d8
Show file tree
Hide file tree
Showing 17 changed files with 64 additions and 60 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Describe your training task with three components:
+ Use Python to easily handle any of your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
For example, InceptionV3 can run in the same speed as the official code which reads data using TF operators.

3. Callbacks, including everything you want to do apart from the training iterations. Such as:
3. Callbacks, including everything you want to do apart from the training iterations, such as:
+ Change hyperparameters during training
+ Print some variables of interest
+ Run inference on a test dataset
Expand All @@ -49,7 +49,7 @@ Multi-GPU training is off-the-shelf by simply switching the trainer.
pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed)
```
+ Use [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) whenever possible
+ [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) usually helps.
+ Enable `import tensorpack`:
```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
Expand Down
1 change: 0 additions & 1 deletion examples/DisturbLabel/mnist-disturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _build_graph(self, input_vars):
prob = tf.nn.softmax(logits, name='prob')

wrong = symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
Expand Down
6 changes: 2 additions & 4 deletions examples/DoReFa-Net/alexnet-dorefa.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,9 @@ def activate(x):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')

wrong = prediction_incorrect(logits, label, 1)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top1')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))
wrong = prediction_incorrect(logits, label, 5)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top5')
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))

# weight decay on all W of fc layers
Expand Down
1 change: 0 additions & 1 deletion examples/DoReFa-Net/svhn-digit-dorefa.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def activate(x):

# compute the number of failed samples
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

Expand Down
6 changes: 2 additions & 4 deletions examples/Inception/inception-bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype):
cost = tf.add_n([loss3, 0.3 * loss2, 0.3 * loss1], name='weighted_cost')
add_moving_summary([cost, loss1, loss2, loss3])

wrong = prediction_incorrect(logits, label, 1)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top1')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train_error_top1'))

wrong = prediction_incorrect(logits, label, 5)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top5')
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train_error_top5'))

# weight decay on all W of fc layers
Expand Down
6 changes: 2 additions & 4 deletions examples/Inception/inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,10 @@ def proj_277(l, ch_r, ch):
loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
loss2 = tf.reduce_mean(loss2, name='loss2')

wrong = prediction_incorrect(logits, label, 1)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top1')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))

wrong = prediction_incorrect(logits, label, 5)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top5')
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))

# weight decay on all W of fc layers
Expand Down
1 change: 0 additions & 1 deletion examples/ResNet/cifar10-resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def residual(name, l, increase_dim=False, first=False):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')

wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

Expand Down
6 changes: 2 additions & 4 deletions examples/ResNet/imagenet-resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,10 @@ def layer(l, layername, block_func, features, count, stride, first=False):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
loss = tf.reduce_mean(loss, name='xentropy-loss')

wrong = prediction_incorrect(logits, label, 1)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top1')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))

wrong = prediction_incorrect(logits, label, 5)
nr_wrong = tf.reduce_sum(wrong, name='wrong-top5')
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))

# weight decay on all W of fc layers
Expand Down
1 change: 0 additions & 1 deletion examples/SpatialTransformer/mnist-addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_stn(image):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')

wrong = symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
summary.add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

wd_cost = tf.mul(1e-5, regularize_cost('fc.*/W', tf.nn.l2_loss),
Expand Down
2 changes: 0 additions & 2 deletions examples/cifar-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def _build_graph(self, input_vars):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')

# compute the number of failed samples, for ClassificationError to use at test time
wrong = symbf.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

Expand Down
7 changes: 3 additions & 4 deletions examples/mnist-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def _build_graph(self, input_vars):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) # a vector of length B with loss of each sample
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss

# compute the number of failed samples, for thee callback ClassificationError to use at test time
wrong = symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect')

# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
Expand Down Expand Up @@ -117,7 +116,7 @@ def get_config():
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow
[ScalarStats('cost'), ClassificationError() ]),
[ScalarStats('cost'), ClassificationError('incorrect')]),
]),
model=Model(),
step_per_epoch=step_per_epoch,
Expand Down
1 change: 0 additions & 1 deletion examples/svhn-digit-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def _build_graph(self, input_vars):

# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

Expand Down
66 changes: 40 additions & 26 deletions tensorpack/callbacks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from six.moves import zip, map

from ..dataflow import DataFlow
from ..utils import get_tqdm_kwargs, logger
from ..utils import get_tqdm_kwargs, logger, execute_only_once
from ..utils.stat import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback
from .dispatcher import OutputTensorDispatcer

__all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
Expand All @@ -31,14 +32,14 @@ def before_inference(self):
def _before_inference(self):
pass

def datapoint(self, dp, output):
def datapoint(self, _, output):
"""
Called after complete running every data point
"""
self._datapoint(dp, output)
self._datapoint(_, output)

@abstractmethod
def _datapoint(self, dp, output):
def _datapoint(self, _, output):
pass

def after_inference(self):
Expand Down Expand Up @@ -97,21 +98,24 @@ def _find_input_tensors(self):
self.input_tensors = [x.name for x in input_vars]

def _find_output_tensors(self):
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()

IOTensor = InferenceRunner.IOTensor
self.output_tensors = []
def find_oid(t):
tensorname = get_op_tensor_name(t)[1]
if tensorname in self.input_tensors:
# this inferencer needs the input dp
return IOTensor(self.input_tensors.index(tensorname), False)
if t in self.output_tensors:
return IOTensor(self.output_tensors.index(t), True)
else:
self.output_tensors.append(t)
return IOTensor(len(self.output_tensors) - 1, True)
self.inf_to_tensors = [
[find_oid(t) for t in inf.get_output_tensors()]
for inf in self.infs]
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)

def _trigger_epoch(self):
Expand Down Expand Up @@ -162,7 +166,7 @@ def _get_output_tensors(self):
def _before_inference(self):
self.stats = []

def _datapoint(self, dp, output):
def _datapoint(self, _, output):
self.stats.append(output)

def _after_inference(self):
Expand All @@ -180,15 +184,16 @@ class ClassificationError(Inferencer):
"""
Compute classification error in batch mode, from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch.
You can use `tf.nn.in_top_k` to record top-k error as well.
The `wrong` tensor is supposed to be an 0/1 integer vector containing
whether each sample in the batch is incorrectly classified.
You can use `tf.nn.in_top_k` to produce this vector record top-k error as well.
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
Therefore the result is different from averaging the error rate of each batch.
"""
def __init__(self, wrong_var_name='wrong:0', summary_name='val_error'):
def __init__(self, wrong_var_name='incorrect_vector', summary_name='val_error'):
"""
:param wrong_var_name: name of the `wrong` variable
:param summary_name: the name for logging
Expand All @@ -202,9 +207,18 @@ def _get_output_tensors(self):
def _before_inference(self):
self.err_stat = RatioCounter()

def _datapoint(self, dp, outputs):
batch_size = dp[0].shape[0] # assume batched input
wrong = int(outputs[0])
def _datapoint(self, _, outputs):
vec = outputs[0]
if vec.ndim == 0:
if execute_only_once():
logger.warn("[DEPRECATED] use a 'wrong vector' for ClassificationError instead of nr_wrong")
batch_size = _[0].shape[0] # assume batched input
wrong = int(vec)
else:
# TODO put shape assertion into inferencerrunner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_var_name)
batch_size = len(vec)
wrong = np.sum(vec)
self.err_stat.feed(wrong, batch_size)

def _after_inference(self):
Expand All @@ -230,7 +244,7 @@ def _get_output_tensors(self):
def _before_inference(self):
self.stat = BinaryStatistics()

def _datapoint(self, dp, outputs):
def _datapoint(self, _, outputs):
pred, label = outputs
self.stat.feed(pred, label)

Expand Down
5 changes: 3 additions & 2 deletions tensorpack/tfutils/symbolic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import numpy as np
from ..utils import logger

def prediction_incorrect(logits, label, topk=1):
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
:param logits: NxC
:param label: N
:returns: a float32 vector of length N with 0/1 values, 1 meaning incorrect prediction
"""
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)), tf.float32)
return tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, topk)),
tf.float32, name=name)

def flatten(x):
"""
Expand Down
4 changes: 2 additions & 2 deletions tensorpack/train/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _trigger_epoch(self):
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary(
'async-global-step', async_step_total_cnt)
'async_global_step', async_step_total_cnt)
except:
logger.exception("Cannot log async-global-step")
logger.exception("Cannot log async_global_step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
1 change: 0 additions & 1 deletion tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init__(self, trainer):

def run(self):
self.dataflow.reset_state()

with self.sess.as_default():
try:
while True:
Expand Down
6 changes: 6 additions & 0 deletions tensorpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def get_rng(obj=None):

_EXECUTE_HISTORY = set()
def execute_only_once():
"""
when called with:
if execute_only_once():
# do something
The body is guranteed to be executed only the first time.
"""
f = inspect.currentframe().f_back
ident = (f.f_code.co_filename, f.f_lineno)
if ident in _EXECUTE_HISTORY:
Expand Down

0 comments on commit 740e9d8

Please sign in to comment.