Skip to content

Commit

Permalink
add test data
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Nov 13, 2024
1 parent 5c5a912 commit e21a404
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
12 changes: 5 additions & 7 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@ train:
pipeline_evaluation:
detection_annotations_dir:
classification_annotations_dir:
detection:
true_positive_threshold: 0.8
false_positive_threshold: 0.5
classification:
avg_score: 0.5
target_labels:
- "Bird"
detection_true_positive_threshold: 0.8
detection_false_positive_threshold: 0.5
classification_avg_score: 0.5
target_labels:
- "Bird"

choose_images:
images_to_annotate_dir: '/path/to/images'
Expand Down
23 changes: 15 additions & 8 deletions src/pipeline_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,28 @@
from torchmetrics.functional import confusion_matrix

class PipelineEvaluation:
def __init__(self, model, detection_annotation_dir=None, classification_annotation_dir=None, iou_threshold=0.4, target_classes=None):
def __init__(self, model, detection_annotations_dir=None, classification_annotations_dir=None, detection_true_positive_threshold=0.8, detection_false_positive_threshold=0.5, classification_avg_score=0.5, target_labels=None):
"""Initialize pipeline evaluation"""
self.iou_threshold = iou_threshold
self.detection_annotations_df = gather_data(detection_annotation_dir)
self.classification_annotations_df = gather_data(classification_annotation_dir)
self.detection_true_positive_threshold = detection_true_positive_threshold
self.detection_false_positive_threshold = detection_false_positive_threshold
self.classification_avg_score = classification_avg_score

self.detection_annotations_df = gather_data(detection_annotations_dir)
self.classification_annotations_df = gather_data(classification_annotations_dir)

self.model = model

# Metrics
self.mAP = MeanAveragePrecision(box_format="xyxy",extended_summary=True)
self.classification_accuracy = Accuracy()
self.mAP = MeanAveragePrecision(box_format="xyxy",extended_summary=True, iou_threshold=detection_true_positive_threshold)
self.classification_accuracy = Accuracy(average="micro", num_classes=len(target_labels))

def _format_targets(self, annotations_df):
return {"boxes": annotations_df["bbox"].tolist(),
"labels": annotations_df["label"].tolist()}
targets = {}
targets["boxes"] = annotations_df[["xmin", "ymin", "xmax",
"ymax"]].values.astype("float32")
targets["labels"] = [self.model.label_dict[x] for x in annotations_df["label"].tolist()]

return targets

def evaluate_detection(self):
preds = self.model.predict(self.detection_annotations_df)
Expand Down
Binary file added tests/data/birds.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/birds_val.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e21a404

Please sign in to comment.