Skip to content

Commit daee92e

Browse files
committed
Merge remote-tracking branch 'origin/pr/144'
* origin/pr/144: Fix flaky qrexec agent tests Check for dom0 messages in more agent tests
2 parents f3a5784 + dfd804f commit daee92e

File tree

1 file changed

+57
-170
lines changed

1 file changed

+57
-170
lines changed

qrexec/tests/socket/agent.py

+57-170
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@ class TestAgentBase(unittest.TestCase):
4747
target_domain = 43
4848
target_port = 1024
4949

50+
def check_dom0(self, dom0):
51+
self.assertEqual(
52+
dom0.recv_message(),
53+
(
54+
qrexec.MSG_CONNECTION_TERMINATED,
55+
struct.pack("<LL", self.target_domain, self.target_port),
56+
),
57+
)
58+
59+
def assertExpectedStdout(self, target, expected_stdout: bytes, *, exit_code=0):
60+
messages = util.sort_messages(target.recv_all_messages())
61+
self.assertListEqual(messages[-3:],
62+
[
63+
(qrexec.MSG_DATA_STDOUT, b""),
64+
(qrexec.MSG_DATA_STDERR, b""),
65+
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", exit_code))
66+
])
67+
stdout_entries = []
68+
for msg_type, msg_body in messages[:-3]:
69+
# messages before last are not empty, hence truthy
70+
self.assertTrue(msg_body)
71+
self.assertEqual(msg_type, qrexec.MSG_DATA_STDOUT)
72+
stdout_entries.append(msg_body)
73+
5074
def setUp(self):
5175
self.tempdir = tempfile.mkdtemp()
5276
os.mkdir(os.path.join(self.tempdir, "local-rpc"))
@@ -157,14 +181,7 @@ def test_just_exec(self):
157181
lambda: os.path.exists(os.path.join(self.tempdir, "new_file")),
158182
"file created",
159183
)
160-
161-
self.assertEqual(
162-
dom0.recv_message(),
163-
(
164-
qrexec.MSG_CONNECTION_TERMINATED,
165-
struct.pack("<LL", self.target_domain, self.target_port),
166-
),
167-
)
184+
self.check_dom0(dom0)
168185

169186
def test_exec_cmdline(self):
170187
self.start_agent()
@@ -186,24 +203,8 @@ def test_exec_cmdline(self):
186203

187204
target.send_message(qrexec.MSG_DATA_STDIN, b"")
188205

189-
messages = target.recv_all_messages()
190-
self.assertListEqual(
191-
util.sort_messages(messages),
192-
[
193-
(qrexec.MSG_DATA_STDOUT, b"Hello world\n"),
194-
(qrexec.MSG_DATA_STDOUT, b""),
195-
(qrexec.MSG_DATA_STDERR, b""),
196-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
197-
],
198-
)
199-
200-
self.assertEqual(
201-
dom0.recv_message(),
202-
(
203-
qrexec.MSG_CONNECTION_TERMINATED,
204-
struct.pack("<LL", self.target_domain, self.target_port),
205-
),
206-
)
206+
self.assertExpectedStdout(target, b"Hello world\n")
207+
self.check_dom0(dom0)
207208

208209
def test_trigger_service(self):
209210
self.start_agent()
@@ -229,13 +230,7 @@ def test_trigger_service(self):
229230
)
230231

231232
client.close()
232-
self.assertEqual(
233-
dom0.recv_message(),
234-
(
235-
qrexec.MSG_CONNECTION_TERMINATED,
236-
struct.pack("<LL", self.target_domain, self.target_port),
237-
),
238-
)
233+
self.check_dom0(dom0)
239234

240235
def test_trigger_service_refused(self):
241236
self.start_agent()
@@ -310,15 +305,6 @@ def execute_qubesrpc(self, service: str, src_domain_name: str):
310305
target.handshake()
311306
return target, dom0
312307

