Skip to content

Commit

Permalink
Handle "+", fix file objects and refactor SmartOpenHttpTest
Browse files Browse the repository at this point in the history
Signed-off-by: Vadim Markovtsev <[email protected]>
  • Loading branch information
vmarkovtsev committed Feb 23, 2019
1 parent b8d9537 commit 7290845
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 72 deletions.
35 changes: 28 additions & 7 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,33 @@ def _open_binary_stream(uri, mode, **kw):
if host is not None:
kw['endpoint_url'] = _add_scheme_to_host(host)
return smart_open_s3.open(uri.bucket.name, uri.name, mode, **kw), uri.name
elif hasattr(uri, 'read'):
elif _is_stream(uri, mode):
# simply pass-through if already a file-like
filename = '/tmp/unknown'
filename = getattr(uri, 'name', 'unknown')
return uri, filename
else:
raise TypeError('don\'t know how to handle uri %s' % repr(uri))


def _is_stream(fileobj, mode):
"""
Detect whether the specified object is a file object with required capabilities
implied by `mode`.
"""
has_read = hasattr(fileobj, 'read')
has_write = hasattr(fileobj, 'write')
if not has_read and not has_write:
return False
if mode.endswith('+'):
return has_read and has_write
if mode[0] == 'r':
return has_read
if mode[0] in ('w', 'a'):
return has_write
# we should never get here
return False


def _s3_open_uri(parsed_uri, mode, **kwargs):
logger.debug('s3_open_uri: %r', locals())
if mode in ('r', 'w'):
Expand Down Expand Up @@ -582,6 +601,8 @@ def _compression_wrapper(file_obj, filename, mode):
if _need_to_buffer(file_obj, mode, ext):
warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL)
file_obj = io.BytesIO(file_obj.read())
if ext in COMPRESSED_EXT and mode.endswith('+'):
raise ValueError('impossible to open in read+write+compressed mode')

if ext == '.bz2':
return BZ2File(file_obj, mode)
Expand Down Expand Up @@ -622,11 +643,11 @@ def _encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
if encoding is None:
encoding = SYSTEM_ENCODING

if mode[0] == 'r':
decoder = codecs.getreader(encoding)
else:
decoder = codecs.getwriter(encoding)
return decoder(fileobj, errors=errors)
if mode[0] == 'r' or mode.endswith('+'):
fileobj = codecs.getreader(encoding)(fileobj, errors=errors)
if mode[0] in ('w', 'a') or mode.endswith('+'):
fileobj = codecs.getwriter(encoding)(fileobj, errors=errors)
return fileobj

def _add_scheme_to_host(host):
if host.startswith('http://') or host.startswith('https://'):
Expand Down
217 changes: 152 additions & 65 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
CURR_DIR = os.path.abspath(os.path.dirname(__file__))


if PY2:
from bz2file import compress as bzip2_compress, decompress as bzip2_decompress
else:
from bz2 import compress as bzip2_compress, decompress as bzip2_decompress


class ParseUriTest(unittest.TestCase):
"""
Test ParseUri class.
Expand Down Expand Up @@ -200,92 +206,173 @@ def test_http_pass(self):
self.assertTrue(actual_request.headers['Authorization'].startswith('Basic '))

@responses.activate
def _test_compressed_http(self, suffix, query):
"""Can open <suffix> via http?"""
test_string = b'Hello World Compressed.' * 10000
test_file = six.BytesIO()
test_file.name = 'data' + suffix
with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
compressed_data = test_file.getvalue()
# check that the string was actually compressed
self.assertNotEqual(compressed_data, test_string)

responses.add(responses.GET, 'http://127.0.0.1/data' + suffix, body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open(
'http://127.0.0.1/data%s%s' % (suffix, '?some_param=some_val' if query else ''))

# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)

@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()
self._test_compressed_http(".gz", False)

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz?some_param=some_val")
def test_http_bz2(self):
"""Can open bzip2 via http?"""
self._test_compressed_http(".bz2", False)

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)
def test_http_xz(self):
"""Can open xz via http?"""
self._test_compressed_http(".xz", False)

@responses.activate
@unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet')
def test_http_gz_noquerystring(self):
"""Can open gzip via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()
def test_http_gz_query(self):
"""Can open gzip via http with a query appended to URI?"""
self._test_compressed_http(".gz", True)

with gzip.GzipFile(fpath) as fin:
expected_hash = hashlib.md5(fin.read()).hexdigest()
def test_http_bz2_query(self):
"""Can open bzip2 via http with a query appended to URI?"""
self._test_compressed_http(".bz2", True)

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.gz")
def test_http_xz_query(self):
"""Can open xz via http with a query appended to URI?"""
self._test_compressed_http(".xz", True)

