Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TFLITE] get input tensor information from graph #7400

Merged
merged 13 commits into from
Feb 15, 2021
48 changes: 1 addition & 47 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,6 @@ def load(self, path, shape_dict=None):
class TFLiteFrontend(Frontend):
""" TFLite frontend for TVMC """

_tflite_m = {
0: "float32",
1: "float16",
2: "int32",
3: "uint8",
4: "int64",
5: "string",
6: "bool",
7: "int16",
8: "complex64",
9: "int8",
}

@staticmethod
def name():
return "tflite"
Expand Down Expand Up @@ -241,43 +228,10 @@ def load(self, path, shape_dict=None):
if version != 3:
raise TVMCException("input file not tflite version 3")

logger.debug("tflite_input_type")
input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model)
if shape_dict is not None:
input_shapes.update(shape_dict)

logger.debug("parse TFLite model and convert into Relay computation graph")
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict
)
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict)
return mod, params

@staticmethod
def _decode_type(n):
return TFLiteFrontend._tflite_m[n]

@staticmethod
def _input_type(model):
subgraph_count = model.SubgraphsLength()
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
shape_dict[input_name] = input_shape
dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type)

return shape_dict, dtype_dict


class PyTorchFrontend(Frontend):
""" PyTorch frontend for TVMC """
Expand Down
50 changes: 47 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3539,7 +3539,45 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
def _decode_type(n):
_tflite_m = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is duplicated in tvmc/frontends.py - is there any reason why we can't reuse this one there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it. Fixed.

0: "float32",
1: "float16",
2: "int32",
3: "uint8",
4: "int64",
5: "string",
6: "bool",
7: "int16",
8: "complex64",
9: "int8",
}
return _tflite_m[n]


def _input_type(model):
subgraph_count = model.SubgraphsLength()
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
shape_dict[input_name] = input_shape
dtype_dict[input_name] = _decode_type(tensor_type)

return shape_dict, dtype_dict


def from_tflite(model, shape_dict=None, dtype_dict=None):
"""Convert from tflite model into compatible relay Function.

Parameters
Expand Down Expand Up @@ -3577,6 +3615,12 @@ def from_tflite(model, shape_dict, dtype_dict):

assert isinstance(model, tflite.Model.Model)

_shape_dict, _dtype_dict = _input_type(model)
if shape_dict is not None:
_shape_dict.update(shape_dict)
if dtype_dict is not None:
_dtype_dict.update(dtype_dict)

# keep the same as tflite
assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)"
subgraph = model.Subgraphs(0)
Expand All @@ -3588,8 +3632,8 @@ def from_tflite(model, shape_dict, dtype_dict):
exp_tab = ExprTable()
for model_input in model_inputs:
model_input_name = get_tensor_name(subgraph, model_input)
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32"
shape = _shape_dict[model_input_name] if model_input_name in _shape_dict else None
dtype = _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32"
exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype))

# op code in model
Expand Down