Skip to content

Commit

Permalink
add full name of retrieval intent to tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
indam23 committed Mar 17, 2020
1 parent 6b8090b commit 6223f57
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
17 changes: 16 additions & 1 deletion rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
index_label_id_mapping: Optional[Dict[int, Text]] = None,
index_tag_id_mapping: Optional[Dict[int, Text]] = None,
model: Optional[RasaModel] = None,
response_text_key_dict=None,
) -> None:
"""Declare instance variables with default values."""

Expand All @@ -281,6 +282,11 @@ def __init__(
self.index_label_id_mapping = index_label_id_mapping
self.index_tag_id_mapping = index_tag_id_mapping

if response_text_key_dict:
self.response_text_key_dict = response_text_key_dict
else:
self.response_text_key_dict = {}

self.model = model

self.num_tags: Optional[int] = None # number of entity tags
Expand Down Expand Up @@ -812,7 +818,10 @@ def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
model_dir / f"{file_name}.index_tag_id_mapping.pkl",
self.index_tag_id_mapping,
)

io_utils.json_pickle(
model_dir / f"{file_name}.response_text_key_dict.pkl",
self.response_text_key_dict,
)
return {"file": file_name}

@classmethod
Expand All @@ -839,6 +848,7 @@ def load(
label_data,
meta,
data_example,
response_text_key_dict,
) = cls._load_from_files(meta, model_dir)

meta = train_utils.update_similarity_type(meta)
Expand All @@ -852,6 +862,7 @@ def load(
index_label_id_mapping=index_label_id_mapping,
index_tag_id_mapping=index_tag_id_mapping,
model=model,
response_text_key_dict=response_text_key_dict,
)

@classmethod
Expand All @@ -868,6 +879,9 @@ def _load_from_files(cls, meta: Dict[Text, Any], model_dir: Text):
index_tag_id_mapping = io_utils.json_unpickle(
model_dir / f"{file_name}.index_tag_id_mapping.pkl"
)
response_text_key_dict = io_utils.json_unpickle(
model_dir / f"{file_name}.response_text_key_dict.pkl"
)

# jsonpickle converts dictionary keys to strings
index_label_id_mapping = {
Expand All @@ -884,6 +898,7 @@ def _load_from_files(cls, meta: Dict[Text, Any], model_dir: Text):
label_data,
meta,
data_example,
response_text_key_dict,
)

@classmethod
Expand Down
36 changes: 34 additions & 2 deletions rasa/nlu/selectors/response_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import numpy as np
import tensorflow as tf
import os
import pickle
import warnings

from typing import Any, Dict, Optional, Text, Tuple, Union, List, Type

Expand Down Expand Up @@ -64,9 +67,12 @@
from rasa.nlu.constants import (
RESPONSE,
RESPONSE_SELECTOR_PROPERTY_NAME,
RESPONSE_KEY_ATTRIBUTE,
INTENT,
DEFAULT_OPEN_UTTERANCE_TYPE,
TEXT,
)

from rasa.utils.tensorflow.model_data import RasaModelData
from rasa.utils.tensorflow.models import RasaModel

Expand Down Expand Up @@ -194,6 +200,7 @@ def __init__(
index_label_id_mapping: Optional[Dict[int, Text]] = None,
index_tag_id_mapping: Optional[Dict[int, Text]] = None,
model: Optional[RasaModel] = None,
response_text_key_dict=None,
) -> None:

component_config = component_config or {}
Expand All @@ -204,7 +211,11 @@ def __init__(
component_config[BILOU_FLAG] = None

super().__init__(
component_config, index_label_id_mapping, index_tag_id_mapping, model
component_config,
index_label_id_mapping,
index_tag_id_mapping,
model,
response_text_key_dict,
)

@property
Expand All @@ -222,6 +233,20 @@ def _check_config_parameters(self) -> None:
super()._check_config_parameters()
self._load_selector_params(self.component_config)

@staticmethod
def _create_response_text_key_dict(
training_data: "TrainingData",
) -> Dict[Text, Text]:
"""Create response_key dictionary"""

response_text_key_dict = {}
for example in training_data.intent_examples:
response_text_key_dict[
example.get(RESPONSE)
] = f"{example.get(INTENT)}/{example.get(RESPONSE_KEY_ATTRIBUTE)}"

return response_text_key_dict

@staticmethod
def _set_message_property(
message: Message, prediction_dict: Dict[Text, Any], selector_key: Text
Expand Down Expand Up @@ -253,6 +278,7 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
label_id_index_mapping = self._label_id_index_mapping(
training_data, attribute=RESPONSE
)
self.response_text_key_dict = self._create_response_text_key_dict(training_data)

if not label_id_index_mapping:
# no labels are present to train
Expand All @@ -279,6 +305,8 @@ def process(self, message: Message, **kwargs: Any) -> None:

out = self._predict(message)
label, label_ranking = self._predict_label(out)
full_intent_name = self.response_text_key_dict.get(label.get("name"))
# add suffix to label here

selector_key = (
self.retrieval_intent
Expand All @@ -290,7 +318,11 @@ def process(self, message: Message, **kwargs: Any) -> None:
f"Adding following selector key to message property: {selector_key}"
)

prediction_dict = {"response": label, "ranking": label_ranking}
prediction_dict = {
"response": label,
"ranking": label_ranking,
"full_intent_name": full_intent_name,
}

self._set_message_property(message, prediction_dict, selector_key)

Expand Down
9 changes: 8 additions & 1 deletion tests/nlu/selectors/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rasa.nlu.training_data import load_data
from rasa.nlu.train import Trainer, Interpreter
from rasa.utils.tensorflow.constants import EPOCHS
from rasa.nlu.constants import RESPONSE_SELECTOR_PROPERTY_NAME


@pytest.mark.parametrize(
Expand Down Expand Up @@ -33,6 +34,12 @@ def test_train_selector(pipeline, component_builder, tmpdir):
assert trainer.pipeline

loaded = Interpreter.load(persisted_path, component_builder)
parsed = loaded.parse("hello")

assert loaded.pipeline
assert loaded.parse("hello") is not None
assert parsed is not None
assert (
parsed.get(RESPONSE_SELECTOR_PROPERTY_NAME)
.get("default")
.get("full_intent_name")
) is not None

0 comments on commit 6223f57

Please sign in to comment.