diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index 375e753bec35..8bcb7a6992fb 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -50,6 +50,7 @@ import logging import json +import numpy as np from mxnet import ndarray as nd @@ -290,7 +291,7 @@ def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=Fals class NodeOutput: def __init__(self, name, dtype): self.name = name - self.dtype = dtype + self.dtype = np.dtype(dtype) initializer = [] all_processed_nodes = []