Skip to content

Commit

Permalink
Handle "+", fix file objects and refactor SmartOpenHttpTest (#263)
Browse files Browse the repository at this point in the history
* Handle "+", fix file objects and refactor SmartOpenHttpTest
* Apply review suggestions
* Apply review suggestions
* Apply review suggestions
* Apply review suggestions
  • Loading branch information
vmarkovtsev authored and mpenkov committed Mar 4, 2019
1 parent b69feb3 commit 83e4930
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 99 deletions.
20 changes: 13 additions & 7 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,14 @@ def _open_binary_stream(uri, mode, **kw):
return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name
elif hasattr(uri, 'read'):
# simply pass-through if already a file-like
filename = '/tmp/unknown'
# we need to return something as the file name, but we don't know what
# so we probe for uri.name (e.g., this works with open() or tempfile.NamedTemporaryFile)
# if the value ends with COMPRESSED_EXT, we will note it in _compression_wrapper()
# if there is no such an attribute, we return "unknown" - this effectively disables any compression
filename = getattr(uri, 'name', 'unknown')
return uri, filename
else:
raise TypeError('don\'t know how to handle uri %s' % repr(uri))
raise TypeError("don't know how to handle uri %r" % uri)


def _s3_open_uri(parsed_uri, mode, **kwargs):
Expand Down Expand Up @@ -580,6 +584,8 @@ def _compression_wrapper(file_obj, filename, mode):
if _need_to_buffer(file_obj, mode, ext):
warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL)
file_obj = io.BytesIO(file_obj.read())
if ext in COMPRESSED_EXT and mode.endswith('+'):
raise ValueError('transparent (de)compression unsupported for mode %r' % mode)

if ext == '.bz2':
return BZ2File(file_obj, mode)
Expand Down Expand Up @@ -620,11 +626,11 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
if encoding is None:
encoding = SYSTEM_ENCODING

if mode[0] == 'r':
decoder = codecs.getreader(encoding)
else:
decoder = codecs.getwriter(encoding)
return decoder(fileobj, errors=errors)
if mode[0] == 'r' or mode.endswith('+'):
fileobj = codecs.getreader(encoding)(fileobj, errors=errors)
if mode[0] in ('w', 'a') or mode.endswith('+'):
fileobj = codecs.getwriter(encoding)(fileobj, errors=errors)
return fileobj

def _add_scheme_to_host(host):
if host.startswith('http://') or host.startswith('https://'):
Expand Down
233 changes: 141 additions & 92 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This code is distributed under the terms and conditions
# from the MIT License (MIT).

import bz2
import io
import unittest
import logging
Expand All @@ -27,8 +28,9 @@

logger = logging.getLogger(__name__)

PY2 = sys.version_info[0] == 2
CURR_DIR = os.path.abspath(os.path.dirname(__file__))
SAMPLE_TEXT = 'Hello, world!'
SAMPLE_BYTES = SAMPLE_TEXT.encode('utf-8')


class ParseUriTest(unittest.TestCase):
Expand Down Expand Up @@ -200,93 +202,149 @@ def test_http_pass(self):
self.assertTrue(actual_request.headers['Authorization'].startswith('Basic '))

@responses.activate
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()
def _test_compressed_http(self, suffix, query):
"""Can open <suffix> via http?"""
raw_data = b'Hello World Compressed.' * 10000
buffer = make_buffer(name='data' + suffix)
with smart_open.smart_open(buffer, 'wb') as outfile:
outfile.write(raw_data)
compressed_data = buffer.getvalue()
# check that the string was actually compressed
self.assertNotEqual(compressed_data, raw_data)

responses.add(responses.GET, 'http://127.0.0.1/data' + suffix, body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open(
'http://127.0.0.1/data%s%s' % (suffix, '?some_param=some_val' if query else ''))

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz?some_param=some_val")

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)
# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), raw_data)

@responses.activate
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz_noquerystring(self):
def test_http_gz(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()
self._test_compressed_http(".gz", False)

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz")

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)