313-
def check_dom0(self, dom0):
314-
self.assertEqual(
315-
dom0.recv_message(),
316-
(
317-
qrexec.MSG_CONNECTION_TERMINATED,
318-
struct.pack("<LL", self.target_domain, self.target_port),
319-
),
320-
)
321-
322308
def make_executable_service(self, *args):
323309
util.make_executable_service(self.tempdir, *args)
324310

@@ -332,18 +318,10 @@ def test_exec_service(self):
332318
echo "arg: $1, remote domain: $QREXEC_REMOTE_DOMAIN"
333319
""",
334320
)
335-
target, _ = self.execute_qubesrpc("qubes.Service+arg", "domX")
321+
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
336322
target.send_message(qrexec.MSG_DATA_STDIN, b"")
337-
messages = target.recv_all_messages()
338-
self.assertListEqual(
339-
util.sort_messages(messages),
340-
[
341-
(qrexec.MSG_DATA_STDOUT, b"arg: arg, remote domain: domX\n"),
342-
(qrexec.MSG_DATA_STDOUT, b""),
343-
(qrexec.MSG_DATA_STDERR, b""),
344-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
345-
],
346-
)
323+
self.assertExpectedStdout(target, b"arg: arg, remote domain: domX\n")
324+
self.check_dom0(dom0)
347325

348326
def test_exec_service_keyword(self):
349327
util.make_executable_service(
@@ -361,20 +339,11 @@ def test_exec_service_keyword(self):
361339
)
362340
target, dom0 = self.execute_qubesrpc("qubes.Service", "domX")
363341
target.send_message(qrexec.MSG_DATA_STDIN, b"")
364-
messages = target.recv_all_messages()
365-
self.assertListEqual(
366-
util.sort_messages(messages),
367-
[
368-
(qrexec.MSG_DATA_STDOUT, b"""arg: , remote domain: domX
342+
self.assertExpectedStdout(target, b"""arg: , remote domain: domX
369343
target name: NONAME
370344
target keyword: NOKEYWORD
371345
target type: ''
372-
"""),
373-
(qrexec.MSG_DATA_STDOUT, b""),
374-
(qrexec.MSG_DATA_STDERR, b""),
375-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
376-
],
377-
)
346+
""")
378347
self.check_dom0(dom0)
379348

380349
def test_exec_service_with_config(self):
@@ -395,16 +364,7 @@ def test_exec_service_with_config(self):
395364
""")
396365
target, dom0 = self.execute_qubesrpc("qubes.Service+arg", "domX")
397366
target.send_message(qrexec.MSG_DATA_STDIN, b"")
398-
messages = target.recv_all_messages()
399-
self.assertListEqual(
400-
util.sort_messages(messages),
401-
[
402-
(qrexec.MSG_DATA_STDOUT, b"arg: arg, remote domain: domX\n"),
403-
(qrexec.MSG_DATA_STDOUT, b""),
404-
(qrexec.MSG_DATA_STDERR, b""),
405-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
406-
],
407-
)
367+
self.assertExpectedStdout(target, b"arg: arg, remote domain: domX\n")
408368
self.check_dom0(dom0)
409369

