Skip to content

Commit

Permalink
Prevent subscription execution without websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Feb 21, 2024
1 parent c5fa431 commit 2f6a6a8
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 0 deletions.
40 changes: 40 additions & 0 deletions ariadne/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ async def graphql(

if require_query:
validate_operation_is_query(document, operation_name)
else:
validate_operation_is_not_subscription(document, operation_name)

if callable(root_value):
try:
Expand Down Expand Up @@ -358,6 +360,8 @@ def graphql_sync(

if require_query:
validate_operation_is_query(document, operation_name)
else:
validate_operation_is_not_subscription(document, operation_name)

if callable(root_value):
try:
Expand Down Expand Up @@ -680,3 +684,39 @@ def validate_operation_is_query(
raise GraphQLError(
"'operationName' is required if 'query' defines multiple operations."
)


def validate_operation_is_not_subscription(
document_ast: DocumentNode, operation_name: Optional[str]
):
if operation_name:
validate_named_operation_is_not_subscription(document_ast, operation_name)
else:
validate_anonymous_operation_is_not_subscription(document_ast)


def validate_named_operation_is_not_subscription(
document_ast: DocumentNode, operation_name: str
):
for definition in document_ast.definitions:
if (
isinstance(definition, OperationDefinitionNode)
and definition.name.value == operation_name
and definition.operation.name == "SUBSCRIPTION"
):
raise GraphQLError(
f"Operation '{operation_name}' is a subscription and can only be "
"executed over a WebSocket connection."
)


def validate_anonymous_operation_is_not_subscription(document_ast: DocumentNode):
operations: List[OperationDefinitionNode] = []
for definition in document_ast.definitions:
if isinstance(definition, OperationDefinitionNode):
operations.append(definition)

if len(operations) == 1 and operations[0].operation.name == "SUBSCRIPTION":
raise GraphQLError(
"Subscription operations can only be executed over a WebSocket connection."
)
18 changes: 18 additions & 0 deletions tests/asgi/__snapshots__/test_query_execution.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# serializer version: 1
# name: test_attempt_execute_anonymous_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': 'Subscription operations can only be executed over a WebSocket connection.',
}),
]),
})
# ---
# name: test_attempt_execute_complex_query_without_variables_returns_error_json
dict({
'data': None,
Expand Down Expand Up @@ -61,6 +70,15 @@
]),
})
# ---
# name: test_attempt_execute_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': "Operation 'Test' is a subscription and can only be executed over a WebSocket connection.",
}),
]),
})
# ---
# name: test_attempt_execute_subscription_with_invalid_query_returns_error_json
dict({
'locations': list([
Expand Down
22 changes: 22 additions & 0 deletions tests/asgi/test_query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ def test_attempt_execute_query_with_invalid_operation_name_type_returns_error_js
assert snapshot == response.json()


def test_attempt_execute_anonymous_subscription_over_post_returns_error_json(
client, snapshot
):
response = client.post("/", json={"query": "subscription { ping }"})
assert response.status_code == 400
assert snapshot == response.json()


def test_attempt_execute_subscription_over_post_returns_error_json(
client, snapshot
):
response = client.post(
"/",
json={
"query": "subscription Test { ping }",
"operationName": "Test",
},
)
assert response.status_code == 400
assert snapshot == response.json()


def test_attempt_execute_subscription_with_invalid_query_returns_error_json(
client, snapshot
):
Expand Down
18 changes: 18 additions & 0 deletions tests/wsgi/__snapshots__/test_query_execution.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# serializer version: 1
# name: test_attempt_execute_anonymous_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': 'Subscription operations can only be executed over a WebSocket connection.',
}),
]),
})
# ---
# name: test_attempt_execute_complex_query_without_variables_returns_error_json
dict({
'data': None,
Expand Down Expand Up @@ -61,6 +70,15 @@
]),
})
# ---
# name: test_attempt_execute_subscription_over_post_returns_error_json
dict({
'errors': list([
dict({
'message': "Operation 'Test' is a subscription and can only be executed over a WebSocket connection.",
}),
]),
})
# ---
# name: test_complex_query_is_executed_for_post_json_request
dict({
'data': dict({
Expand Down
32 changes: 32 additions & 0 deletions tests/wsgi/test_query_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ def test_attempt_execute_query_with_invalid_operation_name_type_returns_error_js
assert_json_response_equals_snapshot(result)


def test_attempt_execute_anonymous_subscription_over_post_returns_error_json(
middleware,
start_response,
graphql_query_request_factory,
graphql_response_headers,
assert_json_response_equals_snapshot,
):
request = graphql_query_request_factory(query="subscription { ping }")
result = middleware(request, start_response)
start_response.assert_called_once_with(
HTTP_STATUS_400_BAD_REQUEST, graphql_response_headers
)
assert_json_response_equals_snapshot(result)


def test_attempt_execute_subscription_over_post_returns_error_json(
middleware,
start_response,
graphql_query_request_factory,
graphql_response_headers,
assert_json_response_equals_snapshot,
):
request = graphql_query_request_factory(
query="subscription Test { ping }", operationName="Test"
)
result = middleware(request, start_response)
start_response.assert_called_once_with(
HTTP_STATUS_400_BAD_REQUEST, graphql_response_headers
)
assert_json_response_equals_snapshot(result)


def test_query_is_executed_for_multipart_form_request_with_file(
middleware, snapshot, start_response, graphql_response_headers
):
Expand Down

0 comments on commit 2f6a6a8

Please sign in to comment.