Skip to content

Commit

Permalink
move validate metrics to seperate class & file
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed Apr 11, 2024
1 parent 4345db5 commit 1a8d9eb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 177 deletions.
1 change: 1 addition & 0 deletions mlonmcu/models/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def process_metadata(self, model, cfg=None):
f"outputs_ref.{fmt}", raw=raw, fmt=ArtifactFormat.BIN, flags=("outputs_ref", fmt)
)
artifacts.append(outputs_data_artifact)
return artifacts

def generate(self, model) -> Tuple[dict, dict]:
artifacts = []
Expand Down
214 changes: 37 additions & 177 deletions mlonmcu/session/postprocess/postprocesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from mlonmcu.logging import get_logger

from .postprocess import SessionPostprocess, RunPostprocess
from .validate_metrics import parse_validate_metrics

logger = get_logger()

Expand Down Expand Up @@ -1460,22 +1461,20 @@ def post_run(self, report, artifacts):
class ValidateOutputsPostprocess(RunPostprocess):
"""Postprocess for comparing model outputs with golden reference."""

DEFAULTS = {**RunPostprocess.DEFAULTS, "atol": 0.0, "rtol": 0.0, "report": False}
DEFAULTS = {
**RunPostprocess.DEFAULTS,
"report": False,
"validate_metrics": "topk(n=1);topk(n=2)",
}

def __init__(self, features=None, config=None):
super().__init__("validate_outputs", features=features, config=config)

@property
def atol(self):
"""Get atol property."""
value = self.config["atol"]
return float(value)

@property
def rtol(self):
"""Get rtol property."""
value = self.config["rtol"]
return float(value)
def validate_metrics(self):
"""Get validate_metrics property."""
value = self.config["validate_metrics"]
return value

@property
def report(self):
Expand Down Expand Up @@ -1507,22 +1506,24 @@ def post_run(self, report, artifacts):
assert len(outputs_artifact) == 1, "Could not find artifact: outputs.npy"
outputs_artifact = outputs_artifact[0]
outputs = np.load(outputs_artifact.path, allow_pickle=True)
compared = 0
# compared = 0
# matching = 0
missing = 0
metrics = {
"allclose(atol=0.0,rtol=0.0)": None,
"allclose(atol=0.05,rtol=0.05)": None,
"allclose(atol=0.1,rtol=0.1)": None,
"topk(n=1)": None,
"topk(n=2)": None,
"topk(n=inf)": None,
"toy": None,
"mse(thr=0.1)": None,
"mse(thr=0.05)": None,
"mse(thr=0.01)": None,
"+-1": None,
}
# metrics = {
# "allclose(atol=0.0,rtol=0.0)": None,
# "allclose(atol=0.05,rtol=0.05)": None,
# "allclose(atol=0.1,rtol=0.1)": None,
# "topk(n=1)": None,
# "topk(n=2)": None,
# "topk(n=inf)": None,
# "toy": None,
# "mse(thr=0.1)": None,
# "mse(thr=0.05)": None,
# "mse(thr=0.01)": None,
# "+-1": None,
# }
validate_metrics_str = self.validate_metrics
validate_metrics = parse_validate_metrics(validate_metrics_str)
for i, output_ref in enumerate(outputs_ref):
if i >= len(outputs):
logger.warning("Missing output sample")
Expand All @@ -1546,19 +1547,6 @@ def post_run(self, report, artifacts):
# print("out_data_before_quant", out_data)
# print("sum(out_data_before_quant", np.sum(out_data))

def pm1_helper(data, ref_data):
data_ = data.flatten().tolist()
ref_data_ = ref_data.flatten().tolist()

length = len(data_)
for jjj in range(length):
diff = abs(data_[jjj] - ref_data_[jjj])
print("diff", diff)
if diff > 1:
print("r FALSE")
return False
return True

quant = model_info_data.get("output_quant_details", None)
if quant:
assert ii < len(quant)
Expand All @@ -1570,15 +1558,11 @@ def pm1_helper(data, ref_data):
# need to dequantize here
assert out_data.dtype.name in ["int8"], "Dequantization only supported for int8 input"
assert quant_dtype in ["float32"], "Dequantization only supported for float32 output"
for metric_name in metrics:
if metric_name == "+-1":
if metrics[metric_name] is None:
metrics[metric_name] = 0
out_ref_data_quant = np.around(
(out_ref_data / quant_scale) + quant_zero_point
).astype("int8")
if pm1_helper(out_data, out_ref_data_quant):
metrics[metric_name] += 1
out_ref_data_quant = np.around(
(out_ref_data / quant_scale) + quant_zero_point
).astype("int8")
for vm in validate_metrics:
vm.process(out_data, out_ref_data_quant, quant=True)
out_data = (out_data.astype("float32") - quant_zero_point) * quant_scale
# print("out_data", out_data)
# print("sum(out_data)", np.sum(out_data))
Expand All @@ -1587,137 +1571,13 @@ def pm1_helper(data, ref_data):
# input("TIAW")
assert out_data.dtype == out_ref_data.dtype, "dtype missmatch"
assert out_data.shape == out_ref_data.shape, "shape missmatch"
# if np.allclose(out_data, out_ref_data, rtol=0, atol=0):
# if np.allclose(out_data, out_ref_data, rtol=0.1, atol=0.1):
# if np.allclose(out_data, out_ref_data, rtol=0.0, atol=0.0):
# if np.allclose(out_data, out_ref_data, rtol=0.01, atol=0.01):

