diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 5d15b27fa7ea..05ffb1539cae 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -868,11 +868,13 @@ def _syncParameters(bn1, bn2, ctx): cfgs = [(1, False)] num_gpus = mx.context.num_gpus() + batch_size = 24 for i in range(1, num_gpus + 1): - cfgs.append((i, True)) + if batch_size % i == 0: + cfgs.append((i, True)) for ndev, cuda in cfgs: # check with unsync version - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: + for shape in [(batch_size, 2), (batch_size, 3, 4), (batch_size, 4, 4, 4), (batch_size, 5, 6, 4, 4)]: print(str((ndev, cuda, shape))) for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape,