Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to boto3. Fix #43 #164

Merged
merged 9 commits into from
Apr 2, 2018
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ matrix:

install:
- pip install .[test]
- pip uninstall --yes botocore boto3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to edit setup.py as we discussed, I think this will work

- pip install boto3
- pip freeze


Expand Down
143 changes: 141 additions & 2 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
# -*- coding: utf-8 -*-
"""Implements file-like objects for reading and writing from/to S3."""
import boto3
import botocore.client

import io
import contextlib
import functools
import itertools
import logging

import boto3
import botocore.client
import six


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# Multiprocessing is unavailable in App Engine (and possibly other sandboxes).
# The only method currently relying on it is s3_iter_bucket, which is instructed
# whether to use it by the MULTIPROCESSING flag.
_MULTIPROCESSING = False
try:
import multiprocessing.pool
_MULTIPROCESSING = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can multiprocessing be unavailable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust the above comment. It may be unavailable in certain environments, but I've never been in one.

except ImportError:
logger.warning("multiprocessing could not be imported and won't be used")

START = 0
CURRENT = 1
Expand Down Expand Up @@ -447,3 +459,130 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.terminate()
else:
self.close()


def iter_bucket(bucket_name, prefix='', accept_key=lambda key: True,
key_limit=None, workers=16, retries=3):
"""
Iterate and download all S3 files under `bucket/prefix`, yielding out
`(key, key content)` 2-tuples (generator).

`accept_key` is a function that accepts a key name (unicode string) and
returns True/False, signalling whether the given key should be downloaded out or
not (default: accept all keys).

If `key_limit` is given, stop after yielding out that many results.

The keys are processed in parallel, using `workers` processes (default: 16),
to speed up downloads greatly. If multiprocessing is not available, thus
MULTIPROCESSING is False, this parameter will be ignored.

Example::

TODO
"""
#
# If people insist on giving us bucket instances, silently extract the name
# before moving on. Works for boto3 as well as boto.
#
try:
bucket_name = bucket_name.name
except AttributeError:
pass

total_size, key_no = 0, -1
key_iterator = _list_bucket(bucket_name, prefix=prefix, accept_key=accept_key)
download_key = functools.partial(_download_key, bucket_name=bucket_name, retries=retries)

with _create_process_pool(processes=workers) as pool:
result_iterator = pool.imap_unordered(download_key, key_iterator)
for key_no, (key, content) in enumerate(result_iterator):
if True or key_no % 1000 == 0:
logger.info(
"yielding key #%i: %s, size %i (total %.1fMB)",
key_no, key, len(content), total_size / 1024.0 ** 2
)
yield key, content
total_size += len(content)

if key_limit is not None and key_no + 1 >= key_limit:
# we were asked to output only a limited number of keys => we're done
break
logger.info("processed %i keys, total size %i" % (key_no + 1, total_size))


def _list_bucket(bucket_name, prefix='', accept_key=lambda k: True):
client = boto3.client('s3')
ctoken = None

while True:
response = client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
try:
content = response['Contents']
except KeyError:
pass
else:
for c in content:
key = c['Key']
if accept_key(key):
yield key
ctoken = response.get('NextContinuationToken', None)
if not ctoken:
break


def _download_key(key_name, bucket_name=None, retries=3):
if bucket_name is None:
raise ValueError('bucket_name may not be None')

#
# https://geekpete.com/blog/multithreading-boto3/
#
session = boto3.session.Session()
s3 = session.resource('s3')
bucket = s3.Bucket(bucket_name)

# Sometimes, https://github.com/boto/boto/issues/2409 can happen because of network issues on either side.
# Retry up to 3 times to ensure its not a transient issue.
for x in range(retries + 1):
try:
content_bytes = _download_fileobj(bucket, key_name)
except botocore.client.ClientError:
# Actually fail on last pass through the loop
if x == retries:
raise
# Otherwise, try again, as this might be a transient timeout
pass
else:
return key_name, content_bytes


def _download_fileobj(bucket, key_name):
#
# This is a separate function only because it makes it easier to inject
# exceptions during tests.
#
buf = io.BytesIO()
bucket.download_fileobj(key_name, buf)
return buf.getvalue()


class DummyPool(object):
"""A class that mimics multiprocessing.pool.Pool for our purposes."""
def imap_unordered(self, function, items):
return six.moves.map(function, items)

def terminate(self):
pass


@contextlib.contextmanager
def _create_process_pool(processes=1):
if _MULTIPROCESSING and processes:
logger.info("creating pool with %i workers", processes)
pool = multiprocessing.pool.Pool(processes=processes)
else:
logger.info("creating dummy pool")
pool = DummyPool()
yield pool
pool.terminate()
157 changes: 12 additions & 145 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
import warnings

from boto.compat import BytesIO, urlsplit, six
import boto.s3.connection
import boto.s3.key
from ssl import SSLError
import sys


Expand All @@ -52,24 +50,14 @@

logger = logging.getLogger(__name__)

# Multiprocessing is unavailable in App Engine (and possibly other sandboxes).
# The only method currently relying on it is s3_iter_bucket, which is instructed
# whether to use it by the MULTIPROCESSING flag.
MULTIPROCESSING = False
try:
import multiprocessing.pool
MULTIPROCESSING = True
except ImportError:
logger.warning("multiprocessing could not be imported and won't be used")
from itertools import imap

if IS_PY2:
from bz2file import BZ2File
else:
from bz2 import BZ2File

