diff --git a/tests/asgi/test_websockets_graphql_transport_ws.py b/tests/asgi/test_websockets_graphql_transport_ws.py index 7e2d3c17..5820d6e1 100644 --- a/tests/asgi/test_websockets_graphql_transport_ws.py +++ b/tests/asgi/test_websockets_graphql_transport_ws.py @@ -11,6 +11,7 @@ from ariadne.asgi.handlers import GraphQLTransportWSHandler from ariadne.exceptions import WebSocketConnectionError from ariadne.utils import get_operation_type +from .websocket_utils import wait_for_condition def test_field_can_be_subscribed_using_websocket_connection_graphql_transport_ws( @@ -681,6 +682,8 @@ async def on_complete(websocket, operation): def test_custom_websocket_on_disconnect_is_called_on_invalid_operation_graphql_transport_ws( schema, + timeout=5, + poll_interval=0.1, ): def on_disconnect(websocket): websocket.scope["on_disconnect"] = True @@ -694,9 +697,15 @@ def on_disconnect(websocket): response = ws.receive_json() assert response["type"] == GraphQLTransportWSHandler.GQL_CONNECTION_ACK ws.send_json({"type": "INVALID"}) - assert "on_disconnect" not in ws.scope + condition_met = wait_for_condition( + lambda: "on_disconnect" in ws.scope, + timeout, + poll_interval, + ) - assert ws.scope["on_disconnect"] is True + assert ( + condition_met and ws.scope.get("on_disconnect") is True + ), "on_disconnect should be set in ws.scope after invalid message" def test_custom_websocket_on_disconnect_is_called_on_connection_close_graphql_transport_ws( @@ -720,6 +729,8 @@ def on_disconnect(websocket): def test_custom_websocket_on_disconnect_is_awaited_if_its_async_graphql_transport_ws( schema, + timeout=5, + poll_interval=0.1, ): async def on_disconnect(websocket): websocket.scope["on_disconnect"] = True @@ -733,9 +744,15 @@ async def on_disconnect(websocket): response = ws.receive_json() assert response["type"] == GraphQLTransportWSHandler.GQL_CONNECTION_ACK ws.send_json({"type": "INVALID"}) - assert "on_disconnect" not in ws.scope + condition_met = wait_for_condition( + lambda: "on_disconnect" in ws.scope, + timeout, + poll_interval, + ) - assert ws.scope["on_disconnect"] is True + assert ( + condition_met and ws.scope.get("on_disconnect") is True + ), "on_disconnect should be set in ws.scope after invalid message" def test_error_in_custom_websocket_on_disconnect_is_handled_graphql_transport_ws( diff --git a/tests/asgi/test_websockets_graphql_ws.py b/tests/asgi/test_websockets_graphql_ws.py index f7a47761..c5510e3b 100644 --- a/tests/asgi/test_websockets_graphql_ws.py +++ b/tests/asgi/test_websockets_graphql_ws.py @@ -8,6 +8,7 @@ from ariadne.asgi import GraphQL from ariadne.asgi.handlers import GraphQLWSHandler from ariadne.exceptions import WebSocketConnectionError +from .websocket_utils import wait_for_condition def test_field_can_be_subscribed_using_websocket_connection(client): @@ -540,7 +541,11 @@ def on_complete(websocket, operation): assert ws.scope["on_complete"] is True -def test_custom_websocket_on_complete_is_called_on_terminate(schema): +def test_custom_websocket_on_complete_is_called_on_terminate( + schema, + timeout=5, + poll_interval=0.1, +): def on_complete(websocket, operation): assert operation.name == "TestOp" websocket.scope["on_complete"] = True @@ -568,9 +573,15 @@ def on_complete(websocket, operation): assert response["id"] == "test1" assert response["payload"]["data"] == {"ping": "pong"} ws.send_json({"type": GraphQLWSHandler.GQL_CONNECTION_TERMINATE}) - assert "on_complete" not in ws.scope + condition_met = wait_for_condition( + lambda: "on_complete" in ws.scope, + timeout, + poll_interval, + ) - assert ws.scope["on_complete"] is True + assert ( + condition_met and ws.scope.get("on_complete") is True + ), "on_complete should be set in ws.scope after invalid message" def test_custom_websocket_on_complete_is_called_on_disconnect(schema): @@ -605,7 +616,11 @@ def on_complete(websocket, operation): assert ws.scope["on_complete"] is True -def test_custom_websocket_on_complete_is_awaited_if_its_async(schema): +def test_custom_websocket_on_complete_is_awaited_if_its_async( + schema, + timeout=5, + poll_interval=0.1, +): async def on_complete(websocket, operation): assert operation.name == "TestOp" websocket.scope["on_complete"] = True @@ -634,9 +649,15 @@ async def on_complete(websocket, operation): assert response["payload"]["data"] == {"ping": "pong"} ws.send_json({"type": GraphQLWSHandler.GQL_STOP}) ws.send_json({"type": GraphQLWSHandler.GQL_CONNECTION_TERMINATE}) - assert "on_complete" in ws.scope + condition_met = wait_for_condition( + lambda: "on_complete" in ws.scope, + timeout, + poll_interval, + ) - assert ws.scope["on_complete"] is True + assert ( + condition_met and ws.scope.get("on_complete") is True + ), "on_complete should be set in ws.scope after invalid message" def test_error_in_custom_websocket_on_complete_is_handled(schema): diff --git a/tests/asgi/websocket_utils.py b/tests/asgi/websocket_utils.py new file mode 100644 index 00000000..a3217317 --- /dev/null +++ b/tests/asgi/websocket_utils.py @@ -0,0 +1,18 @@ +import time + + +def wait_for_condition(condition_func, timeout=5, poll_interval=0.1): + """ + This function is particularly useful in scenarios where asynchronous operations + are involved. For instance, in a WebSocket-based system, certain events or + state changes, like setting a flag in a callback, may not occur instantly. + The wait_for_condition function ensures that the test waits long enough for + these asynchronous events to complete, preventing race conditions or false + negatives in test outcomes. + """ + start_time = time.time() + while time.time() - start_time < timeout: + if condition_func(): + return True + time.sleep(poll_interval) + return False