From 83e4930dc98e81f331d578dffb5ddbd1eac491b6 Mon Sep 17 00:00:00 2001 From: Vadim Markovtsev Date: Mon, 4 Mar 2019 05:49:41 +0100 Subject: [PATCH] Handle "+", fix file objects and refactor SmartOpenHttpTest (#263) * Handle "+", fix file objects and refactor SmartOpenHttpTest * Apply review suggestions * Apply review suggestions * Apply review suggestions * Apply review suggestions --- smart_open/smart_open_lib.py | 20 ++- smart_open/tests/test_smart_open.py | 233 +++++++++++++++++----------- 2 files changed, 154 insertions(+), 99 deletions(-) diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index f8770e77..696d6e20 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -381,10 +381,14 @@ def _open_binary_stream(uri, mode, **kw): return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name elif hasattr(uri, 'read'): # simply pass-through if already a file-like - filename = '/tmp/unknown' + # we need to return something as the file name, but we don't know what + # so we probe for uri.name (e.g., this works with open() or tempfile.NamedTemporaryFile) + # if the value ends with COMPRESSED_EXT, we will note it in _compression_wrapper() + # if there is no such an attribute, we return "unknown" - this effectively disables any compression + filename = getattr(uri, 'name', 'unknown') return uri, filename else: - raise TypeError('don\'t know how to handle uri %s' % repr(uri)) + raise TypeError("don't know how to handle uri %r" % uri) def _s3_open_uri(parsed_uri, mode, **kwargs): @@ -580,6 +584,8 @@ def _compression_wrapper(file_obj, filename, mode): if _need_to_buffer(file_obj, mode, ext): warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL) file_obj = io.BytesIO(file_obj.read()) + if ext in COMPRESSED_EXT and mode.endswith('+'): + raise ValueError('transparent (de)compression unsupported for mode %r' % mode) if ext == '.bz2': return BZ2File(file_obj, mode) @@ -620,11 +626,11 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS): if encoding is None: encoding = SYSTEM_ENCODING - if mode[0] == 'r': - decoder = codecs.getreader(encoding) - else: - decoder = codecs.getwriter(encoding) - return decoder(fileobj, errors=errors) + if mode[0] == 'r' or mode.endswith('+'): + fileobj = codecs.getreader(encoding)(fileobj, errors=errors) + if mode[0] in ('w', 'a') or mode.endswith('+'): + fileobj = codecs.getwriter(encoding)(fileobj, errors=errors) + return fileobj def _add_scheme_to_host(host): if host.startswith('http://') or host.startswith('https://'): diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index c29da7cc..0eaba0cc 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -6,6 +6,7 @@ # This code is distributed under the terms and conditions # from the MIT License (MIT). +import bz2 import io import unittest import logging @@ -27,8 +28,9 @@ logger = logging.getLogger(__name__) -PY2 = sys.version_info[0] == 2 CURR_DIR = os.path.abspath(os.path.dirname(__file__)) +SAMPLE_TEXT = 'Hello, world!' +SAMPLE_BYTES = SAMPLE_TEXT.encode('utf-8') class ParseUriTest(unittest.TestCase): @@ -200,93 +202,149 @@ def test_http_pass(self): self.assertTrue(actual_request.headers['Authorization'].startswith('Basic ')) @responses.activate - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') - def test_http_gz(self): - """Can open gzip via http?""" - fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz') - with open(fpath, 'rb') as infile: - data = infile.read() - - with gzip.GzipFile(fpath) as fin: - expected_hash = hashlib.md5(fin.read()).hexdigest() + def _test_compressed_http(self, suffix, query): + """Can open via http?""" + raw_data = b'Hello World Compressed.' * 10000 + buffer = make_buffer(name='data' + suffix) + with smart_open.smart_open(buffer, 'wb') as outfile: + outfile.write(raw_data) + compressed_data = buffer.getvalue() + # check that the string was actually compressed + self.assertNotEqual(compressed_data, raw_data) + + responses.add(responses.GET, 'http://127.0.0.1/data' + suffix, body=compressed_data, stream=True) + smart_open_object = smart_open.smart_open( + 'http://127.0.0.1/data%s%s' % (suffix, '?some_param=some_val' if query else '')) - responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True) - smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz?some_param=some_val") - - m = hashlib.md5(smart_open_object.read()) - # decompress the gzip and get the same md5 hash - self.assertEqual(m.hexdigest(), expected_hash) + # decompress the xz and get the same md5 hash + self.assertEqual(smart_open_object.read(), raw_data) - @responses.activate @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') - def test_http_gz_noquerystring(self): + def test_http_gz(self): """Can open gzip via http?""" - fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz') - with open(fpath, 'rb') as infile: - data = infile.read() + self._test_compressed_http(".gz", False) - with gzip.GzipFile(fpath) as fin: - expected_hash = hashlib.md5(fin.read()).hexdigest() - - responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True) - smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz") - - m = hashlib.md5(smart_open_object.read()) - # decompress the gzip and get the same md5 hash - self.assertEqual(m.hexdigest(), expected_hash) - - @responses.activate def test_http_bz2(self): - """Can open bz2 via http?""" - test_string = b'Hello World Compressed.' - # - # TODO: why are these tests writing to temporary files? We can do the - # bz2 compression in memory. - # - with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile: - test_file = infile.name + """Can open bzip2 via http?""" + self._test_compressed_http(".bz2", False) - with smart_open.smart_open(test_file, 'wb') as outfile: - outfile.write(test_string) + def test_http_xz(self): + """Can open xz via http?""" + self._test_compressed_http(".xz", False) - with open(test_file, 'rb') as infile: - compressed_data = infile.read() + @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') + def test_http_gz_query(self): + """Can open gzip via http with a query appended to URI?""" + self._test_compressed_http(".gz", True) - if os.path.isfile(test_file): - os.unlink(test_file) + def test_http_bz2_query(self): + """Can open bzip2 via http with a query appended to URI?""" + self._test_compressed_http(".bz2", True) - responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=compressed_data, stream=True) - smart_open_object = smart_open.smart_open("http://127.0.0.1/data.bz2") + def test_http_xz_query(self): + """Can open xz via http with a query appended to URI?""" + self._test_compressed_http(".xz", True) - # decompress the bzip2 and get the same md5 hash - self.assertEqual(smart_open_object.read(), test_string) - @responses.activate - def test_http_xz(self): - """Can open xz via http?""" - test_string = b'Hello World Compressed.' - # - # TODO: why are these tests writing to temporary files? We can do the - # lzma compression in memory. - # - with tempfile.NamedTemporaryFile('wb', suffix='.xz', delete=False) as infile: - test_file = infile.name - - with smart_open.smart_open(test_file, 'wb') as outfile: - outfile.write(test_string) +def make_buffer(cls=six.BytesIO, initial_value=None, name=None): + """ + Construct a new in-memory file object aka "buffer". - with open(test_file, 'rb') as infile: - compressed_data = infile.read() + :param cls: Class of the file object. Meaningful values are BytesIO and StringIO. + :param initial_value: Passed directly to the constructor, this is the content of the returned buffer. + :param name: Associated file path. Not assigned if is None (default). + :return: Instance of `cls`. + """ + buf = cls(initial_value) if initial_value else cls() + if name is not None: + buf.name = name + if six.PY2: + buf.__enter__ = lambda: buf + buf.__exit__ = lambda exc_type, exc_val, exc_tb: None + return buf - if os.path.isfile(test_file): - os.unlink(test_file) - responses.add(responses.GET, "http://127.0.0.1/data.xz", body=compressed_data, stream=True) - smart_open_object = smart_open.smart_open("http://127.0.0.1/data.xz") +class SmartOpenFileObjTest(unittest.TestCase): + """ + Test passing raw file objects. + """ - # decompress the xz and get the same md5 hash - self.assertEqual(smart_open_object.read(), test_string) + def test_read_bytes(self): + """Can we read bytes from a byte stream?""" + buffer = make_buffer(initial_value=SAMPLE_BYTES) + with smart_open.smart_open(buffer, 'rb') as sf: + data = sf.read() + self.assertEqual(data, SAMPLE_BYTES) + + def test_write_bytes(self): + """Can we write bytes to a byte stream?""" + buffer = make_buffer() + with smart_open.smart_open(buffer, 'wb') as sf: + sf.write(SAMPLE_BYTES) + self.assertEqual(buffer.getvalue(), SAMPLE_BYTES) + + @unittest.skipIf(six.PY2, "Python 2 does not differentiate between str and bytes") + def test_read_text_stream_fails(self): + """Attempts to read directly from a text stream should fail.""" + buffer = make_buffer(six.StringIO, initial_value=SAMPLE_TEXT) + with smart_open.smart_open(buffer, 'r') as sf: + self.assertRaises(TypeError, sf.read) # we expect binary mode + + @unittest.skipIf(six.PY2, "Python 2 does not differentiate between str and bytes") + def test_write_text_stream_fails(self): + """Attempts to write directly to a text stream should fail.""" + buffer = make_buffer(six.StringIO) + with smart_open.smart_open(buffer, 'w') as sf: + self.assertRaises(TypeError, sf.write, SAMPLE_TEXT) # we expect binary mode + + def test_read_str_from_bytes(self): + """Can we read strings from a byte stream?""" + buffer = make_buffer(initial_value=SAMPLE_BYTES) + with smart_open.smart_open(buffer, 'r') as sf: + data = sf.read() + self.assertEqual(data, SAMPLE_TEXT) + + def test_write_str_to_bytes(self): + """Can we write strings to a byte stream?""" + buffer = make_buffer() + with smart_open.smart_open(buffer, 'w') as sf: + sf.write(SAMPLE_TEXT) + self.assertEqual(buffer.getvalue(), SAMPLE_BYTES) + + def test_name_read(self): + """Can we use the "name" attribute to decompress on the fly?""" + data = SAMPLE_BYTES * 1000 + buffer = make_buffer(initial_value=bz2.compress(data), name='data.bz2') + with smart_open.smart_open(buffer, 'rb') as sf: + data = sf.read() + self.assertEqual(data, data) + + def test_name_write(self): + """Can we use the "name" attribute to compress on the fly?""" + data = SAMPLE_BYTES * 1000 + buffer = make_buffer(name='data.bz2') + with smart_open.smart_open(buffer, 'wb') as sf: + sf.write(data) + self.assertEqual(bz2.decompress(buffer.getvalue()), data) + + def test_open_side_effect(self): + """ + Does our detection of the `name` attribute work with wrapped open()-ed streams? + We `open()` a file with ".bz2" extension, pass the file object to `smart_open()` and check that + we read decompressed data. This behavior is driven by detecting the `name` attribute in + `_open_binary_stream()`. + """ + data = SAMPLE_BYTES * 1000 + with tempfile.NamedTemporaryFile(prefix='smart_open_tests_', suffix=".bz2", delete=False) as tmpf: + tmpf.write(bz2.compress(data)) + try: + with open(tmpf.name, 'rb') as openf: + with smart_open.smart_open(openf) as smartf: + smart_data = smartf.read() + self.assertEqual(data, smart_data) + finally: + os.unlink(tmpf.name) # # What exactly to patch here differs on _how_ we're opening the file. @@ -764,18 +822,18 @@ def test_s3_modes_moto(self): # fake bucket and key s3 = boto3.resource('s3') s3.create_bucket(Bucket='mybucket') - test_string = b"second test" + raw_data = b"second test" # correct write mode, correct s3 URI with smart_open.smart_open("s3://mybucket/newkey", "wb") as fout: logger.debug('fout: %r', fout) - fout.write(test_string) + fout.write(raw_data) logger.debug("write successfully completed") output = list(smart_open.smart_open("s3://mybucket/newkey", "rb")) - self.assertEqual(output, [test_string]) + self.assertEqual(output, [raw_data]) @mock_s3 def test_s3_metadata_write(self): @@ -886,24 +944,15 @@ class CompressionFormatTest(unittest.TestCase): Test that compression """ - TEXT = 'Hello' - def write_read_assertion(self, suffix): - with tempfile.NamedTemporaryFile('wb', suffix=suffix, delete=False) as infile: - test_file = infile.name - - text = self.TEXT.encode('utf8') - with smart_open.smart_open(test_file, 'wb') as fout: # 'b' for binary, needed on Windows - fout.write(text) - - with open(test_file, 'rb') as fin: - self.assertNotEqual(text, fin.read()) - + test_file = make_buffer(name='file' + suffix) + with smart_open.smart_open(test_file, 'wb') as fout: + fout.write(SAMPLE_BYTES) + self.assertNotEqual(SAMPLE_BYTES, test_file.getvalue()) + # we have to recreate the buffer because it is closed + test_file = make_buffer(initial_value=test_file.getvalue(), name=test_file.name) with smart_open.smart_open(test_file, 'rb') as fin: - self.assertEqual(fin.read().decode('utf8'), self.TEXT) - - if os.path.isfile(test_file): - os.unlink(test_file) + self.assertEqual(fin.read(), SAMPLE_BYTES) def test_open_gz(self): """Can open gzip?""" @@ -997,7 +1046,7 @@ def cleanup_temp_bz2(self, test_file): os.unlink(test_file) def test_can_read_multistream_bz2(self): - if PY2: + if six.PY2: # this is a backport from Python 3 from bz2file import BZ2File else: @@ -1010,7 +1059,7 @@ def test_can_read_multistream_bz2(self): def test_python2_stdlib_bz2_cannot_read_multistream(self): # Multistream bzip is included in Python 3 - if not PY2: + if not six.PY2: return import bz2