diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 65ca481e694f..81b20d4b092d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -250,6 +250,7 @@ List of Contributors * [dithyrambe](https://github.com/dithyrambe) * [Piljae Chae](https://github.com/IHateMint) * [Oliver Kowalke](https://github.com/olk) +* [Connor Goggins](https://github.com/connorgoggins) Label Bot --------- @@ -260,6 +261,6 @@ Label Bot - @mxnet-label-bot remove [specify comma separated labels here] - @mxnet-label-bot update [specify comma separated labels here] (i.e. @mxnet-label-bot update [Bug, Python]) - + - Available label names which are supported: [Labels](https://github.com/apache/incubator-mxnet/labels) - For further details: [My Wiki Page](https://cwiki.apache.org/confluence/display/MXNET/Machine+Learning+Based+GitHub+Bot) diff --git a/benchmark/opperf/nd_operations/nn_basic_operators.py b/benchmark/opperf/nd_operations/nn_basic_operators.py index 9a34a9a725ee..a8273d4105dc 100644 --- a/benchmark/opperf/nd_operations/nn_basic_operators.py +++ b/benchmark/opperf/nd_operations/nn_basic_operators.py @@ -16,71 +16,61 @@ # under the License. import mxnet as mx -from benchmark.opperf.utils.benchmark_utils import run_performance_test -from benchmark.opperf.utils.common_utils import merge_map_list -from benchmark.opperf.rules.default_params import MX_OP_MODULE + +from benchmark.opperf.utils.op_registry_utils import get_all_nn_basic_operators +from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks """Performance benchmark tests for MXNet NDArray basic NN Operators. 1. FullyConnected 2. Dropout 3. BatchNorm +4. SoftmaxOutput +5. LinearRegressionOutput +6. LogisticRegressionOutput +7. MAERegressionOutput +8. SVMOutput +9. L2Normalization +10. LayerNorm +11. InstanceNorm +12. Embedding +13. Correlation +14. SpatialTransformer +15. im2col +16. col2im +17. GroupNorm +18. RNN +19. LRN """ def run_nn_basic_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100): - # FullyConnnected operator benchmarks - fc_benchmark_res = run_performance_test([getattr(MX_OP_MODULE, "FullyConnected")], - run_backward=True, - dtype=dtype, - ctx=ctx, - profiler=profiler, - inputs=[{"data": (32, 3, 256, 256), - "num_hidden": 64, - "weight": (64, 3 * 256 * 256), - "bias": (64,), - "flatten": True}, - {"data": (32, 3, 256, 256), - "num_hidden": 64, - "weight": (64, 256), - "bias": (64,), - "flatten": False}], - warmup=warmup, - runs=runs) + """Runs benchmarks with the given context and precision (dtype)for all the NN basic + operators in MXNet. + + Parameters + ---------- + ctx: mx.ctx + Context to run benchmarks + dtype: str, default 'float32' + Precision to use for benchmarks + profiler: str, default 'native' + Module to use for tracking benchmark excecution time + warmup: int, default 25 + Number of times to run for warmup + runs: int, default 100 + Number of runs to capture benchmark results + + Returns + ------- + Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. + + """ - # Dropout benchmarks - dropout_benchmark_res = run_performance_test([getattr(MX_OP_MODULE, "Dropout")], - run_backward=True, - dtype=dtype, - ctx=ctx, - profiler=profiler, - inputs=[{"data": (32, 3, 256, 256), - "p": 0.5, - "mode": "always"}, - {"data": (10000, 10), - "p": 0.5, - "mode": "always"}], - warmup=warmup, - runs=runs) - # BatchNorm benchmarks - batchnorm_benchmark_res = run_performance_test([getattr(MX_OP_MODULE, "BatchNorm")], - run_backward=True, - dtype=dtype, - ctx=ctx, - profiler=profiler, - inputs=[{"data": (32, 3, 256, 256), - "gamma": (3,), - "beta": (3,), - "moving_mean": (3,), - "moving_var": (3,)}, - {"data": (32, 3, 10000, 10), - "gamma": (3,), - "beta": (3,), - "moving_mean": (3,), - "moving_var": (3,)}], - warmup=warmup, - runs=runs) - # Prepare combined results - mx_basic_nn_results = merge_map_list(fc_benchmark_res + dropout_benchmark_res + batchnorm_benchmark_res) - return mx_basic_nn_results + # Fetch all NN Basic Operators + mx_nn_basic_ops = get_all_nn_basic_operators() + + # Run benchmarks + mx_nn_basic_op_results = run_op_benchmarks(mx_nn_basic_ops, dtype, ctx, profiler, warmup, runs) + return mx_nn_basic_op_results diff --git a/benchmark/opperf/rules/default_params.py b/benchmark/opperf/rules/default_params.py index 31940da8eb77..15bcd72b0553 100644 --- a/benchmark/opperf/rules/default_params.py +++ b/benchmark/opperf/rules/default_params.py @@ -81,6 +81,92 @@ # NOTE: Data used is DEFAULT_DATA DEFAULT_AXIS = [0] +# For NN basic operators +# General +DEFAULT_DATA_NN_BASIC = [(32, 3, 256, 256), (32, 3, 10000, 10)] +DEFAULT_NUM_HIDDEN = [64] +DEFAULT_BIAS = [(64,)] +DEFAULT_FLATTEN = [True, False] +DEFAULT_GAMMA = [(3,)] +DEFAULT_BETA = [(3,)] +DEFAULT_MOVING_MEAN = [(3,)] +DEFAULT_MOVING_VAR = [(3,)] +DEFAULT_LABEL_REG = [(32, 3, 256, 256), (32, 3, 10000, 10)] +DEFAULT_GRAD_SCALE = [.5] +DEFAULT_NORMALIZATION = ["batch"] +DEFAULT_MARGIN = [.5] +DEFAULT_REG_COEFF = [.5] +DEFAULT_INPUT_DIM = [3, 16] +DEFAULT_OUTPUT_DIM = [4, 9] +DEFAULT_SPARSE_GRAD = [False] +DEFAULT_KERNEL_SIZE = [3] +DEFAULT_MAX_DISPLACEMENT = [2] +DEFAULT_STRIDE_1 = [2] +DEFAULT_STRIDE_2 = [2] +DEFAULT_ALPHA = [.001] +DEFAULT_NSIZE = [3] +DEFAULT_PARAMETERS = [(7,), (104,)] +DEFAULT_STATE = [(1, 4, 1), (2, 10000, 4)] +DEFAULT_MODE = ["rnn_relu", "rnn_tanh"] +DEFAULT_STATE_SIZE = [1, 4] +DEFAULT_NUM_LAYERS = [1, 2] +DEFAULT_NUM_GROUPS = [1, 10] +DEFAULT_TRANSFORM = ["affine"] +DEFAULT_SAMPLER = ["bilinear"] +DEFAULT_DILATE = [(1,), (1, 1)] +DEFAULT_PAD = [(1,), (1, 1)] +DEFAULT_OUTPUT_SIZE = [(64, 16, 1), (32, 8, 1)] +DEFAULT_KERNEL = [(1, 1, 1), (1, 1, 1)] +DEFAULT_STRIDE = [(2, 2, 2), (1, 1, 1)] + +# BatchNorm +DEFAULT_AXIS_BN = [1] + +# LayerNorm +DEFAULT_GAMMA_LN = [(32,), (32,)] +DEFAULT_BETA_LN = [(32,), (32,)] + +# L2Normalization +DEFAULT_MODE_L2 = ['channel', 'instance', 'spatial'] + +# SVMOutput +DEFAULT_LABEL_SVM = [(32, 3, 256), (32, 3, 10000)] + +# SoftmaxOutput +DEFAULT_LABEL_SM = [(32, 3, 256), (32, 3, 10000)] + +# FullyConnected +DEFAULT_WEIGHT_FC = [(64, 3 * 256 * 256), (64, 10)] + +# Embedding +DEFAULT_WEIGHT_EMBEDDING = [(3, 4), (16, 9)] + +# GroupNorm +DEFAULT_DATA_GN = [(32, 3, 256, 256), (32, 10, 10000, 10)] +DEFAULT_BETA_GAMMA_GN = [(1,), (10,)] + +# Dropout +DEFAULT_DATA_DROPOUT = [(32, 3, 256, 256), (10000, 10)] +DEFAULT_MODE_DROPOUT = ["always"] + +# SpatialTransformer +DEFAULT_DATA_ST = [(32, 3, 256, 6), (256, 3, 10000, 6)] +DEFAULT_LOC_TAR_ST = [(32, 6), (256, 6)] + +# im2col +DEFAULT_KERNEL_I2C = [(3,), (3, 3)] +DEFAULT_STRIDE_I2C = [(1,), (1, 1)] + +# col2im +DEFAULT_DATA_C2I = [(32, 64, 256), (32, 64, 256)] + +# RNN +DEFAULT_DATA_RNN = [(32, 4, 4), (512, 10000, 10)] +DEFAULT_P_RNN = [.5] + +# LRN +DEFAULT_BETA_LRN = [.2] + # For optimizer operators DEFAULT_WEIGHT = [(1024, 1024), (10000, 1), (10000, 100)] DEFAULT_GRAD = [(1024, 1024), (10000, 1), (10000, 100)] @@ -267,7 +353,85 @@ "a": DEFAULT_A, "lhs_fill_element_0index": DEFAULT_LHS_FEI, "rhs_fill_element_0index": DEFAULT_RHS_FEI, - "mhs": DEFAULT_MHS} + "mhs": DEFAULT_MHS, + "data_spatialtransformer": DEFAULT_DATA_ST, + "loc_spatialtransformer": DEFAULT_LOC_TAR_ST, + "target_shape": DEFAULT_LOC_TAR_ST, + "transform_type_spatialtransformer": DEFAULT_TRANSFORM, + "sampler_type": DEFAULT_SAMPLER, + "data_col2im": DEFAULT_DATA_C2I, + "output_size": DEFAULT_OUTPUT_SIZE, + "kernel_col2im": DEFAULT_KERNEL, + "stride_col2im": DEFAULT_STRIDE, + "data_rnn": DEFAULT_DATA_RNN, + "p_rnn": DEFAULT_P_RNN, + "parameters": DEFAULT_PARAMETERS, + "state": DEFAULT_STATE, + "state_size": DEFAULT_STATE_SIZE, + "num_layers": DEFAULT_NUM_LAYERS, + "mode_rnn": DEFAULT_MODE, + "data_groupnorm": DEFAULT_DATA_GN, + "gamma_groupnorm": DEFAULT_BETA_GAMMA_GN, + "beta_groupnorm": DEFAULT_BETA_GAMMA_GN, + "num_groups": DEFAULT_NUM_GROUPS, + "eps": DEFAULT_EPSILON, + "data_dropout": DEFAULT_DATA_DROPOUT, + "mode_dropout": DEFAULT_MODE_DROPOUT, + "p_dropout": DEFAULT_P, + "data_nn_basic": DEFAULT_DATA_NN_BASIC, + "num_hidden": DEFAULT_NUM_HIDDEN, + "data_fullyconnected": DEFAULT_DATA_NN_BASIC, + "weight_fullyconnected": DEFAULT_WEIGHT_FC, + "weight_embedding": DEFAULT_WEIGHT_EMBEDDING, + "bias": DEFAULT_BIAS, + "flatten": DEFAULT_FLATTEN, + "data_batchnorm": DEFAULT_DATA_NN_BASIC, + "gamma_batchnorm": DEFAULT_GAMMA, + "beta_batchnorm": DEFAULT_BETA, + "moving_mean_batchnorm": DEFAULT_MOVING_MEAN, + "moving_var_batchnorm": DEFAULT_MOVING_VAR, + "axis_batchnorm": DEFAULT_AXIS_BN, + "data_softmaxoutput": DEFAULT_DATA_NN_BASIC, + "label_softmaxoutput": DEFAULT_LABEL_SM, + "data_maeregressionoutput": DEFAULT_DATA_NN_BASIC, + "label_maeregressionoutput": DEFAULT_LABEL_REG, + "data_logisticregressionoutput": DEFAULT_DATA_NN_BASIC, + "label_logisticregressionoutput": DEFAULT_LABEL_REG, + "data_linearregressionoutput": DEFAULT_DATA_NN_BASIC, + "label_linearregressionoutput": DEFAULT_LABEL_REG, + "data_svmoutput": DEFAULT_DATA_NN_BASIC, + "label_svmoutput": DEFAULT_LABEL_SVM, + "grad_scale": DEFAULT_GRAD_SCALE, + "normalization": DEFAULT_NORMALIZATION, + "margin": DEFAULT_MARGIN, + "regularization_coefficient": DEFAULT_REG_COEFF, + "data_l2normalization": DEFAULT_DATA_NN_BASIC, + "mode_l2normalization": DEFAULT_MODE_L2, + "gamma_layernorm": DEFAULT_GAMMA_LN, + "beta_layernorm": DEFAULT_BETA_LN, + "data_instancenorm": DEFAULT_DATA_NN_BASIC, + "gamma_instancenorm": DEFAULT_GAMMA, + "beta_instancenorm": DEFAULT_BETA, + "input_dim": DEFAULT_INPUT_DIM, + "output_dim": DEFAULT_OUTPUT_DIM, + "sparse_grad": DEFAULT_SPARSE_GRAD, + "data1": DEFAULT_DATA_NN_BASIC, + "data2": DEFAULT_DATA_NN_BASIC, + "kernel_size": DEFAULT_KERNEL_SIZE, + "max_displacement": DEFAULT_MAX_DISPLACEMENT, + "stride1": DEFAULT_STRIDE_1, + "stride2": DEFAULT_STRIDE_2, + "data_im2col": DEFAULT_DATA_NN_BASIC, + "kernel_im2col": DEFAULT_KERNEL_I2C, + "stride_im2col": DEFAULT_STRIDE_I2C, + "dilate_im2col": DEFAULT_DILATE, + "pad_im2col": DEFAULT_PAD, + "data_lrn": DEFAULT_DATA_NN_BASIC, + "alpha_lrn": DEFAULT_ALPHA, + "beta_lrn": DEFAULT_BETA_LRN, + "nsize": DEFAULT_NSIZE, + "data_layernorm": DEFAULT_DATA_NN_BASIC, + "axis_layernorm": DEFAULT_AXIS} # These are names of MXNet operator parameters that is of type NDArray. @@ -282,4 +446,4 @@ "v", "z", "g", "delta", "args", "indices", "shape_like", "y", "x", "condition", "a", "index", "raveL_data", "label", "grid", "A", "B", "C", "r1", "r2", "rois", "lrs", "wds", "weights_sum_sq", - "grads_sum_sq", "mhs"] + "grads_sum_sq", "mhs", "data1", "data2", "loc", "parameters", "state"] diff --git a/benchmark/opperf/utils/benchmark_utils.py b/benchmark/opperf/utils/benchmark_utils.py index 29223ff40aa9..f6cdfe004215 100644 --- a/benchmark/opperf/utils/benchmark_utils.py +++ b/benchmark/opperf/utils/benchmark_utils.py @@ -26,7 +26,7 @@ from benchmark.opperf.rules.default_params import PARAMS_OF_TYPE_NDARRAY from .profiler_utils import cpp_profile, python_profile -no_backward = ['gather_nd', 'softmax_cross_entropy', 'linalg_gelqf', 'linalg_slogdet', 'moments', 'SequenceLast'] +no_backward = {'gather_nd', 'softmax_cross_entropy', 'linalg_gelqf', 'linalg_slogdet', 'moments', 'SequenceLast', 'Embedding'} def _prepare_op_inputs(inputs, run_backward, dtype, ctx): mx.random.seed(41) @@ -163,6 +163,8 @@ def run_performance_test(ops, inputs, run_backward=True, ------- List of dictionary of benchmark results. key -> name of the operator, Value is benchmark results. + Note: when run_performance_test is called on the nd.Embedding operator with run_backward=True, an error will + be thrown. Track issue here: https://github.com/apache/incubator-mxnet/issues/11314 """ kwargs_list = _prepare_op_inputs(inputs, run_backward, dtype, ctx) @@ -180,24 +182,33 @@ def run_performance_test(ops, inputs, run_backward=True, def run_op_benchmarks(ops, dtype, ctx, profiler, warmup, runs): + # Running SoftmaxOutput backwards on GPU results in errors + # track issue here: https://github.com/apache/incubator-mxnet/issues/880 + gpu_backwards_disabled_ops = ['SoftmaxOutput'] + + # Running im2col either forwards or backwards on GPU results in errors + # track issue here: https://github.com/apache/incubator-mxnet/issues/17493 + gpu_disabled_ops = ['im2col'] + # For each operator, run benchmarks mx_op_benchmark_results = [] for op, op_params in ops.items(): - # Prepare inputs for the operator - inputs = prepare_op_inputs(op, op_params) - - # setting backward false for ops with known issue - if op in no_backward: - op_params["has_backward"] = False - - # Run benchmarks - cur_op_res = run_performance_test(op_params["nd_op_handle"], - run_backward=op_params["has_backward"], - dtype=dtype, ctx=ctx, - profiler=profiler, - inputs=inputs, - warmup=warmup, runs=runs) - mx_op_benchmark_results += cur_op_res + if ctx == mx.cpu() or op not in gpu_disabled_ops: + # Prepare inputs for the operator + inputs = prepare_op_inputs(op, op_params) + + # setting backward false for ops with known issue + if (ctx == mx.gpu() and op in gpu_backwards_disabled_ops) or op in no_backward: + op_params["has_backward"] = False + + # Run benchmarks + cur_op_res = run_performance_test(op_params["nd_op_handle"], + run_backward=op_params["has_backward"], + dtype=dtype, ctx=ctx, + profiler=profiler, + inputs=inputs, + warmup=warmup, runs=runs) + mx_op_benchmark_results += cur_op_res # Prepare combined results for all operators mx_op_benchmark_results = merge_map_list(mx_op_benchmark_results) diff --git a/benchmark/opperf/utils/op_registry_utils.py b/benchmark/opperf/utils/op_registry_utils.py index de7ad4dcc93f..99678b8d31a9 100644 --- a/benchmark/opperf/utils/op_registry_utils.py +++ b/benchmark/opperf/utils/op_registry_utils.py @@ -119,8 +119,11 @@ def prepare_op_inputs(op, arg_params): ops_3d = {'CTCLoss', 'ctc_loss'} # For ops with args that need to change shape/value for different ops - custom_data = {'Activation', 'LeakyReLU', 'Softmax', 'BilinearSampler', 'GridGenerator', - 'sample_multinomial', 'linalg_maketrian', 'squeeze', 'fill_element_0index'} + custom_data = {'Activation', 'LeakyReLU', 'Softmax', 'BilinearSampler', 'GridGenerator', 'sample_multinomial', 'linalg_maketrian', + 'SpatialTransformer', 'col2im', 'RNN', 'GroupNorm', 'Dropout', 'FullyConnected', + 'SoftmaxOutput', 'LinearRegressionOutput', 'BatchNorm', 'LogisticRegressionOutput', + 'MAERegressionOutput', 'SVMOutput', 'L2Normalization', 'LayerNorm', 'InstanceNorm', + 'Embedding', 'Correlation', 'im2col', 'LRN', 'squeeze', 'fill_element_0index'} int_only = {'random_randint'} float_only = {'log_softmax', 'softmax', 'softmin'} @@ -327,6 +330,27 @@ def get_all_reduction_operators(): reduction_mx_operators[op_name] = mx_operators[op_name] return reduction_mx_operators +def get_all_nn_basic_operators(): + """Gets all NN basic operators registered with MXNet. + + Returns + ------- + {"operator_name": {"has_backward", "nd_op_handle", "params"}} + """ + nn_basic_ops = ['FullyConnected', 'Dropout', 'BatchNorm', 'SoftmaxOutput', 'LinearRegressionOutput', + 'LogisticRegressionOutput', 'MAERegressionOutput', 'SVMOutput', 'L2Normalization', + 'LayerNorm', 'InstanceNorm', 'Embedding', 'Correlation', 'SpatialTransformer', 'im2col', + 'col2im', 'GroupNorm', 'RNN', 'LRN'] + + # Get all mxnet operators + mx_operators = _get_all_mxnet_operators() + + # Filter for NN Basic operators + nn_basic_mx_operators = {} + for op_name, _ in mx_operators.items(): + if op_name in nn_basic_ops: + nn_basic_mx_operators[op_name] = mx_operators[op_name] + return nn_basic_mx_operators def get_all_nn_activation_operators(): """Gets all NN Activation operators registered with MXNet.