Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix test_gluon.py:test_sync_batchnorm when number of GPUS > 4 (#16835)
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy authored and ptrendx committed Nov 21, 2019
1 parent 530bd27 commit 33a3af9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,11 +868,13 @@ def _syncParameters(bn1, bn2, ctx):

cfgs = [(1, False)]
num_gpus = mx.context.num_gpus()
batch_size = 24
for i in range(1, num_gpus + 1):
cfgs.append((i, True))
if batch_size % i == 0:
cfgs.append((i, True))
for ndev, cuda in cfgs:
# check with unsync version
for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]:
for shape in [(batch_size, 2), (batch_size, 3, 4), (batch_size, 4, 4, 4), (batch_size, 5, 6, 4, 4)]:
print(str((ndev, cuda, shape)))
for i in range(10):
_check_batchnorm_result(mx.nd.random.uniform(shape=shape,
Expand Down

0 comments on commit 33a3af9

Please sign in to comment.