Skip to content

Commit

Permalink
Clean up _parse_uri_s3x, resolve edge cases (#237)
Browse files Browse the repository at this point in the history
* improve exception handling in the _parse_uri function

* clean up _parse_uri_s3x

fix host:port:junk edge case, improve exception handling, simplify logic

* fix README.md
  • Loading branch information
mpenkov authored and menshikh-iv committed Oct 1, 2018
1 parent efb74ce commit 7ac3b4e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 37 deletions.
8 changes: 4 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ What?
... print(line)
>>> # stream from WebHDFS
>>> for line in smart_open('webhdfs://host:port/user/hadoop/my_file.txt'):
>>> for line in smart_open('webhdfs://host:1234/user/hadoop/my_file.txt'):
... print(line)
>>> # stream content *into* S3 (write mode):
Expand All @@ -56,17 +56,17 @@ What?
... fout.write(line)
>>> # stream content *into* HDFS (write mode):
>>> with smart_open('hdfs://host:port/user/hadoop/my_file.txt', 'wb') as fout:
>>> with smart_open('hdfs://host:1234/user/hadoop/my_file.txt', 'wb') as fout:
... for line in [b'first line\n', b'second line\n', b'third line\n']:
... fout.write(line)
>>> # stream content *into* WebHDFS (write mode):
>>> with smart_open('webhdfs://host:port/user/hadoop/my_file.txt', 'wb') as fout:
>>> with smart_open('webhdfs://host:1234/user/hadoop/my_file.txt', 'wb') as fout:
... for line in [b'first line\n', b'second line\n', b'third line\n']:
... fout.write(line)
>>> # stream using a completely custom s3 server, like s3proxy:
>>> for line in smart_open('s3u://user:secret@host:port@mybucket/mykey.txt', 'rb'):
>>> for line in smart_open('s3u://user:secret@host:1234@mybucket/mykey.txt', 'rb'):
... print(line.decode('utf8'))
>>> # you can also use a boto.s3.key.Key instance directly:
Expand Down
69 changes: 38 additions & 31 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,44 +478,51 @@ def _parse_uri_webhdfs(parsed_uri):


def _parse_uri_s3x(parsed_uri):
#
# Restrictions on bucket names and labels:
#
# - Bucket names must be at least 3 and no more than 63 characters long.
# - Bucket names must be a series of one or more labels.
# - Adjacent labels are separated by a single period (.).
# - Bucket names can contain lowercase letters, numbers, and hyphens.
# - Each label must start and end with a lowercase letter or a number.
#
# We use the above as a guide only, and do not perform any validation. We
# let boto3 take care of that for us.
#
assert parsed_uri.scheme in smart_open_s3.SUPPORTED_SCHEMES

port = 443
host = boto.config.get('s3', 'host', 's3.amazonaws.com')
ordinary_calling_format = False
#
# These defaults tell boto3 to look for credentials elsewhere
#
access_id, access_secret = None, None

#
# Common URI template [secret:key@][host[:port]@]bucket/object
try:
uri = parsed_uri.netloc + parsed_uri.path
# Separate authentication from URI if exist
if ':' in uri.split('@')[0] and '@' in uri:
auth, uri = uri.split('@', 1)
access_id, access_secret = auth.split(':')
else:
# "None" credentials are interpreted as "look for credentials in other locations" by boto
access_id, access_secret = None, None

# Split [host[:port]@]bucket/path
host_bucket, key_id = uri.split('/', 1)
if '@' in host_bucket:
host_port, bucket_id = host_bucket.split('@')
ordinary_calling_format = True
if ':' in host_port:
server = host_port.split(':')
host = server[0]
if len(server) == 2:
port = int(server[1])
else:
host = host_port
else:
bucket_id = host_bucket
except Exception:
# Bucket names must be at least 3 and no more than 63 characters long.
# Bucket names must be a series of one or more labels.
# Adjacent labels are separated by a single period (.).
# Bucket names can contain lowercase letters, numbers, and hyphens.
# Each label must start and end with a lowercase letter or a number.
raise RuntimeError("invalid S3 URI: %s" % str(parsed_uri))
#
# The urlparse function doesn't handle the above schema, so we have to do
# it ourselves.
#
uri = parsed_uri.netloc + parsed_uri.path

if '@' in uri and ':' in uri.split('@')[0]:
auth, uri = uri.split('@', 1)
access_id, access_secret = auth.split(':')

head, key_id = uri.split('/', 1)
if '@' in head and ':' in head:
ordinary_calling_format = True
host_port, bucket_id = head.split('@')
host, port = host_port.split(':', 1)
port = int(port)
elif '@' in head:
ordinary_calling_format = True
host, bucket_id = head.split('@')
else:
bucket_id = head

return Uri(
scheme=parsed_uri.scheme, bucket_id=bucket_id, key_id=key_id,
Expand Down
23 changes: 21 additions & 2 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def test_s3_uri_has_atmark_in_key_name2(self):
self.assertEqual(parsed_uri.port, 1234)

def test_s3_invalid_url_atmark_in_bucket_name(self):
self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id:access_secret@my@bucket@port/mykey")
self.assertRaises(ValueError, smart_open_lib._parse_uri, "s3://access_id:access_secret@my@bucket@port/mykey")

def test_s3_invalid_uri_missing_colon(self):
self.assertRaises(RuntimeError, smart_open_lib._parse_uri, "s3://access_id@access_secret@mybucket@port/mykey")
self.assertRaises(ValueError, smart_open_lib._parse_uri, "s3://access_id@access_secret@mybucket@port/mykey")

def test_webhdfs_uri(self):
"""Do webhdfs URIs parse correctly"""
Expand Down Expand Up @@ -137,6 +137,25 @@ def test_s3_uri_with_colon_in_key_name(self):
self.assertEqual(parsed_uri.access_id, None)
self.assertEqual(parsed_uri.access_secret, None)

def test_host_and_port(self):
as_string = 's3u://user:secret@host:1234@mybucket/mykey.txt'
uri = smart_open_lib._parse_uri(as_string)
self.assertEqual(uri.scheme, "s3u")
self.assertEqual(uri.bucket_id, "mybucket")
self.assertEqual(uri.key_id, "mykey.txt")
self.assertEqual(uri.access_id, "user")
self.assertEqual(uri.access_secret, "secret")
self.assertEqual(uri.host, "host")
self.assertEqual(uri.port, 1234)

def test_invalid_port(self):
as_string = 's3u://user:secret@host:port@mybucket/mykey.txt'
self.assertRaises(ValueError, smart_open_lib._parse_uri, as_string)

def test_invalid_port2(self):
as_string = 's3u://user:secret@host:port:foo@mybucket/mykey.txt'
self.assertRaises(ValueError, smart_open_lib._parse_uri, as_string)


class SmartOpenHttpTest(unittest.TestCase):
"""
Expand Down

0 comments on commit 7ac3b4e

Please sign in to comment.