diff --git a/messaging/impl_zmq.cc b/messaging/impl_zmq.cc index c47fa94aa3eb7e..c9f124beb9ca41 100644 --- a/messaging/impl_zmq.cc +++ b/messaging/impl_zmq.cc @@ -89,7 +89,7 @@ Message * ZMQSubSocket::receive(bool non_blocking){ // Make a copy to ensure the data is aligned r = new ZMQMessage; r->init((char*)zmq_msg_data(&msg), zmq_msg_size(&msg)); - } + } zmq_msg_close(&msg); return r; diff --git a/messaging/messaging.pxd b/messaging/messaging.pxd index 5e3da7241544cf..f6fd931273eaf2 100644 --- a/messaging/messaging.pxd +++ b/messaging/messaging.pxd @@ -39,4 +39,4 @@ cdef extern from "messaging.hpp": @staticmethod Poller * create() void registerSocket(SubSocket *) - vector[SubSocket*] poll(int) + vector[SubSocket*] poll(int) nogil diff --git a/messaging/messaging_pyx.pyx b/messaging/messaging_pyx.pyx index 78b928ef8a9c50..f6dfaa7e24bd00 100644 --- a/messaging/messaging_pyx.pyx +++ b/messaging/messaging_pyx.pyx @@ -19,6 +19,10 @@ cdef class Context: def __cinit__(self): self.context = cppContext.create() + def term(self): + del self.context + self.context = NULL + def __dealloc__(self): pass # Deleting the context will hang if sockets are still active @@ -43,8 +47,11 @@ cdef class Poller: def poll(self, timeout): sockets = [] + cdef int t = timeout + + with nogil: + result = self.poller.poll(t) - result = self.poller.poll(timeout) for s in result: socket = SubSocket() socket.setPtr(s) diff --git a/messaging/tests/test_poller.py b/messaging/tests/test_poller.py index 609989b584cad5..0f343fb89f0a66 100644 --- a/messaging/tests/test_poller.py +++ b/messaging/tests/test_poller.py @@ -1,24 +1,24 @@ import unittest +import os import time import cereal.messaging as messaging -from multiprocessing import Process, Pipe +import concurrent.futures -def poller(pipe): +def poller(): context = messaging.Context() + p = messaging.Poller() + sub = messaging.SubSocket() sub.connect(context, 'controlsState') - - p = messaging.Poller() p.registerSocket(sub) - while True: - pipe.recv() + socks = p.poll(1000) + r = [s.receive(non_blocking=True) for s in socks] - socks = p.poll(1000) - pipe.send([s.receive(non_blocking=True) for s in socks]) + return r class TestPoller(unittest.TestCase): @@ -28,19 +28,44 @@ def test_poll_once(self): pub = messaging.PubSocket() pub.connect(context, 'controlsState') - pipe, pipe_child = Pipe() - proc = Process(target=poller, args=(pipe_child,)) - proc.start() + with concurrent.futures.ThreadPoolExecutor() as e: + poll = e.submit(poller) + + time.sleep(0.1) # Slow joiner syndrome + + # Send message + pub.send("a") + + # Wait for poll result + result = poll.result() + + del pub + context.term() + + self.assertEqual(result, [b"a"]) + + @unittest.skipIf(os.environ.get('MSGQ'), "fails under msgq") + def test_poll_and_create_many_subscribers(self): + context = messaging.Context() + + pub = messaging.PubSocket() + pub.connect(context, 'controlsState') + + with concurrent.futures.ThreadPoolExecutor() as e: + poll = e.submit(poller) - time.sleep(.1) + time.sleep(0.1) # Slow joiner syndrome + c = messaging.Context() + for _ in range(10): + messaging.SubSocket().connect(c, 'controlsState') - # Start poll - pipe.send("go") + # Send message + pub.send("a") - # Send message - pub.send("a") + # Wait for poll result + result = poll.result() - result = pipe.recv() - proc.kill() + del pub + context.term() self.assertEqual(result, [b"a"])