Skip to content

Commit

Permalink
introduce classify metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed May 27, 2024
1 parent 916f38a commit 79baeda
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mlonmcu/session/postprocess/postprocesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from mlonmcu.logging import get_logger

from .postprocess import SessionPostprocess, RunPostprocess
from .validate_metrics import parse_validate_metrics
from .validate_metrics import parse_validate_metrics, parse_classify_metrics

logger = get_logger()

Expand Down
131 changes: 117 additions & 14 deletions mlonmcu/session/postprocess/validate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ def get_summary(self):
return f"{self.num_correct}/{self.num_total} ({int(self.num_correct/self.num_total*100)}%)"


class ClassifyMetric:

def __init__(self, name, **cfg):
self.name = name
self.num_total = 0
self.num_correct = 0

def process_(self, out_data, label_ref, quant: bool = False):
raise NotImplementedError

def check(self, out_data, label_ref, quant: bool = False):
return out_data.dtype == out_data_ref.dtype

def process(self, out_data, label_ref, quant: bool = False):
if not self.check(out_data, label_ref, quant=quant):
return
self.num_total += 1
if self.process_(out_data, label_ref):
self.num_correct += 1

def get_summary(self):
if self.num_total == 0:
return "N/A"
return f"{self.num_correct}/{self.num_total} ({int(self.num_correct/self.num_total*100)}%)"


class AllCloseMetric(ValidationMetric):

def __init__(self, name: str, atol: float = 0.0, rtol: float = 0.0):
Expand Down Expand Up @@ -88,28 +114,16 @@ def process_(self, out_data, out_data_ref, quant: bool = False):
k = 0
num_checks = min(self.n, len(data_sorted_idx))
assert len(data_sorted_idx) == len(ref_data_sorted_idx)
# print("data_sorted_idx", data_sorted_idx, type(data_sorted_idx))
# print("ref_data_sorted_idx", ref_data_sorted_idx, type(ref_data_sorted_idx))
# print("num_checks", num_checks)
for j in range(num_checks):
# print("j", j)
# print(f"data_sorted_idx[{j}]", data_sorted_idx[j], type(data_sorted_idx[j]))
idx = data_sorted_idx[j]
# print("idx", idx)
ref_idx = ref_data_sorted_idx[j]
# print("ref_idx", ref_idx)
if idx == ref_idx:
# print("IF")
k += 1
else:
# print("ELSE")
if out_data.tolist()[0][idx] == out_data_ref.tolist()[0][ref_idx]:
# print("SAME")
k += 1
else:
# print("BREAK")
break
# print("k", k)
if k < num_checks:
return False
elif k == num_checks:
Expand All @@ -118,12 +132,87 @@ def process_(self, out_data, out_data_ref, quant: bool = False):
assert False


class TopKLabelsMetric(ClassifyMetric):

def __init__(self, name: str, n: int = 2):
super().__init__(name)
assert n >= 1
self.n = n

def check(self, out_data, label_ref, quant: bool = False):
data_len = len(out_data.flatten().tolist())
# Probably no classification
return data_len < 25

def process_(self, out_data, label_ref, quant: bool = False):
# print("process_")
# print("out_data", out_data)
# print("label_ref", label_ref)
data_sorted_idx = list(reversed(np.argsort(out_data).tolist()[0]))
# print("data_sorted_idx", data_sorted_idx)
data_sorted_idx_trunc = data_sorted_idx[:self.n]
# print("data_sorted_idx_trunc", data_sorted_idx_trunc)
res = label_ref in data_sorted_idx_trunc
# print("res", res)
# TODO: handle same values?
# input("111")
return res


class ConfusionMatrixMetric(ValidationMetric):

def __init__(self, name: str):
super().__init__(name)
self.temp = {}
self.num_correct_per_class = {}

def check(self, out_data, label_ref, quant: bool = False):
data_len = len(out_data.flatten().tolist())
# Probably no classification
return data_len < 25 and not quant

def process_(self, out_data, label_ref, quant: bool = False):
data_sorted_idx = list(reversed(np.argsort(out_data).tolist()[0]))
label = data_sorted_idx[0]
correct = label_ref == label
# TODO: handle same values?
return correct, label

def process(self, out_data, label_ref, quant: bool = False):
print("ConfusionMatrixMetric.process")
if not self.check(out_data, label_ref, quant=quant):
return
self.num_total += 1
correct, label = self.process_(out_data, label_ref)
if correct:
self.num_correct += 1
if label_ref not in self.num_correct_per_class:
self.num_correct_per_class[label_ref] = 0
self.num_correct_per_class[label_ref] += 1
temp_ = self.temp.get(label_ref, {})
if label not in temp_:
temp_[label] = 0
temp_[label] += 1
self.temp[label_ref] = temp_

def get_summary(self):
if self.num_total == 0:
return "N/A"
return f"{self.temp}"


class AccuracyMetric(TopKMetric):

def __init__(self, name: str):
super().__init__(name, n=1)


class AccuracyLabelsMetric(TopKLabelsMetric):

def __init__(self, name: str):
super().__init__(name, n=1)


class MSEMetric(ValidationMetric):

def __init__(self, name: str, thr: int = 0.5):
Expand Down Expand Up @@ -197,6 +286,12 @@ def process_(self, out_data, out_data_ref, quant: bool = False):
"pm1": PlusMinusOneMetric,
}

LABELS_LOOKUP = {
"topk_label": TopKLabelsMetric,
"acc_label": AccuracyLabelsMetric,
"confusion_matrix": ConfusionMatrixMetric,
}


def parse_validate_metric_args(inp):
ret = {}
Expand All @@ -212,7 +307,7 @@ def parse_validate_metric_args(inp):
return ret


def parse_validate_metric(inp):
def parse_validate_metric(inp, lookup=LOOKUP):
if "(" in inp:
metric_name, inp_ = inp.split("(")
assert inp_[-1] == ")"
Expand All @@ -221,7 +316,7 @@ def parse_validate_metric(inp):
else:
metric_name = inp
metric_args = {}
metric_cls = LOOKUP.get(metric_name, None)
metric_cls = lookup.get(metric_name, None)
assert metric_cls is not None, f"Validate metric not found: {metric_name}"
metric = metric_cls(inp, **metric_args)
return metric
Expand All @@ -233,3 +328,11 @@ def parse_validate_metrics(inp):
metric = parse_validate_metric(metric_str)
ret.append(metric)
return ret


def parse_classify_metrics(inp):
ret = []
for metric_str in inp.split(";"):
metric = parse_validate_metric(metric_str, lookup=LABELS_LOOKUP)
ret.append(metric)
return ret

0 comments on commit 79baeda

Please sign in to comment.