410370
def test_wait_for_session(self):
@@ -468,20 +428,9 @@ def _test_wait_for_session(self, config_name, service_name="qubes.Service", argu
468428
# Do not send EOF. Shell read doesn't need it, and this checks that
469429
# qrexec does not wait for EOF on stdin before sending the exit code
470430
# from the remote process.
471-
messages = target.recv_all_messages()
472-
self.assertListEqual(
473-
util.sort_messages(messages),
474-
[
475-
(
476-
qrexec.MSG_DATA_STDOUT,
477-
b"arg: " + argument.encode("ascii", "strict")
478-
+ b", remote domain: domX, input: stdin data\n",
479-
),
480-
(qrexec.MSG_DATA_STDOUT, b""),
481-
(qrexec.MSG_DATA_STDERR, b""),
482-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
483-
],
484-
)
431+
expected_stdout = (b"arg: " + argument.encode("ascii", "strict")
432+
+ b", remote domain: domX, input: stdin data\n")
433+
self.assertExpectedStdout(target, expected_stdout)
485434
self.check_dom0(dom0)
486435

487436
def test_exec_service_fail(self):
@@ -618,16 +567,7 @@ def test_exec_null_argument_finds_service_for_empty_argument(self):
618567
)
619568
target, dom0 = self.execute_qubesrpc("qubes.Service", "domX")
620569
target.send_message(qrexec.MSG_DATA_STDIN, b"")
621-
messages = target.recv_all_messages()
622-
self.assertListEqual(
623-
util.sort_messages(messages),
624-
[
625-
(qrexec.MSG_DATA_STDOUT, b"specific service: qubes.Service\n"),
626-
(qrexec.MSG_DATA_STDOUT, b""),
627-
(qrexec.MSG_DATA_STDERR, b""),
628-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
629-
],
630-
)
570+
self.assertExpectedStdout(target, b"specific service: qubes.Service\n")
631571
self.check_dom0(dom0)
632572

633573
def test_socket_null_argument_finds_service_for_empty_argument(self):
@@ -793,10 +733,10 @@ def execute(self, cmd: str):
793733

794734
target = self.connect_target()
795735
target.handshake()
796-
return target
736+
return target, dom0
797737

798738
def test_stdin_stderr(self):
799-
target = self.execute('echo "stdout"; echo "stderr" >&2')
739+
target, dom0 = self.execute('echo "stdout"; echo "stderr" >&2')
800740
target.send_message(qrexec.MSG_DATA_STDIN, b"")
801741

802742
messages = target.recv_all_messages()
@@ -812,7 +752,7 @@ def test_stdin_stderr(self):
812752
)
813753

814754
def test_pass_stdin(self):
815-
target = self.execute("cat")
755+
target, dom0 = self.execute("cat")
816756

817757
target.send_message(qrexec.MSG_DATA_STDIN, b"data 1")
818758
self.assertEqual(
@@ -825,24 +765,16 @@ def test_pass_stdin(self):
825765
)
826766

827767
target.send_message(qrexec.MSG_DATA_STDIN, b"")
828-
messages = target.recv_all_messages()
829-
self.assertListEqual(
830-
util.sort_messages(messages),
831-
[
832-
(qrexec.MSG_DATA_STDOUT, b""),
833-
(qrexec.MSG_DATA_STDERR, b""),
834-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
835-
],
836-
)
768+
self.assertExpectedStdout(target, b"")
837769

