Skip to content
This repository has been archived by the owner on Nov 23, 2017. It is now read-only.

Raise RuntimeError when transport's FD is used with add_reader etc #420

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 78 additions & 48 deletions asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
import socket
import warnings
import weakref
try:
import ssl
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(self, selector=None):
logger.debug('Using selector: %s', selector.__class__.__name__)
self._selector = selector
self._make_self_pipe()
self._transports = weakref.WeakValueDictionary()

def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None):
Expand Down Expand Up @@ -115,7 +117,7 @@ def _socketpair(self):
raise NotImplementedError

def _close_self_pipe(self):
self.remove_reader(self._ssock.fileno())
self._remove_reader(self._ssock.fileno())
self._ssock.close()
self._ssock = None
self._csock.close()
Expand All @@ -128,7 +130,7 @@ def _make_self_pipe(self):
self._ssock.setblocking(False)
self._csock.setblocking(False)
self._internal_fds += 1
self.add_reader(self._ssock.fileno(), self._read_from_self)
self._add_reader(self._ssock.fileno(), self._read_from_self)

def _process_self_data(self, data):
pass
Expand Down Expand Up @@ -163,8 +165,8 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100):
self.add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog)
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog)

def _accept_connection(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100):
Expand Down Expand Up @@ -194,7 +196,7 @@ def _accept_connection(self, protocol_factory, sock,
'exception': exc,
'socket': sock,
})
self.remove_reader(sock.fileno())
self._remove_reader(sock.fileno())
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
Expand Down Expand Up @@ -244,8 +246,17 @@ def _accept_connection2(self, protocol_factory, conn, extra,
context['transport'] = transport
self.call_exception_handler(context)

def add_reader(self, fd, callback, *args):
"""Add a reader callback."""
def _ensure_fd_no_transport(self, fd):
try:
transport = self._transports[fd]
except KeyError:
pass
else:
raise RuntimeError(
'File descriptor {!r} is used by transport {!r}'.format(
fd, transport))

def _add_reader(self, fd, callback, *args):
self._check_closed()
handle = events.Handle(callback, args, self)
try:
Expand All @@ -260,8 +271,7 @@ def add_reader(self, fd, callback, *args):
if reader is not None:
reader.cancel()

def remove_reader(self, fd):
"""Remove a reader callback."""
def _remove_reader(self, fd):
if self.is_closed():
return False
try:
Expand All @@ -282,8 +292,7 @@ def remove_reader(self, fd):
else:
return False

def add_writer(self, fd, callback, *args):
"""Add a writer callback.."""
def _add_writer(self, fd, callback, *args):
self._check_closed()
handle = events.Handle(callback, args, self)
try:
Expand All @@ -298,7 +307,7 @@ def add_writer(self, fd, callback, *args):
if writer is not None:
writer.cancel()

def remove_writer(self, fd):
def _remove_writer(self, fd):
"""Remove a writer callback."""
if self.is_closed():
return False
Expand All @@ -321,6 +330,26 @@ def remove_writer(self, fd):
else:
return False

def add_reader(self, fd, callback, *args):
"""Add a reader callback."""
self._ensure_fd_no_transport(fd)
return self._add_reader(fd, callback, *args)

def remove_reader(self, fd):
"""Remove a reader callback."""
self._ensure_fd_no_transport(fd)
return self._remove_reader(fd)

def add_writer(self, fd, callback, *args):
"""Add a writer callback.."""
self._ensure_fd_no_transport(fd)
return self._add_writer(fd, callback, *args)

def remove_writer(self, fd):
"""Remove a writer callback."""
self._ensure_fd_no_transport(fd)
return self._remove_writer(fd)

def sock_recv(self, sock, n):
"""Receive data from the socket.

Expand Down Expand Up @@ -494,17 +523,17 @@ def _process_events(self, event_list):
fileobj, (reader, writer) = key.fileobj, key.data
if mask & selectors.EVENT_READ and reader is not None:
if reader._cancelled:
self.remove_reader(fileobj)
self._remove_reader(fileobj)
else:
self._add_callback(reader)
if mask & selectors.EVENT_WRITE and writer is not None:
if writer._cancelled:
self.remove_writer(fileobj)
self._remove_writer(fileobj)
else:
self._add_callback(writer)

def _stop_serving(self, sock):
self.remove_reader(sock.fileno())
self._remove_reader(sock.fileno())
sock.close()


Expand Down Expand Up @@ -539,6 +568,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
self._closing = False # Set when close() called.
if self._server is not None:
self._server._attach()
loop._transports[self._sock_fd] = self

def __repr__(self):
info = [self.__class__.__name__]
Expand Down Expand Up @@ -584,10 +614,10 @@ def close(self):
if self._closing:
return
self._closing = True
self._loop.remove_reader(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
if not self._buffer:
self._conn_lost += 1
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
self._loop.call_soon(self._call_connection_lost, None)

# On Python 3.3 and older, objects with a destructor part of a reference
Expand Down Expand Up @@ -618,10 +648,10 @@ def _force_close(self, exc):
return
if self._buffer:
self._buffer.clear()
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
if not self._closing:
self._closing = True
self._loop.remove_reader(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
self._conn_lost += 1
self._loop.call_soon(self._call_connection_lost, exc)

Expand Down Expand Up @@ -658,7 +688,7 @@ def __init__(self, loop, sock, protocol, waiter=None,

self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._loop.call_soon(self._loop._add_reader,
self._sock_fd, self._read_ready)
if waiter is not None:
# only wake up the waiter when connection_made() has been called
Expand All @@ -671,7 +701,7 @@ def pause_reading(self):
if self._paused:
raise RuntimeError('Already paused')
self._paused = True
self._loop.remove_reader(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
if self._loop.get_debug():
logger.debug("%r pauses reading", self)

Expand All @@ -681,7 +711,7 @@ def resume_reading(self):
self._paused = False
if self._closing:
return
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop._add_reader(self._sock_fd, self._read_ready)
if self._loop.get_debug():
logger.debug("%r resumes reading", self)

Expand All @@ -705,7 +735,7 @@ def _read_ready(self):
# We're keeping the connection open so the
# protocol can write more, but we still can't
# receive more, so remove the reader callback.
self._loop.remove_reader(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
else:
self.close()

Expand Down Expand Up @@ -738,7 +768,7 @@ def write(self, data):
if not data:
return
# Not all was written; register write handler.
self._loop.add_writer(self._sock_fd, self._write_ready)
self._loop._add_writer(self._sock_fd, self._write_ready)

# Add it to the buffer.
self._buffer.extend(data)
Expand All @@ -754,15 +784,15 @@ def _write_ready(self):
except (BlockingIOError, InterruptedError):
pass
except Exception as exc:
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
self._buffer.clear()
self._fatal_error(exc, 'Fatal write error on socket transport')
else:
if n:
del self._buffer[:n]
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
if self._closing:
self._call_connection_lost(None)
elif self._eof:
Expand Down Expand Up @@ -833,28 +863,28 @@ def _on_handshake(self, start_time):
try:
self._sock.do_handshake()
except ssl.SSLWantReadError:
self._loop.add_reader(self._sock_fd,
self._on_handshake, start_time)
self._loop._add_reader(self._sock_fd,
self._on_handshake, start_time)
return
except ssl.SSLWantWriteError:
self._loop.add_writer(self._sock_fd,
self._on_handshake, start_time)
self._loop._add_writer(self._sock_fd,
self._on_handshake, start_time)
return
except BaseException as exc:
if self._loop.get_debug():
logger.warning("%r: SSL handshake failed",
self, exc_info=True)
self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
self._sock.close()
self._wakeup_waiter(exc)
if isinstance(exc, Exception):
return
else:
raise

self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
self._loop._remove_writer(self._sock_fd)

peercert = self._sock.getpeercert()
if not hasattr(self._sslcontext, 'check_hostname'):
Expand Down Expand Up @@ -882,7 +912,7 @@ def _on_handshake(self, start_time):

self._read_wants_write = False
self._write_wants_read = False
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop._add_reader(self._sock_fd, self._read_ready)
self._protocol_connected = True
self._loop.call_soon(self._protocol.connection_made, self)
# only wake up the waiter when connection_made() has been called
Expand All @@ -904,7 +934,7 @@ def pause_reading(self):
if self._paused:
raise RuntimeError('Already paused')
self._paused = True
self._loop.remove_reader(self._sock_fd)
self._loop._remove_reader(self._sock_fd)
if self._loop.get_debug():
logger.debug("%r pauses reading", self)

Expand All @@ -914,7 +944,7 @@ def resume_reading(self):
self._paused = False
if self._closing:
return
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop._add_reader(self._sock_fd, self._read_ready)
if self._loop.get_debug():
logger.debug("%r resumes reading", self)

Expand All @@ -926,16 +956,16 @@ def _read_ready(self):
self._write_ready()

if self._buffer:
self._loop.add_writer(self._sock_fd, self._write_ready)
self._loop._add_writer(self._sock_fd, self._write_ready)

try:
data = self._sock.recv(self.max_size)
except (BlockingIOError, InterruptedError, ssl.SSLWantReadError):
pass
except ssl.SSLWantWriteError:
self._read_wants_write = True
self._loop.remove_reader(self._sock_fd)
self._loop.add_writer(self._sock_fd, self._write_ready)
self._loop._remove_reader(self._sock_fd)
self._loop._add_writer(self._sock_fd, self._write_ready)
except Exception as exc:
self._fatal_error(exc, 'Fatal read error on SSL transport')
else:
Expand All @@ -960,7 +990,7 @@ def _write_ready(self):
self._read_ready()

if not (self._paused or self._closing):
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop._add_reader(self._sock_fd, self._read_ready)

if self._buffer:
try:
Expand All @@ -969,10 +999,10 @@ def _write_ready(self):
n = 0
except ssl.SSLWantReadError:
n = 0
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
self._write_wants_read = True
except Exception as exc:
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
self._buffer.clear()
self._fatal_error(exc, 'Fatal write error on SSL transport')
return
Expand All @@ -983,7 +1013,7 @@ def _write_ready(self):
self._maybe_resume_protocol() # May append to buffer.

if not self._buffer:
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
if self._closing:
self._call_connection_lost(None)

Expand All @@ -1001,7 +1031,7 @@ def write(self, data):
return

if not self._buffer:
self._loop.add_writer(self._sock_fd, self._write_ready)
self._loop._add_writer(self._sock_fd, self._write_ready)

# Add it to the buffer.
self._buffer.extend(data)
Expand All @@ -1021,7 +1051,7 @@ def __init__(self, loop, sock, protocol, address=None,
self._address = address
self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._loop.call_soon(self._loop._add_reader,
self._sock_fd, self._read_ready)
if waiter is not None:
# only wake up the waiter when connection_made() has been called
Expand Down Expand Up @@ -1071,7 +1101,7 @@ def sendto(self, data, addr=None):
self._sock.sendto(data, addr)
return
except (BlockingIOError, InterruptedError):
self._loop.add_writer(self._sock_fd, self._sendto_ready)
self._loop._add_writer(self._sock_fd, self._sendto_ready)
except OSError as exc:
self._protocol.error_received(exc)
return
Expand Down Expand Up @@ -1105,6 +1135,6 @@ def _sendto_ready(self):

self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop.remove_writer(self._sock_fd)
self._loop._remove_writer(self._sock_fd)
if self._closing:
self._call_connection_lost(None)
Loading