Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle "+", fix file objects and refactor SmartOpenHttpTest #263

Merged
merged 5 commits into from
Mar 4, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,36 @@ def _open_binary_stream(uri, mode, **kw):
if host is not None:
kw['endpoint_url'] = _add_scheme_to_host(host)
return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name
elif hasattr(uri, 'read'):
elif _is_stream(uri, mode):
# simply pass-through if already a file-like
filename = '/tmp/unknown'
# we need to return something as the file name, but we don't know what
# so we probe for uri.name (e.g., this works with tempfile.NamedTemporaryFile)
# if the value ends with COMPRESSED_EXT, we will note it in _compression_wrapper()
# if there is no such an attribute, we return "unknown" - this effectively disables
# any compression (the actual string does not matter ofc)
filename = getattr(uri, 'name', 'unknown')
return uri, filename
else:
raise TypeError('don\'t know how to handle uri %s' % repr(uri))
raise TypeError("don't know how to handle uri %r in mode %r" % (uri, mode))


def _is_stream(fileobj, mode):
"""
Detect whether the specified object is a file object with required capabilities
implied by `mode`.
"""
has_read = hasattr(fileobj, 'read')
has_write = hasattr(fileobj, 'write')
if not has_read and not has_write:
return False
if mode.endswith('+'):
return has_read and has_write
if mode[0] == 'r':
return has_read
if mode[0] in ('w', 'a'):
return has_write
# we should never get here
assert False


def _s3_open_uri(parsed_uri, mode, **kwargs):
Expand Down Expand Up @@ -582,6 +606,8 @@ def _compression_wrapper(file_obj, filename, mode):
if _need_to_buffer(file_obj, mode, ext):
warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL)
file_obj = io.BytesIO(file_obj.read())
if ext in COMPRESSED_EXT and mode.endswith('+'):
raise ValueError('impossible to open in read+write+compressed mode')

if ext == '.bz2':
return BZ2File(file_obj, mode)
Expand Down Expand Up @@ -622,11 +648,11 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
if encoding is None:
encoding = SYSTEM_ENCODING

if mode[0] == 'r':
decoder = codecs.getreader(encoding)
else:
decoder = codecs.getwriter(encoding)
return decoder(fileobj, errors=errors)
if mode[0] == 'r' or mode.endswith('+'):
fileobj = codecs.getreader(encoding)(fileobj, errors=errors)
if mode[0] in ('w', 'a') or mode.endswith('+'):
fileobj = codecs.getwriter(encoding)(fileobj, errors=errors)
return fileobj

def _add_scheme_to_host(host):
if host.startswith('http://') or host.startswith('https://'):
Expand Down
212 changes: 147 additions & 65 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This code is distributed under the terms and conditions
# from the MIT License (MIT).

from bz2 import compress as bzip2_compress, decompress as bzip2_decompress
import io
import unittest
import logging
Expand Down Expand Up @@ -200,92 +201,173 @@ def test_http_pass(self):
self.assertTrue(actual_request.headers['Authorization'].startswith('Basic '))

@responses.activate
def _test_compressed_http(self, suffix, query):
"""Can open <suffix> via http?"""
test_string = b'Hello World Compressed.' * 10000
test_file = six.BytesIO()
test_file.name = 'data' + suffix
with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
compressed_data = test_file.getvalue()
# check that the string was actually compressed
self.assertNotEqual(compressed_data, test_string)

responses.add(responses.GET, 'http://127.0.0.1/data' + suffix, body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open(
'http://127.0.0.1/data%s%s' % (suffix, '?some_param=some_val' if query else ''))

# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)

@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()
self._test_compressed_http(".gz", False)

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz?some_param=some_val")
def test_http_bz2(self):
"""Can open bzip2 via http?"""
self._test_compressed_http(".bz2", False)

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)
def test_http_xz(self):
"""Can open xz via http?"""
self._test_compressed_http(".xz", False)

