diff --git a/rasa/nlu/training_data/training_data.py b/rasa/nlu/training_data/training_data.py index 6ac11824f8d9..787f63a5a734 100644 --- a/rasa/nlu/training_data/training_data.py +++ b/rasa/nlu/training_data/training_data.py @@ -16,6 +16,7 @@ from rasa.nlu.training_data.message import Message from rasa.nlu.training_data.util import check_duplicate_synonym from rasa.nlu.utils import list_to_str +from sklearn.model_selection import train_test_split DEFAULT_TRAINING_DATA_OUTPUT_PATH = "training_data.json" @@ -349,6 +350,7 @@ def train_test_split( preserving the fraction of examples per intent.""" # collect all nlu data + test, train = self.split_nlu_examples(train_frac, random_seed) # collect all nlg stories @@ -395,6 +397,7 @@ def build_nlg_stories_from_examples(examples) -> Dict[Text, list]: ] return nlg_stories + def split_nlu_examples( self, train_frac: float, random_seed: Optional[int] = None ) -> Tuple[list, list]: