From 4f0245a60caae6416109fce2ebd2b05e6210cffb Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Fri, 28 Sep 2018 19:37:16 -0700 Subject: [PATCH 01/13] Bulked op seg size to ignore Variable nodes, limited by MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_{FWD,BWD}. --- src/executor/graph_executor.cc | 75 ++++++++++++++++------------------ 1 file changed, 35 insertions(+), 40 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 8302dc133c64..70f8e9d4d020 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1211,63 +1211,58 @@ void GraphExecutor::InitOpSegs() { void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { - // The maximum number of node in a segment executed in bulk - size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); + // The maximum number of nodes in a segment executed in bulk (excluding variables). + size_t segment_num_nodes_threshold = + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); + // The maximum number of nodes in a segment executed in bulk (excluding variables) in fwd pass. + size_t segment_num_nodes_threshold_fwd = + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD", segment_num_nodes_threshold); + // The maximum number of nodes in a segment executed in bulk (excluding variables) in bwd pass. + size_t segment_num_nodes_threshold_bwd = + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", segment_num_nodes_threshold); // create forward segments for training size_t topo_start = 0; + size_t segment_node_count = 0; for (size_t nid = 0; nid < num_forward_nodes_; nid++) { auto &node = graph_.indexed_graph()[nid].source; auto &op_node = op_nodes_[nid]; - // check if the segment relies on external input, or exceeds maxinum number of node, - // or requires async ops - if (node->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { - // create a new segment for the previous nodes if the current one cannot be bulked - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); + // Variables, such as learned weights, are ignored in the segment_node_count + bool ignore_node = node->is_variable(); + if (!ignore_node) + segment_node_count++; + bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync; + // check if we need to create the segment based on properties of this node + if (!can_bulk || nid == num_forward_nodes_ - 1 || + segment_node_count >= segment_num_nodes_threshold_fwd) { + // Create a new segment for the previous nodes- include also this node if it's bulkable + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid); topo_start = nid + 1; + segment_node_count = 0; } } - // the last segment - if (topo_start != num_forward_nodes_) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); - } // create backward segments for training - // get all gradient variables - std::unordered_set grad_vars; - for (auto &kv : grad_store_) { - grad_vars.insert(kv.second.var()); - } - auto &idx = graph_.indexed_graph(); topo_start = num_forward_nodes_; + segment_node_count = 0; for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) { + auto &node = graph_.indexed_graph()[nid].source; auto &op_node = op_nodes_[nid]; - if (op_node.skip_exec_node || op_node.exec == nullptr) { - continue; - } - if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold || - op_node.exec->exec_type() != ExecType::kSync) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); + // Variables, such as learned weights, are ignored in the segment_node_count and + // nodes that are not executed for various reasons. + bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr; + if (!ignore_node) + segment_node_count++; + bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync; + // check if we need to create the segment based on properties of this node + if (!can_bulk || nid == total_num_nodes - 1 || + segment_node_count >= segment_num_nodes_threshold_bwd) { + // Create a new segment for the previous nodes- include also this node if it's bulkable + cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid); topo_start = nid + 1; - } else { - // If it produces output gradient, don't include it in the segment - bool output_gradient = false; - for (auto &out_arr : op_node.exec->out_array) { - if (grad_vars.find(out_arr.var()) != grad_vars.end()) { - output_gradient = true; - } - } - if (output_gradient) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } + segment_node_count = 0; } } - // last segment for backward - if (topo_start < total_num_nodes) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes); - } } void GraphExecutor::BulkInferenceOpSegs() { From 54fd288c7a4bf59d37f793c26ef9a98ed40b0c40 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Wed, 20 Feb 2019 19:21:09 -0800 Subject: [PATCH 02/13] Document new env variables. Unify operation with imperative. --- docs/faq/env_var.md | 8 +++++++- include/mxnet/imperative.h | 3 ++- src/imperative/cached_op.h | 6 ++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index c35d4e5723a5..d1924ea9888c 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -115,7 +115,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - If set to `1`, during training MXNet executes the computation graph as several subgraphs in bulk mode. * MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN - Values: Int ```(default=15)``` - - The maximum number of nodes in the subgraph executed in bulk during training(not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training. + - The maximum number of nodes in the subgraph executed in bulk during training (not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training. +* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD + - Values: Int ```(default=)``` + - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the forward pass. +* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD + - Values: Int ```(default=)``` + - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass. ## Control the Data Communication diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 7ea60df33028..566103982da7 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -135,7 +135,8 @@ class Imperative { /*! \brief make constructor protected. */ Imperative() { if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) { - backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); + backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)); } } /*! \brief find the input/output ndarrays that are needed for backward */ diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 3b173c8654a4..3fac1da2c34f 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -53,10 +53,12 @@ struct CachedOpConfig : public dmlc::Parameter { .set_default(2) .describe("Maximum number of operators that can be inlined."); DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))) .describe("Segment size of bulk execution during forward pass."); DMLC_DECLARE_FIELD(backward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)) + .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))) .describe("Segment size of bulk execution during backward pass."); DMLC_DECLARE_FIELD(data_indices) .set_default(nnvm::Tuple()) From f0e42966bacd4475b37087ea1ea7dd023f8f5aac Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Sat, 23 Feb 2019 18:26:40 -0800 Subject: [PATCH 03/13] Add timing-based tests of symbol and gluon op bulking. --- tests/python/gpu/test_gluon_gpu.py | 73 ++++++++++++++++++++++++++ tests/python/gpu/test_operator_gpu.py | 74 +++++++++++++++++++++++++++ tests/python/unittest/common.py | 45 ++++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 54bfcee47347..ae46ca7f266f 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -38,6 +38,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied +from common import test_in_separate_process from test_gluon import * from test_loss import * from test_gluon_rnn import * @@ -408,6 +409,78 @@ def tensor_size(big_tensor_bytes): # Evaluate model net(data_in).asnumpy() +# isolated execution bulking test function to be invoked with different env var settings +def _test_bulking_in_process(seed, time_per_iteration): + # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused. + class Flip(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Flip, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + return F.flip(x, axis=0) + + def get_net(num_ops): + net = nn.HybridSequential() + with net.name_scope(): + for _ in range(num_ops): + net.add(Flip()) + return net + + data_shape = (10,) + num_ops = 1000 + num_iterations = 20 + + # build model + x = mx.ndarray.zeros(data_shape) + x.attach_grad() + dy = mx.ndarray.ones(data_shape) + net = get_net(num_ops) + net.hybridize(static_alloc=True, static_shape=True) + + # time a number of forward() and backward() executions after some warm-up iterations + warmups = 1 + for i in range(num_iterations+warmups): + with autograd.record(): + if i == warmups: + start = time.time() + y = net(x) + y.backward(dy) + x.grad.wait_to_read() + + time_per_iteration.value = (time.time() - start) / num_iterations + +@with_seed() +def test_bulking(): + # test case format: (max_fwd_segment_size, max_bwd_segment_size) + test_cases = [(0,0), (1,1), (15,0), (0,15), (15,15)] + times = {} + times_str = '' + for seg_sizes in test_cases: + # Create shared variable to return measured time from test process + time_per_iteration = mp.Manager().Value('d', 0.0) + test_in_separate_process(_test_bulking_in_process, + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1]}, + time_per_iteration) + times[seg_sizes] = time_per_iteration.value + times_str += '\n runtime of (fwd,bwd) seg size max ({},{}) =\t{:.1f} msec'.format( + seg_sizes[0], seg_sizes[1], 1000.0 * times[seg_sizes]) + + fastest_non_bulked_time = min(times[(0,0)], times[(1,1)]) + slowest_half_bulked_time = max(times[(0,15)], times[(15,0)]) + fastest_half_bulked_time = min(times[(0,15)], times[(15,0)]) + fully_bulked_time = times[(15,15)] + + # The non-bulked times[0,0] and times[1,1] should be about the same, + # slower than both half-bulked times[0,15] and times[15,0] + assert slowest_half_bulked_time < fastest_non_bulked_time, \ + 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ + .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) + # The fully bulked time[15,15] should be faster than both half-bulked runs + assert fully_bulked_time < fastest_half_bulked_time, \ + 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ + .format(fully_bulked_time - fastest_half_bulked_time, times_str) + if __name__ == '__main__': import nose diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 010cf504fe70..c82e8419d8dc 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2067,6 +2067,80 @@ def test_bilinear_sampler_versions(): assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) +@with_seed() +def test_bulking(): + # Return the execution time of a model with the specified limits to the bulked op segments + def test_bulking_helper(data_shape, num_ops, num_iterations, + max_fwd_segment_size, max_bwd_segment_size): + orig_environ = os.environ.copy() + try: + # Explore different ways of setting the env vars. + # The framework does not cache the bulked seg size env var lookups during symbolic. + if max_fwd_segment_size == max_bwd_segment_size: + os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN'] = str(max_fwd_segment_size) + os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD', None) + os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD', None) + else: + os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN', None) + os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = str(max_fwd_segment_size) + os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = str(max_bwd_segment_size) + + ctx = default_context() + # build symbol + X = mx.sym.Variable('X') + sym = mx.sym.flip(X, axis=0) + for _ in range(num_ops-1): + sym = mx.sym.flip(sym, axis=0) + x = mx.ndarray.zeros(data_shape) + dx = mx.ndarray.zeros(data_shape) + dy = mx.ndarray.ones(data_shape) + exe = sym.bind(ctx=ctx, args=[x], args_grad = {'X':dx}) + + # time a number of forward() and backward() executions after some warm-up iterations + warmups = 1 + for i in range(num_iterations+warmups): + if i == warmups: + start = time.time() + exe.forward(is_train=True) + exe.backward(dy) + dx.wait_to_read() + time_per_iteration = (time.time() - start) / num_iterations + finally: + os.environ.clear() + os.environ.update(orig_environ) + return time_per_iteration + + data_shape = (10,) + num_ops = 1000 + num_iterations = 20 + + # test cases are (max_fwd_segment_size, max_bwd_segment_size) + test_cases = [(0,0), (1,1), (15,0), (0,15), (15,15)] + times = {} + times_str = '' + for seg_sizes in test_cases: + times[seg_sizes] = test_bulking_helper(data_shape, num_ops, num_iterations, + seg_sizes[0], seg_sizes[1]) + times_str += '\n runtime of (fwd,bwd) seg size max ({},{}) =\t{:.1f} msec'.format( + seg_sizes[0], seg_sizes[1], 1000.0 * times[seg_sizes]) + + fastest_non_bulked_time = min(times[(0,0)], times[(1,1)]) + slowest_half_bulked_time = max(times[(0,15)], times[(15,0)]) + fastest_half_bulked_time = min(times[(0,15)], times[(15,0)]) + fully_bulked_time = times[(15,15)] + + print(times_str) + # The non-bulked times[0,0] and times[1,1] should be about the same, + # slower than both half-bulked times[0,15] and times[15,0] + assert slowest_half_bulked_time < fastest_non_bulked_time,\ + 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}'\ + .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) + # The fully bulked time[15,15] should be faster than both half-bulked runs + assert fully_bulked_time < fastest_half_bulked_time,\ + 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}'\ + .format(fully_bulked_time - fastest_half_bulked_time, times_str) + + def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index abfba73ab727..e1736b47ec62 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -16,6 +16,7 @@ # under the License. import sys, os, logging +import multiprocessing as mp import mxnet as mx import numpy as np import random @@ -39,6 +40,7 @@ def assertRaises(expected_exception, func, *args, **kwargs): # Did not raise exception assert False, "%s did not raise %s" % (func.__name__, expected_exception.__name__) + def default_logger(): """A logger used to output seed information to nosetests logs.""" logger = logging.getLogger(__name__) @@ -51,6 +53,7 @@ def default_logger(): logger.setLevel(logging.INFO) return logger + @contextmanager def random_seed(seed=None): """ @@ -181,6 +184,7 @@ def test_new(*args, **kwargs): return test_new return test_helper + def setup_module(): """ A function with a 'magic name' executed automatically before each nosetests module @@ -265,3 +269,44 @@ def teardown(): It waits for all operations in one file to finish before carrying on the next. """ mx.nd.waitall() + + +def test_in_separate_process(func, env, *args): + """ + Helper function to run a test in its own process. + + Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend. + Adds a seed as first arg to propagate determinism. + + Parameters + ---------- + + func : function to run in a spawned process. + + env : dict of additional environment values to set temporarily in the environment before exec. + + args : args to pass to the function. + + This routine calculates a random seed and passes it into the test as a first argument. If the + test uses random values, it should include an outer 'with random_seed(seed):'. If the + test needs to return values to the caller, consider use of shared variable arguments. + """ + try: + mpctx = mp.get_context('spawn') + except: + print('SKIP: python%s.%s lacks the required process fork-exec support ... ' % + sys.version_info[0:2], file=sys.stderr, end='') + else: + seed = np.random.randint(0,1024*1024*1024) + orig_environ = os.environ.copy() + try: + for (key, value) in env.items(): + os.environ[key] = str(value) + # Prepend seed as first arg + p = mpctx.Process(target=func, args=(seed,)+args) + p.start() + p.join() + assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) + finally: + os.environ.clear() + os.environ.update(orig_environ) From c1b478593f685c9a30b6ea0e879527204a335639 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 25 Feb 2019 13:43:37 -0800 Subject: [PATCH 04/13] Rename test_in_separate_process -> run_in_spawned_process. --- tests/python/gpu/test_gluon_gpu.py | 8 +++++--- tests/python/unittest/common.py | 10 +++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index ae46ca7f266f..78c72bf18260 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -38,7 +38,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied -from common import test_in_separate_process +from common import run_in_spawned_process from test_gluon import * from test_loss import * from test_gluon_rnn import * @@ -458,10 +458,12 @@ def test_bulking(): for seg_sizes in test_cases: # Create shared variable to return measured time from test process time_per_iteration = mp.Manager().Value('d', 0.0) - test_in_separate_process(_test_bulking_in_process, + if not run_in_spawned_process(_test_bulking_in_process, {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1]}, - time_per_iteration) + time_per_iteration): + # skip test since the python version can't run it properly. Warning msg was logged. + return times[seg_sizes] = time_per_iteration.value times_str += '\n runtime of (fwd,bwd) seg size max ({},{}) =\t{:.1f} msec'.format( seg_sizes[0], seg_sizes[1], 1000.0 * times[seg_sizes]) diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index e1736b47ec62..104e7a073d6d 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -271,7 +271,7 @@ def teardown(): mx.nd.waitall() -def test_in_separate_process(func, env, *args): +def run_in_spawned_process(func, env, *args): """ Helper function to run a test in its own process. @@ -282,11 +282,13 @@ def test_in_separate_process(func, env, *args): ---------- func : function to run in a spawned process. - env : dict of additional environment values to set temporarily in the environment before exec. - args : args to pass to the function. + Returns + ------- + Whether the python version supports running the function as a spawned process. + This routine calculates a random seed and passes it into the test as a first argument. If the test uses random values, it should include an outer 'with random_seed(seed):'. If the test needs to return values to the caller, consider use of shared variable arguments. @@ -296,6 +298,7 @@ def test_in_separate_process(func, env, *args): except: print('SKIP: python%s.%s lacks the required process fork-exec support ... ' % sys.version_info[0:2], file=sys.stderr, end='') + return False else: seed = np.random.randint(0,1024*1024*1024) orig_environ = os.environ.copy() @@ -310,3 +313,4 @@ def test_in_separate_process(func, env, *args): finally: os.environ.clear() os.environ.update(orig_environ) + return True \ No newline at end of file From c9d0f2467658a906a22b67930b83a4ffb5eec26e Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 25 Feb 2019 14:04:46 -0800 Subject: [PATCH 05/13] Remove redundant util test_operator_gpu.py:_test_in_separate_process(). --- tests/python/gpu/test_operator_gpu.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 4cac645cbd01..71b202528833 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -33,6 +33,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied +from common import run_in_spawned_process from test_operator import * from test_optimizer import * from test_random import * @@ -521,24 +522,6 @@ def test_convolution_options(): check_consistency_NxM([sym, sym_no_cudnn], ctx_list) -# Helper function to run tests in a subprocess to avoid save/restore of os.environ. -# Also avoids issues of cached environment variable lookups in the backend. -def _test_in_separate_process(func, env, *args): - try: - mpctx = mp.get_context('spawn') - except: - print('SKIP: python%s.%s lacks the required process fork-exec support ... ' % - sys.version_info[0:2], file=sys.stderr, end='') - else: - seed = np.random.randint(0,1024*1024*1024) - for (key, value) in env.items(): - os.environ[key] = str(value) - # Prepend seed as first arg - p = mpctx.Process(target=func, args=(seed,)+args) - p.start() - p.join() - assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__) - def _conv_with_num_streams(seed): with random_seed(seed): # Try to expose timing-dependent improper workspace sharing by parallel dgrad and wgrad @@ -566,7 +549,7 @@ def _conv_with_num_streams(seed): def test_convolution_multiple_streams(): for num_streams in [1, 2]: for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']: - _test_in_separate_process(_conv_with_num_streams, + run_in_spawned_process(_conv_with_num_streams, {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine}) From f0e1f530a80869af40da83b338cd393c8a49e450 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 25 Feb 2019 17:45:18 -0800 Subject: [PATCH 06/13] Consolidate references to env vars that set op-bulking policy. --- include/mxnet/imperative.h | 24 ++++++++++++++++++++---- src/executor/graph_executor.cc | 18 +++++++----------- src/imperative/cached_op.cc | 11 ++++++++++- src/imperative/cached_op.h | 6 ++---- 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 566103982da7..52cedb2fadd9 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -129,15 +129,31 @@ class Imperative { bool create_graph); /*! \return AutogradRuntime singleton */ static Imperative* Get(); + /*! \brief Should op execution bulking be employed during inference. */ + static bool PreferBulkExecInference() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true); + } + /*! \brief Should op execution bulking be employed during training. */ + static bool PreferBulkExecTrain() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", true); + } + /*! \brief The max number of op nodes in a bulk during forward pass of training. */ + static int BulkExecMaxNodeTrainFwd() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)); + } + /*! \brief The max number of op nodes in a bulk during backward pass of training. */ + static int BulkExecMaxNodeTrainBwd() { + return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", + dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)); + } private: friend class NDArray; /*! \brief make constructor protected. */ Imperative() { - if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) { - backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", - dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15)); - } + if (PreferBulkExecTrain()) + backward_bulk_size_ = BulkExecMaxNodeTrainBwd(); } /*! \brief find the input/output ndarrays that are needed for backward */ void GetBackwardDependency( diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 70f8e9d4d020..3e091ccde9f3 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1191,16 +1191,17 @@ void GraphExecutor::InitOpSegs() { cached_seg_opr_.resize(total_num_nodes, p); if (monitor_callback_) return; + // Symbolic bulking is set by the same environment variables as Imperative bulking. // Generate segments based on the graph structure - bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true); + bool prefer_bulk_exec_inference = Imperative::PreferBulkExecInference(); // Whether to perform bulk exec for training const profiler::Profiler *prof = profiler::Profiler::Get(); - bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1) - && (!prof || !prof->AggregateEnabled()); + bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain() + && (!prof || !prof->AggregateEnabled()); bool is_training = num_forward_nodes_ != total_num_nodes; - if (prefer_bulk_exec && is_training) { + if (prefer_bulk_exec_train && is_training) { this->BulkTrainingOpSegs(total_num_nodes); } @@ -1211,15 +1212,10 @@ void GraphExecutor::InitOpSegs() { void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { - // The maximum number of nodes in a segment executed in bulk (excluding variables). - size_t segment_num_nodes_threshold = - dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15); // The maximum number of nodes in a segment executed in bulk (excluding variables) in fwd pass. - size_t segment_num_nodes_threshold_fwd = - dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD", segment_num_nodes_threshold); + size_t segment_num_nodes_threshold_fwd = Imperative::BulkExecMaxNodeTrainFwd(); // The maximum number of nodes in a segment executed in bulk (excluding variables) in bwd pass. - size_t segment_num_nodes_threshold_bwd = - dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", segment_num_nodes_threshold); + size_t segment_num_nodes_threshold_bwd = Imperative::BulkExecMaxNodeTrainBwd(); // create forward segments for training size_t topo_start = 0; diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 8dd0a4deaac3..1f5a7e5c5a30 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -590,9 +590,18 @@ void CachedOp::StaticInitExec( SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs); } + // Init bulk_size for Inference mode with bulking enabled (= entire forward graph). size_t bulk_size = idx.num_nodes(); if (recording || keep_fwd) { - bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; + // Training mode + if (!Imperative::PreferBulkExecTrain()) + bulk_size = 0; + else + bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; + } else { + // Inference mode + if (!Imperative::PreferBulkExecInference()) + bulk_size = 0; } CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 3fac1da2c34f..6e75c4687b79 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -53,12 +53,10 @@ struct CachedOpConfig : public dmlc::Parameter { .set_default(2) .describe("Maximum number of operators that can be inlined."); DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD", - dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))) + .set_default(Imperative::BulkExecMaxNodeTrainFwd()) .describe("Segment size of bulk execution during forward pass."); DMLC_DECLARE_FIELD(backward_bulk_size) - .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD", - dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))) + .set_default(Imperative::BulkExecMaxNodeTrainBwd()) .describe("Segment size of bulk execution during backward pass."); DMLC_DECLARE_FIELD(data_indices) .set_default(nnvm::Tuple()) From 18b444f2ca61d200e8d4665d159f31ff745d76a9 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 25 Feb 2019 18:42:37 -0800 Subject: [PATCH 07/13] Test for effect of MXNET_EXEC_BULK_EXEC_TRAIN=0. --- tests/python/gpu/test_gluon_gpu.py | 31 +++++++++++++++------------ tests/python/gpu/test_operator_gpu.py | 28 +++++++++++++----------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 78c72bf18260..88b436a0deb2 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -451,8 +451,8 @@ def get_net(num_ops): @with_seed() def test_bulking(): - # test case format: (max_fwd_segment_size, max_bwd_segment_size) - test_cases = [(0,0), (1,1), (15,0), (0,15), (15,15)] + # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) + test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] times = {} times_str = '' for seg_sizes in test_cases: @@ -460,25 +460,28 @@ def test_bulking(): time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(_test_bulking_in_process, {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1]}, + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1], + 'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]}, time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return times[seg_sizes] = time_per_iteration.value - times_str += '\n runtime of (fwd,bwd) seg size max ({},{}) =\t{:.1f} msec'.format( - seg_sizes[0], seg_sizes[1], 1000.0 * times[seg_sizes]) - - fastest_non_bulked_time = min(times[(0,0)], times[(1,1)]) - slowest_half_bulked_time = max(times[(0,15)], times[(15,0)]) - fastest_half_bulked_time = min(times[(0,15)], times[(15,0)]) - fully_bulked_time = times[(15,15)] - - # The non-bulked times[0,0] and times[1,1] should be about the same, - # slower than both half-bulked times[0,15] and times[15,0] + times_str += \ + '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( + seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) + + fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) + slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) + fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) + fully_bulked_time = times[(15,15,True)] + + print(times_str) + # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, + # slower than both half-bulked times[0,15,True] and times[15,0,True] assert slowest_half_bulked_time < fastest_non_bulked_time, \ 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) - # The fully bulked time[15,15] should be faster than both half-bulked runs + # The fully bulked times[15,15,True] should be faster than both half-bulked runs assert fully_bulked_time < fastest_half_bulked_time, \ 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ .format(fully_bulked_time - fastest_half_bulked_time, times_str) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 71b202528833..17a97ae90765 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2104,11 +2104,12 @@ def test_bilinear_sampler_versions(): def test_bulking(): # Return the execution time of a model with the specified limits to the bulked op segments def test_bulking_helper(data_shape, num_ops, num_iterations, - max_fwd_segment_size, max_bwd_segment_size): + max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training): orig_environ = os.environ.copy() try: # Explore different ways of setting the env vars. # The framework does not cache the bulked seg size env var lookups during symbolic. + os.environ['MXNET_EXEC_BULK_EXEC_TRAIN'] = str(enable_bulking_in_training) if max_fwd_segment_size == max_bwd_segment_size: os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN'] = str(max_fwd_segment_size) os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD', None) @@ -2147,28 +2148,29 @@ def test_bulking_helper(data_shape, num_ops, num_iterations, num_ops = 1000 num_iterations = 20 - # test cases are (max_fwd_segment_size, max_bwd_segment_size) - test_cases = [(0,0), (1,1), (15,0), (0,15), (15,15)] + # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) + test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] times = {} times_str = '' for seg_sizes in test_cases: times[seg_sizes] = test_bulking_helper(data_shape, num_ops, num_iterations, - seg_sizes[0], seg_sizes[1]) - times_str += '\n runtime of (fwd,bwd) seg size max ({},{}) =\t{:.1f} msec'.format( - seg_sizes[0], seg_sizes[1], 1000.0 * times[seg_sizes]) + seg_sizes[0], seg_sizes[1], seg_sizes[2]) + times_str +=\ + '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( + seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) - fastest_non_bulked_time = min(times[(0,0)], times[(1,1)]) - slowest_half_bulked_time = max(times[(0,15)], times[(15,0)]) - fastest_half_bulked_time = min(times[(0,15)], times[(15,0)]) - fully_bulked_time = times[(15,15)] + fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) + slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) + fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) + fully_bulked_time = times[(15,15,True)] print(times_str) - # The non-bulked times[0,0] and times[1,1] should be about the same, - # slower than both half-bulked times[0,15] and times[15,0] + # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, + # slower than both half-bulked times[0,15,True] and times[15,0,True] assert slowest_half_bulked_time < fastest_non_bulked_time,\ 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}'\ .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) - # The fully bulked time[15,15] should be faster than both half-bulked runs + # The fully bulked times[15,15,True] should be faster than both half-bulked runs assert fully_bulked_time < fastest_half_bulked_time,\ 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}'\ .format(fully_bulked_time - fastest_half_bulked_time, times_str) From 4bcef309c3402c4da8b7e486e0c837ebb6064c52 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 25 Feb 2019 19:03:20 -0800 Subject: [PATCH 08/13] Fix python2 print() issue. --- tests/python/unittest/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 104e7a073d6d..7cd637da3d4f 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from __future__ import print_function import sys, os, logging import multiprocessing as mp import mxnet as mx From f9594985062e3380161ba9ff02f898d56cec3e2f Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Wed, 27 Feb 2019 09:35:45 -0800 Subject: [PATCH 09/13] Trigger CI. From 320001a43fb58d96f661c355526f7f07d01b0c5b Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Sun, 3 Mar 2019 13:28:34 -0800 Subject: [PATCH 10/13] Consolidate similar op bulking routines. --- src/executor/graph_executor.cc | 67 +++++----------------------------- src/executor/graph_executor.h | 6 +-- 2 files changed, 12 insertions(+), 61 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 3e091ccde9f3..3ba899f43f3a 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1202,57 +1202,32 @@ void GraphExecutor::InitOpSegs() { bool is_training = num_forward_nodes_ != total_num_nodes; if (prefer_bulk_exec_train && is_training) { - this->BulkTrainingOpSegs(total_num_nodes); + // Bulk the forward portion of the graph per the bulk segment max size for forward training + this->BulkOpSegs(0, num_forward_nodes_, Imperative::BulkExecMaxNodeTrainFwd()); + // Bulk the backward portion of the graph per the bulk segment max size for backward training + this->BulkOpSegs(num_forward_nodes_, total_num_nodes, Imperative::BulkExecMaxNodeTrainBwd()); } if (prefer_bulk_exec_inference && !is_training) { - this->BulkInferenceOpSegs(); + // Bulk the entire graph as one bulk segment if possible + this->BulkOpSegs(0, total_num_nodes, total_num_nodes); } } -void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { - // The maximum number of nodes in a segment executed in bulk (excluding variables) in fwd pass. - size_t segment_num_nodes_threshold_fwd = Imperative::BulkExecMaxNodeTrainFwd(); - // The maximum number of nodes in a segment executed in bulk (excluding variables) in bwd pass. - size_t segment_num_nodes_threshold_bwd = Imperative::BulkExecMaxNodeTrainBwd(); - - // create forward segments for training - size_t topo_start = 0; +void GraphExecutor::BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max) { + size_t topo_start = from_node; size_t segment_node_count = 0; - for (size_t nid = 0; nid < num_forward_nodes_; nid++) { + for (size_t nid = from_node; nid < up_to_node; nid++) { auto &node = graph_.indexed_graph()[nid].source; auto &op_node = op_nodes_[nid]; // Variables, such as learned weights, are ignored in the segment_node_count - bool ignore_node = node->is_variable(); - if (!ignore_node) - segment_node_count++; - bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync; - // check if we need to create the segment based on properties of this node - if (!can_bulk || nid == num_forward_nodes_ - 1 || - segment_node_count >= segment_num_nodes_threshold_fwd) { - // Create a new segment for the previous nodes- include also this node if it's bulkable - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid); - topo_start = nid + 1; - segment_node_count = 0; - } - } - - // create backward segments for training - topo_start = num_forward_nodes_; - segment_node_count = 0; - for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) { - auto &node = graph_.indexed_graph()[nid].source; - auto &op_node = op_nodes_[nid]; - // Variables, such as learned weights, are ignored in the segment_node_count and - // nodes that are not executed for various reasons. bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr; if (!ignore_node) segment_node_count++; bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync; // check if we need to create the segment based on properties of this node - if (!can_bulk || nid == total_num_nodes - 1 || - segment_node_count >= segment_num_nodes_threshold_bwd) { + if (!can_bulk || nid == up_to_node - 1 || segment_node_count >= segment_num_nodes_max) { // Create a new segment for the previous nodes- include also this node if it's bulkable cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid); topo_start = nid + 1; @@ -1261,28 +1236,6 @@ void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) { } } -void GraphExecutor::BulkInferenceOpSegs() { - // Attempt to bulk the whole graph for inference. We will only create new segments when - // required for non-kSync operations. - size_t topo_start = 0; - for (size_t nid = 0; nid < num_forward_nodes_; nid++) { - auto &node = graph_.indexed_graph()[nid].source; - auto &op_node = op_nodes_[nid]; - - // Variables do not need to be segmented at inference time. - if (node->is_variable()) continue; - - if (op_node.exec->exec_type() != ExecType::kSync) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid); - topo_start = nid + 1; - } - } - // The last segment - if (topo_start != num_forward_nodes_) { - cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_); - } -} - void GraphExecutor::ExecuteMonInputCallback(size_t nid) { static const auto& flist_inputs = nnvm::Op::GetAttr("FListInputNames"); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index c899a6f5b463..a7436fb28d94 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -213,10 +213,8 @@ class GraphExecutor : public Executor { void ExecuteMonInputCallback(size_t nid); // run the monitor callback for output of node `nid` void ExecuteMonOutputCallback(size_t nid); - // peform bulking and segmentation on an inference graph - void BulkInferenceOpSegs(); - // perform bulking and segmentation on a training graph - void BulkTrainingOpSegs(size_t total_num_nodes); + // peform bulking and segmentation on the region [from_node, up_to_node) of a graph + void BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max); // indicate whether there is a backward graph for gradients. bool need_grad_; // internal graph From 79a22ed0704d564cfcaff167c6deccdf39db28c7 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Mon, 4 Mar 2019 09:45:24 -0800 Subject: [PATCH 11/13] Trigger CI. From bba19b74e6efecf57c2d544a1dfb6f46337c152a Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Tue, 5 Mar 2019 11:41:04 -0800 Subject: [PATCH 12/13] Trigger CI. From 30ccbedb299a4a77ccdb8ee3fcf99dd38b968182 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Wed, 6 Mar 2019 19:04:01 -0800 Subject: [PATCH 13/13] Add instrumentation to debug failing CI. --- tests/python/gpu/test_operator_gpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 94363860269d..7d7c2ed71216 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -559,8 +559,10 @@ def test_convolution_multiple_streams(): for num_streams in [1, 2]: for engine in engines: + print("Starting engine %s with %d streams." % (engine, num_streams), file=sys.stderr) run_in_spawned_process(_conv_with_num_streams, {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine}) + print("Finished engine %s with %d streams." % (engine, num_streams), file=sys.stderr) # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c.