@responses.activate
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz_noquerystring(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()
def test_http_gz_query(self):
"""Can open gzip via http with a query appended to URI?"""
self._test_compressed_http(".gz", True)

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()
def test_http_bz2_query(self):
"""Can open bzip2 via http with a query appended to URI?"""
self._test_compressed_http(".bz2", True)

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz")
def test_http_xz_query(self):
"""Can open xz via http with a query appended to URI?"""
self._test_compressed_http(".xz", True)

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)

@responses.activate
def test_http_bz2(self):
"""Can open bz2 via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# bz2 compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
test_file = infile.name
class SmartOpenFileObjTest(unittest.TestCase):
"""
Test passing raw file objects.
"""

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
class IOReadWrapper(object):
def __init__(self):
self.fobj = six.BytesIO()

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
def read(self, size=-1):
return self.fobj.read(size)

if os.path.isfile(test_file):
os.unlink(test_file)
def __enter__(self):
return self

responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.bz2")
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

# decompress the bzip2 and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def close(self):
self.fobj.close()

@responses.activate
def test_http_xz(self):
"""Can open xz via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# lzma compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.xz', delete=False) as infile:
test_file = infile.name

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
class IOWriteWrapper(object):
def __init__(self):
self.fobj = six.BytesIO()

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
def write(self, b):
return self.fobj.write(b)

if os.path.isfile(test_file):
os.unlink(test_file)
def close(self):
self.fobj.close()

responses.add(responses.GET, "http://127.0.0.1/data.xz", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.xz")
def test_read(self):
file = self.IOReadWrapper()
text = 'Hello, world!'
file.fobj.write(text.encode())
file.fobj.seek(0)
with smart_open.smart_open(file, 'r') as sf:
data = sf.read()
self.assertEqual(data, text)
file = self.IOWriteWrapper() # no `read()`
file.fobj.write(text.encode())
file.fobj.seek(0)
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'r'))

# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def test_write(self):
file = self.IOWriteWrapper()
text = 'Hello, world!'
with smart_open.smart_open(file, 'w') as sf:
sf.write(text)
self.assertEqual(file.fobj.getvalue(), text.encode())
file = self.IOReadWrapper() # no `write()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'w'))

def test_read_write(self):
file = six.BytesIO()
text = 'Hello, world!'
file.write(text.encode())
file.seek(0)
with smart_open.smart_open(file, 'r+') as sf:
data = sf.read()
sf.write("hello")
self.assertEqual(data, text)

file = six.BytesIO()
file.write(text.encode())
file.seek(0)
with smart_open.smart_open(file, 'w+') as sf:
data = sf.read()
sf.write("hello")
self.assertEqual(data, text)

file = six.BytesIO()
file.write(text.encode())
file.seek(0)
with smart_open.smart_open(file, 'a+') as sf:
data = sf.read()
sf.write("hello")
self.assertEqual(data, text)

file = self.IOReadWrapper() # no `write()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'r+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'w+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'a+'))

file = self.IOWriteWrapper() # no `read()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'r+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'w+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'a+'))

def test_append(self):
file = self.IOWriteWrapper()
text = 'Hello, world!'
with smart_open.smart_open(file, 'a') as sf:
sf.write(text)
self.assertEqual(file.fobj.getvalue(), text.encode())
file = self.IOReadWrapper() # no `write()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'a'))

def test_name_read(self):
text = 'Hello, world!' * 1000
file = six.BytesIO()
file.write(bzip2_compress(text.encode()))
file.seek(0)
file.name = 'data.bz2'
with smart_open.smart_open(file, 'r') as sf:
data = sf.read()
self.assertEqual(data, text)

def test_name_write(self):
text = 'Hello, world!' * 1000
file = six.BytesIO()
file.name = 'data.bz2'
with smart_open.smart_open(file, 'w') as sf:
sf.write(text)
self.assertEqual(bzip2_decompress(file.getvalue()), text.encode())

def test_name_read_write(self):
file = six.BytesIO()
file.name = 'data.bz2'
self.assertRaises(ValueError, lambda: smart_open.smart_open(file, 'r+'))
self.assertRaises(ValueError, lambda: smart_open.smart_open(file, 'w+'))
self.assertRaises(ValueError, lambda: smart_open.smart_open(file, 'a+'))


#
Expand Down