-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FRONTEND][TFLITE] get input tensor information from graph #7400
Conversation
merge upstream main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I think it is an improvement to have a default function to come up with shapes in case we don't provide them, so thanks for the initiative @euntaik.
I'm pointing to some similar logic we have somewhere else in the code base that can be ported here. Please have a look.
I any case, it would be good to come up with some testing also to make sure this doesn't break in future.
python/tvm/relay/frontend/tflite.py
Outdated
@@ -3539,7 +3539,62 @@ def get_tensor_name(subgraph, tensor_idx): | |||
return subgraph.Tensors(tensor_idx).Name().decode("utf-8") | |||
|
|||
|
|||
def from_tflite(model, shape_dict, dtype_dict): | |||
def get_tensor_shape(subgraph, tensor_idx): | |||
"""Get the tensor shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- minor: the name of the argument doesn't match with the actual argument
- the types are not specified
Please review all docstring being introduced here for those items above.
python/tvm/relay/frontend/tflite.py
Outdated
if shape_dict: | ||
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None | ||
else: | ||
shape = get_tensor_shape(subgraph, model_input) | ||
if dtype_dict: | ||
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" | ||
else: | ||
dtype = get_tensor_type(subgraph, model_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here.
Can you have a look on the function I'm pointing here (below) and spot why are they so different, and in case you agree on what's the best approach, improve it here and remove it there?
tvm/python/tvm/driver/tvmc/frontends.py
Lines 255 to 278 in 2999d03
@staticmethod | |
def _decode_type(n): | |
return TFLiteFrontend._tflite_m[n] | |
@staticmethod | |
def _input_type(model): | |
subgraph_count = model.SubgraphsLength() | |
assert subgraph_count > 0 | |
shape_dict = {} | |
dtype_dict = {} | |
for subgraph_index in range(subgraph_count): | |
subgraph = model.Subgraphs(subgraph_index) | |
inputs_count = subgraph.InputsLength() | |
assert inputs_count >= 1 | |
for input_index in range(inputs_count): | |
input_ = subgraph.Inputs(input_index) | |
assert subgraph.TensorsLength() > input_ | |
tensor = subgraph.Tensors(input_) | |
input_shape = tuple(tensor.ShapeAsNumpy()) | |
tensor_type = tensor.Type() | |
input_name = tensor.Name().decode("utf8") | |
shape_dict[input_name] = input_shape | |
dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here.
Oh, it was there all along. I think I missed your code since I was loading my models in a separate script to put the relay output into my compile passes.
Can you have a look on the function I'm pointing here (below) and spot why are they so different,
I don't see much difference except that your code accounts for models with more than one subgraph.
and in case you agree on what's the best approach, improve it here and remove it there?
My rationale behind making and putting this code in the tflite.py file was:
- use the data in the graph since it is already embedded in it.
- place the code inside the frontend code since it is dependent on the frontend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. I think we both agree that it is better to have the funcionality only in the tflite.py, and remove it from TVMC.
So I suggest we keep the one that accounts for many subgraphs, and move it from TVMC to the official frontend? If you agree, feel free to do it in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will update the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better now. A few comments below, mostly on TVMC area.
@@ -3539,7 +3539,45 @@ def get_tensor_name(subgraph, tensor_idx): | |||
return subgraph.Tensors(tensor_idx).Name().decode("utf-8") | |||
|
|||
|
|||
def from_tflite(model, shape_dict, dtype_dict): | |||
def _decode_type(n): | |||
_tflite_m = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this is duplicated in tvmc/frontends.py - is there any reason why we can't reuse this one there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed it. Fixed.
python/tvm/driver/tvmc/frontends.py
Outdated
mod, params = relay.frontend.from_tflite( | ||
tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict | ||
) | ||
mod, params = relay.frontend.from_tflite(tflite_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we merged #7366, users are able to provide shapes in tvmc from outside, can you have a look on that one and adjust?
mod, params = relay.frontend.from_tflite(tflite_model) | |
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict) |
cc @CircleSpin @hogepodge to help
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the from_tflite()
now duplicated? (I just looked quickly, might be wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Sorry for that.
Any other comments? |
Sorry I forgot to check this again after CI. |
@mbaret @FrozenGene can you have a look on this one, and merge if you think it is ok? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a good change to me.
* [FRONTEND][TFLITE] get input tensor information from graph * remove bare-except * fix lint * delete empty line * comment change * move some of the tflite frontend code from tvmc to tflite.py * update shape and dtype when user provided them * remove unused var. pass user provided shape_dict * remove duplicate code
* [FRONTEND][TFLITE] get input tensor information from graph * remove bare-except * fix lint * delete empty line * comment change * move some of the tflite frontend code from tvmc to tflite.py * update shape and dtype when user provided them * remove unused var. pass user provided shape_dict * remove duplicate code
get input tensor information from graph