From 2f6a6a863fdb33589a3de37079f7dfa5b6dc69a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Wed, 21 Feb 2024 16:18:07 +0100 Subject: [PATCH] Prevent subscription execution without websocket --- ariadne/graphql.py | 40 +++++++++++++++++++ .../__snapshots__/test_query_execution.ambr | 18 +++++++++ tests/asgi/test_query_execution.py | 22 ++++++++++ .../__snapshots__/test_query_execution.ambr | 18 +++++++++ tests/wsgi/test_query_execution.py | 32 +++++++++++++++ 5 files changed, 130 insertions(+) diff --git a/ariadne/graphql.py b/ariadne/graphql.py index 36e4498c6..3496beb3d 100644 --- a/ariadne/graphql.py +++ b/ariadne/graphql.py @@ -185,6 +185,8 @@ async def graphql( if require_query: validate_operation_is_query(document, operation_name) + else: + validate_operation_is_not_subscription(document, operation_name) if callable(root_value): try: @@ -358,6 +360,8 @@ def graphql_sync( if require_query: validate_operation_is_query(document, operation_name) + else: + validate_operation_is_not_subscription(document, operation_name) if callable(root_value): try: @@ -680,3 +684,39 @@ def validate_operation_is_query( raise GraphQLError( "'operationName' is required if 'query' defines multiple operations." ) + + +def validate_operation_is_not_subscription( + document_ast: DocumentNode, operation_name: Optional[str] +): + if operation_name: + validate_named_operation_is_not_subscription(document_ast, operation_name) + else: + validate_anonymous_operation_is_not_subscription(document_ast) + + +def validate_named_operation_is_not_subscription( + document_ast: DocumentNode, operation_name: str +): + for definition in document_ast.definitions: + if ( + isinstance(definition, OperationDefinitionNode) + and definition.name.value == operation_name + and definition.operation.name == "SUBSCRIPTION" + ): + raise GraphQLError( + f"Operation '{operation_name}' is a subscription and can only be " + "executed over a WebSocket connection." + ) + + +def validate_anonymous_operation_is_not_subscription(document_ast: DocumentNode): + operations: List[OperationDefinitionNode] = [] + for definition in document_ast.definitions: + if isinstance(definition, OperationDefinitionNode): + operations.append(definition) + + if len(operations) == 1 and operations[0].operation.name == "SUBSCRIPTION": + raise GraphQLError( + "Subscription operations can only be executed over a WebSocket connection." + ) diff --git a/tests/asgi/__snapshots__/test_query_execution.ambr b/tests/asgi/__snapshots__/test_query_execution.ambr index 37029df30..1bb7d363f 100644 --- a/tests/asgi/__snapshots__/test_query_execution.ambr +++ b/tests/asgi/__snapshots__/test_query_execution.ambr @@ -1,4 +1,13 @@ # serializer version: 1 +# name: test_attempt_execute_anonymous_subscription_over_post_returns_error_json + dict({ + 'errors': list([ + dict({ + 'message': 'Subscription operations can only be executed over a WebSocket connection.', + }), + ]), + }) +# --- # name: test_attempt_execute_complex_query_without_variables_returns_error_json dict({ 'data': None, @@ -61,6 +70,15 @@ ]), }) # --- +# name: test_attempt_execute_subscription_over_post_returns_error_json + dict({ + 'errors': list([ + dict({ + 'message': "Operation 'Test' is a subscription and can only be executed over a WebSocket connection.", + }), + ]), + }) +# --- # name: test_attempt_execute_subscription_with_invalid_query_returns_error_json dict({ 'locations': list([ diff --git a/tests/asgi/test_query_execution.py b/tests/asgi/test_query_execution.py index 910e5bae7..4a9e456c9 100644 --- a/tests/asgi/test_query_execution.py +++ b/tests/asgi/test_query_execution.py @@ -107,6 +107,28 @@ def test_attempt_execute_query_with_invalid_operation_name_type_returns_error_js assert snapshot == response.json() +def test_attempt_execute_anonymous_subscription_over_post_returns_error_json( + client, snapshot +): + response = client.post("/", json={"query": "subscription { ping }"}) + assert response.status_code == 400 + assert snapshot == response.json() + + +def test_attempt_execute_subscription_over_post_returns_error_json( + client, snapshot +): + response = client.post( + "/", + json={ + "query": "subscription Test { ping }", + "operationName": "Test", + }, + ) + assert response.status_code == 400 + assert snapshot == response.json() + + def test_attempt_execute_subscription_with_invalid_query_returns_error_json( client, snapshot ): diff --git a/tests/wsgi/__snapshots__/test_query_execution.ambr b/tests/wsgi/__snapshots__/test_query_execution.ambr index 08ba44326..aff913278 100644 --- a/tests/wsgi/__snapshots__/test_query_execution.ambr +++ b/tests/wsgi/__snapshots__/test_query_execution.ambr @@ -1,4 +1,13 @@ # serializer version: 1 +# name: test_attempt_execute_anonymous_subscription_over_post_returns_error_json + dict({ + 'errors': list([ + dict({ + 'message': 'Subscription operations can only be executed over a WebSocket connection.', + }), + ]), + }) +# --- # name: test_attempt_execute_complex_query_without_variables_returns_error_json dict({ 'data': None, @@ -61,6 +70,15 @@ ]), }) # --- +# name: test_attempt_execute_subscription_over_post_returns_error_json + dict({ + 'errors': list([ + dict({ + 'message': "Operation 'Test' is a subscription and can only be executed over a WebSocket connection.", + }), + ]), + }) +# --- # name: test_complex_query_is_executed_for_post_json_request dict({ 'data': dict({ diff --git a/tests/wsgi/test_query_execution.py b/tests/wsgi/test_query_execution.py index b44bcc714..63c6c70f7 100644 --- a/tests/wsgi/test_query_execution.py +++ b/tests/wsgi/test_query_execution.py @@ -158,6 +158,38 @@ def test_attempt_execute_query_with_invalid_operation_name_type_returns_error_js assert_json_response_equals_snapshot(result) +def test_attempt_execute_anonymous_subscription_over_post_returns_error_json( + middleware, + start_response, + graphql_query_request_factory, + graphql_response_headers, + assert_json_response_equals_snapshot, +): + request = graphql_query_request_factory(query="subscription { ping }") + result = middleware(request, start_response) + start_response.assert_called_once_with( + HTTP_STATUS_400_BAD_REQUEST, graphql_response_headers + ) + assert_json_response_equals_snapshot(result) + + +def test_attempt_execute_subscription_over_post_returns_error_json( + middleware, + start_response, + graphql_query_request_factory, + graphql_response_headers, + assert_json_response_equals_snapshot, +): + request = graphql_query_request_factory( + query="subscription Test { ping }", operationName="Test" + ) + result = middleware(request, start_response) + start_response.assert_called_once_with( + HTTP_STATUS_400_BAD_REQUEST, graphql_response_headers + ) + assert_json_response_equals_snapshot(result) + + def test_query_is_executed_for_multipart_form_request_with_file( middleware, snapshot, start_response, graphql_response_headers ):