Skip to content

Commit

Permalink
[FRONTEND][TFLITE] get input tensor information from graph (#7400)
Browse files Browse the repository at this point in the history
* [FRONTEND][TFLITE] get input tensor information from graph

* remove bare-except

* fix lint

* delete empty line

* comment change

* move some of the tflite frontend code from tvmc to tflite.py

* update shape and dtype when user provided them

* remove unused var. pass user provided shape_dict

* remove duplicate code
  • Loading branch information
eric authored Feb 15, 2021
1 parent 2af3ab1 commit 6187e1c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 50 deletions.
48 changes: 1 addition & 47 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,6 @@ def load(self, path, shape_dict=None):
class TFLiteFrontend(Frontend):
""" TFLite frontend for TVMC """

_tflite_m = {
0: "float32",
1: "float16",
2: "int32",
3: "uint8",
4: "int64",
5: "string",
6: "bool",
7: "int16",
8: "complex64",
9: "int8",
}

@staticmethod
def name():
return "tflite"
Expand Down Expand Up @@ -241,43 +228,10 @@ def load(self, path, shape_dict=None):
if version != 3:
raise TVMCException("input file not tflite version 3")

logger.debug("tflite_input_type")
input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model)
if shape_dict is not None:
input_shapes.update(shape_dict)

logger.debug("parse TFLite model and convert into Relay computation graph")
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict
)
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict)
return mod, params

@staticmethod
def _decode_type(n):
return TFLiteFrontend._tflite_m[n]

@staticmethod
def _input_type(model):
subgraph_count = model.SubgraphsLength()
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
shape_dict[input_name] = input_shape
dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type)

return shape_dict, dtype_dict


class PyTorchFrontend(Frontend):
""" PyTorch frontend for TVMC """
Expand Down
50 changes: 47 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3539,7 +3539,45 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
def _decode_type(n):
_tflite_m = {
0: "float32",
1: "float16",
2: "int32",
3: "uint8",
4: "int64",
5: "string",
6: "bool",
7: "int16",
8: "complex64",
9: "int8",
}
return _tflite_m[n]


def _input_type(model):
subgraph_count = model.SubgraphsLength()
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
shape_dict[input_name] = input_shape
dtype_dict[input_name] = _decode_type(tensor_type)

return shape_dict, dtype_dict


def from_tflite(model, shape_dict=None, dtype_dict=None):
"""Convert from tflite model into compatible relay Function.
Parameters
Expand Down Expand Up @@ -3577,6 +3615,12 @@ def from_tflite(model, shape_dict, dtype_dict):

assert isinstance(model, tflite.Model.Model)

_shape_dict, _dtype_dict = _input_type(model)
if shape_dict is not None:
_shape_dict.update(shape_dict)
if dtype_dict is not None:
_dtype_dict.update(dtype_dict)

# keep the same as tflite
assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)"
subgraph = model.Subgraphs(0)
Expand All @@ -3588,8 +3632,8 @@ def from_tflite(model, shape_dict, dtype_dict):
exp_tab = ExprTable()
for model_input in model_inputs:
model_input_name = get_tensor_name(subgraph, model_input)
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32"
shape = _shape_dict[model_input_name] if model_input_name in _shape_dict else None
dtype = _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32"
exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype))

# op code in model
Expand Down

0 comments on commit 6187e1c

Please sign in to comment.