diff --git a/shared-bindings/audiomp3/MP3Decoder.c b/shared-bindings/audiomp3/MP3Decoder.c index 7a596f90dbc5..ff715720d028 100644 --- a/shared-bindings/audiomp3/MP3Decoder.c +++ b/shared-bindings/audiomp3/MP3Decoder.c @@ -72,6 +72,12 @@ //| decoder.file = stream //| //| If the stream is played with ``loop = True``, the loop will start at the beginning. +//| +//| It is possible to stream an mp3 from a socket, including a secure socket. +//| The MP3Decoder may change the timeout and non-blocking status of the socket. +//| Using a larger decode buffer with a stream can be helpful to avoid data underruns. +//| An ``adafruit_requests`` request must be made with ``headers={"Connection": "close"}`` so +//| that the socket closes when the stream ends. //| """ //| ... diff --git a/shared-bindings/ssl/SSLSocket.c b/shared-bindings/ssl/SSLSocket.c index e244f77c76a9..3ed4fa366243 100644 --- a/shared-bindings/ssl/SSLSocket.c +++ b/shared-bindings/ssl/SSLSocket.c @@ -10,10 +10,11 @@ #include #include "shared/runtime/context_manager_helpers.h" -#include "py/objtuple.h" +#include "py/mperrno.h" #include "py/objlist.h" +#include "py/objtuple.h" #include "py/runtime.h" -#include "py/mperrno.h" +#include "py/stream.h" #include "shared/netutils/netutils.h" @@ -247,9 +248,69 @@ static const mp_rom_map_elem_t ssl_sslsocket_locals_dict_table[] = { static MP_DEFINE_CONST_DICT(ssl_sslsocket_locals_dict, ssl_sslsocket_locals_dict_table); +typedef mp_uint_t (*readwrite_func)(ssl_sslsocket_obj_t *, const uint8_t *, mp_uint_t); + +static mp_int_t readwrite_common(mp_obj_t self_in, readwrite_func fn, const uint8_t *buf, size_t size, int *errorcode) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + mp_int_t ret = -EIO; + nlr_buf_t nlr; + if (nlr_push(&nlr) == 0) { + ret = fn(self, buf, size); + nlr_pop(); + } else { + mp_obj_t exc = MP_OBJ_FROM_PTR(nlr.ret_val); + if (nlr_push(&nlr) == 0) { + ret = -mp_obj_get_int(mp_load_attr(exc, MP_QSTR_errno)); + nlr_pop(); + } + } + if (ret < 0) { + *errorcode = -ret; + return MP_STREAM_ERROR; + } + return ret; +} + +static mp_uint_t sslsocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int *errorcode) { + return readwrite_common(self_in, (readwrite_func)common_hal_ssl_sslsocket_recv_into, buf, size, errorcode); +} + +static mp_uint_t sslsocket_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errorcode) { + return readwrite_common(self_in, common_hal_ssl_sslsocket_send, buf, size, errorcode); +} + +static mp_uint_t sslsocket_ioctl(mp_obj_t self_in, mp_uint_t request, mp_uint_t arg, int *errcode) { + ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in); + mp_uint_t ret; + if (request == MP_STREAM_POLL) { + mp_uint_t flags = arg; + ret = 0; + if ((flags & MP_STREAM_POLL_RD) && common_hal_ssl_sslsocket_readable(self) > 0) { + ret |= MP_STREAM_POLL_RD; + } + if ((flags & MP_STREAM_POLL_WR) && common_hal_ssl_sslsocket_writable(self)) { + ret |= MP_STREAM_POLL_WR; + } + } else { + *errcode = MP_EINVAL; + ret = MP_STREAM_ERROR; + } + return ret; +} + + +static const mp_stream_p_t sslsocket_stream_p = { + .read = sslsocket_read, + .write = sslsocket_write, + .ioctl = sslsocket_ioctl, + .is_text = false, +}; + + MP_DEFINE_CONST_OBJ_TYPE( ssl_sslsocket_type, MP_QSTR_SSLSocket, MP_TYPE_FLAG_NONE, - locals_dict, &ssl_sslsocket_locals_dict + locals_dict, &ssl_sslsocket_locals_dict, + protocol, &sslsocket_stream_p ); diff --git a/shared-bindings/ssl/SSLSocket.h b/shared-bindings/ssl/SSLSocket.h index 4079048a6cb1..85467fee089a 100644 --- a/shared-bindings/ssl/SSLSocket.h +++ b/shared-bindings/ssl/SSLSocket.h @@ -20,8 +20,10 @@ void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self); void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr); bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t *self); bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self); +bool common_hal_ssl_sslsocket_readable(ssl_sslsocket_obj_t *self); +bool common_hal_ssl_sslsocket_writable(ssl_sslsocket_obj_t *self); void common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog); -mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len); -mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len); +mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, mp_uint_t len); +mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, mp_uint_t len); void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t timeout_obj); void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t level, mp_obj_t optname, mp_obj_t optval); diff --git a/shared-module/audiomp3/MP3Decoder.c b/shared-module/audiomp3/MP3Decoder.c index 19867271f605..f59f4a4510b6 100644 --- a/shared-module/audiomp3/MP3Decoder.c +++ b/shared-module/audiomp3/MP3Decoder.c @@ -95,6 +95,18 @@ static off_t stream_lseek(void *stream, off_t offset, int whence) { #define INPUT_BUFFER_CONSUME(i, n) ((i).read_off += (n)) #define INPUT_BUFFER_CLEAR(i) ((i).read_off = (i).write_off = 0) +static void stream_set_blocking(audiomp3_mp3file_obj_t *self, bool block_ok) { + if (!self->settimeout_args[0]) { + return; + } + if (block_ok == self->block_ok) { + return; + } + self->block_ok = block_ok; + self->settimeout_args[2] = block_ok ? mp_const_none : mp_obj_new_int(0); + mp_call_method_n_kw(1, 0, self->settimeout_args); +} + /** Fill the input buffer unconditionally. * * Returns true if the input buffer contains any useful data, @@ -110,6 +122,8 @@ static bool mp3file_update_inbuf_always(audiomp3_mp3file_obj_t *self, bool block return INPUT_BUFFER_AVAILABLE(self->inbuf) > 0; } + stream_set_blocking(self, block_ok); + // We didn't previously reach EOF and we have input buffer space available // Move the unconsumed portion of the buffer to the start @@ -119,7 +133,7 @@ static bool mp3file_update_inbuf_always(audiomp3_mp3file_obj_t *self, bool block self->inbuf.read_off = 0; } - for (size_t to_read; !self->eof && (to_read = INPUT_BUFFER_SPACE(self->inbuf)) > 0 && (block_ok || stream_readable(self->stream));) { + for (size_t to_read; !self->eof && (to_read = INPUT_BUFFER_SPACE(self->inbuf)) > 0;) { uint8_t *write_ptr = self->inbuf.buf + self->inbuf.write_off; ssize_t n_read = stream_read(self->stream, write_ptr, to_read); @@ -328,9 +342,14 @@ void common_hal_audiomp3_mp3file_set_file(audiomp3_mp3file_obj_t *self, mp_obj_t background_callback_prevent(); self->stream = stream; + mp_load_method_maybe(stream, MP_QSTR_settimeout, self->settimeout_args); INPUT_BUFFER_CLEAR(self->inbuf); self->eof = 0; + + self->block_ok = false; + stream_set_blocking(self, true); + self->other_channel = -1; mp3file_update_inbuf_half(self, true); mp3file_find_sync_word(self, true); @@ -365,6 +384,7 @@ void common_hal_audiomp3_mp3file_deinit(audiomp3_mp3file_obj_t *self) { self->pcm_buffer[0] = NULL; self->pcm_buffer[1] = NULL; self->stream = mp_const_none; + self->settimeout_args[0] = MP_OBJ_NULL; self->samples_decoded = 0; } diff --git a/shared-module/audiomp3/MP3Decoder.h b/shared-module/audiomp3/MP3Decoder.h index 89328deb4a97..9f1b97a5a516 100644 --- a/shared-module/audiomp3/MP3Decoder.h +++ b/shared-module/audiomp3/MP3Decoder.h @@ -35,6 +35,8 @@ typedef struct { uint8_t buffer_index; uint8_t channel_count; bool eof; + bool block_ok; + mp_obj_t settimeout_args[3]; int8_t other_channel; int8_t other_buffer_index; diff --git a/shared-module/ssl/SSLSocket.c b/shared-module/ssl/SSLSocket.c index 9303964af719..0b40bd663c7d 100644 --- a/shared-module/ssl/SSLSocket.c +++ b/shared-module/ssl/SSLSocket.c @@ -22,6 +22,8 @@ #include "mbedtls/version.h" +#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR) + #if defined(MBEDTLS_ERROR_C) #include "../../lib/mbedtls_errors/mp_mbedtls_errors.c" #endif @@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t o->base.type = &ssl_sslsocket_type; o->ssl_context = self; o->sock_obj = socket; + o->poll_mask = 0; mp_load_method(socket, MP_QSTR_accept, o->accept_args); mp_load_method(socket, MP_QSTR_bind, o->bind_args); @@ -330,7 +333,8 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t } } -mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) { +mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, mp_uint_t len) { + self->poll_mask = 0; int ret = mbedtls_ssl_read(&self->ssl, buf, len); DEBUG_PRINT("recv_into mbedtls_ssl_read() -> %d\n", ret); if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { @@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t DEBUG_PRINT("returning %d\n", ret); return ret; } + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + self->poll_mask = MP_STREAM_POLL_WR; + } DEBUG_PRINT("raising errno [error case] %d\n", ret); mbedtls_raise_error(ret); } -mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) { +mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, mp_uint_t len) { + self->poll_mask = 0; int ret = mbedtls_ssl_write(&self->ssl, buf, len); DEBUG_PRINT("send mbedtls_ssl_write() -> %d\n", ret); if (ret >= 0) { DEBUG_PRINT("returning %d\n", ret); return ret; } + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + self->poll_mask = MP_STREAM_POLL_RD; + } DEBUG_PRINT("raising errno [error case] %d\n", ret); mbedtls_raise_error(ret); } @@ -448,3 +459,37 @@ void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t lev void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t timeout_obj) { ssl_socket_settimeout(self, timeout_obj); } + +static bool poll_common(ssl_sslsocket_obj_t *self, uintptr_t arg) { + // Take into account that the library might have buffered data already + int has_pending = 0; + if (arg & MP_STREAM_POLL_RD) { + has_pending = mbedtls_ssl_check_pending(&self->ssl); + if (has_pending) { + // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket + return true; + } + } + + // If the library signaled us that it needs reading or writing, only + // check that direction + if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) { + arg = (arg & ~MP_STREAM_POLL_RDWR) | self->poll_mask; + } + + // If direction the library needed is available, return a fake + // result to the caller so that it reenters a read or a write to + // allow the handshake to progress + const mp_stream_p_t *stream_p = mp_get_stream_raise(self->sock_obj, MP_STREAM_OP_IOCTL); + int errcode; + mp_int_t ret = stream_p->ioctl(self->sock_obj, MP_STREAM_POLL, arg, &errcode); + return ret != 0; +} + +bool common_hal_ssl_sslsocket_readable(ssl_sslsocket_obj_t *self) { + return poll_common(self, MP_STREAM_POLL_RD); +} + +bool common_hal_ssl_sslsocket_writable(ssl_sslsocket_obj_t *self) { + return poll_common(self, MP_STREAM_POLL_WR); +} diff --git a/shared-module/ssl/SSLSocket.h b/shared-module/ssl/SSLSocket.h index b9baad3ce174..f7f3d1ae83ce 100644 --- a/shared-module/ssl/SSLSocket.h +++ b/shared-module/ssl/SSLSocket.h @@ -29,6 +29,7 @@ typedef struct ssl_sslsocket_obj { mbedtls_x509_crt cacert; mbedtls_x509_crt cert; mbedtls_pk_context pkey; + uintptr_t poll_mask; bool closed; mp_obj_t accept_args[2]; mp_obj_t bind_args[3];