Skip to content

Commit

Permalink
lint code
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed May 29, 2024
1 parent 9196ede commit cbc5a15
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
24 changes: 19 additions & 5 deletions mlonmcu/models/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,12 @@ def generate_input_data(self, input_names, input_types, input_shapes, input_rang
if CLIP_INPUTS:
arr = np.clip(arr, lower, upper)
else:
assert np.min(arr) >= lower or np.isclose(np.min(arr), lower), "Range missmatch (lower)"
assert np.max(arr) <= upper or np.isclose(np.max(arr), upper), "Range missmatch (upper)"
assert np.min(arr) >= lower or np.isclose(
np.min(arr), lower
), "Range missmatch (lower)"
assert np.max(arr) <= upper or np.isclose(
np.max(arr), upper
), "Range missmatch (upper)"
arr = (arr / scale) + shift
arr = np.around(arr)
arr = arr.astype(dtype)
Expand Down Expand Up @@ -414,7 +418,9 @@ def generate_output_ref_data(
else:
files = out_paths
temp = {}
assert len(inputs_data) <= len(files), f"Missing output data for provided inputs. (Expected: {len(inputs_data)}, Got: {len(files)})"
assert len(inputs_data) <= len(
files
), f"Missing output data for provided inputs. (Expected: {len(inputs_data)}, Got: {len(files)})"
for file in files:
if not isinstance(file, Path):
file = Path(file)
Expand Down Expand Up @@ -492,12 +498,13 @@ def generate_ref_labels(
fmt = ext[1:].lower()
if fmt == "csv":
import pandas as pd

labels_df = pd.read_csv(file, sep=",")
assert "i" in labels_df.columns
assert "label_idx" in labels_df.columns
assert len(inputs_data) <= len(labels_df)
labels_df.sort_values("i", inplace=True)
labels = list(labels_df["label_idx"].astype(int))[:len(inputs_data)]
labels = list(labels_df["label_idx"].astype(int))[: len(inputs_data)]
else:
raise NotImplementedError(f"Fmt not supported: {fmt}")
else:
Expand Down Expand Up @@ -914,7 +921,14 @@ def produce_artifacts(self, model):
# TODO: frontend parsed metadata instead of lookup.py?
# TODO: how to find inout_data?
class TfLiteFrontend(SimpleFrontend):
FEATURES = Frontend.FEATURES | {"visualize", "split_layers", "tflite_analyze", "gen_data", "gen_ref_data", "gen_ref_labels"}
FEATURES = Frontend.FEATURES | {
"visualize",
"split_layers",
"tflite_analyze",
"gen_data",
"gen_ref_data",
"gen_ref_labels",
}

DEFAULTS = {
**Frontend.DEFAULTS,
Expand Down
1 change: 0 additions & 1 deletion mlonmcu/platform/mlif/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def get_process_outputs_stdout_raw():


class ModelSupport:

def __init__(self, in_interface, out_interface, model_info, target=None, batch_size=None, inputs_data=None):
self.model_info = model_info
self.target = target
Expand Down
18 changes: 13 additions & 5 deletions mlonmcu/session/postprocess/postprocesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,7 @@ def post_run(self, report, artifacts):
quant = model_info_data.get("output_quant_details", None)
rng = model_info_data.get("output_ranges", None)
if quant:

def ref_quant_helper(quant, data): # TODO: move somewhere else
if quant is None:
return data
Expand All @@ -1576,9 +1577,8 @@ def ref_quant_helper(quant, data): # TODO: move somewhere else
assert lower <= upper
assert np.min(data) >= lower and np.max(data) <= upper, "Range missmatch"

return np.around((data / quant_scale) + quant_zero_point).astype(
"int8"
)
return np.around((data / quant_scale) + quant_zero_point).astype("int8")

def dequant_helper(quant, data): # TODO: move somewhere else
if quant is None:
return data
Expand All @@ -1597,6 +1597,7 @@ def dequant_helper(quant, data): # TODO: move somewhere else
assert lower <= upper
assert np.min(ret) >= lower and np.max(ret) <= upper, "Range missmatch"
return ret

assert ii < len(rng)
rng_ = rng[ii]
if rng_ and self.validate_range:
Expand Down Expand Up @@ -1668,7 +1669,9 @@ def post_run(self, report, artifacts):
if len(model_info_data["output_names"]) > 1:
raise NotImplementedError("Multi-outputs not yet supported.")
labels_ref_artifact = lookup_artifacts(artifacts, name="labels_ref.npy", first_only=True)
assert len(labels_ref_artifact) == 1, "Could not find artifact: labels_ref.npy (Run classify_labels postprocess first!)"
assert (
len(labels_ref_artifact) == 1
), "Could not find artifact: labels_ref.npy (Run classify_labels postprocess first!)"
labels_ref_artifact = labels_ref_artifact[0]
import numpy as np

Expand Down Expand Up @@ -1792,6 +1795,7 @@ def post_run(self, report, artifacts):
output = {output_names[idx]: out for idx, out in enumerate(output)}
quant = model_info_data.get("output_quant_details", None)
if quant and not self.skip_dequant:

def dequant_helper(quant, data):
if quant is None:
return data
Expand All @@ -1801,7 +1805,10 @@ def dequant_helper(quant, data):
assert data.dtype.name in ["int8"], "Dequantization only supported for int8 input"
assert quant_dtype in ["float32"], "Dequantization only supported for float32 output"
return (data.astype("float32") - quant_zero_point) * quant_scale
output = {out_name: dequant_helper(quant[j], output[out_name]) for j, out_name in enumerate(output.keys())}

output = {
out_name: dequant_helper(quant[j], output[out_name]) for j, out_name in enumerate(output.keys())
}
if self.fmt == "npy":
raise NotImplementedError("npy export")
elif self.fmt == "bin":
Expand All @@ -1827,6 +1834,7 @@ def dequant_helper(quant, data):
archive_path = f"{dest_}.{archive_fmt}"
if archive_fmt == "tar.gz":
import tarfile

with tarfile.open(archive_path, "w:gz") as tar:
for filename in filenames:
tar.add(filename, arcname=filename.name)
Expand Down
21 changes: 5 additions & 16 deletions mlonmcu/session/postprocess/validate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


class ValidationMetric:

def __init__(self, name, **cfg):
self.name = name
self.num_total = 0
Expand All @@ -54,7 +53,6 @@ def get_summary(self):


class ClassifyMetric:

def __init__(self, name, **cfg):
self.name = name
self.num_total = 0
Expand All @@ -80,7 +78,6 @@ def get_summary(self):


class AllCloseMetric(ValidationMetric):

def __init__(self, name: str, atol: float = 0.0, rtol: float = 0.0):
super().__init__(name)
assert atol >= 0
Expand All @@ -96,7 +93,6 @@ def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, q


class TopKMetric(ValidationMetric):

def __init__(self, name: str, n: int = 2):
super().__init__(name)
assert n >= 1
Expand Down Expand Up @@ -134,7 +130,6 @@ def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, q


class TopKLabelsMetric(ClassifyMetric):

def __init__(self, name: str, n: int = 2):
super().__init__(name)
assert n >= 1
Expand All @@ -151,7 +146,7 @@ def process_(self, out_data, label_ref, quant: bool = False):
# print("label_ref", label_ref)
data_sorted_idx = list(reversed(np.argsort(out_data).tolist()[0]))
# print("data_sorted_idx", data_sorted_idx)
data_sorted_idx_trunc = data_sorted_idx[:self.n]
data_sorted_idx_trunc = data_sorted_idx[: self.n]
# print("data_sorted_idx_trunc", data_sorted_idx_trunc)
res = label_ref in data_sorted_idx_trunc
# print("res", res)
Expand All @@ -161,7 +156,6 @@ def process_(self, out_data, label_ref, quant: bool = False):


class ConfusionMatrixMetric(ValidationMetric):

def __init__(self, name: str):
super().__init__(name)
self.temp = {}
Expand Down Expand Up @@ -203,19 +197,16 @@ def get_summary(self):


class AccuracyMetric(TopKMetric):

def __init__(self, name: str):
super().__init__(name, n=1)


class AccuracyLabelsMetric(TopKLabelsMetric):

def __init__(self, name: str):
super().__init__(name, n=1)


class MSEMetric(ValidationMetric):

def __init__(self, name: str, thr: int = 0.5):
super().__init__(name)
assert thr >= 0
Expand All @@ -227,7 +218,6 @@ def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, q


class ToyScoreMetric(ValidationMetric):

def __init__(self, name: str, atol: float = 0.1, rtol: float = 0.1):
super().__init__(name)
assert atol >= 0
Expand All @@ -248,10 +238,10 @@ def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, q
ref_res = 0
length = len(out_data_flat)
for jjj in range(length):
res = (in_data_flat[jjj] - out_data_flat[jjj])
res += res ** 2
ref_res = (in_data_flat[jjj] - ref_out_data_flat[jjj])
ref_res += ref_res ** 2
res = in_data_flat[jjj] - out_data_flat[jjj]
res += res**2
ref_res = in_data_flat[jjj] - ref_out_data_flat[jjj]
ref_res += ref_res**2
res /= length
ref_res /= length
print("res", res)
Expand All @@ -260,7 +250,6 @@ def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, q


class PlusMinusOneMetric(ValidationMetric):

def __init__(self, name: str):
super().__init__(name)

Expand Down
3 changes: 2 additions & 1 deletion mlonmcu/setup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def execute(
else:
try:
p = subprocess.Popen(
[i for i in args], **kwargs, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT)
[i for i in args], **kwargs, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT
)
if stdin_data:
out_str = p.communicate(input=stdin_data)[0]
else:
Expand Down

0 comments on commit cbc5a15

Please sign in to comment.