Skip to content

[v2] Add support for stdin/stdout streams for CRT client #8291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-s3cp-40300.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``s3 cp``",
"description": "Support streaming uploads from stdin and streaming downloads to stdout for CRT transfer client"
}
2 changes: 0 additions & 2 deletions awscli/customizations/s3/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def create_transfer_manager(self, params, runtime_config,
def _compute_transfer_client_type(self, params, runtime_config):
if params.get('paths_type') == 's3s3':
return constants.DEFAULT_TRANSFER_CLIENT
if params.get('is_stream'):
return constants.DEFAULT_TRANSFER_CLIENT
return runtime_config.get(
'preferred_transfer_client', constants.DEFAULT_TRANSFER_CLIENT)

Expand Down
168 changes: 130 additions & 38 deletions awscli/s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,12 @@ def _crt_request_from_aws_request(self, aws_request):
headers_list.append((name, str(value, 'utf-8')))

crt_headers = awscrt.http.HttpHeaders(headers_list)
# CRT requires body (if it exists) to be an I/O stream.
crt_body_stream = None
if aws_request.body:
if hasattr(aws_request.body, 'seek'):
crt_body_stream = aws_request.body
else:
crt_body_stream = BytesIO(aws_request.body)

crt_request = awscrt.http.HttpRequest(
method=aws_request.method,
path=crt_path,
headers=crt_headers,
body_stream=crt_body_stream,
body_stream=aws_request.body,
)
return crt_request

Expand All @@ -453,6 +446,25 @@ def _convert_to_crt_http_request(self, botocore_http_request):
crt_request.headers.set("host", url_parts.netloc)
if crt_request.headers.get('Content-MD5') is not None:
crt_request.headers.remove("Content-MD5")

# In general, the CRT S3 client expects a content length header. It
# only expects a missing content length header if the body is not
# seekable. However, botocore does not set the content length header
# for GetObject API requests and so we set the content length to zero
# to meet the CRT S3 client's expectation that the content length
# header is set even if there is no body.
if crt_request.headers.get('Content-Length') is None:
if botocore_http_request.body is None:
crt_request.headers.add('Content-Length', "0")

# Botocore sets the Transfer-Encoding header when it cannot determine
# the content length of the request body (e.g. it's not seekable).
# However, CRT does not support this header, but it supports
# non-seekable bodies. So we remove this header to not cause issues
# in the downstream CRT S3 request.
if crt_request.headers.get('Transfer-Encoding') is not None:
crt_request.headers.remove('Transfer-Encoding')

return crt_request

def _capture_http_request(self, request, **kwargs):
Expand Down Expand Up @@ -555,39 +567,20 @@ def __init__(self, crt_request_serializer, os_utils):
def get_make_request_args(
self, request_type, call_args, coordinator, future, on_done_after_calls
):
recv_filepath = None
send_filepath = None
s3_meta_request_type = getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
request_args_handler = getattr(
self,
f'_get_make_request_args_{request_type}',
self._default_get_make_request_args,
)
on_done_before_calls = []
if s3_meta_request_type == S3RequestType.GET_OBJECT:
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
file_ondone_call = RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
on_done_before_calls.append(file_ondone_call)
elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len

crt_request = self._request_serializer.serialize_http_request(
request_type, future
return request_args_handler(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=[],
on_done_after_calls=on_done_after_calls,
)

return {
'request': crt_request,
'type': s3_meta_request_type,
'recv_filepath': recv_filepath,
'send_filepath': send_filepath,
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
'on_progress': self.get_crt_callback(future, 'progress'),
}

def get_crt_callback(
self,
future,
Expand All @@ -613,6 +606,97 @@ def invoke_all_callbacks(*args, **kwargs):

return invoke_all_callbacks

def _get_make_request_args_put_object(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
send_filepath = None
if isinstance(call_args.fileobj, str):
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len
else:
call_args.extra_args["Body"] = call_args.fileobj

# Suppress botocore's automatic MD5 calculation by setting an override
# value that will get deleted in the BotocoreCRTRequestSerializer.
# The CRT S3 client is able automatically compute checksums as part of
# requests it makes, and the intention is to configure automatic
# checksums in a future update.
call_args.extra_args["ContentMD5"] = "override-to-be-removed"

make_request_args = self._default_get_make_request_args(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=on_done_before_calls,
on_done_after_calls=on_done_after_calls,
)
make_request_args['send_filepath'] = send_filepath
return make_request_args

def _get_make_request_args_get_object(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
recv_filepath = None
on_body = None
if isinstance(call_args.fileobj, str):
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
on_done_before_calls.append(
RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
)
else:
on_body = OnBodyFileObjWriter(call_args.fileobj)

make_request_args = self._default_get_make_request_args(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=on_done_before_calls,
on_done_after_calls=on_done_after_calls,
)
make_request_args['recv_filepath'] = recv_filepath
make_request_args['on_body'] = on_body
return make_request_args

def _default_get_make_request_args(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
return {
'request': self._request_serializer.serialize_http_request(
request_type, future
),
'type': getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
),
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
'on_progress': self.get_crt_callback(future, 'progress'),
}


class RenameTempFileHandler:
def __init__(self, coordinator, final_filename, temp_filename, osutil):
Expand Down Expand Up @@ -642,3 +726,11 @@ def __init__(self, coordinator):

def __call__(self, **kwargs):
self._coordinator.set_done_callbacks_complete()


class OnBodyFileObjWriter:
def __init__(self, fileobj):
self._fileobj = fileobj

def __call__(self, chunk, **kwargs):
self._fileobj.write(chunk)
3 changes: 0 additions & 3 deletions awscli/topics/s3-config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,6 @@ files to and from S3. Valid choices are:

* S3 to S3 copies - Falls back to using the ``default`` transfer client

* Streaming uploads from standard input and downloads to standard output -
Falls back to using ``default`` transfer client.

* Region redirects - Transfers fail for requests sent to a region that does
not match the region of the targeted S3 bucket.

Expand Down
22 changes: 19 additions & 3 deletions tests/functional/s3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def setUp(self):
self.mock_crt_client.return_value.make_request.side_effect = \
self.simulate_make_request_side_effect
self.files = FileCreator()
self.expected_download_content = b'content'

def tearDown(self):
super(BaseCRTTransferClientTest, self).tearDown()
Expand All @@ -456,6 +457,8 @@ def get_config_file_contents(self):
def simulate_make_request_side_effect(self, *args, **kwargs):
if kwargs.get('recv_filepath'):
self.simulate_file_download(kwargs['recv_filepath'])
elif kwargs.get('on_body'):
self.simulate_on_body(kwargs['on_body'])
s3_request = FakeCRTS3Request(
future=FakeCRTFuture(kwargs.get('on_done'))
)
Expand All @@ -465,11 +468,14 @@ def simulate_file_download(self, recv_filepath):
parent_dir = os.path.dirname(recv_filepath)
if not os.path.isdir(parent_dir):
os.makedirs(parent_dir)
with open(recv_filepath, 'w') as f:
with open(recv_filepath, 'wb') as f:
# The content is arbitrary as most functional tests are just going
# to assert the file exists since it is the CRT writing the
# data to the file.
f.write('content')
f.write(self.expected_download_content)

def simulate_on_body(self, on_body_callback):
on_body_callback(chunk=self.expected_download_content, offset=0)

def get_crt_make_request_calls(self):
return self.mock_crt_client.return_value.make_request.call_args_list
Expand All @@ -489,7 +495,8 @@ def assert_crt_make_request_call(
self, make_request_call, expected_type, expected_host,
expected_path, expected_http_method=None,
expected_send_filepath=None,
expected_recv_startswith=None):
expected_recv_startswith=None,
expected_body_content=None):
make_request_kwargs = make_request_call[1]
self.assertEqual(
make_request_kwargs['type'], expected_type)
Expand Down Expand Up @@ -522,6 +529,15 @@ def assert_crt_make_request_call(
f"start with {expected_recv_startswith}"
)
)
if expected_body_content is not None:
# Note: The underlying CRT awscrt.io.InputStream does not expose
# a public read method so we have to reach into the private,
# underlying stream to determine the content. We should update
# to use a public interface if a public interface is ever exposed.
self.assertEqual(
make_request_kwargs['request'].body_stream._stream.read(),
expected_body_content
)


class FakeCRTS3Request:
Expand Down
33 changes: 23 additions & 10 deletions tests/functional/s3/test_cp_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,25 +2116,38 @@ def test_does_not_use_crt_client_for_copies(self):
self.assertEqual(self.get_crt_make_request_calls(), [])
self.assert_no_remaining_botocore_responses()

def test_does_not_use_crt_client_for_streaming_upload(self):
def test_streaming_upload_using_crt_client(self):
cmdline = [
's3', 'cp', '-', 's3://bucket/key'
]
self.add_botocore_put_object_response()
with mock.patch('sys.stdin', BufferedBytesIO(b'foo')):
self.run_command(cmdline)
self.assertEqual(self.get_crt_make_request_calls(), [])
self.assert_no_remaining_botocore_responses()
crt_requests = self.get_crt_make_request_calls()
self.assertEqual(len(crt_requests), 1)
self.assert_crt_make_request_call(
crt_requests[0],
expected_type=S3RequestType.PUT_OBJECT,
expected_host=self.get_virtual_s3_host('bucket'),
expected_path='/key',
expected_body_content=b'foo',
)

def test_does_not_use_crt_client_for_streaming_download(self):
def test_streaming_download_using_crt_client(self):
cmdline = [
's3', 'cp', 's3://bucket/key', '-'
]
self.add_botocore_head_object_response()
self.add_botocore_get_object_response()
self.run_command(cmdline)
self.assertEqual(self.get_crt_make_request_calls(), [])
self.assert_no_remaining_botocore_responses()
result = self.run_command(cmdline)
crt_requests = self.get_crt_make_request_calls()
self.assertEqual(len(crt_requests), 1)
self.assert_crt_make_request_call(
crt_requests[0],
expected_type=S3RequestType.GET_OBJECT,
expected_host=self.get_virtual_s3_host('bucket'),
expected_path='/key',
)
self.assertEqual(
result.stdout, self.expected_download_content.decode('utf-8')
)

def test_respects_region_parameter(self):
filename = self.files.create_file('myfile', 'mycontent')
Expand Down
Loading