Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix for enable model parallelism for non-fp32 data
Browse files Browse the repository at this point in the history
  • Loading branch information
Asmus Hetzel committed Oct 31, 2019
1 parent 27bddf8 commit 2fdef62
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
10 changes: 10 additions & 0 deletions src/operator/cross_device_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class CrossDeviceCopyProp : public OperatorProperty {
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const {
CHECK_EQ(in_type->size(), 1) << "Input:[data]";
if (in_type->at(0) == -1) return false;
out_type->clear();
out_type->push_back(in_type->at(0));
return true;
}

OperatorProperty* Copy() const override {
return new CrossDeviceCopyProp();
}
Expand Down
28 changes: 17 additions & 11 deletions tests/python/unittest/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import mxnet as mx
from mxnet.test_utils import *

def reldiff(a, b):
diff = np.sum(np.abs(a - b))
Expand All @@ -26,13 +27,11 @@ def reldiff(a, b):
reldiff = diff / norm
return reldiff

def test_chain():
ctx1 = mx.cpu(0)
ctx2 = mx.cpu(1)
def test_chain(ctx1=mx.cpu(0), ctx2=mx.cpu(1), dtype=np.float32):
n = 2
data1 = mx.sym.Variable('data1')
data2 = mx.sym.Variable('data2')
data3 = mx.sym.Variable('data3')
data1 = mx.sym.Variable('data1', dtype=dtype)
data2 = mx.sym.Variable('data2', dtype=dtype)
data3 = mx.sym.Variable('data3', dtype=dtype)
with mx.AttrScope(ctx_group='dev1'):
net = data1 + data2
net = net * 3
Expand All @@ -45,11 +44,11 @@ def test_chain():
shape = (4, 5)
with mx.Context(ctx1):
for i in range(n):
arr.append(mx.nd.empty(shape))
arr_grad.append(mx.nd.empty(shape))
arr.append(mx.nd.empty(shape, dtype=dtype))
arr_grad.append(mx.nd.empty(shape, dtype=dtype))
with mx.Context(ctx2):
arr.append(mx.nd.empty(shape))
arr_grad.append(mx.nd.empty(shape))
arr.append(mx.nd.empty(shape, dtype=dtype))
arr_grad.append(mx.nd.empty(shape, dtype=dtype))

exec1 = net.bind(ctx1,
args=arr,
Expand All @@ -76,6 +75,13 @@ def test_chain():
for a, b in zip(arr_grad, arr_grad2):
assert reldiff(a.asnumpy(), b.asnumpy()) < 1e-6

def test_chain_type_device():
ctx_pairs = [(mx.cpu(0), mx.cpu(1))]
if default_context().device_type == 'gpu':
ctx_pairs = ctx_pairs + [(mx.gpu(0), mx.gpu(0)), (mx.cpu(0), mx.gpu(0)), (mx.gpu(0), mx.cpu(0))]
for ctx1, ctx2 in ctx_pairs:
for dtype in [np.float16, np.float32, np.float64]:
test_chain(ctx1, ctx2, dtype)

if __name__ == '__main__':
test_chain()
test_chain_type_device()

0 comments on commit 2fdef62

Please sign in to comment.