@responses.activate
def test_http_bz2(self):
"""Can open bz2 via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# bz2 compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
test_file = infile.name
"""Can open bzip2 via http?"""
self._test_compressed_http(".bz2", False)

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
def test_http_xz(self):
"""Can open xz via http?"""
self._test_compressed_http(".xz", False)

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz_query(self):
"""Can open gzip via http with a query appended to URI?"""
self._test_compressed_http(".gz", True)

if os.path.isfile(test_file):
os.unlink(test_file)
def test_http_bz2_query(self):
"""Can open bzip2 via http with a query appended to URI?"""
self._test_compressed_http(".bz2", True)

responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.bz2")
def test_http_xz_query(self):
"""Can open xz via http with a query appended to URI?"""
self._test_compressed_http(".xz", True)

# decompress the bzip2 and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)

@responses.activate
def test_http_xz(self):
"""Can open xz via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# lzma compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.xz', delete=False) as infile:
test_file = infile.name

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
def make_buffer(cls=six.BytesIO, initial_value=None, name=None):
"""
Construct a new in-memory file object aka "buffer".
with open(test_file, 'rb') as infile:
compressed_data = infile.read()
:param cls: Class of the file object. Meaningful values are BytesIO and StringIO.
:param initial_value: Passed directly to the constructor, this is the content of the returned buffer.
:param name: Associated file path. Not assigned if is None (default).
:return: Instance of `cls`.
"""
buf = cls(initial_value) if initial_value else cls()
if name is not None:
buf.name = name
if six.PY2:
buf.__enter__ = lambda: buf
buf.__exit__ = lambda exc_type, exc_val, exc_tb: None
return buf

if os.path.isfile(test_file):
os.unlink(test_file)