def mse_helper(data, ref_data, thr):
mse = ((data - ref_data) ** 2).mean()
print("mse", mse)
return mse < thr

def toy_helper(data, ref_data, atol, rtol):
data_flat = data.flatten().tolist()
ref_data_flat = ref_data.flatten().tolist()
res = 0
ref_res = 0
length = len(data_flat)
for jjj in range(length):
res += data_flat[jjj] ** 2
ref_res += ref_data_flat[jjj] ** 2
res /= length
ref_res /= length
print("res", res)
print("ref_res", ref_res)
return np.allclose([res], [ref_res], atol=atol, rtol=rtol)

def topk_helper(data, ref_data, n):
# TODO: only for classification models!
# TODO: support multi_outputs?
data_sorted_idx = list(reversed(np.argsort(data).tolist()[0]))
ref_data_sorted_idx = list(reversed(np.argsort(ref_data).tolist()[0]))
k = 0
num_checks = min(n, len(data_sorted_idx))
assert len(data_sorted_idx) == len(ref_data_sorted_idx)
# print("data_sorted_idx", data_sorted_idx, type(data_sorted_idx))
# print("ref_data_sorted_idx", ref_data_sorted_idx, type(ref_data_sorted_idx))
# print("num_checks", num_checks)
for j in range(num_checks):
# print("j", j)
# print(f"data_sorted_idx[{j}]", data_sorted_idx[j], type(data_sorted_idx[j]))
idx = data_sorted_idx[j]
# print("idx", idx)
ref_idx = ref_data_sorted_idx[j]
# print("ref_idx", ref_idx)
if idx == ref_idx:
# print("IF")
k += 1
else:
# print("ELSE")
if data.tolist()[0][idx] == ref_data.tolist()[0][ref_idx]:
# print("SAME")
k += 1
else:
# print("BREAK")
break
# print("k", k)
if k < num_checks:
return False
elif k == num_checks:
return True
else:
assert False

for metric_name in metrics:
if "allclose" in metric_name:
if metrics[metric_name] is None:
metrics[metric_name] = 0
if metric_name == "allclose(atol=0.0,rtol=0.0)":
atol = 0.0
rtol = 0.0
elif metric_name == "allclose(atol=0.05,rtol=0.05)":
atol = 0.05
rtol = 0.05
elif metric_name == "allclose(atol=0.1,rtol=0.1)":
atol = 0.1
rtol = 0.1
else:
raise NotImplementedError
if np.allclose(out_data, out_ref_data, rtol=rtol, atol=atol):
metrics[metric_name] += 1
elif "topk" in metric_name:
data_len = len(out_data.flatten().tolist())
if data_len > 25: # Probably no classification
continue
if metrics[metric_name] is None:
metrics[metric_name] = 0
if metric_name == "topk(n=1)":
n = 1
elif metric_name == "topk(n=2)":
n = 2
elif metric_name == "topk(n=inf)":
n = 1000000
else:
raise NotImplementedError
if topk_helper(out_data, out_ref_data, n):
metrics[metric_name] += 1
elif metric_name == "toy":
data_len = len(out_data.flatten().tolist())
if data_len != 640:
continue
if metrics[metric_name] is None:
metrics[metric_name] = 0
if toy_helper(out_data, out_ref_data, 0.01, 0.01):
metrics[metric_name] += 1
elif metric_name == "mse":
if metrics[metric_name] is None:
metrics[metric_name] = 0
if metric_name == "mse(thr=0.1)":
thr = 0.1
elif metric_name == "mse(thr=0.05)":
thr = 0.05
elif metric_name == "mse(thr=0.01)":
thr = 0.01
else:
raise NotImplementedError
if mse_helper(out_data, out_ref_data, thr):
metrics[metric_name] += 1
elif metric_name == "+-1":
continue
compared += 1

for vm in validate_metrics:
vm.process(out_data, out_ref_data_quant, quant=False)
ii += 1
if self.report:
raise NotImplementedError
if self.atol:
raise NotImplementedError
if self.rtol:
raise NotImplementedError
for metric_name, metric_data in metrics.items():
if metric_data is None:
res = "N/A"
else:
matching = metric_data
res = f"{matching}/{compared} ({int(matching/compared*100)}%)"
report.post_df[f"{metric_name}"] = res
for vm in validate_metrics:
res = vm.get_summary()
report.post_df[f"{vm.name}"] = res
return []

0 comments on commit 1a8d9eb

Please sign in to comment.