Skip to content

Commit

Permalink
create Matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Jan 28, 2025
1 parent 4985fc8 commit 1156dd2
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 45 deletions.
2 changes: 1 addition & 1 deletion conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ detection_model:
train:
fast_dev_run: False
epochs: 10
lr: 0.0001
lr: 0.000001
workers: 0
validation:
val_accuracy_interval: 3
Expand Down
2 changes: 1 addition & 1 deletion src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def run(self):
uncertain_predictions=uncertain_predictions,
pipeline_monitor=pipeline_monitor)

reporter.generate_report(create_video=True)
reporter.generate_report(create_video=False)
else:
print("No images to annotate")

121 changes: 85 additions & 36 deletions src/pipeline_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from src.label_studio import gather_data
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.classification import Accuracy
from torchmetrics.functional import confusion_matrix
from src.detection import predict
from src.detection import predict, fix_taxonomy
import pandas as pd
import torch
from torchvision.ops.boxes import box_iou
from torchvision.models.detection._utils import Matcher

import os

class PipelineEvaluation:
Expand Down Expand Up @@ -39,12 +41,25 @@ def __init__(self, model, crop_model, image_dir, detect_ground_truth_dir, classi

# Gather data
self.detection_annotations = gather_data(detect_ground_truth_dir)
self.detection_annotations = self.detection_annotations[self.detection_annotations.label.isin(["Bird","Cetacean","Turtle"])]
self.detection_annotations= fix_taxonomy(self.detection_annotations)

self.detection_annotations = self.detection_annotations[self.detection_annotations.label.isin(self.model.label_dict.keys())]

self.classification_annotations = gather_data(classify_ground_truth_dir)

# There is one caveat for empty frames, assign a label which the dict contains
self.classification_annotations.loc[self.classification_annotations.label.astype(str)=='0',"label"] = self.model.numeric_to_label_dict[0]

# No need to evaluate if there are no annotations for classification
self.classification_annotations = self.classification_annotations.loc[
~(
(self.classification_annotations.xmin == 0) &
(self.classification_annotations.ymin == 0) &
(self.classification_annotations.xmax == 0) &
(self.classification_annotations.ymax == 0)
)
]

# Prediction container
self.predictions = []

Expand All @@ -54,6 +69,10 @@ def __init__(self, model, crop_model, image_dir, detect_ground_truth_dir, classi
if self.num_classes == 1:
self.num_classes = 2

# Metrics
self.confident_classification_accuracy = Accuracy(average="micro", task="multiclass", num_classes=self.num_classes)
self.uncertain_classification_accuracy = Accuracy(average="micro", task="multiclass", num_classes=self.num_classes)

def _format_targets(self, annotations_df):
targets = {}

Expand Down Expand Up @@ -128,7 +147,7 @@ def predict_classification(self):
image_paths=full_image_paths,
patch_size=self.patch_size,
patch_overlap=self.patch_overlap,
batch_size=32
batch_size=16
)
combined_predictions = pd.concat(predictions)
self.predictions.append(combined_predictions)
Expand All @@ -145,9 +164,51 @@ def predict_classification(self):

return confident_predictions, uncertain_predictions

def match_predictions_and_targets(self, pred, target):
"""
Matches predicted bounding boxes with source bounding boxes using Intersection over Union (IoU).
Args:
pred (Tensor): A tensor containing the source bounding boxes.
target (Tensor): A tensor containing the predicted bounding boxes.
Returns:
DataFrame: A dataframe containing the matched predictions and targets.
"""

# Match predictions and targets
matcher = Matcher(
0.3,
0.3,
allow_low_quality_matches=False)

pred_boxes = pred["boxes"]
src_boxes = target["boxes"]

match_quality_matrix = box_iou(
src_boxes,
pred_boxes)

results = matcher(match_quality_matrix)

matched_pred = []
matched_target = []

for i, match in enumerate(results):
if match >= 0:
matched_pred.append(int(pred["labels"][i].item()))
matched_target.append(int(target["labels"][match].item()))
else:
matched_pred.append(int(pred["labels"][i].item()))
matched_target.append(None)

matches = pd.DataFrame({"pred": matched_pred, "target": matched_target})

# Remove the None values for predicted, can't get class scores if the box doesn't match
matches = matches.dropna(subset=["pred"])

return matches

def evaluate_confident_classification(self):
"""Evaluate confident classification performance"""

targets = []
preds = []
for image_path in self.confident_predictions.drop_duplicates("image_path").image_path.tolist():
Expand All @@ -159,28 +220,20 @@ def evaluate_confident_classification(self):
pred = self._format_targets(image_predictions)
if len(pred["labels"]) == 0:
continue
targets.append(target)
preds.append(pred)

if len(preds) == 0:
return {"confident_classification_accuracy": None}
else:
# Classification is just the labels dict
target_labels = torch.stack([x["labels"] for x in targets])
pred_labels = torch.stack([x["labels"] for x in preds])

self.confident_classification_accuracy = Accuracy(average="micro", task="multiclass", num_classes=self.num_classes)

self.confident_classification_accuracy.update(preds=pred_labels, target=target_labels)
results = {"confident_classification_accuracy": self.confident_classification_accuracy.compute()}

matches = self.match_predictions_and_targets(pred, target)
if len(matches) == 0:
continue
else:
self.confident_classification_accuracy.update(preds=torch.tensor(matches["pred"].values), target=torch.tensor(matches["target"].values))

results = {"confident_classification_accuracy": self.confident_classification_accuracy.compute()}

return results

def evaluate_uncertain_classification(self):
"""Evaluate uncertain classification performance"""

self.uncertain_classification_accuracy = Accuracy(average="micro", task="multiclass", num_classes=self.num_classes)

targets = []
preds = []
for image_path in self.uncertain_predictions.drop_duplicates("image_path").image_path.tolist():
Expand All @@ -189,23 +242,19 @@ def evaluate_uncertain_classification(self):
image_predictions = image_predictions[image_predictions.score > self.min_score]
if image_predictions.empty:
continue
target = self._format_targets(image_targets)
pred = self._format_targets(image_predictions)
targets.append(target)
preds.append(pred)
targets = self._format_targets(image_targets)
preds = self._format_targets(image_predictions)

if len(preds) == 0:

return {"uncertain_classification_accuracy": None}
else:
# Classification is just the labels dict
target_labels = torch.stack([x["labels"] for x in targets])
pred_labels = torch.stack([x["labels"] for x in preds])

self.uncertain_classification_accuracy.update(preds=pred_labels, target=target_labels)
results = {"uncertain_classification_accuracy": self.uncertain_classification_accuracy.compute()}
matches = self.match_predictions_and_targets(preds, targets)

if len(matches) == 0:
continue
else:
self.uncertain_classification_accuracy.update(preds=torch.tensor(matches["pred"].values), target=torch.tensor(matches["target"].values))

results = {"uncertain_classification_accuracy": self.uncertain_classification_accuracy.compute()}

return results
return results

def evaluate(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def predict_video_images(self, images):
m=self.model,
crop_model=self.classification_model,
patch_overlap=self.patch_overlap,
patch_size=self.patch_overlap,
patch_size=self.patch_size,
)

predictions = predictions[predictions.score > self.min_score]
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ def config(tmpdir_factory):
'image_path': ['empty.jpg', 'birds.jpg', "birds.jpg"],
'xmin': [20, 200, 150],
'ymin': [10, 300, 250],
'xmax': [40, 300, 250],
'ymax': [20, 400, 350],
'xmax': [40, 250, 200],
'ymax': [20, 350, 300],
'label': ['FalsePositive', 'Bird', 'Bird2'],
'annotator': ['test_user', 'test_user', 'test_user']
}

val_data = {
'image_path': ['empty.jpg','birds_val.jpg', 'birds_val.jpg'],
'xmin': [None,150, 150],
'ymin': [None,250, 250],
'xmax': [None,250, 250],
'ymax': [None,350, 350],
'xmin': [None,200, 150],
'ymin': [None,300, 250],
'xmax': [None,250, 200],
'ymax': [None,350, 300],
'label': ['Bird','Bird', 'Bird2'],
'annotator': ['test_user','test_user', 'test_user'],
}
Expand Down

0 comments on commit 1156dd2

Please sign in to comment.