From d4338a952246b3bd3125ccf22459797823f9dc19 Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Thu, 7 Dec 2023 12:29:19 +0100 Subject: [PATCH] RelayModelInfo: support non-tensor outputs Thanks to @jokap11 --- mlonmcu/flow/tvm/backend/model_info.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlonmcu/flow/tvm/backend/model_info.py b/mlonmcu/flow/tvm/backend/model_info.py index 55bdbb008..c94c8d0ab 100644 --- a/mlonmcu/flow/tvm/backend/model_info.py +++ b/mlonmcu/flow/tvm/backend/model_info.py @@ -143,7 +143,7 @@ def parse_relay_main(line): output_tensors_str = re.compile(r"-> (.+) {").findall(line) # The following depends on InferType annocations if len(output_tensors_str) > 0: - output_tensor_strs = re.compile(r"Tensor\[\([\di]+(?:, [\di]+)*\), [a-zA-Z0-9_]+\]").findall( + output_tensor_strs = re.compile(r"Tensor\[\([\di]+(?:, [\di]+)*\), [a-zA-Z0-9_]+\]|(?:u?int\d+)").findall( output_tensors_str[0] ) @@ -156,10 +156,15 @@ def parse_relay_main(line): for i, output_name in enumerate(output_tensor_names): res = re.compile(r"Tensor\[\(([\di]+(?:, [\di]+)*)\), ([a-zA-Z0-9_]+)\]").match(output_tensor_strs[i]) + if res is None: + res = re.compile(r"(u?int\d+)").match(output_tensor_strs[i]) assert res is not None groups = res.groups() - assert len(groups) == 2 - output_shape_str, output_type = groups + assert len(groups) in [1, 2] + if len(groups) == 2: + output_shape_str, output_type = groups + elif len(groups) == 1: + output_shape_str, output_type = "1, 1", groups[0] output_shape = shape_from_str(output_shape_str) output_tensor = TensorInfo(output_name, output_shape, output_type) output_tensors.append(output_tensor)