From 7fc625599cc8654af3e7da9a70114361be397b3a Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Sat, 16 Nov 2019 01:30:40 +0000 Subject: [PATCH] Fix test_gluon.py:test_sync_batchnorm when number of GPUS > 4 --- tests/python/unittest/test_gluon.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 5d15b27fa7ea..05ffb1539cae 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -868,11 +868,13 @@ def _syncParameters(bn1, bn2, ctx): cfgs = [(1, False)] num_gpus = mx.context.num_gpus() + batch_size = 24 for i in range(1, num_gpus + 1): - cfgs.append((i, True)) + if batch_size % i == 0: + cfgs.append((i, True)) for ndev, cuda in cfgs: # check with unsync version - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: + for shape in [(batch_size, 2), (batch_size, 3, 4), (batch_size, 4, 4, 4), (batch_size, 5, 6, 4, 4)]: print(str((ndev, cuda, shape))) for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape,