Skip to content

Commit

Permalink
Migrate npairs_loss (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
WindQAQ authored and Squadrick committed Jun 21, 2019
1 parent df45d88 commit 7a1ed4b
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tensorflow_addons/losses/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ py_library(
"focal_loss.py",
"lifted.py",
"metric_learning.py",
"npairs.py",
"sparsemax_loss.py",
"triplet.py",
],
Expand Down Expand Up @@ -46,6 +47,19 @@ py_test(
],
)

py_test(
name = "npairs_test",
size = "small",
srcs = [
"npairs_test.py",
],
main = "npairs_test.py",
srcs_version = "PY2AND3",
deps = [
":losses",
],
)

py_test(
name = "sparsemax_loss_test",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/losses/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
| contrastive | @WindQAQ | [email protected] |
| focal_loss | | |
| lifted | | |
| npairs | @WindQAQ | [email protected] |
| sparsemax_loss | @AndreasMadsen | [email protected] |
| triplet | | |

Expand All @@ -15,6 +16,7 @@
| contrastive | ContrastiveLoss | http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf |
| focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 |
| lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 |
| npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf |
| sparsemax_loss | SparsemaxLoss | https://arxiv.org/abs/1602.02068 |
| triplet | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 |

Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss
from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss
95 changes: 95 additions & 0 deletions tensorflow_addons/losses/npairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements npairs loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils


@keras_utils.register_keras_custom_object
@tf.function
def npairs_loss(y_true, y_pred):
"""Computes the npairs loss between `y_true` and `y_pred`.
Npairs loss expects paired data where a pair is composed of samples from
the same labels and each pairs in the minibatch have different labels.
The loss takes each row of the pair-wise similarity matrix, `y_pred`,
as logits and the remapped multi-class labels, `y_true`, as labels.
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:
```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```
See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
Args:
y_true: 1-D integer `Tensor` with shape `[batch_size]` of
multi-class labels.
y_pred: 2-D float `Tensor` with shape `[batch_size, batch_size]` of
similarity matrix between embedding matrices.
Returns:
npairs_loss: float scalar.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

# Expand to [batch_size, 1]
y_true = tf.expand_dims(y_true, -1)
y_true = tf.cast(tf.equal(y_true, tf.transpose(y_true)), y_pred.dtype)
y_true /= tf.math.reduce_sum(y_true, 1, keepdims=True)

loss = tf.nn.softmax_cross_entropy_with_logits(
logits=y_pred, labels=y_true)

return tf.math.reduce_mean(loss)


@keras_utils.register_keras_custom_object
class NpairsLoss(tf.keras.losses.Loss):
"""Computes the npairs loss between `y_true` and `y_pred`.
Npairs loss expects paired data where a pair is composed of samples from
the same labels and each pairs in the minibatch have different labels.
The loss takes each row of the pair-wise similarity matrix, `y_pred`,
as logits and the remapped multi-class labels, `y_true`, as labels.
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:
```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```
See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
Args:
name: (Optional) name for the loss.
"""

def __init__(self, name="npairs_loss"):
super(NpairsLoss, self).__init__(
reduction=tf.keras.losses.Reduction.NONE, name=name)

def call(self, y_true, y_pred):
return npairs_loss(y_true, y_pred)
58 changes: 58 additions & 0 deletions tensorflow_addons/losses/npairs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for npairs loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.losses import npairs
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class NpairsLossTest(tf.test.TestCase):
def test_config(self):
nl_obj = npairs.NpairsLoss(name="nl")
self.assertEqual(nl_obj.name, "nl")
self.assertEqual(nl_obj.reduction, tf.keras.losses.Reduction.NONE)

def test_unweighted(self):
nl_obj = npairs.NpairsLoss()
# batch size = 4, hidden size = 2
y_true = tf.constant([0, 1, 2, 3], dtype=tf.int64)
# features of anchors
f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# features of positive samples
fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# similarity matrix
y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True)
loss = nl_obj(y_true, y_pred)

# Loss = 1/4 * \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T))
# Compute loss for i = 0, 1, 2, 3 without multiplier 1/4
# i = 0 => log(1 + sum([exp(-2), exp(-2), exp(-4)])) = 0.253846
# i = 1 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846
# i = 2 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846
# i = 3 => log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 0.253846
# Loss = (0.253856 + 0.253856 + 0.253856 + 0.253856) / 4 = 0.253856

self.assertAllClose(loss, 0.253856)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 7a1ed4b

Please sign in to comment.