m = hashlib.md5(smart_open_object.read())
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), expected_hash)

@responses.activate
def test_http_bz2(self):
"""Can open bz2 via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# bz2 compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
test_file = infile.name
class SmartOpenFileObjTest(unittest.TestCase):
"""
Test passing raw file objects.
"""

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
class IOReadWrapper(object):
def __init__(self):
self.fobj = six.BytesIO()

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
def read(self, size=-1):
return self.fobj.read(size)

if os.path.isfile(test_file):
os.unlink(test_file)

responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.bz2")
def __enter__(self):
return self

# decompress the bzip2 and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@responses.activate
def test_http_xz(self):
"""Can open xz via http?"""
test_string = b'Hello World Compressed.'
#
# TODO: why are these tests writing to temporary files? We can do the
# lzma compression in memory.
#
with tempfile.NamedTemporaryFile('wb', suffix='.xz', delete=False) as infile:
test_file = infile.name
def close(self):
self.fobj.close()

with smart_open.smart_open(test_file, 'wb') as outfile:
outfile.write(test_string)
class IOWriteWrapper(object):
def __init__(self):
self.fobj = six.BytesIO()

with open(test_file, 'rb') as infile:
compressed_data = infile.read()
def write(self, b):
return self.fobj.write(b)

if os.path.isfile(test_file):
os.unlink(test_file)
def close(self):
self.fobj.close()

responses.add(responses.GET, "http://127.0.0.1/data.xz", body=compressed_data, stream=True)
smart_open_object = smart_open.smart_open("http://127.0.0.1/data.xz")
def test_read(self):
file = self.IOReadWrapper()
text = 'Hello, world!'
file.fobj.write(text.encode())
file.fobj.seek(0)
with smart_open.smart_open(file, 'r') as sf:
data = sf.read()
self.assertEqual(data, text)
file = self.IOWriteWrapper() # no `read()`
file.fobj.write(text.encode())
file.fobj.seek(0)
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'r'))

# decompress the xz and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)
def test_write(self):
file = self.IOWriteWrapper()
text = 'Hello, world!'
with smart_open.smart_open(file, 'w') as sf:
sf.write(text)
self.assertEqual(file.fobj.getvalue(), text.encode())
file = self.IOReadWrapper() # no `write()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'w'))

def test_read_write(self):
file = six.BytesIO()
text = 'Hello, world!'
file.write(text.encode())
file.seek(0)
with smart_open.smart_open(file, 'r+') as sf:
data = sf.read()
sf.write("hello")
self.assertEqual(data, text)

file = six.BytesIO()
file.write(text.encode())
file.seek(0)
with smart_open.smart_open(file, 'w+') as sf:
data = sf.read()
sf.write("hello")
self.assertEqual(data, text)

file = six.BytesIO()
file.write(text.encode())
file.seek(0)
with smart_open.smart_open(file, 'a+') as sf:
data = sf.read()
sf.write("hello")
self.assertEqual(data, text)

file = self.IOReadWrapper() # no `write()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'r+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'w+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'a+'))

file = self.IOWriteWrapper() # no `read()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'r+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'w+'))
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'a+'))

def test_append(self):
file = self.IOWriteWrapper()
text = 'Hello, world!'
with smart_open.smart_open(file, 'a') as sf:
sf.write(text)
self.assertEqual(file.fobj.getvalue(), text.encode())
file = self.IOReadWrapper() # no `write()`
self.assertRaises(TypeError, lambda: smart_open.smart_open(file, 'a'))

def test_name_read(self):
text = 'Hello, world!' * 1000
file = six.BytesIO()
file.write(bzip2_compress(text.encode()))
file.seek(0)
file.name = 'data.bz2'
with smart_open.smart_open(file, 'r') as sf:
data = sf.read()
self.assertEqual(data, text)

def test_name_write(self):
text = 'Hello, world!' * 1000
file = six.BytesIO()
file.name = 'data.bz2'
with smart_open.smart_open(file, 'w') as sf:
sf.write(text)
self.assertEqual(bzip2_decompress(file.getvalue()), text.encode())

def test_name_read_write(self):
file = six.BytesIO()
file.name = 'data.bz2'
self.assertRaises(ValueError, lambda: smart_open.smart_open(file, 'r+'))
self.assertRaises(ValueError, lambda: smart_open.smart_open(file, 'w+'))
self.assertRaises(ValueError, lambda: smart_open.smart_open(file, 'a+'))


#
Expand Down

0 comments on commit 7290845

Please sign in to comment.