Skip to content

Commit

Permalink
[models] fix bug in model lookup with multiple formats
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafael Stahl committed Jun 29, 2022
1 parent 6a1847f commit 2a7cd76
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mlonmcu/flow/tvm/backend/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, mod_text, fix_names=False):
super().__init__(in_tensors, out_tensors)


def get_tfgraph_inout(graph, graph_def):
def get_tfgraph_inout(graph):
ops = graph.get_operations()
outputs_set = set(ops)
inputs = []
Expand All @@ -166,7 +166,7 @@ def __init__(self, model_file):
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
inputs, outputs = get_tfgraph_inout(graph, graph_def)
inputs, outputs = get_tfgraph_inout(graph)
in_tensors = [TensorInfo(t.name, t.shape.as_list(), t.dtype.name) for op in inputs for t in op.outputs]
out_tensors = [TensorInfo(t.name, t.shape.as_list(), t.dtype.name) for op in outputs for t in op.outputs]
super().__init__(in_tensors, out_tensors)
Expand Down
2 changes: 1 addition & 1 deletion mlonmcu/models/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def list_models(directory, depth=1, formats=None, config=None): # TODO: get con
config = config if config is not None else {}
formats = formats if formats else [ModelFormats.TFLITE]
assert len(formats) > 0, "No formats provided for model lookup"
models = []
for fmt in formats:
if depth != 1:
raise NotImplementedError # TODO: implement for arm ml zoo
Expand All @@ -68,7 +69,6 @@ def list_models(directory, depth=1, formats=None, config=None): # TODO: get con
logger.debug("Not a directory: %s", str(directory))
return []
subdirs = [Path(directory) / o for o in os.listdir(directory) if os.path.isdir(os.path.join(directory, o))]
models = []
for subdir in subdirs:
dirname = subdir.name
if dirname.startswith("."):
Expand Down

0 comments on commit 2a7cd76

Please sign in to comment.