Skip to content

Commit

Permalink
Raise warning if entities are not correctly annotated. (#5672)
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma authored Apr 22, 2020
1 parent be8cbec commit b8f0255
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 5 deletions.
2 changes: 2 additions & 0 deletions changelog/5672.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Raise a warning in ``CRFEntityExtractor`` and ``DIETClassifier`` if entities are not correctly annotated in the
training data, e.g. their start and end values do not match any start and end values of tokens.
2 changes: 2 additions & 0 deletions rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,8 @@ def train(
f"Skipping training of classifier."
)
return
if self.component_config.get(ENTITY_RECOGNITION):
self.check_correct_entity_annotations(training_data)

# keep one example for persisting and loading
self._data_example = model_data.first_data_example()
Expand Down
2 changes: 2 additions & 0 deletions rasa/nlu/extractors/crf_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def train(
# checks whether there is at least one
# example with an entity annotation
if training_data.entity_examples:
self.check_correct_entity_annotations(training_data)

# filter out pre-trained entity examples
filtered_entity_examples = self.filter_trainable_entities(
training_data.training_examples
Expand Down
38 changes: 37 additions & 1 deletion rasa/nlu/extractors/extractor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from typing import Any, Dict, List, Text, Tuple, Optional, Union

from rasa.constants import DOCS_URL_TRAINING_DATA_NLU
from rasa.nlu.training_data import TrainingData
from rasa.nlu.tokenizers.tokenizer import Token
from rasa.nlu.components import Component
from rasa.nlu.constants import EXTRACTOR, ENTITIES, TOKENS_NAMES, TEXT
from rasa.nlu.constants import (
EXTRACTOR,
ENTITIES,
TOKENS_NAMES,
TEXT,
ENTITY_ATTRIBUTE_START,
ENTITY_ATTRIBUTE_END,
INTENT,
)
from rasa.nlu.training_data import Message
import rasa.utils.common as common_utils


class EntityExtractor(Component):
Expand Down Expand Up @@ -333,3 +344,28 @@ def filter_trainable_entities(
)

return filtered

@staticmethod
def check_correct_entity_annotations(training_data: TrainingData) -> None:
for example in training_data.entity_examples:
entity_boundaries = [
(entity[ENTITY_ATTRIBUTE_START], entity[ENTITY_ATTRIBUTE_END])
for entity in example.get(ENTITIES)
]
token_start_positions = [t.start for t in example.get(TOKENS_NAMES[TEXT])]
token_end_positions = [t.end for t in example.get(TOKENS_NAMES[TEXT])]

for entity_start, entity_end in entity_boundaries:
if (
entity_start not in token_start_positions
or entity_end not in token_end_positions
):
common_utils.raise_warning(
f"Misaligned entity annotation in message '{example.text}' "
f"with intent '{example.get(INTENT)}'. Make sure the start and "
f"end values of entities in the training data match the token "
f"boundaries (e.g. entities don't include trailing whitespaces "
f"or punctuation).",
docs=DOCS_URL_TRAINING_DATA_NLU,
)
break
14 changes: 12 additions & 2 deletions rasa/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@
import logging
import scipy.sparse
from typing import Optional, Text, Dict, Any, Union, List

from rasa.constants import DOCS_URL_TRAINING_DATA_NLU
from rasa.nlu.training_data import TrainingData
from rasa.core.constants import DIALOGUE
from rasa.nlu.constants import TEXT
from rasa.nlu.constants import (
TEXT,
ENTITIES,
TOKENS_NAMES,
ENTITY_ATTRIBUTE_START,
ENTITY_ATTRIBUTE_END,
INTENT,
)
from rasa.nlu.tokenizers.tokenizer import Token
import rasa.utils.io as io_utils
from rasa.utils.tensorflow.constants import (
Expand Down Expand Up @@ -33,7 +43,7 @@
INNER,
COSINE,
)

import rasa.utils.common as common_utils

logger = logging.getLogger(__name__)

Expand Down
52 changes: 52 additions & 0 deletions tests/nlu/extractors/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from rasa.nlu.tokenizers.tokenizer import Token
from rasa.nlu.training_data import Message
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.nlu.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
from rasa.nlu.training_data.formats import MarkdownReader


@pytest.mark.parametrize(
Expand Down Expand Up @@ -245,3 +247,53 @@ def test_clean_up_entities(
updated_entities = extractor.clean_up_entities(message, entities, keep)

assert updated_entities == expected_entities


@pytest.mark.parametrize(
"text, warnings",
[
(
"## intent:test\n"
"- I want to fly from [Berlin](location) to [ San Fransisco](location)\n",
1,
),
(
"## intent:test\n"
"- I want to fly from [Berlin ](location) to [San Fransisco](location)\n",
1,
),
(
"## intent:test\n"
"- I want to fly from [Berlin](location) to [San Fransisco.](location)\n"
"- I have nothing to say.",
1,
),
(
"## intent:test\n"
"- I have nothing to say.\n"
"- I want to fly from [Berlin](location) to[San Fransisco](location)\n",
1,
),
(
"## intent:test\n"
"- I want to fly from [Berlin](location) to[San Fransisco](location)\n"
"- Book a flight from [London](location) to [Paris.](location)\n",
2,
),
],
)
def test_check_check_correct_entity_annotations(text: Text, warnings: int):
reader = MarkdownReader()
tokenizer = WhitespaceTokenizer()

training_data = reader.reads(text)
tokenizer.train(training_data)

with pytest.warns(UserWarning) as record:
EntityExtractor.check_correct_entity_annotations(training_data)

assert len(record) == warnings
assert all(
[excerpt in record[0].message.args[0]]
for excerpt in ["Misaligned entity annotation in sentence"]
)
3 changes: 1 addition & 2 deletions tests/nlu/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,7 @@ def test_run_evaluation(unpacked_trained_moodbot_path):
errors=False,
)

assert result["intent_evaluation"]
assert result["entity_evaluation"]["DIETClassifier"]
assert result.get("intent_evaluation")


def test_run_cv_evaluation(pretrained_embeddings_spacy_config):
Expand Down
File renamed without changes.

0 comments on commit b8f0255

Please sign in to comment.