diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index d0830dcc8cae..6cb9fc690b5a 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -26,6 +26,7 @@ import re from collections import OrderedDict +from ..base import mx_real_t from .. import symbol, ndarray, initializer from ..symbol import Symbol from ..ndarray import NDArray @@ -1053,13 +1054,20 @@ def __init__(self, outputs, inputs, params=None): "SymbolBlock doesn't support Parameter '%s' because its storage " \ "type is 'row_sparse'." % j.name - for i in out.list_arguments(): - if i not in input_names: - self.params.get(i, allow_deferred_init=True) + # Infer type of parameters. Without this, every parameter will be created with + # default type i.e., fp32 + arg_params = out.list_arguments() + aux_params = out.list_auxiliary_states() - for i in out.list_auxiliary_states(): - if i not in input_names: - self.params.get(i, grad_req='null', allow_deferred_init=True) + arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) + + for i, arg in enumerate(arg_params): + if arg not in input_names: + self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) + + for i, aux in enumerate(aux_params): + if aux not in input_names: + self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) self._cached_graph = syms, out len_prefix = len(_common_prefix(list(self._params.keys()))) @@ -1084,5 +1092,71 @@ def _clear_cached_op(self): super(SymbolBlock, self)._clear_cached_op() self._cached_graph = tmp + def cast(self, dtype): + self._clear_cached_op() + super(SymbolBlock, self).cast(dtype) + def hybrid_forward(self, F, x, *args, **kwargs): raise NotImplementedError + +def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t): + """Utility function that helps in inferring DType of args and auxs params + from given input param. + + Parameters + ---------- + in_params: List of Symbol + List of input symbol variables. + out_params: Symbol + Output symbol variable. + arg_params: List of Str + List of names of argument parametrs. + aux_params: List of Str + List of names of auxiliary parameters. + default_dtype: numpy.dtype or str, default 'float32' + Default data type for arg_params and aux_params, if unable to infer the type. + + Returns + ------- + arg_types: List of numpy.dtype + List of arg_params type. Order is same as arg_params. + Defaults to 'float32', if unable to infer type. + aux_types: List of numpy.dtype + List of aux_params type. Order is same as aux_params. + Defaults to 'float32', if unable to infer type. + """ + arg_types = None + aux_types = None + + # Get Input symbol details. This will be used to infer types of + # other parameters. + input_sym_names = [in_param.name for in_param in in_params] + + # Try to infer input types. If not successful, we will set default dtype. + # If successful, we will try to infer other params in the graph. + input_sym_arg_types = [] + can_infer_input_type = True + for in_param in in_params: + input_sym_arg_type = in_param.infer_type()[0] + if not input_sym_arg_type or len(input_sym_arg_type) < 1: + can_infer_input_type = False + break + else: + input_sym_arg_types.append(in_param.infer_type()[0][0]) + + # Try to infer types of other parameters. + if can_infer_input_type: + params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)} + arg_types, _, aux_types = out_params.infer_type(**params) + + if arg_types is None or len(arg_types) != len(arg_params): + arg_types = [] + for _ in arg_params: + arg_types.append(default_dtype) + + if aux_types is None or len(aux_types) != len(aux_params): + aux_types = [] + for _ in aux_params: + aux_types.append(default_dtype) + + return (arg_types, aux_types) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 24c86f4e0fa7..f53eeb00694a 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -727,6 +727,8 @@ def get(self, name, **kwargs): if matched: param._shape = tuple(inferred_shape) continue + elif k == 'dtype' and np.dtype(v) == np.dtype(existing): + continue assert v is None or v == existing, \ "Cannot retrieve Parameter '%s' because desired attribute " \ diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 42d65dab5fdc..ac7df6257967 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -18,6 +18,7 @@ from __future__ import print_function import sys import os +import tempfile import time import multiprocessing as mp import unittest @@ -198,6 +199,36 @@ def get_num_devices(): _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), num_devices=ndev, cuda=True) +@with_seed() +def test_symbol_block_fp16(): + # Test case to verify if initializing the SymbolBlock from a model with params + # other than fp32 param dtype. + + # 1. Load a resnet model, cast it to fp16 and export + tmp = tempfile.mkdtemp() + tmpfile = os.path.join(tmp, 'resnet34_fp16') + ctx = mx.gpu(0) + + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) + net_fp32.cast('float16') + net_fp32.hybridize() + data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx) + net_fp32.forward(data) + net_fp32.export(tmpfile, 0) + + # 2. Load the saved model and verify if all the params are loaded correctly. + # and choose one of the param to verify the type if fp16. + sm = mx.sym.load(tmpfile + '-symbol.json') + inputs = mx.sym.var('data', dtype='float16') + net_fp16 = mx.gluon.SymbolBlock(sm, inputs) + net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx) + # 3. Get a conv layer's weight parameter name. Conv layer's weight param is + # expected to be of dtype casted, fp16. + for param_name in net_fp16.params.keys(): + if 'conv' in param_name and 'weight' in param_name: + break + assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 61b441a5f842..4e13fc38e87b 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import os +import tempfile + import mxnet as mx from mxnet import gluon from mxnet.gluon import nn @@ -336,6 +339,41 @@ def hybrid_forward(self, F, x): net.hybridize() assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray) + # Test case to verify if initializing the SymbolBlock from a model with params + # other than fp32 param dtype. + + # 1. Load a resnet model, cast it to fp64 and export + tmp = tempfile.mkdtemp() + tmpfile = os.path.join(tmp, 'resnet34_fp64') + ctx = mx.cpu(0) + + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) + net_fp32.cast('float64') + net_fp32.hybridize() + data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx) + net_fp32.forward(data) + net_fp32.export(tmpfile, 0) + + # 2. Load the saved model and verify if all the params are loaded correctly. + # and choose one of the param to verify the type if fp64. + sm = mx.sym.load(tmpfile + '-symbol.json') + inputs = mx.sym.var('data', dtype='float64') + net_fp64 = mx.gluon.SymbolBlock(sm, inputs) + net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx) + # 3. Get a conv layer's weight parameter name. Conv layer's weight param is + # expected to be of dtype casted, fp64. + for param_name in net_fp64.params.keys(): + if 'conv' in param_name and 'weight' in param_name: + break + assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64) + + # Cast the symbol block to FP32 and try to forward a FP32 data. + # This will verify SymbolBlock.cast() functionality. + net_fp64.cast('float32') + fp32_data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx) + prediction = net_fp64.forward(fp32_data) + assert np.dtype(prediction.dtype) == np.dtype(np.float32) + @with_seed() @raises(AssertionError) def test_sparse_symbol_block():