Skip to content

Commit

Permalink
Support errors keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
mpenkov committed Dec 3, 2017
1 parent 2c79505 commit 8017985
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 26 deletions.
42 changes: 28 additions & 14 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
"Re-open the file without specifying an encoding to suppress this warning."
)

DEFAULT_ERRORS = 'strict'


def smart_open(uri, mode="rb", **kw):
"""
Expand Down Expand Up @@ -169,7 +171,9 @@ def smart_open(uri, mode="rb", **kw):
if parsed_uri.scheme in ("file", ):
# local files -- both read & write supported
# compression, if any, is determined by the filename extension (.gz, .bz2)
return file_smart_open(parsed_uri.uri_path, mode, encoding=kw.pop('encoding', None))
encoding = kw.pop('encoding', None)
errors = kw.pop('errors', DEFAULT_ERRORS)
return file_smart_open(parsed_uri.uri_path, mode, encoding=encoding, errors=errors)
elif parsed_uri.scheme in ("s3", "s3n", 's3u'):
return s3_open_uri(parsed_uri, mode, **kw)
elif parsed_uri.scheme in ("hdfs", ):
Expand Down Expand Up @@ -237,12 +241,12 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
#
if codec and mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY):
if mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY):
s3_mode = smart_open_s3.READ_BINARY
elif codec and mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY):
elif mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY):
s3_mode = smart_open_s3.WRITE_BINARY
else:
s3_mode = mode
raise NotImplementedError('mode %r not implemented for S3' % mode)

#
# TODO: I'm not sure how to handle this with boto3. Any ideas?
Expand All @@ -251,8 +255,12 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
#
# _setup_unsecured_mode()

encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, s3_mode, **kwargs)
return _CODECS[codec](fobj, mode)
decompressed_fobj = _CODECS[codec](fobj, mode)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj


def _setup_unsecured_mode(parsed_uri, kwargs):
Expand Down Expand Up @@ -284,16 +292,20 @@ def s3_open_key(key, mode, **kwargs):
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
#
if codec and mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY):
if mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY):
s3_mode = smart_open_s3.READ_BINARY
elif codec and mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY):
elif mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY):
s3_mode = smart_open_s3.WRITE_BINARY
else:
s3_mode = mode
raise NotImplementedError('mode %r not implemented for S3' % mode)

logging.debug('codec: %r mode: %r s3_mode: %r', codec, mode, s3_mode)
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(key.bucket.name, key.name, s3_mode, **kwargs)
return _CODECS[codec](fobj, mode)
decompressed_fobj = _CODECS[codec](fobj, mode)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj


def _detect_codec(filename):
Expand Down Expand Up @@ -594,15 +606,16 @@ def compression_wrapper(file_obj, filename, mode):
return file_obj


def encoding_wrapper(fileobj, mode, encoding=None):
def encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
"""Decode bytes into text, if necessary.
If mode specifies binary access, does nothing, unless the encoding is
specified. A non-null encoding implies text mode.
:arg fileobj: must quack like a filehandle object.
:arg str mode: is the mode which was originally requested by the user.
:arg encoding: The text encoding to use. If mode is binary, overrides mode.
:arg str encoding: The text encoding to use. If mode is binary, overrides mode.
:arg str errors: The method to use when handling encoding/decoding errors.
:returns: a file object
"""
logger.debug('encoding_wrapper: %r', locals())
Expand All @@ -626,17 +639,18 @@ def encoding_wrapper(fileobj, mode, encoding=None):
decoder = codecs.getreader(encoding)
else:
decoder = codecs.getwriter(encoding)
return decoder(fileobj)
return decoder(fileobj, errors=errors)


def file_smart_open(fname, mode='rb', encoding=None):
def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
"""
Stream from/to local filesystem, transparently (de)compressing gzip and bz2
files if necessary.
:arg str fname: The path to the file to open.
:arg str mode: The mode in which to open the file.
:arg str encoding: The text encoding to use.
:arg str errors: The method to use when handling encoding/decoding errors.
:returns: A file object
"""
#
Expand All @@ -656,7 +670,7 @@ def file_smart_open(fname, mode='rb', encoding=None):
raw_mode = mode
raw_fobj = open(fname, raw_mode)
decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj


Expand Down
105 changes: 93 additions & 12 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,29 +293,29 @@ def test_file(self, mock_smart_open):
smart_open_object = smart_open.smart_open(prefix+full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None)
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')

full_path = '/tmp/test#hash##more.txt'
read_mode = "rb"
smart_open_object = smart_open.smart_open(prefix+full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None)
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')

full_path = 'aa#aa'
read_mode = "rb"
smart_open_object = smart_open.smart_open(full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None)
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')

short_path = "~/tmp/test.txt"
full_path = os.path.expanduser(short_path)

smart_open_object = smart_open.smart_open(prefix+short_path, read_mode)
smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict')
smart_open_object.__iter__()
# called with the correct expanded path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None)
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')

# couldn't find any project for mocking up HDFS data
# TODO: we want to test also a content of the files, not just fnc call params
Expand Down Expand Up @@ -485,15 +485,15 @@ def test_file_mode_mock(self, mock_file, mock_boto):

# correct read modes
smart_open.smart_open("blah", "r")
mock_file.assert_called_with("blah", "r", encoding=None)
mock_file.assert_called_with("blah", "r", encoding=None, errors='strict')

smart_open.smart_open("blah", "rb")
mock_file.assert_called_with("blah", "rb", encoding=None)
mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict')

short_path = "~/blah"
full_path = os.path.expanduser(short_path)
smart_open.smart_open(short_path, "rb")
mock_file.assert_called_with(full_path, "rb", encoding=None)
mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict')

# correct write modes, incorrect scheme
self.assertRaises(NotImplementedError, smart_open.smart_open, "hdfs:///blah.txt", "wb+")
Expand All @@ -502,16 +502,16 @@ def test_file_mode_mock(self, mock_file, mock_boto):

# correct write mode, correct file:// URI
smart_open.smart_open("blah", "w")
mock_file.assert_called_with("blah", "w", encoding=None)
mock_file.assert_called_with("blah", "w", encoding=None, errors='strict')

smart_open.smart_open("file:///some/file.txt", "wb")
mock_file.assert_called_with("/some/file.txt", "wb", encoding=None)
mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict')

smart_open.smart_open("file:///some/file.txt", "wb+")
mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None)
mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict')

smart_open.smart_open("file:///some/file.txt", "w+")
mock_file.assert_called_with("/some/file.txt", "w+", encoding=None)
mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict')

@mock.patch('boto3.Session')
def test_s3_mode_mock(self, mock_session):
Expand Down Expand Up @@ -595,6 +595,32 @@ def test_s3_modes_moto(self):

self.assertEqual(output, [test_string])

@mock_s3
def test_write_bad_encoding_strict(self):
"""Should abort on encoding error."""
text = u'欲しい気持ちが成長しすぎて'

with self.assertRaises(UnicodeEncodeError):
with tempfile.NamedTemporaryFile('wb', delete=True) as infile:
with smart_open.smart_open(infile.name, 'w', encoding='koi8-r',
errors='strict') as fout:
fout.write(text)

@mock_s3
def test_write_bad_encoding_replace(self):
"""Should replace characters that failed to encode."""
text = u'欲しい気持ちが成長しすぎて'
expected = u'?' * len(text)

with tempfile.NamedTemporaryFile('wb', delete=True) as infile:
with smart_open.smart_open(infile.name, 'w', encoding='koi8-r',
errors='replace') as fout:
fout.write(text)
with smart_open.smart_open(infile.name, 'r', encoding='koi8-r') as fin:
actual = fin.read()

self.assertEqual(expected, actual)


class WebHdfsWriteTest(unittest.TestCase):
"""
Expand Down Expand Up @@ -1078,6 +1104,61 @@ def test_write_encoding(self):
actual = fin.read()
self.assertEqual(text, actual)

