Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle "+", fix file objects and refactor SmartOpenHttpTest #263

Merged
merged 5 commits into from
Mar 4, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,14 @@ def _open_binary_stream(uri, mode, **kw):
return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name
elif hasattr(uri, 'read'):
# simply pass-through if already a file-like
filename = '/tmp/unknown'
# we need to return something as the file name, but we don't know what
# so we probe for uri.name (e.g., this works with open() or tempfile.NamedTemporaryFile)
# if the value ends with COMPRESSED_EXT, we will note it in _compression_wrapper()
# if there is no such an attribute, we return "unknown" - this effectively disables any compression
filename = getattr(uri, 'name', 'unknown')
return uri, filename
else:
raise TypeError('don\'t know how to handle uri %s' % repr(uri))
raise TypeError("don't know how to handle uri %r" % uri)


def _s3_open_uri(parsed_uri, mode, **kwargs):
Expand Down Expand Up @@ -582,6 +586,8 @@ def _compression_wrapper(file_obj, filename, mode):
if _need_to_buffer(file_obj, mode, ext):
warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL)
file_obj = io.BytesIO(file_obj.read())
if ext in COMPRESSED_EXT and mode.endswith('+'):
raise ValueError('transparent (de)compression unsupported for mode %r' % mode)

if ext == '.bz2':
return BZ2File(file_obj, mode)
Expand Down Expand Up @@ -622,11 +628,11 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
if encoding is None:
encoding = SYSTEM_ENCODING

if mode[0] == 'r':
decoder = codecs.getreader(encoding)
else:
decoder = codecs.getwriter(encoding)
return decoder(fileobj, errors=errors)
if mode[0] == 'r' or mode.endswith('+'):
fileobj = codecs.getreader(encoding)(fileobj, errors=errors)
if mode[0] in ('w', 'a') or mode.endswith('+'):
fileobj = codecs.getwriter(encoding)(fileobj, errors=errors)
return fileobj

def _add_scheme_to_host(host):
if host.startswith('http://') or host.startswith('https://'):
Expand Down
216 changes: 137 additions & 79 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# This code is distributed under the terms and conditions
# from the MIT License (MIT).

import bz2
import io
import functools
import unittest
import logging
import tempfile
Expand All @@ -29,6 +31,8 @@

PY2 = sys.version_info[0] == 2
CURR_DIR = os.path.abspath(os.path.dirname(__file__))
SAMPLE_TEXT = 'Hello, world!'
SAMPLE_BYTES = SAMPLE_TEXT.encode('utf-8')


class ParseUriTest(unittest.TestCase):
Expand Down Expand Up @@ -200,93 +204,150 @@ def test_http_pass(self):
self.assertTrue(actual_request.headers['Authorization'].startswith('Basic '))

@responses.activate
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()
def _test_compressed_http(self, suffix, query):
"""Can open <suffix> via http?"""
raw_data = b'Hello World Compressed.' * 10000
buffer = six.BytesIO()
buffer.name = 'data' + suffix
with smart_open.smart_open(buffer, 'wb') as outfile:
outfile.write(raw_data)
compressed_data = buffer.getvalue()
# check that the string was actually compressed
self.assertNotEqual(compressed_data, raw_data)

responses.add(responses.GET, 'http://127.0.0.1/data' + suffix, body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open(
'http://127.0.0.1/data%s%s' % (suffix, '?some_param=some_val' if query else ''))

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz?some_param=some_val")

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)
# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), raw_data)

@responses.activate
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz_noquerystring(self):
def test_http_gz(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz")
self._test_compressed_http(".gz", False)

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)

@responses.activate
def test_http_bz2(self):
"""Can open bz2 via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# bz2 compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
test_file = infile.name

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
"""Can open bzip2 via http?"""
self._test_compressed_http(".bz2", False)

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
def test_http_xz(self):
"""Can open xz via http?"""
self._test_compressed_http(".xz", False)

if os.path.isfile(test_file):
os.unlink(test_file)
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz_query(self):
"""Can open gzip via http with a query appended to URI?"""
self._test_compressed_http(".gz", True)

responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.bz2")
def test_http_bz2_query(self):
"""Can open bzip2 via http with a query appended to URI?"""
self._test_compressed_http(".bz2", True)

# decompress the bzip2 and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def test_http_xz_query(self):
"""Can open xz via http with a query appended to URI?"""
self._test_compressed_http(".xz", True)

@responses.activate
def test_http_xz(self):
"""Can open xz via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# lzma compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.xz', delete=False) as infile:
test_file = infile.name

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
def make_buffer(cls=six.BytesIO, initial_value=None, name=None):
"""
Construct a new in-memory file object aka "buffer".

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
:param cls: Class of the file object. Meaningful values are BytesIO and StringIO.
:param initial_value: Passed directly to the constructor, this is the content of the returned buffer.
:param name: Associated file path. Not assigned if is None (default).
:return: Instance of `cls`.
"""
buf = cls(initial_value) if initial_value else cls()
if name is not None:
buf.name = name
if PY2:
buf.__enter__ = lambda: buf
buf.__exit__ = lambda exc_type, exc_val, exc_tb: None
return buf

if os.path.isfile(test_file):
os.unlink(test_file)

