Skip to content

Commit

Permalink
Merge pull request #19 from SFI-Visual-Intelligence/christian/simpler…
Browse files Browse the repository at this point in the history
…-cli

Simplify how metrics are parsed
  • Loading branch information
salomaestro authored Jan 30, 2025
2 parents 85f7edf + cba9b80 commit 8693d41
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
18 changes: 4 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@ def main():
parser.add_argument('--dataset', type=str, default='svhn',
choices=['svhn'], help='Which dataset to train the model on.')

parser.add_argument('--EntropyPrediction', type=bool, default=True, help='Include the Entropy Prediction metric in evaluation')
parser.add_argument('--F1Score', type=bool, default=True, help='Include the F1Score metric in evaluation')
parser.add_argument('--Recall', type=bool, default=True, help='Include the Recall metric in evaluation')
parser.add_argument('--Precision', type=bool, default=True, help='Include the Precision metric in evaluation')
parser.add_argument('--Accuracy', type=bool, default=True, help='Include the Accuracy metric in evaluation')

parser.add_argument("--metric", type=str, default="entropy", choices=['entropy', 'f1', 'recall', 'precision', 'accuracy'], nargs="+", help='Which metric to use for evaluation')

#Training specific values
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
Expand All @@ -61,13 +57,7 @@ def main():
model = load_model()
model.to(device)

metrics = MetricWrapper(
EntropyPred = args.EntropyPrediction,
F1Score = args.F1Score,
Recall = args.Recall,
Precision = args.Precision,
Accuracy = args.Accuracy
)
metrics = MetricWrapper(*args.metric)

#Dataset
traindata = load_data(args.dataset)
Expand Down Expand Up @@ -126,4 +116,4 @@ def main():


if __name__ == '__main__':
main()
main()
51 changes: 31 additions & 20 deletions utils/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,45 @@


class MetricWrapper(nn.Module):
def __init__(self,
EntropyPred:bool = True,
F1Score:bool = True,
Recall:bool = True,
Precision:bool = True,
Accuracy:bool = True):
def __init__(self, *metrics):
super().__init__()
self.metrics = {}

if EntropyPred:
self.metrics['Entropy of Predictions'] = EntropyPrediction()

if F1Score:
self.metrics['F1 Score'] = None

if Recall:
self.metrics['Recall'] = None
for metric in metrics:
self.metrics[metric] = self._get_metric(metric)

if Precision:
self.metrics['Precision'] = None

if Accuracy:
self.metrics['Accuracy'] = None

self.tmp_scores = copy.deepcopy(self.metrics)
for key in self.tmp_scores:
self.tmp_scores[key] = []


def _get_metric(self, key):
"""
Get the metric function based on the key
Args
----
key (str): metric name
Returns
-------
metric (callable): metric function
"""

match key.lower():
case 'entropy':
return EntropyPrediction()
case 'f1':
raise NotImplementedError("F1 score not implemented yet")
case 'recall':
raise NotImplementedError("Recall score not implemented yet")
case 'precision':
raise NotImplementedError("Precision score not implemented yet")
case 'accuracy':
raise NotImplementedError("Accuracy score not implemented yet")
case _:
raise ValueError(f"Metric {key} not supported")

def __call__(self, y_true, y_pred):
for key in self.metrics:
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
Expand Down

0 comments on commit 8693d41

Please sign in to comment.