-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FRONTEND][TFLITE] get input tensor information from graph #7400
Changes from 7 commits
3d4e491
bd5a180
d4532d7
d58353b
e4621ec
86c70b7
7ec10c4
c954a95
dc91132
53e044a
2ff9b43
b80e798
bc89d27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3539,7 +3539,62 @@ def get_tensor_name(subgraph, tensor_idx): | |||||||||||||||||||||||||||||||||||||||||||||||||
return subgraph.Tensors(tensor_idx).Name().decode("utf-8") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
def from_tflite(model, shape_dict, dtype_dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||
def get_tensor_shape(subgraph, tensor_idx): | ||||||||||||||||||||||||||||||||||||||||||||||||||
"""Get the tensor shape. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||
---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||
subgraph: | ||||||||||||||||||||||||||||||||||||||||||||||||||
tflite.Subgraph.Subgraph | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||
tensor index in subgraph | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||||||||||||||||||||||||
tensor shape | ||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
return tuple(subgraph.Tensors(tensor_idx).ShapeAsNumpy()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
def get_tensor_type(subgraph, tensor_idx): | ||||||||||||||||||||||||||||||||||||||||||||||||||
"""Get the tensor type. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||
---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||
subgraph: | ||||||||||||||||||||||||||||||||||||||||||||||||||
tflite.Subgraph.Subgraph | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||
tensor index in subgraph | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||||||||||||||||||||||||
tensor type in string | ||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
from enum import Enum | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
class TensorType(Enum): | ||||||||||||||||||||||||||||||||||||||||||||||||||
""" Enum defined in tensorflow lite """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
FLOAT32 = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
FLOAT16 = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||
INT32 = 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||
UINT8 = 3 | ||||||||||||||||||||||||||||||||||||||||||||||||||
INT64 = 4 | ||||||||||||||||||||||||||||||||||||||||||||||||||
STRING = 5 | ||||||||||||||||||||||||||||||||||||||||||||||||||
BOOL = 6 | ||||||||||||||||||||||||||||||||||||||||||||||||||
INT16 = 7 | ||||||||||||||||||||||||||||||||||||||||||||||||||
COMPLEX64 = 8 | ||||||||||||||||||||||||||||||||||||||||||||||||||
INT8 = 9 | ||||||||||||||||||||||||||||||||||||||||||||||||||
FLOAT64 = 10 | ||||||||||||||||||||||||||||||||||||||||||||||||||
COMPLEX128 = 11 | ||||||||||||||||||||||||||||||||||||||||||||||||||
UINT64 = 12 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
return TensorType(subgraph.Tensors(tensor_idx).Type()).name.lower() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
def from_tflite(model, shape_dict=None, dtype_dict=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||
"""Convert from tflite model into compatible relay Function. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -3588,8 +3643,14 @@ def from_tflite(model, shape_dict, dtype_dict): | |||||||||||||||||||||||||||||||||||||||||||||||||
exp_tab = ExprTable() | ||||||||||||||||||||||||||||||||||||||||||||||||||
for model_input in model_inputs: | ||||||||||||||||||||||||||||||||||||||||||||||||||
model_input_name = get_tensor_name(subgraph, model_input) | ||||||||||||||||||||||||||||||||||||||||||||||||||
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" | ||||||||||||||||||||||||||||||||||||||||||||||||||
if shape_dict: | ||||||||||||||||||||||||||||||||||||||||||||||||||
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None | ||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
shape = get_tensor_shape(subgraph, model_input) | ||||||||||||||||||||||||||||||||||||||||||||||||||
if dtype_dict: | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" | ||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtype = get_tensor_type(subgraph, model_input) | ||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here. Can you have a look on the function I'm pointing here (below) and spot why are they so different, and in case you agree on what's the best approach, improve it here and remove it there? tvm/python/tvm/driver/tvmc/frontends.py Lines 255 to 278 in 2999d03
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh, it was there all along. I think I missed your code since I was loading my models in a separate script to put the relay output into my compile passes.
I don't see much difference except that your code accounts for models with more than one subgraph.
My rationale behind making and putting this code in the tflite.py file was:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool. I think we both agree that it is better to have the funcionality only in the tflite.py, and remove it from TVMC. So I suggest we keep the one that accounts for many subgraphs, and move it from TVMC to the official frontend? If you agree, feel free to do it in this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I will update the PR. |
||||||||||||||||||||||||||||||||||||||||||||||||||
exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
# op code in model | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please review all docstring being introduced here for those items above.