diff --git a/changelog/4917.bugfix.rst b/changelog/4917.bugfix.rst new file mode 100644 index 000000000000..49a14cbf1e4d --- /dev/null +++ b/changelog/4917.bugfix.rst @@ -0,0 +1,2 @@ +When loading models from S3, namespaces (folders within a bucket) are now respected. +Previously, this would result in an error upon loading the model. diff --git a/rasa/nlu/persistor.py b/rasa/nlu/persistor.py index b9fc28540f73..43fdefc7178d 100644 --- a/rasa/nlu/persistor.py +++ b/rasa/nlu/persistor.py @@ -53,7 +53,7 @@ def retrieve(self, model_name: Text, target_path: Text) -> None: tar_name = self._tar_name(model_name) self._retrieve_tar(tar_name) - self._decompress(tar_name, target_path) + self._decompress(os.path.basename(tar_name), target_path) def list_models(self) -> List[Text]: """Lists all the trained models.""" @@ -151,11 +151,11 @@ def _persist_tar(self, file_key: Text, tar_path: Text) -> None: with open(tar_path, "rb") as f: self.s3.Object(self.bucket_name, file_key).put(Body=f) - def _retrieve_tar(self, target_filename: Text) -> None: + def _retrieve_tar(self, model_path: Text) -> None: """Downloads a model that has previously been persisted to s3.""" - - with open(target_filename, "wb") as f: - self.bucket.download_fileobj(target_filename, f) + tar_name = os.path.basename(model_path) + with open(tar_name, "wb") as f: + self.bucket.download_fileobj(model_path, f) class GCSPersistor(Persistor): diff --git a/tests/nlu/base/test_persistor.py b/tests/nlu/base/test_persistor.py index e212130c5a40..8371060a37bd 100644 --- a/tests/nlu/base/test_persistor.py +++ b/tests/nlu/base/test_persistor.py @@ -47,6 +47,33 @@ def test_list_models_method_raise_exeception_in_AWSPersistor(): assert result == [] +# noinspection PyPep8Naming +@mock_s3 +def test_retrieve_tar_archive_with_s3_namespace(): + model = "/my/s3/project/model.tar.gz" + destination = "dst" + with patch.object(persistor.AWSPersistor, "_decompress") as decompress: + with patch.object(persistor.AWSPersistor, "_retrieve_tar") as retrieve: + persistor.AWSPersistor("rasa-test").retrieve(model, destination) + decompress.assert_called_once_with("model.tar.gz", destination) + retrieve.assert_called_once_with(model) + + +# noinspection PyPep8Naming +@mock_s3 +def test_s3_private_retrieve_tar(): + # Ensure the S3 persistor writes to a filename `model.tar.gz`, whilst + # passing the fully namespaced path to boto3 + model = "/my/s3/project/model.tar.gz" + awsPersistor = persistor.AWSPersistor("rasa-test") + with patch.object(awsPersistor.bucket, "download_fileobj") as download_fileobj: + # noinspection PyProtectedMember + awsPersistor._retrieve_tar(model) + retrieveArgs = download_fileobj.call_args[0] + assert retrieveArgs[0] == model + assert retrieveArgs[1].name == "model.tar.gz" + + # noinspection PyPep8Naming def test_list_models_method_in_GCSPersistor(): # noinspection PyUnusedLocal