Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TFLITE] get input tensor information from graph #7400

Merged
merged 13 commits into from
Feb 15, 2021
67 changes: 64 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3539,7 +3539,62 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
def get_tensor_shape(subgraph, tensor_idx):
"""Get the tensor shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • minor: the name of the argument doesn't match with the actual argument
  • the types are not specified

Please review all docstring being introduced here for those items above.


Parameters
----------
subgraph:
tflite.Subgraph.Subgraph

tensor:
tensor index in subgraph

Returns
-------
tensor shape
"""
return tuple(subgraph.Tensors(tensor_idx).ShapeAsNumpy())


def get_tensor_type(subgraph, tensor_idx):
"""Get the tensor type.

Parameters
----------
subgraph:
tflite.Subgraph.Subgraph

tensor:
tensor index in subgraph

Returns
-------
tensor type in string
"""
from enum import Enum

class TensorType(Enum):
""" Enum defined in tensorflow lite """

FLOAT32 = 0
FLOAT16 = 1
INT32 = 2
UINT8 = 3
INT64 = 4
STRING = 5
BOOL = 6
INT16 = 7
COMPLEX64 = 8
INT8 = 9
FLOAT64 = 10
COMPLEX128 = 11
UINT64 = 12

return TensorType(subgraph.Tensors(tensor_idx).Type()).name.lower()


def from_tflite(model, shape_dict=None, dtype_dict=None):
"""Convert from tflite model into compatible relay Function.

Parameters
Expand Down Expand Up @@ -3588,8 +3643,14 @@ def from_tflite(model, shape_dict, dtype_dict):
exp_tab = ExprTable()
for model_input in model_inputs:
model_input_name = get_tensor_name(subgraph, model_input)
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32"
if shape_dict:
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None
else:
shape = get_tensor_shape(subgraph, model_input)
if dtype_dict:
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32"
else:
dtype = get_tensor_type(subgraph, model_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here.

Can you have a look on the function I'm pointing here (below) and spot why are they so different, and in case you agree on what's the best approach, improve it here and remove it there?

@staticmethod
def _decode_type(n):
return TFLiteFrontend._tflite_m[n]
@staticmethod
def _input_type(model):
subgraph_count = model.SubgraphsLength()
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
shape_dict[input_name] = input_shape
dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type)

Copy link
Contributor Author

@euntaik euntaik Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here.

Oh, it was there all along. I think I missed your code since I was loading my models in a separate script to put the relay output into my compile passes.

Can you have a look on the function I'm pointing here (below) and spot why are they so different,

I don't see much difference except that your code accounts for models with more than one subgraph.

and in case you agree on what's the best approach, improve it here and remove it there?

My rationale behind making and putting this code in the tflite.py file was:

  1. use the data in the graph since it is already embedded in it.
  2. place the code inside the frontend code since it is dependent on the frontend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. I think we both agree that it is better to have the funcionality only in the tflite.py, and remove it from TVMC.

So I suggest we keep the one that accounts for many subgraphs, and move it from TVMC to the official frontend? If you agree, feel free to do it in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will update the PR.

exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype))

# op code in model
Expand Down