Skip to content

Commit

Permalink
Merge pull request #4917 from cianclarke/supportS3Namespaces
Browse files Browse the repository at this point in the history
Support S3 namespaces when retrieving models from buckets
  • Loading branch information
ricwo authored Dec 17, 2019
2 parents 6674b1f + 85ff4ec commit 6449f7f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
2 changes: 2 additions & 0 deletions changelog/4917.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
When loading models from S3, namespaces (folders within a bucket) are now respected.
Previously, this would result in an error upon loading the model.
10 changes: 5 additions & 5 deletions rasa/nlu/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def retrieve(self, model_name: Text, target_path: Text) -> None:
tar_name = self._tar_name(model_name)

self._retrieve_tar(tar_name)
self._decompress(tar_name, target_path)
self._decompress(os.path.basename(tar_name), target_path)

def list_models(self) -> List[Text]:
"""Lists all the trained models."""
Expand Down Expand Up @@ -151,11 +151,11 @@ def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
with open(tar_path, "rb") as f:
self.s3.Object(self.bucket_name, file_key).put(Body=f)

def _retrieve_tar(self, target_filename: Text) -> None:
def _retrieve_tar(self, model_path: Text) -> None:
"""Downloads a model that has previously been persisted to s3."""

with open(target_filename, "wb") as f:
self.bucket.download_fileobj(target_filename, f)
tar_name = os.path.basename(model_path)
with open(tar_name, "wb") as f:
self.bucket.download_fileobj(model_path, f)


class GCSPersistor(Persistor):
Expand Down
27 changes: 27 additions & 0 deletions tests/nlu/base/test_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,33 @@ def test_list_models_method_raise_exeception_in_AWSPersistor():
assert result == []


# noinspection PyPep8Naming
@mock_s3
def test_retrieve_tar_archive_with_s3_namespace():
model = "/my/s3/project/model.tar.gz"
destination = "dst"
with patch.object(persistor.AWSPersistor, "_decompress") as decompress:
with patch.object(persistor.AWSPersistor, "_retrieve_tar") as retrieve:
persistor.AWSPersistor("rasa-test").retrieve(model, destination)
decompress.assert_called_once_with("model.tar.gz", destination)
retrieve.assert_called_once_with(model)


# noinspection PyPep8Naming
@mock_s3
def test_s3_private_retrieve_tar():
# Ensure the S3 persistor writes to a filename `model.tar.gz`, whilst
# passing the fully namespaced path to boto3
model = "/my/s3/project/model.tar.gz"
awsPersistor = persistor.AWSPersistor("rasa-test")
with patch.object(awsPersistor.bucket, "download_fileobj") as download_fileobj:
# noinspection PyProtectedMember
awsPersistor._retrieve_tar(model)
retrieveArgs = download_fileobj.call_args[0]
assert retrieveArgs[0] == model
assert retrieveArgs[1].name == "model.tar.gz"


# noinspection PyPep8Naming
def test_list_models_method_in_GCSPersistor():
# noinspection PyUnusedLocal
Expand Down

0 comments on commit 6449f7f

Please sign in to comment.