From d8ecb121915bc892c6e8eeef52d95a1f74db9bdf Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Fri, 15 Mar 2024 13:34:53 +0100 Subject: [PATCH] Move to `pyclass(frozen)` (#232) --- granian/_granian.pyi | 18 ---- granian/asgi.py | 7 +- granian/rsgi.py | 22 +++- src/asgi/callbacks.rs | 83 +++++++------- src/asgi/http.rs | 16 +-- src/asgi/io.rs | 154 +++++++++++++------------- src/asgi/mod.rs | 3 +- src/asgi/serve.rs | 2 +- src/asgi/types.rs | 244 +++++++++++++++++++++--------------------- src/callbacks.rs | 221 ++++++++++++++++++++++++++------------ src/conversion.rs | 2 + src/rsgi/callbacks.rs | 80 +++++++------- src/rsgi/http.rs | 16 ++- src/rsgi/io.rs | 158 +++++++++++++-------------- src/rsgi/mod.rs | 3 +- src/rsgi/serve.rs | 2 +- src/rsgi/types.rs | 148 +++++++++++++------------ src/runtime.rs | 94 ++++++++++------ src/tcp.rs | 2 +- src/workers.rs | 18 ++-- src/wsgi/serve.rs | 2 +- src/wsgi/types.rs | 177 +++++++++++++++--------------- 22 files changed, 790 insertions(+), 682 deletions(-) diff --git a/granian/_granian.pyi b/granian/_granian.pyi index ae1fce4a..f60d9061 100644 --- a/granian/_granian.pyi +++ b/granian/_granian.pyi @@ -4,9 +4,6 @@ from ._types import WebsocketMessage __version__: str -class ASGIScope: - def as_dict(self, root_path: str) -> Dict[str, Any]: ... - class RSGIHeaders: def __contains__(self, key: str) -> bool: ... def keys(self) -> List[str]: ... @@ -14,21 +11,6 @@ class RSGIHeaders: def items(self) -> List[Tuple[str]]: ... def get(self, key: str, default: Any = None) -> Any: ... -class RSGIScope: - proto: str - http_version: str - rsgi_version: str - server: str - client: str - scheme: str - method: str - path: str - query_string: str - authority: Optional[str] - - @property - def headers(self) -> RSGIHeaders: ... - class RSGIHTTPStreamTransport: async def send_bytes(self, data: bytes): ... async def send_str(self, data: str): ... diff --git a/granian/asgi.py b/granian/asgi.py index b70da848..2d6ec01b 100644 --- a/granian/asgi.py +++ b/granian/asgi.py @@ -1,10 +1,15 @@ import asyncio from functools import wraps +from typing import Any, Dict -from ._granian import ASGIScope as Scope from .log import logger +class Scope: + def as_dict(self, root_path: str) -> Dict[str, Any]: + ... + + class LifespanProtocol: error_transition = 'Invalid lifespan state transition' diff --git a/granian/rsgi.py b/granian/rsgi.py index 4f9ecbba..e2db1961 100644 --- a/granian/rsgi.py +++ b/granian/rsgi.py @@ -1,16 +1,32 @@ from enum import Enum -from typing import Union +from typing import Optional, Union from ._granian import ( - RSGIHeaders as Headers, # noqa + RSGIHeaders as Headers, RSGIHTTPProtocol as HTTPProtocol, # noqa RSGIProtocolClosed as ProtocolClosed, # noqa RSGIProtocolError as ProtocolError, # noqa - RSGIScope as Scope, # noqa RSGIWebsocketProtocol as WebsocketProtocol, # noqa ) +class Scope: + proto: str + http_version: str + rsgi_version: str + server: str + client: str + scheme: str + method: str + path: str + query_string: str + authority: Optional[str] + + @property + def headers(self) -> Headers: + ... + + class WebsocketMessageType(int, Enum): close = 0 bytes = 1 diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index f79c109e..d5a56bd7 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -4,7 +4,8 @@ use tokio::sync::oneshot; use super::{ io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport}, - types::ASGIScope as Scope, + types::ASGIHTTPScope as HTTPScope, + types::ASGIWebsocketScope as WebsocketScope, }; use crate::{ callbacks::{ @@ -17,7 +18,7 @@ use crate::{ ws::{HyperWebsocket, UpgradeData}, }; -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackRunnerHTTP { proto: Py, context: TaskLocals, @@ -25,7 +26,7 @@ pub(crate) struct CallbackRunnerHTTP { } impl CallbackRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: HTTPScope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone_ref(py), @@ -51,23 +52,21 @@ impl CallbackRunnerHTTP { } macro_rules! callback_impl_done_http { - ($self:expr, $py:expr) => { - if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { - if let Some(tx) = proto.tx() { - let _ = tx.send(response_500()); - } + ($self:expr) => { + if let Some(tx) = $self.proto.get().tx() { + let _ = tx.send(response_500()); } }; } macro_rules! callback_impl_done_err { - ($self:expr, $py:expr, $err:expr) => { - $self.done($py); + ($self:expr, $err:expr) => { + $self.done(); log_application_callable_exception($err); }; } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackTaskHTTP { proto: Py, context: TaskLocals, @@ -86,12 +85,12 @@ impl CallbackTaskHTTP { }) } - fn done(&self, py: Python) { - callback_impl_done_http!(self, py); + fn done(&self) { + callback_impl_done_http!(self); } - fn err(&self, py: Python, err: &PyErr) { - callback_impl_done_err!(self, py, err); + fn err(&self, err: &PyErr) { + callback_impl_done_err!(self, err); } callback_impl_loop_run!(); @@ -109,7 +108,7 @@ impl CallbackTaskHTTP { } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackWrappedRunnerHTTP { #[pyo3(get)] proto: Py, @@ -120,7 +119,7 @@ pub(crate) struct CallbackWrappedRunnerHTTP { } impl CallbackWrappedRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: HTTPScope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, @@ -138,16 +137,16 @@ impl CallbackWrappedRunnerHTTP { callback_impl_loop_pytask!(pyself, py) } - fn done(&self, py: Python) { - callback_impl_done_http!(self, py); + fn done(&self) { + callback_impl_done_http!(self); } - fn err(&self, py: Python, err: &PyAny) { - callback_impl_done_err!(self, py, &PyErr::from_value(err)); + fn err(&self, err: &PyAny) { + callback_impl_done_err!(self, &PyErr::from_value(err)); } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackRunnerWebsocket { proto: Py, context: TaskLocals, @@ -155,7 +154,7 @@ pub(crate) struct CallbackRunnerWebsocket { } impl CallbackRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone(), @@ -175,16 +174,14 @@ impl CallbackRunnerWebsocket { } macro_rules! callback_impl_done_ws { - ($self:expr, $py:expr) => { - if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { - if let (Some(tx), res) = proto.tx() { - let _ = tx.send(res); - } + ($self:expr) => { + if let (Some(tx), res) = $self.proto.get().tx() { + let _ = tx.send(res); } }; } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackTaskWebsocket { proto: Py, context: TaskLocals, @@ -203,12 +200,12 @@ impl CallbackTaskWebsocket { }) } - fn done(&self, py: Python) { - callback_impl_done_ws!(self, py); + fn done(&self) { + callback_impl_done_ws!(self); } - fn err(&self, py: Python, err: &PyErr) { - callback_impl_done_err!(self, py, err); + fn err(&self, err: &PyErr) { + callback_impl_done_err!(self, err); } callback_impl_loop_run!(); @@ -226,7 +223,7 @@ impl CallbackTaskWebsocket { } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackWrappedRunnerWebsocket { #[pyo3(get)] proto: Py, @@ -237,7 +234,7 @@ pub(crate) struct CallbackWrappedRunnerWebsocket { } impl CallbackWrappedRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, @@ -255,12 +252,12 @@ impl CallbackWrappedRunnerWebsocket { callback_impl_loop_pytask!(pyself, py) } - fn done(&self, py: Python) { - callback_impl_done_ws!(self, py); + fn done(&self) { + callback_impl_done_ws!(self); } - fn err(&self, py: Python, err: &PyAny) { - callback_impl_done_err!(self, py, &PyErr::from_value(err)); + fn err(&self, err: &PyAny) { + callback_impl_done_err!(self, &PyErr::from_value(err)); } } @@ -293,7 +290,7 @@ macro_rules! call_impl_rtb_http { cb: CallbackWrapper, rt: RuntimeRef, req: HTTPRequest, - scope: Scope, + scope: HTTPScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, req, tx); @@ -313,7 +310,7 @@ macro_rules! call_impl_rtt_http { cb: CallbackWrapper, rt: RuntimeRef, req: HTTPRequest, - scope: Scope, + scope: HTTPScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, req, tx); @@ -336,7 +333,7 @@ macro_rules! call_impl_rtb_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope, + scope: WebsocketScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); @@ -357,7 +354,7 @@ macro_rules! call_impl_rtt_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope, + scope: WebsocketScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); diff --git a/src/asgi/http.rs b/src/asgi/http.rs index c41f575a..9363544d 100644 --- a/src/asgi/http.rs +++ b/src/asgi/http.rs @@ -8,7 +8,8 @@ use super::{ call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws, call_rtt_ws_pyw, }, - types::ASGIScope as Scope, + types::ASGIHTTPScope as HTTPScope, + types::ASGIWebsocketScope as WebsocketScope, }; use crate::{ callbacks::CallbackWrapper, @@ -17,9 +18,9 @@ use crate::{ ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData}, }; -macro_rules! default_scope { - ($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { - Scope::new( +macro_rules! build_scope { + ($cls:ty, $server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { + <$cls>::new( $req.version(), $scheme, $req.uri().clone(), @@ -53,7 +54,7 @@ macro_rules! handle_request { req: HTTPRequest, scheme: &str, ) -> HTTPResponse { - let scope = default_scope!(server_addr, client_addr, &req, scheme); + let scope = build_scope!(HTTPScope, server_addr, client_addr, &req, scheme); handle_http_response!($handler, rt, callback, req, scope) } }; @@ -69,10 +70,8 @@ macro_rules! handle_request_with_ws { req: HTTPRequest, scheme: &str, ) -> HTTPResponse { - let mut scope = default_scope!(server_addr, client_addr, &req, scheme); - if is_ws_upgrade(&req) { - scope.set_websocket(); + let scope = build_scope!(WebsocketScope, server_addr, client_addr, &req, scheme); return match ws_upgrade(req, None) { Ok((res, ws)) => { @@ -130,6 +129,7 @@ macro_rules! handle_request_with_ws { }; } + let scope = build_scope!(HTTPScope, server_addr, client_addr, &req, scheme); handle_http_response!($handler_req, rt, callback, req, scope) } }; diff --git a/src/asgi/io.rs b/src/asgi/io.rs index 7c71aef4..45819246 100644 --- a/src/asgi/io.rs +++ b/src/asgi/io.rs @@ -8,10 +8,13 @@ use hyper::{ }; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; -use std::{borrow::Cow, sync::Arc}; +use std::{ + borrow::Cow, + sync::{atomic, Arc, Mutex, RwLock}, +}; use tokio::{ fs::File, - sync::{mpsc, oneshot, Mutex, RwLock}, + sync::{mpsc, oneshot, Mutex as AsyncMutex}, }; use tokio_tungstenite::tungstenite::Message; use tokio_util::io::ReaderStream; @@ -30,17 +33,16 @@ use crate::{ const EMPTY_BYTES: Cow<[u8]> = Cow::Borrowed(b""); const EMPTY_STRING: String = String::new(); -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct ASGIHTTPProtocol { rt: RuntimeRef, - tx: Option>, - request_body: Arc>>, - response_started: bool, - response_chunked: bool, - response_status: Option, - response_headers: Option, - body_tx: Option>>, - flow_rx_exhausted: Arc>, + tx: Mutex>>, + request_body: Arc>>, + response_started: atomic::AtomicBool, + response_chunked: atomic::AtomicBool, + response_intent: Mutex>, + body_tx: Mutex>>>, + flow_rx_exhausted: Arc>, flow_tx_waiter: Arc, } @@ -48,21 +50,20 @@ impl ASGIHTTPProtocol { pub fn new(rt: RuntimeRef, request: HTTPRequest, tx: oneshot::Sender) -> Self { Self { rt, - tx: Some(tx), - request_body: Arc::new(Mutex::new(http_body_util::BodyStream::new(request.into_body()))), - response_started: false, - response_chunked: false, - response_status: None, - response_headers: None, - body_tx: None, - flow_rx_exhausted: Arc::new(std::sync::RwLock::new(false)), + tx: Mutex::new(Some(tx)), + request_body: Arc::new(AsyncMutex::new(http_body_util::BodyStream::new(request.into_body()))), + response_started: false.into(), + response_chunked: false.into(), + response_intent: Mutex::new(None), + body_tx: Mutex::new(None), + flow_rx_exhausted: Arc::new(RwLock::new(false)), flow_tx_waiter: Arc::new(tokio::sync::Notify::new()), } } #[inline(always)] - fn send_response(&mut self, status: i16, headers: HeaderMap, body: HTTPResponseBody) { - if let Some(tx) = self.tx.take() { + fn send_response(&self, status: i16, headers: HeaderMap, body: HTTPResponseBody) { + if let Some(tx) = self.tx.lock().unwrap().take() { let mut res = Response::new(body); *res.status_mut() = hyper::StatusCode::from_u16(status as u16).unwrap(); *res.headers_mut() = headers; @@ -88,14 +89,14 @@ impl ASGIHTTPProtocol { }) } - pub fn tx(&mut self) -> Option> { - self.tx.take() + pub fn tx(&self) -> Option> { + self.tx.lock().unwrap().take() } } #[pymethods] impl ASGIHTTPProtocol { - fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { + fn receive<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { if *self.flow_rx_exhausted.read().unwrap() { let holder = self.flow_tx_waiter.clone(); return future_into_py_futlike(self.rt.clone(), py, async move { @@ -138,24 +139,28 @@ impl ASGIHTTPProtocol { }) } - fn send<'p>(&mut self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { + fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { match adapt_message_type(data) { - Ok(ASGIMessageType::HTTPStart) => match self.response_started { + Ok(ASGIMessageType::HTTPStart) => match self.response_started.load(atomic::Ordering::Relaxed) { false => { - self.response_status = Some(adapt_status_code(py, data)?); - self.response_headers = Some(adapt_headers(py, data)); - self.response_started = true; + let mut response_intent = self.response_intent.lock().unwrap(); + *response_intent = Some((adapt_status_code(py, data)?, adapt_headers(py, data))); + self.response_started.store(true, atomic::Ordering::Relaxed); empty_future_into_py(py) } true => error_flow!(), }, Ok(ASGIMessageType::HTTPBody) => { let (body, more) = adapt_body(py, data); - match (self.response_started, more, self.response_chunked) { + match ( + self.response_started.load(atomic::Ordering::Relaxed), + more, + self.response_chunked.load(atomic::Ordering::Relaxed), + ) { (true, false, false) => { - let headers = self.response_headers.take().unwrap(); + let (status, headers) = self.response_intent.lock().unwrap().take().unwrap(); self.send_response( - self.response_status.unwrap(), + status, headers, http_body_util::Full::new(body::Bytes::from(body)) .map_err(|e| match e {}) @@ -165,24 +170,24 @@ impl ASGIHTTPProtocol { empty_future_into_py(py) } (true, true, false) => { - self.response_chunked = true; - let headers = self.response_headers.take().unwrap(); + self.response_chunked.store(true, atomic::Ordering::Relaxed); + let (status, headers) = self.response_intent.lock().unwrap().take().unwrap(); let (body_tx, body_rx) = mpsc::channel::>(1); let body_stream = http_body_util::StreamBody::new( tokio_stream::wrappers::ReceiverStream::new(body_rx).map_ok(hyper::body::Frame::data), ); - self.body_tx = Some(body_tx.clone()); - self.send_response(self.response_status.unwrap(), headers, BodyExt::boxed(body_stream)); + *self.body_tx.lock().unwrap() = Some(body_tx.clone()); + self.send_response(status, headers, BodyExt::boxed(body_stream)); self.send_body(py, body_tx, body) } - (true, true, true) => match self.body_tx.as_mut() { + (true, true, true) => match &*self.body_tx.lock().unwrap() { Some(tx) => { let tx = tx.clone(); self.send_body(py, tx, body) } _ => error_flow!(), }, - (true, false, true) => match self.body_tx.take() { + (true, false, true) => match self.body_tx.lock().unwrap().take() { Some(tx) => { self.flow_tx_waiter.notify_one(); match body.is_empty() { @@ -195,10 +200,13 @@ impl ASGIHTTPProtocol { _ => error_flow!(), } } - Ok(ASGIMessageType::HTTPFile) => match (self.response_started, adapt_file(py, data), self.tx.take()) { + Ok(ASGIMessageType::HTTPFile) => match ( + self.response_started.load(atomic::Ordering::Relaxed), + adapt_file(py, data), + self.tx.lock().unwrap().take(), + ) { (true, Ok(file_path), Some(tx)) => { - let status = self.response_status.unwrap(); - let headers = self.response_headers.take().unwrap(); + let (status, headers) = self.response_intent.lock().unwrap().take().unwrap(); future_into_py_iter(self.rt.clone(), py, async move { let res = match File::open(&file_path).await { Ok(file) => { @@ -248,16 +256,16 @@ impl WebsocketDetachedTransport { } } -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct ASGIWebsocketProtocol { rt: RuntimeRef, - tx: Option>, - websocket: Option, - upgrade: Option, - ws_rx: Arc>>, - ws_tx: Arc>>, - accepted: Arc>, - closed: Arc>, + tx: Mutex>>, + websocket: Mutex>, + upgrade: Mutex>, + ws_rx: Arc>>, + ws_tx: Arc>>, + accepted: Arc, + closed: Arc, } impl ASGIWebsocketProtocol { @@ -269,20 +277,20 @@ impl ASGIWebsocketProtocol { ) -> Self { Self { rt, - tx: Some(tx), - websocket: Some(websocket), - upgrade: Some(upgrade), - ws_rx: Arc::new(Mutex::new(None)), - ws_tx: Arc::new(Mutex::new(None)), - accepted: Arc::new(RwLock::new(false)), - closed: Arc::new(RwLock::new(false)), + tx: Mutex::new(Some(tx)), + websocket: Mutex::new(Some(websocket)), + upgrade: Mutex::new(Some(upgrade)), + ws_rx: Arc::new(AsyncMutex::new(None)), + ws_tx: Arc::new(AsyncMutex::new(None)), + accepted: Arc::new(false.into()), + closed: Arc::new(false.into()), } } #[inline(always)] - fn accept<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { - let upgrade = self.upgrade.take(); - let websocket = self.websocket.take(); + fn accept<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let upgrade = self.upgrade.lock().unwrap().take(); + let websocket = self.websocket.lock().unwrap().take(); let accepted = self.accepted.clone(); let rx = self.ws_rx.clone(); let tx = self.ws_tx.clone(); @@ -294,11 +302,10 @@ impl ASGIWebsocketProtocol { if let Ok(stream) = websocket.await { let mut wtx = tx.lock().await; let mut wrx = rx.lock().await; - let mut accepted = accepted.write().await; let (tx, rx) = stream.split(); *wtx = Some(tx); *wrx = Some(rx); - *accepted = true; + accepted.store(true, atomic::Ordering::Relaxed); return Ok(()); } } @@ -321,8 +328,7 @@ impl ASGIWebsocketProtocol { match ws.send(message).await { Ok(()) => return Ok(()), _ => { - let closed = closed.read().await; - if *closed { + if closed.load(atomic::Ordering::Relaxed) { log::info!("Attempted to write to a closed websocket"); return Ok(()); } @@ -337,7 +343,7 @@ impl ASGIWebsocketProtocol { } #[inline(always)] - fn close<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { + fn close<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let closed = self.closed.clone(); let ws_rx = self.ws_rx.clone(); let ws_tx = self.ws_tx.clone(); @@ -345,8 +351,7 @@ impl ASGIWebsocketProtocol { future_into_py_iter(self.rt.clone(), py, async move { match ws_tx.lock().await.take() { Some(tx) => { - let mut closed = closed.write().await; - *closed = true; + closed.store(true, atomic::Ordering::Relaxed); WebsocketDetachedTransport::new(true, ws_rx.lock().await.take(), Some(tx)) .close() .await; @@ -358,11 +363,11 @@ impl ASGIWebsocketProtocol { } fn consumed(&self) -> bool { - self.upgrade.is_none() + self.upgrade.lock().unwrap().is_none() } pub fn tx( - &mut self, + &self, ) -> ( Option>, WebsocketDetachedTransport, @@ -370,7 +375,7 @@ impl ASGIWebsocketProtocol { let mut ws_rx = self.ws_rx.blocking_lock(); let mut ws_tx = self.ws_tx.blocking_lock(); ( - self.tx.take(), + self.tx.lock().unwrap().take(), WebsocketDetachedTransport::new(self.consumed(), ws_rx.take(), ws_tx.take()), ) } @@ -378,14 +383,14 @@ impl ASGIWebsocketProtocol { #[pymethods] impl ASGIWebsocketProtocol { - fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { + fn receive<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let accepted = self.accepted.clone(); let closed = self.closed.clone(); let transport = self.ws_rx.clone(); future_into_py_futlike(self.rt.clone(), py, async move { - let accepted = accepted.read().await; - if !*accepted { + let accepted = accepted.load(atomic::Ordering::Relaxed); + if !accepted { return Python::with_gil(|py| { let dict = PyDict::new(py); dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?; @@ -398,8 +403,7 @@ impl ASGIWebsocketProtocol { match recv { Ok(Message::Ping(_)) => continue, Ok(message @ Message::Close(_)) => { - let mut closed = closed.write().await; - *closed = true; + closed.store(true, atomic::Ordering::Relaxed); return ws_message_into_py(message); } Ok(message) => return ws_message_into_py(message), @@ -411,7 +415,7 @@ impl ASGIWebsocketProtocol { }) } - fn send<'p>(&mut self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { + fn send<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { match adapt_message_type(data) { Ok(ASGIMessageType::WSAccept) => self.accept(py), Ok(ASGIMessageType::WSClose) => self.close(py), diff --git a/src/asgi/mod.rs b/src/asgi/mod.rs index 739d265a..3037349a 100644 --- a/src/asgi/mod.rs +++ b/src/asgi/mod.rs @@ -10,7 +10,8 @@ mod types; pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_class::()?; - module.add_class::()?; + module.add_class::()?; + module.add_class::()?; Ok(()) } diff --git a/src/asgi/serve.rs b/src/asgi/serve.rs index 3d0768f7..76b7e9a0 100644 --- a/src/asgi/serve.rs +++ b/src/asgi/serve.rs @@ -8,7 +8,7 @@ use super::http::{ use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal}; -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub struct ASGIWorker { config: WorkerConfig, } diff --git a/src/asgi/types.rs b/src/asgi/types.rs index a5da81cd..36543727 100644 --- a/src/asgi/types.rs +++ b/src/asgi/types.rs @@ -9,7 +9,10 @@ use pyo3::{ sync::GILOnceCell, types::{PyBytes, PyDict, PyList}, }; -use std::net::{IpAddr, SocketAddr}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; const SCHEME_HTTPS: &str = "https"; const SCHEME_WS: &str = "ws"; @@ -27,152 +30,153 @@ pub(crate) enum ASGIMessageType { WSMessage, } -#[pyclass(module = "granian._granian")] -pub(crate) struct ASGIScope { - http_version: Version, - scheme: String, - method: String, - uri: Uri, - server_ip: IpAddr, - server_port: u16, - client_ip: IpAddr, - client_port: u16, - headers: HeaderMap, - is_websocket: bool, -} - -impl ASGIScope { - pub fn new( - http_version: Version, - scheme: &str, - uri: Uri, - method: &str, - server: SocketAddr, - client: SocketAddr, - headers: &HeaderMap, - ) -> Self { - Self { - http_version, - scheme: scheme.to_string(), - method: method.to_string(), - uri, - server_ip: server.ip(), - server_port: server.port(), - client_ip: client.ip(), - client_port: client.port(), - headers: headers.clone(), - is_websocket: false, - } - } - - pub fn set_websocket(&mut self) { - self.is_websocket = true; - } - - #[inline(always)] - fn py_proto(&self) -> &str { - match self.is_websocket { - false => "http", - true => "websocket", +macro_rules! asgi_scope_cls { + ($name:ident, $proto:expr) => { + #[pyclass(frozen, module = "granian._granian")] + pub(crate) struct $name { + http_version: Version, + scheme: Arc, + method: Arc, + uri: Uri, + server_ip: IpAddr, + server_port: u16, + client_ip: IpAddr, + client_port: u16, + headers: HeaderMap, } - } - #[inline(always)] - fn py_http_version(&self) -> &str { - match self.http_version { - Version::HTTP_10 => "1", - Version::HTTP_11 => "1.1", - Version::HTTP_2 => "2", - Version::HTTP_3 => "3", - _ => "1", - } - } + impl $name { + pub fn new( + http_version: Version, + scheme: &str, + uri: Uri, + method: &str, + server: SocketAddr, + client: SocketAddr, + headers: &HeaderMap, + ) -> Self { + Self { + http_version, + scheme: scheme.into(), + method: method.into(), + uri, + server_ip: server.ip(), + server_port: server.port(), + client_ip: client.ip(), + client_port: client.port(), + headers: headers.clone(), + } + } - #[inline(always)] - fn py_scheme(&self) -> &str { - let scheme = &self.scheme[..]; - match self.is_websocket { - false => scheme, - true => match scheme { - SCHEME_HTTPS => SCHEME_WSS, - _ => SCHEME_WS, - }, - } - } + #[inline(always)] + fn get_proto(&self) -> &str { + $proto + } - #[inline(always)] - fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> { - let rv = PyList::empty(py); - for (key, value) in &self.headers { - rv.append(( - PyBytes::new(py, key.as_str().as_bytes()), - PyBytes::new(py, value.as_bytes()), - ))?; - } - if !self.headers.contains_key(header::HOST) { - let host = self.uri.authority().map_or("", Authority::as_str); - rv.insert(0, (PyBytes::new(py, b"host"), PyBytes::new(py, host.as_bytes())))?; + #[inline(always)] + fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> { + let rv = PyList::empty(py); + for (key, value) in &self.headers { + rv.append(( + PyBytes::new(py, key.as_str().as_bytes()), + PyBytes::new(py, value.as_bytes()), + ))?; + } + if !self.headers.contains_key(header::HOST) { + let host = self.uri.authority().map_or("", Authority::as_str); + rv.insert(0, (PyBytes::new(py, b"host"), PyBytes::new(py, host.as_bytes())))?; + } + Ok(rv) + } } - Ok(rv) - } + }; } -#[pymethods] -impl ASGIScope { - fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str, state: &'p PyAny) -> PyResult<&'p PyAny> { - let (path, query_string, proto, http_version, server, client, scheme, method) = py.allow_threads(|| { - let (path, query_string) = self +asgi_scope_cls!(ASGIHTTPScope, "http"); +asgi_scope_cls!(ASGIWebsocketScope, "websocket"); + +macro_rules! asgi_scope_as_dict { + ($self:expr, $py:expr, $url_path_prefix:expr, $state:expr, $dict:expr) => { + let (path, query_string, proto, http_version, server, client) = $py.allow_threads(|| { + let (path, query_string) = $self .uri .path_and_query() .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); ( percent_decode_str(path).decode_utf8().unwrap(), query_string, - self.py_proto(), - self.py_http_version(), - (self.server_ip.to_string(), self.server_port), - (self.client_ip.to_string(), self.client_port), - self.py_scheme(), - &self.method[..], + $self.get_proto(), + match $self.http_version { + Version::HTTP_10 => "1", + Version::HTTP_11 => "1.1", + Version::HTTP_2 => "2", + Version::HTTP_3 => "3", + _ => "1", + }, + ($self.server_ip.to_string(), $self.server_port), + ($self.client_ip.to_string(), $self.client_port), ) }); - let dict: &PyDict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "asgi"), + $dict.set_item( + pyo3::intern!($py, "asgi"), ASGI_VERSION - .get_or_try_init(py, || { - let rv = PyDict::new(py); + .get_or_try_init($py, || { + let rv = PyDict::new($py); rv.set_item("version", "3.0")?; rv.set_item("spec_version", "2.3")?; Ok::(rv.into()) })? - .as_ref(py), + .as_ref($py), )?; - dict.set_item( - pyo3::intern!(py, "extensions"), + $dict.set_item( + pyo3::intern!($py, "extensions"), ASGI_EXTENSIONS - .get_or_try_init(py, || { - let rv = PyDict::new(py); - rv.set_item("http.response.pathsend", PyDict::new(py))?; + .get_or_try_init($py, || { + let rv = PyDict::new($py); + rv.set_item("http.response.pathsend", PyDict::new($py))?; Ok::(rv.into()) })? - .as_ref(py), + .as_ref($py), + )?; + $dict.set_item(pyo3::intern!($py, "type"), proto)?; + $dict.set_item(pyo3::intern!($py, "http_version"), http_version)?; + $dict.set_item(pyo3::intern!($py, "server"), server)?; + $dict.set_item(pyo3::intern!($py, "client"), client)?; + $dict.set_item(pyo3::intern!($py, "method"), &*$self.method)?; + $dict.set_item(pyo3::intern!($py, "root_path"), $url_path_prefix)?; + $dict.set_item(pyo3::intern!($py, "path"), &path)?; + $dict.set_item(pyo3::intern!($py, "raw_path"), PyBytes::new($py, path.as_bytes()))?; + $dict.set_item( + pyo3::intern!($py, "query_string"), + PyBytes::new($py, query_string.as_bytes()), )?; - dict.set_item(pyo3::intern!(py, "type"), proto)?; - dict.set_item(pyo3::intern!(py, "http_version"), http_version)?; - dict.set_item(pyo3::intern!(py, "server"), server)?; - dict.set_item(pyo3::intern!(py, "client"), client)?; - dict.set_item(pyo3::intern!(py, "scheme"), scheme)?; - dict.set_item(pyo3::intern!(py, "method"), method)?; - dict.set_item(pyo3::intern!(py, "root_path"), url_path_prefix)?; - dict.set_item(pyo3::intern!(py, "path"), &path)?; - dict.set_item(pyo3::intern!(py, "raw_path"), PyBytes::new(py, path.as_bytes()))?; + $dict.set_item(pyo3::intern!($py, "headers"), $self.py_headers($py)?)?; + $dict.set_item(pyo3::intern!($py, "state"), $state)?; + }; +} + +#[pymethods] +impl ASGIHTTPScope { + fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str, state: &'p PyAny) -> PyResult<&'p PyAny> { + let dict: &PyDict = PyDict::new(py); + asgi_scope_as_dict!(self, py, url_path_prefix, state, dict); + dict.set_item(pyo3::intern!(py, "scheme"), &*self.scheme)?; + Ok(dict) + } +} + +#[pymethods] +impl ASGIWebsocketScope { + fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str, state: &'p PyAny) -> PyResult<&'p PyAny> { + let dict: &PyDict = PyDict::new(py); + asgi_scope_as_dict!(self, py, url_path_prefix, state, dict); dict.set_item( - pyo3::intern!(py, "query_string"), - PyBytes::new(py, query_string.as_bytes()), + pyo3::intern!(py, "scheme"), + match &*self.scheme { + SCHEME_HTTPS => SCHEME_WSS, + _ => SCHEME_WS, + }, )?; - dict.set_item(pyo3::intern!(py, "headers"), self.py_headers(py)?)?; - dict.set_item(pyo3::intern!(py, "state"), state)?; Ok(dict) } } diff --git a/src/callbacks.rs b/src/callbacks.rs index bb0cc1d9..98fb8787 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -1,5 +1,6 @@ use pyo3::{prelude::*, pyclass::IterNextOutput, sync::GILOnceCell}; -use std::sync::Arc; + +use std::sync::{atomic, Arc, RwLock}; use tokio::sync::Notify; static CONTEXTVARS: GILOnceCell = GILOnceCell::new(); @@ -20,8 +21,8 @@ impl CallbackWrapper { } } -#[pyclass] -pub(crate) struct PyEmptyAwaitable {} +#[pyclass(frozen)] +pub(crate) struct PyEmptyAwaitable; #[pymethods] impl PyEmptyAwaitable { @@ -38,18 +39,21 @@ impl PyEmptyAwaitable { } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct PyIterAwaitable { - result: Option>, + result: RwLock>>, } impl PyIterAwaitable { pub(crate) fn new() -> Self { - Self { result: None } + Self { + result: RwLock::new(None), + } } - pub(crate) fn set_result(&mut self, result: PyResult) { - self.result = Some(result); + pub(crate) fn set_result(&self, result: PyResult>) { + let mut res = self.result.write().unwrap(); + *res = Some(Python::with_gil(|py| result.map(|v| v.into_py(py)))); } } @@ -64,46 +68,64 @@ impl PyIterAwaitable { } fn __next__(&self, py: Python) -> PyResult> { - match &self.result { - Some(res) => match res { - Ok(v) => Ok(IterNextOutput::Return(v.clone_ref(py))), - Err(err) => Err(err.clone_ref(py)), - }, - _ => Ok(IterNextOutput::Yield(py.None())), - } + if let Ok(res) = self.result.try_read() { + if let Some(ref res) = *res { + return res + .as_ref() + .map(|v| IterNextOutput::Return(v.clone_ref(py))) + .map_err(|err| err.clone_ref(py)); + } + }; + Ok(IterNextOutput::Yield(py.None())) } } -#[pyclass] +enum PyFutureAwaitableState { + Pending, + Completed(PyResult), + Cancelled, +} + +#[pyclass(frozen)] pub(crate) struct PyFutureAwaitable { - fut_spawner: Option, Arc, Py) + Send>>, - result: Option>, + state: RwLock, event_loop: PyObject, - callback: Option, cancel_tx: Arc, - py_block: bool, - py_cancelled: bool, + py_block: atomic::AtomicBool, + ack: RwLock)>>, } impl PyFutureAwaitable { - pub(crate) fn new( - fut_spawner: Box, Arc, Py) + Send>, - event_loop: PyObject, - ) -> Self { + pub(crate) fn new(event_loop: PyObject) -> Self { Self { - fut_spawner: Some(fut_spawner), - result: None, + state: RwLock::new(PyFutureAwaitableState::Pending), event_loop, - callback: None, cancel_tx: Arc::new(Notify::new()), - py_block: true, - py_cancelled: false, + py_block: true.into(), + ack: RwLock::new(None), } } - pub(crate) fn set_result(mut pyself: PyRefMut<'_, Self>, result: PyResult) -> Option { - pyself.result = Some(result); - pyself.callback.take() + pub fn to_spawn(self, py: Python) -> PyResult<(Py, Arc)> { + let cancel_tx = self.cancel_tx.clone(); + Ok((Py::new(py, self)?, cancel_tx)) + } + + pub(crate) fn set_result(&self, result: PyResult>, aw: Py) { + Python::with_gil(|py| { + let mut state = self.state.write().unwrap(); + *state = PyFutureAwaitableState::Completed(result.map(|v| v.into_py(py))); + + let ack = self.ack.read().unwrap(); + if let Some((cb, ctx)) = &*ack { + let _ = self.event_loop.clone_ref(py).call_method( + py, + pyo3::intern!(py, "call_soon_threadsafe"), + (cb, aw), + Some(ctx.as_ref(py)), + ); + } + }); } } @@ -112,15 +134,31 @@ impl PyFutureAwaitable { fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> { pyself } + fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> { + pyself + } + + fn __next__(pyself: PyRef<'_, Self>) -> PyResult, PyObject>> { + let state = pyself.state.read().unwrap(); + if let PyFutureAwaitableState::Completed(res) = &*state { + let py = pyself.py(); + return res + .as_ref() + .map(|v| IterNextOutput::Return(v.clone_ref(py))) + .map_err(|err| err.clone_ref(py)); + }; + drop(state); + Ok(IterNextOutput::Yield(pyself)) + } #[getter(_asyncio_future_blocking)] fn get_block(&self) -> bool { - self.py_block + self.py_block.load(atomic::Ordering::Relaxed) } #[setter(_asyncio_future_blocking)] - fn set_block(&mut self, val: bool) { - self.py_block = val; + fn set_block(&self, val: bool) { + self.py_block.store(val, atomic::Ordering::Relaxed); } fn get_loop(&self, py: Python) -> PyObject { @@ -128,66 +166,107 @@ impl PyFutureAwaitable { } #[pyo3(signature = (cb, context=None))] - fn add_done_callback(mut pyself: PyRefMut<'_, Self>, cb: PyObject, context: Option) { - pyself.callback = Some(cb); - if let Some(spawner) = pyself.fut_spawner.take() { - (spawner)(context, pyself.cancel_tx.clone(), pyself.into()); + fn add_done_callback(pyself: PyRef<'_, Self>, cb: PyObject, context: Option) -> PyResult<()> { + let py = pyself.py(); + let kwctx = pyo3::types::PyDict::new(py); + kwctx.set_item(pyo3::intern!(py, "context"), context)?; + + let state = pyself.state.read().unwrap(); + match &*state { + PyFutureAwaitableState::Pending => { + let mut ack = pyself.ack.write().unwrap(); + *ack = Some((cb, kwctx.into_py(py))); + Ok(()) + } + _ => { + drop(state); + let event_loop = pyself.event_loop.clone_ref(py); + event_loop.call_method(py, pyo3::intern!(py, "call_soon"), (cb, pyself), Some(kwctx))?; + Ok(()) + } } } #[allow(unused)] - fn remove_done_callback(&mut self, cb: PyObject) -> i32 { - self.callback = None; + fn remove_done_callback(&self, cb: PyObject) -> i32 { + let mut ack = self.ack.write().unwrap(); + *ack = None; 1 } #[allow(unused)] #[pyo3(signature = (msg=None))] - fn cancel(&mut self, msg: Option) -> bool { - if self.done() { + fn cancel(&self, msg: Option) -> bool { + let mut state = self.state.write().unwrap(); + if !matches!(&mut *state, PyFutureAwaitableState::Pending) { return false; } - self.py_cancelled = true; + + *state = PyFutureAwaitableState::Cancelled; self.cancel_tx.notify_one(); true } fn done(&self) -> bool { - self.result.is_some() || self.py_cancelled + let state = self.state.read().unwrap(); + !matches!(&*state, PyFutureAwaitableState::Pending) } - fn result(pyself: PyRef<'_, Self>) -> PyResult { - if pyself.py_cancelled { - return Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")); - } - - match &pyself.result { - Some(res) => { - let py = pyself.py(); + fn result(&self, py: Python) -> PyResult { + let state = self.state.read().unwrap(); + match &*state { + PyFutureAwaitableState::Completed(res) => { res.as_ref().map(|v| v.clone_ref(py)).map_err(|err| err.clone_ref(py)) } - _ => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( + PyFutureAwaitableState::Cancelled => { + Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")) + } + PyFutureAwaitableState::Pending => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( "Result is not ready.", )), } } - fn exception(&self) {} - - fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> { - pyself + fn exception(&self, py: Python) -> PyResult { + let state = self.state.read().unwrap(); + match &*state { + PyFutureAwaitableState::Completed(res) => res.as_ref().map(|_| py.None()).map_err(|err| err.clone_ref(py)), + PyFutureAwaitableState::Cancelled => { + Err(pyo3::exceptions::asyncio::CancelledError::new_err("Future cancelled.")) + } + PyFutureAwaitableState::Pending => Err(pyo3::exceptions::asyncio::InvalidStateError::new_err( + "Exception is not set.", + )), + } } +} - fn __next__(pyself: PyRef<'_, Self>) -> PyResult, PyObject>> { - match &pyself.result { - Some(res) => { - let py = pyself.py(); - res.as_ref() - .map(|v| IterNextOutput::Return(v.clone_ref(py))) - .map_err(|err| err.clone_ref(py)) - } - _ => Ok(IterNextOutput::Yield(pyself)), +#[pyclass(frozen)] +pub(crate) struct PyFutureDoneCallback { + pub cancel_tx: Arc, +} + +#[pymethods] +impl PyFutureDoneCallback { + pub fn __call__(&self, fut: &PyAny) -> PyResult<()> { + let py = fut.py(); + + if { fut.getattr(pyo3::intern!(py, "cancelled"))?.call0()?.is_true() }.unwrap_or(false) { + self.cancel_tx.notify_one(); } + + Ok(()) + } +} + +#[pyclass(frozen)] +pub(crate) struct PyFutureResultSetter; + +#[pymethods] +impl PyFutureResultSetter { + pub fn __call__(&self, target: &PyAny, value: &PyAny) -> PyResult<()> { + target.call1((value,))?; + Ok(()) } } @@ -296,9 +375,9 @@ macro_rules! callback_impl_loop_step { if (err.is_instance_of::($py) || err.is_instance_of::($py)) { - $pyself.done($py); + $pyself.done(); } else { - $pyself.err($py, &err); + $pyself.err(&err); } Ok(()) } @@ -318,7 +397,7 @@ macro_rules! callback_impl_loop_wake { macro_rules! callback_impl_loop_err { () => { pub fn _loop_err(&self, py: Python, err: PyErr) -> PyResult { - self.err(py, &err); + self.err(&err); let cberr = self.cb.call_method1(py, pyo3::intern!(py, "throw"), (err,)); cberr } diff --git a/src/conversion.rs b/src/conversion.rs index a4ed8d38..d74d8724 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -8,12 +8,14 @@ pub(crate) struct BytesToPy(pub hyper::body::Bytes); impl Deref for BytesToPy { type Target = hyper::body::Bytes; + #[inline] fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for BytesToPy { + #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index e8da33ff..17e6e71f 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -4,7 +4,7 @@ use tokio::sync::oneshot; use super::{ io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport}, - types::{PyResponse, PyResponseBody, RSGIScope as Scope}, + types::{PyResponse, PyResponseBody, RSGIHTTPScope as HTTPScope, RSGIWebsocketScope as WebsocketScope}, }; use crate::{ callbacks::{ @@ -17,7 +17,7 @@ use crate::{ ws::{HyperWebsocket, UpgradeData}, }; -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackRunnerHTTP { proto: Py, context: TaskLocals, @@ -25,7 +25,7 @@ pub(crate) struct CallbackRunnerHTTP { } impl CallbackRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: HTTPScope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone_ref(py), @@ -51,23 +51,21 @@ impl CallbackRunnerHTTP { } macro_rules! callback_impl_done_http { - ($self:expr, $py:expr) => { - if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { - if let Some(tx) = proto.tx() { - let _ = tx.send(PyResponse::Body(PyResponseBody::empty(500, Vec::new()))); - } + ($self:expr) => { + if let Some(tx) = $self.proto.get().tx() { + let _ = tx.send(PyResponse::Body(PyResponseBody::empty(500, Vec::new()))); } }; } macro_rules! callback_impl_done_err { - ($self:expr, $py:expr, $err:expr) => { - $self.done($py); + ($self:expr, $err:expr) => { + $self.done(); log_application_callable_exception($err); }; } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackTaskHTTP { proto: Py, context: TaskLocals, @@ -86,12 +84,12 @@ impl CallbackTaskHTTP { }) } - fn done(&self, py: Python) { - callback_impl_done_http!(self, py); + fn done(&self) { + callback_impl_done_http!(self); } - fn err(&self, py: Python, err: &PyErr) { - callback_impl_done_err!(self, py, err); + fn err(&self, err: &PyErr) { + callback_impl_done_err!(self, err); } callback_impl_loop_run!(); @@ -109,7 +107,7 @@ impl CallbackTaskHTTP { } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackWrappedRunnerHTTP { #[pyo3(get)] proto: Py, @@ -120,7 +118,7 @@ pub(crate) struct CallbackWrappedRunnerHTTP { } impl CallbackWrappedRunnerHTTP { - pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: HTTPScope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, @@ -138,16 +136,16 @@ impl CallbackWrappedRunnerHTTP { callback_impl_loop_pytask!(pyself, py) } - fn done(&self, py: Python) { - callback_impl_done_http!(self, py); + fn done(&self) { + callback_impl_done_http!(self); } - fn err(&self, py: Python, err: &PyAny) { - callback_impl_done_err!(self, py, &PyErr::from_value(err)); + fn err(&self, err: &PyAny) { + callback_impl_done_err!(self, &PyErr::from_value(err)); } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackRunnerWebsocket { proto: Py, context: TaskLocals, @@ -155,7 +153,7 @@ pub(crate) struct CallbackRunnerWebsocket { } impl CallbackRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone(), @@ -175,14 +173,12 @@ impl CallbackRunnerWebsocket { } macro_rules! callback_impl_done_ws { - ($self:expr, $py:expr) => { - if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { - let _ = proto.close($py, None); - } + ($self:expr) => { + let _ = $self.proto.get().close(None); }; } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackTaskWebsocket { proto: Py, context: TaskLocals, @@ -201,12 +197,12 @@ impl CallbackTaskWebsocket { }) } - fn done(&self, py: Python) { - callback_impl_done_ws!(self, py); + fn done(&self) { + callback_impl_done_ws!(self); } - fn err(&self, py: Python, err: &PyErr) { - callback_impl_done_err!(self, py, err); + fn err(&self, err: &PyErr) { + callback_impl_done_err!(self, err); } callback_impl_loop_run!(); @@ -224,7 +220,7 @@ impl CallbackTaskWebsocket { } } -#[pyclass] +#[pyclass(frozen)] pub(crate) struct CallbackWrappedRunnerWebsocket { #[pyo3(get)] proto: Py, @@ -235,7 +231,7 @@ pub(crate) struct CallbackWrappedRunnerWebsocket { } impl CallbackWrappedRunnerWebsocket { - pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: WebsocketScope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, @@ -253,12 +249,12 @@ impl CallbackWrappedRunnerWebsocket { callback_impl_loop_pytask!(pyself, py) } - fn done(&self, py: Python) { - callback_impl_done_ws!(self, py); + fn done(&self) { + callback_impl_done_ws!(self); } - fn err(&self, py: Python, err: &PyAny) { - callback_impl_done_err!(self, py, &PyErr::from_value(err)); + fn err(&self, err: &PyAny) { + callback_impl_done_err!(self, &PyErr::from_value(err)); } } @@ -268,7 +264,7 @@ macro_rules! call_impl_rtb_http { cb: CallbackWrapper, rt: RuntimeRef, req: HTTPRequest, - scope: Scope, + scope: HTTPScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, tx, req); @@ -288,7 +284,7 @@ macro_rules! call_impl_rtt_http { cb: CallbackWrapper, rt: RuntimeRef, req: HTTPRequest, - scope: Scope, + scope: HTTPScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, tx, req); @@ -311,7 +307,7 @@ macro_rules! call_impl_rtb_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope, + scope: WebsocketScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); @@ -332,7 +328,7 @@ macro_rules! call_impl_rtt_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope, + scope: WebsocketScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); diff --git a/src/rsgi/http.rs b/src/rsgi/http.rs index 61c188ad..0b4aedc8 100644 --- a/src/rsgi/http.rs +++ b/src/rsgi/http.rs @@ -8,7 +8,7 @@ use super::{ call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws, call_rtt_ws_pyw, }, - types::{PyResponse, RSGIScope as Scope}, + types::{PyResponse, RSGIHTTPScope as HTTPScope, RSGIWebsocketScope as WebsocketScope}, }; use crate::{ callbacks::CallbackWrapper, @@ -17,10 +17,9 @@ use crate::{ ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData}, }; -macro_rules! default_scope { - ($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { - Scope::new( - "http", +macro_rules! build_scope { + ($cls:ty, $server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { + <$cls>::new( $req.version(), $scheme, $req.uri().clone(), @@ -55,7 +54,7 @@ macro_rules! handle_request { req: HTTPRequest, scheme: &str, ) -> HTTPResponse { - let scope = default_scope!(server_addr, client_addr, &req, scheme); + let scope = build_scope!(HTTPScope, server_addr, client_addr, &req, scheme); handle_http_response!($handler, rt, callback, req, scope) } }; @@ -71,10 +70,8 @@ macro_rules! handle_request_with_ws { req: HTTPRequest, scheme: &str, ) -> HTTPResponse { - let mut scope = default_scope!(server_addr, client_addr, &req, scheme); - if is_ws_upgrade(&req) { - scope.set_proto("ws"); + let scope = build_scope!(WebsocketScope, server_addr, client_addr, &req, scheme); match ws_upgrade(req, None) { Ok((res, ws)) => { @@ -131,6 +128,7 @@ macro_rules! handle_request_with_ws { } } + let scope = build_scope!(HTTPScope, server_addr, client_addr, &req, scheme); handle_http_response!($handler_req, rt, callback, req, scope) } }; diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index 5dd135f2..dc4f0021 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -3,8 +3,11 @@ use http_body_util::BodyExt; use hyper::body; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyString}; -use std::{borrow::Cow, sync::Arc}; -use tokio::sync::{mpsc, oneshot, Mutex}; +use std::{ + borrow::Cow, + sync::{atomic, Arc, Mutex, RwLock}, +}; +use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex}; use tokio_tungstenite::tungstenite::Message; use super::{ @@ -20,7 +23,7 @@ use crate::{ pub(crate) type WebsocketDetachedTransport = (i32, bool, Option>); -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIHTTPStreamTransport { rt: RuntimeRef, tx: mpsc::Sender>, @@ -56,33 +59,33 @@ impl RSGIHTTPStreamTransport { } } -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIHTTPProtocol { rt: RuntimeRef, - tx: Option>, - body: Option, - body_stream: Option>>>, + tx: Mutex>>, + body: Mutex>, + body_stream: Arc>>>, } impl RSGIHTTPProtocol { pub fn new(rt: RuntimeRef, tx: oneshot::Sender, request: HTTPRequest) -> Self { Self { rt, - tx: Some(tx), - body: Some(request.into_body()), - body_stream: None, + tx: Mutex::new(Some(tx)), + body: Mutex::new(Some(request.into_body())), + body_stream: Arc::new(AsyncMutex::new(None)), } } - pub fn tx(&mut self) -> Option> { - self.tx.take() + pub fn tx(&self) -> Option> { + self.tx.lock().unwrap().take() } } #[pymethods] impl RSGIHTTPProtocol { - fn __call__<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { - if let Some(body) = self.body.take() { + fn __call__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + if let Some(body) = self.body.lock().unwrap().take() { return future_into_py_iter(self.rt.clone(), py, async move { let body = body .collect() @@ -94,70 +97,62 @@ impl RSGIHTTPProtocol { error_proto!() } - fn __aiter__(mut pyself: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { - if let Some(body) = pyself.body.take() { - pyself.body_stream = Some(Arc::new(Mutex::new(http_body_util::BodyStream::new(body)))); + fn __aiter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> { + if let Some(body) = pyself.body.lock().unwrap().take() { + let mut stream = pyself.body_stream.blocking_lock(); + *stream = Some(http_body_util::BodyStream::new(body)); } pyself } - fn __anext__<'p>(&mut self, py: Python<'p>) -> PyResult> { - if let Some(body_ref) = &self.body_stream { - let body_ref = body_ref.clone(); - let fut = future_into_py_iter(self.rt.clone(), py, async move { - let mut bodym = body_ref.lock().await; - let body = &mut *bodym; - match body.next().await { - Some(chunk) => { - let chunk = chunk - .map(|buf| buf.into_data().unwrap_or_default()) - .unwrap_or(body::Bytes::new()); - Ok(BytesToPy(chunk)) - } - _ => Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")), - } - })?; - return Ok(Some(fut)); - } - error_proto!() + fn __anext__<'p>(&self, py: Python<'p>) -> PyResult> { + let body_stream = self.body_stream.clone(); + let pyfut = future_into_py_iter(self.rt.clone(), py, async move { + if let Some(stream) = &mut *body_stream.lock().await { + if let Some(chunk) = stream.next().await { + let chunk = chunk + .map(|buf| buf.into_data().unwrap_or_default()) + .unwrap_or(body::Bytes::new()); + return Ok(BytesToPy(chunk)); + }; + return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")); + } + error_proto!() + })?; + Ok(Some(pyfut)) } #[pyo3(signature = (status=200, headers=vec![]))] - fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) { - if let Some(tx) = self.tx.take() { + fn response_empty(&self, status: u16, headers: Vec<(String, String)>) { + if let Some(tx) = self.tx.lock().unwrap().take() { let _ = tx.send(PyResponse::Body(PyResponseBody::empty(status, headers))); } } #[pyo3(signature = (status=200, headers=vec![], body=vec![].into()))] - fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Cow<[u8]>) { - if let Some(tx) = self.tx.take() { + fn response_bytes(&self, status: u16, headers: Vec<(String, String)>, body: Cow<[u8]>) { + if let Some(tx) = self.tx.lock().unwrap().take() { let _ = tx.send(PyResponse::Body(PyResponseBody::from_bytes(status, headers, body))); } } #[pyo3(signature = (status=200, headers=vec![], body=String::new()))] - fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) { - if let Some(tx) = self.tx.take() { + fn response_str(&self, status: u16, headers: Vec<(String, String)>, body: String) { + if let Some(tx) = self.tx.lock().unwrap().take() { let _ = tx.send(PyResponse::Body(PyResponseBody::from_string(status, headers, body))); } } #[pyo3(signature = (status, headers, file))] - fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) { - if let Some(tx) = self.tx.take() { + fn response_file(&self, status: u16, headers: Vec<(String, String)>, file: String) { + if let Some(tx) = self.tx.lock().unwrap().take() { let _ = tx.send(PyResponse::File(PyResponseFile::new(status, headers, file))); } } #[pyo3(signature = (status=200, headers=vec![]))] - fn response_stream<'p>( - &mut self, - py: Python<'p>, - status: u16, - headers: Vec<(String, String)>, - ) -> PyResult<&'p PyAny> { - if let Some(tx) = self.tx.take() { + fn response_stream<'p>(&self, py: Python<'p>, status: u16, headers: Vec<(String, String)>) -> PyResult<&'p PyAny> { + if let Some(tx) = self.tx.lock().unwrap().take() { let (body_tx, body_rx) = mpsc::channel::>(1); let body_stream = http_body_util::StreamBody::new( tokio_stream::wrappers::ReceiverStream::new(body_rx).map_ok(hyper::body::Frame::data), @@ -174,12 +169,12 @@ impl RSGIHTTPProtocol { } } -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIWebsocketTransport { rt: RuntimeRef, - tx: Arc>, - rx: Arc>, - closed: bool, + tx: Arc>, + rx: Arc>, + closed: atomic::AtomicBool, } impl RSGIWebsocketTransport { @@ -187,17 +182,17 @@ impl RSGIWebsocketTransport { let (tx, rx) = transport.split(); Self { rt, - tx: Arc::new(Mutex::new(tx)), - rx: Arc::new(Mutex::new(rx)), - closed: false, + tx: Arc::new(AsyncMutex::new(tx)), + rx: Arc::new(AsyncMutex::new(rx)), + closed: false.into(), } } - pub fn close(&mut self) -> Option> { - if self.closed { + pub fn close(&self) -> Option> { + if self.closed.load(atomic::Ordering::Relaxed) { return None; } - self.closed = true; + self.closed.store(true, atomic::Ordering::Relaxed); let tx = self.tx.clone(); let handle = self.rt.spawn(async move { @@ -258,14 +253,13 @@ impl RSGIWebsocketTransport { } } -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIWebsocketProtocol { rt: RuntimeRef, - tx: Option>, - websocket: Arc>, - upgrade: Option, + tx: Mutex>>, + websocket: Arc>, + upgrade: RwLock>, transport: Arc>>>, - status: i32, } impl RSGIWebsocketProtocol { @@ -277,16 +271,15 @@ impl RSGIWebsocketProtocol { ) -> Self { Self { rt, - tx: Some(tx), - websocket: Arc::new(Mutex::new(websocket)), - upgrade: Some(upgrade), + tx: Mutex::new(Some(tx)), + websocket: Arc::new(AsyncMutex::new(websocket)), + upgrade: RwLock::new(Some(upgrade)), transport: Arc::new(Mutex::new(None)), - status: 0, } } fn consumed(&self) -> bool { - self.upgrade.is_none() + self.upgrade.read().unwrap().is_none() } } @@ -296,7 +289,7 @@ enum WebsocketMessageType { Text = 2, } -#[pyclass] +#[pyclass(frozen)] struct WebsocketInboundCloseMessage { #[pyo3(get)] kind: usize, @@ -310,7 +303,7 @@ impl WebsocketInboundCloseMessage { } } -#[pyclass] +#[pyclass(frozen)] struct WebsocketInboundBytesMessage { #[pyo3(get)] kind: usize, @@ -327,7 +320,7 @@ impl WebsocketInboundBytesMessage { } } -#[pyclass] +#[pyclass(frozen)] struct WebsocketInboundTextMessage { #[pyo3(get)] kind: usize, @@ -347,25 +340,22 @@ impl WebsocketInboundTextMessage { #[pymethods] impl RSGIWebsocketProtocol { #[pyo3(signature = (status=None))] - pub fn close(&mut self, py: Python, status: Option) { - self.status = status.unwrap_or(0); - if let Some(tx) = self.tx.take() { + pub fn close(&self, status: Option) { + if let Some(tx) = self.tx.lock().unwrap().take() { let mut handle = None; if let Ok(mut transport) = self.transport.try_lock() { if let Some(transport) = transport.take() { - if let Ok(mut trx) = transport.try_borrow_mut(py) { - handle = trx.close(); - } + handle = transport.get().close(); } } - let _ = tx.send((self.status, self.consumed(), handle)); + let _ = tx.send((status.unwrap_or(0), self.consumed(), handle)); } } - fn accept<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { + fn accept<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let rth = self.rt.clone(); - let mut upgrade = self.upgrade.take().unwrap(); + let mut upgrade = self.upgrade.write().unwrap().take().unwrap(); let transport = self.websocket.clone(); let itransport = self.transport.clone(); future_into_py_iter(self.rt.clone(), py, async move { @@ -373,7 +363,7 @@ impl RSGIWebsocketProtocol { match upgrade.send().await { Ok(()) => match (&mut *ws).await { Ok(stream) => { - let mut trx = itransport.lock().await; + let mut trx = itransport.lock().unwrap(); Ok(Python::with_gil(|py| { let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap(); *trx = Some(pytransport.clone()); diff --git a/src/rsgi/mod.rs b/src/rsgi/mod.rs index f79f41c6..5c27c911 100644 --- a/src/rsgi/mod.rs +++ b/src/rsgi/mod.rs @@ -15,7 +15,8 @@ pub(crate) fn init_pymodule(py: Python, module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_class::()?; module.add_class::()?; - module.add_class::()?; + module.add_class::()?; + module.add_class::()?; Ok(()) } diff --git a/src/rsgi/serve.rs b/src/rsgi/serve.rs index 95c41eb7..ed44d502 100644 --- a/src/rsgi/serve.rs +++ b/src/rsgi/serve.rs @@ -8,7 +8,7 @@ use super::http::{ use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal}; -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub struct RSGIWorker { config: WorkerConfig, } diff --git a/src/rsgi/types.rs b/src/rsgi/types.rs index 33ba55cc..7a4f8dc3 100644 --- a/src/rsgi/types.rs +++ b/src/rsgi/types.rs @@ -16,7 +16,9 @@ use tokio_util::io::ReaderStream; use crate::http::{empty_body, response_404, HTTPResponseBody, HV_SERVER}; -#[pyclass(module = "granian._granian")] +const RSGI_PROTO_VERSION: &str = "1.3"; + +#[pyclass(frozen, module = "granian._granian")] #[derive(Clone)] pub(crate) struct RSGIHeaders { inner: HeaderMap, @@ -85,84 +87,90 @@ impl RSGIHeaders { } } -#[pyclass(module = "granian._granian")] -pub(crate) struct RSGIScope { - #[pyo3(get)] - proto: String, - http_version: Version, - #[pyo3(get)] - rsgi_version: String, - #[pyo3(get)] - scheme: String, - #[pyo3(get)] - method: String, - uri: Uri, - #[pyo3(get)] - server: String, - #[pyo3(get)] - client: String, - #[pyo3(get)] - headers: RSGIHeaders, -} +macro_rules! rsgi_scope_cls { + ($name:ident, $proto:expr) => { + #[pyclass(frozen, module = "granian._granian")] + pub(crate) struct $name { + http_version: Version, + #[pyo3(get)] + scheme: String, + #[pyo3(get)] + method: String, + uri: Uri, + #[pyo3(get)] + server: String, + #[pyo3(get)] + client: String, + #[pyo3(get)] + headers: RSGIHeaders, + } -impl RSGIScope { - pub fn new( - proto: &str, - http_version: Version, - scheme: &str, - uri: Uri, - method: &str, - server: SocketAddr, - client: SocketAddr, - headers: &HeaderMap, - ) -> Self { - Self { - proto: proto.to_string(), - http_version, - rsgi_version: "1.3".to_string(), - scheme: scheme.to_string(), - method: method.to_string(), - uri, - server: server.to_string(), - client: client.to_string(), - headers: RSGIHeaders::new(headers), + impl $name { + pub fn new( + http_version: Version, + scheme: &str, + uri: Uri, + method: &str, + server: SocketAddr, + client: SocketAddr, + headers: &HeaderMap, + ) -> Self { + Self { + http_version, + scheme: scheme.to_string(), + method: method.to_string(), + uri, + server: server.to_string(), + client: client.to_string(), + headers: RSGIHeaders::new(headers), + } + } } - } - pub fn set_proto(&mut self, value: &str) { - self.proto = value.to_string(); - } -} + #[pymethods] + impl $name { + #[getter(proto)] + fn get_proto(&self) -> &str { + $proto + } -#[pymethods] -impl RSGIScope { - #[getter(http_version)] - fn get_http_version(&self) -> &str { - match self.http_version { - Version::HTTP_10 => "1", - Version::HTTP_11 => "1.1", - Version::HTTP_2 => "2", - Version::HTTP_3 => "3", - _ => "1", - } - } + #[getter(rsgi_version)] + fn get_rsgi_version(&self) -> &str { + RSGI_PROTO_VERSION + } - #[getter(authority)] - fn get_authority(&self) -> Option { - self.uri.authority().map(Authority::to_string) - } + #[getter(http_version)] + fn get_http_version(&self) -> &str { + match self.http_version { + Version::HTTP_10 => "1", + Version::HTTP_11 => "1.1", + Version::HTTP_2 => "2", + Version::HTTP_3 => "3", + _ => "1", + } + } - #[getter(path)] - fn get_path(&self) -> Result> { - Ok(percent_decode_str(self.uri.path()).decode_utf8()?) - } + #[getter(authority)] + fn get_authority(&self) -> Option { + self.uri.authority().map(Authority::to_string) + } - #[getter(query_string)] - fn get_query_string(&self) -> &str { - self.uri.query().unwrap_or("") - } + #[getter(path)] + fn get_path(&self) -> Result> { + Ok(percent_decode_str(self.uri.path()).decode_utf8()?) + } + + #[getter(query_string)] + fn get_query_string(&self) -> &str { + self.uri.query().unwrap_or("") + } + } + }; } +rsgi_scope_cls!(RSGIHTTPScope, "http"); +rsgi_scope_cls!(RSGIWebsocketScope, "ws"); + pub(crate) enum PyResponse { Body(PyResponseBody), File(PyResponseFile), diff --git a/src/runtime.rs b/src/runtime.rs index ff463a67..6c3627f3 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -9,11 +9,14 @@ use std::{ }; use tokio::{ runtime::Builder, - sync::Notify, task::{JoinHandle, LocalSet}, }; -use super::callbacks::{PyEmptyAwaitable, PyFutureAwaitable, PyIterAwaitable}; +#[cfg(unix)] +use super::callbacks::PyFutureAwaitable; +use super::callbacks::{PyEmptyAwaitable, PyIterAwaitable}; +#[cfg(windows)] +use super::callbacks::{PyFutureDoneCallback, PyFutureResultSetter}; tokio::task_local! { static TASK_LOCALS: OnceCell; @@ -197,15 +200,12 @@ where F: Future> + Send + 'static, T: IntoPy, { - let aw = PyIterAwaitable::new(); - let py_aw = Py::new(py, aw)?; - let py_fut = py_aw.clone(); + let aw = Py::new(py, PyIterAwaitable::new())?; + let py_fut = aw.clone_ref(py); rt.spawn(async move { let result = fut.await; - Python::with_gil(move |py| { - py_aw.borrow_mut(py).set_result(result.map(|v| v.into_py(py))); - }); + aw.get().set_result(result); }); Ok(py_fut.into_ref(py)) @@ -217,6 +217,7 @@ where // It won't consume more cpu-cycles than standard asyncio implementation, // and for "long" operations it's something like 6% faster than `future_into_py_iter`. #[allow(unused_must_use)] +#[cfg(unix)] pub(crate) fn future_into_py_futlike(rt: R, py: Python, fut: F) -> PyResult<&PyAny> where R: Runtime + ContextExt + Clone, @@ -225,37 +226,64 @@ where { let task_locals = get_current_locals::(py)?; let event_loop = task_locals.event_loop(py).to_object(py); - let event_loop_aw = event_loop.clone(); - let fut_spawner = move |context: Option, cancel_tx: Arc, aw: Py| { - rt.spawn(async move { - let result = tokio::select! { - result = fut => { - result - }, - () = cancel_tx.notified() => { - Err(pyo3::exceptions::asyncio::CancelledError::new_err("Task cancelled")) - } - }; - - Python::with_gil(|py| { - if let Some(cb) = PyFutureAwaitable::set_result(aw.borrow_mut(py), result.map(|v| v.into_py(py))) { - let kwctx = pyo3::types::PyDict::new(py); - kwctx.set_item(pyo3::intern!(py, "context"), context).unwrap(); - let _ = - event_loop.call_method(py, pyo3::intern!(py, "call_soon_threadsafe"), (cb, aw), Some(kwctx)); - } - }); - }); - }; + let (aw, cancel_tx) = PyFutureAwaitable::new(event_loop).to_spawn(py)?; + let aw_ref = aw.clone_ref(py); + let py_fut = aw.clone_ref(py); + + rt.spawn(async move { + tokio::select! { + result = fut => aw.get().set_result(result, aw_ref), + () = cancel_tx.notified() => {} + } + }); + + Ok(py_fut.into_ref(py)) +} + +#[allow(unused_must_use)] +#[cfg(windows)] +pub(crate) fn future_into_py_futlike(rt: R, py: Python, fut: F) -> PyResult<&PyAny> +where + R: Runtime + ContextExt + Clone, + F: Future> + Send + 'static, + T: IntoPy, +{ + let task_locals = get_current_locals::(py)?; + let event_loop = task_locals.event_loop(py); + let event_loop_ref = event_loop.to_object(py); + let cancel_tx = Arc::new(tokio::sync::Notify::new()); + + let py_fut = event_loop.call_method0(pyo3::intern!(py, "create_future"))?; + py_fut.call_method1( + pyo3::intern!(py, "add_done_callback"), + (PyFutureDoneCallback { + cancel_tx: cancel_tx.clone(), + },), + )?; + let fut_ref = PyObject::from(py_fut); + + rt.spawn(async move { + tokio::select! { + result = fut => { + Python::with_gil(|py| { + let (cb, value) = match result { + Ok(val) => (fut_ref.getattr(py, pyo3::intern!(py, "set_result")).unwrap(), val.into_py(py)), + Err(err) => (fut_ref.getattr(py, pyo3::intern!(py, "set_exception")).unwrap(), err.into_py(py)) + }; + let _ = event_loop_ref.call_method1(py, pyo3::intern!(py, "call_soon_threadsafe"), (PyFutureResultSetter, cb, value)); + }); + }, + () = cancel_tx.notified() => {} + } + }); - let aw = PyFutureAwaitable::new(Box::new(fut_spawner), event_loop_aw); - Ok(aw.into_py(py).into_ref(py)) + Ok(py_fut) } #[allow(clippy::unnecessary_wraps)] #[inline(always)] pub(crate) fn empty_future_into_py(py: Python) -> PyResult<&PyAny> { - Ok(PyEmptyAwaitable {}.into_py(py).into_ref(py)) + Ok(PyEmptyAwaitable.into_py(py).into_ref(py)) } #[allow(unused_must_use)] diff --git a/src/tcp.rs b/src/tcp.rs index 8983a02a..57dfda64 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -9,7 +9,7 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket}; use socket2::{Domain, Protocol, Socket, Type}; -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub struct ListenerHolder { socket: TcpListener, } diff --git a/src/workers.rs b/src/workers.rs index 535ff17a..713ae9e1 100644 --- a/src/workers.rs +++ b/src/workers.rs @@ -1,5 +1,6 @@ use pyo3::prelude::*; use std::net::TcpListener; +use std::sync::Mutex; #[cfg(unix)] use std::os::unix::io::FromRawFd; @@ -11,9 +12,9 @@ use super::rsgi::serve::RSGIWorker; use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey}; use super::wsgi::serve::WSGIWorker; -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct WorkerSignal { - pub rx: Option>, + pub rx: Mutex>>, tx: tokio::sync::watch::Sender, } @@ -22,7 +23,10 @@ impl WorkerSignal { #[new] fn new() -> Self { let (tx, rx) = tokio::sync::watch::channel(false); - Self { rx: Some(rx), tx } + Self { + rx: Mutex::new(Some(rx)), + tx, + } } fn set(&self) { @@ -326,7 +330,7 @@ macro_rules! serve_rth { let http1_opts = self.config.http1_opts.clone(); let http2_opts = self.config.http2_opts.clone(); let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); - let mut pyrx = Python::with_gil(|py| signal.borrow_mut(py).rx.take().unwrap()); + let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); @@ -453,7 +457,7 @@ macro_rules! serve_rth_ssl { let http2_opts = self.config.http2_opts.clone(); let tls_cfg = self.config.tls_cfg(); let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); - let mut pyrx = Python::with_gil(|py| signal.borrow_mut(py).rx.take().unwrap()); + let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); @@ -581,7 +585,7 @@ macro_rules! serve_wth { log::info!("Started worker-{}", worker_id); let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); - let mut pyrx = Python::with_gil(|py| signal.borrow_mut(py).rx.take().unwrap()); + let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); let (stx, srx) = tokio::sync::watch::channel(false); let mut workers = vec![]; @@ -733,7 +737,7 @@ macro_rules! serve_wth_ssl { log::info!("Started worker-{}", worker_id); let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); - let mut pyrx = Python::with_gil(|py| signal.borrow_mut(py).rx.take().unwrap()); + let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); let (stx, srx) = tokio::sync::watch::channel(false); let mut workers = vec![]; diff --git a/src/wsgi/serve.rs b/src/wsgi/serve.rs index 99a7c333..3b2b64eb 100644 --- a/src/wsgi/serve.rs +++ b/src/wsgi/serve.rs @@ -5,7 +5,7 @@ use super::http::{handle_rtb, handle_rtt}; use crate::conversion::{worker_http1_config_from_py, worker_http2_config_from_py}; use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig, WorkerSignal}; -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub struct WSGIWorker { config: WorkerConfig, } diff --git a/src/wsgi/types.rs b/src/wsgi/types.rs index 544913c0..1f8377a5 100644 --- a/src/wsgi/types.rs +++ b/src/wsgi/types.rs @@ -11,23 +11,27 @@ use pyo3::types::{PyBytes, PyDict, PyList}; use pyo3::{prelude::*, types::IntoPyDict}; use std::{ borrow::Cow, + cell::RefCell, convert::Infallible, net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex}, task::{Context, Poll}, }; -use crate::http::HTTPRequest; +use crate::{conversion::BytesToPy, http::HTTPRequest}; const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n"); -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct WSGIBody { - inner: Bytes, + inner: RefCell, } impl WSGIBody { pub fn new(body: Bytes) -> Self { - Self { inner: body } + Self { + inner: RefCell::new(body), + } } } @@ -37,70 +41,69 @@ impl WSGIBody { pyself } - fn __next__<'p>(&mut self, py: Python<'p>) -> Option<&'p PyBytes> { - match self.inner.iter().position(|&c| c == LINE_SPLIT) { - Some(next_split) => { - let bytes = self.inner.split_to(next_split); - Some(PyBytes::new(py, &bytes)) - } - _ => None, - } + fn __next__(&self) -> Option { + let mut inner = self.inner.borrow_mut(); + inner + .iter() + .position(|&c| c == LINE_SPLIT) + .map(|next_split| BytesToPy(inner.split_to(next_split))) } #[pyo3(signature = (size=None))] - fn read<'p>(&mut self, py: Python<'p>, size: Option) -> &'p PyBytes { + fn read(&self, size: Option) -> BytesToPy { match size { None => { - let bytes = self.inner.split_to(self.inner.len()); - PyBytes::new(py, &bytes[..]) + let mut inner = self.inner.borrow_mut(); + let len = inner.len(); + BytesToPy(inner.split_to(len)) } Some(size) => match size { - 0 => PyBytes::new(py, b""), + 0 => BytesToPy(Bytes::new()), size => { - let limit = self.inner.len(); + let mut inner = self.inner.borrow_mut(); + let limit = inner.len(); let rsize = if size > limit { limit } else { size }; - let bytes = self.inner.split_to(rsize); - PyBytes::new(py, &bytes[..]) + BytesToPy(inner.split_to(rsize)) } }, } } - fn readline<'p>(&mut self, py: Python<'p>) -> &'p PyBytes { - match self.inner.iter().position(|&c| c == LINE_SPLIT) { + fn readline(&self) -> BytesToPy { + let mut inner = self.inner.borrow_mut(); + match inner.iter().position(|&c| c == LINE_SPLIT) { Some(next_split) => { - let bytes = self.inner.split_to(next_split); - self.inner = self.inner.slice(1..); - PyBytes::new(py, &bytes[..]) + let bytes = inner.split_to(next_split); + *inner = inner.slice(1..); + BytesToPy(bytes) } - _ => PyBytes::new(py, b""), + _ => BytesToPy(Bytes::new()), } } #[pyo3(signature = (_hint=None))] - fn readlines<'p>(&mut self, py: Python<'p>, _hint: Option) -> &'p PyList { - let lines: Vec<&PyBytes> = self - .inner + fn readlines<'p>(&self, py: Python<'p>, _hint: Option) -> &'p PyList { + let mut inner = self.inner.borrow_mut(); + let lines: Vec<&PyBytes> = inner .split(|&c| c == LINE_SPLIT) .map(|item| PyBytes::new(py, item)) .collect(); - self.inner.clear(); + inner.clear(); PyList::new(py, lines) } } -// TODO: use Arc instead of strings? -#[pyclass(module = "granian._granian")] +#[pyclass(frozen, module = "granian._granian")] pub(crate) struct WSGIScope { http_version: Version, - scheme: String, - method: String, + scheme: Arc, + method: Arc, uri: Uri, server_ip: IpAddr, server_port: u16, client: String, - headers: HeaderMap, - body: Option, + headers: Mutex, + body: Mutex>, } impl WSGIScope { @@ -108,7 +111,7 @@ impl WSGIScope { let http_version = request.version(); let method = request.method().clone(); let uri = request.uri().clone(); - let headers = request.headers().clone(); + let headers = Mutex::new(request.headers().clone()); let body = match method { Method::HEAD | Method::GET | Method::OPTIONS => Bytes::new(), @@ -120,14 +123,14 @@ impl WSGIScope { Self { http_version, - scheme: scheme.to_string(), - method: method.to_string(), + scheme: scheme.into(), + method: method.as_str().into(), uri, server_ip: server.ip(), server_port: server.port(), client: client.to_string(), headers, - body: Some(body), + body: Mutex::new(Some(body)), } } @@ -146,62 +149,50 @@ impl WSGIScope { #[pymethods] impl WSGIScope { - fn to_environ<'p>(&mut self, py: Python<'p>, ret: &'p PyDict) -> PyResult<&'p PyDict> { - let ( - path, - query_string, - http_version, - server, - client, - scheme, - method, - content_type, - content_len, - headers, - body, - ) = py.allow_threads(|| { - let (path, query_string) = self - .uri - .path_and_query() - .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); - let content_type = self.headers.remove(CONTENT_TYPE); - let content_len = self.headers.remove(CONTENT_LENGTH); - let mut headers = Vec::with_capacity(self.headers.len()); - - for (key, val) in &self.headers { - headers.push(( - format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()), - val.to_str().unwrap_or_default(), - )); - } - if !self.headers.contains_key(HOST) { - let host = self.uri.authority().map_or("", Authority::as_str); - headers.push(("HTTP_HOST".to_string(), host)); - } + fn to_environ<'p>(&self, py: Python<'p>, ret: &'p PyDict) -> PyResult<&'p PyDict> { + let (path, query_string, http_version, server, client, content_type, content_len, headers, body) = py + .allow_threads(|| { + let (path, query_string) = self + .uri + .path_and_query() + .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); + let mut source_headers = self.headers.lock().unwrap(); + let content_type = source_headers.remove(CONTENT_TYPE); + let content_len = source_headers.remove(CONTENT_LENGTH); + let mut headers = Vec::with_capacity(source_headers.len()); + + for (key, val) in source_headers.iter() { + headers.push(( + format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()), + val.to_str().unwrap_or_default().to_owned(), + )); + } + if !source_headers.contains_key(HOST) { + let host = self.uri.authority().map_or("", Authority::as_str); + headers.push(("HTTP_HOST".to_string(), host.to_owned())); + } - ( - percent_decode_str(path).decode_utf8().unwrap(), - query_string, - self.py_http_version(), - (self.server_ip.to_string(), self.server_port.to_string()), - &self.client[..], - &self.scheme[..], - &self.method[..], - content_type, - content_len, - headers, - WSGIBody::new(self.body.take().unwrap()), - ) - }); + ( + percent_decode_str(path).decode_utf8().unwrap(), + query_string, + self.py_http_version(), + (self.server_ip.to_string(), self.server_port.to_string()), + &self.client[..], + content_type, + content_len, + headers, + WSGIBody::new(self.body.lock().unwrap().take().unwrap()), + ) + }); ret.set_item(pyo3::intern!(py, "SERVER_PROTOCOL"), http_version)?; ret.set_item(pyo3::intern!(py, "SERVER_NAME"), server.0)?; ret.set_item(pyo3::intern!(py, "SERVER_PORT"), server.1)?; ret.set_item(pyo3::intern!(py, "REMOTE_ADDR"), client)?; - ret.set_item(pyo3::intern!(py, "REQUEST_METHOD"), method)?; + ret.set_item(pyo3::intern!(py, "REQUEST_METHOD"), &*self.method)?; ret.set_item(pyo3::intern!(py, "PATH_INFO"), path)?; ret.set_item(pyo3::intern!(py, "QUERY_STRING"), query_string)?; - ret.set_item(pyo3::intern!(py, "wsgi.url_scheme"), scheme)?; + ret.set_item(pyo3::intern!(py, "wsgi.url_scheme"), &*self.scheme)?; ret.set_item(pyo3::intern!(py, "wsgi.input"), Py::new(py, body)?)?; if let Some(content_type) = content_type { @@ -232,6 +223,7 @@ impl WSGIResponseBodyIter { Self { inner: body } } + #[inline] fn close_inner(&self, py: Python) { let _ = self.inner.call_method0(py, pyo3::intern!(py, "close")); } @@ -241,20 +233,21 @@ impl Stream for WSGIResponseBodyIter { type Item = Result, Infallible>; fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) { + let ret = Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) { Ok(chunk_obj) => match chunk_obj.extract::>(py) { - Ok(chunk) => Poll::Ready(Some(Ok(chunk.into()))), + Ok(chunk) => Some(Ok(chunk.into())), _ => { self.close_inner(py); - Poll::Ready(None) + None } }, Err(err) => { if err.is_instance_of::(py) { self.close_inner(py); } - Poll::Ready(None) + None } - }) + }); + Poll::Ready(ret) } }