Skip to content

Commit

Permalink
[frontend] add support for pb graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafael Stahl committed Jun 29, 2022
1 parent ba0c6a4 commit d4c1c4e
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 9 deletions.
7 changes: 5 additions & 2 deletions mlonmcu/flow/tvm/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mlonmcu.setup import utils
from mlonmcu.config import str2bool
from mlonmcu.models.model import ModelFormats
from .model_info import get_tflite_model_info, get_relay_model_info
from .model_info import get_tflite_model_info, get_relay_model_info, get_pb_model_info
from .tuner import TVMTuner
from .python_utils import prepare_python_environment
from .tvmc_utils import (
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, features=None, config=None, context=None):
self.model = None # Actual filename!
self.model_info = None
self.input_shapes = None
self.supported_formats = [ModelFormats.TFLITE, ModelFormats.RELAY]
self.supported_formats = [ModelFormats.TFLITE, ModelFormats.RELAY, ModelFormats.PB]

self.prefix = "default"
self.artifacts = (
Expand Down Expand Up @@ -254,6 +254,9 @@ def load_model(self, model):
with open(model, "r") as handle:
mod_text = handle.read()
self.model_info = get_relay_model_info(mod_text)
elif fmt == ModelFormats.PB:
self.model_format = "pb"
self.model_info = get_pb_model_info(model)
else:
raise RuntimeError(f"Unsupported model format '{fmt.name}' for backend '{self.name}'")
self.input_shapes = {tensor.name: tensor.shape for tensor in self.model_info.in_tensors}
34 changes: 34 additions & 0 deletions mlonmcu/flow/tvm/backend/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import tflite
from tflite.TensorType import TensorType as TType
import tensorflow as tf


class TensorInfo:
Expand Down Expand Up @@ -143,6 +144,34 @@ def __init__(self, mod_text, fix_names=False):
super().__init__(in_tensors, out_tensors)


def get_tfgraph_inout(graph, graph_def):
ops = graph.get_operations()
outputs_set = set(ops)
inputs = []
for op in ops:
if op.type == "Placeholder":
inputs.append(op)
else:
for input_tensor in op.inputs:
if input_tensor.op in outputs_set:
outputs_set.remove(input_tensor.op)
outputs = list(outputs_set)
return (inputs, outputs)


class PBModelInfo(ModelInfo):
def __init__(self, model_file):
with tf.io.gfile.GFile(model_file, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
inputs, outputs = get_tfgraph_inout(graph, graph_def)
in_tensors = [TensorInfo(t.name, t.shape.as_list(), t.dtype.name) for op in inputs for t in op.outputs]
out_tensors = [TensorInfo(t.name, t.shape.as_list(), t.dtype.name) for op in outputs for t in op.outputs]
super().__init__(in_tensors, out_tensors)


def get_tflite_model_info(model_buf):
tflite_model = tflite.Model.GetRootAsModel(model_buf, 0)
model_info = TfLiteModelInfo(tflite_model)
Expand All @@ -152,3 +181,8 @@ def get_tflite_model_info(model_buf):
def get_relay_model_info(mod_text):
model_info = RelayModelInfo(mod_text)
return model_info


def get_pb_model_info(model_file):
model_info = PBModelInfo(model_file)
return model_info
5 changes: 3 additions & 2 deletions mlonmcu/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
# limitations under the License.
#
from mlonmcu.models.lookup import print_summary
from .frontend import TfLiteFrontend, PackedFrontend, ONNXFrontend, RelayFrontend
from .frontend import PBFrontend, TfLiteFrontend, PackedFrontend, ONNXFrontend, RelayFrontend

SUPPORTED_FRONTENDS = {
"tflite": TfLiteFrontend,
"relay": RelayFrontend,
"packed": PackedFrontend,
"onnx": ONNXFrontend,
"pb": PBFrontend,
} # TODO: use registry instead

__all__ = ["print_summary", "TfLiteFrontend", "PackedFrontend", "ONNXFrontend", "SUPPORTED_FRONTENDS"]
__all__ = ["print_summary", "TfLiteFrontend", "PackedFrontend", "ONNXFrontend", "PBFrontend", "SUPPORTED_FRONTENDS"]
24 changes: 21 additions & 3 deletions mlonmcu/models/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,27 @@ class ONNXFrontend(SimpleFrontend):

def __init__(self, features=None, config=None):
super().__init__(
name="onnx",
input_formats=[ModelFormats.ONNX],
output_formats=[ModelFormats.ONNX],
"onnx",
ModelFormats.ONNX,
features=features,
config=config,
)


class PBFrontend(SimpleFrontend):

FEATURES = Frontend.FEATURES + ["visualize"]

DEFAULTS = {
**Frontend.DEFAULTS,
}

REQUIRED = Frontend.REQUIRED + []

def __init__(self, features=None, config=None):
super().__init__(
"pb",
ModelFormats.PB,
features=features,
config=config,
)
2 changes: 1 addition & 1 deletion mlonmcu/models/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def lookup_models_and_groups(directories, formats):
duplicates = {}
group_duplicates = {}
for directory in directories:
models = list_models(directory)
models = list_models(directory, formats=formats)
if len(all_models) == 0:
all_models = models
else:
Expand Down
1 change: 1 addition & 0 deletions mlonmcu/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def from_extension(cls, ext):
IPYNB = ModelFormat(3, ["ipynb"])
ONNX = ModelFormat(4, ["onnx"])
RELAY = ModelFormat(5, ["relay"])
PB = ModelFormat(6, ["pb"])


def parse_metadata_from_path(path):
Expand Down
2 changes: 1 addition & 1 deletion mlonmcu/session/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def add_backend(self, backend):
# assert len(self.platforms) > 0, "Add at least a platform before adding a backend."
if self.model is not None:
assert self.backend.supports_model(self.model), (
"The added backend does not support the chosen model."
"The added backend does not support the chosen model. "
"Add the backend before adding a model to find a suitable frontend."
)
for platform in self.platforms:
Expand Down

0 comments on commit d4c1c4e

Please sign in to comment.