responses.add(responses.GET, "http://127.0.0.1/data.xz", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.xz")
class SmartOpenFileObjTest(unittest.TestCase):
"""
Test passing raw file objects.
"""

# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def test_read_bytes(self):
"""Can we read bytes from a byte stream?"""
buffer = make_buffer(initial_value=SAMPLE_BYTES)
with smart_open.smart_open(buffer, 'rb') as sf:
data = sf.read()
self.assertEqual(data, SAMPLE_BYTES)

def test_write_bytes(self):
"""Can we write bytes to a byte stream?"""
buffer = make_buffer()
with smart_open.smart_open(buffer, 'wb') as sf:
sf.write(SAMPLE_BYTES)
self.assertEqual(buffer.getvalue(), SAMPLE_BYTES)

@unittest.skipIf(six.PY2, "Python 2 does not differentiate between str and bytes")
def test_read_text_stream_fails(self):
"""Attempts to read directly from a text stream should fail."""
buffer = make_buffer(six.StringIO, initial_value=SAMPLE_TEXT)
with smart_open.smart_open(buffer, 'r') as sf:
self.assertRaises(TypeError, sf.read) # we expect binary mode

@unittest.skipIf(six.PY2, "Python 2 does not differentiate between str and bytes")
def test_write_text_stream_fails(self):
"""Attempts to write directly to a text stream should fail."""
buffer = make_buffer(six.StringIO)
with smart_open.smart_open(buffer, 'w') as sf:
self.assertRaises(TypeError, sf.write, SAMPLE_TEXT) # we expect binary mode

def test_read_str_from_bytes(self):
"""Can we read strings from a byte stream?"""
buffer = make_buffer(initial_value=SAMPLE_BYTES)
with smart_open.smart_open(buffer, 'r') as sf:
data = sf.read()
self.assertEqual(data, SAMPLE_TEXT)

def test_write_str_to_bytes(self):
"""Can we write strings to a byte stream?"""
buffer = make_buffer()
with smart_open.smart_open(buffer, 'w') as sf:
sf.write(SAMPLE_TEXT)
self.assertEqual(buffer.getvalue(), SAMPLE_BYTES)

def test_name_read(self):
"""Can we use the "name" attribute to decompress on the fly?"""
data = SAMPLE_BYTES * 1000
buffer = make_buffer(initial_value=bz2.compress(data), name='data.bz2')
with smart_open.smart_open(buffer, 'rb') as sf:
data = sf.read()
self.assertEqual(data, data)

def test_name_write(self):
"""Can we use the "name" attribute to compress on the fly?"""
data = SAMPLE_BYTES * 1000
buffer = make_buffer(name='data.bz2')
with smart_open.smart_open(buffer, 'wb') as sf:
sf.write(data)
self.assertEqual(bz2.decompress(buffer.getvalue()), data)

def test_open_side_effect(self):
"""
Does our detection of the `name` attribute work with wrapped open()-ed streams?
We `open()` a file with ".bz2" extension, pass the file object to `smart_open()` and check that
we read decompressed data. This behavior is driven by detecting the `name` attribute in
`_open_binary_stream()`.
"""
data = SAMPLE_BYTES * 1000
with tempfile.NamedTemporaryFile(prefix='smart_open_tests_', suffix=".bz2", delete=False) as tmpf:
tmpf.write(bz2.compress(data))
try:
with open(tmpf.name, 'rb') as openf:
with smart_open.smart_open(openf) as smartf:
smart_data = smartf.read()
self.assertEqual(data, smart_data)
finally:
os.unlink(tmpf.name)

#
# What exactly to patch here differs on _how_ we're opening the file.
Expand Down Expand Up @@ -764,18 +822,18 @@ def test_s3_modes_moto(self):
# fake bucket and key
s3 = boto3.resource('s3')
s3.create_bucket(Bucket='mybucket')
test_string = b"second test"
raw_data = b"second test"

# correct write mode, correct s3 URI
with smart_open.smart_open("s3://mybucket/newkey", "wb") as fout:
logger.debug('fout: %r', fout)
fout.write(test_string)
fout.write(raw_data)

logger.debug("write successfully completed")

output = list(smart_open.smart_open("s3://mybucket/newkey", "rb"))

self.assertEqual(output, [test_string])
self.assertEqual(output, [raw_data])

@mock_s3
def test_s3_metadata_write(self):
Expand Down Expand Up @@ -886,24 +944,15 @@ class CompressionFormatTest(unittest.TestCase):
Test that compression
"""

TEXT = 'Hello'

def write_read_assertion(self, suffix):
with tempfile.NamedTemporaryFile('wb', suffix=suffix, delete=False) as infile:
test_file = infile.name

text = self.TEXT.encode('utf8')
with smart_open.smart_open(test_file, 'wb') as fout: # 'b' for binary, needed on Windows
fout.write(text)

with open(test_file, 'rb') as fin:
self.assertNotEqual(text, fin.read())

test_file = make_buffer(name='file' + suffix)
with smart_open.smart_open(test_file, 'wb') as fout:
fout.write(SAMPLE_BYTES)
self.assertNotEqual(SAMPLE_BYTES, test_file.getvalue())
# we have to recreate the buffer because it is closed
test_file = make_buffer(initial_value=test_file.getvalue(), name=test_file.name)
with smart_open.smart_open(test_file, 'rb') as fin:
self.assertEqual(fin.read().decode('utf8'), self.TEXT)

if os.path.isfile(test_file):
os.unlink(test_file)
self.assertEqual(fin.read(), SAMPLE_BYTES)

def test_open_gz(self):
"""Can open gzip?"""
Expand Down Expand Up @@ -997,7 +1046,7 @@ def cleanup_temp_bz2(self, test_file):
os.unlink(test_file)

def test_can_read_multistream_bz2(self):
if PY2:
if six.PY2:
# this is a backport from Python 3
from bz2file import BZ2File
else:
Expand All @@ -1010,7 +1059,7 @@ def test_can_read_multistream_bz2(self):

def test_python2_stdlib_bz2_cannot_read_multistream(self):
# Multistream bzip is included in Python 3
if not PY2:
if not six.PY2:
return
import bz2

Expand Down

0 comments on commit 83e4930

Please sign in to comment.