Skip to content

Commit

Permalink
Added sanity check for unsupported formats to minimize bugs from typo…
Browse files Browse the repository at this point in the history
…s and other human error
  • Loading branch information
jaewon committed May 24, 2022
1 parent de8b35b commit 987cc3d
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions utils/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,28 @@ def get_benchmark_threshold_value(data_path: str, model_name: str) -> t.Union[fl
return map_benchmark_threshold


def get_unsupported_formats(unsupported_arguments: t.Tuple = ('edgetpu', 'tfjs', 'engine', 'coreml')) -> t.Tuple:
def get_unsupported_formats() -> t.Tuple:
# coreml: Exception: Model prediction is only supported on macOS version 10.13 or later.
# engine: Requires gpu and docker container with TensorRT dependencies to run
# tfjs: Conflict with openvino numpy version (openvino < 1.20, tfjs >= 1.20)
# edgetpu: requires coral board, cloud tpu or some other external tpu
return 'edgetpu', 'tfjs', 'engine', 'coreml'


def check_if_formats_exist(unsupported_arguments: t.Tuple):
"""
Check to see if the formats actually exists under export_formats().
An error will be thrown if the argument type does not exist
Args:
unsupported_arguments: A tuple of unsupported export formats
"""
export_formats = export.export_formats()
unsupported = export_formats[export_formats['Argument'].isin(unsupported_arguments)].iloc[:, 1].values.tolist()
return tuple(unsupported)
valid_export_format_arguments = set(export_formats.Argument)
for unsupported_arg in unsupported_arguments:
if unsupported_arg not in valid_export_format_arguments:
raise ValueError(f'Argument: "{unsupported_arg}" is not a valid export format.\n'
f'Valid export formats: {", ".join(valid_export_format_arguments)[: -1]}. \n'
f'See export.export_formats() for more info.')


def get_benchmark_values(
Expand Down Expand Up @@ -139,6 +153,10 @@ def run(
model_name = str(weights).split('/')[-1].split('.')[0]
map_benchmark_threshold = get_benchmark_threshold_value(str(data), model_name)

# get unsupported formats and check if they exist under exports.get_exports()
check_if_formats_exist(get_unsupported_formats())
# check

for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable)
if hard_fail:
if f in get_unsupported_formats():
Expand Down

0 comments on commit 987cc3d

Please sign in to comment.