838770
def test_close_stdin_early(self):
839771
# Make sure that we cover the error on writing stdin into living
840772
# process.
841-
target = self.execute(
773+
target, dom0 = self.execute(
842774
"""
843775
read
844776
exec <&-
845-
echo closed stdin
777+
echo "closed stdin"
846778
sleep 1
847779
"""
848780
)
@@ -853,15 +785,8 @@ def test_close_stdin_early(self):
853785
target.send_message(qrexec.MSG_DATA_STDIN, b"data 2\n")
854786
target.send_message(qrexec.MSG_DATA_STDIN, b"")
855787

856-
messages = target.recv_all_messages()
857-
self.assertListEqual(
858-
util.sort_messages(messages),
859-
[
860-
(qrexec.MSG_DATA_STDOUT, b""),
861-
(qrexec.MSG_DATA_STDERR, b""),
862-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
863-
],
864-
)
788+
self.assertExpectedStdout(target, b"")
789+
self.check_dom0(dom0)
865790

866791
def test_buffer_stdin(self):
867792
# Test to trigger WRITE_STDIN_BUFFERED.
@@ -877,7 +802,7 @@ def test_buffer_stdin(self):
877802

878803
fifo = os.path.join(self.tempdir, "fifo")
879804
os.mkfifo(fifo)
880-
target = self.execute("read <{}; cat".format(fifo))
805+
target, dom0 = self.execute("read <{}; cat".format(fifo))
881806

882807
for i in range(0, data_size, msg_size):
883808
msg = data[i : i + msg_size]
@@ -888,32 +813,11 @@ def test_buffer_stdin(self):
888813
with open(fifo, "a") as f:
889814
f.write("end\n")
890815
f.flush()
891-
892-
messages = []
893-
received_data = b""
894-
while len(received_data) < data_size:
895-
message_type, message = target.recv_message()
896-
if message_type != qrexec.MSG_DATA_STDOUT:
897-
messages.append((message_type, message))
898-
else:
899-
self.assertEqual(message_type, qrexec.MSG_DATA_STDOUT)
900-
received_data += message
901-
902-
self.assertEqual(len(received_data), data_size)
903-
self.assertEqual(received_data, data)
904-
905-
messages += target.recv_all_messages()
906-
self.assertListEqual(
907-
util.sort_messages(messages),
908-
[
909-
(qrexec.MSG_DATA_STDOUT, b""),
910-
(qrexec.MSG_DATA_STDERR, b""),
911-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
912-
],
913-
)
816+
self.assertExpectedStdout(target, data)
817+
self.check_dom0(dom0)
914818

915819
def test_close_stdout_stderr_early(self):
916-
target = self.execute(
820+
target, dom0 = self.execute(
917821
"""\
918822
read
919823
echo closing stdout
@@ -946,9 +850,10 @@ def test_close_stdout_stderr_early(self):
946850
target.recv_message(),
947851
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
948852
)
853+
self.check_dom0(dom0)
949854

950855
def test_stdio_socket(self):
951-
target = self.execute(
856+
target, dom0 = self.execute(
952857
"""\
953858
kill -USR1 $QREXEC_AGENT_PID
954859
echo hello world >&0
@@ -963,21 +868,13 @@ def test_stdio_socket(self):
963868
target.send_message(qrexec.MSG_DATA_STDIN, b"stdin\n")
964869
target.send_message(qrexec.MSG_DATA_STDIN, b"")
965870

966-
messages = target.recv_all_messages()
967-
self.assertListEqual(
968-
util.sort_messages(messages),
969-
[
970-
(qrexec.MSG_DATA_STDOUT, b"received: stdin\n"),
971-
(qrexec.MSG_DATA_STDOUT, b""),
972-
(qrexec.MSG_DATA_STDERR, b""),
973-
(qrexec.MSG_DATA_EXIT_CODE, b"\0\0\0\0"),
974-
],
975-
)
871+
self.assertExpectedStdout(target, b"received: stdin\n")
872+
self.check_dom0(dom0)
976873

977874
def test_exit_before_closing_streams(self):
978875
fifo = os.path.join(self.tempdir, "fifo")
979876
os.mkfifo(fifo)
980-
target = self.execute(
877+
target, dom0 = self.execute(
981878
"""\
982879
# duplicate original stdin to fd 3, because bash will
983880
# close original stdin in child process
@@ -1009,18 +906,8 @@ def test_exit_before_closing_streams(self):
1009906
with open(fifo, "a") as f:
1010907
f.write("end\n")
1011908
f.flush()
1012-
self.assertEqual(
1013-
target.recv_message(), (qrexec.MSG_DATA_STDOUT, b"child exiting\n")
1014-
)
1015-
messages = target.recv_all_messages()
1016-
self.assertListEqual(
1017-
util.sort_messages(messages),
1018-
[
1019-
(qrexec.MSG_DATA_STDOUT, b""),
1020-
(qrexec.MSG_DATA_STDERR, b""),
1021-
(qrexec.MSG_DATA_EXIT_CODE, struct.pack("<L", 42)),
1022-
],
1023-
)
909+
self.assertExpectedStdout(target, b"child exiting\n", exit_code=42)
910+
self.check_dom0(dom0)
1024911

1025912

1026913
@unittest.skipIf(os.environ.get("SKIP_SOCKET_TESTS"), "socket tests not set up")

0 commit comments

Comments
 (0)