From 9d08a1cdee7b5e163d6fa2c19c8a6468138d7463 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 11 Sep 2018 13:54:46 +0800 Subject: [PATCH 01/28] Implement mkldnn convolution fusion. Implement mkldnn convolution quantization. --- Makefile | 4 +- example/quantization/imagenet_gen_qsym.py | 2 +- .../quantization/imagenet_gen_qsym_mkldnn.py | 236 +++++++ example/quantization/imagenet_inference.py | 2 +- include/mxnet/c_api.h | 40 +- include/mxnet/ndarray.h | 5 + mkldnn.mk | 2 +- python/mxnet/contrib/quantization.py | 81 ++- src/c_api/c_api_symbolic.cc | 28 +- src/ndarray/ndarray.cc | 33 +- src/operator/nn/mkldnn/mkldnn_base-inl.h | 18 + .../nn/mkldnn/mkldnn_convolution-inl.h | 88 ++- src/operator/nn/mkldnn/mkldnn_convolution.cc | 163 +++-- .../quantization/mkldnn/mkldnn_quantize-inl.h | 5 + .../mkldnn/mkldnn_quantized_conv.cc | 89 --- .../quantization/quantize_graph_pass.cc | 122 +++- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 653 ++++++++++++++++++ .../subgraph/mkldnn/mkldnn_conv_property.cc | 246 +++++++ src/operator/subgraph/partition_graph.cc | 6 +- src/operator/subgraph/subgraph_property.h | 16 + 20 files changed, 1598 insertions(+), 241 deletions(-) create mode 100644 example/quantization/imagenet_gen_qsym_mkldnn.py delete mode 100644 src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc create mode 100644 src/operator/subgraph/mkldnn/mkldnn_conv.cc create mode 100644 src/operator/subgraph/mkldnn/mkldnn_conv_property.cc diff --git a/Makefile b/Makefile index 7aa7867f7c18..a1322e914828 100644 --- a/Makefile +++ b/Makefile @@ -66,8 +66,8 @@ $(warning "USE_MKL2017 is deprecated. We will switch to USE_MKLDNN.") endif ifeq ($(USE_MKLDNN), 1) - MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/install - MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/install + MKLDNNROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install + MKLROOT = $(ROOTDIR)/3rdparty/mkldnn/build/install export USE_MKLML = 1 endif diff --git a/example/quantization/imagenet_gen_qsym.py b/example/quantization/imagenet_gen_qsym.py index 85474b663fae..8a2818c4bca0 100644 --- a/example/quantization/imagenet_gen_qsym.py +++ b/example/quantization/imagenet_gen_qsym.py @@ -92,7 +92,7 @@ def save_params(fname, arg_params, aux_params, logger=None): ' thresholds. This mode is expected to produce the best inference accuracy of all three' ' kinds of quantized models if the calibration dataset is representative enough of the' ' inference dataset.') - parser.add_argument('--quantized-dtype', type=str, default='int8', + parser.add_argument('--quantized-dtype', type=str, default='int8', choices=['int8', 'uint8'], help='quantization destination data type for input data') args = parser.parse_args() diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py new file mode 100644 index 000000000000..59444c0ad6df --- /dev/null +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import os +import logging +from common import modelzoo +import mxnet as mx +from mxnet.contrib.quantization import * +from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array +import ctypes + + +def download_calib_dataset(dataset_url, calib_dataset, logger=None): + if logger is not None: + logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset)) + mx.test_utils.download(dataset_url, calib_dataset) + + +def download_model(model_name, logger=None): + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if logger is not None: + logger.info('Downloading model %s... into path %s' % (model_name, model_path)) + return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) + + +def save_symbol(fname, sym, logger=None): + if logger is not None: + logger.info('Saving symbol into file at %s' % fname) + sym.save(fname) + + +def save_params(fname, arg_params, aux_params, logger=None): + if logger is not None: + logger.info('Saving params into file at %s' % fname) + save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) + mx.nd.save(fname, save_dict) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with mkldnn support') + parser.add_argument('--ctx', type=str, default='cpu') + parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'ssd', 'imagenet1k-resnet-50', 'imagenet1k-resnet-50-v1', + 'imagenet1k-inception-v3', 'imagenet1k-inception-bn', 'imagenet1k-vgg-16'], + help='currently only supports imagenet1k-resnet-152, imagenet1k-inception-bn or imagenet1k-vgg-16') + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--label-name', type=str, default='softmax_label') + parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec', + help='path of the calibration dataset') + parser.add_argument('--image-shape', type=str, default='3,224,224') + parser.add_argument('--data-nthreads', type=int, default=60, + help='number of threads for data decoding') + parser.add_argument('--num-calib-batches', type=int, default=10, + help='number of batches for calibration') + parser.add_argument('--exclude-first-conv', action='store_true', default=True, + help='excluding quantizing the first conv layer since the' + ' number of channels is usually not a multiple of 4 in that layer' + ' which does not satisfy the requirement of cuDNN') + parser.add_argument('--shuffle-dataset', action='store_true', default=True, + help='shuffle the calibration dataset') + parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304, + help='shuffling chunk seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--shuffle-seed', type=int, default=48564309, + help='shuffling seed, see' + ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter' + ' for more details') + parser.add_argument('--calib-mode', type=str, default='entropy', + help='calibration mode used for generating calibration table for the quantized symbol; supports' + ' 1. none: no calibration will be used. The thresholds for quantization will be calculated' + ' on the fly. This will result in inference speed slowdown and loss of accuracy' + ' in general.' + ' 2. naive: simply take min and max values of layer outputs as thresholds for' + ' quantization. In general, the inference accuracy worsens with more examples used in' + ' calibration. It is recommended to use `entropy` mode as it produces more accurate' + ' inference results.' + ' 3. entropy: calculate KL divergence of the fp32 output and quantized output for optimal' + ' thresholds. This mode is expected to produce the best inference accuracy of all three' + ' kinds of quantized models if the calibration dataset is representative enough of the' + ' inference dataset.') + parser.add_argument('--quantized-dtype', type=str, default='uint8', + choices=['int8', 'uint8'], + help='quantization destination data type for input data') + parser.add_argument('--disable-requantize', type=bool, default=True, + help='If disable requantize, the OP needed requantize' + ' will output int8 directly and hence requantize ' + 'OP is not needed during quantization. Note: ' + 'calibration mode need to be used if requantize ' + 'is disabled.') + parser.add_argument('--enable-calib-quantize', type=bool, default=True, + help='If enabled, the quantize op will ' + 'be calibrated offline if calibration mode is ' + 'enabled') + args = parser.parse_args() + + if args.ctx == 'gpu': + ctx = mx.gpu(0) + elif args.ctx == 'cpu': + ctx = mx.cpu(0) + else: + raise ValueError('ctx %s is not supported in this script' % args.ctx) + + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + logger.info('shuffle_dataset=%s' % args.shuffle_dataset) + + calib_mode = args.calib_mode + logger.info('calibration mode set to %s' % calib_mode) + + # download calibration dataset + if calib_mode != 'none': + download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset) + + # download model + prefix, epoch = download_model(model_name=args.model, logger=logger) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + + out = SymbolHandle() + backend = "MKLDNN" + check_call(_LIB.MXGenBackendSubgraph(c_str(backend), sym.handle, ctypes.byref(out))) + sym = Symbol(out) + + # get batch size + batch_size = args.batch_size + logger.info('batch size = %d for calibration' % batch_size) + + # get number of batches for calibration + num_calib_batches = args.num_calib_batches + if calib_mode != 'none': + logger.info('number of batches = %d for calibration' % num_calib_batches) + + # get number of threads for decoding the dataset + data_nthreads = args.data_nthreads + + # get image shape + image_shape = args.image_shape + + exclude_first_conv = args.exclude_first_conv + excluded_sym_names = [] + if args.model == 'imagenet1k-resnet-152': + rgb_mean = '0,0,0' + if args.ctx == 'gpu': + calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 + or name.find('sc') != -1 + or name.find('fc') != -1) + else: + calib_layer = lambda name: name.endswith('_output') + excluded_sym_names += ['flatten0', 'fc1'] + if exclude_first_conv: + excluded_sym_names += ['sg_mkldnn_conv_bn_relu_0', 'pooling0'] + elif args.model == 'imagenet1k-inception-bn': + rgb_mean = '123.68,116.779,103.939' + if args.ctx == 'gpu': + calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 + or name.find('fc') != -1) + else: + calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1) + excluded_sym_names += ['flatten', 'fc1'] + if exclude_first_conv: + excluded_sym_names += ['conv_1'] + else: + raise ValueError('model %s is not supported in this script' % args.model) + + label_name = args.label_name + logger.info('label_name = %s' % label_name) + + data_shape = tuple([int(i) for i in image_shape.split(',')]) + logger.info('Input data shape = %s' % str(data_shape)) + + logger.info('rgb_mean = %s' % rgb_mean) + rgb_mean = [float(i) for i in rgb_mean.split(',')] + mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} + + if calib_mode == 'none': + logger.info('Quantizing FP32 model %s' % args.model) + qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + ctx=ctx, excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, + disable_requantize=args.disable_requantize, + logger=logger) + sym_name = '%s-symbol.json' % (prefix + '-quantized') + save_symbol(sym_name, qsym, logger) + else: + logger.info('Creating ImageRecordIter for reading calibration dataset') + data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset, + label_width=1, + preprocess_threads=data_nthreads, + batch_size=batch_size, + data_shape=data_shape, + label_name=label_name, + rand_crop=False, + rand_mirror=False, + shuffle=args.shuffle_dataset, + shuffle_chunk_seed=args.shuffle_chunk_seed, + seed=args.shuffle_seed, + **mean_args) + + cqsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + ctx=ctx, excluded_sym_names=excluded_sym_names, + calib_mode=calib_mode, calib_data=data, + num_calib_examples=num_calib_batches * batch_size, + calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, + disable_requantize=args.disable_requantize, + label_names=(label_name,), + logger=logger) + if calib_mode == 'entropy': + suffix = '-quantized-%dbatches-entropy' % num_calib_batches + elif calib_mode == 'naive': + suffix = '-quantized-%dbatches-naive' % num_calib_batches + else: + raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' + % calib_mode) + sym_name = '%s-symbol.json' % (prefix + suffix) + save_symbol(sym_name, cqsym, logger) + + param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) + save_params(param_name, qarg_params, aux_params, logger) diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index 85649530aa0b..286e49ea4401 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -129,7 +129,7 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, ctx = mx.cpu(0) else: raise ValueError('ctx %s is not supported in this script' % args.ctx) - + logging.basicConfig() logger = logging.getLogger('logger') logger.setLevel(logging.INFO) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 00439962a944..784f88774763 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1502,22 +1502,23 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, int *complete); /*! - * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 - * \param sym_handle symbol to be converted - * \param ret_sym_handle quantized symbol result - * \param num_excluded_symbols number of layers excluded from being quantized in the input symbol - * \param excluded_symbols array of symbols to be excluded from being quantized - * \param num_offline number of parameters that are quantized offline - * \param offline_params array of c strings representing the names of params quantized offline - * \param quantized_dtype the quantized destination type for input data. - */ -MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, - SymbolHandle *ret_sym_handle, - const mx_uint num_excluded_symbols, - const SymbolHandle *excluded_symbols, - const mx_uint num_offline, - const char **offline_params, - const char *quantized_dtype); +* \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 +* \param sym_handle symbol to be converted +* \param ret_sym_handle quantized symbol result +* \param num_excluded_symbols number of layers excluded from being quantized in the input symbol +* \param excluded_symbols array of symbols to be excluded from being quantized +* \param num_offline number of parameters that are quantized offline +* \param offline_params array of c strings representing the names of params quantized offline +* \param quantized_dtype the quantized destination type for input data. +* \param disable_requantize whether disable requantize OP during quantization +* \param calib_quantize whether calibrate quantize op with offline calibration data. +*/ +int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, + const mx_uint num_excluded_symbols, + const SymbolHandle *excluded_symbols, + const mx_uint num_offline, const char **offline_params, + const char *quantized_dtype, const bool disable_requantize, + bool calib_quantize); /*! * \brief Set calibration table to node attributes in the sym @@ -1527,13 +1528,18 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, * \param low_quantiles low quantiles of layers stored in the calibration table * \param high_quantiles high quantiles of layers stored in the calibration table * \param ret_sym_handle returned symbol + * \param disable_requantize whether disable requantize OP during quantization */ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const mx_uint num_layers, const char** layer_names, const float* low_quantiles, const float* high_quantiles, - SymbolHandle* ret_sym_handle); + SymbolHandle* ret_sym_handle, + const bool disable_requantize); + +MXNET_DLL int MXGenBackendSubgraph(const char *backend, SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle); //-------------------------------------------- // Part 4: Executor interface diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 6141a4da78ef..47706e8a7947 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -702,6 +702,11 @@ class NDArray { * It's used by FullyConnected right now. */ NDArray MKLDNNDataReshape(const TShape &shape) const; + + /*! + * \ Fix mkldnn memory descriptor mismatch from NDArray. + */ + void UpdateMKLDNNMemDesc(); #endif /*! diff --git a/mkldnn.mk b/mkldnn.mk index 1be0704dcde1..d79bbe7d2a0e 100644 --- a/mkldnn.mk +++ b/mkldnn.mk @@ -47,7 +47,7 @@ $(MKLDNN_LIBFILE): mkldnn_clean: $(RM) -r 3rdparty/mkldnn/build - $(RM) -r 3rdparty/mkldnn/install/* + $(RM) -r $(MKLDNNROOT) ifeq ($(USE_MKLDNN), 1) mkldnn: mkldnn_build diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 8df923908fec..d8a91f8d8956 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -40,7 +40,7 @@ from ..module import Module -def _quantize_params(qsym, params): +def _quantize_params(qsym, params, th_dict): """Given a quantized symbol and a dict of params that have not been quantized, generate quantized params. Currently only supports quantizing the arg_params with names of `weight` or `bias`, not aux_params. If `qsym` contains symbols @@ -53,7 +53,9 @@ def _quantize_params(qsym, params): qsym : Symbol Quantized symbol from FP32 symbol. params : dict of str->NDArray + th_dict: dict of min/max pairs of layers' output """ + print(th_dict) inputs_name = qsym.list_arguments() quantized_params = {} for name in inputs_name: @@ -69,11 +71,20 @@ def _quantize_params(qsym, params): quantized_params[name+'_max'] = vmax elif name in params: quantized_params[name] = params[name] + elif name.endswith(('_min')): + output = name[: - len('_min')] + "_output" + if output in th_dict: + print(name) + quantized_params[name] = ndarray.array([th_dict[output][0]]) + elif name.endswith(('_max')): + output = name[: - len('_min')] + "_output" + if output in th_dict: + quantized_params[name] = ndarray.array([th_dict[output][1]]) return quantized_params - def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, - quantized_dtype='int8'): + quantized_dtype='int8', disable_requantize=False, + calib_quantize_op=True): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -89,6 +100,10 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, avoided. quantized_dtype: str The quantized destination type for input data. + disable_requantize : bool + Whether disable requantize OP functionality. + calib_quantize_op : bool + Whether perform offline calibration for quantize op. """ num_excluded_symbols = 0 excluded_handles = [] @@ -112,7 +127,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, c_array(SymbolHandle, excluded_handles), mx_uint(num_offline), c_array(ctypes.c_char_p, offline), - c_str(quantized_dtype))) + c_str(quantized_dtype), + ctypes.c_bool(disable_requantize))) return Symbol(out) @@ -170,7 +186,7 @@ def collect(self, name, arr): % (name, min_range, max_range)) -def _calibrate_quantized_sym(qsym, th_dict): +def _calibrate_quantized_sym(qsym, th_dict, disable_requantize=False): """Given a dictionary containing the thresholds for quantizing the layers, set the thresholds into the quantized symbol as the params of requantize operators. """ @@ -191,7 +207,8 @@ def _calibrate_quantized_sym(qsym, th_dict): c_str_array(layer_output_names), c_array(ctypes.c_float, min_vals), c_array(ctypes.c_float, max_vals), - ctypes.byref(calibrated_sym))) + ctypes.byref(calibrated_sym), + ctypes.c_bool(disable_requantize))) return Symbol(calibrated_sym) @@ -254,9 +271,6 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=line-too-long def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. - The reference distribution is `q`, and the candidate distribution is `p`. - `q` is a truncated version of the original distribution. - Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ if isinstance(arr, NDArray): @@ -307,10 +321,10 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): right_outlier_count = np.sum(hist[p_bin_idx_stop:]) p[-1] += right_outlier_count # is_nonzeros[k] indicates whether hist[k] is nonzero - is_nonzeros = (p != 0).astype(np.int32) + is_nonzeros = (sliced_nd_hist != 0).astype(np.int32) # calculate how many bins should be merged to generate quantized distribution q - num_merged_bins = sliced_nd_hist.size // num_quantized_bins + num_merged_bins = p.size // num_quantized_bins # merge hist into num_quantized_bins bins for j in range(num_quantized_bins): start = j * num_merged_bins @@ -318,17 +332,17 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): quantized_bins[j] = sliced_nd_hist[start:stop].sum() quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() # expand quantized_bins into p.size bins - q = np.zeros(sliced_nd_hist.size, dtype=np.float32) + q = np.zeros(p.size, dtype=np.float32) for j in range(num_quantized_bins): start = j * num_merged_bins if j == num_quantized_bins - 1: - stop = len(is_nonzeros) + stop = -1 else: stop = start + num_merged_bins norm = is_nonzeros[start:stop].sum() if norm != 0: q[start:stop] = float(quantized_bins[j]) / float(norm) - q[p == 0] = 0 + q[sliced_nd_hist == 0] = 0 p = _smooth_distribution(p) # There is a chance that q is an invalid probability distribution. try: @@ -336,6 +350,7 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): except ValueError: divergence[i - num_half_quantized_bins] = float("inf") divergence[i - num_half_quantized_bins] = stats.entropy(p, q) + quantized_bins[:] = 0 min_divergence_idx = np.argmin(divergence) min_divergence = divergence[min_divergence_idx] @@ -363,7 +378,10 @@ def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logg _get_optimal_threshold(nd_dict[name], num_bins=num_bins, num_quantized_bins=num_quantized_bins) del nd_dict[name] # release the memory of ndarray - th_dict[name] = (-opt_th, opt_th) + if min_val < 0: + th_dict[name] = (-opt_th, opt_th) + else: + th_dict[name] = (0, opt_th) if logger is not None: logger.info('layer=%s, min_val=%f, max_val=%f, min_divergence=%f, optimal_threshold=%f' % (name, min_val, max_val, min_divergence, opt_th)) @@ -408,12 +426,24 @@ def _load_params(params, logger=logging): raise ValueError('Unsupported params provided. Must be either a path to the param file or' ' a pair of dictionaries representing arg_params and aux_params') +def save_params(fname, arg_params, aux_params, logger=None): + if logger is not None: + logger.info('Saving params into file at %s' % fname) + save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) + ndarray.save(fname, save_dict) + +def save_symbol(fname, sym, logger=None): + if logger is not None: + logger.info('Saving symbol into file at %s' % fname) + sym.save(fname) def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', logger=logging): + quantized_dtype='int8', disable_requantize=False, + calib_quantize_op=True, logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -466,6 +496,12 @@ def quantize_model(sym, arg_params, aux_params, quantized_dtype : str The quantized destination type for input data. Currently support 'int8' and 'uint8', default value is 'int8'. + disable_requantize : bool + Whether disable requantize OP during quantization. If disabled, the related + quantized OP needed requantize will output int8 directly and hence requantize + OP is not needed during symbol quantization + calib_quantize_op: bool + Whether calibrate quantize op with its input calibration data. The quantize op's input should be in calib_layer logger : Object A logging object for printing information during the process of quantization. @@ -494,11 +530,11 @@ def quantize_model(sym, arg_params, aux_params, ' expected `int8` or `uint8`' % quantized_dtype) qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms, offline_params=list(arg_params.keys()), - quantized_dtype=quantized_dtype) - - logger.info('Quantizing parameters') - qarg_params = _quantize_params(qsym, arg_params) + quantized_dtype=quantized_dtype, + disable_requantize=disable_requantize, + calib_quantize_op=calib_quantize_op) + th_dict = {} if calib_mode is not None and calib_mode != 'none': if not isinstance(ctx, Context): raise ValueError('currently only supports single ctx, while received %s' % str(ctx)) @@ -535,6 +571,9 @@ def quantize_model(sym, arg_params, aux_params, raise ValueError('unknown calibration mode %s received,' ' expected `none`, `naive`, or `entropy`' % calib_mode) logger.info('Calibrating quantized symbol') - qsym = _calibrate_quantized_sym(qsym, th_dict) + qsym = _calibrate_quantized_sym(qsym, th_dict, disable_requantize) + + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, th_dict) return qsym, qarg_params, aux_params diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 35ecec7e11f6..98699ef01aaf 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -31,6 +31,7 @@ #include "./c_api_common.h" #include "../operator/operator_common.h" #include "../executor/exec_pass.h" +#include "../operator/subgraph/subgraph_property.h" namespace mxnet { namespace op { @@ -649,7 +650,9 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const SymbolHandle *excluded_symbols, const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype) { + const char *quantized_dtype, + const bool disable_requantize, + bool calib_quantize) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -669,6 +672,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, std::string quantized_type(quantized_dtype); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); + g.attrs["calib_quantize"] = std::make_shared(calib_quantize); g = ApplyPass(std::move(g), "QuantizeGraph"); s->outputs = g.outputs; *ret_sym_handle = s; @@ -680,7 +684,8 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const char** layer_names, const float* min_ranges, const float* max_ranges, - SymbolHandle* ret_qsym_handle) { + SymbolHandle* ret_qsym_handle, + const bool disable_requantize) { nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol* sym = static_cast(qsym_handle); @@ -691,8 +696,27 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, calib_table.emplace(prefix+layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); } g.attrs["calib_table"] = std::make_shared(std::move(calib_table)); + g.attrs["disable_requantize"] = std::make_shared(disable_requantize); g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); s->outputs = g.outputs; *ret_qsym_handle = s; API_END_HANDLE_ERROR(delete s); } + +int MXGenBackendSubgraph(const char *backend, SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle) { + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(sym_handle); + *s = sym->Copy(); + nnvm::Graph g = Symbol2Graph(*s); + mxnet::op::SubgraphPropertyPtr property = + mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty( + backend); + g.attrs["subgraph_property"] = + std::make_shared(std::move(property)); + g = ApplyPass(std::move(g), "PartitionGraph"); + s->outputs = g.outputs; + *ret_sym_handle = s; + API_END_HANDLE_ERROR(delete s); +} \ No newline at end of file diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 853838a87f4c..0ae4f3adbad1 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -449,24 +449,6 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); } -/* - * Here we want to get MKLDNN memory whose primitive desc is exactly the same as - * the given one. operator== can't guarantee that. == can return true even if - * the formats are different. I need to double check its format. - */ -static inline mkldnn::memory *GetMKLDNNExact( - const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { - mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); - if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { - return const_cast(mem); - } else { - std::shared_ptr ret(new mkldnn::memory( - desc, mem->get_data_handle())); - MKLDNNStream::Get()->RegisterMem(ret); - return ret.get(); - } -} - const mkldnn::memory *NDArray::GetMKLDNNData( const mkldnn::memory::primitive_desc &desc) const { if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { @@ -694,6 +676,21 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return ptr_->mkl_mem_->GetRaw(); } + +void NDArray::UpdateMKLDNNMemDesc() { + const mkldnn::memory *mem = GetMKLDNNData(); + auto mem_desc = mem->get_primitive_desc().desc(); + auto this_dtype = get_mkldnn_type(dtype()); + if (this_dtype != mem_desc.data.data_type) { + mkldnn::memory::desc data_md( + mkldnn::memory::dims(mem_desc.data.dims, + mem_desc.data.dims + mem_desc.data.ndims), + this_dtype, static_cast(mem_desc.data.format)); + mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine()); + ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr)); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); + } +} #endif void NDArray::SetTBlob() const { diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 6eb90f845d37..b9c2d2599250 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -327,6 +327,24 @@ enum OutDataOp { typedef std::pair mkldnn_output_t; void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem); +/* + * Here we want to get MKLDNN memory whose primitive desc is exactly the same as + * the given one. operator== can't guarantee that. == can return true even if + * the formats are different. I need to double check its format. + */ +static inline mkldnn::memory *GetMKLDNNExact( + const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { + mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); + if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { + return const_cast(mem); + } else { + std::shared_ptr ret(new mkldnn::memory( + desc, mem->get_data_handle())); + MKLDNNStream::Get()->RegisterMem(ret); + return ret.get(); + } +} + /* * These two functions try to create MKLDNN memory in an NDArray based on `req'. * The difference is that the first function can create MKLDNN memory with diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 23f2fe694633..9df3806f8b65 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -35,19 +35,79 @@ namespace mxnet { namespace op { -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const ConvolutionParam& param, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output); +struct MKLDNNConvParam : public dmlc::Parameter { + // When adding more members into this class, please double check GetHash() + // won't overflow. + bool with_bn; + bool with_relu; + bool with_sum; + bool with_postsum_relu; + bool quantized; + bool weight_channelwise_scale; + + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + + DMLC_DECLARE_PARAMETER(MKLDNNConvParam) { + DMLC_DECLARE_FIELD(with_bn).set_default(false) + .describe("Add post batchnorm."); + DMLC_DECLARE_FIELD(with_relu).set_default(false) + .describe("Add post relu"); + DMLC_DECLARE_FIELD(with_sum).set_default(false) + .describe("Add post sum"); + DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false) + .describe("Add post relu after sum"); + DMLC_DECLARE_FIELD(quantized).set_default(false) + .describe("enable quantization"); + DMLC_DECLARE_FIELD(weight_channelwise_scale).set_default(true) + .describe("Quantize weight with channel wise scales."); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized convolution op to calculate primitive scale"); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized convolution op to calculate primitive scale"); + } + const int GetBoolHash() const { + int hash = 0; + hash = hash * 2 + this->with_bn ? 1 : 0; + hash = hash * 2 + this->with_relu ? 1 : 0; + hash = hash * 2 + this->with_sum ? 1 : 0; + hash = hash * 2 + this->with_postsum_relu ? 1 : 0; + hash = hash * 2 + this->quantized ? 1 : 0; + return hash; + } +}; + +struct MKLDNNConvFullParam { + ConvolutionParam conv_param; + MKLDNNConvParam mkldnn_param; + float sum_scale; + std::vector requantize_scales; +}; + +static inline bool IsOutputUInt8(const MKLDNNConvParam &mkldnn_param) { + return ((!mkldnn_param.with_sum) && mkldnn_param.with_relu) || + mkldnn_param.with_postsum_relu; +} + +mkldnn::convolution_forward::primitive_desc +GetConvFwdImpl(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weights, const NDArray *bias, + const NDArray &output); class MKLDNNConvForward { public: mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const ConvolutionParam& param, const bool is_train, + MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output): fwd_pd( - GetConvFwdImpl(param, is_train, data, weights, bias, output)) { - } + const NDArray *bias, const NDArray &output) + : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) {} void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, const mkldnn::memory *bias, const mkldnn::memory &output); @@ -66,9 +126,17 @@ class MKLDNNConvForward { typedef ParamOpSign MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, - const bool is_train, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output); +MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, + const bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output); + +void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, + const OpContext &ctx, + MKLDNNConvForward &fwd, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index cf04ea8da3d7..f02fe2bf2458 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -34,55 +34,99 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(MKLDNNConvParam); + bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { if (params.kernel.ndim() != 2) return false; - return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; + return input.shape().ndim() == 4; +} + +inline static mkldnn::memory::desc GetInDataMemDesc(const NDArray &arr) { + mkldnn::memory::dims dims(arr.shape().ndim()); + for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; + int mkldnn_dtype; + // For INT8 case, currently we only support uint8 as input data so need + // to create the memory primitive of uint8 type + if (arr.dtype() == mshadow::kInt8) { + mkldnn_dtype = mshadow::kUint8; + } else { + mkldnn_dtype = arr.dtype(); + } + return mkldnn::memory::desc{dims, get_mkldnn_type(mkldnn_dtype), + mkldnn::memory::format::any}; } mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( - const ConvolutionParam& param, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output) { + const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weights, const NDArray *bias, + const NDArray &output) { auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); + auto data_md = GetInDataMemDesc(data); + auto weight_md = GetWeightDesc(weights, param.conv_param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.conv_param.stride.ndim(), 2U); + CHECK_GE(param.conv_param.pad.ndim(), 2U); + CHECK_GE(param.conv_param.dilate.ndim(), 2U); mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; + strides[0] = param.conv_param.stride[0]; + strides[1] = param.conv_param.stride[1]; mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - if (param.dilate.ndim() == 0 && bias == nullptr) { + padding[0] = param.conv_param.pad[0]; + padding[1] = param.conv_param.pad[1]; + mkldnn::primitive_attr attr; + mkldnn::post_ops ops; + if (param.mkldnn_param.with_relu) { + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + + } + if (param.mkldnn_param.with_sum) { + ops.append_sum(param.sum_scale); + } + if (param.mkldnn_param.with_postsum_relu) { + float scale = 1.0f; // for fp32, scale is 1. + float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. + float beta = 1.0f; // ignored for mkldnn_eltwise_relu. + ops.append_eltwise(scale, eltwise_relu, alpha, beta); + } + attr.set_post_ops(ops); + + if (param.mkldnn_param.quantized) { + int mask = param.mkldnn_param.weight_channelwise_scale ? 2 : 0; + attr.set_output_scales(mask, param.requantize_scales); + attr.set_int_output_round_mode(round_nearest); + } + + if (param.conv_param.dilate.ndim() == 0 && bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); - } else if (param.dilate.ndim() == 0) { + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); + } else if (param.conv_param.dilate.ndim() == 0) { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, bias_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { mkldnn::memory::dims dilates{0, 0}; - dilates[0] = param.dilate[0] - 1; - dilates[1] = param.dilate[1] - 1; + dilates[0] = param.conv_param.dilate[0] - 1; + dilates[1] = param.conv_param.dilate[1] - 1; if (bias == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } else { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, bias_md, out_md, strides, dilates, padding, padding, mkldnn::padding_kind::zero); - return mkldnn::convolution_forward::primitive_desc(desc, engine); + return mkldnn::convolution_forward::primitive_desc(desc, attr, engine); } } } @@ -207,16 +251,16 @@ void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, } } -MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, - const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output) { +MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, + const bool is_train, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else static MX_THREAD_LOCAL std::unordered_map fwds; #endif - const ConvolutionParam& param = nnvm::get(attrs.parsed); - MKLDNNConvSignature key(param); + MKLDNNConvSignature key(param.conv_param); key.AddSign(is_train); // Here we can sign the conv op with NDArray because conv primitive will // decide the right layout for the, so we only need to get the shape and the @@ -238,17 +282,17 @@ MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, const bool is_train, return it->second; } -void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { +void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, + const OpContext &ctx, + MKLDNNConvForward &fwd, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - const ConvolutionParam& param = nnvm::get(attrs.parsed); NDArray weight = in_data[conv::kWeight]; - MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train, in_data[conv::kData], weight, - param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); - - auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn; + auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder( + fwd.fwd_pd.src_primitive_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it @@ -257,12 +301,14 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx // This asks the engine to change the layout of the weight array after // it's used. weight.Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), + param.conv_param.num_group); } else { // For inference, we want to reorder the weight array so we don't need to // reorder data every time. if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), + param.conv_param.num_group); // We also need to modify the layout on the original weight array. The // data conversion happens after the weight array is used. weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); @@ -271,11 +317,21 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); } } - auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), - req[conv::kOut]); + mkldnn_output_t out_mem; + if (param.mkldnn_param.with_sum) { + out_mem = mkldnn_output_t( + OutDataOp::Noop, + const_cast(out_data[conv::kOut].GetMKLDNNDataReorder( + fwd.fwd_pd.dst_primitive_desc()))); + } else { + out_mem = CreateMKLDNNMem(out_data[conv::kOut], + fwd.fwd_pd.dst_primitive_desc(), req[conv::kOut]); + } + const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); + if (!no_bias) { + bias_mem = in_data[conv::kBias].GetMKLDNNData(); + } fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); @@ -283,16 +339,35 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx MKLDNNStream::Get()->Submit(); } +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + MKLDNNConvFullParam param; + param.conv_param = nnvm::get(attrs.parsed); + param.mkldnn_param.Init(std::unordered_map()); + auto &fwd = GetConvFwd( + param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], + out_data[conv::kOut]); + MKLDNNConvolutionForwardFullFeature(param, ctx, fwd, in_data, req, out_data); +} + void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); const std::vector &in_grad = outputs; - const ConvolutionParam& param = nnvm::get(attrs.parsed); - mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(param, ctx.is_train, - inputs[conv::kData + 1], inputs[conv::kWeight + 1], - param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); + MKLDNNConvFullParam full_param; + full_param.conv_param = nnvm::get(attrs.parsed); + full_param.mkldnn_param.Init(std::unordered_map()); + mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl( + full_param, ctx.is_train, inputs[conv::kData + 1], inputs[conv::kWeight + 1], + full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1], + inputs[conv::kOut]); + const ConvolutionParam ¶m = full_param.conv_param; CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; mkldnn::convolution_backward_data::primitive_desc bwdData_pd diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h index f7709319d6a2..7a00f621d452 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize-inl.h @@ -75,6 +75,11 @@ static void MKLDNNQuantizeComputeKer(const std::vector& inputs, auto i_mpd = i_mem->get_primitive_desc(); auto i_desc = i_mpd.desc(); mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); + if (i_fmt == mkldnn::memory::format::nchw || + i_fmt == mkldnn::memory::format::nChw8c || + i_fmt == mkldnn_nChw16c) { + i_fmt = mkldnn::memory::format::nhwc; + } size_t i_ndim = in_buffer.shape().ndim(); mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); for (size_t i = 0; i < i_ndim; i++) { diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc deleted file mode 100644 index fa6a32a47392..000000000000 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file mkldnn_quantized_conv.cc - * \brief - * \author Wenting Jiang, Xinyu Chen -*/ - -#if MXNET_USE_MKLDNN == 1 -#include "../../nn/mkldnn/mkldnn_base-inl.h" -#include "../../nn/mkldnn/mkldnn_convolution-inl.h" -#include "../../nn/convolution-inl.h" -#include "../quantization_utils.h" -#include "../../tensor/matrix_op-inl.h" -#include "../../elemwise_op_common.h" -namespace mxnet { -namespace op { - -static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { - CHECK_EQ(in_data[0].dtype(), mshadow::kUint8) - << "mkldnn_quantized_conv op only supports uint8 as input type"; - TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - const ConvolutionParam& param = nnvm::get(attrs.parsed); - NDArray weight = in_data[conv::kWeight]; - MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train, - in_data[conv::kData], weight, - param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); - - auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); - const mkldnn::memory *weight_mem; - // For inference, we want to reorder the weight array so we don't need to - // reorder data every time. - if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); - } else { - weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); - } - auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), - req[conv::kOut]); - const mkldnn::memory *bias_mem = nullptr; - if (!param.no_bias) - bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); - fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); - - CommitOutput(out_data[conv::kOut], out_mem); - MKLDNNStream::Get()->Submit(); - Stream *s = ctx.get_stream(); - const size_t num_inputs = param.no_bias ? 2 : 3; - mxnet_op::Kernel::Launch(s, 1, - out_data[1].data().dptr(), out_data[2].data().dptr(), - in_data[num_inputs].data().dptr(), - in_data[num_inputs+1].data().dptr(), - in_data[num_inputs+2].data().dptr(), - in_data[num_inputs+3].data().dptr()); -} - -NNVM_REGISTER_OP(_contrib_quantized_conv) -.set_attr("FComputeEx", MKLDNNQuantizedConvForward); - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 10834868d2b5..e1a17ad1faa4 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -94,12 +94,27 @@ inline bool NeedQuantize(NodePtr node, const std::unordered_set exclude return quantized_op_map.count(node->op()) && !excluded_nodes.count(node); } +inline bool ExcludeKey(NodePtr node, NodeEntry e) { + auto findSGConv = node->attrs.name.find("sg_mkldnn_conv_"); + std::vector exclude_key{"weight", "bias", "gamma", "beta", "mean", "var"}; + bool found = false; + if (findSGConv == std::string::npos) return false; + for (size_t i = 0; i < exclude_key.size(); i++) { + if (e.node->attrs.name.find(exclude_key[i]) != std::string::npos) { + found = true; + break; + } + } + return found; +} + Graph QuantizeGraph(Graph &&src) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); auto offline_params = src.GetAttr>("offline_params"); auto excluded_nodes = src.GetAttr>("excluded_nodes"); auto quantized_dtype = src.GetAttr("quantized_dtype"); + auto calib_quantize = src.GetAttr("calib_quantize"); // mirror_map stores the mapping from the currently visited graph to the newly created quantized // graph. Key is the currently visited graph's node pointer, and value is a copied node of the key @@ -125,23 +140,34 @@ Graph QuantizeGraph(Graph &&src) { // taking mirror_entry as input to generate a quantized NDArray. Save the mapping between // e's source node and the newly created quantize op so that the quantize op can be // reused next time when the same entry is visited again. - if (!NeedQuantize(e.node, excluded_nodes) && + if (ExcludeKey(node, e)) { + new_node->inputs.emplace_back(mirror_entry); + } + else if (!NeedQuantize(e.node, excluded_nodes) && (mirror_node->op() == nullptr || mirror_node->op()->name != "_contrib_quantize")) { NodePtr quantize_node = InsertNode("_contrib_quantize", e.node->attrs.name + "_quantize", new_node, mirror_entry); quantize_node->attrs.dict["out_type"] = quantized_dtype; quantize_node->op()->attr_parser(&(quantize_node->attrs)); + if (calib_quantize) { + NodePtr min_var = CreateNode("nullptr", e.node->attrs.name + "_min"); + quantize_node->inputs.emplace_back(NodeEntry{min_var, 0, 0}); + NodePtr max_var = CreateNode("nullptr", e.node->attrs.name + "_max"); + quantize_node->inputs.emplace_back(NodeEntry{max_var, 0, 0}); + } else { + NodePtr min_node = InsertNode("min", + e.node->attrs.name + "_min", quantize_node, mirror_entry); + min_node->op()->attr_parser(&(min_node->attrs)); - NodePtr min_node = InsertNode("min", - e.node->attrs.name + "_min", quantize_node, mirror_entry); - min_node->op()->attr_parser(&(min_node->attrs)); - - NodePtr max_node = InsertNode("max", - e.node->attrs.name + "_max", quantize_node, mirror_entry); - max_node->op()->attr_parser(&(max_node->attrs)); - + NodePtr max_node = InsertNode("max", + e.node->attrs.name + "_max", quantize_node, mirror_entry); + max_node->op()->attr_parser(&(max_node->attrs)); + } mirror_map[e.node.get()] = std::move(quantize_node); + } else if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_dequantize") { + new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); } else { // If the entry e's node needs quantization, or mirror_entry is from a quantize op, // simply add mirror_entry to the input of the new_node. @@ -154,12 +180,18 @@ Graph QuantizeGraph(Graph &&src) { // data1, data2, ..., min1, max1, min2, max2, ... for (const auto& e : node->inputs) { NodePtr mirror_node = mirror_map.at(e.node.get()); + if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_dequantize") { + mirror_node = mirror_node->inputs[0].node; + } NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; // for quantize node uint32_t min_index = 1; uint32_t max_index = 2; - if (quantized_op_map.count(e.node->op())) { + if (e.node->op() != nullptr && + (quantized_op_map.count(e.node->op()) || + e.node->op()->name != "_contrib_quantize")) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is @@ -167,12 +199,9 @@ Graph QuantizeGraph(Graph &&src) { size_t num_outputs = mirror_node->num_outputs() - 2; min_index = num_outputs + 2 * e.index; max_index = num_outputs + 2 * e.index + 1; - } else { - CHECK(mirror_node->op()->name == "_contrib_quantize") - << "The input is not quantize or quantized_op"; + new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); } - new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); } // If the new_node op registered attr FNeedRequantize, insert requantize node after it. @@ -180,17 +209,17 @@ Graph QuantizeGraph(Graph &&src) { // out_data, min_range, and max_range. if (need_requantize_map.count(new_node->op()) > 0 && need_requantize_map[new_node->op()](new_node->attrs)) { - NodePtr requantize_node = Node::Create(); - requantize_node->attrs.op = Op::Get("_contrib_requantize"); - requantize_node->attrs.name = "requantize_" + node->attrs.name; - if (requantize_node->op()->attr_parser != nullptr) { - requantize_node->op()->attr_parser(&(requantize_node->attrs)); - } - for (size_t i = 0; i < 3; ++i) { - requantize_node->inputs.emplace_back(NodeEntry{new_node, static_cast(i), 0}); + NodePtr requantize_node = Node::Create(); + requantize_node->attrs.op = Op::Get("_contrib_requantize"); + requantize_node->attrs.name = "requantize_" + node->attrs.name; + if (requantize_node->op()->attr_parser != nullptr) { + requantize_node->op()->attr_parser(&(requantize_node->attrs)); + } + for (size_t i = 0; i < 3; ++i) { + requantize_node->inputs.emplace_back(NodeEntry{new_node, static_cast(i), 0}); + } + new_node = requantize_node; } - new_node = requantize_node; - } } else { // If the currently visited node does not need quantization, copy the current node to become // the new_node. Meanwhile, check whether any inputs of the current node need quantization @@ -204,7 +233,9 @@ Graph QuantizeGraph(Graph &&src) { NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; // if input node is quantized operator, add dequantize node - if (NeedQuantize(e.node, excluded_nodes)) { + if (NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() == nullptr || + mirror_node->op()->name != "_contrib_dequantize")) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is @@ -225,7 +256,8 @@ Graph QuantizeGraph(Graph &&src) { && mirror_node->op()->name == "_contrib_quantize") { new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); } else { - new_node->inputs.emplace_back(NodeEntry{mirror_node, e.index, e.version}); + new_node->inputs.emplace_back( + NodeEntry{mirror_node, e.index, e.version}); } } } @@ -268,18 +300,44 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) { nnvm::Op::GetAttr("FNeedRequantize"); const auto& calib_table = g.GetAttr>>("calib_table"); + auto disable_requantize = g.GetAttr("disable_requantize"); + DFSVisit(g.outputs, [&](const NodePtr& node) { - // If the current op is requantize - // find the thresholds from the calibration table with the key equal - // to the current op's input node name, e.g. a quantized_conv2d node. - if (node->op() != nullptr && node->op()->name == "_contrib_requantize") { - NodePtr quantized_op_node = node->inputs[0].node; + bool found = false; + NodePtr quantized_op_node; + // If requantize is not disabled, find requantize OP and + // the thresholds from the calibration table with the key equal + // to the requantize OP's input node name, e.g. a quantized_conv node. + if (!disable_requantize && + node->op() != nullptr && node->op()->name == "_contrib_requantize") { + quantized_op_node = node->inputs[0].node; CHECK(quantized_op_node->op() != nullptr) << quantized_op_node->attrs.name << " must be an quantized op node"; CHECK(need_requantize_map.count(quantized_op_node->op()) > 0 && need_requantize_map[quantized_op_node->op()](quantized_op_node->attrs)) << quantized_op_node->attrs.name << " op must register FNeedRequantize attr" " and the attr func should return true"; + found = true; + // If requantize is disabled, find OPs that needed requantize and + // the thresholds from the calibration table with the key equal + // to the found OP's name, e.g. a quantized_conv node. + } else if (disable_requantize && node->op() != nullptr && + need_requantize_map.count(node->op()) > 0 && + need_requantize_map[node->op()](node->attrs)) { + quantized_op_node = node; + found = true; + } else if (disable_requantize && + node->op() != nullptr && node->op()->name == "_sg_mkldnn_conv" + && !node->attrs.name.find("quantized_")) { + quantized_op_node = node; + std::string out_data_name = quantized_op_node->attrs.name + "_output"; + const auto calib_table_iter = calib_table.find(out_data_name); + if (calib_table_iter != calib_table.end()) { + node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); + node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); + } + } + if (found) { std::string out_data_name = quantized_op_node->attrs.name + "_"; auto list_output_names_func = flist_outputs.get(quantized_op_node->op(), nullptr); // Here it's assumed that the quantized_op node only produces three outputs: diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc new file mode 100644 index 000000000000..ed575bbb9a4e --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -0,0 +1,653 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include "../common.h" +#include "../../../imperative/imperative_utils.h" +#include "../../../imperative/cached_op.h" +#include "../../nn/convolution-inl.h" +#include "../../nn/batch_norm-inl.h" +#include "../../nn/activation-inl.h" +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" +#include "../../quantization/quantization_utils.h" +namespace mxnet { +namespace op { + +struct MKLDNNConvFusionParam { + MKLDNNConvFullParam full_conv_param; + std::shared_ptr bn_param; +}; + +static const size_t uint8_range = 255; +static const size_t int8_range = 127; + +enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; + +template +static void UpdateConvWeightBias(NDArray &weight, NDArray &bias, bool no_bias, + const NDArray &gamma, const NDArray &beta, + const NDArray &mean, const NDArray &variance, + const BatchNormParam *param) { + // TODO(Zhennan): Handle the case weight is not in dims 4. + NDArray update_weight = NDArray(weight.storage_type(), weight.shape(), + weight.ctx(), true, weight.dtype()); + NDArray update_bias = NDArray(beta.storage_type(), beta.shape(), beta.ctx(), + true, beta.dtype()); + DType *weight_ptr = weight.data().dptr(); + DType *bias_ptr = no_bias ? nullptr : bias.data().dptr(); + DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); + DType *beta_ptr = beta.Reorder2Default().data().dptr(); + DType *mean_ptr = mean.Reorder2Default().data().dptr(); + DType *var_ptr = variance.Reorder2Default().data().dptr(); + DType *update_weight_ptr = update_weight.data().dptr(); + DType *update_bias_ptr = update_bias.data().dptr(); + size_t channel = gamma.shape()[0]; + size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; +#pragma omp parallel for + for (size_t c = 0; c < channel; ++c) { + DType *p1 = reinterpret_cast(weight_ptr + c * offset); + DType *p2 = reinterpret_cast(update_weight_ptr + c * offset); + DType alpha = (param->fix_gamma ? static_cast(1.0f) : gamma_ptr[c]) / + sqrt(var_ptr[c] + param->eps); + + if (bias_ptr) + update_bias_ptr[c] = beta_ptr[c] + alpha * (bias_ptr[c] - mean_ptr[c]); + else + update_bias_ptr[c] = beta_ptr[c] - alpha * mean_ptr[c]; + + for (size_t k = 0; k < offset; ++k) { + p2[k] = p1[k] * alpha; + } + } + weight = update_weight; + bias = update_bias; +} + +static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { + return 2 + (param.full_conv_param.conv_param.no_bias ? 0 : 1) + + (param.full_conv_param.mkldnn_param.with_bn ? 4 : 0); +} + +template +static void QuantizeConvWeightBias(NDArray &weight, NDArray &bias, + bool has_bias, float data_min, + float data_max, + bool weight_channelwise_scale, + std::vector &weight_scales) { + using red::limits::MaxValue; + using red::limits::MinValue; + DType *weight_ptr = weight.data().dptr(); + NDArray quantized_weight = NDArray(weight.storage_type(), weight.shape(), + weight.ctx(), true, mshadow::kInt8); + int8_t *quan_weight_ptr = quantized_weight.data().dptr(); + size_t channel = weight.shape()[0]; + + //TODO(Zhennan): Handle the case weight is not in dims 4. + size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; + std::vector weight_c_min(channel, MaxValue()); + std::vector weight_c_max(channel, MinValue()); +#pragma omp parallel for + for (size_t c = 0; c < channel; ++c) { + DType *p1 = weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + if (weight_c_min[c] > p1[k]) + weight_c_min[c] = p1[k]; + if (weight_c_max[c] < p1[k]) + weight_c_max[c] = p1[k]; + } + } + + if (weight_channelwise_scale) { + weight_scales.resize(channel); +#pragma omp parallel for + for (size_t c = 0; c < channel; ++c) { + DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); + weight_scales[c] = int8_range / weight_range; + DType *fp_ptr = weight_ptr + c * offset; + int8_t *quan_ptr = quan_weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + quan_ptr[k] = std::round(weight_scales[c] * fp_ptr[k]); + } + } + } + else { + DType total_min = weight_c_min[0]; + DType total_max = weight_c_max[0]; + for (size_t c = 0; c < channel; ++c) { + if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; + if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; + } + weight_scales.resize(1); + DType weight_range = MaxAbs(total_min, total_max); + weight_scales[0] = int8_range / weight_range; +#pragma omp parallel for + for (size_t c = 0; c < channel; ++c) { + DType *fp_ptr = weight_ptr + c * offset; + int8_t *quan_ptr = quan_weight_ptr + c * offset; + for (size_t k = 0; k < offset; ++k) { + quan_ptr[k] = std::round(weight_scales[0] * fp_ptr[k]); + } + } + } + + weight = quantized_weight; + if (has_bias) { + DType *bias_ptr = bias.data().dptr(); + NDArray quantized_bias = NDArray(bias.storage_type(), bias.shape(), + bias.ctx(), true, mshadow::kInt32); + int32_t *quan_bias_ptr = quantized_bias.data().dptr(); + DType data_scale = uint8_range / MaxAbs(data_min, data_max); + for (size_t c = 0; c < channel; ++c) { + auto weight_scale = + weight_channelwise_scale ? weight_scales[c] : weight_scales[0]; + float bias_scale = weight_scale * data_scale; + quan_bias_ptr[c] = std::round(bias_scale * bias_ptr[c]); + } + bias = quantized_bias; + } +} + +static void ConvFusionFallBackCompute() { + LOG(FATAL) << "Don't know how to do ConvFusionFallBackCompute!"; +} + +static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param, + const OpContext &ctx, + MKLDNNConvForward &fwd, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + if (SupportMKLDNNConv(full_param.conv_param, inputs[0])) { + // MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNConvolutionForwardFullFeature(full_param, ctx, fwd, inputs, req, outputs); + // MKLDNN_OPCHECK_RUN(ConvolutionCompute, attrs, ctx, inputs, req, + // outputs); + return; + } + ConvFusionFallBackCompute(); +} + +class SgMKLDNNConvOperator { + public: + explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) + : initalized(false), + subgraph_sym_(*attrs.subgraphs[0]), + param(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports " + "inference computation"; + } + + private: + bool initalized; + nnvm::Symbol subgraph_sym_; + MKLDNNConvFusionParam param; + std::shared_ptr fwd; + NDArray cached_weight_; + NDArray cached_bias_; + NDArray cached_data_; + NDArray cached_output_; + float cached_data_min; + float cached_data_max; + float cached_sum_min; + float cached_sum_max; + size_t weight_ver; + size_t bias_ver; + std::vector weight_scales; +}; + +void SgMKLDNNConvOperator::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto &full_conv_param = param.full_conv_param; + auto &mkldnn_param = param.full_conv_param.mkldnn_param; + auto &conv_param = param.full_conv_param.conv_param; + auto bn_param = param.bn_param.get(); + size_t input_size = + 2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) + + (mkldnn_param.with_sum ? 1 : 0) + + (mkldnn_param.quantized + ? 2 + (param.full_conv_param.mkldnn_param.with_sum ? 2 : 0) + : 0); + CHECK_EQ(inputs.size(), input_size); + size_t idx = 0; + + auto in_data = idx++; + auto in_weight = idx++; + auto in_bias = conv_param.no_bias ? 0 : (idx++); + auto in_gamma = mkldnn_param.with_bn ? (idx++) : 0; + auto in_beta = mkldnn_param.with_bn ? (idx++) : 0; + auto in_mean = mkldnn_param.with_bn ? (idx++) : 0; + auto in_var = mkldnn_param.with_bn ? (idx++) : 0; + auto in_sum = mkldnn_param.with_sum ? (idx++) : 0; + float data_min = + mkldnn_param.quantized ? inputs[idx++].data().dptr()[0] : 0.0; + float data_max = + mkldnn_param.quantized ? inputs[idx++].data().dptr()[0] : 0.0; + float sum_min = (mkldnn_param.with_sum && mkldnn_param.quantized) + ? inputs[idx++].data().dptr()[0] + : 0.0; + float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized) + ? inputs[idx++].data().dptr()[0] + : 0.0; + float *out_min_ptr = + mkldnn_param.quantized ? outputs[kMin].data().dptr() : nullptr; + float *out_max_ptr = + mkldnn_param.quantized ? outputs[kMax].data().dptr() : nullptr; + CHECK_EQ(input_size, idx); + bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias; + cached_data_ = inputs[in_data]; + if (mkldnn_param.with_sum) + cached_output_ = inputs[in_sum]; + else + cached_output_ = outputs[kOut]; + + // Check data change + // TODO(zhennan): Only update cached_* changed. + if (initalized) { + if (mkldnn_param.with_bn) { + if (weight_ver != inputs[in_weight].version() || + ((!conv_param.no_bias) && bias_ver != inputs[in_bias].version())) { + initalized = false; + } + } + if (initalized && mkldnn_param.quantized) { + if (cached_data_min != data_min || cached_data_max != data_max || + cached_sum_min != sum_min || cached_sum_max != sum_max || + weight_ver != inputs[in_weight].version() || + ((!conv_param.no_bias) && bias_ver != inputs[in_bias].version())) { + initalized = false; + } + } + } + + if (mkldnn_param.quantized) { + *out_min_ptr = mkldnn_param.min_calib_range.has_value() + ? mkldnn_param.min_calib_range.value() + : 0.0; + *out_max_ptr = mkldnn_param.max_calib_range.has_value() + ? mkldnn_param.max_calib_range.value() + : 1.0; + } + + if (!initalized) { + cached_data_min = data_min; + cached_data_max = data_max; + cached_sum_min = sum_min; + cached_sum_max = sum_max; + full_conv_param.sum_scale = 1.0; + cached_weight_ = inputs[in_weight].Reorder2Default(); + weight_ver = inputs[in_weight].version(); + if (!conv_param.no_bias) { + cached_bias_ = inputs[in_bias].Reorder2Default(); + bias_ver = inputs[in_bias].version(); + } else { + cached_bias_ = NDArray(); + } + + // Update weight and bias after bn fusion. + if (mkldnn_param.with_bn) { + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_gamma].dtype()); + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_beta].dtype()); + CHECK_EQ(inputs[in_weight].dtype(), inputs[in_var].dtype()); + MSHADOW_REAL_TYPE_SWITCH(inputs[in_weight].dtype(), DType, { + UpdateConvWeightBias( + cached_weight_, cached_bias_, conv_param.no_bias, inputs[in_gamma], + inputs[in_beta], inputs[in_mean], inputs[in_var], bn_param); + }); + } + // Quantize weight and bias. + if (mkldnn_param.quantized) { + MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { + QuantizeConvWeightBias( + cached_weight_, cached_bias_, has_bias, data_min, data_max, + mkldnn_param.weight_channelwise_scale, weight_scales); + }); + // Collect scale. + size_t channel = cached_weight_.shape()[0]; + float data_scale = uint8_range / MaxAbs(data_min, data_max); + float sum_in_scale = 1.0; + float out_range; + float quantized_out_range; + if (data_min < 0.0) { + // TODO(zhennan): we need to use offset to convert int8 to uint8. + LOG(FATAL) << "Can't handle negetive value for QuantizeData"; + } + if (mkldnn_param.with_sum) { + auto quantized_sum_range = sum_min < 0 ? int8_range : uint8_range; + sum_in_scale = quantized_sum_range / MaxAbs(sum_min, sum_max); + } + quantized_out_range = + IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range; + out_range = MaxAbs(*out_min_ptr, *out_max_ptr); + float output_scale = quantized_out_range / out_range; + full_conv_param.requantize_scales.resize(channel); + for (size_t c = 0; c < channel; c++) { + auto weight_scale = mkldnn_param.weight_channelwise_scale + ? weight_scales[c] + : weight_scales[0]; + full_conv_param.requantize_scales[c] = + output_scale / data_scale / weight_scale; + } + if (mkldnn_param.with_sum) + full_conv_param.sum_scale = output_scale / sum_in_scale; + } + fwd.reset(new MKLDNNConvForward( + full_conv_param, ctx.is_train, cached_data_, cached_weight_, + has_bias ? &cached_bias_ : nullptr, cached_output_)); + } + initalized = true; + std::vector new_inputs; + std::vector new_req; + if (has_bias) { + new_inputs = {cached_data_, cached_weight_, cached_bias_}; + new_req = {req[in_data], req[in_weight], req[in_bias]}; + } else { + new_inputs = {cached_data_, cached_weight_}; + new_req = {req[in_data], req[in_weight]}; + } + ConvolutionFusionComputeExCPU(full_conv_param, ctx, *fwd, new_inputs, new_req, + {cached_output_}); + + if (mkldnn_param.with_sum) { + auto out = const_cast(outputs[kOut]); + out.UpdateMKLDNNMemDesc(); + } +} + +static void SgMKLDNNConvOpForward(const OpStatePtr &state_ptr, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNConvOperator &op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + auto num_input = DefaultSubgraphOpNumInputs(attrs); + if (param.full_conv_param.mkldnn_param.quantized) + return num_input + 2 + param.full_conv_param.mkldnn_param.with_sum ? 2 : 0; + else + return num_input; +} + +static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { + MKLDNNConvFusionParam param_; + try { + param_.full_conv_param.mkldnn_param.Init(attrs->dict); + } catch (const dmlc::ParamError &e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto &k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + auto subgraph_sym = attrs->subgraphs[0]; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &node_name = node->op()->name; + if (node_name == "BatchNorm") { + CHECK_EQ(param_.full_conv_param.mkldnn_param.with_bn, true); + CHECK(param_.bn_param.get() == nullptr); + param_.bn_param = std::make_shared( + nnvm::get(node->attrs.parsed)); + } else if (node_name == "Convolution") { + param_.full_conv_param.conv_param = + nnvm::get(node->attrs.parsed); + } + }); + attrs->parsed = std::move(param_); +} + +static std::vector SgMKLDNNConvListInputNames( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + std::vector input_names = DefaultSubgraphOpListInputs(attrs); + if (param.full_conv_param.mkldnn_param.quantized) { + input_names.emplace_back("data_min"); + input_names.emplace_back("data_max"); + if (param.full_conv_param.mkldnn_param.with_sum) { + input_names.emplace_back("sum_min"); + input_names.emplace_back("sum_max"); + } + } + return input_names; +} + +static std::vector SgMKLDNNConvListOutputNames( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) + return std::vector{"output", "output_min", "output_max"}; + else + return std::vector{"output"}; +} + +static OpStatePtr CreateSgMKLDNNConvState(const nnvm::NodeAttrs &attrs, + Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +template +static void FilterMinMaxIndice(const MKLDNNConvParam &mkldnn_param, + std::vector *in_shapes, + std::vector *out_shapes, + std::vector &base_in_shapes, + std::vector &base_out_shapes, + std::unordered_set &minmax_indice) { + base_out_shapes.push_back(out_shapes->at(0)); + size_t last = in_shapes->size() - 1; + if (mkldnn_param.with_sum) { + minmax_indice.insert(last); + minmax_indice.insert(last - 1); + minmax_indice.insert(last - 2); + minmax_indice.insert(last - 3); + base_in_shapes = + std::vector(in_shapes->begin(), in_shapes->end() - 4); + } else { + minmax_indice.insert(last); + minmax_indice.insert(last - 1); + base_in_shapes = + std::vector(in_shapes->begin(), in_shapes->end() - 2); + } +} + +static bool SgMKLDNNConvInferShape(const nnvm::NodeAttrs &attrs, + std::vector *in_shapes, + std::vector *out_shapes) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + std::unordered_set minmax_indice; + std::vector base_in_shapes; + std::vector base_out_shapes; + + FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_shapes, + out_shapes, base_in_shapes, base_out_shapes, + minmax_indice); + bool result = + DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); + size_t base_idx = 0; + for (size_t i = 0; i < in_shapes->size(); ++i) { + if (minmax_indice.count(i)) { + SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1)); + } else { + in_shapes->at(i) = base_in_shapes[base_idx++]; + } + } + out_shapes->at(0) = base_out_shapes[0]; + SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1)); + SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1)); + return result; + } else { + return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes); + } +} + +static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + std::unordered_set minmax_indice; + std::vector base_in_types; + std::vector base_out_types; + FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_types, + out_types, base_in_types, base_out_types, + minmax_indice); + // Override data type to fp32 for default infer type as bn doesn't support + // uint8. + int orig_data = base_in_types[0]; + base_in_types[0] = mshadow::kFloat32; + int orig_sum = base_in_types[0]; + if (param.full_conv_param.mkldnn_param.with_sum) { + auto sum_index = GetInSumIndex(param); + orig_sum = base_in_types[sum_index]; + base_in_types[sum_index] = mshadow::kFloat32; + } + bool result = DefaultSubgraphOpType(attrs, &base_in_types, &base_out_types); + base_in_types[0] = orig_data; + if (param.full_conv_param.mkldnn_param.with_sum) { + auto sum_index = GetInSumIndex(param); + base_in_types[sum_index] = orig_sum; + } + size_t base_idx = 0; + for (size_t i = 0; i < in_types->size(); ++i) { + if (minmax_indice.count(i)) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } else { + in_types->at(i) = base_in_types[base_idx++]; + } + } + if (IsOutputUInt8(param.full_conv_param.mkldnn_param)) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + return result; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +static bool SgMKLDNNConvOpStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_stypes, + std::vector *out_stypes) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + std::unordered_set minmax_indice; + std::vector base_in_stypes; + std::vector base_out_stypes; + FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_stypes, + out_stypes, base_in_stypes, base_out_stypes, + minmax_indice); + bool result = DefaultSubgraphOpStorageType( + attrs, dev_mask, dispatch_mode, &base_in_stypes, &base_out_stypes); + size_t base_idx = 0; + for (size_t i = 0; i < in_stypes->size(); ++i) { + if (minmax_indice.count(i)) { + type_assign(&in_stypes->at(i), mxnet::kDefaultStorage); + } else { + in_stypes->at(i) = base_in_stypes[base_idx++]; + } + } + out_stypes->at(0) = base_out_stypes[0]; + type_assign(&out_stypes->at(1), mxnet::kDefaultStorage); + type_assign(&out_stypes->at(2), mxnet::kDefaultStorage); + return result; + } else { + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + in_stypes, out_stypes); + } +} + +std::vector> SgMKLDNNConvInplaceOption( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.full_conv_param.mkldnn_param.with_sum) { + return std::vector>{{GetInSumIndex(param), 0}}; + } else { + return std::vector>(); + } +} + +nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_sg_mkldnn_conv"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "true"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +NNVM_REGISTER_OP(_sg_mkldnn_conv) +.describe(R"code(_sg_mkldnn_conv)code" ADD_FILELINE) +.set_num_inputs(SgMKLDNNConvNumInputs) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + return param.full_conv_param.mkldnn_param.quantized ? 3 : 1; +}) +.set_attr_parser(SgMKLDNNConvParamParser) +.set_attr("FListInputNames", SgMKLDNNConvListInputNames) +.set_attr("FListOutputNames", SgMKLDNNConvListOutputNames) +.set_attr("FCreateOpState", CreateSgMKLDNNConvState) +.set_attr("FInferShape", SgMKLDNNConvInferShape) +.set_attr("FInferType", SgMKLDNNConvInferType) +.set_attr("FInferStorageType", SgMKLDNNConvOpStorageType) +.set_attr("FStatefulComputeEx", SgMKLDNNConvOpForward) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FMutateInputs", + DefaultSubgraphOpMutableInputs) +.set_attr("key_var_num_args", "num_args") +.set_attr("FInplaceOption", SgMKLDNNConvInplaceOption) +.set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp); +} // namespace op +} // namespace mxnet +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc new file mode 100644 index 000000000000..2a75a573a52b --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include "../common.h" +#include "../subgraph_property.h" +#include "../../nn/convolution-inl.h" +#include "../../nn/activation-inl.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" + +namespace mxnet { +namespace op { +class SgMKLDNNConvSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + sFail = 0, + sStart, + sBN, + sSum, + sSuccess, + }; + + private: + bool disable_all; + bool disable_conv_bn; + bool disable_conv_relu; + bool disable_conv_sum; + SelectStatus status; + std::vector matched_list; + + public: + SgMKLDNNConvSelector(int dis_all, int dis_conv_bn, int dis_conv_relu, int dis_conv_sum) + : disable_all(dis_all), + disable_conv_bn(dis_conv_bn), + disable_conv_relu(dis_conv_relu), + disable_conv_sum(dis_conv_sum) {} + + virtual bool Select(const nnvm::Node &n) override { + bool match = + (!disable_all) && (!n.is_variable()) && (n.op()->name == "Convolution"); + if (match) { + status = sStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + virtual bool SelectInput(const nnvm::Node &n, + const nnvm::Node &new_node) override { + return false; + } + + virtual bool SelectOutput(const nnvm::Node &n, + const nnvm::Node &new_node) override { + if (status == sFail || status == sSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + status = sSuccess; + return false; + } + // Use status machine to do selection. The status change is + // sStart -> sBN -> sSum -> sSuccess + switch (status) { + case sStart: + if ((!disable_conv_bn) && new_node.op()->name == "BatchNorm") { + matched_list.push_back(&new_node); + status = sBN; + return true; + } + case sBN: + if ((!disable_conv_sum) && new_node.op()->name == "elemwise_add") { + matched_list.push_back(&new_node); + status = sSum; + return true; + } + case sSum: + default: + if ((!disable_conv_relu) && new_node.op()->name == "Activation") { + const ActivationParam ¶m = + nnvm::get(new_node.attrs.parsed); + if (param.act_type == activation::kReLU) { + matched_list.push_back(&new_node); + // If we find conv+relu, then we can't match bn anymore. + if (status == sStart) status = sBN; + return true; + } else { + status = sSuccess; + return false; + } + } + status = sSuccess; + return false; + } + } + + virtual std::vector Filter( + const std::vector &candidates) override { + if (status == sFail) { + return std::vector(0); + } else { + return candidates; + } + } +}; + +class SgMKLDNNConvProperty : public SubgraphProperty { + public: + SgMKLDNNConvProperty() { + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION", 0); + disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION_CONV_BN", 0); + disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION_CONV_RELU", 0); + disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION_CONV_SUM", 0); + + if (disable_all) { + LOG(INFO) << "MKLDNN Convolution fusion pass is disabled."; + } else { + LOG(INFO) << "Start to execute MKLDNN Convolution fusion pass."; + } + } + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); + std::ostringstream node_name; + node_name << "sg_mkldnn_"; + bool _with_sum = false; + DFSVisit(new_sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &sub_name = node->op()->name; + if (sub_name == "Convolution") { + node_name << "conv_"; + } else if (sub_name == "BatchNorm") { + node_name << "bn_"; + n->attrs.dict["with_bn"] = "true"; + } else if (sub_name == "elemwise_add") { + node_name << "add_"; + n->attrs.dict["with_sum"] = "true"; + _with_sum = true; + + } else if (sub_name == "Activation") { + node_name << "relu_"; + if (!_with_sum) { + n->attrs.dict["with_relu"] = "true"; + } else { + n->attrs.dict["with_postsum_relu"] = "true"; + } + } + }); + node_name << std::to_string(subgraph_id); + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_mkldnn_conv"); + CHECK(n->attrs.op); + n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + n->op()->attr_parser(&(n->attrs)); + return n; + } + + virtual SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared( + disable_all, disable_conv_bn, disable_conv_relu, disable_conv_sum); + return selector; + } + + virtual void ConnectSubgraphOutput( + const nnvm::NodePtr n, + std::vector &output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries.size(); ++i) { + *output_entries[i] = nnvm::NodeEntry{n, 0, 0}; + } + } + + virtual void ConnectSubgraphInput( + const nnvm::NodePtr n, std::vector &input_entries, + std::vector &orig_input_entries) const override { + auto sym = n->attrs.subgraphs[0]; + std::unordered_set node_sets; + DFSVisit(sym->outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + node_sets.insert(node.get()); + if (node->op()->name == "elemwise_add") { + // Make sure n is the left operand of sum, if not, + // switch sum operands sequence to ensure that + // the extra sum operand stays in the last of inputs. + if (node_sets.count(node->inputs[1].node.get())) { + auto tmp = node->inputs[1]; + node->inputs[1] = node->inputs[0]; + node->inputs[0] = tmp; + std::rotate(input_entries.begin(), input_entries.begin() + 1, + input_entries.end()); + std::rotate(orig_input_entries.begin(), + orig_input_entries.begin() + 1, orig_input_entries.end()); + } + } + }); + n->inputs = orig_input_entries; + } + + private: + int disable_all; + int disable_conv_bn; + int disable_conv_relu; + int disable_conv_sum; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); + +} // namespace op +} // namespace mxnet +#endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index 315f7eec00c6..d7adde34a7db 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -653,9 +653,9 @@ void CreateSubgraphNode(Graph* g, nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id); // Connect the external nodes to the subgraph node. - for (size_t i = 0; i < output_entries.size(); ++i) { - *output_entries[i] = nnvm::NodeEntry{n, static_cast(i), 0}; - } + subg_prop->ConnectSubgraphOutput(n, output_entries); + subg_prop->ConnectSubgraphInput(n, input_entries, orig_input_entries); + n->inputs = orig_input_entries; const auto& indexed_graph = g->indexed_graph(); for (size_t i = 0; i < n->inputs.size(); ++i) { diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index cfbc1f837337..e45a8d99c76e 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -92,6 +92,22 @@ class SubgraphProperty { // execute the operators in the subgraph. virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s, const int subgraph_id = 0) const = 0; + // Connect subgraph internal output with external output entries. By default, + // each output entry will connect to an unique internal output. + virtual void ConnectSubgraphOutput( + const nnvm::NodePtr n, + std::vector &output_entries) const { + for (size_t i = 0; i < output_entries.size(); ++i) { + *output_entries[i] = nnvm::NodeEntry{n, static_cast(i), 0}; + } + } + // Connect subgraph internal input with external input entries. By default, + // each input entry will connect in top sorted order. + virtual void ConnectSubgraphInput( + const nnvm::NodePtr n, std::vector &input_entries, + std::vector &orig_input_entries) const { + n->inputs = orig_input_entries; + } // set an attr with name in the attr map template SubgraphProperty& SetAttr(const std::string& name, const T& value) { From 9134b4d2f1bcd18aca82329baaf49d537361b482 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 12 Sep 2018 22:17:52 +0800 Subject: [PATCH 02/28] Fix lint --- .../quantization/imagenet_gen_qsym_mkldnn.py | 32 +---- include/mxnet/c_api.h | 2 +- python/mxnet/contrib/quantization.py | 5 +- src/c_api/c_api_symbolic.cc | 4 +- .../nn/mkldnn/mkldnn_convolution-inl.h | 3 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 23 ++-- .../quantization/quantize_graph_pass.cc | 7 +- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 129 +++++++++--------- .../subgraph/mkldnn/mkldnn_conv_property.cc | 43 +++--- src/operator/subgraph/partition_graph.cc | 4 +- src/operator/subgraph/subgraph_property.h | 13 +- 11 files changed, 123 insertions(+), 142 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 59444c0ad6df..55c57c005e67 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -55,10 +55,8 @@ def save_params(fname, arg_params, aux_params, logger=None): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with mkldnn support') - parser.add_argument('--ctx', type=str, default='cpu') - parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'ssd', 'imagenet1k-resnet-50', 'imagenet1k-resnet-50-v1', - 'imagenet1k-inception-v3', 'imagenet1k-inception-bn', 'imagenet1k-vgg-16'], - help='currently only supports imagenet1k-resnet-152, imagenet1k-inception-bn or imagenet1k-vgg-16') + parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'], + help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn') parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--label-name', type=str, default='softmax_label') parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec', @@ -109,14 +107,7 @@ def save_params(fname, arg_params, aux_params, logger=None): 'be calibrated offline if calibration mode is ' 'enabled') args = parser.parse_args() - - if args.ctx == 'gpu': - ctx = mx.gpu(0) - elif args.ctx == 'cpu': - ctx = mx.cpu(0) - else: - raise ValueError('ctx %s is not supported in this script' % args.ctx) - + ctx = mx.cpu(0) logging.basicConfig() logger = logging.getLogger('logger') logger.setLevel(logging.INFO) @@ -158,23 +149,14 @@ def save_params(fname, arg_params, aux_params, logger=None): excluded_sym_names = [] if args.model == 'imagenet1k-resnet-152': rgb_mean = '0,0,0' - if args.ctx == 'gpu': - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 - or name.find('sc') != -1 - or name.find('fc') != -1) - else: - calib_layer = lambda name: name.endswith('_output') - excluded_sym_names += ['flatten0', 'fc1'] + calib_layer = lambda name: name.endswith('_output') + excluded_sym_names += ['flatten0', 'fc1'] if exclude_first_conv: excluded_sym_names += ['sg_mkldnn_conv_bn_relu_0', 'pooling0'] elif args.model == 'imagenet1k-inception-bn': rgb_mean = '123.68,116.779,103.939' - if args.ctx == 'gpu': - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1 - or name.find('fc') != -1) - else: - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1) - excluded_sym_names += ['flatten', 'fc1'] + calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1) + excluded_sym_names += ['flatten', 'fc1'] if exclude_first_conv: excluded_sym_names += ['conv_1'] else: diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 784f88774763..4d920a4af2bc 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1518,7 +1518,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const SymbolHandle *excluded_symbols, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, const bool disable_requantize, - bool calib_quantize); + const bool calib_quantize); /*! * \brief Set calibration table to node attributes in the sym diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index d8a91f8d8956..8bfa2586919a 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -55,7 +55,6 @@ def _quantize_params(qsym, params, th_dict): params : dict of str->NDArray th_dict: dict of min/max pairs of layers' output """ - print(th_dict) inputs_name = qsym.list_arguments() quantized_params = {} for name in inputs_name: @@ -74,7 +73,6 @@ def _quantize_params(qsym, params, th_dict): elif name.endswith(('_min')): output = name[: - len('_min')] + "_output" if output in th_dict: - print(name) quantized_params[name] = ndarray.array([th_dict[output][0]]) elif name.endswith(('_max')): output = name[: - len('_min')] + "_output" @@ -128,7 +126,8 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, mx_uint(num_offline), c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), - ctypes.c_bool(disable_requantize))) + ctypes.c_bool(disable_requantize), + ctypes.c_bool(calib_quantize_op))) return Symbol(out) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 98699ef01aaf..0e689c08c3ff 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -652,7 +652,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const char **offline_params, const char *quantized_dtype, const bool disable_requantize, - bool calib_quantize) { + const bool calib_quantize) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); @@ -719,4 +719,4 @@ int MXGenBackendSubgraph(const char *backend, SymbolHandle sym_handle, s->outputs = g.outputs; *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); -} \ No newline at end of file +} diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 9df3806f8b65..6b3140a9dab0 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -27,6 +27,7 @@ #if MXNET_USE_MKLDNN == 1 +#include #include #include "../convolution-inl.h" #include "./mkldnn_ops-inl.h" @@ -133,7 +134,7 @@ MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, - MKLDNNConvForward &fwd, + MKLDNNConvForward *fwd, const std::vector &in_data, const std::vector &req, const std::vector &out_data); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index f02fe2bf2458..8dbbfd86912f 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -82,7 +82,6 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu. float beta = 1.0f; // ignored for mkldnn_eltwise_relu. ops.append_eltwise(scale, eltwise_relu, alpha, beta); - } if (param.mkldnn_param.with_sum) { ops.append_sum(param.sum_scale); @@ -284,7 +283,7 @@ MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, - MKLDNNConvForward &fwd, + MKLDNNConvForward *fwd, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { @@ -292,7 +291,7 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, NDArray weight = in_data[conv::kWeight]; bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn; auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder( - fwd.fwd_pd.src_primitive_desc()); + fwd->fwd_pd.src_primitive_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it @@ -301,20 +300,20 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, // This asks the engine to change the layout of the weight array after // it's used. weight.Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), param.conv_param.num_group); } else { // For inference, we want to reorder the weight array so we don't need to // reorder data every time. if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), param.conv_param.num_group); // We also need to modify the layout on the original weight array. The // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); + weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); + CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); } } mkldnn_output_t out_mem; @@ -322,18 +321,18 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, out_mem = mkldnn_output_t( OutDataOp::Noop, const_cast(out_data[conv::kOut].GetMKLDNNDataReorder( - fwd.fwd_pd.dst_primitive_desc()))); + fwd->fwd_pd.dst_primitive_desc()))); } else { out_mem = CreateMKLDNNMem(out_data[conv::kOut], - fwd.fwd_pd.dst_primitive_desc(), req[conv::kOut]); + fwd->fwd_pd.dst_primitive_desc(), req[conv::kOut]); } const mkldnn::memory *bias_mem = nullptr; if (!no_bias) { bias_mem = in_data[conv::kBias].GetMKLDNNData(); } - fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Get()->Submit(); @@ -351,7 +350,7 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); - MKLDNNConvolutionForwardFullFeature(param, ctx, fwd, in_data, req, out_data); + MKLDNNConvolutionForwardFullFeature(param, ctx, &fwd, in_data, req, out_data); } void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index e1a17ad1faa4..b7ea14b63231 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -142,10 +142,9 @@ Graph QuantizeGraph(Graph &&src) { // reused next time when the same entry is visited again. if (ExcludeKey(node, e)) { new_node->inputs.emplace_back(mirror_entry); - } - else if (!NeedQuantize(e.node, excluded_nodes) && - (mirror_node->op() == nullptr || - mirror_node->op()->name != "_contrib_quantize")) { + } else if (!NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() == nullptr || + mirror_node->op()->name != "_contrib_quantize")) { NodePtr quantize_node = InsertNode("_contrib_quantize", e.node->attrs.name + "_quantize", new_node, mirror_entry); quantize_node->attrs.dict["out_type"] = quantized_dtype; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index ed575bbb9a4e..a99377b9bf92 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -44,17 +44,17 @@ static const size_t int8_range = 127; enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; template -static void UpdateConvWeightBias(NDArray &weight, NDArray &bias, bool no_bias, +static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, const NDArray &gamma, const NDArray &beta, const NDArray &mean, const NDArray &variance, const BatchNormParam *param) { // TODO(Zhennan): Handle the case weight is not in dims 4. - NDArray update_weight = NDArray(weight.storage_type(), weight.shape(), - weight.ctx(), true, weight.dtype()); + NDArray update_weight = NDArray(weight->storage_type(), weight->shape(), + weight->ctx(), true, weight->dtype()); NDArray update_bias = NDArray(beta.storage_type(), beta.shape(), beta.ctx(), true, beta.dtype()); - DType *weight_ptr = weight.data().dptr(); - DType *bias_ptr = no_bias ? nullptr : bias.data().dptr(); + DType *weight_ptr = weight->data().dptr(); + DType *bias_ptr = no_bias ? nullptr : bias->data().dptr(); DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); DType *beta_ptr = beta.Reorder2Default().data().dptr(); DType *mean_ptr = mean.Reorder2Default().data().dptr(); @@ -62,7 +62,7 @@ static void UpdateConvWeightBias(NDArray &weight, NDArray &bias, bool no_bias, DType *update_weight_ptr = update_weight.data().dptr(); DType *update_bias_ptr = update_bias.data().dptr(); size_t channel = gamma.shape()[0]; - size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; + size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; #pragma omp parallel for for (size_t c = 0; c < channel; ++c) { DType *p1 = reinterpret_cast(weight_ptr + c * offset); @@ -79,8 +79,8 @@ static void UpdateConvWeightBias(NDArray &weight, NDArray &bias, bool no_bias, p2[k] = p1[k] * alpha; } } - weight = update_weight; - bias = update_bias; + *weight = update_weight; + *bias = update_bias; } static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { @@ -89,21 +89,21 @@ static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { } template -static void QuantizeConvWeightBias(NDArray &weight, NDArray &bias, +static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, bool has_bias, float data_min, float data_max, bool weight_channelwise_scale, - std::vector &weight_scales) { + std::vector *weight_scales) { using red::limits::MaxValue; using red::limits::MinValue; - DType *weight_ptr = weight.data().dptr(); - NDArray quantized_weight = NDArray(weight.storage_type(), weight.shape(), - weight.ctx(), true, mshadow::kInt8); + DType *weight_ptr = weight->data().dptr(); + NDArray quantized_weight = NDArray(weight->storage_type(), weight->shape(), + weight->ctx(), true, mshadow::kInt8); int8_t *quan_weight_ptr = quantized_weight.data().dptr(); - size_t channel = weight.shape()[0]; + size_t channel = weight->shape()[0]; - //TODO(Zhennan): Handle the case weight is not in dims 4. - size_t offset = weight.shape()[1] * weight.shape()[2] * weight.shape()[3]; + // TODO(Zhennan): Handle the case weight is not in dims 4. + size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; std::vector weight_c_min(channel, MaxValue()); std::vector weight_c_max(channel, MinValue()); #pragma omp parallel for @@ -118,52 +118,51 @@ static void QuantizeConvWeightBias(NDArray &weight, NDArray &bias, } if (weight_channelwise_scale) { - weight_scales.resize(channel); + weight_scales->resize(channel); #pragma omp parallel for for (size_t c = 0; c < channel; ++c) { DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); - weight_scales[c] = int8_range / weight_range; + weight_scales->at(c) = int8_range / weight_range; DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { - quan_ptr[k] = std::round(weight_scales[c] * fp_ptr[k]); + quan_ptr[k] = std::round(weight_scales->at(c) * fp_ptr[k]); } } - } - else { - DType total_min = weight_c_min[0]; - DType total_max = weight_c_max[0]; - for (size_t c = 0; c < channel; ++c) { - if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; - if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; - } - weight_scales.resize(1); + } else { + DType total_min = weight_c_min[0]; + DType total_max = weight_c_max[0]; + for (size_t c = 0; c < channel; ++c) { + if (total_min > weight_c_min[c]) total_min = weight_c_min[c]; + if (total_max < weight_c_max[c]) total_max = weight_c_max[c]; + } + weight_scales->resize(1); DType weight_range = MaxAbs(total_min, total_max); - weight_scales[0] = int8_range / weight_range; + weight_scales->at(0) = int8_range / weight_range; #pragma omp parallel for for (size_t c = 0; c < channel; ++c) { DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { - quan_ptr[k] = std::round(weight_scales[0] * fp_ptr[k]); + quan_ptr[k] = std::round(weight_scales->at(0) * fp_ptr[k]); } } } - weight = quantized_weight; + *weight = quantized_weight; if (has_bias) { - DType *bias_ptr = bias.data().dptr(); - NDArray quantized_bias = NDArray(bias.storage_type(), bias.shape(), - bias.ctx(), true, mshadow::kInt32); + DType *bias_ptr = bias->data().dptr(); + NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(), + bias->ctx(), true, mshadow::kInt32); int32_t *quan_bias_ptr = quantized_bias.data().dptr(); DType data_scale = uint8_range / MaxAbs(data_min, data_max); for (size_t c = 0; c < channel; ++c) { auto weight_scale = - weight_channelwise_scale ? weight_scales[c] : weight_scales[0]; + weight_channelwise_scale ? weight_scales->at(c) : weight_scales->at(0); float bias_scale = weight_scale * data_scale; quan_bias_ptr[c] = std::round(bias_scale * bias_ptr[c]); } - bias = quantized_bias; + *bias = quantized_bias; } } @@ -173,7 +172,7 @@ static void ConvFusionFallBackCompute() { static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param, const OpContext &ctx, - MKLDNNConvForward &fwd, + MKLDNNConvForward *fwd, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { @@ -320,17 +319,19 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, CHECK_EQ(inputs[in_weight].dtype(), inputs[in_beta].dtype()); CHECK_EQ(inputs[in_weight].dtype(), inputs[in_var].dtype()); MSHADOW_REAL_TYPE_SWITCH(inputs[in_weight].dtype(), DType, { - UpdateConvWeightBias( - cached_weight_, cached_bias_, conv_param.no_bias, inputs[in_gamma], - inputs[in_beta], inputs[in_mean], inputs[in_var], bn_param); + UpdateConvWeightBias(&cached_weight_, &cached_bias_, + conv_param.no_bias, inputs[in_gamma], + inputs[in_beta], inputs[in_mean], + inputs[in_var], bn_param); }); } // Quantize weight and bias. if (mkldnn_param.quantized) { MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { - QuantizeConvWeightBias( - cached_weight_, cached_bias_, has_bias, data_min, data_max, - mkldnn_param.weight_channelwise_scale, weight_scales); + QuantizeConvWeightBias(&cached_weight_, &cached_bias_, + has_bias, data_min, data_max, + mkldnn_param.weight_channelwise_scale, + &weight_scales); }); // Collect scale. size_t channel = cached_weight_.shape()[0]; @@ -375,8 +376,8 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, new_inputs = {cached_data_, cached_weight_}; new_req = {req[in_data], req[in_weight]}; } - ConvolutionFusionComputeExCPU(full_conv_param, ctx, *fwd, new_inputs, new_req, - {cached_output_}); + ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd.get(), new_inputs, + new_req, {cached_output_}); if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); @@ -469,22 +470,22 @@ template static void FilterMinMaxIndice(const MKLDNNConvParam &mkldnn_param, std::vector *in_shapes, std::vector *out_shapes, - std::vector &base_in_shapes, - std::vector &base_out_shapes, - std::unordered_set &minmax_indice) { - base_out_shapes.push_back(out_shapes->at(0)); + std::vector *base_in_shapes, + std::vector *base_out_shapes, + std::unordered_set *minmax_indice) { + base_out_shapes->push_back(out_shapes->at(0)); size_t last = in_shapes->size() - 1; if (mkldnn_param.with_sum) { - minmax_indice.insert(last); - minmax_indice.insert(last - 1); - minmax_indice.insert(last - 2); - minmax_indice.insert(last - 3); - base_in_shapes = + minmax_indice->insert(last); + minmax_indice->insert(last - 1); + minmax_indice->insert(last - 2); + minmax_indice->insert(last - 3); + *base_in_shapes = std::vector(in_shapes->begin(), in_shapes->end() - 4); } else { - minmax_indice.insert(last); - minmax_indice.insert(last - 1); - base_in_shapes = + minmax_indice->insert(last); + minmax_indice->insert(last - 1); + *base_in_shapes = std::vector(in_shapes->begin(), in_shapes->end() - 2); } } @@ -499,8 +500,8 @@ static bool SgMKLDNNConvInferShape(const nnvm::NodeAttrs &attrs, std::vector base_out_shapes; FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_shapes, - out_shapes, base_in_shapes, base_out_shapes, - minmax_indice); + out_shapes, &base_in_shapes, &base_out_shapes, + &minmax_indice); bool result = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); size_t base_idx = 0; @@ -529,8 +530,8 @@ static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs, std::vector base_in_types; std::vector base_out_types; FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_types, - out_types, base_in_types, base_out_types, - minmax_indice); + out_types, &base_in_types, &base_out_types, + &minmax_indice); // Override data type to fp32 for default infer type as bn doesn't support // uint8. int orig_data = base_in_types[0]; @@ -579,8 +580,8 @@ static bool SgMKLDNNConvOpStorageType(const nnvm::NodeAttrs &attrs, std::vector base_in_stypes; std::vector base_out_stypes; FilterMinMaxIndice(param.full_conv_param.mkldnn_param, in_stypes, - out_stypes, base_in_stypes, base_out_stypes, - minmax_indice); + out_stypes, &base_in_stypes, &base_out_stypes, + &minmax_indice); bool result = DefaultSubgraphOpStorageType( attrs, dev_mask, dispatch_mode, &base_in_stypes, &base_out_stypes); size_t base_idx = 0; @@ -650,4 +651,4 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) .set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp); } // namespace op } // namespace mxnet -#endif // if MXNET_USE_MKLDNN == 1 +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index 2a75a573a52b..e0d8c5b6a87b 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -57,7 +57,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { disable_conv_relu(dis_conv_relu), disable_conv_sum(dis_conv_sum) {} - virtual bool Select(const nnvm::Node &n) override { + bool Select(const nnvm::Node &n) override { bool match = (!disable_all) && (!n.is_variable()) && (n.op()->name == "Convolution"); if (match) { @@ -69,13 +69,11 @@ class SgMKLDNNConvSelector : public SubgraphSelector { return false; } - virtual bool SelectInput(const nnvm::Node &n, - const nnvm::Node &new_node) override { + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { return false; } - virtual bool SelectOutput(const nnvm::Node &n, - const nnvm::Node &new_node) override { + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { if (status == sFail || status == sSuccess || new_node.is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal @@ -122,7 +120,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { } } - virtual std::vector Filter( + std::vector Filter( const std::vector &candidates) override { if (status == sFail) { return std::vector(0); @@ -190,26 +188,26 @@ class SgMKLDNNConvProperty : public SubgraphProperty { return n; } - virtual SubgraphSelectorPtr CreateSubgraphSelector() const override { + SubgraphSelectorPtr CreateSubgraphSelector() const override { auto selector = std::make_shared( disable_all, disable_conv_bn, disable_conv_relu, disable_conv_sum); return selector; } - virtual void ConnectSubgraphOutput( + void ConnectSubgraphOutput( const nnvm::NodePtr n, - std::vector &output_entries) const override { + std::vector *output_entries) const override { // Connect all extern output entries to output[0] - for (size_t i = 0; i < output_entries.size(); ++i) { - *output_entries[i] = nnvm::NodeEntry{n, 0, 0}; + for (size_t i = 0; i < output_entries->size(); ++i) { + *output_entries->at(i) = nnvm::NodeEntry{n, 0, 0}; } } - virtual void ConnectSubgraphInput( - const nnvm::NodePtr n, std::vector &input_entries, - std::vector &orig_input_entries) const override { + void ConnectSubgraphInput( + const nnvm::NodePtr n, std::vector *input_entries, + std::vector *orig_input_entries) const override { auto sym = n->attrs.subgraphs[0]; - std::unordered_set node_sets; + std::unordered_set node_sets; DFSVisit(sym->outputs, [&](const nnvm::NodePtr &node) { if (node->is_variable()) return; node_sets.insert(node.get()); @@ -221,14 +219,15 @@ class SgMKLDNNConvProperty : public SubgraphProperty { auto tmp = node->inputs[1]; node->inputs[1] = node->inputs[0]; node->inputs[0] = tmp; - std::rotate(input_entries.begin(), input_entries.begin() + 1, - input_entries.end()); - std::rotate(orig_input_entries.begin(), - orig_input_entries.begin() + 1, orig_input_entries.end()); + std::rotate(input_entries->begin(), input_entries->begin() + 1, + input_entries->end()); + std::rotate(orig_input_entries->begin(), + orig_input_entries->begin() + 1, + orig_input_entries->end()); } } }); - n->inputs = orig_input_entries; + n->inputs = *orig_input_entries; } private: @@ -240,7 +239,7 @@ class SgMKLDNNConvProperty : public SubgraphProperty { MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); -} // namespace op -} // namespace mxnet +} // namespace op +} // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 #endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index d7adde34a7db..c789d25a1525 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -653,8 +653,8 @@ void CreateSubgraphNode(Graph* g, nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id); // Connect the external nodes to the subgraph node. - subg_prop->ConnectSubgraphOutput(n, output_entries); - subg_prop->ConnectSubgraphInput(n, input_entries, orig_input_entries); + subg_prop->ConnectSubgraphOutput(n, &output_entries); + subg_prop->ConnectSubgraphInput(n, &input_entries, &orig_input_entries); n->inputs = orig_input_entries; const auto& indexed_graph = g->indexed_graph(); diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index e45a8d99c76e..c5a04afb63a6 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -96,17 +96,17 @@ class SubgraphProperty { // each output entry will connect to an unique internal output. virtual void ConnectSubgraphOutput( const nnvm::NodePtr n, - std::vector &output_entries) const { - for (size_t i = 0; i < output_entries.size(); ++i) { - *output_entries[i] = nnvm::NodeEntry{n, static_cast(i), 0}; + std::vector *output_entries) const { + for (size_t i = 0; i < output_entries->size(); ++i) { + *output_entries->at(i) = nnvm::NodeEntry{n, static_cast(i), 0}; } } // Connect subgraph internal input with external input entries. By default, // each input entry will connect in top sorted order. virtual void ConnectSubgraphInput( - const nnvm::NodePtr n, std::vector &input_entries, - std::vector &orig_input_entries) const { - n->inputs = orig_input_entries; + const nnvm::NodePtr n, std::vector *input_entries, + std::vector *orig_input_entries) const { + n->inputs = *orig_input_entries; } // set an attr with name in the attr map template @@ -121,6 +121,7 @@ class SubgraphProperty { CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty"; return nnvm::get(*it->second); } + protected: std::unordered_map> attrs_; }; From 34af2292ffe00a5a3e1e419a33e4f465396e2060 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 13 Sep 2018 16:45:12 +0800 Subject: [PATCH 03/28] Fix performance regression caused by mkldnn fallback. --- src/executor/attach_op_execs_pass.cc | 22 +++++++++++++------ src/operator/quantization/dequantize.cc | 1 + .../mkldnn/mkldnn_quantized_pooling.cc | 1 + src/operator/quantization/quantize.cc | 1 + src/operator/quantization/requantize.cc | 1 + src/operator/subgraph/mkldnn/mkldnn_conv.cc | 1 + 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 0e415ef5112a..a0176fab0a04 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -159,9 +159,13 @@ class StatefulComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(out_array, req); - CreateDefaultInputs(in_array, &in_array_fallback); - fcompute_(state_, op_ctx, in_array_fallback, req, out_array); - return; + // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); + if (!is_mkldnn.get(attrs_.op, false)) { + CreateDefaultInputs(in_array, &in_array_fallback); + fcompute_(state_, op_ctx, in_array_fallback, req, out_array); + return; + } #endif fcompute_(state_, op_ctx, in_array, req, out_array); } @@ -180,12 +184,14 @@ class StatefulComputeExExecutor : public OpExecutor { return state_; } - explicit StatefulComputeExExecutor(const OpStatePtr& state, + explicit StatefulComputeExExecutor(const NodeAttrs& attrs, + const OpStatePtr& state, const FStatefulComputeEx& fcompute, ExecType exec_type) - : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} + : attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: + NodeAttrs attrs_; OpStatePtr state_; FStatefulComputeEx fcompute_; ExecType exec_type_; @@ -302,7 +308,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(state, fcompute_ex, exec_type); + ret[i] = std::make_shared(inode.source->attrs, state, + fcompute_ex, exec_type); } else { FStatefulCompute fcompute = common::GetFCompute( op, "FStatefulCompute", vctx[i]); @@ -322,7 +329,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { ret[i] = std::make_shared( - ret[fwd_id].get()->state(), fcompute_ex, exec_type); + inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex, + exec_type); } else { FStatefulCompute fcompute = common::GetFCompute( op, "FStatefulCompute", vctx[i]); diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc index bbd79417676b..e20bc1722213 100644 --- a/src/operator/quantization/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -72,6 +72,7 @@ by keep zero centered for the quantized value: .set_attr("FInferType", DequantizeType) .set_attr("FInferStorageType", DequantizeStorageType) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNDequantizeCompute) #endif .set_attr("FCompute", DequantizeCompute) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc index b81881af96dc..07e14412618d 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc @@ -46,6 +46,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op } NNVM_REGISTER_OP(_contrib_quantized_pooling) +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizedPoolingForward); } // namespace op diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc index 25fb19dddd12..5227751bc635 100644 --- a/src/operator/quantization/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -83,6 +83,7 @@ where .set_attr("FInferType", QuantizeType) .set_attr("FInferStorageType", QuantizeStorageType) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizeCompute) #endif .set_attr("FCompute", QuantizeCompute) diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc index 5ce0ff0b0209..68b1b65e4e7b 100644 --- a/src/operator/quantization/requantize.cc +++ b/src/operator/quantization/requantize.cc @@ -65,6 +65,7 @@ inference accuracy. .set_attr("FInferType", RequantizeType) .set_attr("FInferStorageType", RequantizeStorageType) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNRequantizeForward) #else .set_attr("FCompute", RequantizeForward) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index a99377b9bf92..c76ff6a64858 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -640,6 +640,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) .set_attr("FInferShape", SgMKLDNNConvInferShape) .set_attr("FInferType", SgMKLDNNConvInferType) .set_attr("FInferStorageType", SgMKLDNNConvOpStorageType) +.set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", SgMKLDNNConvOpForward) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; From affb56435a26b033a51332d8cfab3055541c97ab Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 15 Sep 2018 17:52:16 +0800 Subject: [PATCH 04/28] clean up include --- src/operator/subgraph/default_subgraph_property.cc | 2 -- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 4 ---- src/operator/subgraph/mkldnn/mkldnn_conv_property.cc | 4 ---- 3 files changed, 10 deletions(-) diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc index c8d3e9ffd438..3bcee715691f 100644 --- a/src/operator/subgraph/default_subgraph_property.cc +++ b/src/operator/subgraph/default_subgraph_property.cc @@ -17,8 +17,6 @@ * under the License. */ -#include -#include #include "./common.h" #include "./subgraph_property.h" diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index c76ff6a64858..d874a31b43de 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -18,11 +18,7 @@ */ #if MXNET_USE_MKLDNN == 1 -#include -#include #include "../common.h" -#include "../../../imperative/imperative_utils.h" -#include "../../../imperative/cached_op.h" #include "../../nn/convolution-inl.h" #include "../../nn/batch_norm-inl.h" #include "../../nn/activation-inl.h" diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index e0d8c5b6a87b..8459d68628d9 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -21,13 +21,9 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ #if MXNET_USE_MKLDNN == 1 -#include -#include #include "../common.h" #include "../subgraph_property.h" -#include "../../nn/convolution-inl.h" #include "../../nn/activation-inl.h" -#include "../../nn/mkldnn/mkldnn_convolution-inl.h" namespace mxnet { namespace op { From 5729b96886ce9d26d3fafdf15f2dccd8ea1c0a26 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 17 Sep 2018 12:36:42 +0800 Subject: [PATCH 05/28] Fix msbuild on openmp pragma. --- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index d874a31b43de..1a907fb53190 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -60,7 +60,7 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, size_t channel = gamma.shape()[0]; size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; #pragma omp parallel for - for (size_t c = 0; c < channel; ++c) { + for (int c = 0; c < static_cast(channel); ++c) { DType *p1 = reinterpret_cast(weight_ptr + c * offset); DType *p2 = reinterpret_cast(update_weight_ptr + c * offset); DType alpha = (param->fix_gamma ? static_cast(1.0f) : gamma_ptr[c]) / @@ -103,7 +103,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, std::vector weight_c_min(channel, MaxValue()); std::vector weight_c_max(channel, MinValue()); #pragma omp parallel for - for (size_t c = 0; c < channel; ++c) { + for (int c = 0; c < static_cast(channel); ++c) { DType *p1 = weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { if (weight_c_min[c] > p1[k]) @@ -116,7 +116,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, if (weight_channelwise_scale) { weight_scales->resize(channel); #pragma omp parallel for - for (size_t c = 0; c < channel; ++c) { + for (int c = 0; c < static_cast(channel); ++c) { DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); weight_scales->at(c) = int8_range / weight_range; DType *fp_ptr = weight_ptr + c * offset; @@ -136,7 +136,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, DType weight_range = MaxAbs(total_min, total_max); weight_scales->at(0) = int8_range / weight_range; #pragma omp parallel for - for (size_t c = 0; c < channel; ++c) { + for (int c = 0; c < static_cast(channel); ++c) { DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { From 81e8e28ab5fd9809a6e18c0b15f403a1bc19d5b4 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 17 Sep 2018 14:30:55 +0800 Subject: [PATCH 06/28] Fix quantization test, allow to use original op names as exclude layer for quantization. --- .../quantization/imagenet_gen_qsym_mkldnn.py | 2 +- include/mxnet/c_api.h | 4 +-- python/mxnet/contrib/quantization.py | 26 ++++++---------- src/c_api/c_api_symbolic.cc | 15 ++++----- .../quantization/quantize_graph_pass.cc | 31 ++++++++++++++++--- .../python/quantization/test_quantization.py | 10 +++--- 6 files changed, 51 insertions(+), 37 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 55c57c005e67..d959ee23fbc6 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -202,7 +202,7 @@ def save_params(fname, arg_params, aux_params, logger=None): num_calib_examples=num_calib_batches * batch_size, calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, disable_requantize=args.disable_requantize, - label_names=(label_name,), + label_names=(label_name,), calib_quantize_op = True, logger=logger) if calib_mode == 'entropy': suffix = '-quantized-%dbatches-entropy' % num_calib_batches diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4d920a4af2bc..cf1aab14d40b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1506,7 +1506,7 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param sym_handle symbol to be converted * \param ret_sym_handle quantized symbol result * \param num_excluded_symbols number of layers excluded from being quantized in the input symbol -* \param excluded_symbols array of symbols to be excluded from being quantized +* \param excluded_symbols op names to be excluded from being quantized * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. @@ -1515,7 +1515,7 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, */ int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, - const SymbolHandle *excluded_symbols, + const char **excluded_symbols, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, const bool disable_requantize, const bool calib_quantize); diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 8bfa2586919a..8548e79f284a 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -82,7 +82,7 @@ def _quantize_params(qsym, params, th_dict): def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, quantized_dtype='int8', disable_requantize=False, - calib_quantize_op=True): + calib_quantize_op=False): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -90,8 +90,9 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, ---------- sym : Symbol FP32 neural network symbol. - excluded_symbols : list of symbols - Nodes in the network that users do not want to replace with a symbol of INT8 data type. + excluded_sym_names : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. offline_params : list of strs Names of the parameters that users want to quantize offline. It's always recommended to quantize parameters offline so that quantizing parameters during the inference can be @@ -104,12 +105,11 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, Whether perform offline calibration for quantize op. """ num_excluded_symbols = 0 - excluded_handles = [] if excluded_symbols is not None: assert isinstance(excluded_symbols, list) num_excluded_symbols = len(excluded_symbols) - for s in excluded_symbols: - excluded_handles.append(s.handle) + else: + excluded_symbols = [] num_offline = 0 offline = [] @@ -122,7 +122,7 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, check_call(_LIB.MXQuantizeSymbol(sym.handle, ctypes.byref(out), mx_uint(num_excluded_symbols), - c_array(SymbolHandle, excluded_handles), + c_str_array(excluded_symbols), mx_uint(num_offline), c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), @@ -442,7 +442,7 @@ def quantize_model(sym, arg_params, aux_params, ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, quantized_dtype='int8', disable_requantize=False, - calib_quantize_op=True, logger=logging): + calib_quantize_op=False, logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -516,18 +516,12 @@ def quantize_model(sym, arg_params, aux_params, raise ValueError('excluded_sym_names must be a list of strings representing' ' the names of the symbols that will not be quantized,' ' while received type %s' % str(type(excluded_sym_names))) - excluded_syms = [] - if excluded_sym_names is not None: - for sym_name in excluded_sym_names: - nodes = sym.get_internals() - idx = nodes.list_outputs().index(sym_name + '_output') - excluded_syms.append(nodes[idx]) - logger.info('Quantizing symbol') + logger.info('Quantizing symbol') if quantized_dtype not in ('int8', 'uint8'): raise ValueError('unknown quantized_dtype %s received,' ' expected `int8` or `uint8`' % quantized_dtype) - qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms, + qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), quantized_dtype=quantized_dtype, disable_requantize=disable_requantize, diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 0e689c08c3ff..ac421a911c8c 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -646,8 +646,8 @@ int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHand int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, - const mx_uint num_excluded_symbols, - const SymbolHandle *excluded_symbols, + const mx_uint num_excluded_op_names, + const char **excluded_op_names, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, @@ -657,19 +657,16 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); nnvm::Graph g = Symbol2Graph(*sym); - std::unordered_set excluded_nodes; - for (size_t i = 0; i < num_excluded_symbols; ++i) { - nnvm::Symbol* sym = static_cast(excluded_symbols[i]); - for (const auto& e : sym->outputs) { - excluded_nodes.emplace(e.node); - } + std::unordered_set excluded_node_names; + for (size_t i = 0; i < num_excluded_op_names; ++i) { + excluded_node_names.emplace(excluded_op_names[i]); } - g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_nodes)); std::unordered_set offline; for (size_t i = 0; i < num_offline; ++i) { offline.emplace(offline_params[i]); } std::string quantized_type(quantized_dtype); + g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); g.attrs["offline_params"] = std::make_shared(std::move(offline)); g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); g.attrs["calib_quantize"] = std::make_shared(calib_quantize); diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index b7ea14b63231..325d2e8c7781 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -89,9 +89,32 @@ std::vector OfflineParams(std::vector&& outputs, return outputs; } -inline bool NeedQuantize(NodePtr node, const std::unordered_set excluded_nodes) { - static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); - return quantized_op_map.count(node->op()) && !excluded_nodes.count(node); +inline bool NeedQuantize(NodePtr node, + const std::unordered_set excluded_nodes) { + static auto& quantized_op_map = + Op::GetAttr("FQuantizedOp"); + if (quantized_op_map.count(node->op())) { + bool excluded = false; + if (node->attrs.subgraphs.size()) { + // This is a subgraph node, try to match subgraph name first, + // and then try to match inner node. + if (excluded_nodes.count(node->attrs.name)) { + excluded = true; + } else { + auto subgraph_sym = node->attrs.subgraphs[0]; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& node) { + if (node->is_variable()) return; + if (excluded_nodes.count(node->attrs.name)) { + excluded = true; + } + }); + } + } else { + excluded = excluded_nodes.count(node->attrs.name); + } + return !excluded; + } + return false; } inline bool ExcludeKey(NodePtr node, NodeEntry e) { @@ -112,7 +135,7 @@ Graph QuantizeGraph(Graph &&src) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); auto offline_params = src.GetAttr>("offline_params"); - auto excluded_nodes = src.GetAttr>("excluded_nodes"); + auto excluded_nodes = src.GetAttr>("excluded_nodes"); auto quantized_dtype = src.GetAttr("quantized_dtype"); auto calib_quantize = src.GetAttr("calib_quantize"); diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 369a923c1879..5ae2c6c398e9 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -374,7 +374,7 @@ def test_quantize_params(): for name in offline_params: params[name] = mx.nd.uniform(shape=(2, 2)) qsym = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params) - qparams = mx.contrib.quant._quantize_params(qsym, params) + qparams = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) param_names = params.keys() qparam_names = qparams.keys() for name in qparam_names: @@ -406,7 +406,7 @@ def get_fp32_residual(): fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc') sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False, out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') - return sym + return sym @with_seed() def test_quantize_model(): @@ -418,7 +418,7 @@ def check_params(params, qparams, qsym=None): assert k in qparams assert same(v.asnumpy(), qparams[k].asnumpy()) else: - qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params) + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) assert len(qparams) == len(qparams_ground_truth) for k, v in qparams_ground_truth.items(): assert k in qparams @@ -494,7 +494,7 @@ def check_params(params, qparams, qsym=None): assert k in qparams assert same(v.asnumpy(), qparams[k].asnumpy()) else: - qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params) + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) assert len(qparams) == len(qparams_ground_truth) for k, v in qparams_ground_truth.items(): assert k in qparams @@ -525,7 +525,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() - + sym = get_fp32_residual() mod = Module(symbol=sym) From a245bc308ad781612ed076154b7b647b1d476f5e Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 17 Sep 2018 16:05:03 +0800 Subject: [PATCH 07/28] Fix unittest. --- src/operator/nn/mkldnn/mkldnn_convolution.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 0a70de60df06..75e367ed7e07 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -39,7 +39,7 @@ DMLC_REGISTER_PARAMETER(MKLDNNConvParam); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { if (params.kernel.ndim() != 2) return false; - return input.shape().ndim() == 4; + return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; } inline static mkldnn::memory::desc GetInDataMemDesc(const NDArray &arr) { From ff2a3c22b043aacbc76d9a7cce9374d738a4b8f9 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 18 Sep 2018 21:31:52 +0800 Subject: [PATCH 08/28] Fix unittest --- .../quantization/imagenet_gen_qsym_mkldnn.py | 15 +-- include/mxnet/op_attr_types.h | 8 ++ src/ndarray/ndarray.cc | 6 +- src/operator/nn/mkldnn/mkldnn_base-inl.h | 5 + .../nn/mkldnn/mkldnn_convolution-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 2 +- .../mkldnn/mkldnn_quantized_conv.cc | 92 +++++++++++++++++++ .../quantization/quantize_graph_pass.cc | 24 +++-- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 21 ++++- 9 files changed, 157 insertions(+), 18 deletions(-) create mode 100644 src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index d959ee23fbc6..60611d6d4458 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -54,7 +54,7 @@ def save_params(fname, arg_params, aux_params, logger=None): if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with mkldnn support') + parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with MKL-DNN support') parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'], help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn') parser.add_argument('--batch-size', type=int, default=32) @@ -68,8 +68,7 @@ def save_params(fname, arg_params, aux_params, logger=None): help='number of batches for calibration') parser.add_argument('--exclude-first-conv', action='store_true', default=True, help='excluding quantizing the first conv layer since the' - ' number of channels is usually not a multiple of 4 in that layer' - ' which does not satisfy the requirement of cuDNN') + ' input data may have negative value which doesn\'t support at moment' ) parser.add_argument('--shuffle-dataset', action='store_true', default=True, help='shuffle the calibration dataset') parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304, @@ -98,7 +97,7 @@ def save_params(fname, arg_params, aux_params, logger=None): help='quantization destination data type for input data') parser.add_argument('--disable-requantize', type=bool, default=True, help='If disable requantize, the OP needed requantize' - ' will output int8 directly and hence requantize ' + ' will output uint8 directly and hence requantize ' 'OP is not needed during quantization. Note: ' 'calibration mode need to be used if requantize ' 'is disabled.') @@ -136,7 +135,9 @@ def save_params(fname, arg_params, aux_params, logger=None): # get number of batches for calibration num_calib_batches = args.num_calib_batches - if calib_mode != 'none': + if calib_mode == 'none': + logger.info('skip calibration step as calib_mode is none') + else: logger.info('number of batches = %d for calibration' % num_calib_batches) # get number of threads for decoding the dataset @@ -152,10 +153,10 @@ def save_params(fname, arg_params, aux_params, logger=None): calib_layer = lambda name: name.endswith('_output') excluded_sym_names += ['flatten0', 'fc1'] if exclude_first_conv: - excluded_sym_names += ['sg_mkldnn_conv_bn_relu_0', 'pooling0'] + excluded_sym_names += ['conv0', 'pooling0'] elif args.model == 'imagenet1k-inception-bn': rgb_mean = '123.68,116.779,103.939' - calib_layer = lambda name: name.endswith('_output') and (name.find('conv') != -1) + calib_layer = lambda name: name.endswith('_output') excluded_sym_names += ['flatten', 'fc1'] if exclude_first_conv: excluded_sym_names += ['conv_1'] diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index aa5d4e6de784..4ddce8643fbc 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -300,6 +300,14 @@ using FQuantizedOp = std::function; */ using FNeedRequantize = std::function; +/*! + * \brief Register a function to determine if the input of a quantized operator + * needs to be quantized. This is usually used for the quantized operators + * which can handle fp32 inputs directly. + */ +using FAvoidQuantizeInput = std::function; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 0ae4f3adbad1..90516c1f562e 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1561,8 +1561,12 @@ void NDArray::Save(dmlc::Stream *strm) const { save_data = nd_cpu.data(); } else { this->WaitToRead(); - save_data = this->data(); nd_cpu = *this; +#if MXNET_USE_MKLDNN == 1 + if (nd_cpu.IsMKLDNNData()) + nd_cpu = nd_cpu.Reorder2Default(); +#endif + save_data = nd_cpu.data(); } // save type flag diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 2e67b2c20033..f90291479b0f 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -132,6 +132,11 @@ static inline bool SupportMKLDNN(int dtype, const TShape &shape) { return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); } +static inline bool SupportMKLDNNQuantize(int dtype) { + return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || + dtype == mshadow::kUint8; +} + static inline bool SupportMKLDNN(const NDArray &input) { return SupportMKLDNN(input.dtype(), input.shape()) && SupportStorageMKLDNN(input.storage_type()); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 6b3140a9dab0..2daf60c2adea 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -127,7 +127,7 @@ class MKLDNNConvForward { typedef ParamOpSign MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, +MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 75e367ed7e07..42614401c41c 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -39,7 +39,7 @@ DMLC_REGISTER_PARAMETER(MKLDNNConvParam); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { if (params.kernel.ndim() != 2) return false; - return input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 4; + return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4; } inline static mkldnn::memory::desc GetInDataMemDesc(const NDArray &arr) { diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc new file mode 100644 index 000000000000..e4eaf7227f3e --- /dev/null +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_quantized_conv.cc + * \brief + * \author Wenting Jiang, Xinyu Chen +*/ + +#if MXNET_USE_MKLDNN == 1 +#include "../../nn/mkldnn/mkldnn_base-inl.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" +#include "../../nn/convolution-inl.h" +#include "../quantization_utils.h" +#include "../../tensor/matrix_op-inl.h" +#include "../../elemwise_op_common.h" +namespace mxnet { +namespace op { + +static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + CHECK_EQ(in_data[0].dtype(), mshadow::kUint8) + << "mkldnn_quantized_conv op only supports uint8 as input type"; + TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); + NDArray weight = in_data[conv::kWeight]; + MKLDNNConvFullParam full_param; + full_param.conv_param = nnvm::get(attrs.parsed); + full_param.mkldnn_param.Init(std::unordered_map()); + auto &fwd = GetConvFwd( + full_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + full_param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], + out_data[conv::kOut]); + const ConvolutionParam& param = full_param.conv_param; + auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); + const mkldnn::memory *weight_mem; + // For inference, we want to reorder the weight array so we don't need to + // reorder data every time. + if (weight.IsDefaultData()) { + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); + // We also need to modify the layout on the original weight array. The + // data conversion happens after the weight array is used. + weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); + } else { + weight_mem = weight.GetMKLDNNData(); + CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); + } + auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(), + req[conv::kOut]); + const mkldnn::memory *bias_mem = nullptr; + if (!param.no_bias) + bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc()); + fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + + CommitOutput(out_data[conv::kOut], out_mem); + MKLDNNStream::Get()->Submit(); + Stream *s = ctx.get_stream(); + const size_t num_inputs = param.no_bias ? 2 : 3; + mxnet_op::Kernel::Launch(s, 1, + out_data[1].data().dptr(), out_data[2].data().dptr(), + in_data[num_inputs].data().dptr(), + in_data[num_inputs+1].data().dptr(), + in_data[num_inputs+2].data().dptr(), + in_data[num_inputs+3].data().dptr()); +} + +NNVM_REGISTER_OP(_contrib_quantized_conv) +.set_attr("FComputeEx", MKLDNNQuantizedConvForward); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 325d2e8c7781..4222055a0348 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -134,6 +134,8 @@ inline bool ExcludeKey(NodePtr node, NodeEntry e) { Graph QuantizeGraph(Graph &&src) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); + static auto& avoid_quantize_input_map = + Op::GetAttr("FAvoidQuantizeInput"); auto offline_params = src.GetAttr>("offline_params"); auto excluded_nodes = src.GetAttr>("excluded_nodes"); auto quantized_dtype = src.GetAttr("quantized_dtype"); @@ -163,7 +165,8 @@ Graph QuantizeGraph(Graph &&src) { // taking mirror_entry as input to generate a quantized NDArray. Save the mapping between // e's source node and the newly created quantize op so that the quantize op can be // reused next time when the same entry is visited again. - if (ExcludeKey(node, e)) { + if (avoid_quantize_input_map.count(node->op()) && + avoid_quantize_input_map[node->op()](node->attrs, e.node->attrs)) { new_node->inputs.emplace_back(mirror_entry); } else if (!NeedQuantize(e.node, excluded_nodes) && (mirror_node->op() == nullptr || @@ -211,19 +214,26 @@ Graph QuantizeGraph(Graph &&src) { // for quantize node uint32_t min_index = 1; uint32_t max_index = 2; - if (e.node->op() != nullptr && - (quantized_op_map.count(e.node->op()) || - e.node->op()->name != "_contrib_quantize")) { + if (avoid_quantize_input_map.count(node->op()) && + avoid_quantize_input_map[node->op()](node->attrs, e.node->attrs)) { + // skip non-quantized input + continue; + } + if (quantized_op_map.count(e.node->op())) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; + size_t num_outputs = mirror_node->num_outputs() - 2; min_index = num_outputs + 2 * e.index; max_index = num_outputs + 2 * e.index + 1; - new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + } else { + CHECK(mirror_node->op() != nullptr && + mirror_node->op()->name == "_contrib_quantize") + << "The input is not quantize or quantized_op"; } + new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); } // If the new_node op registered attr FNeedRequantize, insert requantize node after it. diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 1a907fb53190..56bf032e52f2 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -622,6 +622,24 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { return node; } +static inline bool StringEndsWith(std::string const &str, + std::string const &suffix) { + if (suffix.size() > str.size()) return false; + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, + const NodeAttrs &input_attrs) { + const static std::unordered_set exclude_key{ + "weight", "bias", "gamma", "beta", "moving_mean", "moving_var"}; + for (auto i : exclude_key) { + if (StringEndsWith(input_attrs.name, i)) { + return true; + } + } + return false; +} + NNVM_REGISTER_OP(_sg_mkldnn_conv) .describe(R"code(_sg_mkldnn_conv)code" ADD_FILELINE) .set_num_inputs(SgMKLDNNConvNumInputs) @@ -645,7 +663,8 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) DefaultSubgraphOpMutableInputs) .set_attr("key_var_num_args", "num_args") .set_attr("FInplaceOption", SgMKLDNNConvInplaceOption) -.set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp); +.set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp) +.set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidQuantizeInput); } // namespace op } // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 From 013bb0d500f01893205276920108d7932aa339c6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 18 Sep 2018 21:44:05 +0800 Subject: [PATCH 09/28] fix lint --- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 56bf032e52f2..cc3ea1808c74 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -630,7 +630,7 @@ static inline bool StringEndsWith(std::string const &str, bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, const NodeAttrs &input_attrs) { - const static std::unordered_set exclude_key{ + const std::vector exclude_key{ "weight", "bias", "gamma", "beta", "moving_mean", "moving_var"}; for (auto i : exclude_key) { if (StringEndsWith(input_attrs.name, i)) { From 3eee14d6b4d8571abb997fc6eb6fae26be64006c Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 21 Sep 2018 09:48:53 +0800 Subject: [PATCH 10/28] Add post quantize fusion --- .../quantization/imagenet_gen_qsym_mkldnn.py | 20 +-- include/mxnet/c_api.h | 24 +-- python/mxnet/contrib/quantization.py | 21 +-- src/c_api/c_api_symbolic.cc | 7 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 2 +- .../quantization/quantize_graph_pass.cc | 142 ++++++--------- .../subgraph/mkldnn/mkldnn_conv-inl.h | 48 ++++++ src/operator/subgraph/mkldnn/mkldnn_conv.cc | 86 ++++----- .../mkldnn_conv_post_quantize_property.cc | 163 ++++++++++++++++++ .../subgraph/mkldnn/mkldnn_conv_property.cc | 26 ++- 10 files changed, 354 insertions(+), 185 deletions(-) create mode 100644 src/operator/subgraph/mkldnn/mkldnn_conv-inl.h create mode 100644 src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 60611d6d4458..884891f5934d 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -95,12 +95,6 @@ def save_params(fname, arg_params, aux_params, logger=None): parser.add_argument('--quantized-dtype', type=str, default='uint8', choices=['int8', 'uint8'], help='quantization destination data type for input data') - parser.add_argument('--disable-requantize', type=bool, default=True, - help='If disable requantize, the OP needed requantize' - ' will output uint8 directly and hence requantize ' - 'OP is not needed during quantization. Note: ' - 'calibration mode need to be used if requantize ' - 'is disabled.') parser.add_argument('--enable-calib-quantize', type=bool, default=True, help='If enabled, the quantize op will ' 'be calibrated offline if calibration mode is ' @@ -126,7 +120,7 @@ def save_params(fname, arg_params, aux_params, logger=None): out = SymbolHandle() backend = "MKLDNN" - check_call(_LIB.MXGenBackendSubgraph(c_str(backend), sym.handle, ctypes.byref(out))) + check_call(_LIB.MXGenBackendSubgraph(sym.handle, c_str(backend), ctypes.byref(out))) sym = Symbol(out) # get batch size @@ -178,10 +172,8 @@ def save_params(fname, arg_params, aux_params, logger=None): qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, - disable_requantize=args.disable_requantize, logger=logger) sym_name = '%s-symbol.json' % (prefix + '-quantized') - save_symbol(sym_name, qsym, logger) else: logger.info('Creating ImageRecordIter for reading calibration dataset') data = mx.io.ImageRecordIter(path_imgrec=args.calib_dataset, @@ -197,12 +189,11 @@ def save_params(fname, arg_params, aux_params, logger=None): seed=args.shuffle_seed, **mean_args) - cqsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, calib_mode=calib_mode, calib_data=data, num_calib_examples=num_calib_batches * batch_size, calib_layer=calib_layer, quantized_dtype=args.quantized_dtype, - disable_requantize=args.disable_requantize, label_names=(label_name,), calib_quantize_op = True, logger=logger) if calib_mode == 'entropy': @@ -213,7 +204,10 @@ def save_params(fname, arg_params, aux_params, logger=None): raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) - save_symbol(sym_name, cqsym, logger) - + # out = SymbolHandle() + # backend = "MKLDNN_POST_QUANTIZE" + # check_call(_LIB.MXGenBackendSubgraph(qsym.handle, c_str(backend), ctypes.byref(out))) + # qsym = Symbol(out) + save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index cf1aab14d40b..7d62103faadb 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1510,15 +1510,13 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, * \param num_offline number of parameters that are quantized offline * \param offline_params array of c strings representing the names of params quantized offline * \param quantized_dtype the quantized destination type for input data. -* \param disable_requantize whether disable requantize OP during quantization * \param calib_quantize whether calibrate quantize op with offline calibration data. */ -int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, - const mx_uint num_excluded_symbols, - const char **excluded_symbols, - const mx_uint num_offline, const char **offline_params, - const char *quantized_dtype, const bool disable_requantize, - const bool calib_quantize); +MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, + const mx_uint num_excluded_symbols, + const char **excluded_symbols, + const mx_uint num_offline, const char **offline_params, + const char *quantized_dtype, const bool calib_quantize); /*! * \brief Set calibration table to node attributes in the sym @@ -1528,17 +1526,21 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, * \param low_quantiles low quantiles of layers stored in the calibration table * \param high_quantiles high quantiles of layers stored in the calibration table * \param ret_sym_handle returned symbol - * \param disable_requantize whether disable requantize OP during quantization */ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const mx_uint num_layers, const char** layer_names, const float* low_quantiles, const float* high_quantiles, - SymbolHandle* ret_sym_handle, - const bool disable_requantize); + SymbolHandle* ret_sym_handle); -MXNET_DLL int MXGenBackendSubgraph(const char *backend, SymbolHandle sym_handle, +/*! + * \brief Run subgraph pass based on the backend provided + * \param sym_handle symbol to be converted + * \param backend backend names for subgraph pass + * \param ret_sym_handle returned symbol + */ +MXNET_DLL int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, SymbolHandle *ret_sym_handle); //-------------------------------------------- diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 8548e79f284a..052554bc4d6a 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -81,8 +81,7 @@ def _quantize_params(qsym, params, th_dict): return quantized_params def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, - quantized_dtype='int8', disable_requantize=False, - calib_quantize_op=False): + quantized_dtype='int8', calib_quantize_op=False): """Given a symbol object representing a neural network of data type FP32, quantize it into a INT8 network. @@ -99,8 +98,6 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, avoided. quantized_dtype: str The quantized destination type for input data. - disable_requantize : bool - Whether disable requantize OP functionality. calib_quantize_op : bool Whether perform offline calibration for quantize op. """ @@ -126,7 +123,6 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None, mx_uint(num_offline), c_array(ctypes.c_char_p, offline), c_str(quantized_dtype), - ctypes.c_bool(disable_requantize), ctypes.c_bool(calib_quantize_op))) return Symbol(out) @@ -185,7 +181,7 @@ def collect(self, name, arr): % (name, min_range, max_range)) -def _calibrate_quantized_sym(qsym, th_dict, disable_requantize=False): +def _calibrate_quantized_sym(qsym, th_dict): """Given a dictionary containing the thresholds for quantizing the layers, set the thresholds into the quantized symbol as the params of requantize operators. """ @@ -206,8 +202,7 @@ def _calibrate_quantized_sym(qsym, th_dict, disable_requantize=False): c_str_array(layer_output_names), c_array(ctypes.c_float, min_vals), c_array(ctypes.c_float, max_vals), - ctypes.byref(calibrated_sym), - ctypes.c_bool(disable_requantize))) + ctypes.byref(calibrated_sym))) return Symbol(calibrated_sym) @@ -441,8 +436,7 @@ def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', calib_data=None, num_calib_examples=None, calib_layer=None, - quantized_dtype='int8', disable_requantize=False, - calib_quantize_op=False, logger=logging): + quantized_dtype='int8', calib_quantize_op=False, logger=logging): """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. The backend quantized operators are only enabled for Linux systems. Please do not run inference using the quantized models on Windows for now. @@ -495,10 +489,6 @@ def quantize_model(sym, arg_params, aux_params, quantized_dtype : str The quantized destination type for input data. Currently support 'int8' and 'uint8', default value is 'int8'. - disable_requantize : bool - Whether disable requantize OP during quantization. If disabled, the related - quantized OP needed requantize will output int8 directly and hence requantize - OP is not needed during symbol quantization calib_quantize_op: bool Whether calibrate quantize op with its input calibration data. The quantize op's input should be in calib_layer logger : Object @@ -524,7 +514,6 @@ def quantize_model(sym, arg_params, aux_params, qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names, offline_params=list(arg_params.keys()), quantized_dtype=quantized_dtype, - disable_requantize=disable_requantize, calib_quantize_op=calib_quantize_op) th_dict = {} @@ -564,7 +553,7 @@ def quantize_model(sym, arg_params, aux_params, raise ValueError('unknown calibration mode %s received,' ' expected `none`, `naive`, or `entropy`' % calib_mode) logger.info('Calibrating quantized symbol') - qsym = _calibrate_quantized_sym(qsym, th_dict, disable_requantize) + qsym = _calibrate_quantized_sym(qsym, th_dict) logger.info('Quantizing parameters') qarg_params = _quantize_params(qsym, arg_params, th_dict) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index ac421a911c8c..d4625de80110 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -651,7 +651,6 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, - const bool disable_requantize, const bool calib_quantize) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); @@ -681,8 +680,7 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const char** layer_names, const float* min_ranges, const float* max_ranges, - SymbolHandle* ret_qsym_handle, - const bool disable_requantize) { + SymbolHandle* ret_qsym_handle) { nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol* sym = static_cast(qsym_handle); @@ -693,14 +691,13 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, calib_table.emplace(prefix+layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); } g.attrs["calib_table"] = std::make_shared(std::move(calib_table)); - g.attrs["disable_requantize"] = std::make_shared(disable_requantize); g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); s->outputs = g.outputs; *ret_qsym_handle = s; API_END_HANDLE_ERROR(delete s); } -int MXGenBackendSubgraph(const char *backend, SymbolHandle sym_handle, +int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, SymbolHandle *ret_sym_handle) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 42614401c41c..5b2bd84f448d 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -94,7 +94,7 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( } attr.set_post_ops(ops); - if (param.mkldnn_param.quantized) { + if (param.mkldnn_param.quantized && param.requantize_scales.size()) { int mask = param.mkldnn_param.weight_channelwise_scale ? 2 : 0; attr.set_output_scales(mask, param.requantize_scales); attr.set_int_output_round_mode(round_nearest); diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 4222055a0348..283108cc57f4 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -117,20 +117,6 @@ inline bool NeedQuantize(NodePtr node, return false; } -inline bool ExcludeKey(NodePtr node, NodeEntry e) { - auto findSGConv = node->attrs.name.find("sg_mkldnn_conv_"); - std::vector exclude_key{"weight", "bias", "gamma", "beta", "mean", "var"}; - bool found = false; - if (findSGConv == std::string::npos) return false; - for (size_t i = 0; i < exclude_key.size(); i++) { - if (e.node->attrs.name.find(exclude_key[i]) != std::string::npos) { - found = true; - break; - } - } - return found; -} - Graph QuantizeGraph(Graph &&src) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); static auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); @@ -239,18 +225,19 @@ Graph QuantizeGraph(Graph &&src) { // If the new_node op registered attr FNeedRequantize, insert requantize node after it. // Here it's assumed that the quantized_op node only produces three outputs: // out_data, min_range, and max_range. - if (need_requantize_map.count(new_node->op()) > 0 - && need_requantize_map[new_node->op()](new_node->attrs)) { - NodePtr requantize_node = Node::Create(); - requantize_node->attrs.op = Op::Get("_contrib_requantize"); - requantize_node->attrs.name = "requantize_" + node->attrs.name; - if (requantize_node->op()->attr_parser != nullptr) { - requantize_node->op()->attr_parser(&(requantize_node->attrs)); - } - for (size_t i = 0; i < 3; ++i) { - requantize_node->inputs.emplace_back(NodeEntry{new_node, static_cast(i), 0}); - } - new_node = requantize_node; + if (need_requantize_map.count(new_node->op()) > 0 && + need_requantize_map[new_node->op()](new_node->attrs)) { + NodePtr requantize_node = Node::Create(); + requantize_node->attrs.op = Op::Get("_contrib_requantize"); + requantize_node->attrs.name = "requantize_" + node->attrs.name; + if (requantize_node->op()->attr_parser != nullptr) { + requantize_node->op()->attr_parser(&(requantize_node->attrs)); + } + for (size_t i = 0; i < 3; ++i) { + requantize_node->inputs.emplace_back( + NodeEntry{new_node, static_cast(i), 0}); + } + new_node = requantize_node; } } else { // If the currently visited node does not need quantization, copy the current node to become @@ -260,36 +247,45 @@ Graph QuantizeGraph(Graph &&src) { // the new_node. *new_node = *node; new_node->inputs.clear(); - for (const auto& e : node->inputs) { - NodePtr mirror_node = mirror_map.at(e.node.get()); - NodeEntry mirror_entry = NodeEntry{ - mirror_node, e.index, e.version}; - // if input node is quantized operator, add dequantize node - if (NeedQuantize(e.node, excluded_nodes) && - (mirror_node->op() == nullptr || - mirror_node->op()->name != "_contrib_dequantize")) { - // here we calculate the output number (exclude min/max, in order to - // calculate min/max index from mirror node) based on assumption that - // there is only 1min and 1max output from mirror node (which is - // currently true) - size_t num_outputs = mirror_node->num_outputs() - 2; - uint32_t min_index = num_outputs + 2 * e.index; - uint32_t max_index = num_outputs + 2 * e.index + 1; - NodePtr dequantize_node = CreateNode("_contrib_dequantize", - e.node->attrs.name + "_dequantize"); - dequantize_node->inputs.emplace_back(mirror_entry); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); - dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); + if (node->is_variable() && node->attrs.name == "data") { + // Instert identity for data to collect calib for it. + NodePtr identity_node = + CreateNode("identity", new_node->attrs.name + "_id"); + identity_node->inputs.emplace_back(NodeEntry{new_node, 0, 0}); + new_node = identity_node; + } else { + for (const auto& e : node->inputs) { + NodePtr mirror_node = mirror_map.at(e.node.get()); + NodeEntry mirror_entry = NodeEntry{ + mirror_node, e.index, e.version}; + // if input node is quantized operator, add dequantize node + if (NeedQuantize(e.node, excluded_nodes) && + (mirror_node->op() == nullptr || + mirror_node->op()->name != "_contrib_dequantize")) { + // here we calculate the output number (exclude min/max, in order to + // calculate min/max index from mirror node) based on assumption that + // there is only 1min and 1max output from mirror node (which is + // currently true) + size_t num_outputs = mirror_node->num_outputs() - 2; + uint32_t min_index = num_outputs + 2 * e.index; + uint32_t max_index = num_outputs + 2 * e.index + 1; + NodePtr dequantize_node = CreateNode("_contrib_dequantize", + e.node->attrs.name + "_dequantize"); + dequantize_node->inputs.emplace_back(mirror_entry); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); + dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); - new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); - mirror_map[e.node.get()] = std::move(dequantize_node); - } else if (mirror_node->op() != nullptr - && mirror_node->op()->name == "_contrib_quantize") { - new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); - } else { - new_node->inputs.emplace_back( - NodeEntry{mirror_node, e.index, e.version}); + new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + mirror_map[e.node.get()] = std::move(dequantize_node); + } else if (mirror_node->op() != nullptr + && mirror_node->op()->name == "_contrib_quantize") { + new_node->inputs.emplace_back( + NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); + } else { + new_node->inputs.emplace_back( + NodeEntry{mirror_node, e.index, e.version}); + } } } } @@ -332,44 +328,18 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) { nnvm::Op::GetAttr("FNeedRequantize"); const auto& calib_table = g.GetAttr>>("calib_table"); - auto disable_requantize = g.GetAttr("disable_requantize"); - DFSVisit(g.outputs, [&](const NodePtr& node) { - bool found = false; - NodePtr quantized_op_node; - // If requantize is not disabled, find requantize OP and - // the thresholds from the calibration table with the key equal - // to the requantize OP's input node name, e.g. a quantized_conv node. - if (!disable_requantize && - node->op() != nullptr && node->op()->name == "_contrib_requantize") { - quantized_op_node = node->inputs[0].node; + // If the current op is requantize + // find the thresholds from the calibration table with the key equal + // to the current op's input node name, e.g. a quantized_conv2d node. + if (node->op() != nullptr && node->op()->name == "_contrib_requantize") { + NodePtr quantized_op_node = node->inputs[0].node; CHECK(quantized_op_node->op() != nullptr) << quantized_op_node->attrs.name << " must be an quantized op node"; CHECK(need_requantize_map.count(quantized_op_node->op()) > 0 && need_requantize_map[quantized_op_node->op()](quantized_op_node->attrs)) << quantized_op_node->attrs.name << " op must register FNeedRequantize attr" " and the attr func should return true"; - found = true; - // If requantize is disabled, find OPs that needed requantize and - // the thresholds from the calibration table with the key equal - // to the found OP's name, e.g. a quantized_conv node. - } else if (disable_requantize && node->op() != nullptr && - need_requantize_map.count(node->op()) > 0 && - need_requantize_map[node->op()](node->attrs)) { - quantized_op_node = node; - found = true; - } else if (disable_requantize && - node->op() != nullptr && node->op()->name == "_sg_mkldnn_conv" - && !node->attrs.name.find("quantized_")) { - quantized_op_node = node; - std::string out_data_name = quantized_op_node->attrs.name + "_output"; - const auto calib_table_iter = calib_table.find(out_data_name); - if (calib_table_iter != calib_table.end()) { - node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first); - node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second); - } - } - if (found) { std::string out_data_name = quantized_op_node->attrs.name + "_"; auto list_output_names_func = flist_outputs.get(quantized_op_node->op(), nullptr); // Here it's assumed that the quantized_op node only produces three outputs: diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h new file mode 100644 index 000000000000..8675446f5a14 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include "../../nn/convolution-inl.h" +#include "../../nn/batch_norm-inl.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" + +namespace mxnet { +namespace op { + +struct MKLDNNConvFusionParam { + MKLDNNConvFullParam full_conv_param; + std::shared_ptr bn_param; +}; + +static const size_t uint8_range = 255; +static const size_t int8_range = 127; + +enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index cc3ea1808c74..fe9d3ba90ffb 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -18,27 +18,19 @@ */ #if MXNET_USE_MKLDNN == 1 + +#include +#include +#include #include "../common.h" -#include "../../nn/convolution-inl.h" -#include "../../nn/batch_norm-inl.h" -#include "../../nn/activation-inl.h" #include "../../nn/mkldnn/mkldnn_base-inl.h" #include "../../nn/mkldnn/mkldnn_ops-inl.h" -#include "../../nn/mkldnn/mkldnn_convolution-inl.h" #include "../../quantization/quantization_utils.h" +#include "mkldnn_conv-inl.h" + namespace mxnet { namespace op { -struct MKLDNNConvFusionParam { - MKLDNNConvFullParam full_conv_param; - std::shared_ptr bn_param; -}; - -static const size_t uint8_range = 255; -static const size_t int8_range = 127; - -enum MKLDNNConvOpOutputs { kOut, kMin, kMax }; - template static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, const NDArray &gamma, const NDArray &beta, @@ -59,7 +51,7 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, DType *update_bias_ptr = update_bias.data().dptr(); size_t channel = gamma.shape()[0]; size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; -#pragma omp parallel for +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { DType *p1 = reinterpret_cast(weight_ptr + c * offset); DType *p2 = reinterpret_cast(update_weight_ptr + c * offset); @@ -102,7 +94,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; std::vector weight_c_min(channel, MaxValue()); std::vector weight_c_max(channel, MinValue()); -#pragma omp parallel for +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { DType *p1 = weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { @@ -115,7 +107,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, if (weight_channelwise_scale) { weight_scales->resize(channel); -#pragma omp parallel for +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); weight_scales->at(c) = int8_range / weight_range; @@ -135,7 +127,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, weight_scales->resize(1); DType weight_range = MaxAbs(total_min, total_max); weight_scales->at(0) = int8_range / weight_range; -#pragma omp parallel for +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; @@ -284,14 +276,15 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } } } - + bool post_requantize = false; if (mkldnn_param.quantized) { - *out_min_ptr = mkldnn_param.min_calib_range.has_value() - ? mkldnn_param.min_calib_range.value() - : 0.0; - *out_max_ptr = mkldnn_param.max_calib_range.has_value() - ? mkldnn_param.max_calib_range.value() - : 1.0; + if (mkldnn_param.min_calib_range.has_value() && + mkldnn_param.max_calib_range.has_value()) { + post_requantize = true; + mkldnn_param.weight_channelwise_scale = false; + *out_min_ptr = mkldnn_param.min_calib_range.value(); + *out_max_ptr = mkldnn_param.max_calib_range.value(); + } } if (!initalized) { @@ -335,6 +328,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, float sum_in_scale = 1.0; float out_range; float quantized_out_range; + float output_scale; if (data_min < 0.0) { // TODO(zhennan): we need to use offset to convert int8 to uint8. LOG(FATAL) << "Can't handle negetive value for QuantizeData"; @@ -343,17 +337,22 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, auto quantized_sum_range = sum_min < 0 ? int8_range : uint8_range; sum_in_scale = quantized_sum_range / MaxAbs(sum_min, sum_max); } - quantized_out_range = - IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range; - out_range = MaxAbs(*out_min_ptr, *out_max_ptr); - float output_scale = quantized_out_range / out_range; - full_conv_param.requantize_scales.resize(channel); - for (size_t c = 0; c < channel; c++) { - auto weight_scale = mkldnn_param.weight_channelwise_scale - ? weight_scales[c] - : weight_scales[0]; - full_conv_param.requantize_scales[c] = - output_scale / data_scale / weight_scale; + if (post_requantize) { + quantized_out_range = + IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range; + out_range = MaxAbs(*out_min_ptr, *out_max_ptr); + output_scale = quantized_out_range / out_range; + full_conv_param.requantize_scales.resize(channel); + for (size_t c = 0; c < channel; c++) { + auto weight_scale = mkldnn_param.weight_channelwise_scale + ? weight_scales[c] + : weight_scales[0]; + full_conv_param.requantize_scales[c] = + output_scale / data_scale / weight_scale; + } + } else { + output_scale = data_scale * weight_scales[0]; + full_conv_param.requantize_scales.resize(0); } if (mkldnn_param.with_sum) full_conv_param.sum_scale = output_scale / sum_in_scale; @@ -552,11 +551,17 @@ static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs, in_types->at(i) = base_in_types[base_idx++]; } } - if (IsOutputUInt8(param.full_conv_param.mkldnn_param)) { - TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + if (param.full_conv_param.mkldnn_param.min_calib_range.has_value() && + param.full_conv_param.mkldnn_param.max_calib_range.has_value()) { + if (IsOutputUInt8(param.full_conv_param.mkldnn_param)) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } } else { - TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); return result; @@ -664,7 +669,10 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) .set_attr("key_var_num_args", "num_args") .set_attr("FInplaceOption", SgMKLDNNConvInplaceOption) .set_attr("FQuantizedOp", SgMKLDNNConvQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .set_attr("FAvoidQuantizeInput", SgMKLDNNAvoidQuantizeInput); + } // namespace op } // namespace mxnet + #endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc new file mode 100644 index 000000000000..e4effbbcb464 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "../common.h" +#include "../subgraph_property.h" +#include "../../nn/mkldnn/mkldnn_convolution-inl.h" +#include "mkldnn_conv-inl.h" +#include "../../quantization/requantize-inl.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + sFail = 0, + sStart, + sSuccess, + }; + + private: + bool disable_all; + SelectStatus status; + std::vector matched_list; + + public: + explicit SgMKLDNNConvPostQuantizeSelector(int dis_all) + : disable_all(dis_all) {} + + bool Select(const nnvm::Node &n) override { + if ((!disable_all) && n.op() && n.op()->name == "_sg_mkldnn_conv") { + auto const ¶m = nnvm::get(n.attrs.parsed); + if (param.full_conv_param.mkldnn_param.quantized) { + status = sStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == sFail || status == sSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + status = sFail; + return false; + } + if (new_node.op()->name == "_contrib_requantize") { + auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (param.min_calib_range.has_value() && + param.max_calib_range.has_value()) { + matched_list.push_back(&new_node); + status = sSuccess; + return true; + } else { + status = sFail; + } + } + return false; + } + + std::vector Filter( + const std::vector &candidates) override { + if (status != sSuccess) { + return std::vector(0); + } else { + return candidates; + } + } +}; + +class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { + public: + SgMKLDNNConvPostQuantizeProperty() { + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", 0); + if (disable_all) { + LOG(INFO) << "MKLDNN Convolution post-quantization optimization pass is disabled."; + } else { + LOG(INFO) << "Start to execute MKLDNN Convolution post-quantization optimization pass."; + } + } + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::NodePtr conv_node = nullptr; + nnvm::NodePtr requantize_node = nullptr; + DFSVisit(sym.outputs, [&](const nnvm::NodePtr &node) { + if (node->is_variable()) return; + auto &op_name = node->op()->name; + if (op_name == "_sg_mkldnn_conv") { + conv_node = node; + } else if (op_name == "_contrib_requantize") { + requantize_node = node; + } + }); + CHECK_NOTNULL(conv_node); + CHECK_NOTNULL(requantize_node); + auto const &requantize_param = + nnvm::get(requantize_node->attrs.parsed); + CHECK(requantize_param.min_calib_range.has_value()); + CHECK(requantize_param.max_calib_range.has_value()); + conv_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + conv_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + conv_node->op()->attr_parser(&(conv_node->attrs)); + return conv_node; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = + std::make_shared(disable_all); + return selector; + } + + void ConnectSubgraphOutput( + const nnvm::NodePtr n, + std::vector *output_entries) const override { + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } + + private: + int disable_all; +}; + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNConvPostQuantizeProperty); + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index 8459d68628d9..f39b012299cc 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -17,10 +17,8 @@ * under the License. */ -#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ -#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ - #if MXNET_USE_MKLDNN == 1 + #include "../common.h" #include "../subgraph_property.h" #include "../../nn/activation-inl.h" @@ -54,10 +52,8 @@ class SgMKLDNNConvSelector : public SubgraphSelector { disable_conv_sum(dis_conv_sum) {} bool Select(const nnvm::Node &n) override { - bool match = - (!disable_all) && (!n.is_variable()) && (n.op()->name == "Convolution"); - if (match) { - status = sStart; + if (n.op() && n.op()->name == "Convolution") { + status = disable_all ? sSuccess : sStart; matched_list.clear(); matched_list.push_back(&n); return true; @@ -129,15 +125,17 @@ class SgMKLDNNConvSelector : public SubgraphSelector { class SgMKLDNNConvProperty : public SubgraphProperty { public: SgMKLDNNConvProperty() { - disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION", 0); - disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION_CONV_BN", 0); - disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION_CONV_RELU", 0); - disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSION_CONV_SUM", 0); + disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", 0); + disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0); + disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); + disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); + disable_all = + disable_all && disable_conv_bn && disable_conv_relu && disable_conv_sum; if (disable_all) { - LOG(INFO) << "MKLDNN Convolution fusion pass is disabled."; + LOG(INFO) << "MKLDNN Convolution optimization pass is disabled."; } else { - LOG(INFO) << "Start to execute MKLDNN Convolution fusion pass."; + LOG(INFO) << "Start to execute MKLDNN Convolution optimization pass."; } } static SubgraphPropertyPtr Create() { @@ -237,5 +235,5 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); } // namespace op } // namespace mxnet + #endif // if MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_CONV_H_ From 8322784da75a20f9d305366019c6d63161e365e2 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Fri, 21 Sep 2018 17:09:11 +0800 Subject: [PATCH 11/28] add test case --- tests/python/mkl/test_subgraph.py | 209 ++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 tests/python/mkl/test_subgraph.py diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py new file mode 100644 index 000000000000..cfa73c9dc1ce --- /dev/null +++ b/tests/python/mkl/test_subgraph.py @@ -0,0 +1,209 @@ +import sys +import os +import mxnet as mx +import numpy as np +import unittest +import ctypes +from mxnet.io import NDArrayIter +from mxnet.module import Module +from mxnet.symbol import Symbol +from importlib import import_module +from numpy.testing import assert_allclose +from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str +from mxnet.test_utils import DummyIter +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../unittest/')) +from common import with_seed + +def check_qsym_calibrated(qsym): + attrs = qsym.attr_dict() + if ''.join(qsym.attr_dict().keys()).find('quantized_pool') != -1: + return 0, 0 + assert ''.join(qsym.attr_dict().keys()).find('quantized_') != -1 + for k, v in attrs.items(): + if k.find('requantize_sg_mkldnn_conv') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + min_value = v['min_calib_range'] + max_value = v['max_calib_range'] + if k.find('_quantize') != -1: + assert v['out_type'] == 'uint8' + return float(min_value), float(max_value) + +def check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, label_shape): + mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) + mod.set_params(qarg_params, qaux_params) + batch = mx.io.DataBatch(data_val, []) + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + return output + +def check_quantize(sym, data_shape, label_shape, data_val, sym_output): + mod = Module(symbol=sym) + mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)], for_training=False) + mod.init_params() + arg_params, aux_params = mod.get_params() + excluded_sym_names = [] + if mx.current_context() == mx.cpu(): + excluded_sym_names += ['fc'] + calib_data = mx.nd.random.uniform(shape=data_shape) + calib_data = NDArrayIter(data=calib_data) + calib_data = DummyIter(calib_data) + calib_layer = lambda name: name.endswith('_output') + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + excluded_sym_names=excluded_sym_names, + quantized_dtype='uint8', + calib_mode='naive', + calib_data=calib_data, + calib_layer=calib_layer, + #disable_requantize=True, + calib_quantize_op=True, + num_calib_examples=20) + minVar, maxVar = check_qsym_calibrated(qsym) + rtol = (maxVar - minVar) / 256 + qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, label_shape) + assert_allclose(qsym_output[0].asnumpy(), sym_output[0].asnumpy(), rtol=rtol) + +def check_fusion(sym, date_shape, label_shape, name, nofusion=False): + exe = sym.simple_bind(mx.cpu(), data=date_shape, grad_req='null') + out = SymbolHandle() + backend = "MKLDNN" + check_call(_LIB.MXGenBackendSubgraph(sym.handle, c_str(backend), ctypes.byref(out))) + sym_sg = Symbol(out) + exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') + + mx.random.seed(12345) + for k, v in exe.arg_dict.items(): + v = mx.random.uniform(-1.0, 1.0, shape=v.shape) + data_val = [exe.arg_dict['data']] + + fwd = exe.forward(is_train=False) + fwd[0].wait_to_read() + + fwd_sg = exe_sg.forward(is_train=False) + fwd_sg[0].wait_to_read() + + # Check the result accuracy based on fp32 fusion + assert_allclose(fwd[0].asnumpy(), fwd_sg[0].asnumpy(), rtol=0) + attrs=sym_sg.attr_dict() + if not nofusion: + assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1 + for k, v in attrs.items(): + if k.find('sg_mkldnn_conv') != -1: + for attr_op in name: + assert v[attr_op] == 'true' + + # fp32 to uint8 + if nofusion: + check_quantize(sym, date_shape, label_shape, data_val, fwd[0]) + else: check_quantize(sym_sg, date_shape, label_shape, data_val, fwd[0]) + +def single_conv(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') + conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + fc = mx.sym.FullyConnected(data=conv, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +def conv_bn(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn1') + conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn = mx.symbol.BatchNorm(data=conv, name="bn") + fc = mx.sym.FullyConnected(data=bn, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +def conv_relu(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') + conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") + fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +def conv_sum(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') + conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + conv1 = mx.symbol.Convolution(data=bn, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1)) + sum1 = conv + conv1 + fc = mx.sym.FullyConnected(data=sum1, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +def conv_bn_relu(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn1') + conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn = mx.symbol.BatchNorm(data=conv, name="bn") + relu = mx.symbol.Activation(data=bn, name='relu', act_type="relu") + fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +def conv_bn_sum_relu(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bn1 = mx.symbol.BatchNorm(data=data, name="bn1") + conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn = mx.symbol.BatchNorm(data=conv, name="bn") + conv1 = mx.symbol.Convolution(data=bn1, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1)) + sum1 = bn + conv1 + relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") + fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +def int8_pooling(): + data = mx.symbol.Variable('data') + bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') + pool = mx.sym.Pooling(data=bn, kernel=(4, 4), pool_type='avg', name='pool') + fc = mx.sym.FullyConnected(data=pool, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + return sym + +@with_seed() +def test_sugbraph(): + def check_test_sugbraph(): + conv_attr = [''] + conv_relu_attr = ['with_relu'] + conv_bn_attr = ['with_bn'] + conv_sum_attr = ['with_sum'] + conv_bn_relu_attr = ['with_bn', 'with_relu'] + conv_bn_sum_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] + + shape = [(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] + label = [(4, 10), (32, 10), (64, 10)] + + for date_shape, label_shape in zip(shape, label): + net = conv_bn_sum_relu() + check_fusion(net, date_shape, label_shape, conv_bn_sum_relu_attr) + net = single_conv() + check_fusion(net, date_shape, label_shape, conv_attr) + net = conv_relu() + check_fusion(net, date_shape, label_shape, conv_relu_attr) + net = conv_bn() + check_fusion(net, date_shape, label_shape, conv_bn_attr) + net = conv_sum() + check_fusion(net, date_shape, label_shape, conv_sum_attr) + net = conv_bn_relu() + check_fusion(net, date_shape, label_shape, conv_bn_relu_attr) + net = int8_pooling() + check_fusion(net, date_shape, label_shape, '', True) + + check_test_sugbraph() From 302fa65e788b40425c1394baae4b9e0d996f5645 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Fri, 21 Sep 2018 17:14:12 +0800 Subject: [PATCH 12/28] add head license in test case --- tests/python/mkl/test_subgraph.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index cfa73c9dc1ce..8f7cac57decc 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import sys import os import mxnet as mx From efc7f1e02610aed0c81c156daa6ba8fb265cea0a Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sun, 23 Sep 2018 19:42:28 +0800 Subject: [PATCH 13/28] Remove GetBoolHash() --- src/operator/nn/mkldnn/mkldnn_convolution-inl.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 2daf60c2adea..aa58b655ec39 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -73,15 +73,6 @@ struct MKLDNNConvParam : public dmlc::Parameter { "through calibration. If present, it will be used to by " "quantized convolution op to calculate primitive scale"); } - const int GetBoolHash() const { - int hash = 0; - hash = hash * 2 + this->with_bn ? 1 : 0; - hash = hash * 2 + this->with_relu ? 1 : 0; - hash = hash * 2 + this->with_sum ? 1 : 0; - hash = hash * 2 + this->with_postsum_relu ? 1 : 0; - hash = hash * 2 + this->quantized ? 1 : 0; - return hash; - } }; struct MKLDNNConvFullParam { From 741c221c390db84e208e3d73fd5be1360df7dff6 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 25 Sep 2018 13:12:00 +0800 Subject: [PATCH 14/28] Remove mkldnn fallback change. --- src/executor/attach_op_execs_pass.cc | 22 +++++--------- .../nn/mkldnn/mkldnn_convolution-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 11 ++++--- src/operator/quantization/dequantize.cc | 1 - .../mkldnn/mkldnn_quantized_conv.cc | 9 ++---- .../mkldnn/mkldnn_quantized_pooling.cc | 1 - src/operator/quantization/quantize.cc | 1 - src/operator/quantization/requantize.cc | 1 - src/operator/subgraph/mkldnn/mkldnn_conv.cc | 30 ++++++++----------- 9 files changed, 30 insertions(+), 48 deletions(-) diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index a0176fab0a04..0e415ef5112a 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -159,13 +159,9 @@ class StatefulComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(out_array, req); - // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented - const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); - if (!is_mkldnn.get(attrs_.op, false)) { - CreateDefaultInputs(in_array, &in_array_fallback); - fcompute_(state_, op_ctx, in_array_fallback, req, out_array); - return; - } + CreateDefaultInputs(in_array, &in_array_fallback); + fcompute_(state_, op_ctx, in_array_fallback, req, out_array); + return; #endif fcompute_(state_, op_ctx, in_array, req, out_array); } @@ -184,14 +180,12 @@ class StatefulComputeExExecutor : public OpExecutor { return state_; } - explicit StatefulComputeExExecutor(const NodeAttrs& attrs, - const OpStatePtr& state, + explicit StatefulComputeExExecutor(const OpStatePtr& state, const FStatefulComputeEx& fcompute, ExecType exec_type) - : attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {} + : state_(state), fcompute_(fcompute), exec_type_(exec_type) {} private: - NodeAttrs attrs_; OpStatePtr state_; FStatefulComputeEx fcompute_; ExecType exec_type_; @@ -308,8 +302,7 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(inode.source->attrs, state, - fcompute_ex, exec_type); + ret[i] = std::make_shared(state, fcompute_ex, exec_type); } else { FStatefulCompute fcompute = common::GetFCompute( op, "FStatefulCompute", vctx[i]); @@ -329,8 +322,7 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) { // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { ret[i] = std::make_shared( - inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex, - exec_type); + ret[fwd_id].get()->state(), fcompute_ex, exec_type); } else { FStatefulCompute fcompute = common::GetFCompute( op, "FStatefulCompute", vctx[i]); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index aa58b655ec39..4dff9b8d46a2 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -118,7 +118,7 @@ class MKLDNNConvForward { typedef ParamOpSign MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, +MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 5b2bd84f448d..639d52b9ad0e 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -250,7 +250,7 @@ void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, } } -MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, +MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { @@ -259,7 +259,7 @@ MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, #else static MX_THREAD_LOCAL std::unordered_map fwds; #endif - MKLDNNConvSignature key(param.conv_param); + MKLDNNConvSignature key(param); key.AddSign(is_train); // Here we can sign the conv op with NDArray because conv primitive will // decide the right layout for the, so we only need to get the shape and the @@ -272,7 +272,10 @@ MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNConvForward fwd(param, is_train, data, weights, bias, output); + MKLDNNConvFullParam full_param; + full_param.conv_param = param; + full_param.mkldnn_param.Init(std::unordered_map()); + MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output); auto ins_ret = fwds.insert( std::pair(key, fwd)); CHECK(ins_ret.second); @@ -347,7 +350,7 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, param.conv_param = nnvm::get(attrs.parsed); param.mkldnn_param.Init(std::unordered_map()); auto &fwd = GetConvFwd( - param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.conv_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); MKLDNNConvolutionForwardFullFeature(param, ctx, &fwd, in_data, req, out_data); diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc index e20bc1722213..bbd79417676b 100644 --- a/src/operator/quantization/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -72,7 +72,6 @@ by keep zero centered for the quantized value: .set_attr("FInferType", DequantizeType) .set_attr("FInferStorageType", DequantizeStorageType) #if MXNET_USE_MKLDNN == 1 -.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNDequantizeCompute) #endif .set_attr("FCompute", DequantizeCompute) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc index e4eaf7227f3e..b8c47c3af11b 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -42,14 +42,11 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, << "mkldnn_quantized_conv op only supports uint8 as input type"; TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); NDArray weight = in_data[conv::kWeight]; - MKLDNNConvFullParam full_param; - full_param.conv_param = nnvm::get(attrs.parsed); - full_param.mkldnn_param.Init(std::unordered_map()); + ConvolutionParam param = nnvm::get(attrs.parsed); auto &fwd = GetConvFwd( - full_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], - full_param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], + param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); - const ConvolutionParam& param = full_param.conv_param; auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc()); const mkldnn::memory *weight_mem; // For inference, we want to reorder the weight array so we don't need to diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc index 07e14412618d..b81881af96dc 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc @@ -46,7 +46,6 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op } NNVM_REGISTER_OP(_contrib_quantized_pooling) -.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizedPoolingForward); } // namespace op diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc index 5227751bc635..25fb19dddd12 100644 --- a/src/operator/quantization/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -83,7 +83,6 @@ where .set_attr("FInferType", QuantizeType) .set_attr("FInferStorageType", QuantizeStorageType) #if MXNET_USE_MKLDNN == 1 -.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizeCompute) #endif .set_attr("FCompute", QuantizeCompute) diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc index 68b1b65e4e7b..5ce0ff0b0209 100644 --- a/src/operator/quantization/requantize.cc +++ b/src/operator/quantization/requantize.cc @@ -65,7 +65,6 @@ inference accuracy. .set_attr("FInferType", RequantizeType) .set_attr("FInferStorageType", RequantizeStorageType) #if MXNET_USE_MKLDNN == 1 -.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNRequantizeForward) #else .set_attr("FCompute", RequantizeForward) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index fe9d3ba90ffb..2b34eb3f8cd1 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -165,10 +165,7 @@ static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param, const std::vector &req, const std::vector &outputs) { if (SupportMKLDNNConv(full_param.conv_param, inputs[0])) { - // MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNConvolutionForwardFullFeature(full_param, ctx, fwd, inputs, req, outputs); - // MKLDNN_OPCHECK_RUN(ConvolutionCompute, attrs, ctx, inputs, req, - // outputs); return; } ConvFusionFallBackCompute(); @@ -200,8 +197,6 @@ class SgMKLDNNConvOperator { std::shared_ptr fwd; NDArray cached_weight_; NDArray cached_bias_; - NDArray cached_data_; - NDArray cached_output_; float cached_data_min; float cached_data_max; float cached_sum_min; @@ -252,13 +247,10 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, mkldnn_param.quantized ? outputs[kMax].data().dptr() : nullptr; CHECK_EQ(input_size, idx); bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias; - cached_data_ = inputs[in_data]; - if (mkldnn_param.with_sum) - cached_output_ = inputs[in_sum]; - else - cached_output_ = outputs[kOut]; + NDArray data_ = inputs[in_data]; + NDArray output_ = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut]; - // Check data change + // Check input change // TODO(zhennan): Only update cached_* changed. if (initalized) { if (mkldnn_param.with_bn) { @@ -281,9 +273,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, if (mkldnn_param.min_calib_range.has_value() && mkldnn_param.max_calib_range.has_value()) { post_requantize = true; - mkldnn_param.weight_channelwise_scale = false; + mkldnn_param.weight_channelwise_scale = true; *out_min_ptr = mkldnn_param.min_calib_range.value(); *out_max_ptr = mkldnn_param.max_calib_range.value(); + } else { + mkldnn_param.weight_channelwise_scale = false; + } } @@ -358,21 +353,21 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, full_conv_param.sum_scale = output_scale / sum_in_scale; } fwd.reset(new MKLDNNConvForward( - full_conv_param, ctx.is_train, cached_data_, cached_weight_, - has_bias ? &cached_bias_ : nullptr, cached_output_)); + full_conv_param, ctx.is_train, data_, cached_weight_, + has_bias ? &cached_bias_ : nullptr, output_)); } initalized = true; std::vector new_inputs; std::vector new_req; if (has_bias) { - new_inputs = {cached_data_, cached_weight_, cached_bias_}; + new_inputs = {data_, cached_weight_, cached_bias_}; new_req = {req[in_data], req[in_weight], req[in_bias]}; } else { - new_inputs = {cached_data_, cached_weight_}; + new_inputs = {data_, cached_weight_}; new_req = {req[in_data], req[in_weight]}; } ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd.get(), new_inputs, - new_req, {cached_output_}); + new_req, {output_}); if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); @@ -659,7 +654,6 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) .set_attr("FInferShape", SgMKLDNNConvInferShape) .set_attr("FInferType", SgMKLDNNConvInferType) .set_attr("FInferStorageType", SgMKLDNNConvOpStorageType) -.set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", SgMKLDNNConvOpForward) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; From 05e595a4fd673099922cd3cb6636993d6c30f86e Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 25 Sep 2018 20:29:53 +0800 Subject: [PATCH 15/28] Address Haibin's comments. --- .../quantization/imagenet_gen_qsym_mkldnn.py | 6 ++-- python/mxnet/contrib/quantization.py | 12 ------- src/common/utils.h | 8 +++++ .../quantization/quantize_graph_pass.cc | 6 ++-- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 36 ++++++++----------- .../mkldnn_conv_post_quantize_property.cc | 18 +++++----- 6 files changed, 37 insertions(+), 49 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 884891f5934d..0436c4defac2 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -114,9 +114,9 @@ def save_params(fname, arg_params, aux_params, logger=None): if calib_mode != 'none': download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset) - # download model - prefix, epoch = download_model(model_name=args.model, logger=logger) - sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + # download model + prefix, epoch = download_model(model_name=args.model, logger=logger) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) out = SymbolHandle() backend = "MKLDNN" diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 052554bc4d6a..3b04016351ad 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -420,18 +420,6 @@ def _load_params(params, logger=logging): raise ValueError('Unsupported params provided. Must be either a path to the param file or' ' a pair of dictionaries representing arg_params and aux_params') -def save_params(fname, arg_params, aux_params, logger=None): - if logger is not None: - logger.info('Saving params into file at %s' % fname) - save_dict = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in arg_params.items()} - save_dict.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) - ndarray.save(fname, save_dict) - -def save_symbol(fname, sym, logger=None): - if logger is not None: - logger.info('Saving symbol into file at %s' % fname) - sym.save(fname) - def quantize_model(sym, arg_params, aux_params, data_names=('data',), label_names=('softmax_label',), ctx=cpu(), excluded_sym_names=None, calib_mode='entropy', diff --git a/src/common/utils.h b/src/common/utils.h index 26889792e53d..84e2cbbdc3a5 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -713,6 +713,14 @@ inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape } } +/*! + * \brief Return true if str ends with suffix. + */ +inline bool StringEndsWith(std::string const& str, std::string const& suffix) { + if (suffix.size() > str.size()) return false; + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 283108cc57f4..dc549238188e 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -89,10 +89,8 @@ std::vector OfflineParams(std::vector&& outputs, return outputs; } -inline bool NeedQuantize(NodePtr node, - const std::unordered_set excluded_nodes) { - static auto& quantized_op_map = - Op::GetAttr("FQuantizedOp"); +inline bool NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes) { + static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); if (quantized_op_map.count(node->op())) { bool excluded = false; if (node->attrs.subgraphs.size()) { diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 2b34eb3f8cd1..bfd604d329d6 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -26,6 +26,7 @@ #include "../../nn/mkldnn/mkldnn_base-inl.h" #include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../../quantization/quantization_utils.h" +#include "../../../common/utils.h" #include "mkldnn_conv-inl.h" namespace mxnet { @@ -41,20 +42,20 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, weight->ctx(), true, weight->dtype()); NDArray update_bias = NDArray(beta.storage_type(), beta.shape(), beta.ctx(), true, beta.dtype()); - DType *weight_ptr = weight->data().dptr(); - DType *bias_ptr = no_bias ? nullptr : bias->data().dptr(); - DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); - DType *beta_ptr = beta.Reorder2Default().data().dptr(); - DType *mean_ptr = mean.Reorder2Default().data().dptr(); - DType *var_ptr = variance.Reorder2Default().data().dptr(); + const DType *weight_ptr = weight->data().dptr(); + const DType *bias_ptr = no_bias ? nullptr : bias->data().dptr(); + const DType *gamma_ptr = gamma.Reorder2Default().data().dptr(); + const DType *beta_ptr = beta.Reorder2Default().data().dptr(); + const DType *mean_ptr = mean.Reorder2Default().data().dptr(); + const DType *var_ptr = variance.Reorder2Default().data().dptr(); DType *update_weight_ptr = update_weight.data().dptr(); DType *update_bias_ptr = update_bias.data().dptr(); size_t channel = gamma.shape()[0]; size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3]; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { - DType *p1 = reinterpret_cast(weight_ptr + c * offset); - DType *p2 = reinterpret_cast(update_weight_ptr + c * offset); + const DType *p1 = weight_ptr + c * offset; + DType *p2 = update_weight_ptr + c * offset; DType alpha = (param->fix_gamma ? static_cast(1.0f) : gamma_ptr[c]) / sqrt(var_ptr[c] + param->eps); @@ -84,7 +85,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, std::vector *weight_scales) { using red::limits::MaxValue; using red::limits::MinValue; - DType *weight_ptr = weight->data().dptr(); + const DType *weight_ptr = weight->data().dptr(); NDArray quantized_weight = NDArray(weight->storage_type(), weight->shape(), weight->ctx(), true, mshadow::kInt8); int8_t *quan_weight_ptr = quantized_weight.data().dptr(); @@ -96,7 +97,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, std::vector weight_c_max(channel, MinValue()); #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { - DType *p1 = weight_ptr + c * offset; + const DType *p1 = weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { if (weight_c_min[c] > p1[k]) weight_c_min[c] = p1[k]; @@ -111,7 +112,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, for (int c = 0; c < static_cast(channel); ++c) { DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]); weight_scales->at(c) = int8_range / weight_range; - DType *fp_ptr = weight_ptr + c * offset; + const DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { quan_ptr[k] = std::round(weight_scales->at(c) * fp_ptr[k]); @@ -129,7 +130,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, weight_scales->at(0) = int8_range / weight_range; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int c = 0; c < static_cast(channel); ++c) { - DType *fp_ptr = weight_ptr + c * offset; + const DType *fp_ptr = weight_ptr + c * offset; int8_t *quan_ptr = quan_weight_ptr + c * offset; for (size_t k = 0; k < offset; ++k) { quan_ptr[k] = std::round(weight_scales->at(0) * fp_ptr[k]); @@ -139,7 +140,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, *weight = quantized_weight; if (has_bias) { - DType *bias_ptr = bias->data().dptr(); + const DType *bias_ptr = bias->data().dptr(); NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(), bias->ctx(), true, mshadow::kInt32); int32_t *quan_bias_ptr = quantized_bias.data().dptr(); @@ -278,7 +279,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, *out_max_ptr = mkldnn_param.max_calib_range.value(); } else { mkldnn_param.weight_channelwise_scale = false; - } } @@ -622,18 +622,12 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { return node; } -static inline bool StringEndsWith(std::string const &str, - std::string const &suffix) { - if (suffix.size() > str.size()) return false; - return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); -} - bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, const NodeAttrs &input_attrs) { const std::vector exclude_key{ "weight", "bias", "gamma", "beta", "moving_mean", "moving_var"}; for (auto i : exclude_key) { - if (StringEndsWith(input_attrs.name, i)) { + if (common::StringEndsWith(input_attrs.name, i)) { return true; } } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc index e4effbbcb464..a2c14593e3bd 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc @@ -32,9 +32,9 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { public: /*! \brief pattern match status */ enum SelectStatus { - sFail = 0, - sStart, - sSuccess, + kFail = 0, + kStart, + kSuccess, }; private: @@ -50,7 +50,7 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { if ((!disable_all) && n.op() && n.op()->name == "_sg_mkldnn_conv") { auto const ¶m = nnvm::get(n.attrs.parsed); if (param.full_conv_param.mkldnn_param.quantized) { - status = sStart; + status = kStart; matched_list.clear(); matched_list.push_back(&n); return true; @@ -64,12 +64,12 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { } bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { - if (status == sFail || status == sSuccess || new_node.is_variable()) + if (status == kFail || status == kSuccess || new_node.is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. if (matched_list.back() != &n) { - status = sFail; + status = kFail; return false; } if (new_node.op()->name == "_contrib_requantize") { @@ -77,10 +77,10 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { matched_list.push_back(&new_node); - status = sSuccess; + status = kSuccess; return true; } else { - status = sFail; + status = kFail; } } return false; @@ -88,7 +88,7 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { std::vector Filter( const std::vector &candidates) override { - if (status != sSuccess) { + if (status != kSuccess) { return std::vector(0); } else { return candidates; From 3f24d9792f2f3ada12d1173f387680c9f8fd634a Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 25 Sep 2018 20:44:34 +0800 Subject: [PATCH 16/28] Add TIsMKLDNN for _sg_mkldnn_conv temporarily. --- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index bfd604d329d6..6a6e9f4fd50f 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -649,6 +649,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) .set_attr("FInferType", SgMKLDNNConvInferType) .set_attr("FInferStorageType", SgMKLDNNConvOpStorageType) .set_attr("FStatefulComputeEx", SgMKLDNNConvOpForward) +.set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) From 146e157d31a5aeb4d0fac1c8c32e743d137e5e39 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Wed, 26 Sep 2018 10:44:43 +0800 Subject: [PATCH 17/28] Address reminisce's comments. --- include/mxnet/c_api.h | 20 ++--- .../quantization/quantize_graph_pass.cc | 1 + src/operator/subgraph/mkldnn/mkldnn_conv.cc | 86 +++++++++---------- .../mkldnn_conv_post_quantize_property.cc | 2 +- .../subgraph/mkldnn/mkldnn_conv_property.cc | 4 +- src/operator/subgraph/partition_graph.cc | 4 +- src/operator/subgraph/subgraph_property.h | 4 +- 7 files changed, 60 insertions(+), 61 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 7d62103faadb..ecb77c6a9b1e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1502,16 +1502,16 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, int *complete); /*! -* \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 -* \param sym_handle symbol to be converted -* \param ret_sym_handle quantized symbol result -* \param num_excluded_symbols number of layers excluded from being quantized in the input symbol -* \param excluded_symbols op names to be excluded from being quantized -* \param num_offline number of parameters that are quantized offline -* \param offline_params array of c strings representing the names of params quantized offline -* \param quantized_dtype the quantized destination type for input data. -* \param calib_quantize whether calibrate quantize op with offline calibration data. -*/ + * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 + * \param sym_handle symbol to be converted + * \param ret_sym_handle quantized symbol result + * \param num_excluded_symbols number of layers excluded from being quantized in the input symbol + * \param excluded_symbols op names to be excluded from being quantized + * \param num_offline number of parameters that are quantized offline + * \param offline_params array of c strings representing the names of params quantized offline + * \param quantized_dtype the quantized destination type for input data. + * \param calib_quantize whether calibrate quantize op with offline calibration data. + */ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle, const mx_uint num_excluded_symbols, const char **excluded_symbols, diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index dc549238188e..da84bb388b02 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -99,6 +99,7 @@ inline bool NeedQuantize(NodePtr node, const std::unordered_set& ex if (excluded_nodes.count(node->attrs.name)) { excluded = true; } else { + // Assume index 0 holds subgraph symbol. auto subgraph_sym = node->attrs.subgraphs[0]; DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& node) { if (node->is_variable()) return; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 6a6e9f4fd50f..93434f872790 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -175,9 +175,9 @@ static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param, class SgMKLDNNConvOperator { public: explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) - : initalized(false), + : initalized_(false), subgraph_sym_(*attrs.subgraphs[0]), - param(nnvm::get(attrs.parsed)) {} + param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, const std::vector &inputs, @@ -192,35 +192,33 @@ class SgMKLDNNConvOperator { } private: - bool initalized; + bool initalized_; nnvm::Symbol subgraph_sym_; - MKLDNNConvFusionParam param; - std::shared_ptr fwd; + MKLDNNConvFusionParam param_; + std::shared_ptr fwd_; NDArray cached_weight_; NDArray cached_bias_; - float cached_data_min; - float cached_data_max; - float cached_sum_min; - float cached_sum_max; - size_t weight_ver; - size_t bias_ver; - std::vector weight_scales; + float cached_data_min_; + float cached_data_max_; + float cached_sum_min_; + float cached_sum_max_; + size_t weight_ver_; + size_t bias_ver_; + std::vector weight_scales_; }; void SgMKLDNNConvOperator::Forward(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - auto &full_conv_param = param.full_conv_param; - auto &mkldnn_param = param.full_conv_param.mkldnn_param; - auto &conv_param = param.full_conv_param.conv_param; - auto bn_param = param.bn_param.get(); + auto &full_conv_param = param_.full_conv_param; + auto &mkldnn_param = param_.full_conv_param.mkldnn_param; + auto &conv_param = param_.full_conv_param.conv_param; + auto bn_param = param_.bn_param.get(); size_t input_size = 2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) + (mkldnn_param.with_sum ? 1 : 0) + - (mkldnn_param.quantized - ? 2 + (param.full_conv_param.mkldnn_param.with_sum ? 2 : 0) - : 0); + (mkldnn_param.quantized ? 2 + (full_conv_param.mkldnn_param.with_sum ? 2 : 0) : 0); CHECK_EQ(inputs.size(), input_size); size_t idx = 0; @@ -253,19 +251,19 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, // Check input change // TODO(zhennan): Only update cached_* changed. - if (initalized) { + if (initalized_) { if (mkldnn_param.with_bn) { - if (weight_ver != inputs[in_weight].version() || - ((!conv_param.no_bias) && bias_ver != inputs[in_bias].version())) { - initalized = false; + if (weight_ver_ != inputs[in_weight].version() || + ((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) { + initalized_ = false; } } - if (initalized && mkldnn_param.quantized) { - if (cached_data_min != data_min || cached_data_max != data_max || - cached_sum_min != sum_min || cached_sum_max != sum_max || - weight_ver != inputs[in_weight].version() || - ((!conv_param.no_bias) && bias_ver != inputs[in_bias].version())) { - initalized = false; + if (initalized_ && mkldnn_param.quantized) { + if (cached_data_min_ != data_min || cached_data_max_ != data_max || + cached_sum_min_ != sum_min || cached_sum_max_ != sum_max || + weight_ver_ != inputs[in_weight].version() || + ((!conv_param.no_bias) && bias_ver_ != inputs[in_bias].version())) { + initalized_ = false; } } } @@ -282,17 +280,17 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } } - if (!initalized) { - cached_data_min = data_min; - cached_data_max = data_max; - cached_sum_min = sum_min; - cached_sum_max = sum_max; + if (!initalized_) { + cached_data_min_ = data_min; + cached_data_max_ = data_max; + cached_sum_min_ = sum_min; + cached_sum_max_ = sum_max; full_conv_param.sum_scale = 1.0; cached_weight_ = inputs[in_weight].Reorder2Default(); - weight_ver = inputs[in_weight].version(); + weight_ver_ = inputs[in_weight].version(); if (!conv_param.no_bias) { cached_bias_ = inputs[in_bias].Reorder2Default(); - bias_ver = inputs[in_bias].version(); + bias_ver_ = inputs[in_bias].version(); } else { cached_bias_ = NDArray(); } @@ -315,7 +313,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, QuantizeConvWeightBias(&cached_weight_, &cached_bias_, has_bias, data_min, data_max, mkldnn_param.weight_channelwise_scale, - &weight_scales); + &weight_scales_); }); // Collect scale. size_t channel = cached_weight_.shape()[0]; @@ -340,23 +338,23 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, full_conv_param.requantize_scales.resize(channel); for (size_t c = 0; c < channel; c++) { auto weight_scale = mkldnn_param.weight_channelwise_scale - ? weight_scales[c] - : weight_scales[0]; + ? weight_scales_[c] + : weight_scales_[0]; full_conv_param.requantize_scales[c] = output_scale / data_scale / weight_scale; } } else { - output_scale = data_scale * weight_scales[0]; + output_scale = data_scale * weight_scales_[0]; full_conv_param.requantize_scales.resize(0); } if (mkldnn_param.with_sum) full_conv_param.sum_scale = output_scale / sum_in_scale; } - fwd.reset(new MKLDNNConvForward( + fwd_.reset(new MKLDNNConvForward( full_conv_param, ctx.is_train, data_, cached_weight_, has_bias ? &cached_bias_ : nullptr, output_)); } - initalized = true; + initalized_ = true; std::vector new_inputs; std::vector new_req; if (has_bias) { @@ -366,7 +364,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, new_inputs = {data_, cached_weight_}; new_req = {req[in_data], req[in_weight]}; } - ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd.get(), new_inputs, + ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs, new_req, {output_}); if (mkldnn_param.with_sum) { @@ -625,7 +623,7 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, const NodeAttrs &input_attrs) { const std::vector exclude_key{ - "weight", "bias", "gamma", "beta", "moving_mean", "moving_var"}; + "weight", "bias", "gamma", "beta", "moving_mean", "moving_var", "running_mean"}; for (auto i : exclude_key) { if (common::StringEndsWith(input_attrs.name, i)) { return true; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc index a2c14593e3bd..fc68287b039d 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc @@ -142,7 +142,7 @@ class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { return selector; } - void ConnectSubgraphOutput( + void ConnectSubgraphOutputs( const nnvm::NodePtr n, std::vector *output_entries) const override { for (size_t i = 0; i < output_entries->size(); ++i) { diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index f39b012299cc..eaf67d1ab21a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -188,7 +188,7 @@ class SgMKLDNNConvProperty : public SubgraphProperty { return selector; } - void ConnectSubgraphOutput( + void ConnectSubgraphOutputs( const nnvm::NodePtr n, std::vector *output_entries) const override { // Connect all extern output entries to output[0] @@ -197,7 +197,7 @@ class SgMKLDNNConvProperty : public SubgraphProperty { } } - void ConnectSubgraphInput( + void ConnectSubgraphInputs( const nnvm::NodePtr n, std::vector *input_entries, std::vector *orig_input_entries) const override { auto sym = n->attrs.subgraphs[0]; diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index c789d25a1525..57fb47f82933 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -653,8 +653,8 @@ void CreateSubgraphNode(Graph* g, nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id); // Connect the external nodes to the subgraph node. - subg_prop->ConnectSubgraphOutput(n, &output_entries); - subg_prop->ConnectSubgraphInput(n, &input_entries, &orig_input_entries); + subg_prop->ConnectSubgraphOutputs(n, &output_entries); + subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries); n->inputs = orig_input_entries; const auto& indexed_graph = g->indexed_graph(); diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index c5a04afb63a6..85e9adf4267b 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -94,7 +94,7 @@ class SubgraphProperty { const int subgraph_id = 0) const = 0; // Connect subgraph internal output with external output entries. By default, // each output entry will connect to an unique internal output. - virtual void ConnectSubgraphOutput( + virtual void ConnectSubgraphOutputs( const nnvm::NodePtr n, std::vector *output_entries) const { for (size_t i = 0; i < output_entries->size(); ++i) { @@ -103,7 +103,7 @@ class SubgraphProperty { } // Connect subgraph internal input with external input entries. By default, // each input entry will connect in top sorted order. - virtual void ConnectSubgraphInput( + virtual void ConnectSubgraphInputs( const nnvm::NodePtr n, std::vector *input_entries, std::vector *orig_input_entries) const { n->inputs = *orig_input_entries; From 966839dd2a4c80e933bb0c387202f3e708296913 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Thu, 27 Sep 2018 13:18:35 +0800 Subject: [PATCH 18/28] Handle the case that inplace fail. --- .../quantization/imagenet_gen_qsym_mkldnn.py | 8 +-- include/mxnet/ndarray.h | 6 +++ src/ndarray/ndarray.cc | 12 +++++ src/operator/nn/mkldnn/mkldnn_base-inl.h | 17 ++++++ src/operator/nn/mkldnn/mkldnn_convolution.cc | 17 +----- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 54 +++++++++++++------ 6 files changed, 77 insertions(+), 37 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 0436c4defac2..35f815d29a4c 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -204,10 +204,10 @@ def save_params(fname, arg_params, aux_params, logger=None): raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) - # out = SymbolHandle() - # backend = "MKLDNN_POST_QUANTIZE" - # check_call(_LIB.MXGenBackendSubgraph(qsym.handle, c_str(backend), ctypes.byref(out))) - # qsym = Symbol(out) + out = SymbolHandle() + backend = "MKLDNN_POST_QUANTIZE" + check_call(_LIB.MXGenBackendSubgraph(qsym.handle, c_str(backend), ctypes.byref(out))) + qsym = Symbol(out) save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 47706e8a7947..fda8beab1f4d 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -627,6 +627,12 @@ class NDArray { } #if MXNET_USE_MKLDNN == 1 + /* + * Create NDArray from mkldnn memory. + * mkldnn_mem The mkldnn memory to be managed. + * static_data If true, mkldnn memory won't be freed on destruction. + */ + explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true); /* * Test if the data is stored in one of special MKLDNN format. */ diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 90516c1f562e..37198833ddea 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -168,6 +168,18 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 +NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data) + : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + auto mem_pd = mkldnn_mem->get_primitive_desc(); + auto mem_desc = mem_pd.desc(); + shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); + dtype_ = get_mxnet_type(mem_desc.data.data_type); + auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_); + ptr_ = std::make_shared(data, 0); + ptr_->mkl_mem_ = std::make_shared(mem_pd, ptr_->shandle.dptr); + ptr_->static_data = static_data; +} + NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_.Size(), shape.Size()) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index f90291479b0f..8a2f4a3e5011 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -191,6 +191,23 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) { } } +static inline int get_mxnet_type(mkldnn_data_type_t dtype) { + auto mkldnn_dtype = static_cast(dtype); + switch (mkldnn_dtype) { + case mkldnn::memory::data_type::f32: + return mshadow::kFloat32; + case mkldnn::memory::data_type::s32: + return mshadow::kInt32; + case mkldnn::memory::data_type::s8: + return mshadow::kInt8; + case mkldnn::memory::data_type::u8: + return mshadow::kUint8; + default: + LOG(FATAL) << "unknown MKLDNN type"; + return mshadow::kFloat32; + } +} + inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int ndim) { mkldnn::memory::dims dims(ndim); for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 639d52b9ad0e..a5a61c3cc6dc 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -42,27 +42,12 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4; } -inline static mkldnn::memory::desc GetInDataMemDesc(const NDArray &arr) { - mkldnn::memory::dims dims(arr.shape().ndim()); - for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i]; - int mkldnn_dtype; - // For INT8 case, currently we only support uint8 as input data so need - // to create the memory primitive of uint8 type - if (arr.dtype() == mshadow::kInt8) { - mkldnn_dtype = mshadow::kUint8; - } else { - mkldnn_dtype = arr.dtype(); - } - return mkldnn::memory::desc{dims, get_mkldnn_type(mkldnn_dtype), - mkldnn::memory::format::any}; -} - mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; - auto data_md = GetInDataMemDesc(data); + auto data_md = GetMemDesc(data); auto weight_md = GetWeightDesc(weights, param.conv_param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 93434f872790..114d35c7b8d1 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -79,8 +79,7 @@ static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { template static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, - bool has_bias, float data_min, - float data_max, + bool has_bias, float data_scale, bool weight_channelwise_scale, std::vector *weight_scales) { using red::limits::MaxValue; @@ -144,7 +143,6 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias, NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(), bias->ctx(), true, mshadow::kInt32); int32_t *quan_bias_ptr = quantized_bias.data().dptr(); - DType data_scale = uint8_range / MaxAbs(data_min, data_max); for (size_t c = 0; c < channel; ++c) { auto weight_scale = weight_channelwise_scale ? weight_scales->at(c) : weight_scales->at(0); @@ -177,7 +175,8 @@ class SgMKLDNNConvOperator { explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs) : initalized_(false), subgraph_sym_(*attrs.subgraphs[0]), - param_(nnvm::get(attrs.parsed)) {} + param_(nnvm::get(attrs.parsed)), + inplace_(false) {} void Forward(const OpContext &ctx, const std::vector &inputs, @@ -205,6 +204,7 @@ class SgMKLDNNConvOperator { size_t weight_ver_; size_t bias_ver_; std::vector weight_scales_; + bool inplace_; }; void SgMKLDNNConvOperator::Forward(const OpContext &ctx, @@ -246,8 +246,26 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, mkldnn_param.quantized ? outputs[kMax].data().dptr() : nullptr; CHECK_EQ(input_size, idx); bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias; - NDArray data_ = inputs[in_data]; - NDArray output_ = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut]; + NDArray data = inputs[in_data]; + NDArray output = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut]; + + // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed. + if (mkldnn_param.with_sum) { + if (!initalized_) { + auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); + auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); + // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option, + // which make check (req[kOut] == ) + if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) { + inplace_ = true; + } + } + if (!inplace_) { + auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); + const_cast(outputs[kOut]).CopyFrom(*in_mkl_mem); + output = NDArray(outputs[kOut].GetMKLDNNData()); + } + } // Check input change // TODO(zhennan): Only update cached_* changed. @@ -309,26 +327,28 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, } // Quantize weight and bias. if (mkldnn_param.quantized) { + CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8); + auto data_range = (data.dtype() == mshadow::kInt8) ? int8_range : uint8_range; + float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_); MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { QuantizeConvWeightBias(&cached_weight_, &cached_bias_, - has_bias, data_min, data_max, + has_bias, data_scale, mkldnn_param.weight_channelwise_scale, &weight_scales_); }); // Collect scale. size_t channel = cached_weight_.shape()[0]; - float data_scale = uint8_range / MaxAbs(data_min, data_max); float sum_in_scale = 1.0; float out_range; float quantized_out_range; float output_scale; - if (data_min < 0.0) { - // TODO(zhennan): we need to use offset to convert int8 to uint8. + if (cached_data_min_ < 0.0) { + // TODO(zhennan): Support int8 input when mkldnn supports. LOG(FATAL) << "Can't handle negetive value for QuantizeData"; } if (mkldnn_param.with_sum) { - auto quantized_sum_range = sum_min < 0 ? int8_range : uint8_range; - sum_in_scale = quantized_sum_range / MaxAbs(sum_min, sum_max); + auto quantized_sum_range = cached_sum_min_ < 0 ? int8_range : uint8_range; + sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_); } if (post_requantize) { quantized_out_range = @@ -351,21 +371,21 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, full_conv_param.sum_scale = output_scale / sum_in_scale; } fwd_.reset(new MKLDNNConvForward( - full_conv_param, ctx.is_train, data_, cached_weight_, - has_bias ? &cached_bias_ : nullptr, output_)); + full_conv_param, ctx.is_train, data, cached_weight_, + has_bias ? &cached_bias_ : nullptr, output)); } initalized_ = true; std::vector new_inputs; std::vector new_req; if (has_bias) { - new_inputs = {data_, cached_weight_, cached_bias_}; + new_inputs = {data, cached_weight_, cached_bias_}; new_req = {req[in_data], req[in_weight], req[in_bias]}; } else { - new_inputs = {data_, cached_weight_}; + new_inputs = {data, cached_weight_}; new_req = {req[in_data], req[in_weight]}; } ConvolutionFusionComputeExCPU(full_conv_param, ctx, fwd_.get(), new_inputs, - new_req, {output_}); + new_req, {output}); if (mkldnn_param.with_sum) { auto out = const_cast(outputs[kOut]); From 9ccc4bd701db022603f7cfec69f32c86b3a6477d Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 28 Sep 2018 11:56:41 +0800 Subject: [PATCH 19/28] pass unit test. --- .../nn/mkldnn/mkldnn_convolution-inl.h | 2 -- src/operator/nn/mkldnn/mkldnn_convolution.cc | 3 +- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 4 +-- .../subgraph/mkldnn/mkldnn_conv_property.cc | 36 +++++++++---------- tests/python/mkl/test_subgraph.py | 6 +++- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 4dff9b8d46a2..971c66ad9dd2 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -37,8 +37,6 @@ namespace mxnet { namespace op { struct MKLDNNConvParam : public dmlc::Parameter { - // When adding more members into this class, please double check GetHash() - // won't overflow. bool with_bn; bool with_relu; bool with_sum; diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index a5a61c3cc6dc..6a70ae40ac8f 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -308,8 +308,7 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, if (param.mkldnn_param.with_sum) { out_mem = mkldnn_output_t( OutDataOp::Noop, - const_cast(out_data[conv::kOut].GetMKLDNNDataReorder( - fwd->fwd_pd.dst_primitive_desc()))); + const_cast(out_data[conv::kOut].GetMKLDNNData())); } else { out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd->fwd_pd.dst_primitive_desc(), req[conv::kOut]); diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 114d35c7b8d1..35062dc0bfab 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -255,7 +255,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, auto in_mkl_mem = inputs[in_sum].GetMKLDNNData(); auto out_mkl_mem = outputs[kOut].GetMKLDNNData(); // TODO(zhennan): Currently, mkldnn fallback mechanism will break inplace option, - // which make check (req[kOut] == ) + // which make check (req[kOut] == kWriteInplace) useless. if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) { inplace_ = true; } @@ -643,7 +643,7 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, const NodeAttrs &input_attrs) { const std::vector exclude_key{ - "weight", "bias", "gamma", "beta", "moving_mean", "moving_var", "running_mean"}; + "weight", "bias", "gamma", "beta", "moving_mean", "moving_var"}; for (auto i : exclude_key) { if (common::StringEndsWith(input_attrs.name, i)) { return true; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index eaf67d1ab21a..e5220f24d34d 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -29,11 +29,11 @@ class SgMKLDNNConvSelector : public SubgraphSelector { public: /*! \brief pattern match status */ enum SelectStatus { - sFail = 0, - sStart, - sBN, - sSum, - sSuccess, + kFail = 0, + kStart, + kBN, + kSum, + kSuccess, }; private: @@ -53,7 +53,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { bool Select(const nnvm::Node &n) override { if (n.op() && n.op()->name == "Convolution") { - status = disable_all ? sSuccess : sStart; + status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); return true; @@ -66,7 +66,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { } bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { - if (status == sFail || status == sSuccess || new_node.is_variable()) + if (status == kFail || status == kSuccess || new_node.is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. @@ -74,25 +74,25 @@ class SgMKLDNNConvSelector : public SubgraphSelector { while (matched_list.back() != &n) { matched_list.pop_back(); } - status = sSuccess; + status = kSuccess; return false; } // Use status machine to do selection. The status change is - // sStart -> sBN -> sSum -> sSuccess + // kStart -> kBN -> kSum -> kSuccess switch (status) { - case sStart: + case kStart: if ((!disable_conv_bn) && new_node.op()->name == "BatchNorm") { matched_list.push_back(&new_node); - status = sBN; + status = kBN; return true; } - case sBN: + case kBN: if ((!disable_conv_sum) && new_node.op()->name == "elemwise_add") { matched_list.push_back(&new_node); - status = sSum; + status = kSum; return true; } - case sSum: + case kSum: default: if ((!disable_conv_relu) && new_node.op()->name == "Activation") { const ActivationParam ¶m = @@ -100,21 +100,21 @@ class SgMKLDNNConvSelector : public SubgraphSelector { if (param.act_type == activation::kReLU) { matched_list.push_back(&new_node); // If we find conv+relu, then we can't match bn anymore. - if (status == sStart) status = sBN; + if (status == kStart) status = kBN; return true; } else { - status = sSuccess; + status = kSuccess; return false; } } - status = sSuccess; + status = kSuccess; return false; } } std::vector Filter( const std::vector &candidates) override { - if (status == sFail) { + if (status == kFail) { return std::vector(0); } else { return candidates; diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 8f7cac57decc..d09ef650d8bc 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -38,7 +38,7 @@ def check_qsym_calibrated(qsym): return 0, 0 assert ''.join(qsym.attr_dict().keys()).find('quantized_') != -1 for k, v in attrs.items(): - if k.find('requantize_sg_mkldnn_conv') != -1: + if k.find('_sg_mkldnn_conv') != -1: assert 'min_calib_range' in v assert 'max_calib_range' in v min_value = v['min_calib_range'] @@ -83,6 +83,10 @@ def check_quantize(sym, data_shape, label_shape, data_val, sym_output): #disable_requantize=True, calib_quantize_op=True, num_calib_examples=20) + out = SymbolHandle() + backend = "MKLDNN_POST_QUANTIZE" + check_call(_LIB.MXGenBackendSubgraph(qsym.handle, c_str(backend), ctypes.byref(out))) + qsym = Symbol(out) minVar, maxVar = check_qsym_calibrated(qsym) rtol = (maxVar - minVar) / 256 qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, label_shape) From b00c09ebe3b66eb72827707fac71bd1e7452cfa1 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 28 Sep 2018 20:25:47 +0800 Subject: [PATCH 20/28] Add symbol api get_backend_symbol() --- .../quantization/imagenet_gen_qsym_mkldnn.py | 10 ++-------- python/mxnet/symbol/symbol.py | 17 +++++++++++++++++ tests/python/mkl/test_subgraph.py | 10 ++-------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 35f815d29a4c..e06276115154 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -118,10 +118,7 @@ def save_params(fname, arg_params, aux_params, logger=None): prefix, epoch = download_model(model_name=args.model, logger=logger) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) - out = SymbolHandle() - backend = "MKLDNN" - check_call(_LIB.MXGenBackendSubgraph(sym.handle, c_str(backend), ctypes.byref(out))) - sym = Symbol(out) + sym = sym.get_backend_symbol('MKLDNN') # get batch size batch_size = args.batch_size @@ -204,10 +201,7 @@ def save_params(fname, arg_params, aux_params, logger=None): raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) - out = SymbolHandle() - backend = "MKLDNN_POST_QUANTIZE" - check_call(_LIB.MXGenBackendSubgraph(qsym.handle, c_str(backend), ctypes.byref(out))) - qsym = Symbol(out) + qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 554539b424ad..eaf22f3bec1a 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2439,6 +2439,23 @@ def squeeze(self, *args, **kwargs): """ return op.squeeze(self, *args, **kwargs) + def get_backend_symbol(self, backend): + """Return symbol for target backend. + + Parameters + ---------- + backend : str + The backend names. + + Returns + ------- + out : Symbol + The created Symbol for target backend. + """ + out = SymbolHandle() + check_call(_LIB.MXGenBackendSubgraph(self.handle, c_str(backend), ctypes.byref(out))) + return Symbol(out) + def wait_to_read(self): raise NotImplementedForSymbol(self.wait_to_read, None) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index d09ef650d8bc..69b2a4ebafd8 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -83,10 +83,7 @@ def check_quantize(sym, data_shape, label_shape, data_val, sym_output): #disable_requantize=True, calib_quantize_op=True, num_calib_examples=20) - out = SymbolHandle() - backend = "MKLDNN_POST_QUANTIZE" - check_call(_LIB.MXGenBackendSubgraph(qsym.handle, c_str(backend), ctypes.byref(out))) - qsym = Symbol(out) + qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") minVar, maxVar = check_qsym_calibrated(qsym) rtol = (maxVar - minVar) / 256 qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, label_shape) @@ -94,10 +91,7 @@ def check_quantize(sym, data_shape, label_shape, data_val, sym_output): def check_fusion(sym, date_shape, label_shape, name, nofusion=False): exe = sym.simple_bind(mx.cpu(), data=date_shape, grad_req='null') - out = SymbolHandle() - backend = "MKLDNN" - check_call(_LIB.MXGenBackendSubgraph(sym.handle, c_str(backend), ctypes.byref(out))) - sym_sg = Symbol(out) + sym_sg = sym.get_backend_symbol("MKLDNN") exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') mx.random.seed(12345) From 3b7f4f79cada2b1efc6b405cdbef913b9c8b5316 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Fri, 28 Sep 2018 21:25:58 +0800 Subject: [PATCH 21/28] Retrigger ci --- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 35062dc0bfab..01c5ebb06f81 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -187,7 +187,7 @@ class SgMKLDNNConvOperator { const std::vector &req, const std::vector &outputs) { LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports " - "inference computation"; + "inference computation."; } private: From 97d1841e3ef7f92e16d296b2fb417a041af7d835 Mon Sep 17 00:00:00 2001 From: huangzhiyuan Date: Fri, 28 Sep 2018 22:33:00 +0800 Subject: [PATCH 22/28] update the test case --- tests/python/mkl/test_subgraph.py | 550 ++++++++++++++++++++++-------- 1 file changed, 399 insertions(+), 151 deletions(-) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 69b2a4ebafd8..3ca59eab57d4 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -32,193 +32,441 @@ sys.path.append(os.path.join(curr_path, '../unittest/')) from common import with_seed +DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] +DATA_LABEL=[(4, 10), (32, 10), (64, 10)] +MIN_VALUE=-1.0 +MAX_VALUE=1.0 + def check_qsym_calibrated(qsym): - attrs = qsym.attr_dict() - if ''.join(qsym.attr_dict().keys()).find('quantized_pool') != -1: - return 0, 0 - assert ''.join(qsym.attr_dict().keys()).find('quantized_') != -1 - for k, v in attrs.items(): - if k.find('_sg_mkldnn_conv') != -1: + assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 + for k, v in qsym.attr_dict().items(): + if k.find('quantized_sg_mkldnn_conv') != -1: assert 'min_calib_range' in v assert 'max_calib_range' in v - min_value = v['min_calib_range'] - max_value = v['max_calib_range'] if k.find('_quantize') != -1: assert v['out_type'] == 'uint8' - return float(min_value), float(max_value) -def check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, label_shape): +def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) mod.bind(for_training=False, data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) mod.set_params(qarg_params, qaux_params) - batch = mx.io.DataBatch(data_val, []) mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() return output -def check_quantize(sym, data_shape, label_shape, data_val, sym_output): - mod = Module(symbol=sym) - mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)], for_training=False) - mod.init_params() - arg_params, aux_params = mod.get_params() - excluded_sym_names = [] - if mx.current_context() == mx.cpu(): - excluded_sym_names += ['fc'] - calib_data = mx.nd.random.uniform(shape=data_shape) - calib_data = NDArrayIter(data=calib_data) - calib_data = DummyIter(calib_data) - calib_layer = lambda name: name.endswith('_output') - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, - arg_params=arg_params, - aux_params=aux_params, - ctx=mx.current_context(), - excluded_sym_names=excluded_sym_names, - quantized_dtype='uint8', - calib_mode='naive', - calib_data=calib_data, - calib_layer=calib_layer, - #disable_requantize=True, - calib_quantize_op=True, - num_calib_examples=20) - qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") - minVar, maxVar = check_qsym_calibrated(qsym) - rtol = (maxVar - minVar) / 256 - qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, data_val, data_shape, label_shape) - assert_allclose(qsym_output[0].asnumpy(), sym_output[0].asnumpy(), rtol=rtol) - -def check_fusion(sym, date_shape, label_shape, name, nofusion=False): - exe = sym.simple_bind(mx.cpu(), data=date_shape, grad_req='null') - sym_sg = sym.get_backend_symbol("MKLDNN") - exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') +def check_quantize(sym, arg_params, aux_params, data_shape, label_shape, batch, sym_output): + excluded_sym_names = [] + if mx.current_context() == mx.cpu(): + excluded_sym_names += ['fc'] + calib_data = mx.nd.random.uniform(shape=data_shape) + calib_data = NDArrayIter(data=calib_data) + calib_data = DummyIter(calib_data) + calib_layer = lambda name: name.endswith('_output') + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + excluded_sym_names=excluded_sym_names, + quantized_dtype='uint8', + calib_mode='naive', + calib_data=calib_data, + calib_layer=calib_layer, + calib_quantize_op=True, + num_calib_examples=20) + qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") + check_qsym_calibrated(qsym) + qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) - mx.random.seed(12345) - for k, v in exe.arg_dict.items(): - v = mx.random.uniform(-1.0, 1.0, shape=v.shape) - data_val = [exe.arg_dict['data']] + diff = mx.nd.abs(sym_output - qsym_output.astype(sym_output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 - fwd = exe.forward(is_train=False) - fwd[0].wait_to_read() +@with_seed() +def check_fusion(sym, data_shape, label_shape, attrs_op): + dev = mx.cpu() + mod = Module(symbol=sym) + mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) + mod.init_params(mx.init.Normal(0.5)) + arg_params, aux_params = mod.get_params() - fwd_sg = exe_sg.forward(is_train=False) - fwd_sg[0].wait_to_read() + data = [mx.random.uniform(MIN_VALUE, MAX_VALUE, shape=shape, ctx=dev) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) - # Check the result accuracy based on fp32 fusion - assert_allclose(fwd[0].asnumpy(), fwd_sg[0].asnumpy(), rtol=0) - attrs=sym_sg.attr_dict() - if not nofusion: - assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1 - for k, v in attrs.items(): + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + + sym_sg = sym.get_backend_symbol("MKLDNN") + mod_sg = Module(symbol=sym) + mod_sg.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) + mod_sg.set_params(arg_params, aux_params) + + mod_sg.forward(batch, is_train=False) + for output_sg in mod_sg.get_outputs(): + output_sg.wait_to_read() + + assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1 + + for k, v in sym_sg.attr_dict().items(): if k.find('sg_mkldnn_conv') != -1: - for attr_op in name: + for attr_op in attrs_op: assert v[attr_op] == 'true' + # Check the result accuracy based on fp32 fusion + assert_allclose(output[0].asnumpy(), output_sg[0].asnumpy(), rtol = 0) # fp32 to uint8 - if nofusion: - check_quantize(sym, date_shape, label_shape, data_val, fwd[0]) - else: check_quantize(sym_sg, date_shape, label_shape, data_val, fwd[0]) + check_quantize(sym_sg, arg_params, aux_params, data_shape, label_shape, batch, output_sg) + +def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)): + for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): + sym_sg = sym.get_backend_symbol("MKLDNN") + exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') -def single_conv(): - data = mx.symbol.Variable('data') - weight = mx.symbol.Variable('weight') + attrs_dict = sym_sg.attr_dict() + for k, v in attrs_dict.items(): + if k.find('sg_mkldnn_conv') != -1: + for attr in attrs: + assert v[attr] == 'true' + for exc_attr in excluded_attr: + assert exc_attr not in v.keys() + +def head_symbol(): + data = mx.symbol.Variable('data', dtype='float32') + weight = mx.symbol.Variable('weight', dtype='float32') bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') - conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) - fc = mx.sym.FullyConnected(data=conv, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - return sym + return bn, weight -def conv_bn(): - data = mx.symbol.Variable('data') - weight = mx.symbol.Variable('weight') - bn1 = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn1') - conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) - bn = mx.symbol.BatchNorm(data=conv, name="bn") - fc = mx.sym.FullyConnected(data=bn, num_hidden=10, flatten=True, name='fc') +def tail_symbol(sym): + fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') return sym -def conv_relu(): - data = mx.symbol.Variable('data') - weight = mx.symbol.Variable('weight') - bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') - conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) +# single conv fuision case +def single_conv(no_bias): + conv_attr = [''] + data, weight = head_symbol() + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + sym = tail_symbol(conv) + return sym, conv_attr + +# conv + bn fusion case +def conv_bn(no_bias): + conv_bn_attr = ['with_bn'] + data, weight = head_symbol() + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + sym = tail_symbol(bn1) + return sym, conv_bn_attr + +# conv + relu fusion case +def conv_relu(no_bias): + conv_relu_attr = ['with_relu'] + data, weight = head_symbol() + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") - fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - return sym + sym = tail_symbol(relu) + return sym, conv_relu_attr -def conv_sum(): - data = mx.symbol.Variable('data') - weight = mx.symbol.Variable('weight') - bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') - conv = mx.symbol.Convolution(data=bn, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) - conv1 = mx.symbol.Convolution(data=bn, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1)) - sum1 = conv + conv1 - fc = mx.sym.FullyConnected(data=sum1, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - return sym +# conv + add fusion case +def conv_add(no_bias): + conv_add_attr = ['with_sum'] + data, weight = head_symbol() + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + sum = conv + conv1 + sym = tail_symbol(sum) + return sym, conv_add_attr -def conv_bn_relu(): - data = mx.symbol.Variable('data') - weight = mx.symbol.Variable('weight') - bn1 = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn1') - conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) - bn = mx.symbol.BatchNorm(data=conv, name="bn") - relu = mx.symbol.Activation(data=bn, name='relu', act_type="relu") - fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - return sym +# conv + bn + relu fusion case +def conv_bn_relu(no_bias): + conv_bn_relu_attr = ['with_bn', 'with_relu'] + data, weight = head_symbol() + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + relu = mx.symbol.Activation(data=bn1, name='relu', act_type="relu") + sym = tail_symbol(relu) + return sym, conv_bn_relu_attr -def conv_bn_sum_relu(): - data = mx.symbol.Variable('data') - weight = mx.symbol.Variable('weight') - bn1 = mx.symbol.BatchNorm(data=data, name="bn1") - conv = mx.symbol.Convolution(data=bn1, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) - bn = mx.symbol.BatchNorm(data=conv, name="bn") - conv1 = mx.symbol.Convolution(data=bn1, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1)) - sum1 = bn + conv1 +# conv + bn + add + relu fusion case +def conv_bn_sum_relu(no_bias): + conv_bn_add_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] + data, weight = head_symbol() + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + sum1 = bn1 + conv1 relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") - fc = mx.sym.FullyConnected(data=relu, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - return sym + sym = tail_symbol(relu) + return sym, conv_bn_add_relu_attr -def int8_pooling(): - data = mx.symbol.Variable('data') - bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') - pool = mx.sym.Pooling(data=bn, kernel=(4, 4), pool_type='avg', name='pool') - fc = mx.sym.FullyConnected(data=pool, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') +def tail_neg_symbol(sym1, sym2): + fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') + fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2') + concat = mx.sym.Concat(*[fc1, fc2], name="concat") + sym = mx.sym.SoftmaxOutput(data=concat, name='softmax') return sym +# conv + bn can't be fusion case +# eg.1 +# conv --------- > bn +# | +# | +# -------------> [custom op] +def neg_conv_bn(): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol() + + # eg.1 ([custom op] = pool) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(bn1, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + +# conv + relu can't be fusion case +# eg.1 +# conv -----------> relu +# | +# | +# ---------------> [custom op] +def neg_conv_relu(): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol() + + # eg.1 ([custom op] = pool) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(relu, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + +# conv + add can't be fusion case +# eg.1 +# ---------------> [custom op] +# | +# | +# conv -----------> add +# | +# | +# added ------------> +def neg_conv_add(): + syms = [] + attrs = [] + excluded_attrs = [] + val = mx.symbol.Variable('addval') + data, weight = head_symbol() + + # eg.1 ([custom op] = pool, [added op] = val) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + sum1 = conv + val + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(sum1, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append('with_sum') + return syms, attrs, excluded_attrs + +# conv + bn + relu can't be fusion case +# eg.1 +# --------------> [custom op] +# | +# conv -----------> bn -----------> relu +# +# eg.2 +# --------------> [custom op] +# | +# conv -----------> bn -----------> relu +def neg_conv_bn_relu(): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol() + + # eg.1 ([custom op] = pool11) + conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") + relu11 = mx.symbol.Activation(data=bn11, name='relu11', act_type="relu") + pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') + sym1 = tail_neg_symbol(relu11, pool11) + + syms.append(sym1) + attrs.append([]) + excluded_attrs.append([]) + + # eg.2 ([custom op] = pool) + conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") + relu21 = mx.symbol.Activation(data=bn21, name='relu21', act_type="relu") + pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') + sym2 = tail_neg_symbol(relu21, pool21) + + syms.append(sym2) + attrs.append(['with_bn']) + excluded_attrs.append(['with_relu']) + return syms, attrs, excluded_attrs + +# conv + bn + add + relu can't be fusion case +# eg.1 +# --------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +# +# eg.2 +# -------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +# +# eg.3 +# --------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +def neg_conv_bn_add_relu(): + syms = [] + attrs = [] + excluded_attrs = [] + addVal = mx.symbol.Variable('addval') + data, weight = head_symbol() + + # eg.1 + conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") + sum11 = bn11 + addVal + relu11 = mx.symbol.Activation(data=sum11, name='relu11', act_type="relu") + pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') + sym1 = tail_neg_symbol(relu11, pool11) + + syms.append(sym1) + attrs.append([]) + excluded_attrs.append(['with_sum', 'with_postsum_relu', 'with_bn']) + + # eg.2 + conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") + sum21 = bn21 + addVal + relu21 = mx.symbol.Activation(data=sum21, name='relu21', act_type="relu") + pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') + sym2 = tail_neg_symbol(relu21, pool21) + + syms.append(sym2) + attrs.append(['with_bn']) + excluded_attrs.append(['with_sum', 'with_postsum_relu']) + + # eg.3 + conv31 = mx.symbol.Convolution(data=data, weight=weight, name='conv31', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn31 = mx.symbol.BatchNorm(data=conv31, name="bn31") + sum31 = bn31 + addVal + relu31 = mx.symbol.Activation(data=sum31, name='relu31', act_type="relu") + pool31 = mx.sym.Pooling(data=sum31, kernel=(4, 4), pool_type='avg', name='pool31') + sym3 = tail_neg_symbol(relu31, pool31) + + syms.append(sym3) + attrs.append(['with_bn', 'with_sum']) + excluded_attrs.append(['with_postsum_relu']) + return syms, attrs, excluded_attrs + +@with_seed() +def test_pos_single_conv(): + for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): + net, attrs = single_conv(False) + check_fusion(net, data_shape, label_shape, attrs) + net, attrs = single_conv(True) + check_fusion(net, data_shape, label_shape, attrs) + +@with_seed() +def test_pos_conv_relu(): + for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): + net, attrs = conv_relu(False) + check_fusion(net, data_shape, label_shape, attrs) + net, attrs = conv_relu(True) + check_fusion(net, data_shape, label_shape, attrs) + +@with_seed() +def test_pos_conv_bn(): + for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): + net, attrs = conv_bn(False) + check_fusion(net, data_shape, label_shape, attrs) + net, attrs = conv_bn(True) + check_fusion(net, data_shape, label_shape, attrs) + +@with_seed() +def test_pos_conv_add(): + for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): + net, attrs = conv_add(False) + check_fusion(net, data_shape, label_shape, attrs) + net, attrs = conv_add(True) + check_fusion(net, data_shape, label_shape, attrs) + @with_seed() -def test_sugbraph(): - def check_test_sugbraph(): - conv_attr = [''] - conv_relu_attr = ['with_relu'] - conv_bn_attr = ['with_bn'] - conv_sum_attr = ['with_sum'] - conv_bn_relu_attr = ['with_bn', 'with_relu'] - conv_bn_sum_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] - - shape = [(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] - label = [(4, 10), (32, 10), (64, 10)] - - for date_shape, label_shape in zip(shape, label): - net = conv_bn_sum_relu() - check_fusion(net, date_shape, label_shape, conv_bn_sum_relu_attr) - net = single_conv() - check_fusion(net, date_shape, label_shape, conv_attr) - net = conv_relu() - check_fusion(net, date_shape, label_shape, conv_relu_attr) - net = conv_bn() - check_fusion(net, date_shape, label_shape, conv_bn_attr) - net = conv_sum() - check_fusion(net, date_shape, label_shape, conv_sum_attr) - net = conv_bn_relu() - check_fusion(net, date_shape, label_shape, conv_bn_relu_attr) - net = int8_pooling() - check_fusion(net, date_shape, label_shape, '', True) - - check_test_sugbraph() +def test_pos_conv_bn_relu(): + for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): + net, attrs = conv_bn_relu(False) + check_fusion(net, data_shape, label_shape, attrs) + net, attrs = conv_bn_relu(True) + check_fusion(net, data_shape, label_shape, attrs) + +@with_seed() +def test_pos_conv_bn_sum_relu(): + for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): + net, attrs = conv_bn_sum_relu(False) + check_fusion(net, data_shape, label_shape, attrs) + net, attrs = conv_bn_sum_relu(True) + check_fusion(net, data_shape, label_shape, attrs) + +@with_seed() +def test_neg_conv_bn(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_bn() + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_relu(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_relu() + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_add(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_add() + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_bn_relu(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_bn_relu() + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + +@with_seed() +def test_neg_conv_bn_add_relu(): + for data_shape in DATA_SHAPE: + syms, attrs, excluded_attrs = neg_conv_bn_add_relu() + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +if __name__ == "__main__": + import nose + nose.runmodule() \ No newline at end of file From b009f135a0a3cb9afd7e18cec7e18507b7ce7f47 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 29 Sep 2018 12:57:41 +0800 Subject: [PATCH 23/28] Check subgraph index. --- .../quantization/quantize_graph_pass.cc | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index da84bb388b02..13f2b3c6ba05 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -91,27 +91,27 @@ std::vector OfflineParams(std::vector&& outputs, inline bool NeedQuantize(NodePtr node, const std::unordered_set& excluded_nodes) { static auto& quantized_op_map = Op::GetAttr("FQuantizedOp"); - if (quantized_op_map.count(node->op())) { - bool excluded = false; - if (node->attrs.subgraphs.size()) { - // This is a subgraph node, try to match subgraph name first, - // and then try to match inner node. - if (excluded_nodes.count(node->attrs.name)) { - excluded = true; - } else { - // Assume index 0 holds subgraph symbol. + static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); + const auto& op = node->op(); + if (op && quantized_op_map.count(op)) { + bool need = true; + if (excluded_nodes.count(node->attrs.name)) { + need = false; + } else if (node->attrs.subgraphs.size()) { + ExecType exec_type = fexec_type.count(op) ? fexec_type[op](node->attrs) : ExecType::kSync; + if (exec_type != ExecType::kSubgraphExec) { + // This is a fused subgraph node, try to match inner node. + CHECK_EQ(node->attrs.subgraphs.size(), 1); auto subgraph_sym = node->attrs.subgraphs[0]; - DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& node) { - if (node->is_variable()) return; - if (excluded_nodes.count(node->attrs.name)) { - excluded = true; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& n) { + if (n->is_variable()) return; + if (excluded_nodes.count(n->attrs.name)) { + need = false; } }); } - } else { - excluded = excluded_nodes.count(node->attrs.name); } - return !excluded; + return need; } return false; } From 813610cbeb73681d6fec7c59f0156a74f714a8f8 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 29 Sep 2018 15:12:08 +0800 Subject: [PATCH 24/28] Use index as FAvoidQuantizeInput's parameter. --- include/mxnet/op_attr_types.h | 2 +- src/common/utils.h | 8 ------ .../quantization/quantize_graph_pass.cc | 10 +++++--- src/operator/subgraph/mkldnn/mkldnn_conv.cc | 25 +++++++++++-------- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 4ddce8643fbc..dd818457f827 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -306,7 +306,7 @@ using FNeedRequantize = std::function; * which can handle fp32 inputs directly. */ using FAvoidQuantizeInput = std::function; + size_t index)>; } // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 84e2cbbdc3a5..26889792e53d 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -713,14 +713,6 @@ inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape } } -/*! - * \brief Return true if str ends with suffix. - */ -inline bool StringEndsWith(std::string const& str, std::string const& suffix) { - if (suffix.size() > str.size()) return false; - return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); -} - } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 13f2b3c6ba05..c53382a4e793 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -141,7 +141,8 @@ Graph QuantizeGraph(Graph &&src) { new_node = fquantized_op(node->attrs); // add data into quantized op input - for (const auto& e : node->inputs) { + for (size_t i = 0; i < node->inputs.size(); ++i) { + const auto& e = node->inputs[i]; NodePtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{ mirror_node, e.index, e.version}; @@ -151,7 +152,7 @@ Graph QuantizeGraph(Graph &&src) { // e's source node and the newly created quantize op so that the quantize op can be // reused next time when the same entry is visited again. if (avoid_quantize_input_map.count(node->op()) && - avoid_quantize_input_map[node->op()](node->attrs, e.node->attrs)) { + avoid_quantize_input_map[node->op()](node->attrs, i)) { new_node->inputs.emplace_back(mirror_entry); } else if (!NeedQuantize(e.node, excluded_nodes) && (mirror_node->op() == nullptr || @@ -188,7 +189,8 @@ Graph QuantizeGraph(Graph &&src) { // add min and max into quantized op input assume order of quantized op inputs is: // data1, data2, ..., min1, max1, min2, max2, ... - for (const auto& e : node->inputs) { + for (size_t i = 0; i < node->inputs.size(); ++i) { + const auto& e = node->inputs[i]; NodePtr mirror_node = mirror_map.at(e.node.get()); if (mirror_node->op() != nullptr && mirror_node->op()->name == "_contrib_dequantize") { @@ -200,7 +202,7 @@ Graph QuantizeGraph(Graph &&src) { uint32_t min_index = 1; uint32_t max_index = 2; if (avoid_quantize_input_map.count(node->op()) && - avoid_quantize_input_map[node->op()](node->attrs, e.node->attrs)) { + avoid_quantize_input_map[node->op()](node->attrs, i)) { // skip non-quantized input continue; } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 01c5ebb06f81..a1083d09b7b5 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -26,7 +26,6 @@ #include "../../nn/mkldnn/mkldnn_base-inl.h" #include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../../quantization/quantization_utils.h" -#include "../../../common/utils.h" #include "mkldnn_conv-inl.h" namespace mxnet { @@ -640,16 +639,22 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) { return node; } -bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, - const NodeAttrs &input_attrs) { - const std::vector exclude_key{ - "weight", "bias", "gamma", "beta", "moving_mean", "moving_var"}; - for (auto i : exclude_key) { - if (common::StringEndsWith(input_attrs.name, i)) { - return true; - } +bool SgMKLDNNAvoidQuantizeInput(const NodeAttrs &attrs, size_t index) { + auto const ¶m = nnvm::get(attrs.parsed); + std::unordered_set avoid_indice; + size_t idx = 0; + idx++; // data + avoid_indice.insert(idx++); // weight + if (!param.full_conv_param.conv_param.no_bias) { + avoid_indice.insert(idx++); // bias + } + if (param.full_conv_param.mkldnn_param.with_bn) { + avoid_indice.insert(idx++); // gamma + avoid_indice.insert(idx++); // beta + avoid_indice.insert(idx++); // mean + avoid_indice.insert(idx++); // var } - return false; + return avoid_indice.count(index); } NNVM_REGISTER_OP(_sg_mkldnn_conv) From 8833a02ef98906baad09ac54fb77ce935b01fa15 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Sat, 29 Sep 2018 16:38:58 +0800 Subject: [PATCH 25/28] Add mkldnn_hwigo support as quantizaiton needs. --- src/operator/nn/mkldnn/mkldnn_base.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 029f23bd8f5e..a60d6555c74d 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -332,6 +332,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { } else if (desc.data.ndims == 5) { switch (desc.data.format) { case mkldnn_goihw: + case mkldnn_hwigo: case mkldnn_gOIhw8i8o: case mkldnn_gOIhw16i16o: case mkldnn_gOIhw4i16o4i: From f89bd28ce861016014fe811d2ab7ea620062a713 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 2 Oct 2018 19:11:43 +0800 Subject: [PATCH 26/28] Address KellenSunderland's comments. --- src/operator/quantization/quantize_graph_pass.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index c53382a4e793..2fa790dc88ef 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -97,7 +97,7 @@ inline bool NeedQuantize(NodePtr node, const std::unordered_set& ex bool need = true; if (excluded_nodes.count(node->attrs.name)) { need = false; - } else if (node->attrs.subgraphs.size()) { + } else if (!node->attrs.subgraphs.empty()) { ExecType exec_type = fexec_type.count(op) ? fexec_type[op](node->attrs) : ExecType::kSync; if (exec_type != ExecType::kSubgraphExec) { // This is a fused subgraph node, try to match inner node. @@ -239,7 +239,7 @@ Graph QuantizeGraph(Graph &&src) { NodeEntry{new_node, static_cast(i), 0}); } new_node = requantize_node; - } + } } else { // If the currently visited node does not need quantization, copy the current node to become // the new_node. Meanwhile, check whether any inputs of the current node need quantization @@ -249,7 +249,7 @@ Graph QuantizeGraph(Graph &&src) { *new_node = *node; new_node->inputs.clear(); if (node->is_variable() && node->attrs.name == "data") { - // Instert identity for data to collect calib for it. + // Insert identity for data to collect calib for it. NodePtr identity_node = CreateNode("identity", new_node->attrs.name + "_id"); identity_node->inputs.emplace_back(NodeEntry{new_node, 0, 0}); From c5bf05dc161856e03c832428c3780d1483af9e8f Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 8 Oct 2018 22:09:46 +0800 Subject: [PATCH 27/28] Handle input order change after subgraph pass. --- src/executor/graph_executor.cc | 69 ++++-- src/operator/subgraph/partition_graph.cc | 1 - src/operator/subgraph/subgraph_property.h | 68 ++++-- tests/python/mkl/test_subgraph.py | 249 ++++++++++++---------- 4 files changed, 231 insertions(+), 156 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 922917f79475..84a51aa9381f 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1499,6 +1499,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, // This is for simple_bind flow. static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name, + std::vector *in_args, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, const std::unordered_map& arg_stype_map, @@ -1507,6 +1508,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::vector& in_arg_ctxes, const std::vector& aux_state_ctxes) { const std::vector input_names = src.ListInputNames(Symbol::kAll); + const std::vector arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + CHECK_EQ(arg_names.size(), in_args->size()); nnvm::ShapeVector arg_shapes(input_names.size(), TShape()); nnvm::DTypeVector arg_dtypes(input_names.size(), -1); StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage); @@ -1524,22 +1527,38 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, arg_stypes[i] = it3->second; } } - return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, - default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes); + // setup in_args_map + std::unordered_map in_args_map; + for (size_t i = 0; i < in_args->size(); ++i) { + in_args_map[arg_names[i]] = in_args->at(i); + } + auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx, + ctx_map, in_arg_ctxes, aux_state_ctxes); + // Reorder in_args into new_in_args according to partitioned symbol input sequence + std::vector new_in_args(in_args->size()); + // get new symbol in_arg names + std::vector new_arg_names = result.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + CHECK_EQ(arg_names.size(), new_arg_names.size()); + in_args->clear(); + for (auto arg_name : new_arg_names) { + CHECK(in_args_map.count(arg_name)); + in_args->push_back(in_args_map[arg_name]); + } + return result; } // Given input ndarrays, partition the graph using the backend name equal to prop_name. // This is for bind flow. static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name, - const std::vector &in_args, + std::vector *in_args, const std::vector &aux_states, const Context& default_ctx, const std::map& ctx_map) { const std::vector input_names = src.ListInputNames(Symbol::kAll); const std::vector arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); const std::vector aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates); - CHECK_EQ(arg_names.size(), in_args.size()); + CHECK_EQ(arg_names.size(), in_args->size()); CHECK_EQ(aux_names.size(), aux_states.size()); nnvm::ShapeVector arg_shapes; // all input shapes arg_shapes.reserve(input_names.size()); @@ -1547,7 +1566,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, arg_dtypes.reserve(input_names.size()); StorageTypeVector arg_stypes; // all input stypes arg_stypes.reserve(input_names.size()); - std::vector in_arg_ctxes(in_args.size()); + std::vector in_arg_ctxes(in_args->size()); std::vector aux_state_ctxes(aux_states.size()); size_t i1 = 0, i2 = 0; @@ -1561,15 +1580,32 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, } else { CHECK(i1 < arg_names.size()); CHECK_EQ(arg_names[i1], input_names[i]); - arg_shapes.push_back(in_args[i1].shape()); - arg_dtypes.push_back(in_args[i1].dtype()); - arg_stypes.push_back(in_args[i1].storage_type()); - in_arg_ctxes[i1] = in_args[i1].ctx(); + arg_shapes.push_back(in_args->at(i1).shape()); + arg_dtypes.push_back(in_args->at(i1).dtype()); + arg_stypes.push_back(in_args->at(i1).storage_type()); + in_arg_ctxes[i1] = in_args->at(i1).ctx(); ++i1; } } - return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, - default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes); + + // setup in_args_map + std::unordered_map in_args_map; + for (size_t i = 0; i < in_args->size(); ++i) { + in_args_map[arg_names[i]] = in_args->at(i); + } + auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx, + ctx_map, in_arg_ctxes, aux_state_ctxes); + // Reorder in_args into new_in_args according to partitioned symbol input sequence + std::vector new_in_args(in_args->size()); + // get new symbol in_arg names + std::vector new_arg_names = result.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + CHECK_EQ(arg_names.size(), new_arg_names.size()); + in_args->clear(); + for (auto arg_name : new_arg_names) { + CHECK(in_args_map.count(arg_name)); + in_args->push_back(in_args_map[arg_name]); + } + return result; } } // namespace exec @@ -1591,9 +1627,9 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); if (!exec->subgraph_property().empty()) { - symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map, - arg_stype_map, default_ctx, group2ctx, in_arg_ctxes, - aux_state_ctxes); + symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, arg_shape_map, + arg_dtype_map, arg_stype_map, default_ctx, group2ctx, + in_arg_ctxes, aux_state_ctxes); } exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, @@ -1613,12 +1649,13 @@ Executor *Executor::Bind(nnvm::Symbol symbol, const std::vector &aux_states, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); + std::vector tmp_in_args = in_args; if (!exec->subgraph_property().empty()) { - symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, aux_states, + symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states, default_ctx, group2ctx); } exec->Init(symbol, default_ctx, group2ctx, - in_args, arg_grad_store, grad_req_type, aux_states, + tmp_in_args, arg_grad_store, grad_req_type, aux_states, reinterpret_cast(shared_exec)); return exec; } diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index 57fb47f82933..da9a9f375fa5 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -656,7 +656,6 @@ void CreateSubgraphNode(Graph* g, subg_prop->ConnectSubgraphOutputs(n, &output_entries); subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries); - n->inputs = orig_input_entries; const auto& indexed_graph = g->indexed_graph(); for (size_t i = 0; i < n->inputs.size(); ++i) { auto& e = n->inputs[i]; diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index 85e9adf4267b..e9fdd6619275 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -62,16 +62,22 @@ class SubgraphSelector { * \brief Determines if to select input_node when traverse to the cur_node. * \param cur_node the node for determining whether its input_node should be selected * \param input_node the input node of the cur_node + * \return true if input_node is selected */ virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) = 0; /*! * \brief Determines if to select output_node when traverse to the cur_node. * \param cur_node the node for determining whether its output_node should be selected * \param output_node the output node of the cur_node + * \return true if output_node is selected */ virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) = 0; - // Post processes pre-selected subgraph nodes. Return a list of nodes that - // users want to keep in subgraph(s). + /*! + * \brief Post processes pre-selected subgraph nodes. Return a list of nodes that + * users want to keep in subgraph(s). + * \param candidates re-selected subgraph nodes to filt + * \return a list of nodes to keep + */ virtual std::vector Filter(const std::vector& candidates) { return candidates; } @@ -81,40 +87,58 @@ using SubgraphSelectorPtr = std::shared_ptr; /*! * \brief This provides a set of properties for partitioning a graph into subgraphs, - * reconstructing a new graph from the subgraphs and creating a subgraph - * operator to execute the subgraph. + * reconstructing a new graph from the subgraphs and creating a subgraph + * operator to execute the subgraph. */ class SubgraphProperty { public: - // the criteria of selecting the subgraph nodes. + /*! + * \brief The criteria of selecting the subgraph nodes. + */ virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0; - // create an nnvm node for a given subgraph. Here users can customize how to - // execute the operators in the subgraph. - virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s, + /*! + * \brief Create an nnvm node for a given subgraph. Here users can customize how to + * execute the operators in the subgraph. + * \param sym the symbol to create subgraph node + * \param subgraph_id subgraph id + */ + virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const = 0; - // Connect subgraph internal output with external output entries. By default, - // each output entry will connect to an unique internal output. - virtual void ConnectSubgraphOutputs( - const nnvm::NodePtr n, - std::vector *output_entries) const { + /*! + * \brief Connect subgraph internal output with external output entries. + * By default, each output entry will connect to an unique internal output. + * \param subgraph_node the subgraph node to connect output + * \param output_entries external output entries depending on this subgraph node + */ + virtual void ConnectSubgraphOutputs(const nnvm::NodePtr subgraph_node, + std::vector* output_entries) const { for (size_t i = 0; i < output_entries->size(); ++i) { - *output_entries->at(i) = nnvm::NodeEntry{n, static_cast(i), 0}; + *output_entries->at(i) = nnvm::NodeEntry{subgraph_node, static_cast(i), 0}; } } - // Connect subgraph internal input with external input entries. By default, - // each input entry will connect in top sorted order. - virtual void ConnectSubgraphInputs( - const nnvm::NodePtr n, std::vector *input_entries, - std::vector *orig_input_entries) const { - n->inputs = *orig_input_entries; + /*! + * \brief Connect subgraph internal input with external input entries. + * By default, each input entry will connect in top sorted order. + * \param subgraph_node the subgraph node to connect input + * \param input_entries input entries inside subgraph + * \param orig_input_entries input entries outside subgraph + */ + virtual void ConnectSubgraphInputs(const nnvm::NodePtr subgraph_node, + std::vector* input_entries, + std::vector* orig_input_entries) const { + subgraph_node->inputs = *orig_input_entries; } - // set an attr with name in the attr map + /*! + * \brief Set an attr with name in the attr map. + */ template SubgraphProperty& SetAttr(const std::string& name, const T& value) { attrs_[name] = std::make_shared(value); return *this; } - // get the attr with the name + /*! + * \brief Get the attr with the name. + */ template const T& GetAttr(const std::string& name) const { auto it = attrs_.find(name); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 3ca59eab57d4..5b708216e2ac 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -31,11 +31,9 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../unittest/')) from common import with_seed +from mxnet.test_utils import assert_almost_equal DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] -DATA_LABEL=[(4, 10), (32, 10), (64, 10)] -MIN_VALUE=-1.0 -MAX_VALUE=1.0 def check_qsym_calibrated(qsym): assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 @@ -55,17 +53,37 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_ mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() - return output + return mod.get_outputs() + +def check_quantize(sym, data_shape): + fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + sym_sg = sym.get_backend_symbol("MKLDNN") + label_shape = (data_shape[0], 10) + mod = Module(symbol=sym) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) + mod.init_params(mx.init.Normal(0.5)) + arg_params, aux_params = mod.get_params() + + data = [mx.random.uniform(-1, 1, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) + + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + ref_out = mod.get_outputs() -def check_quantize(sym, arg_params, aux_params, data_shape, label_shape, batch, sym_output): excluded_sym_names = [] if mx.current_context() == mx.cpu(): excluded_sym_names += ['fc'] + calib_data = mx.nd.random.uniform(shape=data_shape) calib_data = NDArrayIter(data=calib_data) calib_data = DummyIter(calib_data) calib_layer = lambda name: name.endswith('_output') - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, arg_params=arg_params, aux_params=aux_params, ctx=mx.current_context(), @@ -75,50 +93,37 @@ def check_quantize(sym, arg_params, aux_params, data_shape, label_shape, batch, calib_data=calib_data, calib_layer=calib_layer, calib_quantize_op=True, - num_calib_examples=20) + num_calib_examples=5) qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") check_qsym_calibrated(qsym) - qsym_output = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) + quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) + for i in range(len(ref_out)): + assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) - diff = mx.nd.abs(sym_output - qsym_output.astype(sym_output.dtype)) - cond = mx.nd.lesser(2, diff).sum().asscalar() - assert cond == 0 @with_seed() -def check_fusion(sym, data_shape, label_shape, attrs_op): - dev = mx.cpu() - mod = Module(symbol=sym) - mod.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) - mod.init_params(mx.init.Normal(0.5)) - arg_params, aux_params = mod.get_params() - - data = [mx.random.uniform(MIN_VALUE, MAX_VALUE, shape=shape, ctx=dev) for _, shape in mod.data_shapes] - batch = mx.io.DataBatch(data, []) - - mod.forward(batch, is_train=False) - for output in mod.get_outputs(): - output.wait_to_read() - +def check_fusion(sym, data_shape, attrs_op): sym_sg = sym.get_backend_symbol("MKLDNN") - mod_sg = Module(symbol=sym) - mod_sg.bind(data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) - mod_sg.set_params(arg_params, aux_params) - - mod_sg.forward(batch, is_train=False) - for output_sg in mod_sg.get_outputs(): - output_sg.wait_to_read() - assert ''.join(sym_sg.get_internals().list_outputs()).find('sg_mkldnn_conv') != -1 - for k, v in sym_sg.attr_dict().items(): if k.find('sg_mkldnn_conv') != -1: for attr_op in attrs_op: assert v[attr_op] == 'true' - # Check the result accuracy based on fp32 fusion - assert_allclose(output[0].asnumpy(), output_sg[0].asnumpy(), rtol = 0) + arg_shapes, _, aux_shapes = sym.infer_shape() + arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes] + aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] + exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe.forward() + os.environ['MXNET_SUBGRAPH_BACKEND'] = 'MKLDNN' + exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe_sg.forward() + del os.environ['MXNET_SUBGRAPH_BACKEND'] + for i in range(len(exe.outputs)): + assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) + # fp32 to uint8 - check_quantize(sym_sg, arg_params, aux_params, data_shape, label_shape, batch, output_sg) + check_quantize(sym, data_shape) def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)): for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): @@ -133,73 +138,76 @@ def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4 for exc_attr in excluded_attr: assert exc_attr not in v.keys() -def head_symbol(): - data = mx.symbol.Variable('data', dtype='float32') +def head_symbol(data_shape): + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') weight = mx.symbol.Variable('weight', dtype='float32') bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') return bn, weight -def tail_symbol(sym): - fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - return sym - # single conv fuision case -def single_conv(no_bias): +def single_conv(no_bias, data_shape): conv_attr = [''] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) - sym = tail_symbol(conv) - return sym, conv_attr + return conv, conv_attr # conv + bn fusion case -def conv_bn(no_bias): +def conv_bn(no_bias, data_shape): conv_bn_attr = ['with_bn'] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") - sym = tail_symbol(bn1) - return sym, conv_bn_attr + return bn1, conv_bn_attr # conv + relu fusion case -def conv_relu(no_bias): +def conv_relu(no_bias, data_shape): conv_relu_attr = ['with_relu'] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") - sym = tail_symbol(relu) - return sym, conv_relu_attr + return relu, conv_relu_attr # conv + add fusion case -def conv_add(no_bias): +def conv_add(no_bias, data_shape): conv_add_attr = ['with_sum'] - data, weight = head_symbol() - conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + data, weight = head_symbol(data_shape) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') + sum = conv1 + pool + return sum, conv_add_attr + +# conv + add fusion case 2 +def conv_add2(no_bias, data_shape): + conv_add_attr = ['with_sum'] + data, weight = head_symbol(data_shape) conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, - kernel=(3, 3), stride=(1, 1)) - sum = conv + conv1 - sym = tail_symbol(sum) - return sym, conv_add_attr + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') + sum = pool + conv1 + return sum, conv_add_attr # conv + bn + relu fusion case -def conv_bn_relu(no_bias): +def conv_bn_relu(no_bias, data_shape): conv_bn_relu_attr = ['with_bn', 'with_relu'] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") relu = mx.symbol.Activation(data=bn1, name='relu', act_type="relu") - sym = tail_symbol(relu) - return sym, conv_bn_relu_attr + return relu, conv_bn_relu_attr # conv + bn + add + relu fusion case -def conv_bn_sum_relu(no_bias): +def conv_bn_sum_relu(no_bias, data_shape): conv_bn_add_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1), no_bias=no_bias) bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") @@ -207,8 +215,7 @@ def conv_bn_sum_relu(no_bias): kernel=(3, 3), stride=(1, 1)) sum1 = bn1 + conv1 relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") - sym = tail_symbol(relu) - return sym, conv_bn_add_relu_attr + return relu, conv_bn_add_relu_attr def tail_neg_symbol(sym1, sym2): fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') @@ -223,11 +230,11 @@ def tail_neg_symbol(sym1, sym2): # | # | # -------------> [custom op] -def neg_conv_bn(): +def neg_conv_bn(data_shape): syms = [] attrs = [] excluded_attrs = [] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) # eg.1 ([custom op] = pool) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -246,11 +253,11 @@ def neg_conv_bn(): # | # | # ---------------> [custom op] -def neg_conv_relu(): +def neg_conv_relu(data_shape): syms = [] attrs = [] excluded_attrs = [] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) # eg.1 ([custom op] = pool) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -272,12 +279,12 @@ def neg_conv_relu(): # | # | # added ------------> -def neg_conv_add(): +def neg_conv_add(data_shape): syms = [] attrs = [] excluded_attrs = [] val = mx.symbol.Variable('addval') - data, weight = head_symbol() + data, weight = head_symbol(data_shape) # eg.1 ([custom op] = pool, [added op] = val) conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -300,11 +307,11 @@ def neg_conv_add(): # --------------> [custom op] # | # conv -----------> bn -----------> relu -def neg_conv_bn_relu(): +def neg_conv_bn_relu(data_shape): syms = [] attrs = [] excluded_attrs = [] - data, weight = head_symbol() + data, weight = head_symbol(data_shape) # eg.1 ([custom op] = pool11) conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -344,12 +351,12 @@ def neg_conv_bn_relu(): # --------------> [custom op] # | # conv -----------> bn -----------> add -----------> relu -def neg_conv_bn_add_relu(): +def neg_conv_bn_add_relu(data_shape): syms = [] attrs = [] excluded_attrs = [] addVal = mx.symbol.Variable('addval') - data, weight = head_symbol() + data, weight = head_symbol(data_shape) # eg.1 conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) @@ -390,83 +397,91 @@ def neg_conv_bn_add_relu(): @with_seed() def test_pos_single_conv(): - for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): - net, attrs = single_conv(False) - check_fusion(net, data_shape, label_shape, attrs) - net, attrs = single_conv(True) - check_fusion(net, data_shape, label_shape, attrs) + for data_shape in DATA_SHAPE: + net, attrs = single_conv(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = single_conv(True, data_shape) + check_fusion(net, data_shape, attrs) @with_seed() def test_pos_conv_relu(): - for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): - net, attrs = conv_relu(False) - check_fusion(net, data_shape, label_shape, attrs) - net, attrs = conv_relu(True) - check_fusion(net, data_shape, label_shape, attrs) + for data_shape in DATA_SHAPE: + net, attrs = conv_relu(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_relu(True, data_shape) + check_fusion(net, data_shape, attrs) @with_seed() def test_pos_conv_bn(): - for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): - net, attrs = conv_bn(False) - check_fusion(net, data_shape, label_shape, attrs) - net, attrs = conv_bn(True) - check_fusion(net, data_shape, label_shape, attrs) + for data_shape in DATA_SHAPE: + net, attrs = conv_bn(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn(True, data_shape) + check_fusion(net, data_shape, attrs) @with_seed() def test_pos_conv_add(): - for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): - net, attrs = conv_add(False) - check_fusion(net, data_shape, label_shape, attrs) - net, attrs = conv_add(True) - check_fusion(net, data_shape, label_shape, attrs) + for data_shape in DATA_SHAPE: + net, attrs = conv_add(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_add(True, data_shape) + check_fusion(net, data_shape, attrs) + +@with_seed() +def test_pos_conv_add2(): + for data_shape in DATA_SHAPE: + net, attrs = conv_add2(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_add2(True, data_shape) + check_fusion(net, data_shape, attrs) @with_seed() def test_pos_conv_bn_relu(): - for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): - net, attrs = conv_bn_relu(False) - check_fusion(net, data_shape, label_shape, attrs) - net, attrs = conv_bn_relu(True) - check_fusion(net, data_shape, label_shape, attrs) + for data_shape in DATA_SHAPE: + net, attrs = conv_bn_relu(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn_relu(True, data_shape) + check_fusion(net, data_shape, attrs) @with_seed() def test_pos_conv_bn_sum_relu(): - for data_shape, label_shape in zip(DATA_SHAPE, DATA_LABEL): - net, attrs = conv_bn_sum_relu(False) - check_fusion(net, data_shape, label_shape, attrs) - net, attrs = conv_bn_sum_relu(True) - check_fusion(net, data_shape, label_shape, attrs) + for data_shape in DATA_SHAPE: + net, attrs = conv_bn_sum_relu(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn_sum_relu(True, data_shape) + check_fusion(net, data_shape, attrs) @with_seed() def test_neg_conv_bn(): for data_shape in DATA_SHAPE: - syms, attrs, excluded_attrs = neg_conv_bn() + syms, attrs, excluded_attrs = neg_conv_bn(data_shape) check_neg_fusion(syms, attrs, excluded_attrs, data_shape) @with_seed() def test_neg_conv_relu(): for data_shape in DATA_SHAPE: - syms, attrs, excluded_attrs = neg_conv_relu() + syms, attrs, excluded_attrs = neg_conv_relu(data_shape) check_neg_fusion(syms, attrs, excluded_attrs, data_shape) @with_seed() def test_neg_conv_add(): for data_shape in DATA_SHAPE: - syms, attrs, excluded_attrs = neg_conv_add() + syms, attrs, excluded_attrs = neg_conv_add(data_shape) check_neg_fusion(syms, attrs, excluded_attrs, data_shape) @with_seed() def test_neg_conv_bn_relu(): for data_shape in DATA_SHAPE: - syms, attrs, excluded_attrs = neg_conv_bn_relu() + syms, attrs, excluded_attrs = neg_conv_bn_relu(data_shape) check_neg_fusion(syms, attrs, excluded_attrs, data_shape) @with_seed() def test_neg_conv_bn_add_relu(): for data_shape in DATA_SHAPE: - syms, attrs, excluded_attrs = neg_conv_bn_add_relu() + syms, attrs, excluded_attrs = neg_conv_bn_add_relu(data_shape) check_neg_fusion(syms, attrs, excluded_attrs, data_shape) if __name__ == "__main__": import nose - nose.runmodule() \ No newline at end of file + nose.runmodule() From 8da56c86453d2a0311cf401ae679f807ae76dd22 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 9 Oct 2018 09:08:37 +0800 Subject: [PATCH 28/28] Fix ci test --- src/executor/graph_executor.cc | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 84a51aa9381f..ed394234c525 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1499,7 +1499,6 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, // This is for simple_bind flow. static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name, - std::vector *in_args, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, const std::unordered_map& arg_stype_map, @@ -1508,8 +1507,6 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::vector& in_arg_ctxes, const std::vector& aux_state_ctxes) { const std::vector input_names = src.ListInputNames(Symbol::kAll); - const std::vector arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); - CHECK_EQ(arg_names.size(), in_args->size()); nnvm::ShapeVector arg_shapes(input_names.size(), TShape()); nnvm::DTypeVector arg_dtypes(input_names.size(), -1); StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage); @@ -1527,24 +1524,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, arg_stypes[i] = it3->second; } } - // setup in_args_map - std::unordered_map in_args_map; - for (size_t i = 0; i < in_args->size(); ++i) { - in_args_map[arg_names[i]] = in_args->at(i); - } - auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx, - ctx_map, in_arg_ctxes, aux_state_ctxes); - // Reorder in_args into new_in_args according to partitioned symbol input sequence - std::vector new_in_args(in_args->size()); - // get new symbol in_arg names - std::vector new_arg_names = result.ListInputNames(nnvm::Symbol::kReadOnlyArgs); - CHECK_EQ(arg_names.size(), new_arg_names.size()); - in_args->clear(); - for (auto arg_name : new_arg_names) { - CHECK(in_args_map.count(arg_name)); - in_args->push_back(in_args_map[arg_name]); - } - return result; + return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, + default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes); } // Given input ndarrays, partition the graph using the backend name equal to prop_name. @@ -1627,9 +1608,9 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); if (!exec->subgraph_property().empty()) { - symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, arg_shape_map, - arg_dtype_map, arg_stype_map, default_ctx, group2ctx, - in_arg_ctxes, aux_state_ctxes); + symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map, + arg_stype_map, default_ctx, group2ctx, in_arg_ctxes, + aux_state_ctxes); } exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,