Skip to content

Commit

Permalink
update validate_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed May 28, 2024
1 parent 65499ca commit cbff929
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
5 changes: 3 additions & 2 deletions mlonmcu/session/postprocess/postprocesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,7 @@ def post_run(self, report, artifacts):
assert len(outputs_artifact) == 1, "Could not find artifact: outputs.npy"
outputs_artifact = outputs_artifact[0]
outputs = np.load(outputs_artifact.path, allow_pickle=True)
in_data = None
# compared = 0
# matching = 0
missing = 0
Expand Down Expand Up @@ -1611,7 +1612,7 @@ def dequant_helper(quant, data): # TODO: move somewhere else
if quant_ is not None:
out_ref_data_quant = ref_quant_helper(quant_, out_ref_data)
for vm in validate_metrics:
vm.process(out_data, out_ref_data_quant, quant=True)
vm.process(out_data, out_ref_data_quant, in_data=in_data, quant=True)
out_data = dequant_helper(quant_, out_data)
# print("out_data", out_data)
# print("sum(out_data)", np.sum(out_data))
Expand All @@ -1622,7 +1623,7 @@ def dequant_helper(quant, data): # TODO: move somewhere else
assert out_data.shape == out_ref_data.shape, "shape missmatch"

for vm in validate_metrics:
vm.process(out_data, out_ref_data, quant=False)
vm.process(out_data, out_ref_data, in_data=in_data, quant=False)
ii += 1
if self.report:
raise NotImplementedError
Expand Down
29 changes: 17 additions & 12 deletions mlonmcu/session/postprocess/validate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import ast
import numpy as np
from typing import Optional

from mlonmcu.logging import get_logger

Expand All @@ -33,13 +34,13 @@ def __init__(self, name, **cfg):
self.num_total = 0
self.num_correct = 0

def process_(self, out_data, out_data_ref, quant: bool = False):
def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
raise NotImplementedError

def check(self, out_data, out_data_ref, quant: bool = False):
return out_data.dtype == out_data_ref.dtype

def process(self, out_data, out_data_ref, quant: bool = False):
def process(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
if not self.check(out_data, out_data_ref, quant=quant):
return
self.num_total += 1
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(self, name: str, atol: float = 0.0, rtol: float = 0.0):
def check(self, out_data, out_data_ref, quant: bool = False):
return not quant

def process_(self, out_data, out_data_ref, quant: bool = False):
def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
return np.allclose(out_data, out_data_ref, rtol=self.rtol, atol=self.atol)


Expand All @@ -106,7 +107,7 @@ def check(self, out_data, out_data_ref, quant: bool = False):
# Probably no classification
return data_len < 25 and not quant

def process_(self, out_data, out_data_ref, quant: bool = False):
def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
# TODO: only for classification models!
# TODO: support multi_outputs?
data_sorted_idx = list(reversed(np.argsort(out_data).tolist()[0]))
Expand Down Expand Up @@ -220,7 +221,7 @@ def __init__(self, name: str, thr: int = 0.5):
assert thr >= 0
self.thr = thr

def process_(self, out_data, out_data_ref, quant: bool = False):
def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
mse = ((out_data - out_data_ref) ** 2).mean()
return mse < self.thr

Expand All @@ -238,15 +239,19 @@ def check(self, out_data, out_data_ref, quant: bool = False):
data_len = len(out_data.flatten().tolist())
return data_len == 640 and not quant

def process_(self, out_data, out_data_ref, quant: bool = False):
data_flat = out_data.flatten().tolist()
ref_data_flat = out_data_ref.flatten().tolist()
def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
assert in_data is not None
in_data_flat = in_data.flatten().tolist()
out_data_flat = out_data.flatten().tolist()
ref_out_data_flat = out_data_ref.flatten().tolist()
res = 0
ref_res = 0
length = len(data_flat)
length = len(out_data_flat)
for jjj in range(length):
res += data_flat[jjj] ** 2
ref_res += ref_data_flat[jjj] ** 2
res = (in_data_flat[jjj] - out_data_flat[jjj])
res += res ** 2
ref_res = (in_data_flat[jjj] - ref_out_data_flat[jjj])
ref_res += ref_res ** 2
res /= length
ref_res /= length
print("res", res)
Expand All @@ -262,7 +267,7 @@ def __init__(self, name: str):
def check(self, out_data, out_data_ref, quant: bool = False):
return "int" in out_data.dtype.str

def process_(self, out_data, out_data_ref, quant: bool = False):
def process_(self, out_data, out_data_ref, in_data: Optional[np.array] = None, quant: bool = False):
data_ = out_data.flatten().tolist()
ref_data_ = out_data_ref.flatten().tolist()

Expand Down

0 comments on commit cbff929

Please sign in to comment.