diff --git a/pymobiledevice3/tcp_forwarder.py b/pymobiledevice3/tcp_forwarder.py index f02d4bf85..08680f389 100644 --- a/pymobiledevice3/tcp_forwarder.py +++ b/pymobiledevice3/tcp_forwarder.py @@ -11,7 +11,8 @@ class TcpForwarder: MAX_FORWARDED_CONNECTIONS = 200 TIMEOUT = 1 - def __init__(self, lockdown: LockdownClient, src_port: int, dst_port: int, enable_ssl=False): + def __init__(self, lockdown: LockdownClient, src_port: int, dst_port: int, enable_ssl=False, + listening_event: threading.Event = None): self.logger = logging.getLogger(__name__) self.lockdown = lockdown self.src_port = src_port @@ -20,6 +21,7 @@ def __init__(self, lockdown: LockdownClient, src_port: int, dst_port: int, enabl self.inputs = [] self.enable_ssl = enable_ssl self.stopped = threading.Event() + self.listening_event = listening_event # dictionaries containing the required maps to transfer data between each local # socket to its remote socket and vice versa @@ -37,6 +39,8 @@ def start(self, address='0.0.0.0'): self.server_socket.setblocking(False) self.inputs = [self.server_socket] + if self.listening_event: + self.listening_event.set() while self.inputs: # will only perform the socket select on the inputs. the outputs will handled @@ -45,11 +49,13 @@ def start(self, address='0.0.0.0'): if self.stopped.is_set(): break + closed_sockets = set() for current_sock in readable: if current_sock is self.server_socket: self._handle_server_connection() else: - self._handle_data(current_sock) + if current_sock not in closed_sockets: + self._handle_data(current_sock, closed_sockets) for current_sock in exceptional: self._handle_close_or_error(current_sock) @@ -65,12 +71,14 @@ def _handle_close_or_error(self, from_sock): self.logger.info(f'connection {other_sock} was closed') - def _handle_data(self, from_sock): + def _handle_data(self, from_sock, closed_sockets): data = from_sock.recv(1024) if len(data) == 0: # no data means socket was closed self._handle_close_or_error(from_sock) + closed_sockets.add(from_sock) + closed_sockets.add(self.connections[from_sock]) return # when data is received from one end, just forward it to the other