diff --git a/xpra/net/bytestreams.py b/xpra/net/bytestreams.py index 7b3f0bf58f..4da0f1e391 100644 --- a/xpra/net/bytestreams.py +++ b/xpra/net/bytestreams.py @@ -215,7 +215,7 @@ def read(self, n): self.may_abort("read") return self._read(os.read, self._read_fd, n) - def write(self, buf): + def write(self, buf, packet_type=None): self.may_abort("write") return self._write(os.write, self._write_fd, buf) @@ -344,7 +344,7 @@ def peek(self, n : int): def read(self, n : int): return self._read(self._socket.recv, n) - def write(self, buf): + def write(self, buf, packet_type=None): return self._write(self._socket.send, buf) def close(self): diff --git a/xpra/net/protocol/socket_handler.py b/xpra/net/protocol/socket_handler.py index 7edf19a937..fea7b0e76c 100644 --- a/xpra/net/protocol/socket_handler.py +++ b/xpra/net/protocol/socket_handler.py @@ -351,7 +351,7 @@ def start_network_read_thread(): self._read_thread.start() self.idle_add(start_network_read_thread) if SEND_INVALID_PACKET: - self.timeout_add(SEND_INVALID_PACKET*1000, self.raw_write, "invalid", SEND_INVALID_PACKET_DATA) + self.timeout_add(SEND_INVALID_PACKET*1000, self.raw_write, SEND_INVALID_PACKET_DATA) def send_disconnect(self, reasons, done_callback=None): @@ -482,7 +482,7 @@ def _add_chunks_to_queue(self, packet_type, chunks, items[0] = frame_header + item0 else: items.insert(0, frame_header) - self.raw_write(packet_type, items, start_send_cb, end_send_cb, fail_cb, synchronous, more) + self.raw_write(items, packet_type, start_send_cb, end_send_cb, fail_cb, synchronous, more) def make_xpra_header(self, _packet_type, proto_flags, level, index, payload_size) -> bytes: return pack_header(proto_flags, level, index, payload_size) @@ -496,12 +496,12 @@ def start_write_thread(self): assert not self._write_thread, "write thread already started" self._write_thread = start_thread(self._write_thread_loop, "write", daemon=True) - def raw_write(self, packet_type, items, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False): + def raw_write(self, items, packet_type=None, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False): """ Warning: this bypasses the compression and packet encoder! """ if self._write_thread is None: log("raw_write for %s, starting write thread", packet_type) self.start_write_thread() - self._write_queue.put((items, start_cb, end_cb, fail_cb, synchronous, more)) + self._write_queue.put((items, packet_type, start_cb, end_cb, fail_cb, synchronous, more)) def enable_default_encoder(self): @@ -710,7 +710,7 @@ def _write(self): return False return self.write_items(*items) - def write_items(self, buf_data, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False): + def write_items(self, buf_data, packet_type=None, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False): conn = self._conn if not conn: return False @@ -729,7 +729,7 @@ def write_items(self, buf_data, start_cb=None, end_cb=None, fail_cb=None, synchr except Exception: if not self._closed: log.error(f"Error on write start callback {start_cb}", exc_info=True) - self.write_buffers(buf_data, fail_cb, synchronous) + self.write_buffers(buf_data, packet_type, fail_cb, synchronous) try: if len(buf_data)>1: conn.set_cork(False) @@ -747,13 +747,13 @@ def write_items(self, buf_data, start_cb=None, end_cb=None, fail_cb=None, synchr log.error(f"Error on write end callback {end_cb}", exc_info=True) return True - def write_buffers(self, buf_data, _fail_cb, _synchronous): + def write_buffers(self, buf_data, packet_type, _fail_cb, _synchronous): con = self._conn if not con: return for buf in buf_data: while buf and not self._closed: - written = self.con_write(con, buf) + written = self.con_write(con, buf, packet_type) #example test code, for sending small chunks very slowly: #written = con.write(buf[:1024]) #import time @@ -763,8 +763,8 @@ def write_buffers(self, buf_data, _fail_cb, _synchronous): self.output_raw_packetcount += 1 self.output_packetcount += 1 - def con_write(self, con, buf): - return con.write(buf) + def con_write(self, con, buf, packet_type): + return con.write(buf, packet_type) def _read_thread_loop(self): @@ -1194,7 +1194,7 @@ def packet_queued(*_args): start_send_cb=None, end_send_cb=packet_queued, synchronous=False, more=False) else: - self.raw_write("flush-then-close", (last_packet, )) + self.raw_write((last_packet, ), "flush-then-close") #just in case wait_for_packet_sent never fires: self.timeout_add(5*1000, close_and_release) diff --git a/xpra/net/quic/client.py b/xpra/net/quic/client.py index 67a4ca3055..e63bea733f 100644 --- a/xpra/net/quic/client.py +++ b/xpra/net/quic/client.py @@ -17,6 +17,7 @@ DataReceived, H3Event, HeadersReceived, + PushPromiseReceived, ) from aioquic.tls import SessionTicket from aioquic.quic.logger import QuicLogger @@ -28,7 +29,6 @@ from xpra.net.quic.asyncio_thread import get_threaded_loop from xpra.net.quic.common import USER_AGENT, binary_headers from xpra.util import ellipsizer, envbool -from xpra.os_util import memoryview_to_bytes from xpra.log import Logger log = Logger("quic") @@ -60,22 +60,23 @@ def __init__(self, connection : HttpConnection, stream_id: int, transmit: Callab def flush_writes(self): #flush the buffered writes: - while self.write_buffer.qsize()>0: - buf = self.write_buffer.get() - self.connection.send_data(self.stream_id, memoryview_to_bytes(buf), end_stream=False) - self.transmit() - self.write_buffer = None - - def write(self, buf): - log(f"write(%s) {len(buf)} bytes", ellipsizer(buf)) + try: + while self.write_buffer.qsize()>0: + self.stream_write(*self.write_buffer.get()) + finally: + self.transmit() + self.write_buffer = None + + def write(self, buf, packet_type=None): + log(f"write(%s, %s) {len(buf)} bytes", ellipsizer(buf), packet_type) if self.write_buffer is not None: #buffer it until we are connected and call flush_writes() - self.write_buffer.put(buf) + self.write_buffer.put((buf, packet_type)) return len(buf) - return super().write(buf) + return super().write(buf, packet_type) def http_event_received(self, event: H3Event) -> None: - log("http_event_received(%s)", event) + log("http_event_received(%s)", ellipsizer(event)) if isinstance(event, HeadersReceived): for header, value in event.headers: if header == b"sec-websocket-protocol": @@ -87,6 +88,10 @@ def http_event_received(self, event: H3Event) -> None: self.accepted = True self.flush_writes() return + if isinstance(event, PushPromiseReceived): + log(f"PushPromiseReceived: {event}") + log(f"PushPromiseReceived headers: {event.headers}") + return super().http_event_received(event) @@ -94,6 +99,7 @@ class WebSocketClient(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._http: Optional[HttpConnection] = None + self._push_types: Dict[str, int] = {} self._websockets: Dict[int, ClientWebSocketConnection] = {} if self._quic.configuration.alpn_protocols[0].startswith("hq-"): self._http = H0Connection(self._quic) @@ -121,16 +127,29 @@ def quic_event_received(self, event: QuicEvent) -> None: self.http_event_received(http_event) def http_event_received(self, event: H3Event) -> None: - if isinstance(event, (HeadersReceived, DataReceived)): - stream_id = event.stream_id - if stream_id in self._websockets: - # websocket - websocket : ClientWebSocketConnection = self._websockets[stream_id] - websocket.http_event_received(event) - else: - log.warn(f"Warning: unexpected websocket stream id: {stream_id}") - else: + if not isinstance(event, (HeadersReceived, DataReceived, PushPromiseReceived)): log.warn(f"Warning: unexpected http event type: {event}") + return + stream_id = event.stream_id + websocket : Optional[ClientWebSocketConnection] = self._websockets.get(stream_id) + if not websocket: + #perhaps this is a new substream? + sub = -1 + hdict = {} + if isinstance(event, HeadersReceived): + hdict = dict((k.decode(),v.decode()) for k,v in event.headers) + sub = int(hdict.get("substream", -1)) + if sub<0: + log.warn(f"Warning: unexpected websocket stream id: {stream_id} in {event}") + return + websocket = self._websockets.get(sub) + if not websocket: + log.warn(f"Warning: stream {sub} not found in {self._websockets}") + return + subtype = hdict.get("stream-type") + log.info(f"new substream {stream_id} for {subtype}") + self._websockets[stream_id] = websocket + websocket.http_event_received(event) def quic_connect(host : str, port : int, path : str, diff --git a/xpra/net/quic/connection.py b/xpra/net/quic/connection.py index 3d53b562bc..e3db9062bd 100644 --- a/xpra/net/quic/connection.py +++ b/xpra/net/quic/connection.py @@ -32,6 +32,7 @@ def __init__(self, connection: HttpConnection, stream_id: int, transmit: Callabl self.transmit: Callable[[], None] = transmit self.accepted : bool = False self.closed : bool = False + self._packet_type_streams = {} def __repr__(self): return f"XpraQuicConnection<{self.stream_id}>" @@ -70,29 +71,71 @@ def http_event_received(self, event: H3Event) -> None: def close(self): log("XpraQuicConnection.close()") if not self.closed: + self.closed = True self.send_close() Connection.close(self) + self._packet_type_streams = {} def send_close(self, code : int = 1000, reason : str = ""): - self.closed = True if self.accepted: data = close_packet(code, reason) - self.write(data) + self.write("close", data) else: - self.send_headers({":status" : code}) + self.send_headers(self.stream_id, headers={":status" : code}) self.transmit() - def send_headers(self, headers : dict): - #HttpConnection takes a pair of byte strings: - self.connection.send_headers(stream_id=self.stream_id, headers=binary_headers(headers), end_stream=self.closed) + def send_headers(self, stream_id : int, headers : dict): + self.connection.send_headers( + stream_id=stream_id, + headers=binary_headers(headers), + end_stream=self.closed) - def write(self, buf): - log("XpraQuicConnection.write(%s)", ellipsizer(buf)) + def write(self, buf, packet_type=None): + log("XpraQuicConnection.write(%s, %s)", ellipsizer(buf), packet_type) + try: + return self.stream_write(buf, packet_type) + finally: + self.transmit() + + def stream_write(self, buf, packet_type): data = memoryview_to_bytes(buf) - self.connection.send_data(stream_id=self.stream_id, data=data, end_stream=self.closed) - self.transmit() + stream_id = self.get_packet_stream_id(packet_type) + log("XpraQuicConnection.stream_write(%s, %s) using stream id %s", + ellipsizer(buf), packet_type, stream_id) + self.connection.send_data(stream_id=stream_id, data=data, end_stream=self.closed) return len(buf) + def get_packet_stream_id(self, packet_type): + stream_type = { + "sound-data" : "sound", + "ping" : "ping", + "ping-echo" : "ping", + }.get(packet_type) + stream_id = self._packet_type_streams.setdefault(stream_type, self.stream_id) + if stream_type and stream_id==self.stream_id: + if self.closed: + raise RuntimeError(f"cannot send {packet_type} after connection is closed") + log(f"new quic stream for {packet_type}") + #should use more "correct" values here + #(we don't need those headers, + # but the client would drop the packet without them..) + headers = binary_headers({ + ":method" : "foo", + ":scheme" : "https", + ":authority" : "bar", + ":path" : "/", + }) + stream_id = self.connection.send_push_promise(self.stream_id, headers) + log.error(f"new stream: {stream_id}") + self._packet_type_streams[stream_type] = stream_id + self.send_headers(stream_id=stream_id, headers={ + ":status" : 200, + "substream" : self.stream_id, + "stream-type" : stream_type, + }) + return stream_id + + def read(self, n): log("XpraQuicConnection.read(%s)", n) return self.read_queue.get() diff --git a/xpra/net/quic/websocket.py b/xpra/net/quic/websocket.py index 03e1a62925..5f09a8e70f 100644 --- a/xpra/net/quic/websocket.py +++ b/xpra/net/quic/websocket.py @@ -39,16 +39,16 @@ def http_event_received(self, event: H3Event) -> None: self.close() return log.info("websocket request at %s", self.scope.get("path", "/")) - self.send_accept() + self.accepted = True + self.send_accept(self.stream_id) + self.transmit() return super().http_event_received(event) - def send_accept(self): - self.accepted = True - self.send_headers({ + def send_accept(self, stream_id : int): + self.send_headers(stream_id=stream_id, headers={ ":status" : 200, "server" : SERVER_NAME, "date" : http_date(), "sec-websocket-protocol" : "xpra", }) - self.transmit()