@@ -47,6 +47,30 @@ class TestAgentBase(unittest.TestCase):
47
47
target_domain = 43
48
48
target_port = 1024
49
49
50
+ def check_dom0 (self , dom0 ):
51
+ self .assertEqual (
52
+ dom0 .recv_message (),
53
+ (
54
+ qrexec .MSG_CONNECTION_TERMINATED ,
55
+ struct .pack ("<LL" , self .target_domain , self .target_port ),
56
+ ),
57
+ )
58
+
59
+ def assertExpectedStdout (self , target , expected_stdout : bytes , * , exit_code = 0 ):
60
+ messages = util .sort_messages (target .recv_all_messages ())
61
+ self .assertListEqual (messages [- 3 :],
62
+ [
63
+ (qrexec .MSG_DATA_STDOUT , b"" ),
64
+ (qrexec .MSG_DATA_STDERR , b"" ),
65
+ (qrexec .MSG_DATA_EXIT_CODE , struct .pack ("<L" , exit_code ))
66
+ ])
67
+ stdout_entries = []
68
+ for msg_type , msg_body in messages [:- 3 ]:
69
+ # messages before last are not empty, hence truthy
70
+ self .assertTrue (msg_body )
71
+ self .assertEqual (msg_type , qrexec .MSG_DATA_STDOUT )
72
+ stdout_entries .append (msg_body )
73
+
50
74
def setUp (self ):
51
75
self .tempdir = tempfile .mkdtemp ()
52
76
os .mkdir (os .path .join (self .tempdir , "local-rpc" ))
@@ -157,14 +181,7 @@ def test_just_exec(self):
157
181
lambda : os .path .exists (os .path .join (self .tempdir , "new_file" )),
158
182
"file created" ,
159
183
)
160
-
161
- self .assertEqual (
162
- dom0 .recv_message (),
163
- (
164
- qrexec .MSG_CONNECTION_TERMINATED ,
165
- struct .pack ("<LL" , self .target_domain , self .target_port ),
166
- ),
167
- )
184
+ self .check_dom0 (dom0 )
168
185
169
186
def test_exec_cmdline (self ):
170
187
self .start_agent ()
@@ -186,24 +203,8 @@ def test_exec_cmdline(self):
186
203
187
204
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
188
205
189
- messages = target .recv_all_messages ()
190
- self .assertListEqual (
191
- util .sort_messages (messages ),
192
- [
193
- (qrexec .MSG_DATA_STDOUT , b"Hello world\n " ),
194
- (qrexec .MSG_DATA_STDOUT , b"" ),
195
- (qrexec .MSG_DATA_STDERR , b"" ),
196
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
197
- ],
198
- )
199
-
200
- self .assertEqual (
201
- dom0 .recv_message (),
202
- (
203
- qrexec .MSG_CONNECTION_TERMINATED ,
204
- struct .pack ("<LL" , self .target_domain , self .target_port ),
205
- ),
206
- )
206
+ self .assertExpectedStdout (target , b"Hello world\n " )
207
+ self .check_dom0 (dom0 )
207
208
208
209
def test_trigger_service (self ):
209
210
self .start_agent ()
@@ -229,13 +230,7 @@ def test_trigger_service(self):
229
230
)
230
231
231
232
client .close ()
232
- self .assertEqual (
233
- dom0 .recv_message (),
234
- (
235
- qrexec .MSG_CONNECTION_TERMINATED ,
236
- struct .pack ("<LL" , self .target_domain , self .target_port ),
237
- ),
238
- )
233
+ self .check_dom0 (dom0 )
239
234
240
235
def test_trigger_service_refused (self ):
241
236
self .start_agent ()
@@ -310,15 +305,6 @@ def execute_qubesrpc(self, service: str, src_domain_name: str):
310
305
target .handshake ()
311
306
return target , dom0
312
307
313
- def check_dom0 (self , dom0 ):
314
- self .assertEqual (
315
- dom0 .recv_message (),
316
- (
317
- qrexec .MSG_CONNECTION_TERMINATED ,
318
- struct .pack ("<LL" , self .target_domain , self .target_port ),
319
- ),
320
- )
321
-
322
308
def make_executable_service (self , * args ):
323
309
util .make_executable_service (self .tempdir , * args )
324
310
@@ -332,18 +318,10 @@ def test_exec_service(self):
332
318
echo "arg: $1, remote domain: $QREXEC_REMOTE_DOMAIN"
333
319
""" ,
334
320
)
335
- target , _ = self .execute_qubesrpc ("qubes.Service+arg" , "domX" )
321
+ target , dom0 = self .execute_qubesrpc ("qubes.Service+arg" , "domX" )
336
322
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
337
- messages = target .recv_all_messages ()
338
- self .assertListEqual (
339
- util .sort_messages (messages ),
340
- [
341
- (qrexec .MSG_DATA_STDOUT , b"arg: arg, remote domain: domX\n " ),
342
- (qrexec .MSG_DATA_STDOUT , b"" ),
343
- (qrexec .MSG_DATA_STDERR , b"" ),
344
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
345
- ],
346
- )
323
+ self .assertExpectedStdout (target , b"arg: arg, remote domain: domX\n " )
324
+ self .check_dom0 (dom0 )
347
325
348
326
def test_exec_service_keyword (self ):
349
327
util .make_executable_service (
@@ -361,20 +339,11 @@ def test_exec_service_keyword(self):
361
339
)
362
340
target , dom0 = self .execute_qubesrpc ("qubes.Service" , "domX" )
363
341
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
364
- messages = target .recv_all_messages ()
365
- self .assertListEqual (
366
- util .sort_messages (messages ),
367
- [
368
- (qrexec .MSG_DATA_STDOUT , b"""arg: , remote domain: domX
342
+ self .assertExpectedStdout (target , b"""arg: , remote domain: domX
369
343
target name: NONAME
370
344
target keyword: NOKEYWORD
371
345
target type: ''
372
- """ ),
373
- (qrexec .MSG_DATA_STDOUT , b"" ),
374
- (qrexec .MSG_DATA_STDERR , b"" ),
375
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
376
- ],
377
- )
346
+ """ )
378
347
self .check_dom0 (dom0 )
379
348
380
349
def test_exec_service_with_config (self ):
@@ -395,16 +364,7 @@ def test_exec_service_with_config(self):
395
364
""" )
396
365
target , dom0 = self .execute_qubesrpc ("qubes.Service+arg" , "domX" )
397
366
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
398
- messages = target .recv_all_messages ()
399
- self .assertListEqual (
400
- util .sort_messages (messages ),
401
- [
402
- (qrexec .MSG_DATA_STDOUT , b"arg: arg, remote domain: domX\n " ),
403
- (qrexec .MSG_DATA_STDOUT , b"" ),
404
- (qrexec .MSG_DATA_STDERR , b"" ),
405
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
406
- ],
407
- )
367
+ self .assertExpectedStdout (target , b"arg: arg, remote domain: domX\n " )
408
368
self .check_dom0 (dom0 )
409
369
410
370
def test_wait_for_session (self ):
@@ -468,20 +428,9 @@ def _test_wait_for_session(self, config_name, service_name="qubes.Service", argu
468
428
# Do not send EOF. Shell read doesn't need it, and this checks that
469
429
# qrexec does not wait for EOF on stdin before sending the exit code
470
430
# from the remote process.
471
- messages = target .recv_all_messages ()
472
- self .assertListEqual (
473
- util .sort_messages (messages ),
474
- [
475
- (
476
- qrexec .MSG_DATA_STDOUT ,
477
- b"arg: " + argument .encode ("ascii" , "strict" )
478
- + b", remote domain: domX, input: stdin data\n " ,
479
- ),
480
- (qrexec .MSG_DATA_STDOUT , b"" ),
481
- (qrexec .MSG_DATA_STDERR , b"" ),
482
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
483
- ],
484
- )
431
+ expected_stdout = (b"arg: " + argument .encode ("ascii" , "strict" )
432
+ + b", remote domain: domX, input: stdin data\n " )
433
+ self .assertExpectedStdout (target , expected_stdout )
485
434
self .check_dom0 (dom0 )
486
435
487
436
def test_exec_service_fail (self ):
@@ -618,16 +567,7 @@ def test_exec_null_argument_finds_service_for_empty_argument(self):
618
567
)
619
568
target , dom0 = self .execute_qubesrpc ("qubes.Service" , "domX" )
620
569
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
621
- messages = target .recv_all_messages ()
622
- self .assertListEqual (
623
- util .sort_messages (messages ),
624
- [
625
- (qrexec .MSG_DATA_STDOUT , b"specific service: qubes.Service\n " ),
626
- (qrexec .MSG_DATA_STDOUT , b"" ),
627
- (qrexec .MSG_DATA_STDERR , b"" ),
628
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
629
- ],
630
- )
570
+ self .assertExpectedStdout (target , b"specific service: qubes.Service\n " )
631
571
self .check_dom0 (dom0 )
632
572
633
573
def test_socket_null_argument_finds_service_for_empty_argument (self ):
@@ -793,10 +733,10 @@ def execute(self, cmd: str):
793
733
794
734
target = self .connect_target ()
795
735
target .handshake ()
796
- return target
736
+ return target , dom0
797
737
798
738
def test_stdin_stderr (self ):
799
- target = self .execute ('echo "stdout"; echo "stderr" >&2' )
739
+ target , dom0 = self .execute ('echo "stdout"; echo "stderr" >&2' )
800
740
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
801
741
802
742
messages = target .recv_all_messages ()
@@ -812,7 +752,7 @@ def test_stdin_stderr(self):
812
752
)
813
753
814
754
def test_pass_stdin (self ):
815
- target = self .execute ("cat" )
755
+ target , dom0 = self .execute ("cat" )
816
756
817
757
target .send_message (qrexec .MSG_DATA_STDIN , b"data 1" )
818
758
self .assertEqual (
@@ -825,24 +765,16 @@ def test_pass_stdin(self):
825
765
)
826
766
827
767
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
828
- messages = target .recv_all_messages ()
829
- self .assertListEqual (
830
- util .sort_messages (messages ),
831
- [
832
- (qrexec .MSG_DATA_STDOUT , b"" ),
833
- (qrexec .MSG_DATA_STDERR , b"" ),
834
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
835
- ],
836
- )
768
+ self .assertExpectedStdout (target , b"" )
837
769
838
770
def test_close_stdin_early (self ):
839
771
# Make sure that we cover the error on writing stdin into living
840
772
# process.
841
- target = self .execute (
773
+ target , dom0 = self .execute (
842
774
"""
843
775
read
844
776
exec <&-
845
- echo closed stdin
777
+ echo " closed stdin"
846
778
sleep 1
847
779
"""
848
780
)
@@ -853,15 +785,8 @@ def test_close_stdin_early(self):
853
785
target .send_message (qrexec .MSG_DATA_STDIN , b"data 2\n " )
854
786
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
855
787
856
- messages = target .recv_all_messages ()
857
- self .assertListEqual (
858
- util .sort_messages (messages ),
859
- [
860
- (qrexec .MSG_DATA_STDOUT , b"" ),
861
- (qrexec .MSG_DATA_STDERR , b"" ),
862
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
863
- ],
864
- )
788
+ self .assertExpectedStdout (target , b"" )
789
+ self .check_dom0 (dom0 )
865
790
866
791
def test_buffer_stdin (self ):
867
792
# Test to trigger WRITE_STDIN_BUFFERED.
@@ -877,7 +802,7 @@ def test_buffer_stdin(self):
877
802
878
803
fifo = os .path .join (self .tempdir , "fifo" )
879
804
os .mkfifo (fifo )
880
- target = self .execute ("read <{}; cat" .format (fifo ))
805
+ target , dom0 = self .execute ("read <{}; cat" .format (fifo ))
881
806
882
807
for i in range (0 , data_size , msg_size ):
883
808
msg = data [i : i + msg_size ]
@@ -888,32 +813,11 @@ def test_buffer_stdin(self):
888
813
with open (fifo , "a" ) as f :
889
814
f .write ("end\n " )
890
815
f .flush ()
891
-
892
- messages = []
893
- received_data = b""
894
- while len (received_data ) < data_size :
895
- message_type , message = target .recv_message ()
896
- if message_type != qrexec .MSG_DATA_STDOUT :
897
- messages .append ((message_type , message ))
898
- else :
899
- self .assertEqual (message_type , qrexec .MSG_DATA_STDOUT )
900
- received_data += message
901
-
902
- self .assertEqual (len (received_data ), data_size )
903
- self .assertEqual (received_data , data )
904
-
905
- messages += target .recv_all_messages ()
906
- self .assertListEqual (
907
- util .sort_messages (messages ),
908
- [
909
- (qrexec .MSG_DATA_STDOUT , b"" ),
910
- (qrexec .MSG_DATA_STDERR , b"" ),
911
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
912
- ],
913
- )
816
+ self .assertExpectedStdout (target , data )
817
+ self .check_dom0 (dom0 )
914
818
915
819
def test_close_stdout_stderr_early (self ):
916
- target = self .execute (
820
+ target , dom0 = self .execute (
917
821
"""\
918
822
read
919
823
echo closing stdout
@@ -946,9 +850,10 @@ def test_close_stdout_stderr_early(self):
946
850
target .recv_message (),
947
851
(qrexec .MSG_DATA_EXIT_CODE , struct .pack ("<L" , 42 )),
948
852
)
853
+ self .check_dom0 (dom0 )
949
854
950
855
def test_stdio_socket (self ):
951
- target = self .execute (
856
+ target , dom0 = self .execute (
952
857
"""\
953
858
kill -USR1 $QREXEC_AGENT_PID
954
859
echo hello world >&0
@@ -963,21 +868,13 @@ def test_stdio_socket(self):
963
868
target .send_message (qrexec .MSG_DATA_STDIN , b"stdin\n " )
964
869
target .send_message (qrexec .MSG_DATA_STDIN , b"" )
965
870
966
- messages = target .recv_all_messages ()
967
- self .assertListEqual (
968
- util .sort_messages (messages ),
969
- [
970
- (qrexec .MSG_DATA_STDOUT , b"received: stdin\n " ),
971
- (qrexec .MSG_DATA_STDOUT , b"" ),
972
- (qrexec .MSG_DATA_STDERR , b"" ),
973
- (qrexec .MSG_DATA_EXIT_CODE , b"\0 \0 \0 \0 " ),
974
- ],
975
- )
871
+ self .assertExpectedStdout (target , b"received: stdin\n " )
872
+ self .check_dom0 (dom0 )
976
873
977
874
def test_exit_before_closing_streams (self ):
978
875
fifo = os .path .join (self .tempdir , "fifo" )
979
876
os .mkfifo (fifo )
980
- target = self .execute (
877
+ target , dom0 = self .execute (
981
878
"""\
982
879
# duplicate original stdin to fd 3, because bash will
983
880
# close original stdin in child process
@@ -1009,18 +906,8 @@ def test_exit_before_closing_streams(self):
1009
906
with open (fifo , "a" ) as f :
1010
907
f .write ("end\n " )
1011
908
f .flush ()
1012
- self .assertEqual (
1013
- target .recv_message (), (qrexec .MSG_DATA_STDOUT , b"child exiting\n " )
1014
- )
1015
- messages = target .recv_all_messages ()
1016
- self .assertListEqual (
1017
- util .sort_messages (messages ),
1018
- [
1019
- (qrexec .MSG_DATA_STDOUT , b"" ),
1020
- (qrexec .MSG_DATA_STDERR , b"" ),
1021
- (qrexec .MSG_DATA_EXIT_CODE , struct .pack ("<L" , 42 )),
1022
- ],
1023
- )
909
+ self .assertExpectedStdout (target , b"child exiting\n " , exit_code = 42 )
910
+ self .check_dom0 (dom0 )
1024
911
1025
912
1026
913
@unittest .skipIf (os .environ .get ("SKIP_SOCKET_TESTS" ), "socket tests not set up" )
0 commit comments