Skip to content

Commit

Permalink
[PT FE] Torchvision NMS can accept negative scores (openvinotoolkit#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin authored Sep 14, 2023
1 parent 16adb01 commit 1a950f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/frontends/pytorch/src/op/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ OutputVector translate_nms(const NodeContext& context) {
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()}));
auto iou_threshold = context.get_input(2);

auto nms_out =
context.mark_node(std::make_shared<v9::NonMaxSuppression>(boxes, scores, max_output_per_class, iou_threshold));
auto score_threshold =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {std::numeric_limits<float>::lowest()}));
auto nms_out = context.mark_node(
std::make_shared<v9::NonMaxSuppression>(boxes, scores, max_output_per_class, iou_threshold, score_threshold));
auto select = context.mark_node(std::make_shared<v8::Gather>(nms_out, const_2, const_1));

return {context.mark_node(std::make_shared<v0::Squeeze>(select, const_1))};
Expand Down
3 changes: 2 additions & 1 deletion tests/layer_tests/pytorch_tests/test_nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def _prepare_input(self):
# PyTorch requires that boxes are in (x1, y1, x2, y2) format, where 0<=x1<x2 and 0<=y1<y2
boxes = np.array([[np.random.uniform(1, 3), np.random.uniform(2, 6),
np.random.uniform(4, 6), np.random.uniform(7, 9)] for _ in range(self.boxes_num)]).astype(np.float32)
scores = np.abs(np.random.randn(self.boxes_num).astype(np.float32))
# scores can be negative
scores = np.random.randn(self.boxes_num).astype(np.float32)
return (boxes, scores)

def create_model(self):
Expand Down

0 comments on commit 1a950f9

Please sign in to comment.