import gzip
import smart_open.s3 as smart_open_s3
from smart_open.s3 import iter_bucket as s3_iter_bucket


WEBHDFS_MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads
Expand Down Expand Up @@ -248,13 +236,6 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
else:
raise NotImplementedError('mode %r not implemented for S3' % mode)

#
# TODO: I'm not sure how to handle this with boto3. Any ideas?
#
# https://github.com/boto/boto3/issues/334
#
# _setup_unsecured_mode()

encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, s3_mode, **kwargs)
Expand All @@ -263,49 +244,18 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
return decoded_fobj


def _setup_unsecured_mode(parsed_uri, kwargs):
port = kwargs.pop('port', parsed_uri.port)
if port != 443:
kwargs['port'] = port

if not kwargs.pop('is_secure', parsed_uri.scheme != 's3u'):
kwargs['is_secure'] = False
# If the security model docker is overridden, honor the host directly.
kwargs['calling_format'] = boto.s3.connection.OrdinaryCallingFormat()


def s3_open_key(key, mode, **kwargs):
logger.debug('%r', locals())
#
# TODO: handle boto3 keys as well
#
host = kwargs.pop('host', None)
if host is not None:
kwargs['endpoint_url'] = 'http://' + host

if kwargs.pop("ignore_extension", False):
codec = None
else:
codec = _detect_codec(key.name)

#
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
#
if mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY):
s3_mode = smart_open_s3.READ_BINARY
elif mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY):
s3_mode = smart_open_s3.WRITE_BINARY
else:
raise NotImplementedError('mode %r not implemented for S3' % mode)

logging.debug('codec: %r mode: %r s3_mode: %r', codec, mode, s3_mode)
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(key.bucket.name, key.name, s3_mode, **kwargs)
decompressed_fobj = _CODECS[codec](fobj, mode)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj
try:
bucket_name, key_name = key.bucket_name, key.key
logging.warning('inferring S3 URL from boto3 Key object')
except AttributeError:
try:
bucket_name, key_name = key.bucket.name, key.name
logging.warning('inferring S3 URL from boto.s3.key.Key object')
except AttributeError:
raise ValueError('expected %r to be a boto or boto3 Key object' % key)
parsed_uri = ParseUri('s3://%s/%s' % (bucket_name, key_name))
return s3_open_uri(parsed_uri, mode, **kwargs)


def _detect_codec(filename):
Expand Down Expand Up @@ -879,89 +829,6 @@ def __exit__(self, type, value, traceback):
self.close()


def s3_iter_bucket_process_key_with_kwargs(kwargs):
return s3_iter_bucket_process_key(**kwargs)


def s3_iter_bucket_process_key(key, retries=3):
"""
Conceptually part of `s3_iter_bucket`, but must remain top-level method because
of pickling visibility.

"""
# Sometimes, https://github.com/boto/boto/issues/2409 can happen because of network issues on either side.
# Retry up to 3 times to ensure its not a transient issue.
for x in range(0, retries + 1):
try:
return key, key.get_contents_as_string()
except SSLError:
# Actually fail on last pass through the loop
if x == retries:
raise
# Otherwise, try again, as this might be a transient timeout
pass


def s3_iter_bucket(bucket, prefix='', accept_key=lambda key: True, key_limit=None, workers=16, retries=3):
"""
Iterate and download all S3 files under `bucket/prefix`, yielding out
`(key, key content)` 2-tuples (generator).

`accept_key` is a function that accepts a key name (unicode string) and
returns True/False, signalling whether the given key should be downloaded out or
not (default: accept all keys).

If `key_limit` is given, stop after yielding out that many results.

The keys are processed in parallel, using `workers` processes (default: 16),
to speed up downloads greatly. If multiprocessing is not available, thus
MULTIPROCESSING is False, this parameter will be ignored.

Example::

>>> mybucket = boto.connect_s3().get_bucket('mybucket')

>>> # get all JSON files under "mybucket/foo/"
>>> for key, content in s3_iter_bucket(mybucket, prefix='foo/', accept_key=lambda key: key.endswith('.json')):
... print key, len(content)

>>> # limit to 10k files, using 32 parallel workers (default is 16)
>>> for key, content in s3_iter_bucket(mybucket, key_limit=10000, workers=32):
... print key, len(content)

"""
total_size, key_no = 0, -1
keys = ({'key': key, 'retries': retries} for key in bucket.list(prefix=prefix) if accept_key(key.name))

if MULTIPROCESSING:
logger.info("iterating over keys from %s with %i workers", bucket, workers)
pool = multiprocessing.pool.Pool(processes=workers)
iterator = pool.imap_unordered(s3_iter_bucket_process_key_with_kwargs, keys)
else:
logger.info("iterating over keys from %s without multiprocessing", bucket)
iterator = imap(s3_iter_bucket_process_key_with_kwargs, keys)

for key_no, (key, content) in enumerate(iterator):
if key_no % 1000 == 0:
logger.info(
"yielding key #%i: %s, size %i (total %.1fMB)",
key_no, key, len(content), total_size / 1024.0 ** 2
)

yield key, content
key.close()
total_size += len(content)

if key_limit is not None and key_no + 1 >= key_limit:
# we were asked to output only a limited number of keys => we're done
break

if MULTIPROCESSING:
pool.terminate()

logger.info("processed %i keys, total size %i" % (key_no + 1, total_size))


class WebHdfsException(Exception):
def __init__(self, msg=str()):
self.msg = msg
Expand Down
Loading