Skip to content

Commit

Permalink
Execute queries over GET
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Feb 15, 2024
1 parent 934dc4b commit 1ba4295
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 5 deletions.
56 changes: 51 additions & 5 deletions ariadne/asgi/handlers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class GraphQLHTTPHandler(GraphQLHttpHandlerBase):

def __init__(
self,
execute_get_queries: bool = False,
extensions: Optional[Extensions] = None,
middleware: Optional[Middlewares] = None,
middleware_manager_class: Optional[Type[MiddlewareManager]] = None,
Expand All @@ -43,6 +44,9 @@ def __init__(
# Optional arguments
`execute_get_queries`: a `bool` that controls if `query` operations
sent using the `GET` method should be executed. Defaults to `False`.
`extensions`: an `Extensions` list or callable returning a
list of extensions server should use during query execution. Defaults
to no extensions.
Expand All @@ -58,6 +62,7 @@ def __init__(
"""
super().__init__()

self.execute_get_queries = execute_get_queries
self.extensions = extensions
self.middleware = middleware
self.middleware_manager_class = middleware_manager_class or MiddlewareManager
Expand Down Expand Up @@ -114,9 +119,12 @@ async def handle_request(self, request: Request) -> Response:
`request`: the `Request` instance from Starlette or FastAPI.
"""
if request.method == "GET" and self.introspection and self.explorer:
# only render explorer when introspection is enabled
return await self.render_explorer(request, self.explorer)
if request.method == "GET":
if self.execute_get_queries and request.query_params.get("query"):
return await self.graphql_http_server(request)
elif self.introspection and self.explorer:
# only render explorer when introspection is enabled
return await self.render_explorer(request, self.explorer)

if request.method == "POST":
return await self.graphql_http_server(request)
Expand Down Expand Up @@ -182,14 +190,20 @@ async def extract_data_from_request(self, request: Request):
return await self.extract_data_from_json_request(request)
if content_type == DATA_TYPE_MULTIPART:
return await self.extract_data_from_multipart_request(request)
if (
request.method == "GET"
and self.execute_get_queries
and request.query_params.get("query")
):
return await self.extract_data_from_get_request(request)

raise HttpBadRequestError(
"Posted content must be of type {} or {}".format(
DATA_TYPE_JSON, DATA_TYPE_MULTIPART
)
)

async def extract_data_from_json_request(self, request: Request):
async def extract_data_from_json_request(self, request: Request) -> dict:
"""Extracts GraphQL data from JSON request.
Returns a `dict` with GraphQL query data that was not yet validated.
Expand All @@ -203,7 +217,9 @@ async def extract_data_from_json_request(self, request: Request):
except (TypeError, ValueError) as ex:
raise HttpBadRequestError("Request body is not a valid JSON") from ex

async def extract_data_from_multipart_request(self, request: Request):
async def extract_data_from_multipart_request(
self, request: Request
) -> dict | list:
"""Extracts GraphQL data from `multipart/form-data` request.
Returns an unvalidated `dict` with GraphQL query data.
Expand Down Expand Up @@ -240,6 +256,35 @@ async def extract_data_from_multipart_request(self, request: Request):

return combine_multipart_data(operations, files_map, request_files)

async def extract_data_from_get_request(self, request: Request) -> dict:
"""Extracts GraphQL data from GET request's querystring.
Returns a `dict` with GraphQL query data that was not yet validated.
# Required arguments
`request`: the `Request` instance from Starlette or FastAPI.
"""
query = request.query_params["query"].strip()
operation_name = request.query_params.get("operationName", "").strip()
variables = request.query_params.get("variables", "").strip()

clean_variables = None

if variables:
try:
clean_variables = json.loads(variables)
except (TypeError, ValueError) as ex:
raise HttpBadRequestError(
"Variables query arg is not a valid JSON"
) from ex

return {
"query": query,
"operationName": operation_name or None,
"variables": clean_variables,
}

async def execute_graphql_query(
self,
request: Any,
Expand Down Expand Up @@ -284,6 +329,7 @@ async def execute_graphql_query(
query_validator=self.query_validator,
query_document=query_document,
validation_rules=self.validation_rules,
require_query=request.method == "GET",
debug=self.debug,
introspection=self.introspection,
logger=self.logger,
Expand Down
33 changes: 33 additions & 0 deletions ariadne/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GraphQLError,
GraphQLSchema,
MiddlewareManager,
OperationDefinitionNode,
TypeInfo,
execute,
execute_sync,
Expand Down Expand Up @@ -71,6 +72,7 @@ async def graphql(
introspection: bool = True,
logger: Union[None, str, Logger, LoggerAdapter] = None,
validation_rules: Optional[ValidationRules] = None,
require_query: bool = False,
error_formatter: ErrorFormatter = format_error,
middleware: MiddlewareList = None,
middleware_manager_class: Optional[Type[MiddlewareManager]] = None,
Expand Down Expand Up @@ -123,6 +125,9 @@ async def graphql(
`validation_rules`: a `list` of or callable returning list of custom
validation rules to use to validate query before it's executed.
`require_query`: a `bool` controlling if GraphQL operation to execute must be
a query (vs. mutation or subscription).
`error_formatter`: an `ErrorFormatter` callable to use to convert GraphQL
errors encountered during query execution to JSON-serializable format.
Expand Down Expand Up @@ -169,6 +174,7 @@ async def graphql(
enable_introspection=introspection,
query_validator=query_validator,
)

if validation_errors:
return handle_graphql_errors(
validation_errors,
Expand All @@ -178,6 +184,9 @@ async def graphql(
extension_manager=extension_manager,
)

if require_query and not validation_errors:
validate_operation_is_query(document, operation_name)

if callable(root_value):
try:
root_value = root_value( # type: ignore
Expand Down Expand Up @@ -639,3 +648,27 @@ def validate_variables(variables) -> None:
def validate_operation_name(operation_name) -> None:
if operation_name is not None and not isinstance(operation_name, str):
raise GraphQLError('"%s" is not a valid operation name.' % operation_name)


def validate_operation_is_query(document_ast: DocumentNode, operation_name: str):
query_operations: List[Optional[str]] = []
for definition in document_ast.definitions:
if (
isinstance(definition, OperationDefinitionNode)
and definition.operation.name == "QUERY"
):
if definition.name:
query_operations.append(definition.name.value)
else:
query_operations.append(None)

if operation_name:
if query_operations not in query_operations:
raise GraphQLError(
f"Operation '{operation_name}' can't be executed using the GET "
"HTTP method. Use POST instead."
)
elif len(query_operations) != 1:
raise GraphQLError(
f"'operationName' is required if 'query' defines multiple operations."
)

0 comments on commit 1ba4295

Please sign in to comment.