@mock_s3
def test_write_bad_encoding_strict(self):
"""Should open the file for writing with the correct encoding."""
conn = boto.connect_s3()
conn.create_bucket('bucket')
key = "s3://bucket/key.txt"
text = u'欲しい気持ちが成長しすぎて'

with self.assertRaises(UnicodeEncodeError):
with smart_open.smart_open(key, 'w', encoding='koi8-r', errors='strict') as fout:
fout.write(text)

@mock_s3
def test_write_bad_encoding_replace(self):
"""Should open the file for writing with the correct encoding."""
conn = boto.connect_s3()
conn.create_bucket('bucket')
key = "s3://bucket/key.txt"
text = u'欲しい気持ちが成長しすぎて'
expected = u'?' * len(text)

with smart_open.smart_open(key, 'w', encoding='koi8-r', errors='replace') as fout:
fout.write(text)
with smart_open.smart_open(key, encoding='koi8-r') as fin:
actual = fin.read()
self.assertEqual(expected, actual)

@mock_s3
def test_write_text_gzip(self):
"""Should open the file for writing with the correct encoding."""
conn = boto.connect_s3()
conn.create_bucket('bucket')
key = "s3://bucket/key.txt.gz"
text = u'какая боль, какая боль, аргентина - ямайка, 5-0'

with smart_open.smart_open(key, 'w', encoding='utf-8') as fout:
fout.write(text)
with smart_open.smart_open(key, 'r', encoding='utf-8') as fin:
actual = fin.read()
self.assertEqual(text, actual)

@mock_s3
def test_write_text_gzip_key(self):
"""Should open the boto S3 key for writing with the correct encoding."""
conn = boto.connect_s3()
mybucket = conn.create_bucket('bucket')
mykey = boto.s3.key.Key(mybucket, 'key.txt.gz')
text = u'какая боль, какая боль, аргентина - ямайка, 5-0'

with smart_open.smart_open(mykey, 'w', encoding='utf-8') as fout:
fout.write(text)
with smart_open.smart_open(mykey, 'r', encoding='utf-8') as fin:
actual = fin.read()
self.assertEqual(text, actual)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit 8017985

Please sign in to comment.