From c65c5943c6af30e1d0255e4c52d01a1966d1173e Mon Sep 17 00:00:00 2001 From: tabergma Date: Tue, 14 Aug 2018 14:43:14 +0200 Subject: [PATCH] GH-38: Improve label class. --- flair/data.py | 13 +++++++++++-- tests/test_data.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/flair/data.py b/flair/data.py index 4b5948694a..af9ef6384c 100644 --- a/flair/data.py +++ b/flair/data.py @@ -93,6 +93,10 @@ def load(cls, name: str): class Label: + """ + This class represents a label of a sentence. Each label has a name and optional a confidence value. The confidence + value needs to be between 0.0 and 1.0. Default value for the confidence is 0.0. + """ def __init__(self, name: str, confidence: float = 0.0): self.name = name self.confidence = confidence @@ -103,16 +107,21 @@ def name(self): @name.setter def name(self, name): - self._name = name + if not name: + raise ValueError('Incorrect label name provided. Label name needs to be set.') + else: + self._name = name @property def confidence(self): - return self._name + return self._confidence @confidence.setter def confidence(self, confidence): if 0.0 <= confidence <= 1.0: self._confidence = confidence + else: + self._confidence = 0.0 class Token: diff --git a/tests/test_data.py b/tests/test_data.py index 445a5d7d71..7563a9503c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -179,6 +179,20 @@ def test_tagged_corpus_make_vocab_dictionary(): assert('.' in vocab.get_items()) +def test_label_set_confidence(): + label = Label('class_1', 3.2) + + assert (0.0 == label.confidence) + assert ('class_1' == label.name) + + label.confidence = 0.2 + + assert (0.2 == label.confidence) + + with pytest.raises(ValueError): + label.name = '' + + def test_tagged_corpus_make_label_dictionary(): sentence_1 = Sentence('sentence 1', labels=[Label('class_1')]) sentence_2 = Sentence('sentence 2', labels=[Label('class_2')])