Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update utils.py with temporal_NMS and utility functions #272

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion src/lava/lib/dl/slayer/object_detection/boundingbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from PIL.Image import Image
from torchvision import ops, transforms

from .metrics import bbox_iou
from .metrics import bbox_iou, APstats
from collections import deque

"""Utility functions for Object detection."""

Expand All @@ -21,6 +22,145 @@
Height = int


class storeData():
"""saves and loads pickled data (any type) by simply calling:
storeData.save(data, file_name) for saving
storeData.load(file_name) for loading
"""
def load(fi):
import pickle
with open(fi, 'rb') as handle:
data = pickle.load(handle)
return data
def save(data, fi):
import pickle
with open(fi, 'wb') as handle:
pickle.dump(data, handle,protocol=pickle.HIGHEST_PROTOCOL)


def accuracy(pred_nms, bboxes, printFlag = False):
"""Accuracy of post NMS prediction bboxes and GT bboxex"""
ap_stats = APstats(iou_threshold=0.5)
for t in range(len(pred_nms)):
ap_stats.update(pred_nms[t], bboxes[t])
print(ap_stats.ap_scores()[0]) if printFlag else None
return ap_stats.ap_scores()[0]


class temporal_NMS():
# implemented the temporal Non Maximum Suppression over two consecutive frames
"""Performs temporal Non-Maximal suppression of the input predictions. First a basic
filtering of the bounding boxes based on a minimum confidence threshold are
eliminated. Subsequently we operate a non-maximal suppression is performed maching
bboxes over two successive frames. A non-maximal threshold is used to determine if the
two bounding boxes represent the same object, above(below) which the likelihood of the
object is increased(decreased). It supports batch inputs.


import temporal_NMS as t_nms
### the class is self initialized!! no need to initialize it. Instances can be accessed
### directly calling class_name.instance
### calling the frame predictions to be analyzed, will automatically initialize variables
detections = [t_nms.next(predictions[...,t]) for t in T]
### to reset to zero frame data call:
t_nms.reset()

Parameters
----------
pred : List[torch.tensor]
List of bounding box predictions per batch in
(x_center, y_center, width, height) format.
conf_threshold : float, optional
Confidence threshold, by default 0.5.
nms_threshold : float, optional
Non maximal overlap threshold, by default 0.4.
merge_conf : bool, optional
Flag indicating whether to merge objectness score with classification
confidence, by default True.
max_iterations_temporal : int, optional
Maximum limit of temporal iterations (default 15) that scale the class likelihood given the neighboring bboxes.
temporal_scaling_threshold : float, optional
scaling factor on the nms_threshold for filtering off ious
scaling_prob: List of two floats [float, float]
scaling of the keep [0] and remove [1] probabilities of neighboring bboxes
k_frames: int, optionl - default 2

Returns
-------
List[torch.tensor]
Non-maximal filterered prediction outputs per batch.
"""

init = 0
detections = []
dets_ = []
k_frames = 2
n_batches = None

def reset():
## initialize buffers
__class__.detections = [[] for _ in range(__class__.n_batches)]
__class__.dets_ = [deque(maxlen = __class__.k_frames) for _ in range(__class__.n_batches)]
__class__.init = 1

def next(pred: List[torch.tensor],
conf_threshold = 0.5,
nms_threshold = 0.4,
merge_conf = True,
max_iterations_temporal = 15,
temporal_scaling_threshold = .9,
scaling_prob = [1.15, .85]) -> List[torch.tensor]:
# def __call__(self, pred):

# initialize to data
__class__.n_batches = n_batches = pred.shape[0]
# housekeeping only on first call
if __class__.init==0:
__class__.reset()

dets_, detections = __class__.dets_, __class__.detections

for b_n, pred_, in enumerate(pred): #along the batch
filtered = pred_[pred_[:, 4] > conf_threshold]
obj_conf, labels = torch.max(filtered[:, 5:], dim=1, keepdim=True)
if merge_conf:
scores = filtered[:, 4:5] * obj_conf
else:
scores = filtered[:, 4:5]
boxes = filtered[:, :4]
#last updated frame in detections0
dets_[b_n].append(torch.cat([boxes, scores, labels], dim=-1))
if len(dets_[b_n])==1: ###loads the first frame NMS components discarding NMS
detections0 = detections1 = dets_[b_n][-1]
else:
detections0, detections1 = dets_[b_n][-1], dets_[b_n][-2] ###best performer
# detections0, detections1 = self.dets_[b_n][-1], self.detections[b_n] ## less good on prev scaled det
for k in range(max_iterations_temporal):
if k==max_iterations_temporal-1: # last iteration is classic NMS
detections1 = detections0
order0 = torch.argsort(detections0[:,4], descending=True)
if order0.shape:
detections0 = detections0[order0]
order1 = torch.argsort(detections1[:,4], descending=True)
if order1.shape:
detections1 = detections1[order1]
ious = bbox_iou(detections1, detections0)
label_match = (detections1[:, 5].reshape(-1, 1) == detections0[:, 5].reshape(1, -1))
keep = (
ious * label_match > nms_threshold*temporal_scaling_threshold
).long().triu(1).sum(dim=0, keepdim=True).T.expand_as(detections0) == 0
detections01 = detections0[keep].reshape(-1, 6).contiguous()
detections00 = detections0[~keep].reshape(-1, 6).contiguous()
### rescaling confidence of bboxes if overlapping and belonging to the same label
detections01[:,4] = torch.minimum(detections01[:,4]*scaling_prob[0], torch.tensor(1.0))
detections00[:,4] *= scaling_prob[1]
# considering also last iteration is classic NMS
detections0 = torch.cat([detections01, detections00], dim=0) if k < max_iterations_temporal-1 else detections01
detections[b_n] = detections0.clone()
return detections[:n_batches]



def non_maximum_suppression(predictions: List[torch.tensor],
conf_threshold: float = 0.5,
nms_threshold: float = 0.4,
Expand Down