responses.add(responses.GET, "http://127.0.0.1/data.xz", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.xz")
class SmartOpenFileObjTest(unittest.TestCase):
"""
Test passing raw file objects.
"""

# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def test_read_bytes(self):
"""Can we read bytes from a byte stream?"""
buffer = make_buffer(initial_value=SAMPLE_BYTES)
with smart_open.smart_open(buffer, 'rb') as sf:
data = sf.read()
self.assertEqual(data, SAMPLE_BYTES)

def test_write_bytes(self):
"""Can we write bytes to a byte stream?"""
buffer = make_buffer()
with smart_open.smart_open(buffer, 'wb') as sf:
sf.write(SAMPLE_BYTES)
self.assertEqual(buffer.getvalue(), SAMPLE_BYTES)

@unittest.skipIf(PY2, "Python 2 does not differentiate between str and bytes")
def test_read_text_stream_fails(self):
"""Attempts to read directly from a text stream should fail."""
buffer = make_buffer(six.StringIO, SAMPLE_TEXT)
with smart_open.smart_open(buffer, 'r') as sf:
self.assertRaises(TypeError, sf.read) # we expect binary mode

@unittest.skipIf(PY2, "Python 2 does not differentiate between str and bytes")
def test_write_text_stream_fails(self):
"""Attempts to write directly to a text stream should fail."""
buffer = make_buffer(six.StringIO)
with smart_open.smart_open(buffer, 'w') as sf:
self.assertRaises(TypeError, sf.write, SAMPLE_TEXT) # we expect binary mode

def test_read_str_from_bytes(self):
"""Can we read strings from a byte stream?"""
buffer = make_buffer(initial_value=SAMPLE_BYTES)
with smart_open.smart_open(buffer, 'r') as sf:
data = sf.read()
self.assertEqual(data, SAMPLE_TEXT)

def test_write_str_to_bytes(self):
"""Can we write strings to a byte stream?"""
buffer = make_buffer()
with smart_open.smart_open(buffer, 'w') as sf:
sf.write(SAMPLE_TEXT)
self.assertEqual(buffer.getvalue(), SAMPLE_BYTES)

def test_name_read(self):
"""Can we use the "name" attribute to decompress on the fly?"""
data = SAMPLE_BYTES * 1000
buffer = make_buffer(initial_value=bz2.compress(data), name='data.bz2')
with smart_open.smart_open(buffer, 'rb') as sf:
data = sf.read()
self.assertEqual(data, data)

def test_name_write(self):
"""Can we use the "name" attribute to compress on the fly?"""
data = SAMPLE_BYTES * 1000
buffer = make_buffer(name='data.bz2')
with smart_open.smart_open(buffer, 'wb') as sf:
sf.write(data)
self.assertEqual(bz2.decompress(buffer.getvalue()), data)

def test_open_side_effect(self):
"""
Does our detection of the `name` attribute work with wrapped open()-ed streams?

We `open()` a file with ".bz2" extension, pass the file object to `smart_open()` and check that
we read decompressed data. This behavior is driven by detecting the `name` attribute in
`_open_binary_stream()`.
"""
data = SAMPLE_BYTES * 1000
with tempfile.NamedTemporaryFile(prefix='smart_open_tests_', suffix=".bz2", delete=False) as tmpf:
tmpf.write(bz2.compress(data))
try:
with open(tmpf.name, 'rb') as openf:
with smart_open.smart_open(openf) as smartf:
smart_data = smartf.read()
self.assertEqual(data, smart_data)
finally:
os.unlink(tmpf.name)

#
# What exactly to patch here differs on _how_ we're opening the file.
Expand Down Expand Up @@ -764,18 +825,18 @@ def test_s3_modes_moto(self):
# fake bucket and key
s3 = boto3.resource('s3')
s3.create_bucket(Bucket='mybucket')
test_string = b"second test"
raw_data = b"second test"

# correct write mode, correct s3 URI
with smart_open.smart_open("s3://mybucket/newkey", "wb") as fout:
logger.debug('fout: %r', fout)
fout.write(test_string)
fout.write(raw_data)

logger.debug("write successfully completed")

output = list(smart_open.smart_open("s3://mybucket/newkey", "rb"))

self.assertEqual(output, [test_string])
self.assertEqual(output, [raw_data])

@mock_s3
def test_s3_metadata_write(self):
Expand Down Expand Up @@ -886,21 +947,18 @@ class CompressionFormatTest(unittest.TestCase):
Test that compression
"""

TEXT = 'Hello'

def write_read_assertion(self, suffix):
with tempfile.NamedTemporaryFile('wb', suffix=suffix, delete=False) as infile:
test_file = infile.name

text = self.TEXT.encode('utf8')
with smart_open.smart_open(test_file, 'wb') as fout: # 'b' for binary, needed on Windows
fout.write(text)
fout.write(SAMPLE_BYTES)

with open(test_file, 'rb') as fin:
self.assertNotEqual(text, fin.read())
self.assertNotEqual(SAMPLE_BYTES, fin.read())

with smart_open.smart_open(test_file, 'rb') as fin:
self.assertEqual(fin.read().decode('utf8'), self.TEXT)
self.assertEqual(fin.read().decode('utf8'), SAMPLE_TEXT)

if os.path.isfile(test_file):
os.unlink(test_file)
Expand Down