Skip to content

Commit

Permalink
Exclude combined_non_maximum_supression from Hessian compatible nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Idan-BenAmi committed Nov 26, 2023
1 parent bf5d8cf commit 519da12
Showing 1 changed file with 32 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@ def argmax_output_model(input_shape):
model = keras.Model(inputs=inputs, outputs=outputs)
return model

def nms_output_model(input_shape):
inputs = layers.Input(shape=input_shape)
x = layers.Conv2D(1, 3, padding='same')(inputs)
x = layers.ReLU()(x)

# Dummy layers for creating NMS inputs with the required shape
x = tf.squeeze(x, -1)
x = tf.concat([x, x], -1)
y = tf.concat([x, x], -1)
y = tf.concat([y, y], -1)
scores = tf.concat([x, y], -1) # shape = (batch, detections, classes)
boxes, _ = tf.split(x, (4,12), -1)
boxes = tf.expand_dims(boxes, 2) # shape = (batch, detections, 1, box coordinates)

# NMS layer
outputs = tf.image.combined_non_max_suppression(
boxes,
scores,
max_output_size_per_class=300,
max_total_size=300,
iou_threshold=0.7,
score_threshold=0.001,
pad_per_class=False,
clip_boxes=False
)
model = keras.Model(inputs=inputs, outputs=outputs)
return model

def representative_dataset():
yield [np.random.randn(1, 8, 8, 3).astype(np.float32)]
Expand Down Expand Up @@ -71,6 +98,11 @@ def test_not_supported_output_argmax(self):
self.verify_test_for_model(model)
self.assertTrue("All graph outputs should support Hessian computation" in str(e.exception))

def test_not_supported_output_nms(self):
model = nms_output_model((8, 8, 3))
with self.assertRaises(Exception) as e:
self.verify_test_for_model(model)
self.assertTrue("All graph outputs should support Hessian computation" in str(e.exception))

if __name__ == '__main__':
unittest.main()

0 comments on commit 519da12

Please sign in to comment.